mirror of
https://github.com/wassname/pytorch-a2c-ppo-acktr.git
synced 2026-06-27 16:20:05 +08:00
Refactor code to add ppo easier
This commit is contained in:
@@ -164,34 +164,34 @@ def main():
|
||||
|
||||
episode_rewards *= masks[step].cpu()
|
||||
|
||||
# Reshape to do in a single forward pass for all steps
|
||||
values, logits = actor_critic(Variable(states.view(-1, *states.size()[-3:])))
|
||||
log_probs = F.log_softmax(logits)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
# Unreshape
|
||||
logits_size = (args.num_steps + 1, args.num_processes, logits.size(-1))
|
||||
|
||||
log_probs = F.log_softmax(logits).view(logits_size)[:-1]
|
||||
probs = F.softmax(logits).view(logits_size)[:-1]
|
||||
|
||||
values = values.view(args.num_steps + 1, args.num_processes, 1)
|
||||
logits = logits.view(logits_size)[:-1]
|
||||
|
||||
action_log_probs = log_probs.gather(2, Variable(actions.unsqueeze(2)))
|
||||
|
||||
dist_entropy = -(log_probs * probs).sum(-1).mean()
|
||||
|
||||
returns[-1] = values[-1].data
|
||||
returns[-1] = actor_critic(Variable(states[-1], volatile=True))[0].data
|
||||
|
||||
for step in reversed(range(args.num_steps)):
|
||||
returns[step] = returns[step + 1] * \
|
||||
args.gamma * masks[step] + rewards[step]
|
||||
|
||||
value_loss = (values[:-1] - Variable(returns[:-1])).pow(2).mean()
|
||||
# Reshape to do in a single forward pass for all steps
|
||||
values, logits = actor_critic(Variable(states[:-1].view(-1, *states.size()[-3:])))
|
||||
log_probs = F.log_softmax(logits)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
advantages = returns[:-1] - values[:-1].data
|
||||
action_loss = -(Variable(advantages) * action_log_probs).mean()
|
||||
# Unreshape
|
||||
logits_size = (args.num_steps, args.num_processes, logits.size(-1))
|
||||
|
||||
log_probs = F.log_softmax(logits).view(logits_size)
|
||||
probs = F.softmax(logits).view(logits_size)
|
||||
|
||||
values = values.view(args.num_steps, args.num_processes, 1)
|
||||
logits = logits.view(logits_size)
|
||||
|
||||
action_log_probs = log_probs.gather(2, Variable(actions.unsqueeze(2)))
|
||||
|
||||
dist_entropy = -(log_probs * probs).sum(-1).mean()
|
||||
|
||||
advantages = Variable(returns[:-1]) - values
|
||||
value_loss = advantages.pow(2).mean()
|
||||
|
||||
action_loss = -(Variable(advantages.data) * action_log_probs).mean()
|
||||
|
||||
optimizer.zero_grad()
|
||||
(value_loss * args.value_loss_coef + action_loss - dist_entropy * args.entropy_coef).backward()
|
||||
|
||||
Reference in New Issue
Block a user