Zhangzhe's Blog

The projection of my life.

0%

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

URL

TL;DR

  • Flash Attention 是标准 Attention 的一种完全数学等价的高效实现
  • 标准 Self-attention Attention(Q,K,V)=softmax(QKTd)VAttention(Q,K,V)=softmax({\frac{QK^T}{\sqrt d}})VO(n2)\mathcal{O}(n^2) 时间复杂度的,随着大模型的发展,序列长度 n 越来越大,带来的时间复杂度和 IO 量越来越恐怖
  • GPUNPU 等计算设备的 SRAM 空间有限,无法放下所有后续依赖的计算中间结果,而 softmax 运算又需要全局信息(无法直接简单 tiling),所以这里存在巨大的 IO 量(HBM (high bandwidth memory)SRAM 之间)
  • 本文设计了一种 tiling 的局部计算 self-attention 的方法,这种方法在 forwardbackward 中都可以使用
  • forward 中,FlashAttention 会计算并存储 softmax 操作的归一化因子(下图中的 lm),在 backward 中,利用这些归一化因子,重新计算所需的中间注意力矩阵,而不需要从 HBM 中读取存储的完整矩阵
  • 上述两点改进可以保证和标准 self-attention 计算结果(包括 forwardbackward)完全一致的情况下,极大降低 IO 量(代价是计算量的小幅上涨)

Algorithm

1. Standard attention implementation

flash_1.png

  • 这里之所以要频繁和 HBMblock 为单位交互是因为有一个限制是 dMNdd\le M \le Nd,即 SRAM 可存储的数据量在 dNd 之间(Nd 分别是 seq_lenmodel_dim
  • IO 总量:
    • 读:2N2+3Nd2N^2+3Nd
      • Q / K / V
      • S / P
    • 写:2N2+Nd2N^2+Nd
      • S / P
      • O

2. Flash attention forward implementation

flash_2.png

  • IO 总量:
    • 读:2Nd+2TcNd2Nd+2T_cNd
      • K / V 一次
      • Q / O TcT_c
    • 写:TcNd+2TcNT_cNd+2T_cN
      • O TcT_c
      • l / m TcT_c
  • 由上述可知:
    • 4Tc4N4\le T_c\le 4N
    • 因此,读写上下限分别为:
    • 读:
      • 上限:2Nd+8N2d2Nd+8N^2d
      • 下限:10Nd10Nd
    • 写:
      • 上限:4N2d+8N24N^2d+8N^2
      • 下限:4Nd+8N4Nd+8N
    • Block size 越大,总 IO 量越小

3. 详解 Flash attention 过程

  • 标准 softmax 过程:softmax(xij)=exijmax(x)i=0Nj=0Nexijmax(x)softmax(x_{ij})=\frac{e^{x_{ij}-max(x)}}{\sum_{i=0}^{N}\sum_{j=0}^{N}e^{x_{ij}-max(x)}},为了计算过程的数值稳定,需要对每个值减去序列最大值
  • Flash attention 中,会将 QKV 都在时序维度切分成若干块,然后在小块之间计算 Attention,直接计算是不等价的,因为需要全局信息,例如:
    • 全局最大值 m(x)
    • 指数累加和 l(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import math
BLOCK_SIZE = 1024
NEG_INF = -1e10 # -infinity
EPSILON = 1e-10
def flash_attention_forward(Q, K, V):
batch, seq_len, dim = Q.shape
O = torch.zeros_like(Q, requires_grad=True, device=Q.device)
l = torch.zeros((batch, seq_len, 1), device=Q.device) # 1 for broadcasting
m = torch.ones((batch, seq_len, 1), device=Q.device) * NEG_INF
Q_BLOCK_SIZE = min(BLOCK_SIZE, dim)
KV_BLOCK_SIZE = BLOCK_SIZE
Q = torch.split(Q, Q_BLOCK_SIZE, dim=1)
K = torch.split(K, KV_BLOCK_SIZE, dim=1)
V = torch.split(V, KV_BLOCK_SIZE, dim=1)
O = list(torch.split(O, Q_BLOCK_SIZE, dim=1))
l = list(torch.split(l, Q_BLOCK_SIZE, dim=1))
m = list(torch.split(m, Q_BLOCK_SIZE, dim=1))
Tr = len(Q)
Tc = len(K)
scale = 1 / math.sqrt(dim)
for j in range(Tc):
for i in range(Tr):
Qi_scaled = Q[i] * scale
S_ij = torch.einsum("... i d, ... j d -> ... i j", Qi_scaled, K[j])
m_ij, _ = torch.max(S_ij, dim=-1, keepdim=True)
P_ij = torch.exp(S_ij - m_ij)
mi_new = torch.maximum(m_ij, m[i])
l_ij = (
torch.sum(P_ij * torch.exp(m_ij - mi_new), dim=-1, keepdim=True)
+ EPSILON
)
P_ij_Vj = torch.einsum(
"... i j, ... j d -> ... i d", P_ij * torch.exp(m_ij - mi_new), V[j]
)
li_new = torch.exp(m[i] - mi_new) * l[i] + l_ij
O[i] = (O[i] * l[i] * torch.exp(m[i] - mi_new) + P_ij_Vj) / li_new
l[i] = li_new
m[i] = mi_new
O = torch.cat(O, dim=1)
l = torch.cat(l, dim=1)
m = torch.cat(m, dim=1)
return O, l, m
def flash_attention_backward(Q, K, V, O, l, m, dO):
# 参考代码中进行了实现,但这里省略,重点关注一下 backward 过程的输入
pass
if __name__ == "__main__":
# batch, seq_len, dim
Q = torch.randn(2, 4096, 128, requires_grad=True).to(device="cuda")
K = torch.randn(2, 4096, 128, requires_grad=True).to(device="cuda")
V = torch.randn(2, 4096, 128, requires_grad=True).to(device="cuda")
out = flash_attention_forward(Q, K, V)
print(out[0].sum().item())
  • 这里最重要也最难理解的一行代码是:O[i] = (O[i] * l[i] * torch.exp(m[i] - mi_new) + P_ij_Vj) / li_new,这个仔细分析一下
    • O[i] * l[i] 是计算得到上次外循环(j-1)时的分子
    • O[i] * l[i] * torch.exp(m[i] - mi_new) 是把 j-1 时刻存储的信息更新到最新(调整 max_value
    • O[i] * l[i] * torch.exp(m[i] - mi_new) + P_ij_Vj 因为 O 的初始值是 0,因此这个公式可以拆解为:Oi=j=1TcPijVjliO_i = \sum_{j=1}^{T_c} \frac{P_{ij} * V_j}{l_i},通过不断更新 max_value 和累加求和得到最终结果

Thought

  • Flash Attention 的效率和 block size 相关,block size 越大,效率提升越明显;在一些 block size 较小的情况下,效率可能会比标准 Attention 更差
  • Flash Attention 目标是解决 Memory hierarchy 情况下标注 Attention 计算访存比较低导致 IO 开销大的问题,是一种以时间换空间的做法;当 SRAM 相比模型足够大时,标准 Attention 效率比 Flash Attention 更高
  • backward 过程也是通过几乎相同的方式计算的,forward 阶段计算得到的中间 l / m 都会用于计算 backward