Zhangzhe's Blog

The projection of my life.

0%

URL

TL;DR

  • DINO 一样,都是做自监督视觉预训练的,对于 DINO 的主要升级是构建了一个大规模自动化数据处理管线,构建了 LVD-142M 高质量数据集
  • 用这些数据集预训练了一个 1B 参数的 ViT 模型,通过无监督蒸馏的方式,得到用于不同任务的小模型

Algorithm

自动化数据处理管线

dino_v2.png

  • 自动化处理流程包括如下几个步骤,不断重复迭代

1. 数据收集

  • DINOv2 的数据源包括一个大型的未筛选图像数据集和一个较小的经过筛选的图像数据集
  • 未筛选数据集来自网络爬取
  • 筛选数据集来自 ImageNet-22k / Google Landmarks

2. 图像嵌入

  • 对于未筛选数据集,用一个训练好的 ViT-H/16 计算得到图像 embedding vector

3. 图像去重

  • 用特征空间下去重算法,将未筛选的数据集去重

4. 图像检索

  • 在特征空间下聚类,得到与筛选数据集类似的未筛选数据样本

5. 数据增强

  • 让这些聚类得到的类似的未筛选样本作为筛选样本,不断扩大筛选样本的数量和场景丰富度

模型架构

  • DINO v1 中教师和学生使用动量更新的方式不同,DINO v2 使用了常见的 “大老师,小学生” 架构
  • 先训练一个 1B 参数的 ViT 模型作为老师模型
  • 然后再在各个不同任务数据上蒸馏得到小模型

训练策略优化

  • 由于老师模型很大(1B 参数量),所以需要 LM 常用的训练加速方法,包括:
    • FlashAttention
    • Fully-shared Data Parallel (FSDP)

Thought

  • 这套数据处理管线是本文重点,所有的自监督任务,自动化数据处理流程都是必不可少的

URL

TL;DR

  • CLIPOpenAI 提出的一种图文多模态对齐算法,在收集到的 4 亿对图片文本数据对上,将文本和图像编码在同一表达空间下,实现了图文模态的对齐
  • 可以 zero-shot 迁移到其他计算机视觉任务上

Algorithm

CLIP.png

训练时

  1. N 对图片和文本各自编码
  2. 计算得到不同模态之间两两编码的 余弦相似度 RN×N\in \mathbb{R}^{N\times N}
  3. 使用对比学习的方式,提高 N 个正样本的相似度,降低剩余的 N2NN^2-N 个样本的相似度

推理时(以 ImageNet 分类任务为例)

  1. ImageNet-1k 的所有 1000 种类别标签,通过训练好的文本编码器,转换到特征空间中
  2. 将需要分类的图片,通过训练好的图片编码器,转换到特征空间中
  3. 图像编码找到余弦相似度最高的文本编码,对应的类别就是图片类别

模型选型

  • 图像编码器:
    • Vision Transformer (ViT)
    • ResNet-50
  • 文本编码器:Transformer
    • 63M 参数
    • 12
    • 512
    • 49152 词表大小
    • BPE 文本编码方式

Thought

  • 简洁高效,像 OpenAI 固有的风格
  • 有没有可能在 GPT-4 的多模态中用到呢?

URL

TL;DR

  • DINO(Distillation with No Labels) 是一种自监督学习方法,主要用于 Vision Transformer (ViT) 的训练
  • 在无标签的图片数据上训练模型,让模型学习图像的表示意义
  • 利用 MoCo 提出的 Momentum Teacher 算法做蒸馏

Algorithm

dino_1.png

训练流程

  1. 创建两个完全一样的网络,命名为教师 teacher 网络和学生 student 网络
  2. 对同一个输入 x,进行不同的数据增强,得到 x1x2
  3. 交叉计算对比损失,再求均值得到 loss for student
  4. 只对 student 网络进行反向传播和梯度更新
  5. 基于 student 网络的参数更新 teacher 的参数,更新方式是 EMA (exponential moving average),即:θt=λθt+(1λ)θs\theta_t=\lambda \theta_t+(1-\lambda)\theta_s
  6. 更新 teacher 网络输出的中心点:C=mC+(1m)mean(t1,t2)C = m*C + (1 - m)*mean(t1, t2)

中心化和锐化

dino.png

  • 两种操作本质上是互补的,防止模型训练崩溃

中心化(centering

  • 中心化的目的是防止特征向量的某个维度占主导地位,从而导致模型输出分布过于集中
  • 本质就是一种均值为 0 的归一化,可以提高模型训练的稳定性

锐化(Sharpening

  • 锐化操作的目的是增加教师网络输出的概率分布的锐度,使得输出的概率更加集中在少数几个维度上
  • 实现上,锐化通过修改蒸馏温度系数实现

模型效果

  • 比一众视觉自监督模型效果都好,比如:MoCo v1/v2SimCLR v1/v2

Thought

  • 感觉是 MoCo 系列的升级,框架本身不变,加了数据,稳定了训练过程,增加了些许 trick

URL

TL;DR

  • 本文提出一种提示词微调的方法,是对 P-Tuning 的升级
  • P-Tuning v2 要解决的问题是:对于所有(或大多数)任务类型和模型参数规模,将提示词微调的精度达到和整体参数微调同样的效果,这也是这篇论文的题目

Algorithm

p_tuning_v2.png

P-Tuning 的优化

1. P-Tuning 只在模型输入层添加可学习的连续嵌入,P-Tuning v2 在模型的每一层都添加

  1. Prefix Tuning / Prompt Tuning / P-Tuning 三种方法都是在模型输入中加入连续嵌入
    • 添加方式可能是前缀,也可能是其他 Concat Pattern
    • 通过 Self-Attentionembedding 间信息融合机制让虚拟连续嵌入影响整个模型
  2. P-Tuning V2 的做法则完全不同,是对模型的 每一层 都添加了可学习的虚拟连续嵌入
    • 具体来说是通过初始化虚拟 past_key_values 来实现的
    • GPT2-small 来举例(12transformer,每层 12 个头,每个头的 dim = 64
      • 假设 virtual_prompt_seq_len=3input_prompt_seq_len=10
      • 那么需要先初始化 past_key_valuesshape = [12, 2, 12, 3, 64],分别表示:
        • num_layers
        • key and value
        • num_heads
        • virtual_prompt_seq_len
        • dim
        • shape = [12, 2, 12, 3, 64] 以及修改后的输出层参数是可训练的所有参数
    • 然后将序列长度为 10input token embeddings 输入模型,第一层输出长度还是 10
    • 第二层以及之后的每一层都将上一层输出的长度为 10 的序列和长度为 3virtual_prompt_key_values 合并计算,并输出长度为 10 的序列

2. 不再使用 Verbalizer,而是使用 class head

  1. 什么是 Verbalizer ?
    • 传统的预训练语言模型(例如 Bert)的输入是一个 token 序列,输出是一个 token,也就是说词表中每个词都有可能输出
    • 现在有个下游任务需要用 Bert 做情感分类,输入是一段话,输出是:{正面,负面,中性} 中的一种,而且用 P-Tuning 方法微调,那么直接把输入附加上一些虚拟连续提示嵌入,输出的结果还是整个词表,不是三分类
    • 这时候就需要 Verbalizer 的存在了,它的作用是将 Bert 模型的输出从词表空间映射到三分类空间,它的实现形式可以是规则,也可以是深度学习模型
  2. P-Tuning V2 如何抛弃 Verbalizer?
    • 抛弃 Verbalizer 的方式很简单,就是打破 Prompt Tuning 模型时不应修改模型参数和结构 的限制
    • 直接删除预训练模型输出层,改成任务相关的层并随机初始化,然后微调

Thought

  • 看起来比 P-Tuning v2 更优雅,和 kv cache attention 结合起来,推理耗时增加较小
  • 据说对大模型来讲,这种方法和 Prompt Tuning 相比并没有显著精度优势(模型参数量小时,设计很重要;模型参数量大时,参数量几乎可以弥补一切设计上的非最优)

URL

TL;DR

  • 本文提出一种 Prompt Tuning 的方法 P-Tuning
  • Prefix TuningPrompt Tuning 这种连续词嵌入作为前缀的方法不同, P-Tuning 把连续词嵌入分段插入到 输入标签 之间

Algorithm

提出问题

  • 模型对人工设计的 Prompt 很敏感(指预训练模型,非大模型),同一个模型同一个数据集,只要稍微改变问题的问法,评测指标就差非常多,如图:
    p-tuning_1.png
  • P-Tuning 可解决此类问题

解决问题

  • 使用连续词嵌入(可训练)和离散词嵌入(不可训练)相结合的方法,做 Prompt Tuning 微调
    p-tuning_2.png
  • 上图左侧是传统全部用离散词嵌入 Prompt 过程
  • 上图右侧是离散词嵌入和连续词嵌入相结合的方法,其中 capticalBritain 两个问题中最关键的词使用离散词嵌入(来自于词表,固定不可训练),并在离散词嵌入周围插入若干连续词嵌入(可通过反向传播梯度下降训练)

数学表述

  • P-Tuning 中输入序列为 T={[P0:i],x,[P(i+1):j],y,[P(j+1),k]}T=\{[P_{0:i}],x,[P_{(i+1):j}],y,[P_{(j+1),k}]\},其中:
    • x 表示原始输入的离散词文本(还没有变成词向量)
    • y 表示原始的 label 文本
    • [P][P] 表示连续词向量
  • 输入序列 T 需要通过一种特殊的 Prompt Encoder 变成真实的词嵌入输入 {h0,...,hi,e(x),hi+1,...,hj,e(y),hj+1,...,hk}\{h_0,...,h_i,e(x),h_{i+1},...,h_j,e(y),h_{j+1},...,h_k\},其中:
    • e(x),e(y)e(x),e(y) 是通过查词表得到的离散词嵌入
    • hh 是通过 MLP/LSTM 等方法得到的连续向量的词嵌入,向量的长度和离散词嵌入一致

Thought

  • Prefix Tuning 插入连续词嵌入的自由度更高,因此理应效果更好,但总感觉解决问题的方法不优雅,因为离散和连续嵌入结合的模板是人为规定的,包含了较多先验知识在里面

URL

TL;DR

  • 本文提出的 prompt tuningprefix tuning 非常相似,是一种通过给不同任务输入前添加不同前缀,同时冻结原预训练模型参数的微调方式
  • prefix tuning 区别主要在前缀词向量的设置和初始化方式方面

Algorithm

prompt_tuning.png

Prompt tuning 的前缀词向量长度应该设置多少?

  • 作者实验了 {1, 5, 20, 100, 150} 等长度的前缀长度,结论是 20 最合适,超过 20 收益可忽略

Prompt tuning 的前缀初始化方式

  • 作者实验了三种前缀初始化方式:
    1. 随机初始化(和 prefix tuning 一致)
    2. 从词表中随机选择常见词初始化
    3. 用自然语言描述任务,并将其根据词表转化为词向量
  • 实验结论是:第三种方式最优

其他部分

  • 我没看出来和 prefix tuning 有任何不同,甚至本文对 prefix tuning 的理解都是错的

Thought

  • 我认为这篇论文在 prefix tuning 的基础上改动较小,比较水

URL

TL;DR

  • LLM 作为基础预训练模型,由于其参数量极大(以 B 计)的特点,无法像传统视觉或语言预训练模型一样,在下游任务数据集上全局微调
  • 本文提出一种非常新颖的 LLM 微调方法,不对原始模型参数进行改进,而是在每个任务的 Prompt 输入之前加入不同长度的虚拟 prefix token seqence 来实现
  • 其中 prefix token seqence 实际上是一个 Rlen×dims\mathbb{R}^{len\times dims} 的可学习矩阵
    • 其中,lenlen 表示 prefix token seqence 的长度
    • dimsdims 表示原始模型的词嵌入的长度
    • 虚拟 token 的含义是不是真正来自词表的 token,而是无意义的可学习的连续的任务相关的 token 表征
    • 不同任务的可学习矩阵不共享

Algorithm

传统 fine tuningprefix tuning 的对比

prefix1.png

  • 传统 fine tuning 更新全部模型参数
  • prefix tuning 不更新任何模型参数,只是额外加入了部分参数(虚拟 token 表征参数)

任务相关应该怎么理解?

  • 任务相关是指,每一个任务都需要独立的无法共享的 prefix token 表征,而且表征的长度可能不同
  • 例如,用预训练的 GPT-2 small + prefix tuning,使其可以做内容总结任务,就需要初始化一个 R200×768\mathbb{R}^{200\times 768} 的矩阵,然后通过有监督训练优化这个虚拟 embedding 层相关的参数
    • 其中:200 是通过实验发现的最合适的总结任务 prefix token sequence 长度
    • 768GPT-2 small 模型真实 token embedding 的维度
  • 如果想用预训练的 GPT-2 small + prefix tuning,使其可以做表格转文本的任务,就需要重新初始化虚拟 token embedding 矩阵,重新训练
    prefix3.png

实验证明,对于总结任务,200 是个合适的 prefix token sequence 长度,对于表格转文本,10 长度更合适

prefix tuning 如何适配不同架构的预训练模型?

prefix2.png

  • 对于 GPT 系列的 Casual Decoder 架构,只需要在输入文本 token sequence 之前加入 prefix token sequence
  • 对于 Encoder Decoder 架构,需要在 encoderdecoder 输入之前都加入 prefixencoderdecoderprefix token 表征可以共享,也可以不共享,具体看任务效果

fine tuning 相比,效果如何?

  • 在数据充足且任务难度不大的情况下,prefix tuning 效果不差于 fine tuning
  • 由于 prefix 没有修改模型原始参数,所以不容易出现 full fine tuning 中常出现的模型退化问题

Thought

  • prefix tuning 看起来是非常有希望的 topicprompt engineering 可挖掘的空间还很大
  • 这种微调方式实际上还需要存储大模型的所有参数和推理的中间结果用于梯度回传,所以实际微调代价不如想象中小

URL

TL;DR

  • 本文提出一种强化学习算法 PPOProximal Policy Optimization,近端策略优化),旨在解决传统策略梯度方法中存在的数据效率低、鲁棒性差等问题。
  • PPOTRPOTrust Region Policy Optimization,信任域策略优化)进行了简化。

Algorithm

1. 算法核心思想

策略梯度方法

  • PPO 基于策略梯度方法,通过估计策略的梯度并使用随机梯度上升来优化策略。

裁剪概率比

  • PPO 引入了裁剪概率比(clipped probability ratio)的概念,通过对概率比进行裁剪,限制策略更新的幅度,从而避免过大的策略更新导致的性能下降。

近端目标函数

  • PPO 使用近端目标函数(surrogate objective function),它是一个关于策略参数的函数,用于在策略更新时进行优化。
  • 这个目标函数考虑了策略更新前后的概率比和优势函数(advantage function)的乘积,并通过裁剪概率比来调整这个乘积,使得策略更新更加稳健。

多轮更新

  • 与传统的策略梯度方法每次只使用一个数据样本进行一次梯度更新不同,PPO 允许对每个数据样本进行多轮(epochs)的梯度更新,提高了数据的使用效率。

简化实现

  • 相比于 TRPO 等算法,PPO 的实现更加简单直观,易于与其他算法和框架集成。

2. 伪代码实现

PPO_1.png

  1. 在每次迭代中,多个 actor 并行地在环境中运行旧策略 πθold\pi_{\theta_{old}},收集数据并计算优势估计。
  2. 然后,使用这些数据来优化近端目标函数,通常使用 K epochs(重复使用收集到的数据更新 K 次) 和 SGDAdam 优化器来进行优化。
  3. 优化完成后,更新策略参数 θ\theta,并在下一次迭代中使用新的策略。

3. 公式分析

  • PPO 的目标函数公式分为三部分:
    • 裁剪的目标函数(CLIP),在时间步 t 的裁剪目标函数,用于优化策略参数 θ\theta
    • 价值函数损失(VF, value function
    • 策略熵(S, entropy bouns),用于鼓励探索。
      PPO_2.png

裁剪概率比的优势

  • LtCLIP+VF+S(θ)=E^t[LtCLIP(θ)c1LtVF(θ)c2S[πθ](st)]L_t^{CLIP+VF+S}(\theta)=\hat E_t[L_t^{CLIP}(\theta)-c_1L_t^{VF}(\theta)-c_2S[\pi_\theta](s_t)]
    • LtCLIP(θ)=E^t[min(rt(θ)A^t, clip(rt(θ),1ϵ,1+ϵ)A^t)]L_t^{CLIP}(\theta)=\hat E_t[\min(r_t(\theta)\hat A_t,\ clip(r_t(\theta), 1-\epsilon,1+\epsilon)\hat A_t)]rt(θ)=πθ(atst)πθold(atst)r_t(\theta)=\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}
      • r(θ)r(\theta) :概率比率表示新旧策略在状态 s 下采取动作 a 的概率之比
      • πθ(as)\pi_\theta(a|s) :在状态 s 下采取动作 a 的概率,有策略参数 θ\theta 决定的随机策略
      • πθold(as)\pi_{\theta_{old}}(a|s) :在策略更新前,状态 s 下采取动作 a 的概率
      • ϵ\epsilon :超参数,用于控制裁剪的严格程度
      • A^\hat A :优势函数(advantage function)的估计值,它估计了采取行动 a 相比于平均状况能带来多少额外回报
    • LtVF(θ)=(Vθ(st)Vttarget)2L_t^{VF}(\theta)=(V_\theta(s_t)-V_t^{target})^2
      • Vθ(s)V_\theta(s) :策略参数 θ\theta 下,状态 s 的期望回报,V 表示价值函数
      • VtargetV^{target} :折扣回报的估计
      • 价值函数损失有助于确保价值函数的预测尽可能接近真实回报,从而为策略提供准确的价值估计
    • S[πθ](st)=Eaπθ(st)[logπθ(ast)]S[\pi_\theta](s_t)=-\mathbb{E}_{a\sim \pi_\theta(\cdot|s_t)}[\log \pi_\theta(a|s_t)]
      • 表示在状态 s 下,根据策略 πθ\pi_\theta 采取动作的熵
      • 熵是衡量随机性的一个指标,策略的熵越高,表示策略在选择动作时越随机,这有助于算法探索更多的状态-动作空间

Thought

  • 对强化学习了解不够深,无法完全 getPPO 的精髓
  • InstructGPT / ChatGPT 使用了 PPO 强化学习策略,所以做 LLM 需要掌握此技能

URL

TL;DR

  • 大模型预训练后,在很多任务上已有非常强的能力,但输出的结果可能是不可信、有害或无帮助的。
  • 本文提出 InstructGPT 算法和微调流程,旨在让大模型通过人类反馈的微调方式,在大量任务上对齐人类意图。
  • 目标是做到 3H:
    • Helpful
    • Honest
    • Harmless

Algorithm

instructgtp_1.png

  • 具体来说,InstructGPT 微调分为三步:
    • 根据采集的 SFT (Supervised FineTune) 数据集对 GPT-3 进行有监督的微调;
    • 收集人工标注的 对比数据,训练奖励模型 RM (Reword Model)
    • 使用 RM 作为强化学习的优化目标,利用 PPO 算法(Proximal Policy Optimization 最近策略优化)微调 SFT 模型。

数据集

instructgpt_2.png

1. SFT 数据集

  • SFT 数据一部分来自使用 OpenAIPlayGround 的用户,另一部分来自 OpenAI 雇佣的 40 名标注工。
  • 在这个数据集中,标注工的工作是根据内容自己编写指示,并且要求编写的指示满足下面三点:
    • 简单任务:labeler 给出任意一个简单的任务,同时要确保任务的多样性;
    • Few-shot 任务:labeler 给出一个指示,以及该指示的多个查询-响应对;
    • 用户相关的:从接口中获取用例,然后让 labeler 根据这些用例编写指示。

2. RM 数据集

  • RM 数据集是对比数据类型,用于训练一个奖励模型(Reward Model)。
  • InstructGPT/ChatGPT 的做法是先让模型生成一批候选文本,让后通过 labeler 根据生成数据的质量对这些生成内容进行排序。
  • 具体做法是对于每一个 Prompt,模型给出 K 个候选输出(4K94\le K\le9),然后将 K 个输出中挑选 2 个样本给 labelerlabeler 给出哪个更好,一个 Prompt 一共需要 labeler 标注 CK2C_K^2 次,最终得到 K 个候选输出的排序结果。

3. PPO 数据集

  • InstructGPTPPO 数据没有进行标注,它均来自 GPT-3API 用户。
  • 有不同用户提供的不同种类的生成任务,其中占比最高的包括生成任务(45.6%),QA(12.4%),头脑风暴(11.2%),对话(8.4%)等。

微调过程

1. SFT 有监督微调过程

  • GPT-3 的预训练过程保持一致,只是监督信号从监督一段文本中下一个 token 变成监督模型对 prompt 输出和 ground truth 之间的差异。

2. Reward Model 训练过程

  • Reward Model 的输入是 Prompt + Response,输出是对应的奖励值,是个回归模型,其中 ResponseRM 数据集中的排序的模型候选输出。
  • Reward Model 实际上是原始 LLM 模型将最后一层的分类层改成回归层,用来回归奖励值。
  • 损失函数是:loss(θ)=1CK2E(x,yw,yl)D[log(σ(rθ(x,yw),rθ(x,yl)))]loss(\theta)=-\frac{1}{C_K^2}E_{(x,y_w,y_l) \sim D}[\log(\sigma(r_\theta(x,y_w), r_\theta(x,y_l)))],使用对比学习的方式训练。
  • 其中:
    • rθr_\theta 表示参数为 θ\thetaReward Model
    • x 表示 Prompt
    • yw ,yly_w\ ,y_l 分别表示对 Prompt xlabeler 喜欢和不喜欢的一对 Response

3. 使用近端策略优化强化学习算法(PPO) + Reward Model 继续优化 SFT 模型

InstructGPT/ChatGPT 的分析和计划

1. 收益

  1. 效果比 GPT-3 更加真实
  2. 在模型的无害性上比GPT-3 效果要有些许提升
  3. 具有很强的 Coding 能力

2. 损失

  1. 会降低模型在通用 NLP 任务上的效果
  2. 有时候会给出一些荒谬的输出
  3. 模型对指示非常敏感
  4. 模型对简单概念的过分解读
  5. 对有害的指示可能会输出有害的答复

3. 未来计划

  1. 人工标注的降本增效:将人类表现和模型表现有机和巧妙的结合起来是非常重要的
  2. 模型对指示的泛化/纠错等能力:这不仅可以让模型能够拥有更广泛的应用场景,还可以让模型变得更“智能”
  3. 避免通用任务性能下降:需要方案来让生成结果的 3H 和通用 NLP 任务的性能达到平衡

Thought

  • ChatGPT 爆发之前,Open-AI 多年来一直在默默研究这个领域,包括 GPT 系列、RLHFPPO 等,这些东西的有机结合就成了 InstructGPT,放大就变成了 ChatGPTInstructGPT 是基于 GPT-3 对齐的,ChatGPT 是基于 GPT-3.5 对齐的)
  • 想在一个领域做到断崖式的领先,小聪明是没用的,需要长期积累 + 修炼内功

URL

TL;DR

  • 传统的 DDP 训练大模型有个很大的问题是:由于 DDP 训练中每个 GPU 上都需要存储模型状态信息和其他状态信息:
    • 模型状态信息包括:Prameter + Gradient + Optimizer states
    • 其他状态信息包括:激活值(Activation)、各种临时缓冲区(buffer)以及无法使用的显存碎片(fragmentation
      所以显存占用较大,使得不能通过 DDP 训练较大的模型。
  • 本文提出来 Zero-DPZero-R 两种大模型训练优化策略,分别解决模型状态信息过大和其他状态信息过大的问题。
  • Zero-DP 全称是 Zero Redundancy Optimizer Data Parallelism,主要优化点是将 GPU 间冗余存储的模型状态信息(Parmeter + Gradient + Optimizer states)删除,每个 GPU 只保留一小部分信息,所有 GPU 上的信息汇聚之后才是完整模型。
  • Zero-R 的目的是解决 Tensor Parallelism 训练时,Activation 信息占用过大的问题。

Algorithm

ZeRO-DP

ZeRO-DP 想要解决什么问题?

  • 解决 DDP 训练过程中,Parmeter + Gradient + Optimizer states 对显存的占用过大,导致无法训练很大的模型的问题。

没有其他方法可以降低显存占用吗?

  • 有的,比如混合精度训练 (https://arxiv.org/pdf/1710.03740) 就是一个常用的有效的手段
    mixed_precision_1.png
  • 已知 LLM 通常使用 Adam / AdamW 优化器进行优化,假设模型参数量为 Φ\Phi,那么用混合精度训练过程中,模型状态信息实际的显存占用是:
    • Parameter 占用 2Φ2\Phi 字节(fp16 类型存储)
    • Gradient 占用 2Φ2\Phi 字节(fp16 类型存储)
    • Adam Optimizer state 占用:
      • fp32 parameter 备份占用 4Φ4\Phi 字节
      • fp32 momentum 一阶动量占用 4Φ4\Phi 字节
      • fp32 variance 二阶动量占用 4Φ4\Phi 字节
    • 共计 16Φ16\Phi 字节,其中 Optimizer states 占用 75%,因此 Optimizer states 是本文主要想解决的问题

ZeRO-DP 是如何解决模型状态信息占用过大问题的?

ZeRO-DP_1.png

  • ZeRO-DP 的核心思想是 Data Parallelism 时,每个模型只保留 1N\frac{1}{N} 的模型状态信息,N 表示 GPU 数量,在需要信息聚合时使用 Ring Reduce 来聚合。
  • 有三种程度的优化:
    • PosP_{os} 只删除冗余的优化器状态,所有优化器的信息加起来才是完整优化器状态。
    • Pos+gP_{os+g} 删除冗余的优化器状态和梯度。
    • Pos+g+pP_{os+g+p} 删除冗余优化器状态、梯度和权重。

分块存储 + Ring Reduce 会带来额外的通信代价吗?

  • All Reduce 运算来计算设备间梯度均值(求和)
    ZeRO-DP_2.png
  • Reduce Scatter 来求和,All Gatherbroadcast 同步到所有设备上
    ZeRO-DP_3.png
  • 使用 Ring Reduce 实现 Reduce Scatter
    ZeRO-DP_4.png

颜色逐渐累积代表梯度累加

  • 使用 Ring Reduce 实现 All Gatherbroadcast
    ZeRO-DP_5.png

深色逐渐蔓延到所有设备,代表 broadcast

  • PosP_{os}Pos+gP_{os+g} 不会带来额外的通信代价,Pos+g+pP_{os+g+p}
  • 证明:
    • 传统 Data Parallelism 在计算梯度后,需要进行一次 All Reduce 来计算设备间的梯度均值,All Reduce 可分为 Reduce ScatterAll Gather 两步,每一步实际上都是一次 Ring Reduce
      • Reduce Scatter 如图三所示,把所有设备梯度累加,分散到各个设备上
      • All Gather 如图四所示,把分散到各个设备上的梯度累加值同步到所有设备上
      • Ring Reduce 是一种理论最优通信算法,每个周期内每个设备的发送和接收都被占用,因此梯度总字节数为 2Φ2\Phi (fp16),由于做了两次 Ring Reduce,所以每个设备的通信总字节数为 4Φ4\Phi(发送和接收算一次)
    • ZeRO-DP Pos+gP_{os+g} 也需要通过 All Reduce 计算设备间梯度均值,也需要 Reduce ScatterAll Gather 两个步骤,也是通过两次 Ring Reduce 实现,具体做法和传统 Data Parallelism 没有什么区别,因此每个设备的通信总字节数也是 4Φ4\Phi
    • ZeRO-DP Pos+g+pP_{os+g+p} 在前向传播时,也需要用 All Reduce 计算得到全量 Parameter,这一步在传统 Data Parallelism 中并不需要,因此会有额外通信量

ZeRO-R

ZeRO-R 想要解决哪些问题?

  1. Tensor ParallelismActivation 冗余存储问题
  2. 临时缓冲区占用问题
  3. 显存碎片导致无法使用的问题

ZeRO-R 如何解决 Tensor ParallelismActivation 冗余存储问题?

  • 是通过激活值检查点分片存储(Partitioned Activation Checkpointing,简称 PAP_A)来解决 Tensor ParallelismActivation 冗余存储问题。
  • 其中的 分片 思想和 ZeRO-DP 是一样的,N 个设备每个设备只存储 1N\frac{1}{N} 的信息,通过 All Reduce 进行信息聚合
  • 激活值检查点 思想是来自 https://arxiv.org/pdf/1604.06174 这篇论文,实际上是传统前向反向传播过程和 ReCompute 前向反向传播过程的折衷,也是空间和时间的折衷。
  • 传统前向反向传播过程:存储每个中间算子的前向传播计算结果,用于反向传播计算梯度
    ZeRO-R_1.webp
  • ReCompute 前向反向传播过程:不存储任何中间算子的前向计算结果,反向传播需要用时重新计算
    ZeRO-R_2.webp
  • Activation Checkpointing 前向反向传播过程:存储部分中间算子前向计算结果,反向传播需要用时,从拓扑上游最近的检查点开始重新计算
    ZeRO-R_3.webp

ZeRO-R 如何解决临时缓存区空间占用问题?

  • 通常情况下,临时缓存区是用来做数据聚合的,例如 All Reduce 操作,模型越大,需要的临时缓存区也越大。
  • ZeRO-R 用了一个非常简单的策略:在模型大到一定程度后,使用固定尺寸的融合缓存区,即常数尺寸缓存区(Constant Size Buffers,简称 CBC_B

ZeRO-R 如何解决显存碎片化问题?

  • 问:为什么会产生显存碎片化问题?
  • 答:因为 Activation Checkpointing 的存在,不同的 Activation 有不同的生命周期,非 checkpointActivation 用完即弃,checkpointActivation 需要反向传播之后才可以丢弃,不同生命周期的存储空间交织就会出现显存碎片
  • 问:显存碎片化会带来什么问题?
  • 答:显存碎片化通常会有两个问题:
    1. 大量的碎片无法使用,使得实际空间占用不多,但出现 OOM
    2. 大量存储碎片会导致系统在分配内存时需要大量的搜索来找到合适的空间,会带来额外耗时
  • 问:那 ZeRO-R 如何解决显存碎片化问题?
  • 答:预先给不同生命周期的 Activation 划分不同的存储空间,避免产生交织。例如:把显存划分成两段连续的存储空间,其中一段存储 Checkpoint Activation 另外一段存储 Temporary Activation

Thought

  • 当大模型时代到来,算力和存储不够用了,大家也开始审视之前做的东西是不是时间/空间低效的,是不是还能榨点油水。