Create an act function

This commit is contained in:
Ilya Kostrikov
2017-09-22 12:29:21 -04:00
parent 6c949f291e
commit f4fc4c6064
2 changed files with 10 additions and 6 deletions
+3 -6
View File
@@ -96,11 +96,8 @@ def main():
for j in range(num_updates):
for step in range(args.num_steps):
# Sample actions
value, logits = actor_critic(Variable(rollouts.states[step], volatile=True))
probs = F.softmax(logits)
action = probs.multinomial().data
action_log_probs = F.log_softmax(logits).gather(1, action).data
cpu_actions = action.cpu().numpy()
value, action, action_log_probs = actor_critic.act(Variable(rollouts.states[step], volatile=True))
cpu_actions = action.data.cpu().numpy()
# Obser reward and next state
state, reward, done, info = envs.step(cpu_actions)
@@ -119,7 +116,7 @@ def main():
current_state *= masks.unsqueeze(2).unsqueeze(2)
update_current_state(state)
rollouts.insert(step, current_state, action, value.data, action_log_probs, reward, masks)
rollouts.insert(step, current_state, action.data, value.data, action_log_probs.data, reward, masks)
next_value = actor_critic(Variable(rollouts.states[-1], volatile=True))[0].data
+7
View File
@@ -77,3 +77,10 @@ class ActorCritic(torch.nn.Module):
return self.ab_fc2(self.critic_linear(x)), self.ab_fc3(
self.actor_linear(x))
def act(self, inputs):
value, logits = self(inputs)
probs = F.softmax(logits)
action = probs.multinomial()
action_log_probs = F.log_softmax(logits).gather(1, action)
return value, action, action_log_probs