This commit is contained in:
Ilya Kostrikov
2017-09-16 20:53:16 -04:00
parent eb110220d9
commit f09b3a75e4
+22 -6
View File
@@ -15,7 +15,7 @@ from envs import make_env
from model import ActorCritic
from vizualize_atari import visdom_plot
parser = argparse.ArgumentParser(description='A3C')
parser = argparse.ArgumentParser(description='A2C')
parser.add_argument('--lr', type=float, default=7e-4,
help='learning rate (default: 7e-4)')
parser.add_argument('--eps', type=float, default=1e-5,
@@ -24,6 +24,10 @@ parser.add_argument('--alpha', type=float, default=0.99,
help='RMSprop optimizer apha (default: 0.99)')
parser.add_argument('--gamma', type=float, default=0.99,
help='discount factor for rewards (default: 0.99)')
parser.add_argument('--use-gae', action='store_true', default=False,
help='use generalized advantage estimation')
parser.add_argument('--tau', type=float, default=0.95,
help='gae parameter (default: 0.95)')
parser.add_argument('--entropy-coef', type=float, default=0.01,
help='entropy term coefficient (default: 0.01)')
parser.add_argument('--value-loss-coef', type=float, default=0.5,
@@ -112,6 +116,7 @@ def main():
update_current_state(state)
rewards = torch.zeros(args.num_steps, args.num_processes, 1)
value_preds = torch.zeros(args.num_steps + 1, args.num_processes, 1)
returns = torch.zeros(args.num_steps + 1, args.num_processes, 1)
actions = torch.LongTensor(args.num_steps, args.num_processes)
@@ -125,6 +130,7 @@ def main():
states = states.cuda()
current_state = current_state.cuda()
rewards = rewards.cuda()
value_preds = value_preds.cuda()
returns = returns.cuda()
actions = actions.cuda()
masks = masks.cuda()
@@ -132,7 +138,7 @@ def main():
for j in range(num_updates):
for step in range(args.num_steps):
# Sample actions
_, logits = actor_critic(Variable(states[step], volatile=True))
value, logits = actor_critic(Variable(states[step], volatile=True))
probs = F.softmax(logits)
log_probs = F.log_softmax(logits).data
actions[step] = probs.multinomial().data
@@ -156,6 +162,7 @@ def main():
update_current_state(state)
states[step + 1].copy_(current_state)
value_preds[step].copy_(value.data)
rewards[step].copy_(reward)
masks[step].copy_(torch.from_numpy(np_masks).unsqueeze(1))
@@ -164,11 +171,20 @@ def main():
episode_rewards *= masks[step].cpu()
returns[-1] = actor_critic(Variable(states[-1], volatile=True))[0].data
if args.use_gae:
value_preds[-1] = actor_critic(Variable(states[-1], volatile=True))[0].data
gae = 0
for step in reversed(range(args.num_steps)):
delta = rewards[step] + args.gamma * value_preds[step + 1] * masks[step] - value_preds[step]
gae = delta + args.gamma * args.tau * masks[step] * gae
returns[step] = gae + value_preds[step]
else:
returns[-1] = actor_critic(Variable(states[-1], volatile=True))[0].data
for step in reversed(range(args.num_steps)):
returns[step] = returns[step + 1] * \
args.gamma * masks[step] + rewards[step]
for step in reversed(range(args.num_steps)):
returns[step] = returns[step + 1] * \
args.gamma * masks[step] + rewards[step]
# Reshape to do in a single forward pass for all steps
values, logits = actor_critic(Variable(states[:-1].view(-1, *states.size()[-3:])))