From f4fc4c60646738d289aa09b4451912551ca5122e Mon Sep 17 00:00:00 2001 From: Ilya Kostrikov Date: Fri, 22 Sep 2017 12:29:21 -0400 Subject: [PATCH] Create an act function --- main.py | 9 +++------ model.py | 7 +++++++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 53e6129..b983a7b 100755 --- a/main.py +++ b/main.py @@ -96,11 +96,8 @@ def main(): for j in range(num_updates): for step in range(args.num_steps): # Sample actions - value, logits = actor_critic(Variable(rollouts.states[step], volatile=True)) - probs = F.softmax(logits) - action = probs.multinomial().data - action_log_probs = F.log_softmax(logits).gather(1, action).data - cpu_actions = action.cpu().numpy() + value, action, action_log_probs = actor_critic.act(Variable(rollouts.states[step], volatile=True)) + cpu_actions = action.data.cpu().numpy() # Obser reward and next state state, reward, done, info = envs.step(cpu_actions) @@ -119,7 +116,7 @@ def main(): current_state *= masks.unsqueeze(2).unsqueeze(2) update_current_state(state) - rollouts.insert(step, current_state, action, value.data, action_log_probs, reward, masks) + rollouts.insert(step, current_state, action.data, value.data, action_log_probs.data, reward, masks) next_value = actor_critic(Variable(rollouts.states[-1], volatile=True))[0].data diff --git a/model.py b/model.py index c4e6565..c533b36 100755 --- a/model.py +++ b/model.py @@ -77,3 +77,10 @@ class ActorCritic(torch.nn.Module): return self.ab_fc2(self.critic_linear(x)), self.ab_fc3( self.actor_linear(x)) + + def act(self, inputs): + value, logits = self(inputs) + probs = F.softmax(logits) + action = probs.multinomial() + action_log_probs = F.log_softmax(logits).gather(1, action) + return value, action, action_log_probs