diff --git a/main.py b/main.py index 1898f36..9e49385 100755 --- a/main.py +++ b/main.py @@ -15,7 +15,7 @@ from envs import make_env from model import ActorCritic from vizualize_atari import visdom_plot -parser = argparse.ArgumentParser(description='A3C') +parser = argparse.ArgumentParser(description='A2C') parser.add_argument('--lr', type=float, default=7e-4, help='learning rate (default: 7e-4)') parser.add_argument('--eps', type=float, default=1e-5, @@ -24,6 +24,10 @@ parser.add_argument('--alpha', type=float, default=0.99, help='RMSprop optimizer apha (default: 0.99)') parser.add_argument('--gamma', type=float, default=0.99, help='discount factor for rewards (default: 0.99)') +parser.add_argument('--use-gae', action='store_true', default=False, + help='use generalized advantage estimation') +parser.add_argument('--tau', type=float, default=0.95, + help='gae parameter (default: 0.95)') parser.add_argument('--entropy-coef', type=float, default=0.01, help='entropy term coefficient (default: 0.01)') parser.add_argument('--value-loss-coef', type=float, default=0.5, @@ -112,6 +116,7 @@ def main(): update_current_state(state) rewards = torch.zeros(args.num_steps, args.num_processes, 1) + value_preds = torch.zeros(args.num_steps + 1, args.num_processes, 1) returns = torch.zeros(args.num_steps + 1, args.num_processes, 1) actions = torch.LongTensor(args.num_steps, args.num_processes) @@ -125,6 +130,7 @@ def main(): states = states.cuda() current_state = current_state.cuda() rewards = rewards.cuda() + value_preds = value_preds.cuda() returns = returns.cuda() actions = actions.cuda() masks = masks.cuda() @@ -132,7 +138,7 @@ def main(): for j in range(num_updates): for step in range(args.num_steps): # Sample actions - _, logits = actor_critic(Variable(states[step], volatile=True)) + value, logits = actor_critic(Variable(states[step], volatile=True)) probs = F.softmax(logits) log_probs = F.log_softmax(logits).data actions[step] = probs.multinomial().data @@ -156,6 +162,7 @@ def main(): update_current_state(state) states[step + 1].copy_(current_state) + value_preds[step].copy_(value.data) rewards[step].copy_(reward) masks[step].copy_(torch.from_numpy(np_masks).unsqueeze(1)) @@ -164,11 +171,20 @@ def main(): episode_rewards *= masks[step].cpu() - returns[-1] = actor_critic(Variable(states[-1], volatile=True))[0].data + if args.use_gae: + value_preds[-1] = actor_critic(Variable(states[-1], volatile=True))[0].data + gae = 0 + for step in reversed(range(args.num_steps)): + delta = rewards[step] + args.gamma * value_preds[step + 1] * masks[step] - value_preds[step] + gae = delta + args.gamma * args.tau * masks[step] * gae + + returns[step] = gae + value_preds[step] + else: + returns[-1] = actor_critic(Variable(states[-1], volatile=True))[0].data + for step in reversed(range(args.num_steps)): + returns[step] = returns[step + 1] * \ + args.gamma * masks[step] + rewards[step] - for step in reversed(range(args.num_steps)): - returns[step] = returns[step + 1] * \ - args.gamma * masks[step] + rewards[step] # Reshape to do in a single forward pass for all steps values, logits = actor_critic(Variable(states[:-1].view(-1, *states.size()[-3:])))