Zhangzhe's Blog

The projection of my life.

0%

GQA: Grouped-Query Attention

URL

TL;DR

  • 本质是标准 Multi-Head AttentionMulti-Query Attention 的一个折衷。

GQA.png

Algorithm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn.functional as F


def multiquery_attention_batched(X, M, mask, P_q, P_k, P_v, P_o, num_groups):
# b: batch size
# n: number of querys
# m: number of keys / values
# d: input dimension
# H: number of heads
# g: number of groups
# h: number of heads per group
# k: key / query dimention
# v: value dimention

assert 0 == P_q.size(0) % num_groups
heads_per_group = P_q.size(0) // num_groups

# 应用线性变换获取查询Q(multi-head)
Q = torch.einsum("bnd,Hdk->bHkn", X, P_q).view(
-1, num_groups, heads_per_group, P_q.size(2), X.size(1)
)

# 应用线性变换获取键K和值V(single-head)
K = torch.einsum("bmd,gdk->bgmk", M, P_k)
V = torch.einsum("bmd,gdv->bgmv", M, P_v)

# 计算注意力得分并应用掩码
logits = torch.einsum("bghkn,bgmk->bghnm", Q, K)

# mask 分组
mask = mask.view(-1, num_groups, heads_per_group, mask.size(-2), mask.size(-1))
weights = F.softmax(logits + mask, dim=-1)

# 计算加权的值
O = torch.einsum("bghnm,bgmv->bghnv", weights, V)

# P_o 分组
P_o = P_o.view(num_groups, heads_per_group, P_o.size(-2), P_o.size(-1))

# 应用输出变换
Y = torch.einsum("bghnv,ghvd->bnd", O, P_o)

return Y


# 假设参数和维度
batch_size = 2
num_queries = 5
num_keys = 7 # 同时也是 num_values
dim = 64
key_dim = 32 # 同时也是 query_dim
value_dim = 48
num_heads = 8
num_groups = 4 # head 分成多少组,必须整除 num_heads

# 随机初始化输入数据和参数
X = torch.randn(batch_size, num_queries, dim)
M = torch.randn(batch_size, num_keys, dim)
mask = torch.randn(batch_size, num_heads, num_queries, num_keys)
P_q = torch.randn(num_heads, dim, key_dim)
P_k = torch.randn(num_groups, dim, key_dim) # 组件 key / value 不共享
P_v = torch.randn(num_groups, dim, value_dim)
P_o = torch.randn(num_heads, value_dim, dim)

# 运行多查询注意力机制
output = multiquery_attention_batched(X, M, mask, P_q, P_k, P_v, P_o, num_groups)
print(output.shape) # 应该输出: torch.Size([2, 5, 64])

Thought

  • GQA 介于 MHAMQA 之间,每个 group 内部实际上是 MQA