1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| import torch import torch.nn.functional as F def multiquery_attention_batched(X, M, mask, P_q, P_k, P_v, P_o): Q = torch.einsum("bnd,hdk->bhkn", X, P_q) K = torch.einsum("bmd,dk->bmk", M, P_k) V = torch.einsum("bmd,dv->bmv", M, P_v) logits = torch.einsum("bhkn,bmk->bhnm", Q, K) weights = F.softmax(logits + mask, dim=-1) O = torch.einsum("bhnm,bmv->bhnv", weights, V) Y = torch.einsum("bhnv,hvd->bnd", O, P_o) return Y
batch_size = 2 num_queries = 5 num_keys = 7 dim = 64 key_dim = 32 value_dim = 48 num_heads = 8
X = torch.randn(batch_size, num_queries, dim) M = torch.randn(batch_size, num_keys, dim) mask = torch.randn(batch_size, num_heads, num_queries, num_keys) P_q = torch.randn(num_heads, dim, key_dim) P_k = torch.randn(dim, key_dim) P_v = torch.randn(dim, value_dim) P_o = torch.randn(num_heads, value_dim, dim)
output = multiquery_attention_batched(X, M, mask, P_q, P_k, P_v, P_o) print(output.shape)
|