URL
- paper: https://arxiv.org/pdf/2205.14135
- code: https://github.com/shreyansh26/FlashAttention-PyTorch/blob/master/flash_attention.py
TL;DR
Flash Attention
是标准Attention
的一种完全数学等价的高效实现- 标准
Self-attention
是 时间复杂度的,随着大模型的发展,序列长度n
越来越大,带来的时间复杂度和IO
量越来越恐怖 GPU
或NPU
等计算设备的SRAM
空间有限,无法放下所有后续依赖的计算中间结果,而softmax
运算又需要全局信息(无法直接简单tiling
),所以这里存在巨大的IO
量(HBM (high bandwidth memory)
和SRAM
之间)- 本文设计了一种
tiling
的局部计算self-attention
的方法,这种方法在forward
和backward
中都可以使用 - 在
forward
中,FlashAttention
会计算并存储softmax
操作的归一化因子(下图中的l
和m
),在backward
中,利用这些归一化因子,重新计算所需的中间注意力矩阵,而不需要从HBM
中读取存储的完整矩阵 - 上述两点改进可以保证和标准
self-attention
计算结果(包括forward
和backward
)完全一致的情况下,极大降低IO
量(代价是计算量的小幅上涨)
Algorithm
1. Standard attention implementation
- 这里之所以要频繁和
HBM
以block
为单位交互是因为有一个限制是 ,即SRAM
可存储的数据量在d
和Nd
之间 IO
总量:- 读:
- 读
Q / K / V
- 读
S / P
- 读
- 写:
- 写
S / P
- 写
O
- 写
- 读:
2. Flash attention forward implementation
-
IO
总量:- 读:
- 读
K / V
一次 - 读
Q / O
次
- 读
- 写:
- 写
O
次 - 写
l / m
次
- 写
- 读:
-
由上述可知:
- 因此,读写上下限分别为:
- 读:
- 上限:
- 下限:
- 写:
- 上限:
- 下限:
Block size
越大,总IO
量越小
3. 详解 Flash attention
过程
- 标准
softmax
过程:,为了计算过程的数值稳定,需要对每个值减去序列最大值 Flash attention
中,会将QKV
都在时序维度切分成若干块,然后在小块之间计算Attention
,直接计算是不等价的,因为需要全局信息,例如:- 全局最大值
m(x)
- 指数累加和
l(x)
- 全局最大值
1 | import torch |
- 这里最重要也最难理解的一行代码是:
O[i] = (O[i] * l[i] * torch.exp(m[i] - mi_new) + P_ij_Vj) / li_new
,这个仔细分析一下O[i] * l[i]
是计算得到上次外循环(j-1
)时的分子O[i] * l[i] * torch.exp(m[i] - mi_new)
是把j-1
时刻存储的信息更新到最新(调整max_value
)O[i] * l[i] * torch.exp(m[i] - mi_new) + P_ij_Vj
因为O
的初始值是0
,因此这个公式可以拆解为:,通过不断更新max_value
和累加求和得到最终结果
Thought
Flash Attention
的效率和block size
相关,block size
越大,效率提升越明显;在一些block size
较小的情况下,效率可能会比标准Attention
更差Flash Attention
目标是解决Memory hierarchy
情况下标注Attention
计算访存比较低导致IO
开销大的问题,是一种以时间换空间的做法;当SRAM
相比模型足够大时,标准Attention
效率比Flash Attention
更高backward
过程也是通过几乎相同的方式计算的,forward
阶段计算得到的中间l / m
都会用于计算backward