Zhangzhe's Blog

The projection of my life.

0%

Monte Carlo Tree Search: A Review of Recent Modifications and Applications

URL

TL;DR

  • 蒙特卡洛树搜索(Monte Carlo Tree Search, 简称 MCTS)旨在解决在拥有海量状态空间的复杂决策环境下的 “维度灾难” 问题,即:当我们无法穷举所有可能性,且不知道如何准确评估当前状态时,如何利用有限的计算资源,找到当前的最优解?
  • 核心思想是:不再穷举所有分支,而是像一个有经验的玩家一样:把更多的算力(搜索时间)倾斜到那些看起来赢面更大、更有希望的分支上,同时偶尔去探索一下没怎么走过的分支,防止错过奇招。这在算法中被称为探索与利用(Exploration vs. Exploitation)的平衡。

Algorithm

  • MCTS 通过不断的循环迭代来构建一棵搜索树。每次迭代包含四个经典的步骤:
    • 选择 (Selection)
    • 扩展 (Expansion)
    • 模拟 (Simulation / Rollout)
    • 回溯 / 反向传播 (Backpropagation)

1. 选择

  • 算法从树的根节点(当前局面)出发,向下遍历,直到找到一个还没有被完全展开的节点(即这个节点下还有未尝试过的动作)。
  • 在向下走的过程中,如果遇到多个子节点,它会使用一个公式来决定走哪边,最常用的公式是 UCT (Upper Confidence Bound applied to Trees):

    UCT=WiNi+clnNNiUCT= \frac{W_i}{N_i} + c \sqrt{\frac{\ln N}{N_i}}

    • WiW_i:该节点下的胜利次数
    • NiN_i:该节点被访问的次数
    • Wi/NiW_i / N_i利用(Exploitation) 项,即胜率
    • 胜率越高的节点,越容易被选中。NN:父节点的总访问次数
    • cc:探索常数(通常取 2\sqrt{2}
    • clnNNic \sqrt{\frac{\ln N}{N_i}}探索(Exploration) 项。一个节点被访问的次数 NiN_i 越少,这项的值就越大,就越容易被选中去探索。

2. 扩展

  • 当根据 UCT 公式找到一个未完全展开的叶子节点时,就给这个节点增加一个新的子节点,代表采取了一个从未尝试过的合法动作,进入了一个新的状态。

3. 模拟

  • 从新扩展出的新节点开始,完全随机地(或者基于某种简单策略) 往下模拟对局,直到分出胜负(到达终止状态,比如游戏结束)
  • 因为是随机模拟(这就是“蒙特卡洛”一词的来源),计算速度非常快。它不需要评估函数,而是直接把游戏“快进”到结局,看看是赢是输

4. 回溯 / 反向传播

  • 根据上一步模拟出的最终结果(比如:赢了记为 1,输了记为 0,平局记为 0.5),从这个新节点沿着刚才的路径一直往回走到根节点
  • 在回退的过程中,更新沿途每一个节点的统计数据:该节点的访问次数 NiN_i 加 1,如果模拟结果是胜利,该节点的胜利次数 WiW_i 加 1

循环与决策

  • MCTS 会在给定的时间限制或计算资源内,疯狂地重复这四个步骤(成千上万次)。每次循环都会让这棵搜索树长得更大,且统计数据越来越准确。
  • 当时间耗尽时,算法会看一眼根节点的所有子节点。此时不需要再看 UCT 公式了,直接选择访问次数最多(或者胜率最高)的那个子节点,作为最终的决策动作。

MCTS 在井字棋上的示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import math
import random
import copy
from abc import ABC, abstractmethod

# ==========================================
# 1. 定义状态接口与井字棋实现
# ==========================================
class GameState(ABC):
@abstractmethod
def get_legal_actions(self): pass
@abstractmethod
def apply_action(self, action): pass
@abstractmethod
def is_terminal(self): pass

class TicTacToe(GameState):
def __init__(self, board=None, current_player=1):
# board: 0 为空,1 为玩家 X,-1 为玩家 O
self.board = board if board is not None else [0] * 9
self.current_player = current_player

def get_legal_actions(self):
"""返回所有空位的索引"""
return [i for i, x in enumerate(self.board) if x == 0]

def apply_action(self, action):
"""落子,并返回切换了玩家的新状态"""
new_board = copy.deepcopy(self.board)
new_board[action] = self.current_player
return TicTacToe(new_board, current_player=-self.current_player)

def get_winner(self):
"""检查是否有玩家获胜,返回 1, -1, 或 None"""
lines = [
(0, 1, 2), (3, 4, 5), (6, 7, 8), # 横
(0, 3, 6), (1, 4, 7), (2, 5, 8), # 竖
(0, 4, 8), (2, 4, 6) # 交叉
]
for a, b, c in lines:
if self.board[a] != 0 and self.board[a] == self.board[b] == self.board[c]:
return self.board[a]
return None

def is_terminal(self):
"""判断游戏是否结束(有人赢了,或者没空位了)"""
return self.get_winner() is not None or len(self.get_legal_actions()) == 0

def get_result(self, player):
"""
核心方法:返回特定 player 在当前终局下的得分。
赢=1.0,输=0.0,平局=0.5
"""
winner = self.get_winner()
if winner is None:
return 0.5 # 平局
if winner == player:
return 1.0 # 该玩家获胜
else:
return 0.0 # 该玩家落败

def __str__(self):
symbols = {1: 'X', -1: 'O', 0: '.'}
res = ""
for i in range(3):
res += " ".join([symbols[self.board[i*3 + j]] for j in range(3)]) + "\n"
return res

# ==========================================
# 2. MCTS 树节点与主算法
# ==========================================
class MCTSNode:
def __init__(self, state: TicTacToe, parent=None, action=None):
self.state = state
self.parent = parent
self.action = action
self.children = []
self.untried_actions = state.get_legal_actions()

self.visits = 0 # 记录此节点到达的总次数
self.wins = 0.0 # 记录“导致到达此节点的动作”所带来的胜场

def is_fully_expanded(self):
return len(self.untried_actions) == 0

def best_child(self, c_param=1.414):
"""使用 UCT 公式选择最优子节点"""
choices_weights = []
for child in self.children:
exploit = child.wins / child.visits
explore = c_param * math.sqrt((2 * math.log(self.visits) / child.visits))
choices_weights.append(exploit + explore)
return self.children[choices_weights.index(max(choices_weights))]

def mcts_search(root_state: TicTacToe, iterations: int = 2000):
root_node = MCTSNode(state=root_state)

for _ in range(iterations):
node = root_node

# 第一步:选择 (Selection)
while node.is_fully_expanded() and not node.state.is_terminal():
node = node.best_child()

# 第二步:扩展 (Expansion)
if not node.state.is_terminal():
action = random.choice(node.untried_actions)
node.untried_actions.remove(action)
next_state = node.state.apply_action(action)

child_node = MCTSNode(state=next_state, parent=node, action=action)
node.children.append(child_node)
node = child_node

# 第三步:模拟 (Rollout) - 纯随机下棋直到结束
current_rollout_state = node.state
while not current_rollout_state.is_terminal():
possible_actions = current_rollout_state.get_legal_actions()
action = random.choice(possible_actions)
current_rollout_state = current_rollout_state.apply_action(action)

# 第四步:反向传播 (Backpropagation)
while node is not None:
node.visits += 1
# 如果 node 有父节点,说明 node 是由父节点的玩家执行 action 得到的
if node.parent is not None:
# 获取做出该动作的玩家
player_who_made_the_move = node.parent.state.current_player
# 将终局结果站在该玩家的视角累加到该节点
node.wins += current_rollout_state.get_result(player_who_made_the_move)
node = node.parent

# 搜索结束,选择访问量最大(最稳健)的子节点动作
best_final_node = max(root_node.children, key=lambda c: c.visits)
return best_final_node.action

def play_game():
game = TicTacToe()
print("=====================================")
print("欢迎来到 MCTS 井字棋对战!")
print("AI 执 X (先手),你执 O (后手)。")
print("棋盘的位置索引 (0-8) 如下所示:")
print(" 0 | 1 | 2 ")
print("---+---+---")
print(" 3 | 4 | 5 ")
print("---+---+---")
print(" 6 | 7 | 8 ")
print("=====================================\n")

print("初始棋盘:")
print(game)

while not game.is_terminal():
if game.current_player == 1:
# AI 的回合
print("AI (MCTS) 正在思考中... (这可能需要一两秒)")
# 这里的 iterations 控制 AI 的算力(难度)。
# 对于井字棋,2000 次迭代 AI 已经几乎是“不败”的了。
best_move = mcts_search(game, iterations=2000)
print(f"AI 落子在位置: {best_move}\n")
game = game.apply_action(best_move)
else:
# 玩家的回合
legal_moves = game.get_legal_actions()
while True:
try:
move = input(f"轮到你了,请输入你要落子的位置 {legal_moves}: ")
move = int(move)
if move in legal_moves:
break
else:
print("⚠ 错误:该位置已有棋子或超出范围,请重新输入。")
except ValueError:
print("⚠ 错误:请输入一个有效的数字。")

print(f"你落子在位置: {move}\n")
game = game.apply_action(move)

# 打印当前棋盘状态
print(game)

# 游戏结束,判定胜负
print("=====================================")
winner = game.get_winner()
if winner == 1:
print("🤖 游戏结束:AI (X) 获胜!(MCTS 确实很难被击败哦)")
elif winner == -1:
print("🎉 游戏结束:恭喜你 (O) 获胜!(这说明你找到了算法的盲区!)")
else:
print("🤝 游戏结束:平局!(面对完美的对手,平局已经是最好的结果了)")
print("=====================================")

if __name__ == "__main__":
play_game()

Thoughts

  • 对于这种简单的游戏,只经过简单的蒙特卡洛树搜索,AI 先手基本可以做到完全不败,但对于像围棋这种搜索空间的游戏,单用 MCTS 实际上无法做到穷举例
  • MCTS 在论文中是每个回合都重建一棵新树,在下个回合就抛弃,工程上通常是一局对弈只建立一棵树,随着对弈根节点向下,只有出现之前未搜索到的局面,才会建新树