RetNet:万众期待的 Transformers 杀手

发布时间 2023-09-14 00:01:26作者: 冷冻工厂

动动发财的小手,点个赞吧!

Transformer 已成为大语言模型上的架构,因为它有效地克服了循环神经网络 (RNN) 的顺序训练问题。然而,Transformer也并不完美,因为它们仅解决了所谓“impossible triangle”的两条臂。微软的 RetNet 声称位于这个“impossible triangle”的正中心,胜过了所有尝试过但未能实现这一壮举的方法。突破:

  • RetNet 具有更好的语言建模性能
  • RetNet 内存消耗降低了 3.4 倍
  • ….8.4 倍更高的吞吐量
  • …延迟降低 15.6 倍

这些速度比当前的 SOTA 快几个数量级,同时还提供更好的性能!如果其他团队能够复制这一点并且进入开源领域,这将是巨大的进步,但目前微软绝对是遥遥领先

但问题是,是什么让它如此伟大?我们将在这篇博文中揭晓这个问题的答案。我们将切开每个方程以更深入地研究并可视化正在发生的事情。我们将用一个已完成的示例来处理 RetNet,看看它如何推翻Transformer并显示出成为新王的巨大希望。

目的

“不可能三角”代表当前的序列模型无法同时实现训练并行性、低成本推理以及强大性能的所有三个期望维度。三角上的方法表示它们实现的两个维度,但缺少第三个顶点的所需属性。然而,RetNet 设法在单个框架下实现所有属性。

让我们更详细地理解这一点,因为这是开发该架构的核心动机。

训练并行度

顾名思义,RNN 循环处理序列,即按顺序一个接一个地处理。某个时间步的输入处理取决于前一个时间步的隐藏状态,因此在处理完所有先前的步骤之前无法并行计算。这会显着减慢训练速度。

由于 Transformers 部署了高度并行化的自注意力机制,因此每个时间步的输出都可以使用 Q、K、V 矩阵并行处理。然而,这种帮助 Transformer 在 GPU 上很好地并行的自注意力却成为了推理时最大的敌人,我们稍后会看到。

RetNet 借鉴了两全其美的优点,因为它配备了三种处理范例——并行训练、循环/分块推理。它采用了 Transformer 的可并行自注意力机制,尽管有一些非常巧妙的技巧可以帮助它克服缺点!

推理成本+内存复杂度

推理成本(每个时间步)是指 GPU 内存、吞吐量和延迟,而内存复杂性是指内存占用相对于序列长度的缩放法则。由于 RNN 使用简单且廉价的运算(例如矩阵乘法),因此它们的推理成本不会随序列长度而变化,而是恒定的(即 O(1))。同时,它们的内存复杂性与序列长度呈线性关系。

另一方面,由于 Transformer 使用自注意力块,因此它们需要在推理时维护“NxN”矩阵,您可以看到推理成本呈线性比例 (O(N)),内存复杂度呈二次比例 (O(N2) )。

虽然 RetNet 使用 Transformer 的自注意力模块来并行化训练并实现最先进的性能,但它不会遇到上述推理成本和内存复杂性问题。这是由于它调整了自注意力模块,它用保留模块+它用来在推理时模仿自注意力的循环推理范式取代了它。

性能

Transformer 相对于 RNN 的主要优势在于,由于其自注意力头,它们能够处理较长的序列,而不会发生灾难性遗忘。 RetNet 实现了与 Transformers 类似或更好的性能。

RetNet:概述

RetNet的主要贡献可以概括为两大点。然而,美妙之处在于它们如何从 A 点到达 B 点的细节,我们将在随后详细讨论:

  1. RetNet引入多尺度保留机制来替代多头注意力。这是消除自注意力机制中的魔鬼这一组成部分的关键。尽管如此,这种保留机制有一个小小的理论上的缺点。

  2. RetNet 适用于三种计算范式,而只有一种 Transformer 在训练和推理过程中使用相同的序列处理范式。

A. 并行表示使训练并行性能够充分利用 GPU 设备。

B. 循环表示在内存和计算方面可实现高效的 O(1) 推理。可以显着降低部署成本和延迟。此外,在没有键值缓存技巧的情况下,实现也得到了极大的简化。

C. 分块循环表示可以执行有效的长序列建模。我们对每个本地块进行并行编码以提高计算速度,同时对全局块进行循环编码以节省 GPU 内存。

RetNet 与 Transformers

RetNet 建议充分利用两个领域的优点,并展示我们如何才能实现这一目标。它使用 Transformer 的可并行训练范例,而不是 RNN 低效且缓慢的自回归步骤。然而,在推理时,由于保留机制而不是自注意力机制,RetNet 顺利地采用了 RNN 的更多内存和计算效率更高的循环范式。

并行表示

RetNet 在训练期间部署原始 Transformer 的并行表示学习,以摆脱 RNN 的限制性自回归序列处理。然而,它对整个过程做了一些改变。

我们可以看到,RetNet 放弃了 Hadamard 产品的 softmax 运算,采用新引入的 D 矩阵,然后进行 GroupNorm 运算。这不奇怪吗?

Softmax 操作是自注意力的整个基础,Transformers 从中获得了最先进的性能——softmax 为输入序列中的每个标记赋予相对注意力权重,帮助模型学习和保留长期依赖关系。然而,如果您还记得的话,这种 softmax 计算正是导致 Transformer 推理时间性能不佳的确切原因,因为它们必须将 softmax(Q.KT) 保存在内存中,该内存是一个 NxN 矩阵,并且与序列长度呈二次方增长!在训练和下游性能中赋予 Transformers 优越优势的一件事是它在推理过程中最大的敌人!

之前的许多工作都试图通过引入近似此 softmax 操作的方法来绕过此步骤,但最终的架构最终会受到性能影响。但随后 RetNet 凭借其神奇的 D 矩阵和 GroupNorm 出现,最终表现出与 Transformers 类似或更好的性能,同时在推理过程中速度更快、内存效率更高,并且还能够在训练过程中使用并行化进行高效训练!

  • 那么这个 D 矩阵 + GroupNorm 运算是什么?它有什么帮助?

我们将在接下来的部分中详细探讨这一点,据我所知,Transformers 中的 Softmax 实现了两个目标:

  1. 对不同的时间步进行不同的加权。这有助于模型“关注”序列的不同部分并拾取正确的信号。这也是其性能优于 RNN 的重要因素之一。提议的 D 矩阵负责这部分,但有一个限制性假设(在我看来)。 D 矩阵是一个因果掩模,可以说具有已定义的预定义权重因子。具体来说,它可以防止每个时间步关注未来的步骤,同时它相对于之前的所有时间进行加权-步骤但以预定义的指数方式。 D 矩阵假设最近的时间步骤比过去的时间步骤更重要,因此对先前的步骤部署指数衰减权重。因此,虽然 softmax 足够灵活,可以对不同的步骤进行不同的权重,但 D 矩阵以固定的预定义方式(指数衰减)对所有步骤进行权重。虽然这是直观的,甚至对于大多数顺序情况可能都是如此,但它仍然不如 softmax 灵活。但代价是高效的 O(1) 推理和 O(N) 内存复杂度。看看结果,看起来这确实是现实用例中 softmax 操作的一个非常好的近似!
  2. 引入非线性。在没有softmax的情况下,Q.KT操作只是一种仿射变换,无论堆叠多少层,都会极大地限制其学习能力。 GroupNorm 运算引入了急需的非线性。

新的保留机制

保留机制本质上是 RNN 和 Transformer 核心原理的融合:REcurrent + self-attention = RETENTION

现在让我们更详细地看看 Transformer 之间的差异/相似之处,如下图 所示。

如果您还记得的话,原始 Transformers 输出是通过首先将仿射变换应用到带有 WQ、WK 和 WV 矩阵的输入嵌入 X,然后对结果 (Q.KT) 进行 softmax 计算,最后将结果与 V 相乘来生成的。矩阵。它看起来像这样,其中 O 是包含输入矩阵 X 的上下文嵌入的输出矩阵:

由于 RetNet 在循环范式和并行范式中运行,作者首先在循环设置中激发 RetNet“保留”块(即单独处理每个“n”输入元素)。然后,他们对提出的循环保留块进行矢量化。因此,最初的循环公式看起来像这样:

我们可以清楚地看到,尽管有一些变化,但这看起来与原始 Transformer 公式非常相似。我们看到 softmax 已被位置嵌入项 (pos) 取代。RetNet 用 pos 矩阵替换了原始 Transformer 的 softmax。上面的等式可以扩展如下,以更多地了解 pos 正在做什么:

其中 pos’ 是 pos 的复共轭。使用 γ 作为标量值进一步简化上述方程,我们可以在训练迭代期间轻松并行化此计算,如下所示:

我们可以清楚地看到,获取Q、K和V的第一步与原始Transformer相同。除此之外,现在我们将 pos/pos 的嵌入按元素乘以 Q 和 K 矩阵。但是看看这个并行训练阶段公式的最后一步,我们可以看到它与原始 Transformer 计算非常相似(尽管是 softmax à D 矩阵替换),因此是完全可并行的(D 矩阵可以预先计算)因为它只是一个相对位置嵌入+因果掩码表示)。现在我们知道,经过微小改动的 RetNet 可以在并行范式中进行训练。

相对位置嵌入 (pos/pos’)

我们不需要过多讨论这些位置嵌入的细节,因为它们借用了 Transformers/LLM 的原始位置嵌入的直觉和功能。但是,为了了解该方程式中到底发生了什么,让我们深入了解一下。

从欧拉公式我们已经知道:

因此,上面等式 4 中的 θ 通过向量旋转将“相对位置信息”编码到 Q 和 K 矩阵的每个向量中。这本质上使它们“位置感知”,并通过 Q、K 向量及其各自位置特定向量旋转之间的哈达玛积来实现,如下所示:

每个位置处的 Qn 和 Km 向量均按红色箭头所示的旋转向量进行旋转。从上面方程 5 中附带的矢量旋转图可以看出,当 n=m=1 时 einθ/ eimθ 有一次旋转。这些是旋转矢量 Q1 和 K1 位置。类似地,对于 n=2,m=2 位置,矢量具有双旋转。具有相同旋转(即对角线上的所有位置)的向量之间的点积将 =1。此外,当n=1时,m=2点积位于两个不同旋转的向量之间,并且将对应于该位置处的向量的特定位置值。请注意,随着我们进一步移动(例如,m=2、n=1、..、n),矢量点积趋于 0,因为矢量趋向于彼此正交。

好的,这就是方程 4 中的 θ 与 Q 和 K 中的每个向量逐元素相乘的部分,以使它们“位置感知”。接下来我们将看看所提出的 D 矩阵的作用。

因果掩蔽和指数衰减 (D)

D 矩阵充当因果掩模以及过去位置的指数衰减加权方案。

从上面Eq6中D的定义,我们可以看到保留计算中的D实现了masked Attention和softmax在self-attention中所做的任务。

Masked-attention — 因果掩码:对于 n>m 的位置,(Q.KT) 的向量乘以 0,以确保序列处理的因果假设到位。这确保了未来时间步长的信息不会泄露。

Softmax 指数衰减:对于 n<=m 的位置,(Q.KT) 的向量使用指数衰减因子 γ 进行加权。这意味着标记的过去越远,它对于当前时间步的重要性就越低。这样就实现了对之前时间步的信息进行不同权衡的任务,这一点在self-attention中是通过softmax实现的。虽然由于其限制性假设,这比 softmax 操作受到更多限制和不灵活,但作者已经证明它同样有效!

因此,D 矩阵最终看起来像这样:

下一步是查看“位置感知”(Q.KT) 和 D 矩阵如何结合在一起,给出 X 中每个输入标记的最终输出嵌入。

结合

现在,我们可以使用给定的 Hadamard 产品组合上面的操作,以获得并行操作的最后一步,详细信息如方程 4 所示:

您现在明白为什么我们不关心“位置感知”(Q.KT) 的上三角,因为在使用 D 进行此操作后这些值被设置为 0!您现在可以清楚地看到整个操作在训练期间如何完全并行化。

并行训练——工作示例

假设我们只有两个标记序列,即 N=2,嵌入大小 D=3。假设在此示例中,这为我们提供了以下 NxD 维度的 Q、K 和 V 矩阵(第一行是每个矩阵中的第一个标记,依此类推):

我们使用训练期间使用的 RetNet 并行范例获得了 2 个输入标记的最终上下文嵌入。

推理的循环保留

RetNet 的循环保留范式是通过解构并行计算获得的,使得循环表示在推理过程中的工作原理完全相同,但内存复杂度只有一小部分。这是这项工作的主要贡献之一,也非常有趣。让我们看看如何:

这看起来很熟悉——具有 RNN 的一般流程,但单元内部的操作类似于 Transformer!让我们更详细地看看这里发生了什么,并添加一些注释以使事情更清楚:

我们注意到的第一件事是 Q,K,V 矩阵现在是时间步长索引的(n 个下标),因此是 1xD 维度的向量,而不是之前的 NxD 矩阵。这是有道理的,因为它是一个循环设置,并且给定的块显示了特定令牌的处理。我们注意到的第二件事是,状态向量 S 从前一个时间步向前传递以传达时间/位置信息。该 Sn-1 在每个时间步乘以指数衰减/折扣因子 γ,以循环实现 D 的任务。这控制了状态向量 S 中保留的信息类型,以供将来的步骤使用。

第三个也是最有趣的部分是,如果您查看单元内的计算,我们会发现循环设置中的第一个操作现在是 KT.V 而不是 Q.KT,并且 Q 稍后会相乘。这些矩阵在并行训练阶段以不同的顺序进行训练,但是在推理时它们的计算方式不同,我们仍然希望这能起作用?这是一个巧妙的技巧,也是本文的主要贡献之一,它展示了如何通过一些不直观的修改在循环范式中解构 Transformer 设置。让我们看看循环范式的具体操作是什么:

等式 7 中的操作总结了我们看到的内容。首先,使用先前状态向量上的折扣因子并将其与 KT.V 操作相加来更新状态向量。最后,将更新后的状态向量与Q相乘,得到本步骤的最终输出。随后整理所有输出以形成最终输出矩阵。由于我们已经从上面的示例中了解了 γ 和 KT.V 类型的运算如何工作,因此这已经非常直观了。只是一个棘手的问题,KT.V 如何代替 Q.KT 来达到相同的结果?

循环推理——工作示例

为了保持一致性,我们继续使用之前的两个标记序列 (N=2),嵌入大小 D=3 的示例。因此,我们的旧 Q、K、V 矩阵为:

步骤 1:计算 n=1 时的 KT.V。如果您没有注意的话,KT.V 并不是像 Q.KT 那样的点积,而是两个向量之间的外积,它给出一个矩阵而不是一个标量!此外,现在我们将迭代地处理令牌。因此对于 n=1:

第2步:获取S1。由于没有 S0,S1 与上一步相同,没有添加任何内容:

步骤3:将Q和S1相乘得到最终输出。这里有一个问题。虽然图表和方程没有明确提及这一点,但伪代码表明我们需要进行逐元素乘法,然后进行逐列加法,以获得每个时间步的最终输出向量,如下突出显示:

因此,在匹配形状一段时间后,伪代码终于有所帮助,我们得到第一个标记的以下输出:

您是否注意到,此处通过循环保留获得的第一个标记嵌入与前面方程中的并行训练计算相同?因此,即使对循环块计算进行了不直观的更改,结果也与第一步完全匹配。但是让我们完成另一个步骤,看看如何使用此步骤中计算的 S1。

步骤 4:计算 n=2 时的 KT.V。重复相同的外积过程,我们得到:

第5步:获取S2。这里的计算稍微复杂一些,因为我们必须将 S1 与折扣因子 γ 相乘,然后再将其添加到上述步骤的结果中:

第6步:获得最终输出。现在我们已经利用了先前的状态信息并用当前状态的 K、V 信息更新了它,我们可以将其与 Q 相乘以获得最终输出:

瞧!我们再次看到第二个标记的嵌入与方程 8 中的并行保留完全相同。对计算所做的细微更改使 RetNet 能够将并行训练计算解构为循环计算完全没有任何近似!
对我个人来说,看到一切都井井有条,真是非常非常令人兴奋和美好。感谢作者实现了这一壮举!我们现在拥有一个结合了 Transformer 的训练优势和 RNN 的推理有效性的架构。

总结

这是一篇很长的文章,信息量很大,希望你没有中途睡着。虽然我们深入研究了 RetNet 每个组件的内部工作原理,通过工作示例来理解直觉,但这仍然不是完整的故事。本博客中故意遗漏了许多更有趣的细节和组件,您可以在原始论文中找到它们。该博客旨在为您提供所有必要的知识和数学知识,以便您自己更详细地阅读本文并进一步与大家分享您的学习和想法。

本文由mdnice多平台发布