免注意力Transformer (AFT):使用逐元素乘积而不是点积

发布时间 2023-05-17 10:33:35作者: 鸽鸽的书房

注意力机制作为现代深度学习模型的基石,能够毫不费力地对长期依赖进行建模,并关注输入序列中的相关信息。然而,需要点积自注意力 - 广泛使用在Transformer架构中的一个关键组件 - 已被证明在序列长度方面具有二次空间复杂度,因此不适用于处理长输入。在本文中,我们介绍了Attention Free Transformer(AFT),这是Transformer的一个新变种,消除了点积自注意力的需要,使内存复杂度从\(O\left(T d^2\right)\)变成\(O(Td)\),其中T是序列长度,d是嵌入的维数。

代码:https://nn.labml.ai/transformers/aft/index.html

多头自注意力

在Transformer内核中,多头自注意力(Multi-Head Attention,MHA)操作是其关键之一。在自注意力模式下,给定输入序列\(X \in R^{T \times d}\)和头数\(h\),MHA为每个头\(i\)执行缩放点积注意力,定义为:

\[f_i(X)=\sigma\left(\frac{Q_i\left(K_i\right)^T}{\sqrt{d_k}}\right) V_i \]

其中\(Q_i=XW_i^Q,K_i=XW_i^K,V_i=XW_i^V\)

\(W_i^Q \in R^{d \times d_k}, W_i^K \in R^{d \times d_k}, W_i^V \in R^{d \times d_v}\)是头\(i\)的线性变换,\(\sigma\)是默认设置为Softmax函数(应用于矩阵的每一行)的非线性函数。\(d_k\)\(d_v\)分别是键和值的维度。

MHA沿通道维度连接\(h\)个注意力头的输出,结果的特征维度为\(hd_v\)。除非另有说明,我们假设\(d_k=d_v\)\(h=\frac{d}{d_k}\)。这意味着每个头的查询、键和值都是相同的维度,并且输出维度与输入维度匹配。

AFT

我们现在定义无注意力变换器(AFT),它是MHA的插件替换,无需更改Transformer的其他架构方面。给定输入\(X\),AFT首先线性地将其转换为\(Q=XW^Q,K=XW^K,V=XW^V\),然后执行以下操作:

\[Y=f(X) ; Y_t=\sigma_q\left(Q_t\right) \odot \frac{\sum_{t^{\prime}=1}^T \exp \left(K_{t^{\prime}}+w_{t, t^{\prime}}\right) \odot V_{t^{\prime}}}{\sum_{t^{\prime}=1}^T \exp \left(K_{t^{\prime}}+w_{t, t^{\prime}}\right)} \]

其中, \(\odot\) 是逐元素相乘;\(\sigma_q\) 是应用于默认为sigmoid的查询的非线性函数;\(w \in R^{T×T}\) 是经过学习的成对位置偏差(详见图2的说明)。

用语言解释,对于每个目标位置 \(t\),AFT 执行值的加权平均,其结果与查询进行逐元素乘法。特别地,权重仅由键和一组经过学习的成对位置偏差组成。这提供了一个直接的优势,即无需计算和存储昂贵的注意力矩阵,同时保持MHA所做的查询和值之间的全局交互。 为了进一步了解 AFT 与 MHA 的关系,我们可以将方程2重写为:

\[Y_t^i=<a_t^i, V^i>\text {, s.t. } a_t^i=\frac{\sigma_q\left(Q_t^i\right) \exp \left(K^i+w_t\right)}{\sum_{t^{\prime}=1}^T \exp \left(K_{t^{\prime}}^i+w_{t, t^{\prime}}\right)}, i=1,2, \ldots, d, t=1,2, \ldots, T . \]

这里我们使用上标 \(i\) 来索引矩阵的特征维度; $ ;<\cdot, \cdot> $ 表示向量的点积。在这种重新排列的形式中,我们能够再次用注意力来表示 AFT。具体来说,对于每个位置,我们有一个注意力向量 \(a_t^i \in R^T\),对于每个维度,由\(Q,K,w\)组成。换句话说,AFT 可以解释为使用与特征维度同样多的头数执行隐式注意力,其中注意力矩阵采用分解的形式。

简单来说,在AFT层中,首先使用一组学习到的位置偏置将键和值组合,然后将这个结果与查询逐元素相乘。

Explained in words, for each target position t, AFT performs a weighted average of values, the result of which is combined with the query with element-wise multiplication. In particular, the weighting is simply composed of the keys and a set of learned pair-wise position biases. This provides the immediate advantage of not needing to compute and store the expensive attention matrix, while maintaining the global interactions between query and values as MHA does.

https://paperswithcode.com/method/attention-free-transformer

模型变种

我们还提出了两种模型变种,AFT-local和AFT-conv。AFT-local利用地方性思想降低长序列中注意力计算的计算成本,而AFT-conv则利用数据的空间结构进一步提高效率。

AFT-local

在AFT-local这个变体中,我们仅在本地应用一组学习的相对位置偏差:

\[w_{t, t^{\prime}}= \begin{cases}w_{t, t^{\prime}}, & \text { if }\left|t-t^{\prime}\right|<s \\ 0, & \text { otherwise }\end{cases} \]

这里, \(s \leq T\) 是局部窗口大小。AFT-local 提供了进一步的计算节省,包括参数数量以及时间/空间复杂度。请注意,与本地 Transformer不同,AFT-local不管窗口大小\(s\)如何都会保持全局连通性。

AFT-simple

AFT-simple是AFT-local的一种极端形式,其中s = 0表示不学习位置偏差。这产生了一种非常简单的AFT版本:

\[Y_t=\sigma_q\left(Q_t\right) \odot \frac{\sum_{t^{\prime}=1}^T \exp \left(K_{t^{\prime}}\right) \odot V_{t^{\prime}}}{\sum_{t^{\prime}=1}^T \exp \left(K_{t^{\prime}}\right)}=\sigma_q\left(Q_t\right) \odot \sum_{t^{\prime}=1}^T(\operatorname{softmax}(K) \odot V)_{t^{\prime}} . \]

在这个版本中,上下文缩减进一步简化为逐元素操作和全局池化。AFT-simple类似于线性化注意力,其公式为:

\(Y_t=\) \(\frac{\phi\left(Q_t\right) \sum_{t^{\prime}=1}^T\left(\phi\left(K_{t^{\prime}}\right)^T V_{t^{\prime}}\right)}{\phi\left(Q_t\right) \sum_{t^{\prime}=1}^T \phi\left(K_t\right)^T}\).

AFT-conv

我们还可以进一步扩展到局部权重共享,即卷积。这种变体在视觉任务中特别相关,因为通常希望将预训练模型扩展到变量大小的输入上。具体来说,我们使 \(w_{t, t^{\prime}}\) 的值仅取决于t和t′相对于给定空间网格(\(1d\)\(2d\))的相对位置。与CNN类似,我们也可以学习多组位置偏差(我们重复使用heads的概念)。为了解决heads数量增加时参数数量增长的问题,我们选择将\(K\)的维数与heads数量相关联。这使得AFT-conv适合深度可分离卷积、全局池化和逐元素操作的实现。

我们现在展示具有 \(1d\) 输入的AFT-conv的示例,\(2d\)\(3d\) 输入可以类似地推导出来。我们将模型配置表示为AFT-conv-h-s,其中 \(h\) 是heads的数量,\(s\) 是1d局部窗口大小。我们现在有\(w \in R^{h \times s}, Q, V \in R^{T \times h \times \frac{d}{n}}, K \in R^{T \times h}\). 对于每个头 \(i=1,2, \ldots, h\), 我们有:

\[Y_t^i=\sigma_q\left(Q_t^i\right) \odot \frac{\operatorname{conv} \operatorname{ld}\left(\exp \left(K^i\right) \odot V^i, \exp \left(w^i\right)-1\right)+\sum_{t^{\prime}=1}^T \exp \left(K_{t^{\prime}}^i\right) \odot V_{t^{\prime}}^i}{\operatorname{convld}\left(\exp \left(K^i\right), \exp \left(w^i\right)-1\right)+\sum_{t^{\prime}=1}^T \exp \left(K_{t^{\prime}}^i\right)} . \]

这里\(Y_t^i \in R^{\frac{d}{h}}, Q^i, V^i \in R^{T \times \frac{d}{h}}, K^i \in R^T, w^i \in R^s ;\) conv1d \((x, w)\) 是深度可分离1d卷积操作,其中卷积过滤器w在通道维度上共享。注意方程6可以直接解释为具有以下三个方面的专用卷积层:1)全局连接,2)非负卷积权重和3)复杂的除法/乘法门控机制。我们通过实验表明,这三个方面都对AFT-conv的性能有显著贡献。

参数化

根据经验,我们发现适当地参数化位置偏差 \(w\) 是很重要的。对于 AFT-full 和 AFT-local,我们采用了 \(w\) 的因式分解形式:

\[w_{t, t^{\prime}}=u_t^T v_t^{\prime}, u \in R^{T \times d^{\prime}}, v \in R^{T \times d^{\prime}} \]

其中 \(d^{\prime}\) 是一个小的嵌入维度(例如128)。这个简单的因式分解不仅大大减少了参数数量(\(2 T d^{\prime}\) 对于 \(T^2\)),而且在模型的训练和测试中经验证明可以提高性能。

对于 AFT-conv,因数分解技巧不适用。我们采用了一个简单的重新参数化方法,对于每个头 \(i\),我们让

\[w^i=\gamma^i \frac{w^i-\operatorname{mean}\left(w^i\right)}{\operatorname{std}\left(w^i\right)}+\beta^i, \]

其中 \(\gamma \in R^h\)\(\beta \in R^h\) 是可学习的增益和偏置参数,均初始化为0。


以上转译自论文原文,仅为作者的阅读笔记而不是主观想法。

资料来源:

https://arxiv.org/pdf/2105.14103.pdf