transformer中解码器的实现细节

发布时间 2023-07-26 18:42:20作者: 大师兄啊哈

1. 前言

17年google团队发表l了论文《Attention Is All You Need》,transformer横空出世,并引领了AI学术圈的研发风向,以Transformer为基础模型的新模型层出不穷,无论是NLP还是CV或者是多模态,attention遍地开花。

这篇文章遵循encoder-decoder架构,并在其中使用了self-attention和cross-attention,如下图所示:

transformer架构图

其中,encoder的行为还是非常好理解的,至于decoder,则相关细节在原文中都只草草提过,令人留下很多疑问,譬如,

decoder第一个attention为什么需要使用masked?

decoder在训练阶段和测试阶段有什么区别?

decoder在测试阶段,decoder的query输入是将目前所有的预测输入,还是只输入上一次decoder的输出?

2. 问题探讨

decoder第一个attention为什么需要使用masked?

Transformer模型属于自回归模型,也就是说后面的token的推断是基于前面的token的。Decoder端的Mask的功能是为了保证训练阶段和推理阶段的一致性。
在推理阶段,token是按照从左往右的顺序推理的。也就是说,在推理timestep=T的token时,decoder只能“看到”timestep < T的 T-1 个Token, 不能和timestep大于它自身的token做attention(因为根本还不知道后面的token是什么)。为了保证训练时和推理时的一致性,所以,训练时要同样防止token与它之后的token去做attention。

 

decoder在训练阶段和测试阶段有什么区别?

在训练阶段,预测序列是直接全部喂到decoder的输入的,只是在算self-attention的时候加了一个mask,前面时间步的不能看到后面时间步的词,decoder的预测也是一次就全部出来了,也就是Teacher Forcing机制,如下图所示,在训练的时候,需要预测一段语音,decoder端的input,就直接把gt喂进去了,当然加进去前还需要shift right,在序列最左边增加一个Begin的特殊字符(为了和预测阶段保持一致),然后这些gt作为query,进行进入第一层mask multi-head attention层(根据时间步增加mask,以免在self-attention阶段前面的词可以看到后面的),然后以这层的输出为query,来自encoder的输出为key-value pair输入第二个子层multi-head attention,输出作为下层的输入,继续前面的过程,重复N次。

下载 (1)

如果是测试阶段,则就不一样,首先decoder会先输入Begin,预测出下一个词,然后再以已经预测的词作为输入,再进入decoder预测下一个词,直到遇到预测出的词是表示结束的特殊次元,才结束这个过程,参考以下视频:

https://www.zhihu.com/zvideo/1330559583777939456

 

decoder在测试阶段,decoder的query输入是将目前所有的预测输入,还是只输入上一次decoder的输出?

两种实现都有,具体来说,分别是:

a. 每次都将当前预测全部输入,在self-attention和cross-attention中均进行全量计算,优点是实现简单,缺点是计算量大,如下面的代码实现:

class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

可以看到每次的做self-attention的时候query,key,value都是目前所有的词(query 做了mask操作)。

完全版可以查看:https://zhuanlan.zhihu.com/p/398039366

b. 还有另外一种实现就是增量进行计算,李沐在《动手学深度学习》中就用了这种实现,优点是每次只需要计算一个query,但是因为在self-attention中需要与其他的词进行attention操作,因此需要在每层中保存之前的词作为key和value,如下面代码所示:

class DecoderBlock(nn.Module):
    """解码器中第i个块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
                                   num_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        # 训练阶段,输出序列的所有词元都在同一时间处理,
        # 因此state[2][self.i]初始化为None。
        # 预测阶段,输出序列是通过词元一个接着一个解码的,
        # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), axis=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            # dec_valid_lens的开头:(batch_size,num_steps),
            # 其中每一行是[1,2,...,num_steps]
            dec_valid_lens = torch.arange(
                1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        # 自注意力
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        # 编码器-解码器注意力。
        # enc_outputs的开头:(batch_size,num_steps,num_hiddens)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

其中state[2][self.i]就存储了目前为止所有预测到的词。

完整版可以查看:https://zh-v2.d2l.ai/chapter_attention-mechanisms/transformer.html

3. 参考

[1] Transformer源码详解(Pytorch版本)

[2] 10.7. Transformer

(完)