TL;DR
- 本文是大模型DPO的入门教程,目的是用尽可能简单的方式介绍大模型DPO的基本概念和原理,不涉及过多实现细节。
什么是 DPO
DPO
是 Direct Performance Optimization (直接偏好优化)
的缩写,是 RLHF
的替代方案,目的是使用不涉及强化学习的方法来优化模型性能。
DPO 的数据格式
- 和
RLHF
一致,DPO
的数据格式为:prompt + response 1 + response 2 + label
prompt
:问题描述
response 1
:模型生成的回答 1
response 2
:模型生成的回答 2
label
:人类标注的哪个回答更好(0 / 1
)
DPO 损失函数的设计
L(θ)=E(x,y+,y−)∼D[logσ(sθ(x,y+)−sθ(x,y−))]
sθ(x,y)=logpθ(y∣x)
pθ(y∣x)=t=1∏Tpθ(yt∣y<t,x)
sθ(x,y)=t=1∑Tlogpθ(yt∣y<t,x)
- L(θ):损失函数
- θ:模型参数
- D:数据集
- x:输入
- y+:人类标注的更好的回答
- y−:人类标注的更差的回答
- sθ(x,y):模型对 (x,y) 的得分
- pθ(y∣x):模型对 x 生成 y 的概率
- σ:
sigmoid
函数
- T:序列长度
- y<t:y 的前 t−1 个元素
- sθ(x,y) 可经过修改变成大名鼎鼎的 困惑度 (
Perplexity
, PPL
):PPL(θ)=exp(−len(y)sθ(x,y))
Thoughts
DPO
事实上就是利用正负样本的困惑度来对比学习,从而优化模型性能
- 原理非常简单,收敛速度比
RLHF
快
- 但可以预见的是:
DPO
的损失计算方式比 RLHF
更 硬,那么必然对数据集的质量要求更高,否则会导致模型性能下降
- 一个常用的做法是在
SFT
之后,在 RLHF
之前使用 DPO
快速将模型性能提升到一个较高水平,然后再使用 RLHF
进一步优化