0%
TL;DR
- 本文是大模型RLHF的入门教程,目的是用尽可能简单的方式介绍大模型RLHF的基本概念和原理,不涉及过多实现细节。
什么是 RLHF
RLHF
是 Reinforcement Learning with Human Feedback (人类反馈强化学习)
的缩写,目的是在强化学习任务中引入人类反馈,以提高强化学习模型的性能
- 大多数分类体系下,
SFT
属于 RLHF
中的一个步骤,但本文不讨论 SFT
,假设起点模型已经做好了 pre-train
和 SFT
RLHF
的目标是:
RLHF 相关术语
policy model
:策略模型,即待训练的强化学习模型
reference model
:参考模型(基准模型),通常是冻结参数的原始 policy model
,防止模型在 RL
过程中出现退化(可理解为正则化项)
reward model
:奖励模型,用于评估 policy model
生成的结果的好坏,相当于强化学习的 environment
value model
:价值模型,通常是 policy model
的一个子模块,用于评估状态的价值(也可以理解为对环境的估计),辅助 policy model
的决策,仅仅在 RLHF
过程中有效,训练结束后会被丢弃
RLHF 的基本实现原理
1. 训练奖励模型阶段
- 数据格式:
prompt + response 1 + response 2 + label
prompt
:问题描述
response 1
:模型生成的回答 1
response 2
:模型生成的回答 2
label
:人类标注的哪个回答更好(0 / 1
)
- 模型:
- 模型结构:通常是
LLM
,可以由待 RLHF
的策略模型初始化,也可以重新定义模型结构和随机初始化
- 输入:
prompt + response
- 输出:
reward
,是一个标量,表示 response
相对于 prompt
的好坏(logits
)
- 损失函数:L=−N1∑i=1Nlogσ(R(prompt,responsechosen)−R(prompt,responsereject))
- R(prompt,response):奖励模型就当前
prompt
和 response
的奖励(优秀程度)
- σ:
sigmoid
函数
- N:
batch size
- responsechosen:人类标注的更好的回答
- responsereject:人类标注的更差的回答
2. 强化学习阶段
RLHF
的强化学习阶段通常使用 PPO
算法
RLHF
中使用的 PPO
与传统 PPO
有以下几点不同:
- 数据格式不同:
- 传统
PPO
需要的数据格式为: (st,at,rt,st+1,done)
RLHF
需要的数据格式为: (st,at,rt),其中:
- st:
prompt
- at:
response
- rt:
reward(prompt, response)
reward
的计算方式不同:
- 传统
PPO
中的 reward
由环境给出
RLHF
中的 reward
由 reward model
给出
policy model
的更新方式不同:
- 为了防止
policy model
在 RL
过程中退化,需要在损失函数中加入 policy model
和 reference model
的 KL
散度作为正则化项
reference model
通常是冻结参数的原始 policy model
或者是一个不同架构的经过 pre-train + SFT
的模型
- 综上,
RLHF
强化学习阶段的损失函数为:
- LRLHF=LPPO+β⋅DKL(policy,reference)
- LPPO:传统
PPO
的损失函数
- β:正则化系数
- KLdiv(policy,reference):
policy model
和 reference model
的 KL
散度
- LPPO=E[min(rt(θ)A^t,clip(rt(θ),1−ϵ,1+ϵ)A^t)]+c1⋅MSE(Vt,Vttarg)−c2⋅H[πθ(at∣st)]
- At=λδt+l:优势函数
- δt=Reward(st,at)−V(st)
- rt(θ)=πθold(at∣st)πθ(at∣st):
policy model
新旧策略概率比值
- Vt:状态价值函数
- Vttarg:目标状态价值函数,由于
RLHF
属于单步强化学习,所以 Vttarg=Reward(st,at)
- H[πθ(at∣st)]:熵
- c1,c2:权重系数
Thoughts
- 用对比学习的方式训练奖励模型,可以减少人工标注数据量,同时降低数据标注的主观性噪声
- 第二步强化学习阶段仅仅相当于单步强化学习(无需关注 st+1),因此需要对
PPO
算法进行一定的修改(简化)