Store a single log probability of actions

This commit is contained in:
Ilya Kostrikov
2017-09-21 19:25:16 -04:00
parent 475de22519
commit 6c949f291e
2 changed files with 7 additions and 10 deletions
+3 -5
View File
@@ -98,8 +98,8 @@ def main():
# Sample actions
value, logits = actor_critic(Variable(rollouts.states[step], volatile=True))
probs = F.softmax(logits)
log_probs = F.log_softmax(logits).data
action = probs.multinomial().data
action_log_probs = F.log_softmax(logits).gather(1, action).data
cpu_actions = action.cpu().numpy()
# Obser reward and next state
@@ -119,7 +119,7 @@ def main():
current_state *= masks.unsqueeze(2).unsqueeze(2)
update_current_state(state)
rollouts.insert(step, current_state, action, value.data, log_probs, reward, masks)
rollouts.insert(step, current_state, action, value.data, action_log_probs, reward, masks)
next_value = actor_critic(Variable(rollouts.states[-1], volatile=True))[0].data
@@ -187,8 +187,7 @@ def main():
log_probs = F.log_softmax(logits)
action_log_probs = log_probs.gather(1, Variable(actions_batch))
old_log_probs_batch = rollouts.old_log_probs.view(-1, rollouts.old_log_probs.size(-1))[indices]
old_action_log_probs = old_log_probs_batch.gather(1, actions_batch)
old_action_log_probs = rollouts.action_log_probs.view(-1, rollouts.action_log_probs.size(-1))[indices]
ratio = torch.exp(action_log_probs - Variable(old_action_log_probs))
adv_targ = Variable(advantages.view(-1, 1)[indices])
@@ -196,7 +195,6 @@ def main():
surr2 = ratio.clamp(1.0 - args.clip_param, 1.0 + args.clip_param) * adv_targ
action_loss = -torch.min(surr1, surr2).mean() # PPO's pessimistic surrogate (L^CLIP)
log_probs = F.log_softmax(logits)
probs = F.softmax(logits)
dist_entropy = -(log_probs * probs).sum(-1).mean()
+4 -5
View File
@@ -6,8 +6,7 @@ class RolloutStorage(object):
self.states = torch.zeros(num_steps + 1, num_processes, *obs_shape)
self.rewards = torch.zeros(num_steps, num_processes, 1)
self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
self.old_log_probs = torch.zeros(num_steps, num_processes,
action_shape)
self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
self.returns = torch.zeros(num_steps + 1, num_processes, 1)
self.actions = torch.LongTensor(num_steps, num_processes, 1)
self.masks = torch.zeros(num_steps, num_processes, 1)
@@ -16,17 +15,17 @@ class RolloutStorage(object):
self.states = self.states.cuda()
self.rewards = self.rewards.cuda()
self.value_preds = self.value_preds.cuda()
self.old_log_probs = self.old_log_probs.cuda()
self.action_log_probs = self.action_log_probs.cuda()
self.returns = self.returns.cuda()
self.actions = self.actions.cuda()
self.masks = self.masks.cuda()
def insert(self, step, current_state, action, value_pred, old_log_probs,
def insert(self, step, current_state, action, value_pred, action_log_probs,
reward, mask):
self.states[step + 1].copy_(current_state)
self.actions[step].copy_(action)
self.value_preds[step].copy_(value_pred)
self.old_log_probs[step].copy_(old_log_probs)
self.action_log_probs[step].copy_(action_log_probs)
self.rewards[step].copy_(reward)
self.masks[step].copy_(mask)