RNN(循环神经网络)

发布时间 2023-04-05 22:48:48作者: longRookie

1.递归神经网络的历史版本

递归神经网络有两种类型:Jordan network和Elman network;现在常用的RNN(包括LSTM、GRU等)都是使用Elman network。

Elman network是在Jordan network的基础上进行了创新,并且简化了它的结构。

它们之间的区别:

  • Jordan network是将网络的输出层反馈回网络的输入层;
  • Elman network的每一个循环层都是互相独立的,因此网络结构的设计可以更加灵活;当Jordan network的输出层与循环层的维度不一致时还需要额外的调整;

因此当前的主流的循环神经网络都是基于Elman network。

2.RNN

2.1 RNN原理

循环神经网络(recurrent neural networks,RNNs)是具有隐状态的神经网络。

循环神经网络原理图:

设在时间步t有小批量输入\(\mathbf{X}_{t} \in \mathbb{R}^{n \times d}\),其中批量大小为n,输入维度为d。对于n个序列样本的小批量,\(\mathbf{X}_{t}\)的每一行对应于来自该序列的时间步\(t\)处的一个样本。

\(\mathbf{H}_{t} \in \mathbb{R}^{n \times h}\)表示时间步t的隐藏变量。与多层感知机不同,这里保存了前一个时间步的隐藏变量\(\mathbf{H}_{t-1}\),并引入了一个新的权重参数\(\mathbf{W}_{h h} \in \mathbb{R}^{h \times h}\),用于描述如何在当前时间步中使用前一个时间步的隐藏变量。

当前时间步隐藏变量的计算公式:

\[\mathbf{H}_{t}=\phi\left(\mathbf{X}_{t} \mathbf{W}_{x h}+\mathbf{H}_{t-1} \mathbf{W}_{h h}+\mathbf{b}_{h}\right) \]

对于时间步t,输入层的输出为:

\[\mathbf{O}_{t}=\mathbf{H}_{t} \mathbf{W}_{h q}+\mathbf{b}_{q} \]

循环神经网络的参数包括隐藏层的权重\(\mathbf{W}_{x h} \in \mathbb{R}^{d \times h}\)\(\mathbf{W}_{h h} \in \mathbb{R}^{h \times h}\)和偏置\(\mathbf{b}_{h} \in \mathbb{R}^{1 \times h}\),以及输出层的权重\(\mathbf{W}_{h q} \in \mathbb{R}^{h \times q}\)和偏置\(\mathbf{b}_{q} \in \mathbb{R}^{1 \times q}\)。即使在不同的时间步,循环神经网络参数共享。因此,循环神经网络的参数开销不会随着时间步的增加而增加。

2.2 pytorch实现RNN

torch.nn.RNN(*args,**kwargs)

隐藏层输出公式为:

\[h_{t}=\tanh \left(x_{t} W_{i h}^{T}+b_{i h}+h_{t-1} W_{h h}^{T}+b_{h h}\right) \]

参数:

  • input_size – The number of expected features in the input x
  • hidden_size – The number of features in the hidden state h
  • num_layers – Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1
  • nonlinearity – The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'
  • bias – If False, then the layer does not use bias weights b_ih and b_hh. Default: True

参考链接:
[1] https://en.wikipedia.org/wiki/Recurrent_neural_network