1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
| import numpy as np import gymnasium as gym import torch import torch.nn as nn import torch.optim as optim from collections import deque import random from tqdm import tqdm from typing import Tuple
class DQN(nn.Module): def __init__(self, input_dim, output_dim): super(DQN, self).__init__() self.fc1 = nn.Linear(input_dim, 128) self.fc2 = nn.Linear(128, 128) self.fc3 = nn.Linear(128, output_dim) def forward(self, x): x = torch.nn.functional.one_hot( x.to(torch.int64), num_classes=env.observation_space.n ).float() x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) return self.fc3(x) class DQNAgent: def __init__( self, env, learning_rate: float, initial_epsilon: float, epsilon_decay: float, final_epsilon: float, discount_factor: float = 0.95, batch_size: int = 64, memory_size: int = 10000, ): self.env = env self.lr = learning_rate self.discount_factor = discount_factor self.epsilon = initial_epsilon self.epsilon_decay = epsilon_decay self.final_epsilon = final_epsilon self.batch_size = batch_size self.memory = deque(maxlen=memory_size) self.training_error = [] self.q_network = DQN( input_dim=env.observation_space.n, output_dim=env.action_space.n ) self.target_network = DQN( input_dim=env.observation_space.n, output_dim=env.action_space.n ) self.target_network.load_state_dict(self.q_network.state_dict()) self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr) def get_action(self, obs: Tuple[int, int, bool]) -> int: """选择动作(epsilon-greedy)""" if np.random.random() < self.epsilon: return self.env.action_space.sample() else: obs_tensor = torch.tensor([obs], dtype=torch.float32) q_values = self.q_network(obs_tensor) return torch.argmax(q_values).item() def update(self): """从经验回放中随机抽取一个批次的经验,进行 Q 网络的更新""" if len(self.memory) < self.batch_size: return batch = random.sample(self.memory, self.batch_size) states, actions, rewards, next_states, dones = zip(*batch) states = torch.tensor(states, dtype=torch.float32) actions = torch.tensor(actions, dtype=torch.int64) rewards = torch.tensor(rewards, dtype=torch.float32) next_states = torch.tensor(next_states, dtype=torch.float32) dones = torch.tensor(dones, dtype=torch.int64) q_values = self.q_network(states) next_q_values = self.target_network(next_states) max_next_q_values = next_q_values.max(dim=1)[0] target_q_values = ( rewards + (1 - dones) * self.discount_factor * max_next_q_values ) q_value = q_values.gather(1, actions.unsqueeze(1)).squeeze(1) loss = nn.functional.mse_loss(q_value, target_q_values) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.training_error.append(loss.item()) def decay_epsilon(self): """逐渐减少 epsilon""" self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay) def store_experience(self, obs, action, reward, next_obs, done): """将经验存储到经验回放池""" self.memory.append((obs, action, reward, next_obs, done)) def update_target_network(self): """每隔一段时间更新目标网络""" self.target_network.load_state_dict(self.q_network.state_dict())
learning_rate = 0.001 n_episodes = 1000 + 2 start_epsilon = 1.0 epsilon_decay = start_epsilon / (n_episodes / 2) final_epsilon = 0.1 batch_size = 64 memory_size = 10000 target_update_freq = 10
env = gym.make("Taxi-v3", render_mode="rgb_array") env = gym.wrappers.RecordVideo( env, video_folder="./Taxi_video", episode_trigger=lambda episode_id: episode_id % 200 == 0, name_prefix="episode", )
agent = DQNAgent( env=env, learning_rate=learning_rate, initial_epsilon=start_epsilon, epsilon_decay=epsilon_decay, final_epsilon=final_epsilon, batch_size=batch_size, memory_size=memory_size, )
for episode in tqdm(range(n_episodes)): obs, info = env.reset() done = False while not done: action = agent.get_action(obs) next_obs, reward, terminated, truncated, info = env.step(action) agent.store_experience(obs, action, reward, next_obs, terminated or truncated) agent.update() obs = next_obs done = terminated or truncated if episode % target_update_freq == 0: agent.update_target_network() agent.decay_epsilon() env.close()
|