URL
- paper: https://arxiv.org/pdf/2205.14135
- code: https://github.com/shreyansh26/FlashAttention-PyTorch/blob/master/flash_attention.py
TL;DR
Flash Attention是标准Attention的一种完全数学等价的高效实现- 标准
Self-attention是 时间复杂度的,随着大模型的发展,序列长度n越来越大,带来的时间复杂度和IO量越来越恐怖 GPU或NPU等计算设备的SRAM空间有限,无法放下所有后续依赖的计算中间结果,而softmax运算又需要全局信息(无法直接简单tiling),所以这里存在巨大的IO量(HBM (high bandwidth memory)和SRAM之间)- 本文设计了一种
tiling的局部计算self-attention的方法,这种方法在forward和backward中都可以使用 - 在
forward中,FlashAttention会计算并存储softmax操作的归一化因子(下图中的l和m),在backward中,利用这些归一化因子,重新计算所需的中间注意力矩阵,而不需要从HBM中读取存储的完整矩阵 - 上述两点改进可以保证和标准
self-attention计算结果(包括forward和backward)完全一致的情况下,极大降低IO量(代价是计算量的小幅上涨)
Algorithm

1. Standard attention implementation

- 这里之所以要频繁和
HBM以block为单位交互是因为有一个限制是 ,即SRAM可存储的数据量在d和Nd之间(N和d分别是seq_len和model_dim) IO总量:- 读:
- 读
Q / K / V - 读
S / P
- 读
- 写:
- 写
S / P - 写
O
- 写
- 读:
2. Flash attention forward implementation

IO总量:- 读:
- 读
K / V一次 - 读
Q / O次
- 读
- 写:
- 写
O次 - 写
l / m次
- 写
- 读:
- 由上述可知:
- 表示
Q的分块数, 表示K和V的分块数 - 因此,读写上下限分别为:
- 读:
- 上限:
- 下限:
- 写:
- 上限:
- 下限:
Block size越大,总IO量越小
- 表示
3. 详解 Flash attention 过程
- 标准
softmax过程:,为了计算过程的数值稳定,需要对每个值减去序列最大值 Flash attention中,会将QKV都在时序维度切分成若干块,然后在小块之间计算Attention,直接计算是不等价的,因为需要全局信息,例如:- 全局最大值
m(x) - 指数累加和
l(x)
- 全局最大值
1 | import torch |
- 这里最重要也最难理解的一行代码是:
O[i] = (O[i] * l[i] * torch.exp(m[i] - mi_new) + P_ij_Vj) / li_new,这个仔细分析一下O[i] * l[i]是计算得到上次外循环(j-1)时的分子O[i] * l[i] * torch.exp(m[i] - mi_new)是把j-1时刻存储的信息更新到最新(调整max_value)O[i] * l[i] * torch.exp(m[i] - mi_new) + P_ij_Vj因为O的初始值是0,因此这个公式可以拆解为:,通过不断更新max_value和累加求和得到最终结果
Thought
Flash Attention的效率和block size相关,block size越大,效率提升越明显;在一些block size较小的情况下,效率可能会比标准Attention更差Flash Attention目标是解决Memory hierarchy情况下标注Attention计算访存比较低导致IO开销大的问题,是一种以计算换空间的做法;当SRAM相比模型足够大时,标准Attention效率比Flash Attention更高backward过程也是通过几乎相同的方式计算的,forward阶段计算得到的中间l / m都会用于计算backward















