详细了解Transformer:Attention Is All You Need

发布时间 2023-09-29 12:52:39作者: zh-jp

原文链接:Attention Is All You Need

1. 背景

在机器翻译任务下,RNN、LSTM、GRU等序列模型在NLP中取得了巨大的成功,但是这些模型的训练是通常沿着输入和输出序列的符号位置进行计算的顺序计算,无法并行。

文中提出了名为Transformer的模型架构,完全依赖注意力机制(Attention Mechanisms),构建输入与输出间的依赖关系,并且能够并行计算,使得模型训练速度大大提高,能够在较短的训练时间内达到新的SOTA水平。

2. 模型架构

2.1 编码器(Encoder)与解码器(Decoder)

先放下具体的细节,从上图Transformer的模型架构中可以发现,模型被分为左右两块。

对于左侧的Encoder,实际是有\(N=6\)层相同的Encoder堆叠而成。这N层中,每层都有两个子层分别为:多头注意力层(Muti-Head Attention)前馈网络层(Feed Forward),两个子层都带有残差连接(Add)归一化层(Norm)。因此每一个子层的输出表示为:\(LayerNorm(x+SubLayer(x))\)\(x\)表示输入,\(SubLayer(x)\)表示子层(多头注意力层或前馈网络层)的计算结果。为了实现残差连接,所有子层、嵌入层还有最后的模型输出结果的维度\(d_{model}\)统一设置为\(512\)

对于右侧的Decoder,同样有着\(N=6\)层相同的Decoder堆叠而成。从上至下看:

  • 第一层前馈网络层与Encoder部分一致;
  • 第二层多头注意力层,参与计算的数据由两部分(下文将介绍这两部分)来自Encoder,这两部分的数据是N层的Encoder的运算结果,经过N层计算后输入Decoder,Encoder的这一运算结果会在N层Decoder的每一层多头注意力层参与计算;
  • 第三层是带掩码的多头注意力层,它通过对偏移的输出Embedding添加掩码,确保位置为\(i\)的预测仅依赖于\(i\)之前的数据。

Embedding:就是将现实的信息如文本、音频、图像转化为向量供模型使用,就像它的本意:嵌入进神经网络一样。嵌入层通常为模型的第一层,负责将现实信息转为向量。

Inputs和Outputs:以机器翻译任务:将中文翻译为English为例,Inputs就是“我是一只猫”,Outputs为“I am a cat”即训练的Label,而模型的输出应为模型预测的“我是一只猫”的翻译。

Outputs(shifted right):在图中,Outputs向右偏移了一位,这是因为在Outputs的第一位插入了Token(词元:文本中的基本单位也可称为“分词”),这个Token为翻译开始符一般为写为<Begin>。在本例中,偏移后的Outputs为“<Begin> I am a cat”,这样做的目的是为了让模型预测的第一个词为“我”,第二个词为“是”,以此类推。

2.2 注意力(Attention)

观察图一,Encoder和Decoder的多头注意力层(包括Decoder中带掩码的多头注意力层)的输入,都是由三部分组成:查询(Query)键(Key)值(Value),简称QKV。所谓的“注意力”便是由这三部分经过“注意力函数”计算得到。

关于Q、K、V,不需要太在意它们的名称,它们来自于输入\(X\)和对应的权重矩阵\(W\)的乘积结果,权重矩阵由训练得到。
\(Q=XW^Q\)  \(K=XW^K\)  \(V=XW^V\)

2.2.1 缩放点积注意力(Scaled Dot-Product Attention)

文中,作者将Transformer中的注意力称为“缩放点积注意力”。输入的组成包括:\(d_k\)维的Q和K,以及\(d_v\)维的V。运算步骤如图2左侧所示:

  1. Q与K进行矩阵相乘(MatMul);
  2. 进行缩放(Scale),即除\(\sqrt{d_k}\)
  3. 添加掩码(Mask(opt.)),这一步针对Decoder中的Masked Multi-Head Attention层,文中添加掩码的方式是将需要添加掩码的部分设置为\(-\infin\)
  4. 通过Softmax函数计算得到权重;
  5. 将权重与V进行矩阵相乘。

之所以要进行缩放文中给出了解释:如果\(d_k\)过大,那么点积\(QK^T\)的值(magnitude)就会过大,导致Softmax函数过于陡峭(因为某一点过大)也就是方差过大。
假设\(q\)\(k\)是均值为0、方差为1的独立随机变量,当维度为\(d_k\)时,它们的点积为\(q\cdot k=\sum^{d_k}_{i=1}q_ik_i\),它们的均值为0,方差为\(d_k\)
这会导致梯度消失,为了避免这种情况需要通过第二步对点积结果进行缩放。

实践中,将多组Q打包为一个矩阵\(Q\),同理K和V也被打包为矩阵\(K\)\(V\)。结合上述计算步骤,矩阵运算的结果可以表示为:

\[Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V \tag{1} \]

2.2.2 多头注意力(Multi-Head Attention)

文中,在使用Attention函数前,还对Q、K、V分别进行\(d_k\)\(d_k\)\(d_v\)维的线性变换(Linear),如图2右侧从下往上的第一层所示。同时有\(h\)组这样的Q、K、V并行计算,每组得到\(d_v\)维的输出,这也是称为“多头”的原因。将这些输出连接并再进行一次线性变换(Linear)得到最后的输出。

多头注意力使得模型联合关注来自不同位置的信息,弥补单个Attention函数的局限。

\[MultiHead(Q,K,V)=Concat(head_1,...,head_h)W^O \\ head_i=Attention(QW_i^Q,KW_i^K,VW_i^V) \]

这里的线性变换参数矩阵:\(W_i^Q\in \mathbb{R}^{d_{model}\times d_k},W_i^K\in \mathbb{R}^{d_{model}\times d_k},W_i^V\in \mathbb{R}^{d_{model}\times d_v}\)

文中,\(h=8\)层注意力函数并行计算,其中变量的维度设置为:\(d_k=d_v=d_{model}/h=64\),由于每个头的维度减少,总的计算成本与之前的单个注意函数相似。

2.3 以位置为单元的前馈神经网络(Position-wise Feed-Forward Networks)

在每一层Encoder和Decoder中,都有一层全连接前馈神经网络,它分别作用于每个位置,进行相同的处理。它包含两个线性变换,中间包含着一层ReLU激活。

\[FFF(x)=\max(0,xW_1+b_1)W_2+b_2\tag{2} \]

不同位置的线性变换方式相同,但层与层之间的参数不同。通常是两个核为1的卷积层,它们输出输出的维度\(d_{model}=512\),中间的隐藏层维度\(d_{ff}=2048\)

2.4 Embeddings与Softmax

与其他处理同一类问题的的模型相同,Transformer需要先用嵌入层(Embedding)将输入输出的词元转化为\(d_{model}=512\)维的向量。在Decoder中,使用线性变换和Softmax函数将模型的输出转化为预测的下一个词元的概率分布。两个嵌入层使用相同的权重矩阵,它们的权重还将乘以\(\sqrt{d_{model}}\)

2.5 位置编码(Positional Encoding)

Transformer不含卷积和循环,但是序列的绝对位置和相对位置信息不可或缺,因此需要将位置编码添加到Encoder和Decoder的Embedding层之后,位置编码的维度同样是\(d_{model}\)(毕竟同一维度才能相加)。位置编码的计算方式有许多,文中采用了不同频率的余弦、正弦函数计算。

\[PE_{(pos,2i)}=sin(pos/10000^{2i/d_{model}}) \\ PE_{(pos,2i+1)}=cos(pos/10000^{2i/d_{model}}) \]

其中,\(pos\)是位置,\(i\)表示维度(\(i\le d_{model}\))。每一个位置编码的维度都对应一个正弦曲线。波长范围为\([2\pi,20000\cdot\pi]\)的几何级数。使用这种编码方式的原因是:对于任何固定的偏移量\(k\)\(PE_{pos+k}\)可以表示为\(PE_{pos}\)的线性函数。因此,Transformer可以轻松地学习相对位置的偏移量。另外,这种位置编码方法在当模型遇到比训练时更长的序列也能有较好的表现。

3. 为什么是Attention

作者在使用Attention机制时,考虑了以下三个方面:

  1. 每一层的总计算复杂度;
  2. 能够用于并行计算的计算量;
  3. 网络中,向量依赖的路径长度。在机器翻译等序列任务中,这是一项关键。需要考虑当前位置的前向信号和后向信号在网络中的路径长度

第二点已经被多头注意力机制解决。对于第一点与第三点,文中提到在处理长度较大的序列任务时,可以限制注意力的范围,以序列中相应的输出范围为中心,设置大小为\(r\)的邻域,这时,循环层需要的\(O(n)\)的序列操作,在这里只需要\(O(n/r)\)

参考文献