1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| import torch import math BLOCK_SIZE = 1024 NEG_INF = -1e10 EPSILON = 1e-10 def flash_attention2_forward(Q, K, V): batch, seq_len, dim = Q.shape O = torch.zeros_like(Q, requires_grad=True, device=Q.device) l = torch.zeros((batch, seq_len, 1), device=Q.device) m = torch.ones((batch, seq_len, 1), device=Q.device) * NEG_INF Q = list(torch.split(Q, BLOCK_SIZE, dim=1)) K = list(torch.split(K, BLOCK_SIZE, dim=1)) V = list(torch.split(V, BLOCK_SIZE, dim=1)) O = list(torch.split(O, BLOCK_SIZE, dim=1)) l = list(torch.split(l, BLOCK_SIZE, dim=1)) m = list(torch.split(m, BLOCK_SIZE, dim=1)) Tr = len(Q) Tc = len(K) scale = 1 / math.sqrt(dim) for i in range(Tr): Qi_scaled = Q[i] * scale for j in range(Tc): S_ij = torch.einsum("bnd,bmd->bnm", Qi_scaled, K[j]) m_ij, _ = torch.max(S_ij, dim=-1, keepdim=True) mi_new = torch.maximum(m_ij, m[i]) P_ij = torch.exp(S_ij - mi_new) P_ijV_j = torch.einsum("bnm,bmd->bnd", P_ij, V[j]) exp_row_max_diff = torch.exp(m[i] - mi_new) li_new = l[i] * exp_row_max_diff + P_ij.sum(dim=-1, keepdim=True) O[i] = O[i] * exp_row_max_diff + P_ijV_j m[i] = mi_new l[i] = li_new O[i] = O[i] / l[i] O = torch.cat(O, dim=1) l = torch.cat(l, dim=1) return O, l def flash_attention2_backward(Q, K, V, O, l, dO): pass if __name__ == "__main__": device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") Q = torch.randn(2, 4096, 128, requires_grad=True).to(device) K = torch.randn(2, 4096, 128, requires_grad=True).to(device) V = torch.randn(2, 4096, 128, requires_grad=True).to(device) flash2_out = flash_attention2_forward(Q, K, V) print(flash2_out[0].sum().item())
|