From eb110220d9a39a294479433cefc274e42506737e Mon Sep 17 00:00:00 2001 From: Ilya Kostrikov Date: Sat, 16 Sep 2017 17:04:14 -0400 Subject: [PATCH] Refactor code to add ppo easier --- main.py | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/main.py b/main.py index a4a4531..1898f36 100755 --- a/main.py +++ b/main.py @@ -164,34 +164,34 @@ def main(): episode_rewards *= masks[step].cpu() - # Reshape to do in a single forward pass for all steps - values, logits = actor_critic(Variable(states.view(-1, *states.size()[-3:]))) - log_probs = F.log_softmax(logits) - probs = F.softmax(logits) - - # Unreshape - logits_size = (args.num_steps + 1, args.num_processes, logits.size(-1)) - - log_probs = F.log_softmax(logits).view(logits_size)[:-1] - probs = F.softmax(logits).view(logits_size)[:-1] - - values = values.view(args.num_steps + 1, args.num_processes, 1) - logits = logits.view(logits_size)[:-1] - - action_log_probs = log_probs.gather(2, Variable(actions.unsqueeze(2))) - - dist_entropy = -(log_probs * probs).sum(-1).mean() - - returns[-1] = values[-1].data + returns[-1] = actor_critic(Variable(states[-1], volatile=True))[0].data for step in reversed(range(args.num_steps)): returns[step] = returns[step + 1] * \ args.gamma * masks[step] + rewards[step] - value_loss = (values[:-1] - Variable(returns[:-1])).pow(2).mean() + # Reshape to do in a single forward pass for all steps + values, logits = actor_critic(Variable(states[:-1].view(-1, *states.size()[-3:]))) + log_probs = F.log_softmax(logits) + probs = F.softmax(logits) - advantages = returns[:-1] - values[:-1].data - action_loss = -(Variable(advantages) * action_log_probs).mean() + # Unreshape + logits_size = (args.num_steps, args.num_processes, logits.size(-1)) + + log_probs = F.log_softmax(logits).view(logits_size) + probs = F.softmax(logits).view(logits_size) + + values = values.view(args.num_steps, args.num_processes, 1) + logits = logits.view(logits_size) + + action_log_probs = log_probs.gather(2, Variable(actions.unsqueeze(2))) + + dist_entropy = -(log_probs * probs).sum(-1).mean() + + advantages = Variable(returns[:-1]) - values + value_loss = advantages.pow(2).mean() + + action_loss = -(Variable(advantages.data) * action_log_probs).mean() optimizer.zero_grad() (value_loss * args.value_loss_coef + action_loss - dist_entropy * args.entropy_coef).backward()