URL
- paper: https://arxiv.org/pdf/2211.17192
- code: https://github.com/feifeibear/LLMSpeculativeSampling (民间实现)
TL;DR
- 本文提出一种新颖的大模型推理加速的新范式,叫做投机解码,用一个参数量很小的近似模型预测得到一段序列,让大模型评判 接受 还是 拒绝 这段序列。
- 从数学上可以证明投机解码的结果和大模型逐个
token
预测的结果完全相同。 - 可以加速推理
2 - 3
倍,据说GPT-4
使用了这种方法。
Algorithm
算法流程
-
给定
prefix_seq
序列,用近似小模型(approx model
,M_q
)预测 长度的后续序列,记做predict_seq
,并记录predict_seq
中每一个token
的预测概率,记作q
-
大模型(
target model
,M_p
)输入prefix_seq + predict_seq
,得到predict_seq
每个token
的概率(记作p
)和下一个token
(记作t
) -
对于每一个
token
,如果p >= q
,则表示大模型接受小模型对这个token
的提议;如果p < q
,则表示 大模型以 的概率接受这个token
,用随机均匀采样来确定是否接受这个token
-
添加大模型连续接受的
predict_seq
到prefix_seq
之后,如果predict_seq
每个token
都接受了,则prefix_seq = prefix_seq + predict_seq + t
。重复步骤1
代码实现
1 |
|
收益分析
-
大模型推理过程中,预填充阶段计算访存比较高,增量生成阶段计算访存比较低,因此增量生成阶段是主要优化方向
-
本论文的投机解码算法在某种程度上是将
target model
的增量生成阶段变成approx model
的增量生成 +target model
的预填充,收益主要来自于这里 -
且投机解码算法和其他大模型推理算法(例如
kv cache
、flash attention
、MQA / GQA
)并不冲突,可以合并使用达到最优加速效果
Thought
- 从另外一种新颖且符合逻辑的方面优化了大模型推理