Refactor code to add ppo easier

This commit is contained in:
Ilya Kostrikov
2017-09-16 17:04:14 -04:00
parent 6d17d59f36
commit eb110220d9
+22 -22
View File
@@ -164,34 +164,34 @@ def main():
episode_rewards *= masks[step].cpu()
# Reshape to do in a single forward pass for all steps
values, logits = actor_critic(Variable(states.view(-1, *states.size()[-3:])))
log_probs = F.log_softmax(logits)
probs = F.softmax(logits)
# Unreshape
logits_size = (args.num_steps + 1, args.num_processes, logits.size(-1))
log_probs = F.log_softmax(logits).view(logits_size)[:-1]
probs = F.softmax(logits).view(logits_size)[:-1]
values = values.view(args.num_steps + 1, args.num_processes, 1)
logits = logits.view(logits_size)[:-1]
action_log_probs = log_probs.gather(2, Variable(actions.unsqueeze(2)))
dist_entropy = -(log_probs * probs).sum(-1).mean()
returns[-1] = values[-1].data
returns[-1] = actor_critic(Variable(states[-1], volatile=True))[0].data
for step in reversed(range(args.num_steps)):
returns[step] = returns[step + 1] * \
args.gamma * masks[step] + rewards[step]
value_loss = (values[:-1] - Variable(returns[:-1])).pow(2).mean()
# Reshape to do in a single forward pass for all steps
values, logits = actor_critic(Variable(states[:-1].view(-1, *states.size()[-3:])))
log_probs = F.log_softmax(logits)
probs = F.softmax(logits)
advantages = returns[:-1] - values[:-1].data
action_loss = -(Variable(advantages) * action_log_probs).mean()
# Unreshape
logits_size = (args.num_steps, args.num_processes, logits.size(-1))
log_probs = F.log_softmax(logits).view(logits_size)
probs = F.softmax(logits).view(logits_size)
values = values.view(args.num_steps, args.num_processes, 1)
logits = logits.view(logits_size)
action_log_probs = log_probs.gather(2, Variable(actions.unsqueeze(2)))
dist_entropy = -(log_probs * probs).sum(-1).mean()
advantages = Variable(returns[:-1]) - values
value_loss = advantages.pow(2).mean()
action_loss = -(Variable(advantages.data) * action_log_probs).mean()
optimizer.zero_grad()
(value_loss * args.value_loss_coef + action_loss - dist_entropy * args.entropy_coef).backward()