Store a single log probability of actions

This commit is contained in:
Ilya Kostrikov
2017-09-21 19:25:16 -04:00
parent 475de22519
commit 6c949f291e
2 changed files with 7 additions and 10 deletions
+4 -5
View File
@@ -6,8 +6,7 @@ class RolloutStorage(object):
self.states = torch.zeros(num_steps + 1, num_processes, *obs_shape)
self.rewards = torch.zeros(num_steps, num_processes, 1)
self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
self.old_log_probs = torch.zeros(num_steps, num_processes,
action_shape)
self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
self.returns = torch.zeros(num_steps + 1, num_processes, 1)
self.actions = torch.LongTensor(num_steps, num_processes, 1)
self.masks = torch.zeros(num_steps, num_processes, 1)
@@ -16,17 +15,17 @@ class RolloutStorage(object):
self.states = self.states.cuda()
self.rewards = self.rewards.cuda()
self.value_preds = self.value_preds.cuda()
self.old_log_probs = self.old_log_probs.cuda()
self.action_log_probs = self.action_log_probs.cuda()
self.returns = self.returns.cuda()
self.actions = self.actions.cuda()
self.masks = self.masks.cuda()
def insert(self, step, current_state, action, value_pred, old_log_probs,
def insert(self, step, current_state, action, value_pred, action_log_probs,
reward, mask):
self.states[step + 1].copy_(current_state)
self.actions[step].copy_(action)
self.value_preds[step].copy_(value_pred)
self.old_log_probs[step].copy_(old_log_probs)
self.action_log_probs[step].copy_(action_log_probs)
self.rewards[step].copy_(reward)
self.masks[step].copy_(mask)