Zhangzhe's Blog

The projection of my life.

0%

URL

TL;DR

  • 蒙特卡洛树搜索(Monte Carlo Tree Search, 简称 MCTS)旨在解决在拥有海量状态空间的复杂决策环境下的 “维度灾难” 问题,即:当我们无法穷举所有可能性,且不知道如何准确评估当前状态时,如何利用有限的计算资源,找到当前的最优解?
  • 核心思想是:不再穷举所有分支,而是像一个有经验的玩家一样:把更多的算力(搜索时间)倾斜到那些看起来赢面更大、更有希望的分支上,同时偶尔去探索一下没怎么走过的分支,防止错过奇招。这在算法中被称为探索与利用(Exploration vs. Exploitation)的平衡。

Algorithm

  • MCTS 通过不断的循环迭代来构建一棵搜索树。每次迭代包含四个经典的步骤:
    • 选择 (Selection)
    • 扩展 (Expansion)
    • 模拟 (Simulation / Rollout)
    • 回溯 / 反向传播 (Backpropagation)

1. 选择

  • 算法从树的根节点(当前局面)出发,向下遍历,直到找到一个还没有被完全展开的节点(即这个节点下还有未尝试过的动作)。
  • 在向下走的过程中,如果遇到多个子节点,它会使用一个公式来决定走哪边,最常用的公式是 UCT (Upper Confidence Bound applied to Trees):

    UCT=WiNi+clnNNiUCT= \frac{W_i}{N_i} + c \sqrt{\frac{\ln N}{N_i}}

    • WiW_i:该节点下的胜利次数
    • NiN_i:该节点被访问的次数
    • Wi/NiW_i / N_i利用(Exploitation) 项,即胜率
    • 胜率越高的节点,越容易被选中。NN:父节点的总访问次数
    • cc:探索常数(通常取 2\sqrt{2}
    • clnNNic \sqrt{\frac{\ln N}{N_i}}探索(Exploration) 项。一个节点被访问的次数 NiN_i 越少,这项的值就越大,就越容易被选中去探索。

2. 扩展

  • 当根据 UCT 公式找到一个未完全展开的叶子节点时,就给这个节点增加一个新的子节点,代表采取了一个从未尝试过的合法动作,进入了一个新的状态。

3. 模拟

  • 从新扩展出的新节点开始,完全随机地(或者基于某种简单策略) 往下模拟对局,直到分出胜负(到达终止状态,比如游戏结束)
  • 因为是随机模拟(这就是“蒙特卡洛”一词的来源),计算速度非常快。它不需要评估函数,而是直接把游戏“快进”到结局,看看是赢是输

4. 回溯 / 反向传播

  • 根据上一步模拟出的最终结果(比如:赢了记为 1,输了记为 0,平局记为 0.5),从这个新节点沿着刚才的路径一直往回走到根节点
  • 在回退的过程中,更新沿途每一个节点的统计数据:该节点的访问次数 NiN_i 加 1,如果模拟结果是胜利,该节点的胜利次数 WiW_i 加 1

循环与决策

  • MCTS 会在给定的时间限制或计算资源内,疯狂地重复这四个步骤(成千上万次)。每次循环都会让这棵搜索树长得更大,且统计数据越来越准确。
  • 当时间耗尽时,算法会看一眼根节点的所有子节点。此时不需要再看 UCT 公式了,直接选择访问次数最多(或者胜率最高)的那个子节点,作为最终的决策动作。

MCTS 在井字棋上的示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import math
import random
import copy
from abc import ABC, abstractmethod

# ==========================================
# 1. 定义状态接口与井字棋实现
# ==========================================
class GameState(ABC):
@abstractmethod
def get_legal_actions(self): pass
@abstractmethod
def apply_action(self, action): pass
@abstractmethod
def is_terminal(self): pass

class TicTacToe(GameState):
def __init__(self, board=None, current_player=1):
# board: 0 为空,1 为玩家 X,-1 为玩家 O
self.board = board if board is not None else [0] * 9
self.current_player = current_player

def get_legal_actions(self):
"""返回所有空位的索引"""
return [i for i, x in enumerate(self.board) if x == 0]

def apply_action(self, action):
"""落子,并返回切换了玩家的新状态"""
new_board = copy.deepcopy(self.board)
new_board[action] = self.current_player
return TicTacToe(new_board, current_player=-self.current_player)

def get_winner(self):
"""检查是否有玩家获胜,返回 1, -1, 或 None"""
lines = [
(0, 1, 2), (3, 4, 5), (6, 7, 8), # 横
(0, 3, 6), (1, 4, 7), (2, 5, 8), # 竖
(0, 4, 8), (2, 4, 6) # 交叉
]
for a, b, c in lines:
if self.board[a] != 0 and self.board[a] == self.board[b] == self.board[c]:
return self.board[a]
return None

def is_terminal(self):
"""判断游戏是否结束(有人赢了,或者没空位了)"""
return self.get_winner() is not None or len(self.get_legal_actions()) == 0

def get_result(self, player):
"""
核心方法:返回特定 player 在当前终局下的得分。
赢=1.0,输=0.0,平局=0.5
"""
winner = self.get_winner()
if winner is None:
return 0.5 # 平局
if winner == player:
return 1.0 # 该玩家获胜
else:
return 0.0 # 该玩家落败

def __str__(self):
symbols = {1: 'X', -1: 'O', 0: '.'}
res = ""
for i in range(3):
res += " ".join([symbols[self.board[i*3 + j]] for j in range(3)]) + "\n"
return res

# ==========================================
# 2. MCTS 树节点与主算法
# ==========================================
class MCTSNode:
def __init__(self, state: TicTacToe, parent=None, action=None):
self.state = state
self.parent = parent
self.action = action
self.children = []
self.untried_actions = state.get_legal_actions()

self.visits = 0 # 记录此节点到达的总次数
self.wins = 0.0 # 记录“导致到达此节点的动作”所带来的胜场

def is_fully_expanded(self):
return len(self.untried_actions) == 0

def best_child(self, c_param=1.414):
"""使用 UCT 公式选择最优子节点"""
choices_weights = []
for child in self.children:
exploit = child.wins / child.visits
explore = c_param * math.sqrt((2 * math.log(self.visits) / child.visits))
choices_weights.append(exploit + explore)
return self.children[choices_weights.index(max(choices_weights))]

def mcts_search(root_state: TicTacToe, iterations: int = 2000):
root_node = MCTSNode(state=root_state)

for _ in range(iterations):
node = root_node

# 第一步:选择 (Selection)
while node.is_fully_expanded() and not node.state.is_terminal():
node = node.best_child()

# 第二步:扩展 (Expansion)
if not node.state.is_terminal():
action = random.choice(node.untried_actions)
node.untried_actions.remove(action)
next_state = node.state.apply_action(action)

child_node = MCTSNode(state=next_state, parent=node, action=action)
node.children.append(child_node)
node = child_node

# 第三步:模拟 (Rollout) - 纯随机下棋直到结束
current_rollout_state = node.state
while not current_rollout_state.is_terminal():
possible_actions = current_rollout_state.get_legal_actions()
action = random.choice(possible_actions)
current_rollout_state = current_rollout_state.apply_action(action)

# 第四步:反向传播 (Backpropagation)
while node is not None:
node.visits += 1
# 如果 node 有父节点,说明 node 是由父节点的玩家执行 action 得到的
if node.parent is not None:
# 获取做出该动作的玩家
player_who_made_the_move = node.parent.state.current_player
# 将终局结果站在该玩家的视角累加到该节点
node.wins += current_rollout_state.get_result(player_who_made_the_move)
node = node.parent

# 搜索结束,选择访问量最大(最稳健)的子节点动作
best_final_node = max(root_node.children, key=lambda c: c.visits)
return best_final_node.action

def play_game():
game = TicTacToe()
print("=====================================")
print("欢迎来到 MCTS 井字棋对战!")
print("AI 执 X (先手),你执 O (后手)。")
print("棋盘的位置索引 (0-8) 如下所示:")
print(" 0 | 1 | 2 ")
print("---+---+---")
print(" 3 | 4 | 5 ")
print("---+---+---")
print(" 6 | 7 | 8 ")
print("=====================================\n")

print("初始棋盘:")
print(game)

while not game.is_terminal():
if game.current_player == 1:
# AI 的回合
print("AI (MCTS) 正在思考中... (这可能需要一两秒)")
# 这里的 iterations 控制 AI 的算力(难度)。
# 对于井字棋,2000 次迭代 AI 已经几乎是“不败”的了。
best_move = mcts_search(game, iterations=2000)
print(f"AI 落子在位置: {best_move}\n")
game = game.apply_action(best_move)
else:
# 玩家的回合
legal_moves = game.get_legal_actions()
while True:
try:
move = input(f"轮到你了,请输入你要落子的位置 {legal_moves}: ")
move = int(move)
if move in legal_moves:
break
else:
print("⚠ 错误:该位置已有棋子或超出范围,请重新输入。")
except ValueError:
print("⚠ 错误:请输入一个有效的数字。")

print(f"你落子在位置: {move}\n")
game = game.apply_action(move)

# 打印当前棋盘状态
print(game)

# 游戏结束,判定胜负
print("=====================================")
winner = game.get_winner()
if winner == 1:
print("🤖 游戏结束:AI (X) 获胜!(MCTS 确实很难被击败哦)")
elif winner == -1:
print("🎉 游戏结束:恭喜你 (O) 获胜!(这说明你找到了算法的盲区!)")
else:
print("🤝 游戏结束:平局!(面对完美的对手,平局已经是最好的结果了)")
print("=====================================")

if __name__ == "__main__":
play_game()

Thoughts

  • 对于这种简单的游戏,只经过简单的蒙特卡洛树搜索,AI 先手基本可以做到完全不败,但对于像围棋这种搜索空间的游戏,单用 MCTS 实际上无法做到穷举例
  • MCTS 在论文中是每个回合都重建一棵新树,在下个回合就抛弃,工程上通常是一局对弈只建立一棵树,随着对弈根节点向下,只有出现之前未搜索到的局面,才会建新树

URL

TL;DR

  • 本文提出了一种新的残差连接方法,称为注意力残差(Attention Residuals),本质是一种可学习的残差连接机制,能够动态调整残差连接的权重,从而提高 Transformer 模型的性能和稳定性。

Algorithm

Architecture

pemQqEt.png

code

  • 伪代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def block_attn_res(
blocks: list[Tensor], # List[N]: each [B, T, D]
partial_block: Tensor, # [B, T, D] 或 None
proj: Linear, # weight: [1, D] (pseudo-query wₗ)
norm: RMSNorm # RMSNorm over last dim
) -> Tensor: # Return: [B, T, D]
"""
跨块注意力:在 [已完成块] + [当前块部分和] 上进行选择性聚合
"""

# ── Step 1: 拼接所有候选源 ──────────────────────────────────
# blocks: N × [B, T, D]
# partial_block: [B, T, D]
# 拼接后列表长度: N+1
# stack 沿新维度 0 堆叠
V = torch.stack(blocks + [partial_block])
# V.shape: [N+1, B, T, D]
# ├─ dim 0: source index (历史块 0~N-1 + 当前部分和)
# ├─ dim 1: batch
# ├─ dim 2: token position
# └─ dim 3: hidden feature

# ── Step 2: RMSNorm 作为 Key ───────────────────────────────
# 防止大模值块天然主导注意力权重
K = norm(V)
# K.shape: [N+1, B, T, D] (与 V 相同)

# ── Step 3: 计算注意力 logits ──────────────────────────────
# proj.weight: [1, D] → squeeze() → [D] = wₗ (层特定的伪查询)
# einsum: 'd, n b t d -> n b t'
# 对 D 维度做点积: wₗᵀ · K[n,b,t,:]
logits = torch.einsum('d, n b t d -> n b t', proj.weight.squeeze(), K)
# logits.shape: [N+1, B, T]
# 每个 (n,b,t) 表示: 源 n 对位置 (b,t) 的未归一化注意力分数

# ── Step 4: Softmax + 加权求和 ─────────────────────────────
# softmax(0): 沿 source 维度归一化,使 ∑ₙ αₙ = 1
# α.shape: [N+1, B, T]
#
# einsum: 'n b t, n b t d -> b t d'
# 对每个 (b,t): h[b,t] = Σₙ αₙ[b,t] · V[n,b,t,:]
h = torch.einsum('n b t, n b t d -> b t d', logits.softmax(0), V)
# h.shape: [B, T, D] ← 聚合后的隐藏状态,作为下一子层输入
return h


def forward(
self,
blocks: list[Tensor], # List[N]: each [B, T, D] (历史已完成块)
hidden_states: Tensor # [B, T, D] (当前层输入)
) -> tuple[list[Tensor], Tensor]:
"""
单层前向传播,维护块内残差累积 + 块间注意力聚合
"""

# ── 阶段 0: 初始化当前块的部分和 ─────────────────────────────
partial_block = hidden_states
# partial_block.shape: [B, T, D]

# ── 阶段 1: Pre-Attention 的 Block AttnRes ─────────────────
# 聚合历史块 + 当前部分和 → 得到 h 作为 Attention 输入
h = block_attn_res(
blocks, # List[N] × [B,T,D]
partial_block, # [B,T,D]
self.attn_res_proj, # Linear: weight [1,D]
self.attn_res_norm # RMSNorm
)
# h.shape: [B, T, D]

# ── 阶段 2: 块边界检测 ─────────────────────────────────────
# block_size: 每块包含的子层数 (ATTN+MLP=2 per Transformer layer)
# 例如 block_size=6 → 每块 3 个 Transformer 层
if self.layer_number % (self.block_size // 2) == 0:
# 当前块完成,将其部分和加入历史块列表
blocks.append(partial_block)
# blocks: List[N+1] × [B,T,D]

# 重置部分和,新块从头累积
partial_block = None
# partial_block: None

# ── 阶段 3: 标准 Self-Attention ───────────────────────────
# self.attn_norm: RMSNorm
# h: [B,T,D] → attn_norm(h): [B,T,D]
# self.attn: 标准多头自注意力
# Q/K/V: [B,T,D] → attn_out: [B,T,D]
attn_out = self.attn(self.attn_norm(h))
# attn_out.shape: [B, T, D]

# 块内残差累加 (标准残差连接)
if partial_block is not None:
partial_block = partial_block + attn_out
# [B,T,D] + [B,T,D] = [B,T,D]
else:
partial_block = attn_out
# 新块第一层: 直接初始化为 attn_out
# partial_block.shape: [B, T, D]

# ── 阶段 4: Pre-MLP 的 Block AttnRes ──────────────────────
# 注意: 此时 partial_block 已包含 attn_out,信息更丰富
# 使用独立的投影参数 (mlp_res_proj ≠ attn_res_proj)
h = block_attn_res(
blocks, # List[N] or List[N+1] × [B,T,D]
partial_block, # [B,T,D]
self.mlp_res_proj, # Linear: weight [1,D] (独立参数!)
self.mlp_res_norm # RMSNorm
)
# h.shape: [B, T, D]

# ── 阶段 5: 标准 MLP ──────────────────────────────────────
# self.mlp_norm: RMSNorm
# self.mlp: 例如 SwiGLU: [B,T,D] → [B,T,D]
mlp_out = self.mlp(self.mlp_norm(h))
# mlp_out.shape: [B, T, D]

# 块内残差累加
partial_block = partial_block + mlp_out
# [B,T,D] + [B,T,D] = [B,T,D]

# ── 返回 ─────────────────────────────────────────────────
return blocks, partial_block
# blocks: List[M] × [B,T,D] (M = 已完成块数,≤ N)
# partial_block: [B,T,D] or None (若刚重置)

一些细节

  1. 每个层都有独立的投影参数(attn_res_projmlp_res_proj),允许每个层学习不同的注意力残差权重。
  2. Attention 前和 MLP 前都应用了 block_attn_res,确保每个子层都能从历史块和当前部分和中动态聚合信息。
  3. 块边界检测基于层数和预设的块大小,确保每块包含固定数量的子层(例如 3 层),并在块完成时将部分和加入历史块列表。
  4. attn_res 的最后是一个 softmax,确保所有源(历史块 + 当前部分和)的权重和为 1,使得聚合后的隐藏状态是一个加权平均,不容易出现数值不稳定。

最终结果

pem1YF0.png

Thoughts

  • 很多大模型公司都在探索改进残差连接的方法,比如 DeepSeekmHC,确实在大模型上残差连接的设计非常重要,直接影响模型的训练稳定性和性能,在大模型结构设计进入深水区之后,这种细节确实该琢磨了。
  • 这种这么长生命周期的残差,感觉会让训练和推理的显存管理变困难。

URL

TL;DR

  • 本文提出一种单步图像生成的算法 drifting,可以做到一步从噪声得到图片,无需像 diffusion / flow matching 一样多轮迭代得到图片

Algorithm

损失函数设计

L=Eϵ[fθ(ϵ)predictionstopgrad ⁣(fθ(ϵ)+Vp,qθ ⁣(fθ(ϵ)))frozen target2]\mathcal{L} = \mathbb{E}_{\epsilon} \left[ \left\| \underbrace{ f_\theta(\epsilon) }_{\text{prediction}} - \underbrace{ \operatorname{stopgrad}\!\left( f_\theta(\epsilon) + V_{p,q_\theta}\!\left(f_\theta(\epsilon)\right) \right) }_{\text{frozen target}} \right\|^2 \right]

  • ϵ\epsilon 是随机噪声
  • fθ()f_\theta() 是生成模型
  • Vp,qθV_{p,q_\theta} 是漂移场

漂移场的计算

Vp,q(x)=Vp+(x)正样本吸引Vq(x)负样本排斥V_{p,q}(x) = \underbrace{V_p^+(x)}_{正样本吸引} \underbrace{-V_q^-(x)}_{负样本排斥}

  • pp 表示真实数据分布
  • qq 表示模型生成的数据分布

正样本吸引(来自真实数据)

Vp+(x)=1ZpEy+p[k(x,y+)标量权重(y+x)向量方向]V_p^{+}(x) = \frac{1}{Z_p} \mathbb{E}_{y^{+} \sim p} \left[ \underbrace{k(x, y^{+})}_{标量权重} \, \underbrace{(y^{+} - x)}_{向量方向} \right]

负样本排斥(来自生成分布)

Vq(x)=1ZqEyq[k(x,y)标量权重(yx)向量方向]V_q^{-}(x) = \frac{1}{Z_q} \mathbb{E}_{y^{-} \sim q} \left[ \underbrace{k(x, y^{-})}_{标量权重} \, \underbrace{(y^{-} - x)}_{向量方向} \right]

  • y+y^+ 真实样本
  • yy^- 生成样本
  • k(x,y)k(x,y) 相似度核函数,输出标量

相似度核函数

k(x,y)=exp(xyτ)k(x,y)=exp(-\frac{\|x-y\|}{\tau})

  • τ\tau 表示温度

pZHQZYq.png

pZHQVkn.png

用一个例子说明损失函数的作用方式

  • 当前生成样本(要算 drift 的点)

x=(0,0)x=(0,0)

  • 正样本(来自真实分布 pp)两个点

y1+=(2,0), y2+=(2,2)y_1^+=(2,0),\ y_2^+=(2,2)

  • 负样本(来自生成分布 qq)两个点(通常就是 batch 里的其他生成样本)

y1=(1,0), y2=(0,1)y_1^-=(-1,0),\ y_2^-=(0,-1)

Step 1: 算 kernel 权重(谁更“像” x)

  • 正样本的权重
    • 距离:
      • xy1+=2\|x-y_1^+\|=2
      • xy2+=82.83\|x-y_2^+\|=\sqrt 8\approx 2.83
    • kernel:
      • k(x,y1+)=e20.135k(x,y_1^+)=e^{-2}\approx0.135
      • k(x,y2+)=e2.830.059k(x,y_2^+)=e^{-2.83}\approx0.059
  • 负样本的权重:
    • k(x,y1)=k(x,y2)=e10.368k(x,y_1^-)=k(x,y_2^-)=e^{-1}\approx0.368

Step 2: 算方向向量(往哪里走)

  • 正样本方向:

y1+x=(2,0)y2+x=(2,2)y_1^+-x=(2, 0)\\ y_2^+-x=(2, 2)

  • 负样本方向:

y1x=(1,0)y2x=(0,1)y_1^--x=(-1, 0)\\ y_2^--x=(0, -1)

Step 3: 算正向漂移 Vp+V_p^+

  • 加权求和

k(x,y+)(y+x)=0.135(2,0)+0.059(2,2)=(0.388,0.118)\sum k(x,y^+)(y^+-x)=0.135(2, 0)+0.059(2, 2)=(0.388, 0.118)

  • 归一化

Zp=0.135+0.059=0.194Vp+(x)=10.194(0.388,0.118)(2.0,0.61)Z_p=0.135+0.059=0.194\\ V_p^+(x)=\frac{1}{0.194}(0.388, 0.118)\approx (2.0, 0.61)

Step 4: 算负向漂移 VqV_q^-

  • Vp+V_p^+ 同理

Vq(x)=(0.5,0.5)V_q^-(x) = (-0.5,-0.5)

Step 5: 合并成为最终漂移 Vp,qV_{p,q}

Vp,q(x)=Vp+Vq=(2.0,0.61)(0.5,0.5)=(2.5,1.11)V_{p,q}(x)=V_p^+-V_q^-=(2.0, 0.61)-(-0.5,-0.5)=(2.5, 1.11)

每个模块的组成

  • Generator: 采用 DiT (Diffusion Transformer) 架构。输入是噪声 + 类别条件。
  • Tokenizer: 使用 Stable DiffusionVAE 将图像压缩到 Latent Space (32×32×432 \times 32 \times 4)。
  • Conditioning: 使用 AdaLN-Zero (Adaptive Layer Norm) 注入类别信息。这与标准 DiT 一致。
  • Feature Extractor: 用于计算损失函数。论文指出损失是在特征空间中计算的,而非直接在像素空间。用的模型是 MAE 自监督训练得到的。

Thoughts

  • 有很多 MoCo 的影子,比如 stop_grade 作为监督,对比学习等
  • 还有 MAE,kaiming 大佬把自己的作品都串起来了,tql…

URL

TL;DR

  • 论文题目是:“通过可扩展查找的条件内存:大语言模型的新稀疏轴”,目标是找到一种条件记忆方法,解决目前大语言模型靠大量计算来拟合记忆的问题,换句话说:目前的大模型都是纯计算,没有记忆,但本论文将计算和记忆在某种程度上进行解耦
  • N-gram 是一种在自然语言处理领域非常古老的算法,目的是抽取自然语言局部相关性,类似于卷积之于图像
  • 提出 Engram 架构,通过现代化的 N-gram 嵌入技术,实现了常数级时间复杂度 O(1)O(1) 的静态知识查找,释放了模型主干的计算压力,可以看成是一种新的稀疏

总体流程

 2026-01-17 15-23-05.png

  • 标准 Transformer 包含 Attention + MoE,增强 Transformer 包含 Engram + Attention + MoE

1. 词表压缩

  • 具体做法是规范化:
1
2
3
4
5
6
7
8
9
10
normalizers.Sequence([
normalizers.NFKC(),
normalizers.NFD(),
normalizers.StripAccents(),
normalizers.Lowercase(),
normalizers.Replace(Regex(r"[ \t\r\n]+"), " "),
normalizers.Replace(Regex(r"^ $"), SENTINEL),
normalizers.Strip(),
normalizers.Replace(SENTINEL, " "),
])
  • 128815 词表大小降低到 98627

2. N-gram 哈希

1. 先做 shift

  • 输入:x = [t0, t1, t2, t3, t4]
  • 构造:
    • shift_0: [t0, t1, t2, t3, t4]
    • shift_1: [PAD, t0, t1, t2, t3]
    • shift_2: [PAD, PAD, t0, t1, t2]

2. 做 N-gram 哈希

1
2
3
4
5
6
7
mix = (
token[t] * m0
) XOR (
token[t-1] * m1
) XOR (
token[t-2] * m2
)

其中:

  • m0, m1, m2
    • 奇数
    • layer_id 相关
    • 随机但可复现
  • 为什么是 XOR
    • 顺序敏感
    • 分布均匀
    • 比加法抗碰撞
    • MurmurHash 便宜

3. 多 Head:同一个 n-gram,多种 hash 视角

  • 一个 n-gram 对应多个 embedding lookup,这和 Multi-Head Attention 的思想是完全一致的
  • 对素数词表取余数
1
2
3
# n_head_per_ngram = 8
for head in range(n_head_per_ngram):
head_hash = mix % prime_vocab_size[head]

3. 根据 hash 后的 index lookup embedding

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
@dataclass
class EngramConfig:
tokenizer_name_or_path: str = "/data/Models/deepseek-ai/DeepSeek-V3/"
# 2-gram 和 3-gram 每个词表 129280*5 行
engram_vocab_size: List[int] = field(default_factory=lambda: [129280*5, 129280*5])
# 只做 2-gram / 3-gram
max_ngram_size: int = 3
# embedding dim
n_embed_per_ngram: int = 512
# 每个 ngram 8 个头(8 种 hash 算法),每个头的 dim= 512 / 8 = 64
n_head_per_ngram: int = 8
# 只对第 1 层和第 15 层用 engram,其他层是标准 transformer layer
layer_ids: List[int] = field(default_factory=lambda: [1, 15])
pad_id: int = 2
seed: int = 0
kernel_size: int = 4
  • 对模型的第 1 层和第 15 层使用 engram 模块
  • 每个 engram 模块用 2-gram / 3-gram
  • 每个 n-gram8 个头
  • 每个头对应一个约等于 129280 * 5embedding table,且完全不共享
维度 是否共享
不同 layer(1 vs 15) ❌ 不共享
不同 n-gram(2 vs 3) ❌ 不共享
同一 n-gram 的不同 head ❌ 不共享
embedding 表 全部独立

4. 查到 embedding 之后,计算得到输出

  • embedding 用来生成一个 gate factor,去调制原始 input hidden
  • 换句话说:Engram 不是“提供新特征”,而是“决定哪些原始 hidden 值值得被放大 / 抑制”
1
2
3
4
5
6
7
8
9
embeddings = embedding_lookup(hash_ids)     # (B, T, E)

key = W_k(embeddings)
query = hidden_states

gate = f(key · query) # 标量 gate
value = gate * W_v(embeddings)

output = value + short_conv(value)

推理优化

 2026-01-17 15-23-34.png

  • 由于哈希索引仅依赖于输入 tokenEngram 的查找操作具有确定性。
  • 在推理过程中,系统可以异步地从主机内存(CPU RAM)通过 PCIe 总线预取所需的嵌入向量
    • 通信隐藏:通过将 Engram 模块放置在主干网络的较深层(如第 15 层),可以利用前续层的计算时间来掩盖内存读取的延迟
    • 吞吐量损耗极低:实测显示,在将 100B 参数的嵌入表卸载到主机内存的情况下,推理吞吐量的下降幅度小于 3%
  • 多级缓存层次结构:
    • 基于 N-gramZipfian 分布规律(长尾效应),研究者提出了一种多级缓存设计
      • GPU HBM:存放访问频率最高的 N-gram 嵌入
      • 主机 DRAM:存放大部分中等频次的嵌入
      • NVMe SSD:存放长尾的、极其罕见的模式
    • 这种层次结构使得模型可以支持近乎无限的参数扩展,而不会触及 GPU 显存的硬上限

Thought

  • 这种对 transformer layer 的改动确实很牛,O(1) 复杂度太强了
  • 还考虑了推理存储的 memory hierarchy,不得不说 DeepSeek 确实不讲故事,真挑不出毛病

URL

TL;DR

  • 本文提出了一种新的残差连接方式:“流形约束超连接”,通过扩展残差流的宽度来突破传统残差连接带来的模型表达能力,同时解决了无约束超连接带来的恒等映射丧失和训练不稳定的问题。

Algorithm

大纲

mhc.png

  • 从左到右依次是:
    • 残差连接:xl+1=xl+F(xl,Wl)x_{l+1}=x_l+\mathcal{F}(x_l,\mathcal{W}_l)
    • 超连接:xl+1=Hlresxl+HlpostTF(Hlprexl,Wl)x_{l+1}=\mathcal{H}^{res}_lx_l+{\mathcal{H}^{post}_l}^T\mathcal{F}(\mathcal{H}^{pre}_lx_l,\mathcal{W}_l),其中:
      • HlpreR1×n\mathcal{H}^{pre}_l\in\mathbb{R}^{1\times n} 表示预映射,负责从 nn 条流中聚合信息,形成标准维度 CC 的输入供当前层的 F\mathcal{F} 处理。
      • HlpostR1×n\mathcal{H}^{post}_l\in\mathbb{R}^{1\times n}:负责将当前层 F\mathcal{F} 的输出(维度 CC)广播或分发回 nn 条流中。
      • HlresRn×n\mathcal{H}^{res}_l\in\mathbb{R}^{n\times n}:负责在 nn 条流之间进行信息的混合与路由。
    • 流形约束超连接:在超连接的基础上,对三个映射矩阵做了 必须是双随机矩阵 的限制,用于解决训练不稳定问题,表示为 PM()\mathcal{P_M}(),这里的:
      • M\mathcal{M} 表示双随机矩阵
      • PM\mathcal{P_M} 表示将任意一个矩阵变成双随机矩阵的函数

双随机矩阵

什么是双随机矩阵

  • 一个矩阵 MRn×nM \in \mathbb{R}^{n \times n} 被称为双随机矩阵,当且仅当它满足以下三个条件:
    • 非负性 (Non-negativity): 所有元素 Mij0M_{ij} \ge 0
    • 行和为 1 (Row Sum Unity): M1n=1nM \mathbf{1}_n = \mathbf{1}_n
    • 列和为 1 (Column Sum Unity): 1nM=1n\mathbf{1}_n^\top M = \mathbf{1}_n^\top

双随机矩阵的性质

  • 范数保持 (Norm Preservation) 与非扩张性
    • 双随机矩阵的最大奇异值(即谱范数 2\|\cdot\|_2)严格等于 1。这对应于全 1 向量 1\mathbf{1} 是其主特征向量。
    • 这意味着 Hlres\mathcal{H}_l^{res} 是一个非扩张算子 (Non-expansive Operator)。无论输入信号 xlx_l 的强度如何,经过 Hlresxl\mathcal{H}_l^{res} x_l 变换后,其能量(范数)不会被放大。这直接切断了梯度爆炸的源头。
  • 凸组合 (Convex Combination) 与特征均值守恒
    • 由于矩阵元素非负且归一化,Hlresxl\mathcal{H}_l^{res} x_l 的每一行实际上是输入流特征的加权平均(凸组合)。
    • 几何解释: 输出特征向量必然位于输入特征向量构成的凸包 (Convex Hull) 内部。
    • 物理意义: 这种操作起到了平滑和混合的作用,而不是缩放。它保证了多流系统中,特征的全局均值在传播过程中保持守恒 (Mean Conservation)。
  • 乘法封闭性 (Compositional Closure)
    • 这是 mHC 能够扩展到任意深度的数学基石。双随机矩阵的集合在矩阵乘法下是封闭的。
    • 证明: 设 A,BA, B 为双随机矩阵,则 (AB)1=A(B1)=A1=1(AB)\mathbf{1} = A(B\mathbf{1}) = A\mathbf{1} = \mathbf{1},且 1(AB)=(1A)B=1B=1\mathbf{1}^\top (AB) = (\mathbf{1}^\top A)B = \mathbf{1}^\top B = \mathbf{1}^\top
    • 这一性质意味着,无论网络堆叠多少层,复合映射 i=lL1Hires\prod_{i=l}^{L-1} \mathcal{H}_i^{res} 始终是一个双随机矩阵。
    • 因此,范数保持和均值守恒的性质在整个网络的深度方向上是全局有效的。这使得 mHC 成功恢复了类似 ResNet 的恒等映射稳定性,同时保留了流间信息交互的能力。

双随机矩阵生成的方法:Sinkhorn-Knopp 可微投影

  • 实现代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class MHCMappings(nn.Module):
def __init__(self, n, C):
super().__init__()
self.n = n
self.C = C
self.nc = n * C

# ========== 动态映射参数 φ ==========
self.phi_pre = nn.Linear(self.nc, n, bias=False)
self.phi_post = nn.Linear(self.nc, n, bias=False)
self.phi_res = nn.Linear(self.nc, n * n, bias=False)

# ========== 静态偏置 b ==========
self.b_pre = nn.Parameter(torch.zeros(n))
self.b_post = nn.Parameter(torch.zeros(n))
self.b_res = nn.Parameter(torch.zeros(n, n))

# ========== gating scalars α ==========
self.alpha_pre = nn.Parameter(torch.tensor(0.01))
self.alpha_post = nn.Parameter(torch.tensor(0.01))
self.alpha_res = nn.Parameter(torch.tensor(0.01))

self.norm = RMSNorm(self.nc)

def forward(self, x):
"""
x: (n, C)
return:
H_pre : (1, n)
H_post : (1, n)
H_res : (n, n)
"""

# ===== Step 1: flatten & RMSNorm (公式 7) =====
x_flat = x.reshape(1, self.nc) # (1, nC)
x_norm = self.norm(x_flat) # (1, nC)

# ===== Step 2: dynamic + tanh + static =====
H_pre_tilde = (
self.alpha_pre
* torch.tanh(self.phi_pre(x_norm))
+ self.b_pre
) # (1, n)

H_post_tilde = (
self.alpha_post
* torch.tanh(self.phi_post(x_norm))
+ self.b_post
) # (1, n)

H_res_tilde = (
self.alpha_res
* torch.tanh(self.phi_res(x_norm))
).reshape(self.n, self.n) + self.b_res # (n, n)

# ===== Step 3: manifold projection (公式 8,9) =====
H_pre = torch.sigmoid(H_pre_tilde)
H_post = 2.0 * torch.sigmoid(H_post_tilde)
H_res = sinkhorn_knopp(H_res_tilde)

return H_pre, H_post, H_res

infra 优化

  • 在算法之外还做了很多实现上的深度优化(infra 超强已经是 deepseek 传统艺能了)

Experiments

mhc_v2.png

  • DeepSeek 团队在 3B、9B 和 27B 参数量的模型上,基于 DeepSeek-V3 的 MoE + MLA 架构进行了广泛的实验。所有 mHC 实验均设定扩展率 n=4n=4
  • Loss 曲线对比: 标准 HC 模型在训练至 1.2 万步左右时出现了灾难性的 Loss 突刺 (Spike),这与梯度范数的突然爆炸高度吻合。相比之下,mHC 的 Loss 曲线极其平滑,且绝对值始终低于标准残差基线 (Baseline)。
  • 梯度范数: mHC 的梯度范数在整个训练过程中保持稳定,消除了 HC 的剧烈震荡现象。
  • 27B 模型在 8 个主流基准测试上的结果。mHC 展现了对 baseline 和 HC 的全面的超越。

Thoughts

  • LLM 模型的本质大概有这么几块
    • Transformer / SSM
    • MLP
    • Norm
    • 残差
  • 前两块决定了模型的表现能力上限,deepseek 都已经进行过深度优化了(MLA、MoE)
  • 后两部分是梯度下降算法不得不面对的训练稳定性课题,所以这次 deepseek 对残差动手是非常 make sense 的,而且这两块属于传统基于设计的深度学习擅长的两块,纯粹的力大砖飞是行不通的
  • 其实从实现角度看,并没有什么颠覆性的改变,效果这么好确实挺牛的

URL

TL;DR

  • 本文提出一种多模态理解和生成统一的架构,在 Qwen3-VL-8B 的基础上加入了 generation experts(3B)DiT(2B),在不牺牲任何多模态理解能力的情况下,实现了生成和理解的统一。
  • 可以解决图文输入图文输出的多模态任务,比如:
    1. 文生图
    2. 图像编辑
    3. 多模态理解等

Algorithm

1. 硬路由混合专家(Hard-Routing Mixture-of-Experts)

  • 为了增加图像生成能力,Mammothmoda2Qwen3-VL-8B 的内部加入了随即初始化的 generation experts
  • 为了不牺牲 pretrain 多模态理解的能力,pretrain 模型的参数和 generation experts 的参数是通过 硬路由混合专家 的方式选择 token 的,具体来说就是:多模态理解的 token 激活原 pretrain 模型的参数(也被叫做 understanding experts),图像生成的 token 激活新增的随机初始化的 generation experts
  • 这样做的好处是啥?

mammothmoda.png

  • 从上图可以看出在 pre-stage 1 / 2 阶段,模型的 backbone 参数是冻结的,也就是说在第三列的 SFT 阶段之前,qwen3-vl-8b 的多模态理解能力是不会受到任何影响的
  • 硬路由混合专家的具体实现方式:
    1. 每层 transformer 初始化一个新的 ffn 层,被称为 generation experts;原始的 ffn 层叫做 understanding experts
    2. generation 任务扩展词表和扩展 vocab embedding 参数随机初始化
    3. 根据每一个 token 是否属于 generation token 得到一个 gen token mask
    4. 每一层根据 gen token mask 来为每一个 token 选择是激活哪个 ffn
  • 作者通过消融实验对比了 ffn moe / attention moe / ffn-attention moe 以及全层使用和仅深层使用,最终发现 14 层之后用 ffn moe 效果很好

2. 扩散生成器(Diffusion Generator)

  • qwen3-vl-8b 作为大脑,DiT 就是画图的手
  • 本文使用了一种单流扩散架构 DiT,将处理后的条件信号和噪声潜变量(由 VAE 编码)作为统一输入(而不是两个输入),通过全序列自注意力机制进行生成

3. AR-Diffusion 特征对齐模块

  • 如果说 Qwen3-VL 是大脑,DiT 是手,那么 AR-Diffusion 特征对齐模块 就是神经系统
  • 总体来说,特征对齐分为三个步骤:
    1. 多层级特征融合:不仅仅用 qwen3-vl-8b 的最后一层输出,而是使用了模型深层的多层特征
    2. 统一条件编码:将 backbone 输出的特征重新按照模态拆分,并分别压缩,再用双向 transformer 融合,作为 condition encoding
    3. 上下文条件注入:将原图通过 vae 编码为 noise,和 condition encoding 合并作为单流 DiT 的输入
  • 用一个例子说明:图像编辑任务,图像是一只猫,shape = (256, 256, 3),文本是:“给这只猫戴个红色的帽子”
    1. 假设图像经过 vit 压缩之后的 token 长度 L_v = 224,文本 token 长度 L_t = 32,总长度 L_seq = 256
    2. qwen3-vl-8b6 层的特征,每层 shape = (256, 4096)4096qwen3-vl-8bhidden size
    3. 6 层特征作平均池化 (6, 256, 4096) -> (256, 4096)
    4. 分离图文特征:(256, 4096) -> (32, 4096) + (224, 4096)
    5. 文本特征压缩:(32, 4096) -> mlp -> (32, 1024)1024DiThidden size
    6. 图像特征压缩:(224, 4096) -> QFormer -> (64, 4096)QFormer 可以将任意长度的视觉特征编码成固定长度(64)的特征,做法是用 64querycross attention
    7. 图文特征重新融合得到 条件特征 :文本 (32, 1024) + 视觉 (64, 1024) -> 拼接为 (96, 1024) -> 双向 Transformer 编码 -> 输出 条件特征 (96, 1024)
    8. 原图用 vae 压缩为 噪声潜变量(256, 256, 3) -> vae -> (32, 32, 4) -> flatten -> (1024, 4) -> mlp -> (1024, 1024)
    9. 噪声潜变量和条件特征合并作为单流 DiT 输入:(96, 1024) + (1024, 1024) -> (1120, 1024) -> DiT -> Target Image

4. 训练策略

4.1 预训练

预训练过程冻结 backbone

  • 第一阶段:生成基础(Pre-Stage 1
    • 目标:建立文本到图像(T2I)的基础映射关系。
    • 数据:仅使用 T2I 数据,分辨率限制在 512×512512 \times 512
    • 策略:DiT 仅接收来自 AR 骨干网的文本特征 作为条件。此时,生成专家(Generation Experts)和 DiT 参数从零开始训练,而理解参数被冻结。
    • 逻辑:在低分辨率和单一模态条件下,模型更容易收敛,学会基本的物体形状和构图。
  • 第二阶段:复杂指令与编辑(Pre-Stage 2
    • 目标:引入图像编辑能力和高分辨率生成。
    • 数据:引入图像编辑(Image Editing)数据,分辨率提升至 1024×10241024 \times 1024
    • 策略:DiT 开始接收文本+视觉的双重特征。此时,模型开始学习如何理解输入图像(例如“把图中的猫变成狗”),并保留原图的背景结构。
    • 关键点:这一阶段是 Mammoth2 区别于普通 T2I 模型的关键,它开始具备“看图改图”的能力。

4.2 SFT

  • 所有参数都解冻,全部参与更新,通过联合优化,AR 部分学会了生成更适合 DiT 解码的语义规划,而 DiT 也学会了更好地适应 AR 的意图

4.3 RL

  • Mammoth2 引入了 DiffusionNFT (Negative-aware Fine-Tuning),这是一种针对流匹配模型的开创性 RL 算法 。

Thought

  • 为了保留 vlm 模型的能力作了很大的创新,理论上作生成任务的潜力很大
  • 理解和生成统一架构可能是未来发展方向

URL

TL;DR

  • 本文提出一个图/文生 3D 的模型 TRELLIS 和一种 3D 表征方式 SLAT (Structured LATent),具体来说 SLAT 表征包含两个部分:
    • coords: 整数坐标 (N, 4) int32 - [batch, x, y, z],用于描述 16316^3 voxel 空间下的物体 表面 占据的体素坐标
    • feats: 浮点特征 (N, 8) float32 - 8通道特征向量,用于描述每一个激活体素坐标上的特征编码,编码长度为 8
  • 关键就是两个 DiT 模型,两个模型的输出结果合并就是 SLAT 表征,两个 DiT 都是使用流匹配范式训练的
    • 第一个 DiT 模型从图像/文本特征中得到 3D 物体的表明占据体素坐标,这个模型被称为 SparseStructureFlowModel
    • 第二个 DiT 模型从图像/文本特征以及体素占据坐标得到每个激活体素的特征向量,这个模型被称为 ElasticSLatFlowModel
  • 模型在 3D 数据集上通过自监督方式训练得到,可从图片或文本得到不同的 3D 输出格式(3D Gaussian、Radiance Field、mesh 等)

Algorithm

2025-10-22_12-32.png

1. 推理过程

  • 以图片 -> 3D 为例:
flowchart TD
    A[输入图片] --> B[预处理 518x518]
    B --> C[DINOv2编码 Bx1369x1024]
    C --> D[Stage 1: Sparse Structure Flow]
    D --> E[噪声 Bx8x16x16x16]
    E --> F[DiT Transformer 24层]
    C --> F
    F --> G[Sparse Structure Latent Bx8x16x16x16]
    G --> H[Decoder]
    H --> I[稀疏坐标 Nx4]
    I --> J[Stage 2: SLAT Flow]
    J --> K[稀疏噪声 SparseTensor Nx8]
    K --> L[稀疏DiT Transformer 24层]
    C --> L
    L --> M[SLAT表征 SparseTensor Nx8 分辨率64^3]
    M --> N[Gaussian解码器]
    M --> O[Radiance Field解码器]
    M --> P[Mesh解码器]
    N --> Q[3D Gaussians]
    O --> S[Strivect辐射场]
    P --> T[三角网格]

0: 图像预处理

  • 功能:去除背景、裁剪、归一化
  • 输入:原始PIL图像(任意尺寸)
  • 输出:518×518 RGB图像,带alpha通道
  • 处理步骤:
    • 使用 rembg (U2-Net) 去除背景
    • 根据前景内容裁剪并居中
    • 调整大小到 518×518

1: 图像条件编码

  • 模型:DINOv2 ViT-L/14-reg
  • 输入:(B, 3, 518, 518) - 预处理后的图像
  • 输出:(B, N, 1024) - 图像特征tokens,其中 N = (518/14)² = 1369(37 x 37)
  • 处理:通过 DINOv2 提取 patch tokens

2: 稀疏结构生成 (Stage 1)

  • 模型:SparseStructureFlowModel(DiT)
  • 输入:
    • Noise: (B, 8, 16, 16, 16) - 3D噪声张量
    • Image condition: (B, 1369, 1024) - DINOv2特征
    • Timestep: (B,) - Flow matching时间步
  • 处理流程:
    • 将3D噪声 patchify(如果patch_size>1)
    • 添加3D位置编码
    • 通过24层 DiT transformer blocks,进行图像条件的交叉注意力
    • Unpatchify 回3D体积
  • 输出:(B, 8, 16, 16, 16) - 稀疏结构latent
  • 解码到坐标:
    • 通过 SparseStructureDecoder 解码为occupancy grid
    • Resolution: 16³
    • 提取 occupancy > 0 的体素坐标
  • 解码器输出:(B, 1, 16, 16, 16) - occupancy概率
  • 提取坐标:(N_occupied, 4) - [batch_idx, x, y, z],其中 N_occupied 是非空体素数量

3: SLAT 生成 (Stage 2) - 核心表征

  • 模型:ElasticSLatFlowModel
  • 输入:
    • Sparse noise: SparseTensor
      • coords: (N_voxels, 4) - 从Stage 1得到的坐标
      • feats: (N_voxels, 8) - 随机噪声特征
    • Image condition: (B, 1369, 1024) - DINOv2特征
    • Timestep: (B,) - Flow matching时间步
  • 处理流程:
    • 通过sparse linear层和下采样ResBlocks处理输入 (2倍下采样)
    • 添加稀疏3D位置编码
    • 通过24层稀疏DiT transformer blocks,带图像交叉注意力
    • 通过上采样ResBlocks和skip connections恢复分辨率
    • 输出层产生8通道特征
  • 输出 - SLAT表征:SparseTensor
    • coords: (N_voxels, 4) - [batch_idx, x, y, z],坐标在64³空间
    • feats: (N_voxels, 8) - 8通道特征,应用归一化

4: SLAT 解码到多种3D表征

  • SLAT (Structured Latent) 是TRELLIS的统一3D表征,以稀疏张量格式存储,可以解码为三种不同的3D资产格式:
    • Gaussian Splatting 解码
    • Radiance Field 解码
    • Mesh 解码

2. 训练过程

0. 数据准备与多视角特征提取

  • 输入:
    • 3D 资产模型(mesh 或 point cloud)
    • 渲染器生成的多视角 RGB 图像(约 150 views)
    • 每张图的相机姿态(R, t, intrinsics)
  • 过程:
    • 使用预训练视觉模型 DINOv2 提取每张图的 patch-level 特征(每张图被切成 37x37 个 patch,每个 patch 提取成 1024 长度的特征向量);
    • 将 3D 资产 voxel 化(64 x 64 x 64 分辨率)
    • 体素稀疏化,只保留表面区域激活体素,平均每个样本保留 20k 左右的体素;
    • 通过已知相机位姿把这些特征反投影到 3D voxel grid;
    • 在每个 voxel 聚合多视角特征(平均 / max pooling)→ 得到 voxel feature grid f(x, y, z);
  • 输出:
    • 稀疏 3D 特征 (pi, fi) 组成的序列(active voxels),pi 表示体素 i 的位置坐标,fi 表示体素 i 的特征向量

1. 训练稀疏结构生成模型 SparseStructureFlowModel

  • 输入:
    名称 Shape 含义
    noise (B, 8, 16, 16, 16) 初始高斯噪声
    image_feats (B, 1369, 1024) DINOv2 提取的图像特征
    t (B,) flow matching 时间步
  • 输出:
    名称 Shape 含义
    latent (B, 8, 16, 16, 16) 生成的体素 latent
    occupancy (B, 1, 16, 16, 16) occupancy 概率 (经解码器预测)
  • 监督信号:监督目标来自 ground-truth occupancy grids(稀疏体素结构),根据 3d 资产可以得到(64 下采样到 16)

2. 训练 SLAT 核心生成 (ElasticSLatFlowModel)

  • 输入:
    输入 说明
    Sparse coords Stage 1 输出的非空体素坐标
    Sparse feats 随机噪声特征 (8维)
    Image feats DINOv2 图像特征
    t 时间步,用于 flow matching
  • 输出:
    输出 Shape 含义
    SLAT coords (N_voxels, 4) 稀疏坐标 (batch,x,y,z)
    SLAT feats (N_voxels, 8) 稀疏latent特征
  • 监督信号(SLAT 表征无法直接监督,需要 decode 为 3D 才可以):
    • 多视图渲染监督
      • 对每个 3D 资产有多视角图像
      • SLAT → 解码为 NeRF / Gaussian → 渲染成图像
      • 与真实视图计算重建损失
    • 几何一致性监督:SLAT → 解码为 mesh / occupancy → 与真实 mesh 计算 Chamfer 距离
    • 稀疏特征正则化:约束 latent 特征的分布,促进稀疏性和稳定性

3. 训练多格式解码器

  • SLAT 是统一稀疏latent论文中提到它可被解码为三种不同的 3D 资产类型:Gaussian / Radiance Field / Mesh
  • 这三个解码器各自独立训练,共享同一 SLAT latent 空间

Thought

  • 3D 生成基本都几何结构 + 纹理信息分开训练,都是用 Diffusion 模型在隐空间做,图像 / 文本作为 Diffusioncondition
  • 生成的 3D 模型 -> 渲染得到某个视角的 2D 图片,和原始 3D 模型对应视角渲染图片做 pixel-level 的损失,是自监督 3D 生成模型绕不开的

URL

TL;DR

  • EMMAWaymo 提出的端到端自动驾驶 VLA 模型,不像 RT-1RT-2 模型输出离散化的动作 tokenEMMA 的输出是纯文本。
  • 模型架构特别简单,就是一个标准的 VLM({image, text} -> {text}),在预训练基础上进行自动驾驶数据和通用多模态数据共训练得到。
  • VLA 架构,没有除图和文之外的其他任何模态(例如激光雷达、毫米波雷达、高精度地图定位等)。

Algorithms

核心思想

  • 核心思想:将自动驾驶任务转化为语言建模问题
  • 将所有文本类数据都统一到纯自然语言空间,不做任何的编码,这样所有任务都能在一个“语言空间”中统一表示与训练,利用预训练权重中的语言知识和推理能力。
类型 表示方式
视觉输入 (V) 多摄像头拼接图像或视频帧
非视觉输入 (T) 文本(包括导航意图、历史状态等)
输出 (O) 文本(轨迹点、3D框、车道线等)

模型架构

2025-10-23_21-04.png

模型能力

  • EMMA 有三大能力:
    • 端到端运动规划(E2E Planning)
    • 推理增强 (Chain-of-Thought)
    • 通用多任务(Generalist EMMA)
  • 这三种能力均来自于同一个模型的同一组参数,区别只是 prompt 中描述的任务内容不同。

1. 端到端运动规划能力

  • 输入:
    • 视频图像 V(多相机)
    • 高层导航意图 Tintent(如 “turn left”)
    • 历史自车轨迹 Tego(BEV坐标文本序列)
  • 输出:
    • 未来轨迹 Otrajectory:未来 Tf 个时刻的坐标 {(xt, yt)},纯文本形式输出
  • 监督:
    • 自监督,仅需未来轨迹真实值(无需人工标注)
    • 无需 HD 地图,无需 LiDAR

2. 思维链推理增强能力

  • 在输出轨迹前,模型生成 4 层“推理文本”,如下
层级 含义 示例
R1 场景描述(天气、路况等) “It’s sunny and the road is four-lane…”
R2 关键物体及其坐标 “Pedestrian at [9.01, 3.22], vehicle at [11.58, 0.35]”
R3 关键物体行为 “The pedestrian looks toward the road and may cross.”
R4 元驾驶决策 “I should keep my current low speed.”
  • 训练标签自动生成:
    • 由专家感知模型+启发式算法+Gemini生成描述 → 无需人工标注
  • 效果:
    • CoT 提升规划质量约 6.7%
    • 关键贡献来自 meta-decision (+3.0%) 与关键对象识别 (+1.5%)

3. 通用多任务

2025-10-23_22-59.png

模型表现

  • 一句话概括:在所有任务上,超越所有 SOTA

Thought

  • 和某车企复杂晦涩的 VLA 框架相比,EMMA 简直太优雅了,没有复杂的设计,统一框架上解决所有自动驾驶问题
  • 如果只做端到端,甚至不需要数据标注,只需要把驾驶数据的数量和质量堆上去就行
  • 可以只输出感知结果,也可以端到端,甚至 CoT 还可以做到白盒端到端,太雅了…

URL

TL;DR

  • 本文算是第一个真正意义上的 VLA 模型,和 RT-1 相比,有两个显著特征:
    • 使用了大量互联网的 VL 数据和 Robotic 数据混合训练做 co-fine-tuning
    • 模型更大了,5 - 55B 规模参数量,且 co-fine-tuning 的起点模型就是在互联网数据上训练收敛的模型

Algorithm

RT-1 的意义是什么?

  • RT-1 最主要的意义是证明了:将机器人控制信号离散化为动作 token,用 transformer 架构直接预测是可行的
  • 既然架构是可行的,那就顺其自然做 scaling(放大数据量和参数量)

RT-2 做了什么?

1. 起点是个 VLM

  • RT-2 分别用了谷歌自家的 PaLM-EPaLI-X 两个多模态模型做为起点模型,两个模型都是单独使用的
  • 二者都是 ViT + LLM 架构,且都是 {image, text} -> {text} 范式

2. 统一 VLRobotic 数据格式

2025-10-23_13-47.png

  • 方法和 RT-1 一样,将机器人动作信号离散化,并编码成动作 token
  • 具体来说:
    • 8-DoF 动作空间:6 个自由度(位置 + 旋转)+ 手爪开合 + “终止”标志
    • 每个连续维度被离散化为 256 个 bin
    • 每个 bin 对应一个 token
  • 将 VLM 已有的 tokens 与这 256 个离散量联系起来,才能实现 VLM 到 VLA 的转换
  • PaLI-X 和 PaLM-E 使用不同的 tokenization 方法,action tokens 需要与分别与其保持一致:
    • PaLI-X:1000 以内的每个数字都有一个相应的token,因此只需将 256 个离散量等于 256 个整数即可
    • PaLM-E:将最少出现的 256 个 tokens 覆盖掉,分别对应 256 个离散量

3. 混合数据集

  • vision-language datasets:来源于 PaLI-x 和 Palm-e 所使用的数据集,数据包括:
    • VQA: visual question answering
    • Captioning
    • unstructured interwoven image and text examples
    • PaLI数据集(WebLI)大小:10B images and covering over 109 anguages
    • Palm-e 使用多个数据集联合训练,其中 WebLI 占比52.4%
  • robotics dataset:RT1 dataset(13个机器人,17个月收集得到的数据)
  • 混合方式:co-fine-tuning 过程中,对数据做加权
    • RT-2-PaLI-X 中 robotics 数据占 50%
    • RT-2-PaLM-E 中 robotics 数据占 66%

泛化能力如何?

1. 在没见过的物体、背景、环境下的泛化能力

2025-10-23_14-08.png

2025-10-23_14-09.png

在同场景下,RT-2RT-1 差别很小,在没见过的物体、背景、环境下,RT-2 遥遥领先

2. 涌现能力

2025-10-23_14-19.png

3. CoT 能力

  • robotic 数据的动作 token 之前,加入 resoning token,会提高模型解决问题的能力

2025-10-23_14-20.png

resoning 即图中的 plan 部分

局限性和展望

  • 局限:
    • 物理技能仍局限于机器人数据分布(无法生成全新动作)。
    • 推理频率低,云端计算成本高。
  • 未来方向:
    • 结合人类视频学习新技能;
    • 模型量化与蒸馏以提升实时性;
    • 更多开源 VLM 融合(如 LLaVA、InternVL 等)。

Thought

  • 很有趣的工作,ppl 非常简单,引领了之后 VLA 的发展

URL

TL;DR

  • 本文的标题很直白《通过反向传播进行无监督领域自适应》,目标是实现无监督领域自适应。
  • 提出了梯度反转层 (Gradient Reversal Layer, GRL),用于无监督领域自适应任务中,通过在特征提取器和领域分类器之间插入 GRL,实现特征提取器在训练过程中 最大化 领域分类器的损失,从而通过对抗学习到领域不变的特征表示。
  • GAN 的思想有些类似,不过 GAN 的目标是生成,GRL 的目标是表征。

Algorithm

什么是无监督领域自适应

  • 领域自适应 (Domain Adaptation) 是迁移学习 (Transfer Learning) 的核心问题之一,旨在将从源领域 (Source Domain) 学习到的知识迁移到目标领域 (Target Domain)。
    GRL_1.png
  • 比如上图所示,有一批 MNIST 黑底白字的手写数字识别数据集,包含图像和标签,作为源领域;还有一批背景和字体颜色都随机的 MNIST-M 手写数字识别数据集,只有图像没有标签,作为目标领域。
  • 目标是利用源领域的有标签数据训练一个分类器,并使其在目标领域上也能有较好的性能。

如何实现无监督领域自适应

  • 最直接的方法:在源领域数据上训练,直接迁移到目标领域。但效果显然会很差,因为两个领域的数据分布差异较大。
  • 本文提出的方法:通过对抗训练学习领域不变的特征表示。
    GRL_2.png
  • 模型分成三个部分:
    • 特征提取器 (Feature Extractor, G_f): 提取输入数据的特征表示,上图绿色部分。
    • 标签分类器 (Label Classifier, G_y): 根据特征表示预测标签,上图蓝色部分。
    • 领域分类器 (Domain Classifier, G_d): 根据特征表示预测数据来自哪个领域(源领域或目标领域),上图红色部分。
  • 训练目标:
    • 最小化标签分类器的损失,使其在源领域数据上表现良好。
    • 最大化(注意这里是最大化,不是最小化)领域分类器的损失,使特征提取器学习到领域不变的特征表示。
  • 损失函数:
    • 标签分类器只对源领域数据计算交叉熵损失 LyL_y
    • 领域分类器对源领域和目标领域数据都计算交叉熵损失 LdL_d
    • 两个损失直接求和就是总损失 L=Ly+LdL=L_y + L_d(一定要注意这里没有 λ-\lambda,别被论文带沟里去)。
  • 梯度反转层 (Gradient Reversal Layer, GRL) 如何起作用:
    • 位置:插入在特征提取器和领域分类器之间。
    • 前向传播:GRL 不改变输入,直接将特征传递给领域分类器。
    • 反向传播:GRL 将从领域分类器传回的梯度乘以一个负的常数 ,使得特征提取器在更新时朝着最大化领域分类器损失的方向调整参数。
  • 梯度反转层的数学表达(设输入特征为 xGRL 的输出为 y):
    • 前向传播:y = x
    • 反向传播:dL/dx = -λ * dL/dy
  • 梯度反转层的 PyTorch 实现:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from torch import nn
from torch.autograd import Function

class GradientReversal(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x, alpha)
# 啥都不做,透传
return x

@staticmethod
def backward(ctx, grad_output):
grad_input = None
_, alpha = ctx.saved_tensors
if ctx.needs_input_grad[0]:
# 梯度乘以 -alpha 实现反转
grad_input = -alpha * grad_output
return grad_input, None
  • λ\lambda 的调节:从 0 逐渐增加到 1,意义:
    • 训练初期,λ=0\lambda=0,模型主要关注源领域的标签分类任务,确保分类器能学到有用的特征。
    • 随着训练进行,逐渐增加 λ\lambda,模型开始更多地关注领域分类器的对抗任务,促使特征提取器学习到领域不变的特征表示。

Thoughts

  • 很有趣的对抗学习设计,很简单但很有效。
  • 本质就是让领域分类器掌握的领域分类知识泄露给特征提取器,从而让特征提取器不断调整自己提取的特征,使得领域分类器无法区分源领域和目标领域的数据。