[论文阅读] Replacing softmax with ReLU in Vision Transformers

发布时间 2023-12-12 10:53:40作者: NoNoe

Pre

title: Replacing softmax with ReLU in Vision Transformers
accepted: Arxiv 2023
paper: https://export.arxiv.org/abs/2309.08586
code: None

关键词:attention, parallelization
阅读理由:Google Deepmind,标题挺有意思

Idea

序列缩放能缓解ReLU等激活函数在attention中替换Softmax提高并行性时导致的性能下降,但不确定为何

Motivation&Solution

  1. 注意力中的softmax很重要,但妨碍了并行计算

Background

先前的研究发现如果把注意力的softmax换成逐点激活函数ReLU,准确度会有所下降,作者认为是他们没有将其除以序列长度导致的,同时那些方法仍然依靠normalization来使注意力权重总和为1,(仍保留了无法并行的缺点)

Method(Model)

Overview

用ReLU替换,同时将值除以序列长度,能缓解性能的下降

Attention 原本的注意力权重\(\alpha_{i,j}\)计算如下:

\[\alpha_{i j}=\phi\left(\frac{1}{\sqrt{d}}\left[q_{i}^{\top}k_{1},\ldots,q_{i}^{\top}k_{L}\right]\right)_{j}, \tag{1} \]

其中 L 是序列长度, \(\phi\) 是经典的softmax,而本文就是要探索它的 point-wise 替代。

ReLU-attention 观察到 \(\phi = L^{-1}relu\) 是公式1中softmax有希望的替代,将其称为ReLU-attention

Scaled point-wise attention 本文探索更一般的形式: \(\phi = L^{-\alpha}h,\; \alpha \in [0,1],\; h \in \{relu, relu^2, gelu, softplus, identity, relu6, sigmoid\}\)

Sequence length scaling 除以序列长度这事主要是实验的结果,但作者也给了一定的理论分析:现在的Transformer用的 sotfmax 注意力都要求 \(\sum^L_{j=1}\),这实际上就隐含了 \(\mathbb{E}_j[\alpha_{ij}] = L^{-1}\)。虽然可能不必要,但 \(\phi = L^{-1}relu\) 使得 \(\mathbb{E}_j[\alpha_{ij}]\) 在初始时在 \(O(L^{-1})\) 的量级。维持该条件可能会减轻替换掉softmax后对其他超参数的调整需求。

初始时 q, k 都是 \(O(1)\),因此 \(\frac{\left \langle q_i,k_j \right \rangle }{\sqrt{d}}\) 也是 \(O(1)\)。ReLU指着激活函数能维持\(O(1)\),因此 \(L^{-1}\)对于 \(\mathbb{E}_j[\alpha_{ij}]\) 维持 \(O(L^{-1})\) 是必要的

Experiment

Settings

用了BigVision codebase的训练配置,没有修改超参数。ImageNet-21k训练30epoch,ImageNet-1k训练300epoch,二者都差不多训练了9e5个step

Dataset

ImageNet-21k, ImageNet-1k

Results

图1 将sotfmax替换为relu/seqlen或是用qk-layernorm匹敌视觉Transformer的传统注意力缩放性能。该图展示了模型从小到大训30epoch的结果。

图2 将softmax替换为 Scaled point-wise attention 那种更为一般的形式,观察到\alpha在接近1的时候结果最好,但在此情况下选不出最好的激活函数,为了速度用了ReLU

图3 注意力去掉qk-layernorm,并且使用 $L^{-\alpha}$ 缩放的影响。

图4 使用门控注意力单元加上 $L^{-\alpha}$ 缩放的影响

Main experiment. 图1展示了ReLU-attention在ImageNet-21k上训练时与softmax attention一样的缩放趋势。但看起来性能似乎是稳定地不如。

Effect of sequence length scaling. 见图2

Effect of qk-layernorm. 实验用的qk-layernorm是在计算注意力权重之前把 q,k 拿去过layernorm,据说在提升模型尺寸时能防止不稳定。图3实验了去掉它的效果,表明似乎影响不大。

Effect of adding a gate. 研究加一个门控能否替代用序列长度缩放,但图4表明仍然要缩放才有最好的效果,门不门控没差别(横轴在0的时候并非最优,而且加上门控的线也不是平的)

Conclusion

仍留下许多开放问题,例如仍不确定为何因子 \(L^{-1}\) 会有用,以及该项是否可学。而且也可能有比ReLU更好的激活函数。

Critique

好短。看着google还以为有好东西,结论就是序列缩放有用但不知为何,而且说是提高并行性,不得比较一下吞吐和训练时间?图1有些太笼统了。