mirror of
https://github.com/wassname/pytorch-a2c-ppo-acktr.git
synced 2026-06-27 16:20:05 +08:00
Store a single log probability of actions
This commit is contained in:
@@ -98,8 +98,8 @@ def main():
|
||||
# Sample actions
|
||||
value, logits = actor_critic(Variable(rollouts.states[step], volatile=True))
|
||||
probs = F.softmax(logits)
|
||||
log_probs = F.log_softmax(logits).data
|
||||
action = probs.multinomial().data
|
||||
action_log_probs = F.log_softmax(logits).gather(1, action).data
|
||||
cpu_actions = action.cpu().numpy()
|
||||
|
||||
# Obser reward and next state
|
||||
@@ -119,7 +119,7 @@ def main():
|
||||
current_state *= masks.unsqueeze(2).unsqueeze(2)
|
||||
|
||||
update_current_state(state)
|
||||
rollouts.insert(step, current_state, action, value.data, log_probs, reward, masks)
|
||||
rollouts.insert(step, current_state, action, value.data, action_log_probs, reward, masks)
|
||||
|
||||
next_value = actor_critic(Variable(rollouts.states[-1], volatile=True))[0].data
|
||||
|
||||
@@ -187,8 +187,7 @@ def main():
|
||||
log_probs = F.log_softmax(logits)
|
||||
action_log_probs = log_probs.gather(1, Variable(actions_batch))
|
||||
|
||||
old_log_probs_batch = rollouts.old_log_probs.view(-1, rollouts.old_log_probs.size(-1))[indices]
|
||||
old_action_log_probs = old_log_probs_batch.gather(1, actions_batch)
|
||||
old_action_log_probs = rollouts.action_log_probs.view(-1, rollouts.action_log_probs.size(-1))[indices]
|
||||
|
||||
ratio = torch.exp(action_log_probs - Variable(old_action_log_probs))
|
||||
adv_targ = Variable(advantages.view(-1, 1)[indices])
|
||||
@@ -196,7 +195,6 @@ def main():
|
||||
surr2 = ratio.clamp(1.0 - args.clip_param, 1.0 + args.clip_param) * adv_targ
|
||||
action_loss = -torch.min(surr1, surr2).mean() # PPO's pessimistic surrogate (L^CLIP)
|
||||
|
||||
log_probs = F.log_softmax(logits)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
dist_entropy = -(log_probs * probs).sum(-1).mean()
|
||||
|
||||
+4
-5
@@ -6,8 +6,7 @@ class RolloutStorage(object):
|
||||
self.states = torch.zeros(num_steps + 1, num_processes, *obs_shape)
|
||||
self.rewards = torch.zeros(num_steps, num_processes, 1)
|
||||
self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
|
||||
self.old_log_probs = torch.zeros(num_steps, num_processes,
|
||||
action_shape)
|
||||
self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
|
||||
self.returns = torch.zeros(num_steps + 1, num_processes, 1)
|
||||
self.actions = torch.LongTensor(num_steps, num_processes, 1)
|
||||
self.masks = torch.zeros(num_steps, num_processes, 1)
|
||||
@@ -16,17 +15,17 @@ class RolloutStorage(object):
|
||||
self.states = self.states.cuda()
|
||||
self.rewards = self.rewards.cuda()
|
||||
self.value_preds = self.value_preds.cuda()
|
||||
self.old_log_probs = self.old_log_probs.cuda()
|
||||
self.action_log_probs = self.action_log_probs.cuda()
|
||||
self.returns = self.returns.cuda()
|
||||
self.actions = self.actions.cuda()
|
||||
self.masks = self.masks.cuda()
|
||||
|
||||
def insert(self, step, current_state, action, value_pred, old_log_probs,
|
||||
def insert(self, step, current_state, action, value_pred, action_log_probs,
|
||||
reward, mask):
|
||||
self.states[step + 1].copy_(current_state)
|
||||
self.actions[step].copy_(action)
|
||||
self.value_preds[step].copy_(value_pred)
|
||||
self.old_log_probs[step].copy_(old_log_probs)
|
||||
self.action_log_probs[step].copy_(action_log_probs)
|
||||
self.rewards[step].copy_(reward)
|
||||
self.masks[step].copy_(mask)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user