Zhangzhe's Blog

The projection of my life.

0%

Attention Residuals

URL

TL;DR

  • 本文提出了一种新的残差连接方法,称为注意力残差(Attention Residuals),本质是一种可学习的残差连接机制,能够动态调整残差连接的权重,从而提高 Transformer 模型的性能和稳定性。

Algorithm

Architecture

pemQqEt.png

code

  • 伪代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def block_attn_res(
blocks: list[Tensor], # List[N]: each [B, T, D]
partial_block: Tensor, # [B, T, D] 或 None
proj: Linear, # weight: [1, D] (pseudo-query wₗ)
norm: RMSNorm # RMSNorm over last dim
) -> Tensor: # Return: [B, T, D]
"""
跨块注意力:在 [已完成块] + [当前块部分和] 上进行选择性聚合
"""

# ── Step 1: 拼接所有候选源 ──────────────────────────────────
# blocks: N × [B, T, D]
# partial_block: [B, T, D]
# 拼接后列表长度: N+1
# stack 沿新维度 0 堆叠
V = torch.stack(blocks + [partial_block])
# V.shape: [N+1, B, T, D]
# ├─ dim 0: source index (历史块 0~N-1 + 当前部分和)
# ├─ dim 1: batch
# ├─ dim 2: token position
# └─ dim 3: hidden feature

# ── Step 2: RMSNorm 作为 Key ───────────────────────────────
# 防止大模值块天然主导注意力权重
K = norm(V)
# K.shape: [N+1, B, T, D] (与 V 相同)

# ── Step 3: 计算注意力 logits ──────────────────────────────
# proj.weight: [1, D] → squeeze() → [D] = wₗ (层特定的伪查询)
# einsum: 'd, n b t d -> n b t'
# 对 D 维度做点积: wₗᵀ · K[n,b,t,:]
logits = torch.einsum('d, n b t d -> n b t', proj.weight.squeeze(), K)
# logits.shape: [N+1, B, T]
# 每个 (n,b,t) 表示: 源 n 对位置 (b,t) 的未归一化注意力分数

# ── Step 4: Softmax + 加权求和 ─────────────────────────────
# softmax(0): 沿 source 维度归一化,使 ∑ₙ αₙ = 1
# α.shape: [N+1, B, T]
#
# einsum: 'n b t, n b t d -> b t d'
# 对每个 (b,t): h[b,t] = Σₙ αₙ[b,t] · V[n,b,t,:]
h = torch.einsum('n b t, n b t d -> b t d', logits.softmax(0), V)
# h.shape: [B, T, D] ← 聚合后的隐藏状态,作为下一子层输入
return h


def forward(
self,
blocks: list[Tensor], # List[N]: each [B, T, D] (历史已完成块)
hidden_states: Tensor # [B, T, D] (当前层输入)
) -> tuple[list[Tensor], Tensor]:
"""
单层前向传播,维护块内残差累积 + 块间注意力聚合
"""

# ── 阶段 0: 初始化当前块的部分和 ─────────────────────────────
partial_block = hidden_states
# partial_block.shape: [B, T, D]

# ── 阶段 1: Pre-Attention 的 Block AttnRes ─────────────────
# 聚合历史块 + 当前部分和 → 得到 h 作为 Attention 输入
h = block_attn_res(
blocks, # List[N] × [B,T,D]
partial_block, # [B,T,D]
self.attn_res_proj, # Linear: weight [1,D]
self.attn_res_norm # RMSNorm
)
# h.shape: [B, T, D]

# ── 阶段 2: 块边界检测 ─────────────────────────────────────
# block_size: 每块包含的子层数 (ATTN+MLP=2 per Transformer layer)
# 例如 block_size=6 → 每块 3 个 Transformer 层
if self.layer_number % (self.block_size // 2) == 0:
# 当前块完成,将其部分和加入历史块列表
blocks.append(partial_block)
# blocks: List[N+1] × [B,T,D]

# 重置部分和,新块从头累积
partial_block = None
# partial_block: None

# ── 阶段 3: 标准 Self-Attention ───────────────────────────
# self.attn_norm: RMSNorm
# h: [B,T,D] → attn_norm(h): [B,T,D]
# self.attn: 标准多头自注意力
# Q/K/V: [B,T,D] → attn_out: [B,T,D]
attn_out = self.attn(self.attn_norm(h))
# attn_out.shape: [B, T, D]

# 块内残差累加 (标准残差连接)
if partial_block is not None:
partial_block = partial_block + attn_out
# [B,T,D] + [B,T,D] = [B,T,D]
else:
partial_block = attn_out
# 新块第一层: 直接初始化为 attn_out
# partial_block.shape: [B, T, D]

# ── 阶段 4: Pre-MLP 的 Block AttnRes ──────────────────────
# 注意: 此时 partial_block 已包含 attn_out,信息更丰富
# 使用独立的投影参数 (mlp_res_proj ≠ attn_res_proj)
h = block_attn_res(
blocks, # List[N] or List[N+1] × [B,T,D]
partial_block, # [B,T,D]
self.mlp_res_proj, # Linear: weight [1,D] (独立参数!)
self.mlp_res_norm # RMSNorm
)
# h.shape: [B, T, D]

# ── 阶段 5: 标准 MLP ──────────────────────────────────────
# self.mlp_norm: RMSNorm
# self.mlp: 例如 SwiGLU: [B,T,D] → [B,T,D]
mlp_out = self.mlp(self.mlp_norm(h))
# mlp_out.shape: [B, T, D]

# 块内残差累加
partial_block = partial_block + mlp_out
# [B,T,D] + [B,T,D] = [B,T,D]

# ── 返回 ─────────────────────────────────────────────────
return blocks, partial_block
# blocks: List[M] × [B,T,D] (M = 已完成块数,≤ N)
# partial_block: [B,T,D] or None (若刚重置)

一些细节

  1. 每个层都有独立的投影参数(attn_res_projmlp_res_proj),允许每个层学习不同的注意力残差权重。
  2. Attention 前和 MLP 前都应用了 block_attn_res,确保每个子层都能从历史块和当前部分和中动态聚合信息。
  3. 块边界检测基于层数和预设的块大小,确保每块包含固定数量的子层(例如 3 层),并在块完成时将部分和加入历史块列表。
  4. attn_res 的最后是一个 softmax,确保所有源(历史块 + 当前部分和)的权重和为 1,使得聚合后的隐藏状态是一个加权平均,不容易出现数值不稳定。

最终结果

pem1YF0.png

Thoughts

  • 很多大模型公司都在探索改进残差连接的方法,比如 DeepSeekmHC,确实在大模型上残差连接的设计非常重要,直接影响模型的训练稳定性和性能,在大模型结构设计进入深水区之后,这种细节确实该琢磨了。
  • 这种这么长生命周期的残差,感觉会让训练和推理的显存管理变困难。