Zhangzhe's Blog

The projection of my life.

0%

Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention

URL

TL;DR

  • 本文提出了一种新的稀疏注意力机制,称为 Native Sparse Attention,该机制在硬件上对齐,并且可以直接训练,而无需额外的稀疏化技术
  • 目前没有开源代码,只能通过论文中的公式和描述来实现

Algorithm

总体流程

nsa.png

  • 本质上是用不同的 pattern 组合来替代 full attention,以减少计算量

代码实现

  • DeepSeek-R1 根据论文中的公式和描述实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import torch.nn as nn
import torch.nn.functional as F
class NSAModule(nn.Module):
def __init__(
self,
d_model,
n_heads,
compress_block=32,
select_block=64,
select_topk=16,
window_size=512,
dropout=0.1,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
# 参数设置
self.compress_block = compress_block
self.select_block = select_block
self.select_topk = select_topk
self.window_size = window_size
# 压缩用MLP
self.compress_mlp = nn.Sequential(
nn.Linear(self.head_dim * compress_block, 256),
nn.GELU(),
nn.Linear(256, self.head_dim),
)
# 门控机制
self.gate_mlp = nn.Sequential(nn.Linear(d_model, 3 * n_heads), nn.Sigmoid())
# 投影层
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def _compress_tokens(self, k, v):
"""压缩KV序列到块级别"""
# 调整输入维度 (batch, seq_len, n_heads, head_dim)
b, t, nh, hd = k.shape # 修改为四维解包
block_size = self.compress_block
num_blocks = (t + block_size - 1) // block_size
pad_len = num_blocks * block_size - t
# 填充并分块
k = F.pad(k, (0, 0, 0, 0, 0, pad_len)) # 添加头部维度的填充
v = F.pad(v, (0, 0, 0, 0, 0, pad_len))
# 调整维度: [batch, num_blocks, block_size, n_heads, head_dim]
k_blocks = k.view(b, num_blocks, block_size, nh, hd)
v_blocks = v.view(b, num_blocks, block_size, nh, hd)
# 压缩处理 (保持头部分离)
k_compressed = self.compress_mlp(
k_blocks.permute(0, 1, 3, 2, 4).flatten(3)
) # [b, num_blocks, nh, hd]
v_compressed = self.compress_mlp(v_blocks.permute(0, 1, 3, 2, 4).flatten(3))
return k_compressed, v_compressed
def _select_blocks(self, q, k_compressed, v_compressed):
"""基于注意力分数选择关键块"""
# 计算压缩注意力分数
scores = torch.einsum("bthd,bkhd->bthk", q, k_compressed) / (
self.head_dim ** 0.5
)
probs = F.softmax(scores, dim=-1)
# 选择topk块
topk_scores, topk_indices = torch.topk(
probs.mean(dim=2), self.select_topk, dim=-1
)
# 收集选中的块
k_selected = torch.gather(
k_compressed,
1,
topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.head_dim),
)
v_selected = torch.gather(
v_compressed,
1,
topk_indices.unsqueeze(-1).expand(-1, -1, -1, self.head_dim),
)
return k_selected, v_selected
def forward(self, x, attn_mask=None):
b, t, d = x.shape
# 投影QKV并保持四维结构
q = self.q_proj(x).view(b, t, self.n_heads, self.head_dim)
k = self.k_proj(x).view(b, t, self.n_heads, self.head_dim)
v = self.v_proj(x).view(b, t, self.n_heads, self.head_dim)
# 压缩路径
k_compressed, v_compressed = self._compress_tokens(k, v)
# 选择路径
k_selected, v_selected = self._select_blocks(q, k_compressed, v_compressed)
# 滑动窗口
k_window = k[:, max(0, t - self.window_size) :]
v_window = v[:, max(0, t - self.window_size) :]
# 门控权重
gate = self.gate_mlp(x).view(b, t, 3, self.n_heads).permute(2, 0, 3, 1, 2)
# 三路注意力计算
attn_outputs = []
for branch_k, branch_v in [
(k_compressed, v_compressed),
(k_selected, v_selected),
(k_window, v_window),
]:
scores = torch.einsum("bthd,bkhd->bthk", q, branch_k) / (
self.head_dim ** 0.5
)
if attn_mask is not None:
scores = scores.masked_fill(attn_mask == 0, -1e9)
probs = F.softmax(scores, dim=-1)
probs = self.dropout(probs)
output = torch.einsum("bthk,bkhd->bthd", probs, branch_v)
attn_outputs.append(output)
# 门控融合
weighted = sum(g * o for g, o in zip(gate, attn_outputs))
output = weighted.contiguous().view(b, t, d)
return self.out_proj(output)
def test_nsa_attention_shapes():
# 测试参数
batch_size = 2
seq_len = 128
d_model = 256
n_heads = 8
nsa_attn = NSAModule(d_model, n_heads)
x = torch.randn(batch_size, seq_len, d_model)
print("输入形状:", x.shape)
# 打印中间形状
q = nsa_attn.q_proj(x).view(batch_size, seq_len, n_heads, -1)
k = nsa_attn.k_proj(x).view(batch_size, seq_len, n_heads, -1)
v = nsa_attn.v_proj(x).view(batch_size, seq_len, n_heads, -1)
print("\nQ/K/V形状:", q.shape) # 应该输出 [2, 128, 8, 32]
k_comp, v_comp = nsa_attn._compress_tokens(k, v)
print("压缩后KV形状:", k_comp.shape) # [2, 4, 8, 32]
if __name__ == "__main__":
test_nsa_attention_shapes()

Thoughts

  • 速度快并不惊讶,但效果竟然比 Full Attention 还好,如果实验比较 solid,那么 NSA 确实比较有前途
  • 由于没有开源代码,所以只能通过论文中的公式和描述来实现,这对于一些实验复现和工程应用来说是一个挑战