Zhangzhe's Blog

The projection of my life.

0%

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

URL

TL;DR

  • FlashAttention-2FlashAttention 一作对前一版的更新,基本思想不变,只优化了部分算法流程,可以比 V1 提速一倍

Algorithm

flash_attention_2.png

  • Flash Attention 2 前向过程主要优化了几个方面:

    1. Flash Attention v1 中计算 P_ij 计算过程很绕,需要更新两次全局信息,v2 中进行了优化
    2. Flash Attention v1 中内外层循环设置不合理,v2 对调了内外层循环,这样做的好处是:
    3. 不再需要重复读入写出 O
    4. max_value 信息变成一个计算中间变量,不再需要读入写出,backward 过程也无需依赖
  • Flash Attention v2Flash Attention v1IO 量对比

    • v1
      • 读:2Nd+2TcNd2Nd+2T_cNd
      • 写:TcNd+2TcNT_cNd+2T_cN
    • v2
      • 读:Nd+2TrNdNd+2T_rNd
      • 写:Nd+NNd+N
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import math

BLOCK_SIZE = 1024
NEG_INF = -1e10 # -infinity
EPSILON = 1e-10


def flash_attention2_forward(Q, K, V):
batch, seq_len, dim = Q.shape
O = torch.zeros_like(Q, requires_grad=True, device=Q.device)
l = torch.zeros((batch, seq_len, 1), device=Q.device) # 1 for broadcasting
m = torch.ones((batch, seq_len, 1), device=Q.device) * NEG_INF

Q = list(torch.split(Q, BLOCK_SIZE, dim=1))
K = list(torch.split(K, BLOCK_SIZE, dim=1))
V = list(torch.split(V, BLOCK_SIZE, dim=1))
O = list(torch.split(O, BLOCK_SIZE, dim=1))
l = list(torch.split(l, BLOCK_SIZE, dim=1))
m = list(torch.split(m, BLOCK_SIZE, dim=1))

Tr = len(Q)
Tc = len(K)
scale = 1 / math.sqrt(dim)
for i in range(Tr):
Qi_scaled = Q[i] * scale
for j in range(Tc):
S_ij = torch.einsum("bnd,bmd->bnm", Qi_scaled, K[j])
m_ij, _ = torch.max(S_ij, dim=-1, keepdim=True)
mi_new = torch.maximum(m_ij, m[i])
P_ij = torch.exp(S_ij - mi_new)
P_ijV_j = torch.einsum("bnm,bmd->bnd", P_ij, V[j])
exp_row_max_diff = torch.exp(m[i] - mi_new)
li_new = l[i] * exp_row_max_diff + P_ij.sum(dim=-1, keepdim=True)
O[i] = O[i] * exp_row_max_diff + P_ijV_j
m[i] = mi_new
l[i] = li_new
O[i] = O[i] / l[i]

O = torch.cat(O, dim=1)
l = torch.cat(l, dim=1)

return O, l


def flash_attention2_backward(Q, K, V, O, l, dO):
# 参考代码中进行了实现,但这里省略,重点关注一下 backward 过程的输入
pass


if __name__ == "__main__":
torch.random.manual_seed(0)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# batch, seq_len, dim
Q = torch.randn(2, 4096, 128, requires_grad=True).to(device)
K = torch.randn(2, 4096, 128, requires_grad=True).to(device)
V = torch.randn(2, 4096, 128, requires_grad=True).to(device)

flash2_out = flash_attention2_forward(Q, K, V)
print(flash2_out[0].sum().item())
  • backward 过程在 v2 论文中详细的写出来了,如下:

flash2_2.png

  • 速度效果提升很明显

flash2_3.png

Thought

  • Flash Attention v2 看起来特别像 Flash Attention 的紧急修复版本,把之前写的冗余删掉了
  • 甚至完全不了解 Flash Attention 思想的人看 Flash Attention 的实现代码,都可以将其优化到 Flash Attention v2,就是简单的代码优化…
  • Flash Attention v2 彻底把这个 tiling 思路的优势发挥出来了