Zhangzhe's Blog

The projection of my life.

0%

KV Cache Transformer

URL

TL;DR

  • GPT like 大语言模型目前在生成式大模型领域占了主导地位,这些模型每一次推理只得到一个 token,下一次推理会将得到的 token 放到上一次输入 token 序列的末尾作为新输入,直到达到 max_seq_length 或输出的 token 表示 end of sequence

  • 这种 GPT like 的生成机制意味着每次推理需要将之前所有 token 再重复算一次,序列长度越长,计算量越大,耗时越长

  • KV Cache 通过缓存上一次推理过程中每一层 transformerkey / value,将推理的时间复杂度从 O(n2)\mathcal{O}(n^2) 降低到了 O(n)\mathcal{O}(n),且计算结果完全等价

Algorithm

  • 可以在本时间步使用上一个时间步缓存的 key / value 的必要理论依据是:

    1. 每个时间步输入的 token 序列除最后一个 token 之外都和上一个时间步输入相同
    2. attention mask 是下三角矩阵
    3. attention score 矩阵做 softmax 是逐行的而不是全局的或逐行的
    4. self-attention 操作之外,GPT like LLM 的其他层(layer norm, gelu, FFN 等)对 seq_len 维度来说都是 element-wise 的,且在推理阶段是静态的
  • 以上四个理论依据在 GPT like 模型中都满足,下面解释四个条件为什么是必不可少的:

    1. 如果上面第一个条件不满足,那么每次推理都是新的内容,缓存将毫无意义
    2. 下三角矩阵和逐行做 softmax 是必不可少的,随着 token 长度加一,每一层的 attention score 矩阵都会在最右和最下分别加一列和一行,由于下三角的存在,除最后一行外,其他每一行的值都不会改变(新加的一列不参与计算),而且逐行做,新加的一列也不影响其他行
    3. self-attentionGPT like LLM 中唯一一个在 token feature 间做信息融合的算子,其他算子都是在 token feature 内做且都是静态操作,所以随着序列长度变长,之前的序列特征都是静态的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn


class SelfAttention(nn.Module):
def __init__(self, dim, layer_idx):
super(SelfAttention, self).__init__()
self.dim = dim
self.layer_idx = layer_idx
self.query = nn.Linear(dim, dim)
self.key = nn.Linear(dim, dim)
self.value = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1) # 重点是这里,逐行进行 softmax,而不是全局
self.dropout = nn.Dropout(0.1)
self.layer_norm = nn.LayerNorm(dim)

def forward(self, x, attn_mask=None):
q = self.query(x)
k = self.key(x)
v = self.value(x)

scores = torch.matmul(q, k.transpose(-2, -1))
scores = scores / torch.sqrt(torch.tensor(self.dim).float())

if attn_mask is not None:
scores = scores.masked_fill(attn_mask == 0, float("-inf"))

attention_weights = self.softmax(scores)
attention_weights = self.dropout(attention_weights)

print("layer_idx:", self.layer_idx)
print("q.sum(-1) =", q.squeeze().abs().sum(-1).tolist())
print("k.sum(-1) =", k.squeeze().abs().sum(-1).tolist())
print("v.sum(-1) =", v.squeeze().abs().sum(-1).tolist())
print("attention_weights =\n", attention_weights.squeeze())
print("#" * 100)

output = torch.matmul(attention_weights, v)
output = self.layer_norm(output + x)

return output


class GPTModel(nn.Module):
def __init__(self, dim, num_layers):
super(GPTModel, self).__init__()
self.dim = dim
self.num_layers = num_layers
self.attentions = nn.ModuleList(
[SelfAttention(dim, i) for i in range(num_layers)]
)
self.linear1 = nn.Linear(dim, dim)
self.linear2 = nn.Linear(dim, dim)
self.dropout = nn.Dropout(0.1)
self.layer_norm = nn.LayerNorm(dim)

def forward(self, x, attn_mask=None):
for i in range(self.num_layers):
x = x + self.attentions[i](x, attn_mask=attn_mask)
x = self.layer_norm(x)
x = x + self.linear2(self.dropout(torch.relu(self.linear1(x))))
x = self.layer_norm(x)

return x


torch.random.manual_seed(0)
# Example usage
model = GPTModel(dim=64, num_layers=2)
model.eval()
input_data = torch.randn(
1, 3, 64
) # Example input data with shape (batch_size, sequence_length, dim)
attn_mask = torch.tril(torch.ones(3, 3)) # Lower triangular attention mask
print("time step 1:")
output = model(input_data, attn_mask=attn_mask)

print("\n\ntime step 2:")
input_data = torch.cat((input_data, torch.randn(1, 1, 64)), 1)
attn_mask = torch.tril(torch.ones(4, 4)) # Lower triangular attention mask
output = model(input_data, attn_mask=attn_mask)

# 输出结果:
# time step 1:
# layer_idx: 0
# q.sum(-1) = [31.20478057861328, 34.04328536987305, 33.6611328125]
# k.sum(-1) = [31.446550369262695, 29.15508460998535, 29.265644073486328]
# v.sum(-1) = [30.789640426635742, 33.80058670043945, 32.4859619140625]
# attention_weights =
# tensor([[1.0000, 0.0000, 0.0000],
# [0.6760, 0.3240, 0.0000],
# [0.2452, 0.4188, 0.3360]], grad_fn=<SqueezeBackward0>)
# ####################################################################################################
# layer_idx: 1
# q.sum(-1) = [30.214162826538086, 34.45262145996094, 26.357181549072266]
# k.sum(-1) = [29.88446617126465, 31.186325073242188, 31.727235794067383]
# v.sum(-1) = [23.99290657043457, 35.66062545776367, 33.28850173950195]
# attention_weights =
# tensor([[1.0000, 0.0000, 0.0000],
# [0.3310, 0.6690, 0.0000],
# [0.4525, 0.3270, 0.2204]], grad_fn=<SqueezeBackward0>)
# ####################################################################################################


# time step 2:
# layer_idx: 0
# q.sum(-1) = [31.20478057861328, 34.04328536987305, 33.6611328125, 27.077472686767578]
# k.sum(-1) = [31.446550369262695, 29.15508460998535, 29.265644073486328, 27.211992263793945]
# v.sum(-1) = [30.789640426635742, 33.80058670043945, 32.4859619140625, 29.695066452026367]
# attention_weights =
# tensor([[1.0000, 0.0000, 0.0000, 0.0000],
# [0.6760, 0.3240, 0.0000, 0.0000],
# [0.2452, 0.4188, 0.3360, 0.0000],
# [0.1818, 0.2616, 0.2813, 0.2752]], grad_fn=<SqueezeBackward0>)
# ####################################################################################################
# layer_idx: 1
# q.sum(-1) = [30.214162826538086, 34.45262145996094, 26.357181549072266, 29.209003448486328]
# k.sum(-1) = [29.88446617126465, 31.186325073242188, 31.727235794067383, 26.988731384277344]
# v.sum(-1) = [23.99290657043457, 35.66062545776367, 33.28850173950195, 27.92920684814453]
# attention_weights =
# tensor([[1.0000, 0.0000, 0.0000, 0.0000],
# [0.3310, 0.6690, 0.0000, 0.0000],
# [0.4525, 0.3270, 0.2204, 0.0000],
# [0.2740, 0.0698, 0.4069, 0.2492]], grad_fn=<SqueezeBackward0>)
# ####################################################################################################

可以看到随着时间步的推移,每一层的 x / q / k / v / attention weight 的前 n-1token feature 都和上一个时间步相同,因此可以缓存

Thought

  • 实际上 q / k / v / attention weight 都可以缓存,只缓存 k / v 是计算量和 IO 量的 tradeoff

  • paged attention 是通过类似内存分页管理的方式管理 kv cache 的一种方法,动态分配显存,速度更快,支持更高 batch