mirror of
https://github.com/wassname/pytorch-a2c-ppo-acktr.git
synced 2026-06-27 16:20:05 +08:00
Recompute old probabilities for PPO, to make continuous actions work with obs filter
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import glob
|
||||
import os
|
||||
|
||||
@@ -91,10 +92,13 @@ def main():
|
||||
current_state = current_state.cuda()
|
||||
rollouts.cuda()
|
||||
|
||||
if args.algo == 'ppo':
|
||||
old_model = copy.deepcopy(actor_critic)
|
||||
|
||||
for j in range(num_updates):
|
||||
for step in range(args.num_steps):
|
||||
# Sample actions
|
||||
value, action, action_log_probs = actor_critic.act(Variable(rollouts.states[step], volatile=True))
|
||||
value, action = actor_critic.act(Variable(rollouts.states[step], volatile=True))
|
||||
cpu_actions = action.data.cpu().numpy()
|
||||
|
||||
# Obser reward and next state
|
||||
@@ -114,7 +118,7 @@ def main():
|
||||
current_state *= masks.unsqueeze(2).unsqueeze(2)
|
||||
|
||||
update_current_state(state)
|
||||
rollouts.insert(step, current_state, action.data, value.data, action_log_probs.data, reward, masks)
|
||||
rollouts.insert(step, current_state, action.data, value.data, reward, masks)
|
||||
|
||||
next_value = actor_critic(Variable(rollouts.states[-1], volatile=True))[0].data
|
||||
|
||||
@@ -141,7 +145,7 @@ def main():
|
||||
value_noise = value_noise.cuda()
|
||||
|
||||
sample_values = values + value_noise
|
||||
vf_fisher_loss = - (values - Variable(sample_values.data)).pow(2).mean()
|
||||
vf_fisher_loss = -(values - Variable(sample_values.data)).pow(2).mean()
|
||||
|
||||
fisher_loss = pg_fisher_loss + vf_fisher_loss
|
||||
optimizer.acc_stats = True
|
||||
@@ -158,6 +162,9 @@ def main():
|
||||
elif args.algo == 'ppo':
|
||||
advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
|
||||
|
||||
old_model.load_state_dict(actor_critic.state_dict())
|
||||
|
||||
for _ in range(args.ppo_epoch):
|
||||
sampler = BatchSampler(SubsetRandomSampler(range(args.num_processes * args.num_steps)), args.batch_size * args.num_processes, drop_last=False)
|
||||
for indices in sampler:
|
||||
@@ -171,9 +178,9 @@ def main():
|
||||
# Reshape to do in a single forward pass for all steps
|
||||
values, action_log_probs, dist_entropy = actor_critic.evaluate_actions(Variable(states_batch), Variable(actions_batch))
|
||||
|
||||
old_action_log_probs = rollouts.action_log_probs.view(-1, rollouts.action_log_probs.size(-1))[indices]
|
||||
_, old_action_log_probs, _ = old_model.evaluate_actions(Variable(states_batch, volatile=True), Variable(actions_batch, volatile=True))
|
||||
|
||||
ratio = torch.exp(action_log_probs - Variable(old_action_log_probs))
|
||||
ratio = torch.exp(action_log_probs - Variable(old_action_log_probs.data))
|
||||
adv_targ = Variable(advantages.view(-1, 1)[indices])
|
||||
surr1 = ratio * adv_targ
|
||||
surr2 = ratio.clamp(1.0 - args.clip_param, 1.0 + args.clip_param) * adv_targ
|
||||
|
||||
@@ -82,8 +82,7 @@ class ActorCritic(torch.nn.Module):
|
||||
value, logits = self(inputs)
|
||||
probs = F.softmax(logits)
|
||||
action = probs.multinomial()
|
||||
action_log_probs = F.log_softmax(logits).gather(1, action)
|
||||
return value, action, action_log_probs
|
||||
return value, action
|
||||
|
||||
def evaluate_actions(self, inputs, actions):
|
||||
assert inputs.dim() == 4, "Expect to have inputs in num_processes * num_steps x ... format"
|
||||
|
||||
+1
-5
@@ -6,7 +6,6 @@ class RolloutStorage(object):
|
||||
self.states = torch.zeros(num_steps + 1, num_processes, *obs_shape)
|
||||
self.rewards = torch.zeros(num_steps, num_processes, 1)
|
||||
self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
|
||||
self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
|
||||
self.returns = torch.zeros(num_steps + 1, num_processes, 1)
|
||||
self.actions = torch.LongTensor(num_steps, num_processes, 1)
|
||||
self.masks = torch.zeros(num_steps, num_processes, 1)
|
||||
@@ -15,17 +14,14 @@ class RolloutStorage(object):
|
||||
self.states = self.states.cuda()
|
||||
self.rewards = self.rewards.cuda()
|
||||
self.value_preds = self.value_preds.cuda()
|
||||
self.action_log_probs = self.action_log_probs.cuda()
|
||||
self.returns = self.returns.cuda()
|
||||
self.actions = self.actions.cuda()
|
||||
self.masks = self.masks.cuda()
|
||||
|
||||
def insert(self, step, current_state, action, value_pred, action_log_probs,
|
||||
reward, mask):
|
||||
def insert(self, step, current_state, action, value_pred, reward, mask):
|
||||
self.states[step + 1].copy_(current_state)
|
||||
self.actions[step].copy_(action)
|
||||
self.value_preds[step].copy_(value_pred)
|
||||
self.action_log_probs[step].copy_(action_log_probs)
|
||||
self.rewards[step].copy_(reward)
|
||||
self.masks[step].copy_(mask)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user