Zhangzhe's Blog

The projection of my life.

0%

KV Cache Transformer

URL

TL;DR

  • GPT like 大语言模型目前在生成式大模型领域占了主导地位,这些模型每一次推理只得到一个 token,下一次推理会将得到的 token 放到上一次输入 token 序列的末尾作为新输入,直到达到 max_seq_length 或输出的 token 表示 end of sequence
  • 这种 GPT like 的生成机制意味着每次推理需要将之前所有 token 再重复算一次,序列长度越长,计算量越大,耗时越长
  • KV Cache 通过缓存上一次推理过程中每一层 transformerkey / value,将推理的时间复杂度从 O(n2)\mathcal{O}(n^2) 降低到了 O(n)\mathcal{O}(n),且计算结果完全等价

Algorithm

  • 可以在本时间步使用上一个时间步缓存的 key / value 的必要理论依据是:
    1. 每个时间步输入的 token 序列除最后一个 token 之外都和上一个时间步输入相同
    2. attention mask 是下三角矩阵
    3. attention score 矩阵做 softmax 是逐行的而不是全局的或逐列的
    4. self-attention 操作之外,GPT like LLM 的其他层(layer norm, gelu, FFN 等)对 seq_len 维度来说都是 element-wise 的,且在推理阶段是静态的
  • 以上四个理论依据在 GPT like 模型中都满足,下面解释四个条件为什么是必不可少的:
    1. 如果上面第一个条件不满足,那么每次推理都是新的内容,缓存将毫无意义
    2. 下三角矩阵和逐行做 softmax 是必不可少的,随着 token 长度加一,每一层的 attention score 矩阵都会在最右和最下分别加一列和一行,由于下三角的存在,除最后一行外,其他每一行的值都不会改变(新加的一列不参与计算),而且逐行做,新加的一列也不影响其他行
    3. self-attentionGPT like LLM 中唯一一个在 token feature 间做信息融合的算子,其他算子都是在 token feature 内做且都是静态操作,所以随着序列长度变长,之前的序列特征都是静态的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, dim, layer_idx):
super(SelfAttention, self).__init__()
self.dim = dim
self.layer_idx = layer_idx
self.query = nn.Linear(dim, dim)
self.key = nn.Linear(dim, dim)
self.value = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1) # 重点是这里,逐行进行 softmax,而不是全局
self.dropout = nn.Dropout(0.1)
self.layer_norm = nn.LayerNorm(dim)
def forward(self, x, attn_mask=None):
q = self.query(x)
k = self.key(x)
v = self.value(x)
scores = torch.matmul(q, k.transpose(-2, -1))
scores = scores / torch.sqrt(torch.tensor(self.dim).float())
if attn_mask is not None:
scores = scores.masked_fill(attn_mask == 0, float("-inf"))
attention_weights = self.softmax(scores)
attention_weights = self.dropout(attention_weights)
print("layer_idx:", self.layer_idx)
print("q.sum(-1) =", q.squeeze().abs().sum(-1).tolist())
print("k.sum(-1) =", k.squeeze().abs().sum(-1).tolist())
print("v.sum(-1) =", v.squeeze().abs().sum(-1).tolist())
print("attention_weights =\n", attention_weights.squeeze())
print("#" * 100)
output = torch.matmul(attention_weights, v)
output = self.layer_norm(output + x)
return output
class GPTModel(nn.Module):
def __init__(self, dim, num_layers):
super(GPTModel, self).__init__()
self.dim = dim
self.num_layers = num_layers
self.attentions = nn.ModuleList(
[SelfAttention(dim, i) for i in range(num_layers)]
)
self.linear1 = nn.Linear(dim, dim)
self.linear2 = nn.Linear(dim, dim)
self.dropout = nn.Dropout(0.1)
self.layer_norm = nn.LayerNorm(dim)
def forward(self, x, attn_mask=None):
for i in range(self.num_layers):
x = x + self.attentions[i](x, attn_mask=attn_mask)
x = self.layer_norm(x)
x = x + self.linear2(self.dropout(torch.relu(self.linear1(x))))
x = self.layer_norm(x)
return x
# Example usage
model = GPTModel(dim=64, num_layers=2)
model.eval()
input_data = torch.randn(
1, 3, 64
) # Example input data with shape (batch_size, sequence_length, dim)
attn_mask = torch.tril(torch.ones(3, 3)) # Lower triangular attention mask
print("time step 1:")
output = model(input_data, attn_mask=attn_mask)
print("\n\ntime step 2:")
input_data = torch.cat((input_data, torch.randn(1, 1, 64)), 1)
attn_mask = torch.tril(torch.ones(4, 4)) # Lower triangular attention mask
output = model(input_data, attn_mask=attn_mask)
# 输出结果:
# time step 1:
# layer_idx: 0
# q.sum(-1) = [31.20478057861328, 34.04328536987305, 33.6611328125]
# k.sum(-1) = [31.446550369262695, 29.15508460998535, 29.265644073486328]
# v.sum(-1) = [30.789640426635742, 33.80058670043945, 32.4859619140625]
# attention_weights =
# tensor([[1.0000, 0.0000, 0.0000],
# [0.6760, 0.3240, 0.0000],
# [0.2452, 0.4188, 0.3360]], grad_fn=<SqueezeBackward0>)
# ####################################################################################################
# layer_idx: 1
# q.sum(-1) = [30.214162826538086, 34.45262145996094, 26.357181549072266]
# k.sum(-1) = [29.88446617126465, 31.186325073242188, 31.727235794067383]
# v.sum(-1) = [23.99290657043457, 35.66062545776367, 33.28850173950195]
# attention_weights =
# tensor([[1.0000, 0.0000, 0.0000],
# [0.3310, 0.6690, 0.0000],
# [0.4525, 0.3270, 0.2204]], grad_fn=<SqueezeBackward0>)
# ####################################################################################################
# time step 2:
# layer_idx: 0
# q.sum(-1) = [31.20478057861328, 34.04328536987305, 33.6611328125, 27.077472686767578]
# k.sum(-1) = [31.446550369262695, 29.15508460998535, 29.265644073486328, 27.211992263793945]
# v.sum(-1) = [30.789640426635742, 33.80058670043945, 32.4859619140625, 29.695066452026367]
# attention_weights =
# tensor([[1.0000, 0.0000, 0.0000, 0.0000],
# [0.6760, 0.3240, 0.0000, 0.0000],
# [0.2452, 0.4188, 0.3360, 0.0000],
# [0.1818, 0.2616, 0.2813, 0.2752]], grad_fn=<SqueezeBackward0>)
# ####################################################################################################
# layer_idx: 1
# q.sum(-1) = [30.214162826538086, 34.45262145996094, 26.357181549072266, 29.209003448486328]
# k.sum(-1) = [29.88446617126465, 31.186325073242188, 31.727235794067383, 26.988731384277344]
# v.sum(-1) = [23.99290657043457, 35.66062545776367, 33.28850173950195, 27.92920684814453]
# attention_weights =
# tensor([[1.0000, 0.0000, 0.0000, 0.0000],
# [0.3310, 0.6690, 0.0000, 0.0000],
# [0.4525, 0.3270, 0.2204, 0.0000],
# [0.2740, 0.0698, 0.4069, 0.2492]], grad_fn=<SqueezeBackward0>)
# ####################################################################################################

可以看到随着时间步的推移,每一层的 x / q / k / v / attention weight 的前 n-1token feature 都和上一个时间步相同,因此可以缓存

Thought

  • 实际上 q / k / v / attention weight 都可以缓存,只缓存 k / v 是计算量和 IO 量的 tradeoff
  • paged attention 是通过类似内存分页管理的方式管理 kv cache 的一种方法,动态分配显存,速度更快,支持更高 batch