Zhangzhe's Blog

The projection of my life.

0%

GPipe: Easy Scaling with Micro-Batch Pipeline Parallelism

URL

TL;DR

  • 随着模型和数据量的增长,单 GPU 已经无法满足需求,所以需要多 GPU 并行
  • Data parallelism 是最常用的,即将数据划分到多个 GPU 上,每个 GPU 单独前向反向传播,仅在梯度更新时聚合,例如 pytorch ddp
  • 本论文提出一种新的并行方式 Pipeline parallelism,这种方式是将模型划分为多段子图,每个设备上加载一段,各个子图直接串联
  • 直接实现的 Pipeline parallelismGPU 利用率较低,本文通过一些方法可大幅提高 GPU 利用率

Algorithm

gpipe_1.png
Gpipe 方案主要包含三个部分

划分 Stage

  1. 模型切成若干 stage,每个 GPU 只加载一段
  2. 每个 stage 串起来形成一个 pipeline

划分 Micro-Batch

  1. 传统模型训练过程中使用 Mini-BatchMini-Batchpipeline parallelism 中会出现大量气泡,因此 Gpipe 提出 Micro-Batch 概念
  2. Micro-Batch 是将 Mini-Batch 再划分成多份,用于排计算流水,Mini-Batch 变成最小计算单元,有利于减小气泡,上图 2.c 中横向就是一个 Mini-Batch 划分得到的一组 Micro-Batch
  3. 为了保证梯度更新和 Mini-Batch 的一致性,Gpipe 会将 Micro-Batch 梯度累积,在一个 Mini-Batch 的全部 Micro-Batch 计算结束后再更新(图 2.c 中一行只有一次 update

重计算

  1. 传统模型训练过程中,计算的中间结果都需要保存,用于反向传播计算梯度,非常消耗显存,同时会导致 pipeline parallelism 的数据依赖问题,上图 2.a 中的横向连接
  2. 重计算是一种用计算换空间的操作,在前向传播过程中,丢弃 stage 内所有中间结果,只保留 stage 间中间结果;在反向传播时,stage 内中间结果会再做一次前向传播计算得到
  3. 重计算可大幅降低显存占用,使得可以放更大的 batch size 和更大的模型,且可以改善数据依赖问题

Thought

  • 这种方案前瞻性比较强,提出时还没有 LLM,且拥有和传统非并行训练在数学上一致的参数更新机制
  • 给并行训练提供了一种新思路,非常灵活