[机器学习]对transformer使用padding mask

发布时间 2023-08-14 17:18:14作者: 溡沭

注:本文是对GPT4的回答的整理校正补充。

在处理序列数据时,由于不同的序列可能具有不同的长度,我们经常需要对较短的序列进行填充(padding)以使它们具有相同的长度。但是,在模型的计算过程中,这些填充值是没有实际意义的,因此我们需要一种方法来确保模型在其计算中忽略这些填充值。这就是padding mask的作用。

比如常用的就是在数据集准备中,想用batch来训练,就得将一个batch的数据的长度全部对齐。

1. 什么是Padding Mask?

Padding mask是一个与输入序列形状相同的二进制矩阵,用于指示哪些位置是真实的数据,哪些位置是填充值。

  • 真实数据位置的mask值为0。
  • 填充位置的mask值为1。

2. 如何使用Padding Mask?

在自注意力机制中,我们计算查询和键的点积来得到注意力分数。在应用softmax函数之前,我们可以使用padding mask来确保填充位置的注意力分数为一个非常大的负数(例如,乘以-1e9)。这样,当应用softmax函数时,这些位置的权重将接近于零,从而确保模型在其计算中忽略这些填充值。

3. 示例

假设我们有一个长度为4的序列:[A, B, C, <pad>],其中<pad>是填充标记。对应的padding mask是:[0, 0, 0, 1]

在计算注意力分数后,我们可以使用以下方法应用padding mask:

attention_scores = attention_scores.masked_fill(mask == 1, -1e9)

这里,masked_fill是一个PyTorch函数,它会将mask中值为1的位置替换为-1e9

看图,这里的attention_scores就是Q×K的矩阵,把尾部多余的部分变成-inf,再过SoftMax,这样就是0了。这样,即使V的后半部分有padding的部分,也会因为乘0而变回0。这样被padding掉的部分就从计算图上被剥离了,由此不会影响模型的训练。

4. 代码

笔者自己写的,不保证靠谱哈。

import torch.nn as nn

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, mask=None):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        
        # Apply the padding mask
        if mask is not None:
            attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2) == 1, float('-inf'))
        
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

5. 为什么需要Padding Mask?

  • 忽略无关信息:通过使用padding mask,我们可以确保模型在其计算中忽略填充值,从而避免这些无关的信息对模型的输出产生影响。

  • 稳定性:如果不使用padding mask,填充值可能会对模型的输出产生不稳定的影响,尤其是在使用softmax函数时。

  • 解释性:使用padding mask可以提高模型的解释性,因为我们可以确保模型的输出只与真实的输入数据有关,而不是与填充值有关。

总之,padding mask是处理序列数据时的一个重要工具,它确保模型在其计算中忽略填充值,从而提高模型的性能和稳定性。