Root Mean Square Layer Normalization

发布时间 2023-07-24 10:53:34作者: 馒头and花卷

Zhang B. and Sennrich R. Root mean square layer normalization. NIPS, 2019.

RMSNorm 节省时间.

RMSNorm

  • 假设输入为 \(\mathbf{x} \in \mathbb{R}^m\), 然后

    \[\mathbf{a} = \mathbf{W} \mathbf{x} \in \mathbb{R}^{n}, \\ \mathbf{y} = f(\text{Norm}(\mathbf{a}) + \mathbf{b}) \in \mathbb{R}^{n}. \]

    其中 \(f(\cdot)\) 是 element-wise 的激活函数.

  • LayerNorm 采取的是如下的方式 (注意, 下面的 \(/\) 是 element-wise 的):

    \[\text{LayerNorm}(\mathbf{a}) = \frac{\mathbf{a} - \bm{\mu}}{\bm{\sigma}} \odot \mathbf{g}, \]

    其中

    \[\bm{\mu} = \text{mean}(\mathbf{a}), \\ \bm{\sigma} = \sqrt{\text{mean}((\mathbf{a} - \bm{\mu})^2)}. \]

  • RMSNorm 采用的是如下的方式:

    \[\text{RMSNorm}(\mathbf{a}) = \frac{\mathbf{a}}{\text{RMS}(\mathbf{a})} \odot \mathbf{g}, \]

    其中

    \[\text{RMS}(\mathbf{a}) = \sqrt{\text{mean}(\mathbf{a}^2)}. \]

  • 由于不用计算均值, RMSNorm 所需的计算时间会少一点, 但是效果是差不多的:

  • 此外, RMSNorm 保留了一些重要的不变性:

代码

[official]