Zhangzhe's Blog

The projection of my life.

0%

GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection

URL

TL;DR

  • 本文提出一种梯度低秩映射方法,和 LoRA 思想类似,但支持全参数优化,无需冻结模型也无需额外的训练参数,只需在梯度更新时进行低秩映射,从而在训练时节省显存。
  • 由于优化器状态是大模型微调过程中的主要显存消耗,因此本算法本质是在优化器中记录梯度矩阵的低秩近似,从而减少显存消耗。

Algorithm

总体流程

galore.png

  1. 计算参数的梯度 GRm×nG\in\mathbb R^{m\times n}
  2. 计算梯度的低秩近似(通常用 SVD 奇异值分解法), GUVT,URm×r,VRn×rG\approx UV^T,U\in \mathbb R^{m\times r},V\in \mathbb R^{n\times r}
  3. 记录低秩近似的优化器状态信息,例如:
    1. UUVV
    2. UUVV 的动量信息
    3. UUVV 的二阶动量信息
  4. 将低秩近似的梯度 UVTUV^T 合并变成近似的梯度 GG',并用 GG' 更新参数
  5. 由于梯度具有一定程度的稳定性,为了节省奇异值分解带来的损失,UUVV 不会在每次迭代中都重新计算,而是每隔 TT 次迭代更新一次

用公式表述

  1. 分解梯度 GGUUVV 的乘积
    1. GRm×nG\in \mathbb R^{m\times n}
    2. U,Σ,VT=SVD(G)U,\Sigma,V^T=SVD(G)
    3. URm×m,VRn×n,ΣRmU\in \mathbb R^{m\times m},V\in \mathbb R^{n\times n},\Sigma\in\mathbb R^m
    4. rr 是低秩近似的秩
    5. Ur=U[:,:r],Vr=V[:,:r],UrRm×r,VrRn×rU_r=U[:,:r],V_r=V[:,:r],U_r\in \mathbb R^{m\times r},V_r\in \mathbb R^{n \times r}
    6. Ur=UrΣU_r = U_r \Sigma
    7. G=UrVrTG' = U_r V_r^T
  2. 记录优化器状态信息,以 Adam 优化器为例
    1. mtU=β1mt1U+(1β1)Ur ,  mtV=β1mt1V+(1β1)Vrm_t^U = \beta_1 m_{t-1}^U + (1-\beta_1)U_r\ ,\ \ m_t^V = \beta_1 m_{t-1}^V + (1-\beta_1)V_r
    2. vtU=β2vt1U+(1β2)Ur2 ,  vtV=β2vt1V+(1β2)Vr2v_t^U = \beta_2 v_{t-1}^U + (1-\beta_2)U_r^2\ ,\ \ v_t^V = \beta_2 v_{t-1}^V + (1-\beta_2)V_r^2
    3. mtU=mtU/(1β1t) ,  mtV=mtV/(1β1t)m_t^U = m_t^U/(1-\beta_1^t)\ ,\ \ m_t^V = m_t^V/(1-\beta_1^t)
    4. vtU=vtU/(1β2t) ,  vtV=vtV/(1β2t)v_t^U = v_t^U/(1-\beta_2^t)\ ,\ \ v_t^V = v_t^V/(1-\beta_2^t)
  3. 计算低秩近似梯度并更新参数
    1. G=U×mtVvtV+ϵ+V×mtUvtU+ϵG'=U\times \frac{m_t^V}{\sqrt{v_t^V}+\epsilon} + V\times \frac{m_t^U}{\sqrt{v_t^U}+\epsilon}
    2. ΘΘηG\Theta \leftarrow \Theta - \eta G'
    3. Θ\Theta 是模型参数,η\eta 是学习率
  4. 更新分解矩阵 UUVV
    1. 每隔 TT 次迭代更新一次,也就是说 GGUUVV 可能不是同一个时间步的,GG 是当前时间步的梯度,而 UUVV 是之前更新的
    2. U,Σ,VT=SVD(G)U,\Sigma,V^T=SVD(G)
    3. Ur=U[:,:r],Vr=V[:,:r],UrRm×r,VrRn×rU_r=U[:,:r],V_r=V[:,:r],U_r\in \mathbb R^{m\times r},V_r\in \mathbb R^{n \times r}
  5. 正交化和秩调整(可选)
    1. 为了保持 UUVV 各自的正交性,可以对 UUVV 分别进行正交化处理
    2. UTU=I,VTV=IU^TU=I,V^TV=I
    3. 动态调整秩 rr 以适应训练过程中的梯度变化

Thought

  • 听上去很像 LoRA,终究是效果和显存占用的平衡问题。
  • 大模型微调能改的东西不多,所以传统机器学习用到的一些算法又可以在大模型炒炒冷饭了…