【d2l】【常见函数】【11】 nn.GRU()

发布时间 2023-07-27 16:05:14作者: zz子木zz

门控循环神经网络的API

问题来源

【动手学深度学习】【9.7 序列到序列学习】

这个output和state的形状硬是没看懂

问题解决

参考:
1 B站弹幕

2 官方文档
https://pytorch.org/docs/stable/generated/torch.nn.GRU.html

做了张图,这两段代码大概做了这样的事

self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,
                          dropout=dropout)
def forward(self, X, *args):
      # 输出'X'的形状:(batch_size,num_steps,embed_size)
      X = self.embedding(X)
      # 在循环神经网络模型中,第一个轴对应于时间步
      X = X.permute(1, 0, 2)
      # 如果未提及状态,则默认为0
      output, state = self.rnn(X)
      # output的形状:(num_steps,batch_size,num_hiddens)
      # state的形状:(num_layers,batch_size,num_hiddens)
        return output, state