Zhangzhe's Blog

The projection of my life.

0%

大模型RLHF入门

TL;DR

  • 本文是大模型RLHF的入门教程,目的是用尽可能简单的方式介绍大模型RLHF的基本概念和原理,不涉及过多实现细节。

什么是 RLHF

  • RLHFReinforcement Learning with Human Feedback (人类反馈强化学习) 的缩写,目的是在强化学习任务中引入人类反馈,以提高强化学习模型的性能
  • 大多数分类体系下,SFT 属于 RLHF 中的一个步骤,但本文不讨论 SFT,假设起点模型已经做好了 pre-trainSFT
  • RLHF 的目标是:
    • helpful
    • honst
    • harmless

RLHF 相关术语

  • policy model:策略模型,即待训练的强化学习模型
  • reference model:参考模型(基准模型),通常是冻结参数的原始 policy model,防止模型在 RL 过程中出现退化(可理解为正则化项)
  • reward model:奖励模型,用于评估 policy model 生成的结果的好坏,相当于强化学习的 environment
  • value model:价值模型,通常是 policy model 的一个子模块,用于评估状态的价值(也可以理解为对环境的估计),辅助 policy model 的决策,仅仅在 RLHF 过程中有效,训练结束后会被丢弃

RLHF 的基本实现原理

  • RLHF 的实现包括两个阶段:
    • 训练奖励模型阶段
    • 强化学习阶段

1. 训练奖励模型阶段

  • 数据格式:prompt + response 1 + response 2 + label
    • prompt:问题描述
    • response 1:模型生成的回答 1
    • response 2:模型生成的回答 2
    • label:人类标注的哪个回答更好(0 / 1
  • 模型:
    • 模型结构:通常是 LLM,可以由待 RLHF 的策略模型初始化,也可以重新定义模型结构和随机初始化
    • 输入:prompt + response
    • 输出:reward,是一个标量,表示 response 相对于 prompt 的好坏(logits
    • 损失函数:L=1Ni=1Nlogσ(R(prompt,responsechosen)R(prompt,responsereject))L=-\frac{1}{N}\sum_{i=1}^N \log \sigma(R(prompt, response_{chosen}) - R(prompt, response_{reject}))
      • R(prompt,response)R(prompt, response):奖励模型就当前 promptresponse 的奖励(优秀程度)
      • σ\sigmasigmoid 函数
      • NNbatch size
      • responsechosenresponse_{chosen}:人类标注的更好的回答
      • responserejectresponse_{reject}:人类标注的更差的回答

2. 强化学习阶段

  • RLHF 的强化学习阶段通常使用 PPO 算法
  • RLHF 中使用的 PPO 与传统 PPO 有以下几点不同:
    1. 数据格式不同:
      • 传统 PPO 需要的数据格式为: (st,at,rt,st+1,done)(s_t, a_t, r_t, s_{t+1}, done)
      • RLHF 需要的数据格式为: (st,at,rt)(s_t, a_t, r_t),其中:
        • sts_tprompt
        • ata_tresponse
        • rtr_treward(prompt, response)
    2. reward 的计算方式不同:
      • 传统 PPO 中的 reward 由环境给出
      • RLHF 中的 rewardreward model 给出
    3. policy model 的更新方式不同:
      • 为了防止 policy modelRL 过程中退化,需要在损失函数中加入 policy modelreference modelKL 散度作为正则化项
      • reference model 通常是冻结参数的原始 policy model 或者是一个不同架构的经过 pre-train + SFT 的模型
  • 综上,RLHF 强化学习阶段的损失函数为:
    • LRLHF=LPPO+βDKL(policy,reference)L_{RLHF} = L_{PPO} + \beta \cdot D_{KL}(policy, reference)
      • LPPOL_{PPO}:传统 PPO 的损失函数
      • β\beta:正则化系数
      • KLdiv(policy,reference)KL_{div}(policy, reference)policy modelreference modelKL 散度
      • LPPO=E[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]+c1MSE(Vt,Vttarg)c2H[πθ(atst)]L_{PPO} = \mathbb{E}[\min(r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t)] + c_1 \cdot \text{MSE}(V_t, V_t^{targ}) - c_2 \cdot H[\pi_{\theta}(a_t|s_t)]
        • At=λδt+lA_t = \lambda \delta_{t+l}:优势函数
        • δt=Reward(st,at)V(st)\delta_t=Reward(s_t,a_t)- V(s_t)
        • rt(θ)=πθ(atst)πθold(atst)r_t(\theta)=\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}policy model 新旧策略概率比值
        • VtV_t:状态价值函数
        • VttargV_t^{targ}:目标状态价值函数,由于 RLHF 属于单步强化学习,所以 Vttarg=Reward(st,at)V_t^{targ}=Reward(s_t,a_t)
        • H[πθ(atst)]H[\pi_{\theta}(a_t|s_t)]:熵
        • c1,c2c_1, c_2:权重系数

Thoughts

  • 用对比学习的方式训练奖励模型,可以减少人工标注数据量,同时降低数据标注的主观性噪声
  • 第二步强化学习阶段仅仅相当于单步强化学习(无需关注 st+1s_{t+1}),因此需要对 PPO 算法进行一定的修改(简化)