mirror of
https://github.com/wassname/pytorch-a2c-ppo-acktr.git
synced 2026-06-27 16:20:05 +08:00
Store a single log probability of actions
This commit is contained in:
+4
-5
@@ -6,8 +6,7 @@ class RolloutStorage(object):
|
||||
self.states = torch.zeros(num_steps + 1, num_processes, *obs_shape)
|
||||
self.rewards = torch.zeros(num_steps, num_processes, 1)
|
||||
self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
|
||||
self.old_log_probs = torch.zeros(num_steps, num_processes,
|
||||
action_shape)
|
||||
self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
|
||||
self.returns = torch.zeros(num_steps + 1, num_processes, 1)
|
||||
self.actions = torch.LongTensor(num_steps, num_processes, 1)
|
||||
self.masks = torch.zeros(num_steps, num_processes, 1)
|
||||
@@ -16,17 +15,17 @@ class RolloutStorage(object):
|
||||
self.states = self.states.cuda()
|
||||
self.rewards = self.rewards.cuda()
|
||||
self.value_preds = self.value_preds.cuda()
|
||||
self.old_log_probs = self.old_log_probs.cuda()
|
||||
self.action_log_probs = self.action_log_probs.cuda()
|
||||
self.returns = self.returns.cuda()
|
||||
self.actions = self.actions.cuda()
|
||||
self.masks = self.masks.cuda()
|
||||
|
||||
def insert(self, step, current_state, action, value_pred, old_log_probs,
|
||||
def insert(self, step, current_state, action, value_pred, action_log_probs,
|
||||
reward, mask):
|
||||
self.states[step + 1].copy_(current_state)
|
||||
self.actions[step].copy_(action)
|
||||
self.value_preds[step].copy_(value_pred)
|
||||
self.old_log_probs[step].copy_(old_log_probs)
|
||||
self.action_log_probs[step].copy_(action_log_probs)
|
||||
self.rewards[step].copy_(reward)
|
||||
self.masks[step].copy_(mask)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user