Recompute old probabilities for PPO, to make continuous actions work with obs filter

This commit is contained in:
Ilya Kostrikov
2017-09-24 23:00:14 -04:00
parent 6ee53d245d
commit 54a0f98180
3 changed files with 14 additions and 12 deletions
+12 -5
View File
@@ -1,3 +1,4 @@
import copy
import glob
import os
@@ -91,10 +92,13 @@ def main():
current_state = current_state.cuda()
rollouts.cuda()
if args.algo == 'ppo':
old_model = copy.deepcopy(actor_critic)
for j in range(num_updates):
for step in range(args.num_steps):
# Sample actions
value, action, action_log_probs = actor_critic.act(Variable(rollouts.states[step], volatile=True))
value, action = actor_critic.act(Variable(rollouts.states[step], volatile=True))
cpu_actions = action.data.cpu().numpy()
# Obser reward and next state
@@ -114,7 +118,7 @@ def main():
current_state *= masks.unsqueeze(2).unsqueeze(2)
update_current_state(state)
rollouts.insert(step, current_state, action.data, value.data, action_log_probs.data, reward, masks)
rollouts.insert(step, current_state, action.data, value.data, reward, masks)
next_value = actor_critic(Variable(rollouts.states[-1], volatile=True))[0].data
@@ -141,7 +145,7 @@ def main():
value_noise = value_noise.cuda()
sample_values = values + value_noise
vf_fisher_loss = - (values - Variable(sample_values.data)).pow(2).mean()
vf_fisher_loss = -(values - Variable(sample_values.data)).pow(2).mean()
fisher_loss = pg_fisher_loss + vf_fisher_loss
optimizer.acc_stats = True
@@ -158,6 +162,9 @@ def main():
elif args.algo == 'ppo':
advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
old_model.load_state_dict(actor_critic.state_dict())
for _ in range(args.ppo_epoch):
sampler = BatchSampler(SubsetRandomSampler(range(args.num_processes * args.num_steps)), args.batch_size * args.num_processes, drop_last=False)
for indices in sampler:
@@ -171,9 +178,9 @@ def main():
# Reshape to do in a single forward pass for all steps
values, action_log_probs, dist_entropy = actor_critic.evaluate_actions(Variable(states_batch), Variable(actions_batch))
old_action_log_probs = rollouts.action_log_probs.view(-1, rollouts.action_log_probs.size(-1))[indices]
_, old_action_log_probs, _ = old_model.evaluate_actions(Variable(states_batch, volatile=True), Variable(actions_batch, volatile=True))
ratio = torch.exp(action_log_probs - Variable(old_action_log_probs))
ratio = torch.exp(action_log_probs - Variable(old_action_log_probs.data))
adv_targ = Variable(advantages.view(-1, 1)[indices])
surr1 = ratio * adv_targ
surr2 = ratio.clamp(1.0 - args.clip_param, 1.0 + args.clip_param) * adv_targ
+1 -2
View File
@@ -82,8 +82,7 @@ class ActorCritic(torch.nn.Module):
value, logits = self(inputs)
probs = F.softmax(logits)
action = probs.multinomial()
action_log_probs = F.log_softmax(logits).gather(1, action)
return value, action, action_log_probs
return value, action
def evaluate_actions(self, inputs, actions):
assert inputs.dim() == 4, "Expect to have inputs in num_processes * num_steps x ... format"
+1 -5
View File
@@ -6,7 +6,6 @@ class RolloutStorage(object):
self.states = torch.zeros(num_steps + 1, num_processes, *obs_shape)
self.rewards = torch.zeros(num_steps, num_processes, 1)
self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
self.returns = torch.zeros(num_steps + 1, num_processes, 1)
self.actions = torch.LongTensor(num_steps, num_processes, 1)
self.masks = torch.zeros(num_steps, num_processes, 1)
@@ -15,17 +14,14 @@ class RolloutStorage(object):
self.states = self.states.cuda()
self.rewards = self.rewards.cuda()
self.value_preds = self.value_preds.cuda()
self.action_log_probs = self.action_log_probs.cuda()
self.returns = self.returns.cuda()
self.actions = self.actions.cuda()
self.masks = self.masks.cuda()
def insert(self, step, current_state, action, value_pred, action_log_probs,
reward, mask):
def insert(self, step, current_state, action, value_pred, reward, mask):
self.states[step + 1].copy_(current_state)
self.actions[step].copy_(action)
self.value_preds[step].copy_(value_pred)
self.action_log_probs[step].copy_(action_log_probs)
self.rewards[step].copy_(reward)
self.masks[step].copy_(mask)