Zhangzhe's Blog

The projection of my life.

0%

URL

TL;DR

  • 常用的 top-p random sampling decoding method 可能会导致输出的长文本质量较差
  • 本文设计了一种基于熵的动态概率阈值优化 top-p 随机采样算法,叫做 ηsampling\eta-sampling

Algorithm

公式表示

Ax<i={xVPθ(xx<i)>η}A_{x \lt i}=\{x \in V | P_{\theta}(x | x \lt i) \gt \eta \}

η=min(ϵ,α(hη,x<i))\eta=min(\epsilon,\alpha(-h_{\eta,x<i}))

其中第一行公式表示随机采样概率值截断到 η\eta
第二行中 ϵ,α\epsilon,\alpha 是超参数,通常 α=ϵ\alpha=\sqrt{\epsilon}h 表示输出的熵,h=(plog(p)).sum()h=-(p*log(p)).sum()p 表示概率

代码表示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class EtaWarper(transformers.LogitsWarper):
"""Our proposed eta sampling warper."""
def __init__(self, epsilon):
self.epsilon = epsilon
self.filter_value = -float("Inf")
def __call__(self, input_ids, scores) -> torch.FloatTensor:
probabilities = scores.softmax(dim=-1)
entropy = - (probabilities * torch.log(probabilities)).sum()
eta = min(self.epsilon, torch.sqrt(torch.tensor(self.epsilon))*torch.exp(-entropy))
indices_to_remove = probabilities < eta
max_word = torch.argmax(scores,dim=-1)
# 防止把候选词全删了,至少留一个
indices_to_remove[...,max_word.squeeze()] = 0
new_scores = scores.masked_fill(indices_to_remove, self.filter_value)
return new_scores

Thought

  • top-p 的小升级,计算量较小,算是一个小 trick

URL

TL;DR

  • 本文提出一种叫做 Contrastive search 的解码方法,比 top-k / top-p 等随机解码方法和 argmax 这种贪心解码方法更优,模型输出内容语意更连贯,模型更不容易退化
  • Contrastive search 是在 top-k 基础上增加了 degeneration penalty(退化惩罚项),算是 random sampling 类算法的升级版
  • 但需要将候选的 top-k 输出追加到序列末尾作为输入重新输入到模型中输出 top-k 对应的 hidden state,因此时间复杂度较高

Algorithm

总体架构

contrastive_search_1.png

其中 α\alpha 是超参数,用来平衡模型预测结果和退化惩罚项
h 表示 hidden state,用 hvh_v 和序列中每一个 tokenhidden state 计算相似度
由于 hvh_v 只能在下一次 inference 时产生,所以需要在 decoding 阶段再推理一次,这一步比较麻烦

用一段代码来描述一下解码过程

  1. 首先在 t-1 时刻输入 seq_len 长度的序列,输出 vocab_size 长度的 logits,选 top-k
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# compute the logits in this step
prev_hidden_states, logits = model.compute_logits_and_hidden_states(input_ids)
_, seqlen, embed_dim = prev_hidden_states.size()
_, _, vocab_size = logits.size()
logit_for_next_step = logits[:,-1,:]
assert logit_for_next_step.size() == torch.Size([1, vocab_size])
# normalize with the softmax function
next_probs = F.softmax(logit_for_next_step, dim = -1)
assert next_probs.size() == logit_for_next_step.size()
# collecte top-k candidate tokens and their logits (model confidence)
_, next_top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width)
assert next_top_k_ids.size() == torch.Size([1, beam_width])

next_top_k_probs = torch.gather(next_probs, dim = 1, index=next_top_k_ids)
assert next_top_k_probs.size() == next_top_k_ids.size()
  1. top-k 预测的每一种可能追加到原始序列,并在 batch size 维度 concat 成一个 shape = [beam_width, seq_len + 1, dim] 的输入
1
2
3
4
5
6
7
# concatenate each candidate token with prefix
expanded_context = [input_ids for _ in range(beam_width)]
expanded_context = torch.cat(expanded_context, dim = 0)
assert expanded_context.size() == torch.Size([beam_width, seqlen])
top_k_ids = top_k_ids.view(beam_width, 1)
next_input_ids = torch.cat([expanded_context, top_k_ids], dim = -1)
assert next_input_ids.size() == torch.Size([beam_width, seqlen+1])
  1. shape = [beam_width, seq_len + 1, dim] 输入到模型中,计算得到每个 tokenhidden state
1
2
3
4
5
6
7
# feed these candidates into next round to get their hidden states
new_hidden_states, next_logits = model.compute_logits_and_hidden_states(next_input_ids)
assert new_hidden_states.size() == torch.Size([beam_width, seqlen+1, embed_dim])
context_hidden = new_hidden_states[:,:seqlen,:]
assert context_hidden.size() == torch.Size([beam_width, seqlen, embed_dim])
next_hidden = new_hidden_states[:,seqlen:,:]
assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim])
  1. 计算每一个 next hidden statecontext hidden state 之间的相关性,相关性越高,说明重复性越高,惩罚越重,选择此 token 作为 next input id 的可能性越低
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def ranking(context_hidden, next_hidden, next_top_k_ids, next_top_k_probs, alpha):
'''
context_hidden: beam_width x context_len x embed_dim
next_hidden: beam_width x 1 x embed_dim
next_top_k_ids: beam_width x 1
'''
beam_width, context_len, embed_dim = context_hidden.size()
assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim])
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)
assert cosine_matrix.size() == torch.Size([beam_width, context_len])
scores, _ = torch.max(cosine_matrix, dim = -1)
assert scores.size() == torch.Size([beam_width])
next_top_k_probs = next_top_k_probs.view(-1)
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
_, selected_idx = torch.topk(scores, k = 1)
assert selected_idx.size() == torch.Size([1])
selected_idx = selected_idx.unsqueeze(0)
assert selected_idx.size() == torch.Size([1,1])
next_id = torch.gather(next_top_k_ids, dim = 0, index=selected_idx)
assert next_id.size() == torch.Size([1,1])
return next_id

Thought

  • 道理上讲得通,但多推理的一次是实实在在的计算代价,总计算量直接翻倍,真的会有大模型选择这种解码方式吗?

URL

TL;DR

  • 本文提出一种比 T5 bias 更简单的 position embedding 方法叫做 ALiBi (Attention with Linear Bias),简单好用
  • 可以在短数据集上训练,在长数据集上测试,即具有外推性

Algorithm

T5 bias

  • 先讲一下 T5 bias 是如何实现 position embedding 的,主要分三步:
    1. 计算 query / keyn * n 相对位置矩阵,形如:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      [[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
      [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8],
      [-2, -1, 0, 1, 2, 3, 4, 5, 6, 7],
      [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6],
      [-4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
      [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4],
      [-6, -5, -4, -3, -2, -1, 0, 1, 2, 3],
      [-7, -6, -5, -4, -3, -2, -1, 0, 1, 2],
      [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1],
      [-9, -8, -7, -6, -5, -4, -3, -2, -1, 0]]
    2. 将相对位置矩阵分桶(超过 num_buckets 的饱和到 num_buckets
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      [[ 0, 17, 18, 19, 20, 21, 22, 23, 24, 24],
      [ 1, 0, 17, 18, 19, 20, 21, 22, 23, 24],
      [ 2, 1, 0, 17, 18, 19, 20, 21, 22, 23],
      [ 3, 2, 1, 0, 17, 18, 19, 20, 21, 22],
      [ 4, 3, 2, 1, 0, 17, 18, 19, 20, 21],
      [ 5, 4, 3, 2, 1, 0, 17, 18, 19, 20],
      [ 6, 5, 4, 3, 2, 1, 0, 17, 18, 19],
      [ 7, 6, 5, 4, 3, 2, 1, 0, 17, 18],
      [ 8, 7, 6, 5, 4, 3, 2, 1, 0, 17],
      [ 8, 8, 7, 6, 5, 4, 3, 2, 1, 0]]

      这里上三角和下三角都有值是因为 encoder bidirection=True,如果是 decoder,则如下:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
      [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
      [2, 1, 0, 0, 0, 0, 0, 0, 0, 0],
      [3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
      [4, 3, 2, 1, 0, 0, 0, 0, 0, 0],
      [5, 4, 3, 2, 1, 0, 0, 0, 0, 0],
      [6, 5, 4, 3, 2, 1, 0, 0, 0, 0],
      [7, 6, 5, 4, 3, 2, 1, 0, 0, 0],
      [8, 7, 6, 5, 4, 3, 2, 1, 0, 0],
      [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]]
    3. 最后是将此 n * nrelative position bucket 通过可学习的 embedding 函数变成 n * n * num_heads 的向量,和每个头的 attention score(softmax 之前) 相加,然后通过逐行 softmax 得到 attention weight

ALiBi

alibi_1.png

  • 用数学公式表示:softmax(qiKT+m[(i1),...,2,1,0])softmax(q_iK^T+m\cdot[-(i-1),...,-2,-1,0])
  • ALiBi 的计算和 T5 bias 的前两步几乎一模一样
  • 第三步不再使用可学习的 embedding 函数映射到每个头上,而是将距离矩阵的值和每个头独立的 不可学习的 常量 m 值相乘,然后和 attention score 相加
  • mh=b(2(8/H)b)hm_h = \frac{b}{(2^{(8/H)} \cdot b)^h}
    • b 是一个基数
    • H 是注意力头的数量
    • h 是注意力头的索引(从 0H-1

Thought

  • 标准 attentionpeRn×dpe \in \mathbb{R}^{n\times d} 已经慢慢被淘汰了,不管是 RoPE / T5 Bias / ALiBi 都已经逐渐演变成 peRn×npe \in \mathbb{R}^{n\times n} 直接作用在 attention score 上了
  • ALiBi 的外推性其实本质是强行饱和掉远距离,有点过于粗暴了…

URL

TL;DR

  • GPT like 大语言模型目前在生成式大模型领域占了主导地位,这些模型每一次推理只得到一个 token,下一次推理会将得到的 token 放到上一次输入 token 序列的末尾作为新输入,直到达到 max_seq_length 或输出的 token 表示 end of sequence
  • 这种 GPT like 的生成机制意味着每次推理需要将之前所有 token 再重复算一次,序列长度越长,计算量越大,耗时越长
  • KV Cache 通过缓存上一次推理过程中每一层 transformerkey / value,将推理的时间复杂度从 O(n2)\mathcal{O}(n^2) 降低到了 O(n)\mathcal{O}(n),且计算结果完全等价

Algorithm

  • 可以在本时间步使用上一个时间步缓存的 key / value 的必要理论依据是:
    1. 每个时间步输入的 token 序列除最后一个 token 之外都和上一个时间步输入相同
    2. attention mask 是下三角矩阵
    3. attention score 矩阵做 softmax 是逐行的而不是全局的或逐列的
    4. self-attention 操作之外,GPT like LLM 的其他层(layer norm, gelu, FFN 等)对 seq_len 维度来说都是 element-wise 的,且在推理阶段是静态的
  • 以上四个理论依据在 GPT like 模型中都满足,下面解释四个条件为什么是必不可少的:
    1. 如果上面第一个条件不满足,那么每次推理都是新的内容,缓存将毫无意义
    2. 下三角矩阵和逐行做 softmax 是必不可少的,随着 token 长度加一,每一层的 attention score 矩阵都会在最右和最下分别加一列和一行,由于下三角的存在,除最后一行外,其他每一行的值都不会改变(新加的一列不参与计算),而且逐行做,新加的一列也不影响其他行
    3. self-attentionGPT like LLM 中唯一一个在 token feature 间做信息融合的算子,其他算子都是在 token feature 内做且都是静态操作,所以随着序列长度变长,之前的序列特征都是静态的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, dim, layer_idx):
super(SelfAttention, self).__init__()
self.dim = dim
self.layer_idx = layer_idx
self.query = nn.Linear(dim, dim)
self.key = nn.Linear(dim, dim)
self.value = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1) # 重点是这里,逐行进行 softmax,而不是全局
self.dropout = nn.Dropout(0.1)
self.layer_norm = nn.LayerNorm(dim)
def forward(self, x, attn_mask=None):
q = self.query(x)
k = self.key(x)
v = self.value(x)
scores = torch.matmul(q, k.transpose(-2, -1))
scores = scores / torch.sqrt(torch.tensor(self.dim).float())
if attn_mask is not None:
scores = scores.masked_fill(attn_mask == 0, float("-inf"))
attention_weights = self.softmax(scores)
attention_weights = self.dropout(attention_weights)
print("layer_idx:", self.layer_idx)
print("q.sum(-1) =", q.squeeze().abs().sum(-1).tolist())
print("k.sum(-1) =", k.squeeze().abs().sum(-1).tolist())
print("v.sum(-1) =", v.squeeze().abs().sum(-1).tolist())
print("attention_weights =\n", attention_weights.squeeze())
print("#" * 100)
output = torch.matmul(attention_weights, v)
output = self.layer_norm(output + x)
return output
class GPTModel(nn.Module):
def __init__(self, dim, num_layers):
super(GPTModel, self).__init__()
self.dim = dim
self.num_layers = num_layers
self.attentions = nn.ModuleList(
[SelfAttention(dim, i) for i in range(num_layers)]
)
self.linear1 = nn.Linear(dim, dim)
self.linear2 = nn.Linear(dim, dim)
self.dropout = nn.Dropout(0.1)
self.layer_norm = nn.LayerNorm(dim)
def forward(self, x, attn_mask=None):
for i in range(self.num_layers):
x = x + self.attentions[i](x, attn_mask=attn_mask)
x = self.layer_norm(x)
x = x + self.linear2(self.dropout(torch.relu(self.linear1(x))))
x = self.layer_norm(x)
return x
# Example usage
model = GPTModel(dim=64, num_layers=2)
model.eval()
input_data = torch.randn(
1, 3, 64
) # Example input data with shape (batch_size, sequence_length, dim)
attn_mask = torch.tril(torch.ones(3, 3)) # Lower triangular attention mask
print("time step 1:")
output = model(input_data, attn_mask=attn_mask)
print("\n\ntime step 2:")
input_data = torch.cat((input_data, torch.randn(1, 1, 64)), 1)
attn_mask = torch.tril(torch.ones(4, 4)) # Lower triangular attention mask
output = model(input_data, attn_mask=attn_mask)
# 输出结果:
# time step 1:
# layer_idx: 0
# q.sum(-1) = [31.20478057861328, 34.04328536987305, 33.6611328125]
# k.sum(-1) = [31.446550369262695, 29.15508460998535, 29.265644073486328]
# v.sum(-1) = [30.789640426635742, 33.80058670043945, 32.4859619140625]
# attention_weights =
# tensor([[1.0000, 0.0000, 0.0000],
# [0.6760, 0.3240, 0.0000],
# [0.2452, 0.4188, 0.3360]], grad_fn=<SqueezeBackward0>)
# ####################################################################################################
# layer_idx: 1
# q.sum(-1) = [30.214162826538086, 34.45262145996094, 26.357181549072266]
# k.sum(-1) = [29.88446617126465, 31.186325073242188, 31.727235794067383]
# v.sum(-1) = [23.99290657043457, 35.66062545776367, 33.28850173950195]
# attention_weights =
# tensor([[1.0000, 0.0000, 0.0000],
# [0.3310, 0.6690, 0.0000],
# [0.4525, 0.3270, 0.2204]], grad_fn=<SqueezeBackward0>)
# ####################################################################################################
# time step 2:
# layer_idx: 0
# q.sum(-1) = [31.20478057861328, 34.04328536987305, 33.6611328125, 27.077472686767578]
# k.sum(-1) = [31.446550369262695, 29.15508460998535, 29.265644073486328, 27.211992263793945]
# v.sum(-1) = [30.789640426635742, 33.80058670043945, 32.4859619140625, 29.695066452026367]
# attention_weights =
# tensor([[1.0000, 0.0000, 0.0000, 0.0000],
# [0.6760, 0.3240, 0.0000, 0.0000],
# [0.2452, 0.4188, 0.3360, 0.0000],
# [0.1818, 0.2616, 0.2813, 0.2752]], grad_fn=<SqueezeBackward0>)
# ####################################################################################################
# layer_idx: 1
# q.sum(-1) = [30.214162826538086, 34.45262145996094, 26.357181549072266, 29.209003448486328]
# k.sum(-1) = [29.88446617126465, 31.186325073242188, 31.727235794067383, 26.988731384277344]
# v.sum(-1) = [23.99290657043457, 35.66062545776367, 33.28850173950195, 27.92920684814453]
# attention_weights =
# tensor([[1.0000, 0.0000, 0.0000, 0.0000],
# [0.3310, 0.6690, 0.0000, 0.0000],
# [0.4525, 0.3270, 0.2204, 0.0000],
# [0.2740, 0.0698, 0.4069, 0.2492]], grad_fn=<SqueezeBackward0>)
# ####################################################################################################

可以看到随着时间步的推移,每一层的 x / q / k / v / attention weight 的前 n-1token feature 都和上一个时间步相同,因此可以缓存

Thought

  • 由于在 t 时间步,需要计算 QtQ^tK0,1,...,t, V0,1,...,tK^{0,1,...,t},\ V^{0,1,...,t} 之间的内积,所以需要换成 key / value 的缓存,而不是 query 的缓存
  • paged attention 是通过类似内存分页管理的方式管理 kv cache 的一种方法,动态分配显存,速度更快,支持更高 batch

URL

TL;DR

  • FlashAttention-2FlashAttention 一作对前一版的更新,基本思想不变,只优化了部分算法流程,可以比 V1 提速一倍

Algorithm

flash_attention_2.png

  • Flash Attention 2 前向过程主要优化了几个方面:
    1. Flash Attention v1 中计算 P_ij 计算过程很绕,需要更新两次全局信息,v2 中进行了优化
    2. Flash Attention v1 中内外层循环设置不合理,v2 对调了内外层循环,这样做的好处是:
    3. 不再需要重复读入写出 O
    4. max_value 信息变成一个计算中间变量,不再需要读入写出,backward 过程也无需依赖
  • Flash Attention v2Flash Attention v1IO 量对比
    • v1
      • 读:2Nd+2TcNd2Nd+2T_cNd
      • 写:TcNd+2TcNT_cNd+2T_cN
    • v2
      • 读:Nd+2TrNdNd+2T_rNd
      • 写:Nd+NNd+N
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import math
BLOCK_SIZE = 1024
NEG_INF = -1e10 # -infinity
EPSILON = 1e-10
def flash_attention2_forward(Q, K, V):
batch, seq_len, dim = Q.shape
O = torch.zeros_like(Q, requires_grad=True, device=Q.device)
l = torch.zeros((batch, seq_len, 1), device=Q.device) # 1 for broadcasting
m = torch.ones((batch, seq_len, 1), device=Q.device) * NEG_INF
Q = list(torch.split(Q, BLOCK_SIZE, dim=1))
K = list(torch.split(K, BLOCK_SIZE, dim=1))
V = list(torch.split(V, BLOCK_SIZE, dim=1))
O = list(torch.split(O, BLOCK_SIZE, dim=1))
l = list(torch.split(l, BLOCK_SIZE, dim=1))
m = list(torch.split(m, BLOCK_SIZE, dim=1))
Tr = len(Q)
Tc = len(K)
scale = 1 / math.sqrt(dim)
for i in range(Tr):
Qi_scaled = Q[i] * scale
for j in range(Tc):
S_ij = torch.einsum("bnd,bmd->bnm", Qi_scaled, K[j])
m_ij, _ = torch.max(S_ij, dim=-1, keepdim=True)
mi_new = torch.maximum(m_ij, m[i])
P_ij = torch.exp(S_ij - mi_new)
P_ijV_j = torch.einsum("bnm,bmd->bnd", P_ij, V[j])
exp_row_max_diff = torch.exp(m[i] - mi_new)
li_new = l[i] * exp_row_max_diff + P_ij.sum(dim=-1, keepdim=True)
O[i] = O[i] * exp_row_max_diff + P_ijV_j
m[i] = mi_new
l[i] = li_new
O[i] = O[i] / l[i]
O = torch.cat(O, dim=1)
l = torch.cat(l, dim=1)
return O, l
def flash_attention2_backward(Q, K, V, O, l, dO):
# 参考代码中进行了实现,但这里省略,重点关注一下 backward 过程的输入
pass
if __name__ == "__main__":
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# batch, seq_len, dim
Q = torch.randn(2, 4096, 128, requires_grad=True).to(device)
K = torch.randn(2, 4096, 128, requires_grad=True).to(device)
V = torch.randn(2, 4096, 128, requires_grad=True).to(device)
flash2_out = flash_attention2_forward(Q, K, V)
print(flash2_out[0].sum().item())
  • backward 过程在 v2 论文中详细的写出来了,如下:
    flash2_2.png
  • 速度效果提升很明显
    flash2_3.png

Thought

  • Flash Attention v2 看起来特别像 Flash Attention 的紧急修复版本,把之前写的冗余删掉了
  • 甚至完全不了解 Flash Attention 思想的人看 Flash Attention 的实现代码,都可以将其优化到 Flash Attention v2,就是简单的代码优化…
  • Flash Attention v2 彻底把这个 tiling 思路的优势发挥出来了

URL

TL;DR

  • Flash Attention 是标准 Attention 的一种完全数学等价的高效实现
  • 标准 Self-attention Attention(Q,K,V)=softmax(QKTd)VAttention(Q,K,V)=softmax({\frac{QK^T}{\sqrt d}})VO(n2)\mathcal{O}(n^2) 时间复杂度的,随着大模型的发展,序列长度 n 越来越大,带来的时间复杂度和 IO 量越来越恐怖
  • GPUNPU 等计算设备的 SRAM 空间有限,无法放下所有后续依赖的计算中间结果,而 softmax 运算又需要全局信息(无法直接简单 tiling),所以这里存在巨大的 IO 量(HBM (high bandwidth memory)SRAM 之间)
  • 本文设计了一种 tiling 的局部计算 self-attention 的方法,这种方法在 forwardbackward 中都可以使用
  • forward 中,FlashAttention 会计算并存储 softmax 操作的归一化因子(下图中的 lm),在 backward 中,利用这些归一化因子,重新计算所需的中间注意力矩阵,而不需要从 HBM 中读取存储的完整矩阵
  • 上述两点改进可以保证和标准 self-attention 计算结果(包括 forwardbackward)完全一致的情况下,极大降低 IO 量(代价是计算量的小幅上涨)

Algorithm

1. Standard attention implementation

flash_1.png

  • 这里之所以要频繁和 HBMblock 为单位交互是因为有一个限制是 dMNdd\le M \le Nd,即 SRAM 可存储的数据量在 dNd 之间(Nd 分别是 seq_lenmodel_dim
  • IO 总量:
    • 读:2N2+3Nd2N^2+3Nd
      • Q / K / V
      • S / P
    • 写:2N2+Nd2N^2+Nd
      • S / P
      • O

2. Flash attention forward implementation

flash_2.png

  • IO 总量:
    • 读:2Nd+2TcNd2Nd+2T_cNd
      • K / V 一次
      • Q / O TcT_c
    • 写:TcNd+2TcNT_cNd+2T_cN
      • O TcT_c
      • l / m TcT_c
  • 由上述可知:
    • 4Tc4N4\le T_c\le 4N
    • 因此,读写上下限分别为:
    • 读:
      • 上限:2Nd+8N2d2Nd+8N^2d
      • 下限:10Nd10Nd
    • 写:
      • 上限:4N2d+8N24N^2d+8N^2
      • 下限:4Nd+8N4Nd+8N
    • Block size 越大,总 IO 量越小

3. 详解 Flash attention 过程

  • 标准 softmax 过程:softmax(xij)=exijmax(x)i=0Nj=0Nexijmax(x)softmax(x_{ij})=\frac{e^{x_{ij}-max(x)}}{\sum_{i=0}^{N}\sum_{j=0}^{N}e^{x_{ij}-max(x)}},为了计算过程的数值稳定,需要对每个值减去序列最大值
  • Flash attention 中,会将 QKV 都在时序维度切分成若干块,然后在小块之间计算 Attention,直接计算是不等价的,因为需要全局信息,例如:
    • 全局最大值 m(x)
    • 指数累加和 l(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import math
BLOCK_SIZE = 1024
NEG_INF = -1e10 # -infinity
EPSILON = 1e-10
def flash_attention_forward(Q, K, V):
batch, seq_len, dim = Q.shape
O = torch.zeros_like(Q, requires_grad=True, device=Q.device)
l = torch.zeros((batch, seq_len, 1), device=Q.device) # 1 for broadcasting
m = torch.ones((batch, seq_len, 1), device=Q.device) * NEG_INF
Q_BLOCK_SIZE = min(BLOCK_SIZE, dim)
KV_BLOCK_SIZE = BLOCK_SIZE
Q = torch.split(Q, Q_BLOCK_SIZE, dim=1)
K = torch.split(K, KV_BLOCK_SIZE, dim=1)
V = torch.split(V, KV_BLOCK_SIZE, dim=1)
O = list(torch.split(O, Q_BLOCK_SIZE, dim=1))
l = list(torch.split(l, Q_BLOCK_SIZE, dim=1))
m = list(torch.split(m, Q_BLOCK_SIZE, dim=1))
Tr = len(Q)
Tc = len(K)
scale = 1 / math.sqrt(dim)
for j in range(Tc):
for i in range(Tr):
Qi_scaled = Q[i] * scale
S_ij = torch.einsum("... i d, ... j d -> ... i j", Qi_scaled, K[j])
m_ij, _ = torch.max(S_ij, dim=-1, keepdim=True)
P_ij = torch.exp(S_ij - m_ij)
mi_new = torch.maximum(m_ij, m[i])
l_ij = (
torch.sum(P_ij * torch.exp(m_ij - mi_new), dim=-1, keepdim=True)
+ EPSILON
)
P_ij_Vj = torch.einsum(
"... i j, ... j d -> ... i d", P_ij * torch.exp(m_ij - mi_new), V[j]
)
li_new = torch.exp(m[i] - mi_new) * l[i] + l_ij
O[i] = (O[i] * l[i] * torch.exp(m[i] - mi_new) + P_ij_Vj) / li_new
l[i] = li_new
m[i] = mi_new
O = torch.cat(O, dim=1)
l = torch.cat(l, dim=1)
m = torch.cat(m, dim=1)
return O, l, m
def flash_attention_backward(Q, K, V, O, l, m, dO):
# 参考代码中进行了实现,但这里省略,重点关注一下 backward 过程的输入
pass
if __name__ == "__main__":
# batch, seq_len, dim
Q = torch.randn(2, 4096, 128, requires_grad=True).to(device="cuda")
K = torch.randn(2, 4096, 128, requires_grad=True).to(device="cuda")
V = torch.randn(2, 4096, 128, requires_grad=True).to(device="cuda")
out = flash_attention_forward(Q, K, V)
print(out[0].sum().item())
  • 这里最重要也最难理解的一行代码是:O[i] = (O[i] * l[i] * torch.exp(m[i] - mi_new) + P_ij_Vj) / li_new,这个仔细分析一下
    • O[i] * l[i] 是计算得到上次外循环(j-1)时的分子
    • O[i] * l[i] * torch.exp(m[i] - mi_new) 是把 j-1 时刻存储的信息更新到最新(调整 max_value
    • O[i] * l[i] * torch.exp(m[i] - mi_new) + P_ij_Vj 因为 O 的初始值是 0,因此这个公式可以拆解为:Oi=j=1TcPijVjliO_i = \sum_{j=1}^{T_c} \frac{P_{ij} * V_j}{l_i},通过不断更新 max_value 和累加求和得到最终结果

Thought

  • Flash Attention 的效率和 block size 相关,block size 越大,效率提升越明显;在一些 block size 较小的情况下,效率可能会比标准 Attention 更差
  • Flash Attention 目标是解决 Memory hierarchy 情况下标注 Attention 计算访存比较低导致 IO 开销大的问题,是一种以时间换空间的做法;当 SRAM 相比模型足够大时,标准 Attention 效率比 Flash Attention 更高
  • backward 过程也是通过几乎相同的方式计算的,forward 阶段计算得到的中间 l / m 都会用于计算 backward

URL

TL;DR

  • 本质是标准 Multi-Head AttentionMulti-Query Attention 的一个折衷。
    GQA.png
  • 目的是 降低 Attention 过程内存带宽成本,并没有降低计算复杂度,计算复杂度依然是 O(n2)\mathcal{O}(n^2)

Algorithm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn.functional as F
def multiquery_attention_batched(X, M, mask, P_q, P_k, P_v, P_o, num_groups):
# b: batch size
# n: number of querys
# m: number of keys / values
# d: input dimension
# H: number of heads
# g: number of groups
# h: number of heads per group
# k: key / query dimention
# v: value dimention
assert 0 == P_q.size(0) % num_groups
heads_per_group = P_q.size(0) // num_groups
# 应用线性变换获取查询Q(multi-head)
Q = torch.einsum("bnd,Hdk->bHkn", X, P_q).view(
-1, num_groups, heads_per_group, P_q.size(2), X.size(1)
)
# 应用线性变换获取键K和值V(single-head)
K = torch.einsum("bmd,gdk->bgmk", M, P_k)
V = torch.einsum("bmd,gdv->bgmv", M, P_v)
# 计算注意力得分并应用掩码
logits = torch.einsum("bghkn,bgmk->bghnm", Q, K)
# mask 分组
mask = mask.view(-1, num_groups, heads_per_group, mask.size(-2), mask.size(-1))
weights = F.softmax(logits + mask, dim=-1)
# 计算加权的值
O = torch.einsum("bghnm,bgmv->bghnv", weights, V)
# P_o 分组
P_o = P_o.view(num_groups, heads_per_group, P_o.size(-2), P_o.size(-1))
# 应用输出变换
Y = torch.einsum("bghnv,ghvd->bnd", O, P_o)
return Y
# 假设参数和维度
batch_size = 2
num_queries = 5
num_keys = 7 # 同时也是 num_values
dim = 64
key_dim = 32 # 同时也是 query_dim
value_dim = 48
num_heads = 8
num_groups = 4 # head 分成多少组,必须整除 num_heads
# 随机初始化输入数据和参数
X = torch.randn(batch_size, num_queries, dim)
M = torch.randn(batch_size, num_keys, dim)
mask = torch.randn(batch_size, num_heads, num_queries, num_keys)
P_q = torch.randn(num_heads, dim, key_dim)
P_k = torch.randn(num_groups, dim, key_dim) # 组件 key / value 不共享
P_v = torch.randn(num_groups, dim, value_dim)
P_o = torch.randn(num_heads, value_dim, dim)
# 运行多查询注意力机制
output = multiquery_attention_batched(X, M, mask, P_q, P_k, P_v, P_o, num_groups)
print(output.shape) # 应该输出: torch.Size([2, 5, 64])

Thought

  • GQA 介于 MHAMQA 之间,每个 group 内部实际上是 MQA

URL

TL;DR

  • 本质很简单,就是把标准的 Multi-Head Attention 中的 KeyValue 退化为 Single-HeadQuery 保留 Multi-Head
  • 目的是 降低 Attention 过程内存带宽成本,并没有降低计算复杂度,计算复杂度依然是 O(n2)\mathcal{O}(n^2)

Algorithm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn.functional as F
def multiquery_attention_batched(X, M, mask, P_q, P_k, P_v, P_o):
# b: batch size
# n: number of querys
# m: number of keys / values
# d: input dimension
# h: number of heads
# k: key / query dimention
# v: value dimention
# 应用线性变换获取查询Q(multi-head)
Q = torch.einsum("bnd,hdk->bhkn", X, P_q)
# 应用线性变换获取键K和值V(single-head)
K = torch.einsum("bmd,dk->bmk", M, P_k)
V = torch.einsum("bmd,dv->bmv", M, P_v)
# 计算注意力得分并应用掩码
logits = torch.einsum("bhkn,bmk->bhnm", Q, K)
weights = F.softmax(logits + mask, dim=-1)
# 计算加权的值
O = torch.einsum("bhnm,bmv->bhnv", weights, V)
# 应用输出变换
Y = torch.einsum("bhnv,hvd->bnd", O, P_o)
return Y
# 假设参数和维度
batch_size = 2
num_queries = 5
num_keys = 7 # 同时也是 num_values
dim = 64
key_dim = 32 # 同时也是 query_dim
value_dim = 48
num_heads = 8
# 随机初始化输入数据和参数
X = torch.randn(batch_size, num_queries, dim)
M = torch.randn(batch_size, num_keys, dim)
mask = torch.randn(batch_size, num_heads, num_queries, num_keys)
P_q = torch.randn(num_heads, dim, key_dim)
P_k = torch.randn(dim, key_dim)
P_v = torch.randn(dim, value_dim)
P_o = torch.randn(num_heads, value_dim, dim)
# 运行多查询注意力机制
output = multiquery_attention_batched(X, M, mask, P_q, P_k, P_v, P_o)
print(output.shape) # 输出: torch.Size([2, 5, 64])

Thought

  • 本质是标准 Attention 的部分 multi-head

LLM / 自动驾驶相关

  1. LN和BN
  2. 介绍RLHF
  3. 介绍MoE和变体
  4. 介绍LoRA和变体
  5. LoRA 参数更新机制
  6. 如何减轻LLM中的幻觉现象?
  7. MLM和MIM的关系和区别?
  8. Stable Diffusion的技术原理
  9. 解決LLM Hallucination的方法
  10. Occupancy预测的出发点是什么?
  11. 介绍RWKV、Mamba和Mamba-2
  12. 2D图像预训练怎么迁移到3D点云任务
  13. 为什么现在的LLM都是Decoder only的架构?
  14. 把Transformer模型训深的问题有哪些?怎么解决
  15. 现在车道线检测的主流的loss是什么?你有哪些想法?
  16. 为什么GAN中经常遇到mode collapse,而Diffusion比较少?

Transformer 相关

  1. 介绍Transformer和ViT
  2. 介绍Transformer的QKV
  3. 介绍Layer Normalization
  4. Transformer训练和部署技巧
  5. 介绍Transformer的位置编码
  6. 介绍自注意力机制和数学公式
  7. 介绍Transformer的Encoder模块
  8. 介绍Transformer的Decoder模块
  9. Transformer和Mamba(SSM)的区别
  10. Transformer中的残差结构以及意义
  11. 为什么Transformer适合多模态任务?
  12. Transformer的并行化体现在哪个地方?
  13. 为什么Transformer一般使用LayerNorm?
  14. Transformer为什么使用多头注意力机制?
  15. Transformer训练的Dropout是如何设定的?

URL

TL;DR

  • HiPPO 全称是 High-order Polynomial Projection Operators,是 SSM (State Space Model) 的开山之作,作者延续 SSM 思路,后续做出了可以和 Transformer 结构掰手腕的 Mamba
  • HiPPO 的目标是用一个有限维的向量来储存这一段 u(t) 的信息,实现方式是将 u(t) 通过 Legendre (勒让德)多项式展开,用有限维的向量存储勒让德多项式系数,且这些 向量的值通过求解勒让德多项式得出,不在训练过程中通过梯度下降更新
  • HiPPO 可以给 RNN 提供一种记忆表示方法,因此一个实际的用处是使用 HiPPO 作为 RNN 的记忆存储算子

Algorithm

  • 勒让德多项式本身是定义在连续函数上的,但实际使用中需要记忆的内容是离散的,因此需要离散化过程
  • HiPPO 的记忆更新公式是 m(t+1)=Am(t)+Bu(t)m(t+1) = Am(t)+Bu(t),其中 ABHiPPO 参数,m(t) 表示记忆向量,u(t) 表示更新向量,m(t+1) 表示更新后的记忆向量
    • ARN×NA\in\mathbb R^{N\times N}
    • BRN×1B\in\mathbb R^{N\times 1}
    • N 表示记忆单元的参数尺度,类似于 Transformerhidden size,越大记忆能力越强
  • HiPPOLegT (Translated Legendre Measure)LegS (Scaled Legendre Measure) 两种度量方法,二者都使用上述记忆更新公式,只是 AB 参数不同
    • LegT 使用 翻译 任务的勒让德多项式,本质是一个滑动窗口,只记忆当前时刻前 θ\theta 窗口内容,θ\theta 为超参数
      • Ank=1θ{(1)nk(2n+1)ifnk2n+1ifn<kA_{nk}=\frac{1}{\theta} \begin{cases} (-1)^{n-k}(2n+1) & if & n\ge k\\ 2n+1 & if & n < k \end{cases}
      • Bn=1θ(2n+1)(1)nB_n=\frac{1}{\theta}(2n+1)(-1)^n
    • LegS 使用 缩放 的勒让德多项式,记忆全部时刻的序列内容
      • Ank=1θ{(2n+1)12(2k+1)12ifn>kn+1ifn=k0ifn<kA_{nk}=\frac{1}{\theta} \begin{cases} (2n+1)^{\frac{1}{2}}(2k+1)^{\frac{1}{2}} & if & n\gt k\\ n+1 & if & n = k\\ 0 & if & n < k \end{cases}
      • Bn=(2n+1)12B_n=(2n+1)^{\frac{1}{2}}
  • Permute Mnist 分类任务的例子讲解 HiPPO 如何作为 RNN 的单元参与计算,以及HiPPO 的记忆单元如何更新
  • Permute Mnist 任务是将 28x28Mnist 图像的每一类按照同一种 pattern 进行 shuffle,训练并分类
  • 下图为使用 HiPPO 作为记忆单元的 RNN 网络解决 Permute Mnist 任务的计算过程,input_t 是每次顺序输入图片的一个像素值,是一个时间步总长为 28 * 28 = 784RNN 网络,最后一个 hidden state 输出映射到 class dim 上进行分类
graph TD
    subgraph input;
    input_t([input_t]);
    h_t;
    end;
    subgraph fully_connect;
    W_hxm;
    W_gxm;
    W_uxh;
    end;
    input_t([input_t])-->|1|Concat_1[Concat];
    h_t([h_t])-->|512|Concat_1-->|513|W_uxh-->|1|u_t([u_t]);
    subgraph update_memory;
    A([A])-->|max_length, 512, 512|get_index_A[get_index]-->|512, 512|A_t;
    timestep([timestep])-->get_index_A;
    A_t([A_t])-->|512, 512|MatMul_A[MatMul];
    m_t([m_t])-->|1, 512|MatMul_A-->|1, 512|Add;
    m_t([m_t])-->|1, 512|Add;
    B([B])-->|max_length, 512|get_index_B[get_index]-->|512, 1|B_t;
    timestep([timestep])-->get_index_B;
    B_t([B_t])-->|512, 1|MatMul_B[MatMul];
    u_t([u_t])-->|1|MatMul_B-->|1, 512|Add-->|1, 512|m_t+1([m_t+1]);
    end;
    m_t+1-->|512|Concat_2;
    input_t-->|1|Concat_2[Concat]-->|513|W_hxm-->|512|Tanh-->|512|hidden([hidden]);
    Concat_2-->|513|W_gxm-->|512|gate([gate]);
    h_t-->|512|Alpha_Blending-->|512|h_t+1([h_t+1])-->|512|until_last_h{until_last_h};
    hidden-->|512|Alpha_Blending;
    gate-->|512|Alpha_Blending;
    subgraph output;
    until_last_h-->|512|map_to_class_dim-->|10|classification_result([classification_result]);
    end;

Thought

  • LegTLegS 的参数计算过程需要较强的数学功底才能完全理解
  • 如果只把 AB 当做 万能的不需要梯度下降更新的神经网络记忆力更新参数,那么实际上并不复杂