1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
| class PPOActorCriticNetwork(nn.Module): """ PPO Actor-Critic 网络,包含策略网络和价值网络。 """ def __init__(self, obs_space_dims: int, action_space_dims: int): super().__init__() hidden_space1 = 16 hidden_space2 = 32 self.shared_net = nn.Sequential( nn.Linear(obs_space_dims, hidden_space1), nn.Tanh(), nn.Linear(hidden_space1, hidden_space2), nn.Tanh(), ) self.policy_mean_net = nn.Linear(hidden_space2, action_space_dims) self.policy_stddev_net = nn.Linear(hidden_space2, action_space_dims) self.value_net = nn.Linear(hidden_space2, 1) def forward(self, x: torch.Tensor): """前向传播,输出动作参数和状态价值。""" shared_features = self.shared_net(x.float()) action_means = self.policy_mean_net(shared_features) action_stddevs = torch.nn.functional.softplus( self.policy_stddev_net(shared_features) ) action_stddevs = torch.clamp(action_stddevs, min=1e-3) state_values = self.value_net(shared_features) return action_means, action_stddevs, state_values class PPOAgent: """ PPO 算法的实现。 """ def __init__(self, obs_space_dims: int, action_space_dims: int): self.learning_rate = 1e-5 self.gamma = 0.99 self.lam = 0.95 self.eps_clip = 0.1 self.entropy_coeff = 0.01 self.epochs = 10 self.batch_size = 64 self.net = PPOActorCriticNetwork(obs_space_dims, action_space_dims) self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate) self.memory = [] def sample_action(self, state: np.ndarray): """根据策略采样动作。""" state = torch.tensor(np.array([state])) action_means, action_stddevs, state_value = self.net(state) distrib = Normal(action_means[0], action_stddevs[0] + 1e-3) action = distrib.sample() action = torch.clamp(action, -3.0, 3.0) log_prob = distrib.log_prob(action).sum() return action.numpy(), log_prob, state_value def store_transition(self, transition): """存储轨迹数据。""" self.memory.append(transition) def process_batch(self): """处理并计算优势函数和目标回报值。""" states, actions, log_probs, rewards, dones, state_values = zip(*self.memory) states = torch.tensor(states, dtype=torch.float32) actions = torch.tensor(actions, dtype=torch.float32) log_probs = torch.tensor(log_probs, dtype=torch.float32) state_values = torch.tensor(state_values, dtype=torch.float32).squeeze() rewards = torch.tensor(rewards, dtype=torch.float32) dones = torch.tensor(dones, dtype=torch.float32) next_state_values = torch.cat([state_values[1:], torch.tensor([0.0])]) * ( 1 - dones ) deltas = rewards + self.gamma * next_state_values * (1 - dones) - state_values advantages = torch.zeros_like(rewards) running_adv = 0 for t in reversed(range(len(rewards))): running_adv = deltas[t] + self.gamma * self.lam * running_adv * ( 1 - dones[t] ) advantages[t] = running_adv returns = advantages + state_values advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) self.memory = [] return states, actions, log_probs, returns, advantages def update(self): """更新策略和价值网络。""" states, actions, old_log_probs, returns, advantages = self.process_batch() dataset = torch.utils.data.TensorDataset( states, actions, old_log_probs, returns, advantages ) loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, shuffle=True ) for _ in range(self.epochs): for ( batch_states, batch_actions, batch_old_log_probs, batch_returns, batch_advantages, ) in loader: action_means, action_stddevs, state_values = self.net(batch_states) distrib = Normal(action_means, action_stddevs + 1e-3) log_probs = distrib.log_prob(batch_actions).sum(axis=-1) entropy = distrib.entropy().sum(axis=-1) ratios = torch.exp(log_probs - batch_old_log_probs) surr1 = ratios * batch_advantages surr2 = ( torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * batch_advantages ) policy_loss = -torch.min(surr1, surr2).mean() value_loss = nn.functional.mse_loss( state_values.squeeze(), batch_returns ) entropy_loss = -entropy.mean() loss = ( policy_loss + 0.5 * value_loss + self.entropy_coeff * entropy_loss ) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.net.parameters(), max_norm=0.1) self.optimizer.step()
|