【Transformer 基础系列】手推显存占用

发布时间 2023-12-26 22:48:47作者: China Soft

https://zhuanlan.zhihu.com/p/648924115

 

本文试图以最清晰的方式手动推导 Transformers 每一步的参数量到显存、计算量问题。理解底层,才能更好的做训练和优化。可能是目前最全的大模型显存优化方案分析。

本文内容包括
(1)模型训练和推理过程中的显存占用
(2)KV cache、中间激活值等显存占用
(3)模型状态显存优化方案: Megatron(3D) + Deepspeed(ZeRO)(更新于2023-09-11)
(4)激活值显存优化方案:重计算 + 3D 并行(更新于2023-08-11)
(5)KV Cache 显存优化方案:MQA 和 GQA(更新于2023-09-11)

关于计算量、参数量的分析在本系列其他文章记录。

乞力马扎罗不说话:【Transformer 基础系列】手推计算量FLOPS和训练时间
乞力马扎罗不说话:【Transformer 基础系列】模型参数量

0 前置知识和标记

  1. 显存占用 = 参数数量 x 该参数精度占用的 bytes 数
  2. 换算关系:Int8 需1 bytes, fp16 / bf16 数需 2 bytes, fp32 需要 4 bytes
  • transformer 模型的层数为 
  • 隐藏层维度为 
  • 注意力头数为 
  • 词表大小为 
  • 批次大小为 
  • 序列长度为 

1 训练过程

训练中的显存占用分两块,分别是:

  1. 模型状态,参数、梯度和优化器状态
  2. 剩余状态, 中间激活值、临时buffer、显存碎片等

1-1 模型状态显存

模型状态指的是和模型参数、梯度和优化器状态相关的显存占用。

设模型参数量为 Φ ,模型参数(fp16)、模型梯度(fp16)和优化器状态(fp32),总参数量 = 2Φ+2Φ+�Φ=(4+�)Φ 。参数量和模型配置之间的关系可以看另一篇文章推导,合计约 �ℎ+�(12ℎ2+13ℎ) 。

一般是混合精度训练,梯度/权重为 fp16,但所有涉及累加操作都需要 fp32 防止误差累计,同时优化器也要存 fp32 主权重。以 Adam 系列为例,总数为 2Φ+2Φ+(4+4+4)Φ=16Φ 。

  1. 这部分比较固定,主要和参数量有关,和输入大小无关。
  2. 在整个训练过程中都要存在显存中。 模型参数一般只能通过并行切分(Tensor Parallelism/Pipeline Parallism)能减少。优化器状态一般通过ZeRO 来减少。
  3. 不同优化器的 K 值不同,算法的中间变量、框架的实现都有可能有一定区别。复旦 LOMO 的方法也是基于类似的思路重新改进 SGD 来减少 K 值和梯度部分显存。

不同优化器的 K 值

优化器K值构成
adamw 12 fp32 主权重 4 + 动量 4 +方差 4
SGD 8 fp32 主权重 4 + 动量 4
bitsandbytes 6 fp32 主权重 + 动量 1 + 方差 1
LOMO 0  

1-2 中间激活值显存

激活(activations)指的是前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量。

中间激活值占用显存分两个部分分析:Attention 和 MLP,Embedding 没有中间值。最终合计 (34��ℎ+5��2�)∗�=(13��ℎ+5��2�+21��ℎ)∗� 。

  1. 这部分比较灵活,激活值与输入数据的大小(批次大小 b 和序列长度 )成正相关。
  2. 在训练过程中是变化值,特别是 batch size 大的时候成倍增长很容易导致 OOM。
  3. 可以通过重计算、并行切分策略减少。

直接看公式不太直观,下面是 GPT-3 和 LLaMA 为例计算的模型显存和中间激活值显存占用比例。

Attention 层中间显存表

self-attention 块的计算公式如下: �=���,�=���,�=��� ����=�������(���ℎ)⋅�⋅��+�

Attention 层单步中间激活值显存表

MLP 层中间显存表

MLP 块的计算公式如下:

�=�����(�����1)�2+����

MLP 层单步中间激活值显存表

2 模型状态显存优化方案

如 1-1 所推,模型状态占用 2Φ+2Φ+�Φ=(4+�)Φ,其中一般只能通过各种各样的并行来解决。比如模型参数显存优化一般是 模型并行,包括张量并行 (tensor parallel, TP) 和流水线并行(pipeline parallel, PP),业内通用方案参考 Megatron。只做数据并行 (data parallelism, DP) 情况下,模型参数和优化器状态一般通过 Deepspeed ZeRO 来均摊到所有卡上。

总的来说,都是用通信时间换显存空间。业内很多框架也是基于 Megatron+Deepspeed 这一套比较成熟的底层上改的。

2-1 Megatron-LM 3D Parallel

Megatron-LM 里称之为 Model Parallel,也叫 Tensor Parallel。

Q / K / V 矩阵做列切分(纵刀流),对Dropout做行切分(横刀流),方便GPU 中间计算各算各的,减少额外通信

不切分的时候各层参数如下表

Model Parallel 需要切分所有参数 embedding / attention / mlp 为 � 份,其中 embedding 层 V 在 Megatron 中会补全到最小的 � 倍数以便于切分。因此,显存为 �′ℎ+(12ℎ2+13ℎ)�� 。

Pipeline Parallel 需要按层切分所有参数,一般是 � 层均分 � 份,embedding 在最前面一层或单独一层。不过针对一些奇特结构不能整除的(比如44层的 NeoX)可能需要设计特定切分策略。每层显存为 (12ℎ2+13ℎ)�/� 。

这里显存都没什么好说的,主要是通信量值得分析。

2-2 ZeRO Stage 1-3

Deepspeed ZeRO 本质上都是在数据并行层面对模型状态一步步做分片(partition),系统内只维护一份模型状态,需要全量状态时就执行通信。

ZeRO Stage: 不同 stage 区别主要是切什么,显存占用论文里这张图就很直观了。

  • Stage 1(P os): fp32 optimizer state
  • Stage 2(P os+g): fp32 gradient + Stage 1
  • Stage 3(P os+g+p): fp16 parameters + Stage 1 + Stage 2

  1. 总卡数越多越省。stage 1 下,按照一个节点估算,模型状态从 16Φ→5.5Φ ,如果是现在一般规模的预训练规模,卡数至少上百,优化器状态可以忽略不计,模型状态基本接近。
  2. Stage 1和2 不会额外增加通信量,Stage 3 会额外增加 50%(forward 和 backward 时分别一次 broadcast 参数以获得全量参数),因此后面 Deepspeed ZeRO++ 支持了 stage 3 量化和参数分层存储来降低通信量。
  3. ZeRO 除了分片还支持 offload,显存不够内存来凑,但是内存显存之间的 I/O 成本也不可忽视,因此实际训练中还是很少用。
  4. All-reduce 通信到底怎么充分利用设备和设备之间的带宽也很有趣,请参考袁老师文章 OneFlow:手把手推导Ring All-reduce的数学性质

3 中间激活值显存优化

1-2 中中间激活值式子可以看到,激活值与输入数据的大小成正相关,batch size 较大时远超过模型参数占用。因此主要显存优化是优化中间激活值,有重计算和并行两个思路

(34��ℎ+5��2�)∗�=(13��ℎ+5��2�+21��ℎ)∗�

3-1 重计算

  • activation checkpoint (recompute) :时间换空间,前向的时候重新计算一次来避免存储。计算量的增加参考另一篇博客。
    • 全部重计算可以减少到只有每个 attn 层输入的 2��ℎ�
    • 部分重计算可以减少 �(�2) 项相关的 QK 乘法中间结果,其他不变,减少到 34��ℎ�

3-2 TP 中间激活值

Tensor Parallelism 通过切 attention/mlp 层减少中间值

  • attention (8��ℎ+5��2�)/�
  • mlp 16��ℎ/�
  • dropout/layernorm 6��ℎ (外层的不受影响,但 softmax dropout 也要切 t)
  • attention/mlp input 2��ℎ+2��ℎ (f' 表示需要在 forward/backward 中需要 all reduce因此attn, mlp 输入也是完整的)

显存合计 ��ℎ(10+24�+5��ℎ�)

3-3 SP+TP 中间激活值

Sequence Parallelism 输入沿着 seq 维度切,从而进一步减少两个输入和 layernorm,dropout 的中间激活值

  • attention (8��ℎ+5��2�)/� 不变
  • mlp 16��ℎ/� 不变
  • dropout/layernorm 6��ℎ/� 外层的 sequence parallel 也切 t 份
  • attention/mlp input (2��ℎ+2��ℎ)/� ,外层g, g' 是 all-gather 操作

显存合计 ��ℎ�(34+5��ℎ)

3-4 PP+SP+TP 中间激活值

Pipeline Parallelism 没有减少

  • 和 pp size 无关, 1F1B pp 同时有 L/p 个 microbatch,即便参数只有 L/p 这么多,但是激活状态需要整个 batch 全保留才能backward 时用
  • Megatron 里 interleaving 如果开了需要存 �(1+(�−1)/��) 层的,m 为 interleaving stage

显存合计 ��ℎ��(34+5��ℎ)

3-5 总结

上述优化方案和组合方案优化后的中间激活值如下表

以 LLaMA 和 GPT 预估部分重算情况下模型显存和中间激活值比例

感兴趣也可以根据公式算全部方案下中间激活值节省。以下是博客:

4 推理过程

推理显存没有梯度和优化器,主要是模型参数,一般总显存经验值估算为 1.2 倍参数量

  1. 模型参数 fp16 下推理参数占 2Φ bytes
  2. KV Cache (如有) 缓存 KV Cache 加速方法
  3. 中间结果和输入数据 比较少,一般 20% 内

4-1 KV Cache 显存分析

KV Cache 是典型的推理加速方法,推理时缓存第 n 个 token 及前计算结果,第 n+1 个 token 相当于增量计算从而加速。

  1. 预填充阶段:输入一个 prompt 序列,为每个 transformer 层生成 key cache 和 value cache(KV cache) ()��⋅���(�) ,其中 ��∈[�,���],��∈[ℎ,���],���∈[�,���,ℎ] 。这里是简化后的单头,多头时 ��∈[�,���,�,ℎ/�] 。
  2. 解码阶段:拼接并 concat 更新 KV cache,一个接一个地生成词,当前生成的词依赖于之前已经生成的词。假设输入序列的长度为 s ,输出序列的长度为 n ,最后一个token 推理时长度为 (s+n), KV Cache 占用峰值。

所以每层 2个K/V 各(s+n)bh ,每个 fp16 占 2 个 bytes,KV cache 的峰值显存占用大小为 �(�+�)ℎ∗�∗2∗2=4��ℎ(�+�)

KV Cache 占模型显存比例

4-2 MQA & GQA

面向推理的显存(和速度)的优化主要是 Multi-Query Attention (MQA) 和 Group-Query Attention (GQA),本质上是通过多头共用 KV Cache 减少内存 I/O 时间占总时间比例。已经应用或支持的包括 ChatGLM2、LLaMA2、和 flash attention v2 解决方案。

MHA(n:n) vs GQA(n/t : n) vs MQA(1:n)

这里显存的节省比较简单,如 MQA KV Cache 为原来的 1/n 倍,GQA 为原来的 1/������ 倍。主要是为了加速而不是显存优化提出的方法,推内存时间减少的比较值得一看。

5 参考

[1] Reducing activation recomputation in large transformer models
[2] ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
[3] Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelis
[4] 
[5] 
[6] OneFlow:手把手推导Ring All-reduce的数学性质