Pytorch-Vanilla Transformer的实现

发布时间 2023-05-05 13:22:58作者: 信海

Vanilla Transformer

注意力提示

我们可以将是否包含自主性提示作为将注意力机制与全连接层或汇聚层区别的标准。

​ 定义外部输入至感官的信息为键-值,键是表征值的非自主提示,关注信息为查询(自主性提示)

  • 非自主提示:决策选择偏向于感官输入值,可使用参数化的全连接层或非参数化的最大汇聚层或平均汇聚层来提取信息。
  • 自主提示:通过特定的注意力汇聚方式, 将给定的查询(自主性提示)与键(非自主性提示)进行匹配, 引导决策得出最匹配的值(感官输入)。

注意力机制是指某一时刻将注意力集中到某件事,而忽略周围其他的一切事物;可以描述为将查询和一组键值对映射到输出,查询、键、值和输出均为向量,输出被计算为值的加权和,其中分配给每个值的权重由查询与相应关键字的兼容性函数计算。

注意力分数是Q和K的相似度,注意力权重是分数的softmax结果。

../_images/attention-output.svg

​ 假设有一个查询$ q∈R^q$和 \(m\)个“键-值”对
\((k_1,v_1),…,(k_m,v_m)\), 其中\(k_i∈R^k,v_i∈R^v\)。 注意力汇聚函数\(f\)就被表示成值的加权和:\(f(\mathbf{q},(\mathbf{k}_1,\mathbf{v}_1),\ldots,(\mathbf{k}_m,\mathbf{v}_m))=\sum\limits_{i=1}^m\omega(\mathbf{q},\mathbf{k}_i)\mathbf{v}_i\in\mathbb{R}^o\)

​ 其中查询\(q\)和键\(k_i\)的注意力权重(标量) 是通过注意力评分函数\(a\)将两个向量映射成标量, 再经过\(softmax\)运算得到的:\(\alpha(\mathbf{q},\mathbf{k}_i)=\mathrm{softmax}(a(\mathbf{q},\mathbf{k}_i))=\dfrac{\exp(a(\mathbf{q},\mathbf{k}_i))}{\sum_{j=1}^{m}\exp(a(\mathbf{q},\mathbf{k}_j))}\in\mathbb{R}\)

注意力类型

  • 将注意力汇聚的输出计算可以作为值的加权平均,选择不同的注意力评分函数会带来不同的注意力汇聚操作。注意力评分函数包括加性注意力、点积注意力、缩放点积注意力和双线性注意力。

  • 当查询和键是不同长度的矢量时,可以使用可加性注意力评分函数。当它们的长度相同时,使用缩放的“点-积”注意力评分函数的计算效率更高。

加法注意力:通过使用具有单个隐藏层的前向网络来计算兼容性函数(compatibility function)。虽然两者在理论复杂度上类似,但是点积注意力在实际中快很多,空间效率高,因为可以使用高度优化的矩阵乘法实现。

\(W_k\in \mathbb{R}^{h\times k},W_q\in \mathbb{R}^{h\times q},v\in\mathbb{R}^h,a(k,q)=v^T tanh(W_k k+W_qq)\)

点积注意力:输入维度较高时,模式存在较大方差,softmax函数梯度较小,\(s(h,q)=h^Tq\).

缩放点积注意力:计算查询集中元素与所有键的点积,将每个元素除以\(\sqrt{d_{k}}\),并应用softmax获得权值。实际上,会同时计算一组查询上的注意力函数,这些查询会一起打包成矩阵\(Q\),键和值也一起打包成矩阵\(K\)\(V\)通过平方根\(\sqrt{d_k}\)来平滑分数数值,如果不约束的话,当d很大时内积会很大导致softmax分布陡峭,不利于梯度反向传播;通过引入平方根能够解耦softmax(A)的分布和d的关系

双线性注意力:\(s(h_i,q)=h^TWq=h^T(U^TV)q=(Uh)^T(Vq)\)分别对查询向量\(q\)和原始输入向量\(h\)进行线性变换后计算点积模型;相比点积模型,在计算相似度时引入了非对称性。

键值对注意力机制:输入为键值对形式\((K,V)=[(k_1,v_1),(k_2,v_2),....,(k_n.v_n)]\),注意力权值为\(\alpha_i\)\(\alpha_i=softmax(s(k_i,q))=\frac{exp(s(k_i,q))}{\sum_{j=1}^nexp(s(k_j,q))}\)

自注意力机制/内部注意力:(intra-attention)查询、键和值均来自同一组输入,通过某种运算来直接计算得到句子在编码过程中每个位置上的注意力权重;然后再以权重和的形式来计算得到整个句子的隐含向量表示。但是,当模型在对当前位置的信息进行编码时,会过度将注意力集中于自身的位置。针对查询向量q,可以通过输入信息本身生成,不再选择前一时刻的查询向量,模型根据输入信息自己决定输入信息的那部分信息为重要信息。

​ 事实上,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的

多头注意力:可以用独立学习得到的h组不同的线性投影(linear projections)来变换查询、键和值。 然后,这h组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这h个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 对于h个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。好处是可以捕捉到不同角度的信息,最后将不同子空间的信息合并起来,对比普通的注意力机制,能提取到更全面、更丰富的特征。

位置编码

​ 为什么需要引入位置编码?在传统的RNN或CNN处理序列任务中,模型具有位置信息的提取能力,但是注意力机制由于采用了并行运算而放弃了顺序操作。为了使用序列的顺序信息,通过在输入表示中加入位置编码来注入绝对的或相对的位置信息。

​ 在Vanilla Transfor中位置编码采用了正余弦的形式表征,假设输入表示\(X\in R^{n\times d}\)包含1个序列中\(n\)个词元的\(d\)维嵌入表示,位置编码采用相同形状的位置嵌入矩阵\(P\in R^{n\times d}\)输出\(X+P\),矩阵第\(i\)行、第\(2j\)和第\(2j+1\)列上的元素为\(p_{i,2j}=sin(\frac{i}{1000^{\frac{2j}{d}}}),p_{i,2j+1}=cos(\frac{i}{1000^{\frac{2j}{d}}})\).

​ 在上面表示中,行代表了词元在序列中的位置,列代表了编码的不同维度。

将1个数表征为二进制形式,每个数字、每两个数字和每四个数字的比特值在第一个最低位、第二个最低位和第三个最低位上分别交替。较高比特位的交替频率要低于较低比特位。如下图所示,横轴为编码维,对同一行而言,编码维度越高,交替频率越低。

1683189074119

完整模型实现

​ Transformer的编码器和解码器是基于自注意力模块堆叠而成,源输入序列和目标输出序列的嵌入(embedding)表示将加上位置编码(position encoding),再分别输入到编码器和解码器中。

​ Transformer的编码器是由多个相同的层叠加而成的,编码器中任何层不改变输入的形状,每个层都有两个子层(子层表示为sublayer)。第一个子层是多头自注意力(multi-head self-attention)汇聚;第二个子层是基于位置的前馈网络(positionwise feed-forward network)。在计算编码器的自注意力时,查询、键和值都来自前一个编码器层的输出。每个子层采用了残差连接,对于序列中任何位置的任何输入\(x\in R^d\),要求满足\(sublayer(x) \in R^d\),以便残差连接满足\(x+sublayer(x)\in R^d\),在残差连接的加法计算之后,紧接着应用层规一化。

​ Transformer解码器也是由多个相同的层叠加而成的,并且层中使用了残差连接和层规范化。层规范化和批量规范化的目标相同,但层规范化是基于特征维度进行规范化。批量规范层在CV中使用较多,但是在NLP任务中由于输入是变长序列,批量规范化效果通常不如层规范化效果,这是因为批量规范化可能会破坏词向量的内在联系。

  • 批量规范化是以学习时的小批量样本为单位,使数据分布实现均值为0、方差为1的正规化,即\(\mu_b←\frac{1}{m}\sum_{i=1}{m}x_i;\sigma_B^2←\frac{1}{m}\sum_{i=1}^{m}(x_i-\mu_B)^2;\hat{x}_i←\frac{x_i-\mu_B}{\sqrt{\sigma_B^2+\epsilon}}\)

  • 层规范化和批量规范化的计算公式类似。

  • CV中使用的LayerNorm:对所有channel的所有像素进行计算,计算一个batch中所有channel中所有参数的均值和方差,然后进行归一化,即对CxHxW维度上的元素进行归一化(如下图蓝色区域部分所示,蓝色区域部分元素使用相同的meanvar进行归一化操作)

  • NLP中使用的LayerNorm计算一个batch中所有channel中的每一个参数的均值和方差进行归一化,即只在C维度上进行归一化计算(与CV中在CxHxW维度上计算不同),这里是N×L×C,只在C维度上进行计算。

  • 后面有空补一下这个 https://blog.csdn.net/qq_23981335/article/details/106572171

1683198704620

​ 除了编码器中描述的两个子层之外,解码器还在这两个子层之间插入了第三个子层,称为编码器-解码器注意力层。在编码器-解码器注意力中,查询来自前一个解码器层的输出,而键和值来自整个编码器的输出。在解码器自注意力中,查询、键和值都来自上一个解码器层的输出。但是,解码器中的每个位置只能考虑该位置之前的所有位置。这种掩蔽(masked)注意力保留了自回归属性,确保预测仅依赖于已生成的输出词元。

../_images/transformer.svg

李沐版本实现

注意力汇聚

非参数的Nadaraya-Watson核回归具有一致性(consistency)的优点: 如果有足够的数据,此模型会收敛到最优结果。

\(\begin{aligned}f(x)&=\sum_{i=1}^n\alpha(x_ix_i)y_i\\ &=\sum_{i=1}^n\frac{\exp\left(-\frac{1}{2}(x-x_i)^2\right)}{\sum_{j=1}^n\exp\left(-\frac{1}{2}(x-x_j)^2\right)}y_i\\ &=\sum_{i=1}^n\text{soft}\max\left(-\frac{1}{2}\left(x-x_i\right)^2\right)y_i.\end{aligned}\)

在下面的查询\(x\)和键\(x_i\)之间的距离乘以可学习参数\(w\),可以建立有参数的Nadaraya-Watson核回归模型。

\(\begin{aligned}f(x)&=\sum_{i=1}^n\omega(x,x_i)y_i\\ &=\sum_{i=1}^n\frac{\exp\left(-\frac{1}{2}\left((x-x_i)w\right)^2\right)}{\sum_{j=1}^n\exp\left(-\frac{1}{2}\left((x-x_j)w\right)^2\right)}y_i\\ &=\sum_{i=1}^n\text{softmax}\left(-\frac{1}{2}\left((x-x_i)\omega\right)^2\right)y_i.\end{aligned}\)

代码实现如下:

torch.repeat_interleave(input, repeats, dim=None) → Tensor
# FUNC: 沿着指定的维度重复张量的元素

'''
1)input (类型:torch.Tensor):输入张量
2)repeats(类型:int或torch.Tensor):每个元素的重复次数
3)dim(类型:int)需要重复的维度。默认情况下dim=None,表示将把给定的输入张量展平(flatten)为向量,然后将每个元素重复repeats次,并返回重复后的张量。
'''
import torch
from torch import nn
from d2l import torch as d2l
from matplotlib import pyplot as plt
from  torch.nn import functional as F
from utils import show_heatmaps
import numpy as np


def f(x):
    return 2 * torch.sin(x) + x ** 0.8

n_train = 50  # train samples
x_train, _ = torch.sort(torch.rand(n_train) * 5)  # sorted samples
y_train = f(x_train) + torch.normal(mean=0.0, std=0.5, size=(n_train,))  # output of train samples

x_test = torch.arange(0, 5, 0.1)  # ordered test samples
y_test = f(x_test)  # output of test samples
n_test = len(x_test)  # number of test samples
print("The number of test samples is ", n_test)

def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_test, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5)
    plt.show()

'''
# average pooling
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

# non-parametric pooling using gaussian kernel
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))  # n_test * n_train
attention_weights = F.softmax(-(X_repeat - x_train)**2 / 2, dim=1)  # n_test * 1
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
              xlabel='Sorted training inputs', ylabel='Sorted testing inputs')
'''

# parametric pooling using gaussian kernel
# batch matrix manipulation e.g.
X0 = torch.ones((2, 1, 4))
Y0 = torch.ones((2, 4, 6))
print("X0 * Y0: ", torch.bmm(X0, Y0).shape)

# 在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。
weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
print(torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1)).shape)  # 2 * 1 * 1


class MWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1, )), requires_grad=True)

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = F.softmax(-((queries - keys) * self.w) ** 2 / 2, dim=1)

        return torch.bmm(self.attention_weights.unsqueeze(1), values.unsqueeze(-1)).reshape(-1)

# 将训练数据集变换为键和值用于训练注意力模型
X_tile = x_train.repeat((n_train, 1))  # (n_train, n_train)
Y_tile = y_train.repeat((n_train, 1))  # (n_train, n_train)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))  # num_train * (num_train-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))  # num_train * (num_train-1)

net = MWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

max_iter = 5
for epoch in range(max_iter):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch+1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))

keys = x_train.repeat((n_test, 1))
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
              xlabel='Sorted training inputs', ylabel='Sorted testing inputs')

1、增加训练数据的样本数量,能否得到更好的非参数的Nadaraya-Watson​核回归模型?

​ 不能。仅仅增加数据量的话,变化的仅仅是权重矩阵的规模,而权重矩阵仅仅是由训练数据和测试数据的差值经过一层线性变换得到,其表达能力不够(欠拟合)。

2、在带参数的注意力汇聚的实验中学习得到的参数w的价值是什么?为什么在可视化注意力权重时,它会使加权区域更加尖锐?

w相当于惩罚函数。当键和查询的差异较大时,通过w加权能够使该项值趋近于0,达到过滤键和值差异大的pair,保留键和值差异小的pair,从而达到注意力效果。

不同注意力形式

这里需要引入掩码softmax​,目的是为了滤除不关心部分的影响,通过掩蔽实现不关心部分的权重输出为0。

def masked_softmax(X, valid_lens):
    # 在最后一个轴上掩蔽元素来执行softmax操作

    if valid_lens is None:
        return F.softmax(X, dim=-1)
    else:
        shape = X.shape

        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1) # 保证valid_lens为一维输出向量
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) # valid_lens元素值为V,表示不掩蔽前V个元素,从第V+1个元素开始掩蔽。 
        return F.softmax(X.reshape(shape), dim=-1)
# 查询、键和值的形状为(批量大小,步数或词元序列长度,特征大小)
# 注意力汇聚输出形状为(批量大小,查询的步数,值的维度)

# 加法注意力
class AdditiveAttention(nn.Module):
    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # 在维度扩展后,
        # queries的形状:(batch_size,查询的个数,1,num_hidden)
        # key的形状:(batch_size,1,“键-值”对的个数,num_hiddens)
        # 使用广播方式进行求和
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # self.w_v仅有一个输出,因此从形状中移除最后那个维度。
        # scores的形状:(batch_size,查询的个数,“键-值”对的个数)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # values的形状:(batch_size,“键-值”对的个数,值的维度)
        return torch.bmm(self.dropout(self.attention_weights), values)
    
# To test addictive attention
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])  # masked position
attention = AddictiveAttention(key_size=2, query_size=20, num_hiddens=8,
                              dropout=0.1)
attention.eval()
print(attention(queries, keys, values, valid_lens))

# 缩放点积注意力
class DotProductAttention(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)

    # queries : batch * num_q * d
    # keys    : batch * num_k * d
    # values  : batch * num_k * v
    # valid_lens: (batch_size, ) / (batch_size, num_q)
    def forward(self, queires, keys, values, valid_lens=None):
        d = queires.shape[-1]

        scores = torch.bmm(queires, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)
    
# To test dotproduct attention
queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])  # masked position
attention = DotProductAttention(dropout=0.1)
attention.eval()
print(attention(queries, keys, values, valid_lens))
show_heatmaps(attention.attention_weights.reshape((2, 1, 1, 10)),
                  xlabel='Keys', ylabel='Queries')

# 双线性注意力
class DoubleLinearAttention(nn.Module):
    def __init__(self, n_hiddens, dropout, bias=False):
        super(DoubleLinearAttention, self).__init__()

        self.W_h = nn.Linear(n_hiddens, n_hiddens, bias=bias)
        self.dropout = nn.Dropout(dropout)

    # queries : batch * num_q * d
    # keys    : batch * num_k * d
    # values  : batch * num_k * v
    def forward(self, queries, keys, values, valid_lens):
        scores = torch.bmm(queries, self.W_h(keys).transpose(-1, -2))
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)
    
# To test doublelinear attention
# queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
queries = keys = values = torch.randn((3, 10, 100))
valid_lens = torch.tensor([2, 6, 3])  # masked position
attention = DoubleLinearAttention(n_hiddens=100, dropout=0.1)
attention.eval()
print(attention(queries, keys, values, valid_lens))
show_heatmaps(attention.attention_weights.reshape((3, 1, 10, 10)),
              xlabel='Keys', ylabel='Queries')
# 多头注意力的实现
# 将所有注意力头里面的参数拼起来,变成了一个大的全连接层
# 这种写法和常规的写法不太一样

def transpose_qkv(X, num_heads):
    # input : batch * num_q/num_k * num_hidden
    # output: batch * num_q/num_k * num_heads * (num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # output: batch * num_heads * num_q/num_k * (num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    # output: (batch*num_heads) * num_q/num_k * (num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_out(X, num_heads): 
    # reverse the output of QKV
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return  X.reshape(X.shape[0], X.shape[1], -1)  # batch * num_k/num_q * num_hidden

class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, n_hiddens, num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)

        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, n_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, n_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, n_hiddens, bias=bias)
        self.W_o = nn.Linear(n_hiddens, n_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

        output = self.attention(queries, keys, values, valid_lens) # 缩放点积注意力
        output = transpose_out(output, self.num_heads)  # batch * num_q/num_k * num_hidden

        return self.W_o(output)
    
# To test multi-head attention
num_hiddens, num_heads = 100, 5 # 这里实际上每个头的数目为20 只不过为了并行化处理 将多头的参量进行了合并,也就是变成了20*5=100
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
attention.eval()
print(attention)

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))  # query set
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))  # key-value set
print(attention(X, Y, Y, valid_lens).shape) # batch * num_q * num_hidden

# 多头注意力可视化
multi_attention_set = attention.attention.attention_weights.reshape((batch_size, num_heads, num_queries, num_kvpairs))
show_heatmaps(multi_attention_set.reshape((batch_size, num_heads, num_queries, num_kvpairs)), xlabel='Keys', ylabel='Queries')

# 自注意力和多头注意力的区别在于QKV均采用的是同一输入
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
attention.eval()
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
print(attention(X, X, X, valid_lens).shape)
multi_attention_set = attention.attention.attention_weights.reshape((batch_size, num_heads, num_queries, num_queries))
show_heatmaps(multi_attention_set.reshape((batch_size, num_heads, num_queries, num_queries)), xlabel='Keys', ylabel='Queries')

模型实现
基于位置的前馈网络

​ 基于位置的前馈网络对序列中的所有位置的表示进行变换时使用的是同一个多层感知机(MLP),这就是称前馈网络是基于位置的(position-wise)的原因。

class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_inputs, ffn_num_hiddens, ffn_num_outputs, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)

        self.mlp1 = nn.Linear(ffn_num_inputs, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.mlp2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.mlp2(self.relu(self.mlp1(X)))
残差连接和层规范层
# add & layernorm
class AddNorm(nn.Module):
    def __init__(self, norm_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        
        self.dropout = dropout
        self.ln = nn.LayerNorm(norm_shape)
        
    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)  # the shape of X and Y should be equal
编码器/解码器
# d2l中encoder-decoder的源代码

class EncoderDecoder(nn.Module):
    """The base class for the encoder-decoder architecture.

    Defined in :numref:`sec_encoder-decoder`"""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)
    
# 训练源代码
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):
    """Train a model for sequence to sequence.

    Defined in :numref:`sec_seq2seq_decoder`"""
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])
    net.apply(xavier_init_weights)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = MaskedSoftmaxCELoss()
    net.train()
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs])
    for epoch in range(num_epochs):
        timer = d2l.Timer()
        metric = d2l.Accumulator(2)  # Sum of training loss, no. of tokens
        for batch in data_iter:
            optimizer.zero_grad()
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
                               device=device).reshape(-1, 1)
            dec_input = d2l.concat([bos, Y[:, :-1]], 1)  # Teacher forcing
            Y_hat, _ = net(X, dec_input, X_valid_len)
            l = loss(Y_hat, Y, Y_valid_len)
            l.sum().backward()  # Make the loss scalar for `backward`
            d2l.grad_clipping(net, 1)
            num_tokens = Y_valid_len.sum()
            optimizer.step()
            with torch.no_grad():
                metric.add(l.sum(), num_tokens)
        if (epoch + 1) % 10 == 0:
            animator.add(epoch + 1, (metric[0] / metric[1],))
    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '
          f'tokens/sec on {str(device)}')
    
# TRM编码器的组成部分

class EncoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, n_hiddens, norm_shape, ffn_num_inputs, ffn_num_hiddens,
                 num_heads, dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)

        self.attention = MultiHeadAttention(key_size, query_size, value_size, n_hiddens, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_inputs, ffn_num_hiddens, n_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))
    
# 多层编码器模块堆叠的TRM-ENCODER
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, key_size, query_size, value_size, n_hiddens, norm_shape, ffn_num_inputs, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)

        self.n_hiddens = n_hiddens
        self.embedding = nn.Embedding(vocab_size, n_hiddens)  # batch * vocab_size * n_embed
        self.pos_encoding = PositionEmbedding(n_hiddens, dropout)

        self.encoder_set = nn.Sequential()
        for i in range(num_layers):
            self.encoder_set.add_module("block " + str(i), EncoderBlock(key_size, query_size, value_size, n_hiddens, norm_shape, ffn_num_inputs, ffn_num_hiddens, num_heads, dropout, use_bias))

    def forward(self, X, valid_lens, *args):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.n_hiddens))  # batch * num_q * num_embedding

        self.attention_weights = [None] * len(self.encoder_set)  # 存放每层编码器注意力权重
        for i, enc_set in enumerate(self.encoder_set):
            X = enc_set(X, valid_lens)
            self.attention_weights[i] = enc_set.attention.attention.attention_weights
        return X
# TRM解码器的组成部分

class DecoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, n_hiddens, norm_shape, ffn_num_inputs, ffn_num_hiddens, num_heads, dropout, flag, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)

        self.flag = flag  # 在这个版本无用
        self.attention1 = MultiHeadAttention(key_size, query_size, value_size, n_hiddens, num_heads, dropout) 
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(key_size, query_size, value_size, n_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_inputs, ffn_num_hiddens, n_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        # state[0]: encoder output, batch * num_steps * num_embedding
        # state[1]: masked size

        # 训练阶段,输出序列的所有词元都在同一时间处理,
        # 预测阶段,输出序列是通过词元一个接着一个解码的,
        
        batch_size, num_steps, _ = X.shape
        dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)  
        # batch_size * num_steps
        # valid_lens作用于masked_softmax中用于掩蔽不相干信息;这里dec_valid_lens内容为[1 2 3 ... steps],表示与当前时间步有关,防止每个时间的词元不对未来产生注意力而设置,即解码器的自注意力计算只关注已生成的词元位置,第N时间步的单词只能对前N个时间步的单词有注意力。
        
        # self-attention
        X1 = self.attention1(X, X, X, dec_valid_lens)
        Y1 = self.addnorm1(X, X1)

        # encoder-decoder attention
        X2 = self.attention2(Y1, enc_outputs, enc_outputs, enc_valid_lens)
        Y2 = self.addnorm2(Y1, X2)

        return self.addnorm3(Y2, self.ffn(Y2)), state
    
# 多层解码器模块堆叠的TRM-DECODER

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, key_size, query_size, value_size, n_hiddens, norm_shape, ffn_num_inputs, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)

        self.n_hiddens = n_hiddens
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, n_hiddens)
        self.pos_encoding = PositionEmbedding(n_hiddens, dropout)
        self.decoder_set = nn.Sequential()
        for i in range(num_layers):
            self.decoder_set.add_module("block " + str(i), DecoderBlock(key_size, query_size, value_size, n_hiddens,norm_shape, ffn_num_inputs, ffn_num_hiddens, num_heads, dropout, i))
        self.fc = nn.Linear(n_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        self.seqX = None
        return [enc_outputs, enc_valid_lens]

    def forward(self, X, state):
        # 在预测时把之前每一步的预测结果拼接在一起保存,使得预测第t步目标时输入解码器的X是从第0步到第(t-1)步的张量,而最后输出的结果只取与(t-1)步对应的内容。
        if not self.training:  # 如果非训练,需要保存每一步的预测结果
            self.seqX = X if self.seqX is None else torch.cat((self.seqX, X), dim=1)
            X = self.seqX

        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.n_hiddens))
        self._attention_weights = [[None] * len(self.decoder_set) for _ in range(2)]
        for i, dec_set in enumerate(self.decoder_set):
            X, state = dec_set(X, state)
            self._attention_weights[0][i] = dec_set.attention1.attention.attention_weights 
            # decoder attention weights
            self._attention_weights[1][i] = dec_set.attention2.attention.attention_weights 
            # encoder-decoder attention weights

        if not self.training:  # 非训练 返回时只取最后一个时间步结果
            return self.fc(X)[:, -1:, :], state

        return self.fc(X), state

    @property # only read
    def attention_weights(self):
        return self._attention_weights
# 模型训练部署

num_steps, num_hiddens, num_layers, batch_size = 10, 32, 2, 64
dropout = 0.1
lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
ffn_num_inputs, ffn_num_hiddens, num_heads = 32, 64, 4
key_size, query_size, value_size = 32, 32, 32
norm_shape = [32]

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)  # load data

encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_inputs,
                             ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_inputs,
                             ffn_num_hiddens, num_heads, num_layers, dropout)

# 调用d2l内嵌的encoder-decoder来初始化
network = d2l.EncoderDecoder(encoder, decoder)

d2l.train_seq2seq(network, train_iter, lr, num_epochs, tgt_vocab, device)  # training
plt.show()

model_save_dir = './model_logs/'
model_save_path = os.path.join(model_save_dir, 'model.pt')
torch.save(network.state_dict(), model_save_path)
# 模型预测部署

if os.path.exists(model_save_path):
    loaded_paras = torch.load(model_save_path)
    network.load_state_dict(loaded_paras)
    network.to(device)

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, dec_attention_weight_seq = d2l.predict_seq2seq(
        network, eng, src_vocab, tgt_vocab, num_steps, device, True)
    print(f'{eng} => {translation}, ',
          f'bleu {d2l.bleu(translation, fra, k=2):.3f}')