diff --git a/main.py b/main.py index 0021104..53e6129 100755 --- a/main.py +++ b/main.py @@ -98,8 +98,8 @@ def main(): # Sample actions value, logits = actor_critic(Variable(rollouts.states[step], volatile=True)) probs = F.softmax(logits) - log_probs = F.log_softmax(logits).data action = probs.multinomial().data + action_log_probs = F.log_softmax(logits).gather(1, action).data cpu_actions = action.cpu().numpy() # Obser reward and next state @@ -119,7 +119,7 @@ def main(): current_state *= masks.unsqueeze(2).unsqueeze(2) update_current_state(state) - rollouts.insert(step, current_state, action, value.data, log_probs, reward, masks) + rollouts.insert(step, current_state, action, value.data, action_log_probs, reward, masks) next_value = actor_critic(Variable(rollouts.states[-1], volatile=True))[0].data @@ -187,8 +187,7 @@ def main(): log_probs = F.log_softmax(logits) action_log_probs = log_probs.gather(1, Variable(actions_batch)) - old_log_probs_batch = rollouts.old_log_probs.view(-1, rollouts.old_log_probs.size(-1))[indices] - old_action_log_probs = old_log_probs_batch.gather(1, actions_batch) + old_action_log_probs = rollouts.action_log_probs.view(-1, rollouts.action_log_probs.size(-1))[indices] ratio = torch.exp(action_log_probs - Variable(old_action_log_probs)) adv_targ = Variable(advantages.view(-1, 1)[indices]) @@ -196,7 +195,6 @@ def main(): surr2 = ratio.clamp(1.0 - args.clip_param, 1.0 + args.clip_param) * adv_targ action_loss = -torch.min(surr1, surr2).mean() # PPO's pessimistic surrogate (L^CLIP) - log_probs = F.log_softmax(logits) probs = F.softmax(logits) dist_entropy = -(log_probs * probs).sum(-1).mean() diff --git a/storage.py b/storage.py index 1032edd..27ad7fd 100644 --- a/storage.py +++ b/storage.py @@ -6,8 +6,7 @@ class RolloutStorage(object): self.states = torch.zeros(num_steps + 1, num_processes, *obs_shape) self.rewards = torch.zeros(num_steps, num_processes, 1) self.value_preds = torch.zeros(num_steps + 1, num_processes, 1) - self.old_log_probs = torch.zeros(num_steps, num_processes, - action_shape) + self.action_log_probs = torch.zeros(num_steps, num_processes, 1) self.returns = torch.zeros(num_steps + 1, num_processes, 1) self.actions = torch.LongTensor(num_steps, num_processes, 1) self.masks = torch.zeros(num_steps, num_processes, 1) @@ -16,17 +15,17 @@ class RolloutStorage(object): self.states = self.states.cuda() self.rewards = self.rewards.cuda() self.value_preds = self.value_preds.cuda() - self.old_log_probs = self.old_log_probs.cuda() + self.action_log_probs = self.action_log_probs.cuda() self.returns = self.returns.cuda() self.actions = self.actions.cuda() self.masks = self.masks.cuda() - def insert(self, step, current_state, action, value_pred, old_log_probs, + def insert(self, step, current_state, action, value_pred, action_log_probs, reward, mask): self.states[step + 1].copy_(current_state) self.actions[step].copy_(action) self.value_preds[step].copy_(value_pred) - self.old_log_probs[step].copy_(old_log_probs) + self.action_log_probs[step].copy_(action_log_probs) self.rewards[step].copy_(reward) self.masks[step].copy_(mask)