Zhangzhe's Blog

The projection of my life.

0%

MQA: Multi-Query Attention

URL

TL;DR

  • 本质很简单,就是把标准的 Multi-Head Attention 中的 KeyValue 退化为 Single-HeadQuery 保留 Multi-Head

Algorithm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn.functional as F


def multiquery_attention_batched(X, M, mask, P_q, P_k, P_v, P_o):
# b: batch size
# n: number of querys
# m: number of keys / values
# d: input dimension
# h: number of heads
# k: key / query dimention
# v: value dimention

# 应用线性变换获取查询Q(multi-head)
Q = torch.einsum("bnd,hdk->bhkn", X, P_q)

# 应用线性变换获取键K和值V(single-head)
K = torch.einsum("bmd,dk->bmk", M, P_k)
V = torch.einsum("bmd,dv->bmv", M, P_v)

# 计算注意力得分并应用掩码
logits = torch.einsum("bhkn,bmk->bhnm", Q, K)
weights = F.softmax(logits + mask, dim=-1)

# 计算加权的值
O = torch.einsum("bhnm,bmv->bhnv", weights, V)

# 应用输出变换
Y = torch.einsum("bhnv,hvd->bnd", O, P_o)

return Y


# 假设参数和维度
batch_size = 2
num_queries = 5
num_keys = 7 # 同时也是 num_values
dim = 64
key_dim = 32 # 同时也是 query_dim
value_dim = 48
num_heads = 8

# 随机初始化输入数据和参数
X = torch.randn(batch_size, num_queries, dim)
M = torch.randn(batch_size, num_keys, dim)
mask = torch.randn(batch_size, num_heads, num_queries, num_keys)
P_q = torch.randn(num_heads, dim, key_dim)
P_k = torch.randn(dim, key_dim)
P_v = torch.randn(dim, value_dim)
P_o = torch.randn(num_heads, value_dim, dim)

# 运行多查询注意力机制
output = multiquery_attention_batched(X, M, mask, P_q, P_k, P_v, P_o)
print(output.shape) # 输出: torch.Size([2, 5, 64])

Thought

  • 本质是标准 Attention 的部分 multi-head