Zhangzhe's Blog

The projection of my life.

0%

Fast Inference from Transformers via Speculative Decoding

URL

TL;DR

  • 本文提出一种新颖的大模型推理加速的新范式,叫做投机解码,用一个参数量很小的近似模型预测得到一段序列,让大模型评判 接受 还是 拒绝 这段序列。
  • 从数学上可以证明投机解码的结果和大模型逐个 token 预测的结果完全相同。
  • 可以加速推理 2 - 3 倍,据说 GPT-4 使用了这种方法。

Algorithm

算法流程

  1. 给定 prefix_seq 序列,用近似小模型(approx modelM_q)预测 γ\gamma 长度的后续序列,记做 predict_seq,并记录 predict_seq 中每一个 token 的预测概率,记作 q

  2. 大模型(target modelM_p)输入 prefix_seq + predict_seq,得到 predict_seq 每个 token 的概率(记作 p)和下一个 token(记作 t

  3. 对于每一个 token,如果 p >= q,则表示大模型接受小模型对这个 token 的提议;如果 p < q,则表示 大模型以 1pq1 - \frac{p}{q} 的概率接受这个 token,用随机均匀采样来确定是否接受这个 token

  4. 添加大模型连续接受的 predict_seqprefix_seq 之后,如果 predict_seq 每个 token 都接受了,则 prefix_seq = prefix_seq + predict_seq + t。重复步骤 1

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

@torch.no_grad()
def speculative_sampling(
prefix: torch.Tensor,
approx_model: torch.nn.Module,
target_model: torch.nn.Module,
max_len: int,
gamma: int = 4,
temperature: float = 1,
top_k: int = 0,
top_p: float = 0,
) -> torch.Tensor:
seq_len = prefix.shape[1]
T = seq_len + max_len

approx_model_cache = KVCacheModel(approx_model, temperature, top_k, top_p)
target_model_cache = KVCacheModel(target_model, temperature, top_k, top_p)

resample_count = 0
target_sample_count = 0
accepted_count = 0

while prefix.shape[1] < T:
prefix_len = prefix.shape[1]

x = approx_model_cache.generate(prefix, gamma)
_ = target_model_cache.generate(x, 1)

n = prefix_len + gamma - 1

for i in range(gamma):
r = torch.rand(1, device=target_model_cache.device)
j = x[:, prefix_len + i]

if r > (target_model_cache._prob_history[:, prefix_len + i - 1, j]) / (
approx_model_cache._prob_history[:, prefix_len + i - 1, j]
):
# reject
n = prefix_len + i - 1
break

accepted_count += 1

prefix = x[:, : n + 1]

approx_model_cache.rollback(n + 1)

if n < prefix_len + gamma - 1:
# reject someone, sample from the pos n
t = sample(
max_fn(
target_model_cache._prob_history[:, n, :]
- approx_model_cache._prob_history[:, n, :]
)
)
resample_count += 1
target_model_cache.rollback(n + 1)
else:
# all approx model decoding accepted
assert n == target_model_cache._prob_history.shape[1] - 1
t = sample(target_model_cache._prob_history[:, -1, :])
target_sample_count += 1

prefix = torch.cat((prefix, t), dim=1)

return prefix

收益分析

  • 大模型推理过程中,预填充阶段计算访存比较高,增量生成阶段计算访存比较低,因此增量生成阶段是主要优化方向

  • 本论文的投机解码算法在某种程度上是将 target model 的增量生成阶段变成 approx model 的增量生成 + target model 的预填充,收益主要来自于这里

  • 且投机解码算法和其他大模型推理算法(例如 kv cacheflash attentionMQA / GQA)并不冲突,可以合并使用达到最优加速效果

Thought

  • 从另外一种新颖且符合逻辑的方面优化了大模型推理