URL
TL;DR
- 传统的
DDP
训练大模型有个很大的问题是:由于DDP
训练中每个GPU
上都需要存储模型状态信息和其他状态信息:- 模型状态信息包括:
Prameter
+Gradient
+Optimizer states
- 其他状态信息包括:激活值(
Activation
)、各种临时缓冲区(buffer
)以及无法使用的显存碎片(fragmentation
)
所以显存占用较大,使得不能通过DDP
训练较大的模型。
- 模型状态信息包括:
- 本文提出来
Zero-DP
和Zero-R
两种大模型训练优化策略,分别解决模型状态信息过大和其他状态信息过大的问题。 Zero-DP
全称是Zero Redundancy Optimizer Data Parallelism
,主要优化点是将GPU
间冗余存储的模型状态信息(Parmeter + Gradient + Optimizer states
)删除,每个GPU
只保留一小部分信息,所有GPU
上的信息汇聚之后才是完整模型。Zero-R
的目的是解决Tensor Parallelism
训练时,Activation
信息占用过大的问题。
Algorithm
ZeRO-DP
ZeRO-DP
想要解决什么问题?
- 解决
DDP
训练过程中,Parmeter + Gradient + Optimizer states
对显存的占用过大,导致无法训练很大的模型的问题。
没有其他方法可以降低显存占用吗?
- 有的,比如混合精度训练 (https://arxiv.org/pdf/1710.03740) 就是一个常用的有效的手段
- 已知
LLM
通常使用Adam / AdamW
优化器进行优化,假设模型参数量为 ,那么用混合精度训练过程中,模型状态信息实际的显存占用是:Parameter
占用 字节(fp16
类型存储)Gradient
占用 字节(fp16
类型存储)Adam Optimizer state
占用:fp32 parameter
备份占用 字节fp32 momentum
一阶动量占用 字节fp32 variance
二阶动量占用 字节
- 共计 字节,其中
Optimizer states
占用75%
,因此Optimizer states
是本文主要想解决的问题
ZeRO-DP
是如何解决模型状态信息占用过大问题的?
ZeRO-DP
的核心思想是Data Parallelism
时,每个模型只保留 的模型状态信息,N
表示GPU
数量,在需要信息聚合时使用Ring Reduce
来聚合。- 有三种程度的优化:
- 只删除冗余的优化器状态,所有优化器的信息加起来才是完整优化器状态。
- 删除冗余的优化器状态和梯度。
- 删除冗余优化器状态、梯度和权重。
分块存储 + Ring Reduce
会带来额外的通信代价吗?
All Reduce
运算来计算设备间梯度均值(求和)
Reduce Scatter
来求和,All Gather
做broadcast
同步到所有设备上
- 使用
Ring Reduce
实现Reduce Scatter
颜色逐渐累积代表梯度累加
- 使用
Ring Reduce
实现All Gather
做broadcast
深色逐渐蔓延到所有设备,代表
broadcast
- 和 不会带来额外的通信代价, 会
- 证明:
- 传统
Data Parallelism
在计算梯度后,需要进行一次All Reduce
来计算设备间的梯度均值,All Reduce
可分为Reduce Scatter
和All Gather
两步,每一步实际上都是一次Ring Reduce
Reduce Scatter
如图三所示,把所有设备梯度累加,分散到各个设备上All Gather
如图四所示,把分散到各个设备上的梯度累加值同步到所有设备上Ring Reduce
是一种理论最优通信算法,每个周期内每个设备的发送和接收都被占用,因此梯度总字节数为 (fp16
),由于做了两次Ring Reduce
,所以每个设备的通信总字节数为 (发送和接收算一次)
ZeRO-DP
也需要通过All Reduce
计算设备间梯度均值,也需要Reduce Scatter
和All Gather
两个步骤,也是通过两次Ring Reduce
实现,具体做法和传统Data Parallelism
没有什么区别,因此每个设备的通信总字节数也是ZeRO-DP
在前向传播时,也需要用All Reduce
计算得到全量Parameter
,这一步在传统Data Parallelism
中并不需要,因此会有额外通信量
- 传统
ZeRO-R
ZeRO-R
想要解决哪些问题?
Tensor Parallelism
中Activation
冗余存储问题- 临时缓冲区占用问题
- 显存碎片导致无法使用的问题
ZeRO-R
如何解决 Tensor Parallelism
中 Activation
冗余存储问题?
- 是通过激活值检查点分片存储(
Partitioned Activation Checkpointing
,简称 )来解决Tensor Parallelism
中Activation
冗余存储问题。 - 其中的 分片 思想和
ZeRO-DP
是一样的,N
个设备每个设备只存储 的信息,通过All Reduce
进行信息聚合 - 激活值检查点 思想是来自 https://arxiv.org/pdf/1604.06174 这篇论文,实际上是传统前向反向传播过程和
ReCompute
前向反向传播过程的折衷,也是空间和时间的折衷。 - 传统前向反向传播过程:存储每个中间算子的前向传播计算结果,用于反向传播计算梯度
ReCompute
前向反向传播过程:不存储任何中间算子的前向计算结果,反向传播需要用时重新计算
Activation Checkpointing
前向反向传播过程:存储部分中间算子前向计算结果,反向传播需要用时,从拓扑上游最近的检查点开始重新计算
ZeRO-R
如何解决临时缓存区空间占用问题?
- 通常情况下,临时缓存区是用来做数据聚合的,例如
All Reduce
操作,模型越大,需要的临时缓存区也越大。 ZeRO-R
用了一个非常简单的策略:在模型大到一定程度后,使用固定尺寸的融合缓存区,即常数尺寸缓存区(Constant Size Buffers
,简称 )
ZeRO-R
如何解决显存碎片化问题?
- 问:为什么会产生显存碎片化问题?
- 答:因为
Activation Checkpointing
的存在,不同的Activation
有不同的生命周期,非checkpoint
的Activation
用完即弃,checkpoint
的Activation
需要反向传播之后才可以丢弃,不同生命周期的存储空间交织就会出现显存碎片 - 问:显存碎片化会带来什么问题?
- 答:显存碎片化通常会有两个问题:
- 大量的碎片无法使用,使得实际空间占用不多,但出现
OOM
- 大量存储碎片会导致系统在分配内存时需要大量的搜索来找到合适的空间,会带来额外耗时
- 大量的碎片无法使用,使得实际空间占用不多,但出现
- 问:那
ZeRO-R
如何解决显存碎片化问题? - 答:预先给不同生命周期的
Activation
划分不同的存储空间,避免产生交织。例如:把显存划分成两段连续的存储空间,其中一段存储Checkpoint Activation
另外一段存储Temporary Activation
。
Thought
- 当大模型时代到来,算力和存储不够用了,大家也开始审视之前做的东西是不是时间/空间低效的,是不是还能榨点油水。