1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
| import torch import torch.nn as nn import torch.nn.functional as F class NSAModule(nn.Module): def __init__( self, d_model, n_heads, compress_block=32, select_block=64, select_topk=16, window_size=512, dropout=0.1, ): super().__init__() self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.compress_block = compress_block self.select_block = select_block self.select_topk = select_topk self.window_size = window_size self.compress_mlp = nn.Sequential( nn.Linear(self.head_dim * compress_block, 256), nn.GELU(), nn.Linear(256, self.head_dim), ) self.gate_mlp = nn.Sequential(nn.Linear(d_model, 3 * n_heads), nn.Sigmoid()) self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) def _compress_tokens(self, k, v): """压缩KV序列到块级别""" b, t, nh, hd = k.shape block_size = self.compress_block num_blocks = (t + block_size - 1) // block_size pad_len = num_blocks * block_size - t k = F.pad(k, (0, 0, 0, 0, 0, pad_len)) v = F.pad(v, (0, 0, 0, 0, 0, pad_len)) k_blocks = k.view(b, num_blocks, block_size, nh, hd) v_blocks = v.view(b, num_blocks, block_size, nh, hd) k_compressed = self.compress_mlp( k_blocks.permute(0, 1, 3, 2, 4).flatten(3) ) v_compressed = self.compress_mlp(v_blocks.permute(0, 1, 3, 2, 4).flatten(3)) return k_compressed, v_compressed def _select_blocks(self, q, k_compressed, v_compressed): """基于注意力分数选择关键块""" scores = torch.einsum("bthd,bkhd->bthk", q, k_compressed) / ( self.head_dim ** 0.5 ) probs = F.softmax(scores, dim=-1) topk_scores, topk_indices = torch.topk( probs.mean(dim=2), self.select_topk, dim=-1 ) k_selected = torch.gather( k_compressed, 1, topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.head_dim), ) v_selected = torch.gather( v_compressed, 1, topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.head_dim), ) return k_selected, v_selected def forward(self, x, attn_mask=None): b, t, d = x.shape q = self.q_proj(x).view(b, t, self.n_heads, self.head_dim) k = self.k_proj(x).view(b, t, self.n_heads, self.head_dim) v = self.v_proj(x).view(b, t, self.n_heads, self.head_dim) k_compressed, v_compressed = self._compress_tokens(k, v) k_selected, v_selected = self._select_blocks(q, k_compressed, v_compressed) k_window = k[:, max(0, t - self.window_size) :] v_window = v[:, max(0, t - self.window_size) :] gate = self.gate_mlp(x).view(b, t, 3, self.n_heads).permute(2, 0, 3, 1, 2) attn_outputs = [] for branch_k, branch_v in [ (k_compressed, v_compressed), (k_selected, v_selected), (k_window, v_window), ]: scores = torch.einsum("bthd,bkhd->bthk", q, branch_k) / ( self.head_dim ** 0.5 ) if attn_mask is not None: scores = scores.masked_fill(attn_mask == 0, -1e9) probs = F.softmax(scores, dim=-1) probs = self.dropout(probs) output = torch.einsum("bthk,bkhd->bthd", probs, branch_v) attn_outputs.append(output) weighted = sum(g * o for g, o in zip(gate, attn_outputs)) output = weighted.contiguous().view(b, t, d) return self.out_proj(output) def test_nsa_attention_shapes(): batch_size = 2 seq_len = 128 d_model = 256 n_heads = 8 nsa_attn = NSAModule(d_model, n_heads) x = torch.randn(batch_size, seq_len, d_model) print("输入形状:", x.shape) q = nsa_attn.q_proj(x).view(batch_size, seq_len, n_heads, -1) k = nsa_attn.k_proj(x).view(batch_size, seq_len, n_heads, -1) v = nsa_attn.v_proj(x).view(batch_size, seq_len, n_heads, -1) print("\nQ/K/V形状:", q.shape) k_comp, v_comp = nsa_attn._compress_tokens(k, v) print("压缩后KV形状:", k_comp.shape) if __name__ == "__main__": test_nsa_attention_shapes()
|