mirror of
https://github.com/wassname/pytorch-a2c-ppo-acktr.git
synced 2026-06-27 16:20:05 +08:00
Create an act function
This commit is contained in:
@@ -96,11 +96,8 @@ def main():
|
||||
for j in range(num_updates):
|
||||
for step in range(args.num_steps):
|
||||
# Sample actions
|
||||
value, logits = actor_critic(Variable(rollouts.states[step], volatile=True))
|
||||
probs = F.softmax(logits)
|
||||
action = probs.multinomial().data
|
||||
action_log_probs = F.log_softmax(logits).gather(1, action).data
|
||||
cpu_actions = action.cpu().numpy()
|
||||
value, action, action_log_probs = actor_critic.act(Variable(rollouts.states[step], volatile=True))
|
||||
cpu_actions = action.data.cpu().numpy()
|
||||
|
||||
# Obser reward and next state
|
||||
state, reward, done, info = envs.step(cpu_actions)
|
||||
@@ -119,7 +116,7 @@ def main():
|
||||
current_state *= masks.unsqueeze(2).unsqueeze(2)
|
||||
|
||||
update_current_state(state)
|
||||
rollouts.insert(step, current_state, action, value.data, action_log_probs, reward, masks)
|
||||
rollouts.insert(step, current_state, action.data, value.data, action_log_probs.data, reward, masks)
|
||||
|
||||
next_value = actor_critic(Variable(rollouts.states[-1], volatile=True))[0].data
|
||||
|
||||
|
||||
@@ -77,3 +77,10 @@ class ActorCritic(torch.nn.Module):
|
||||
|
||||
return self.ab_fc2(self.critic_linear(x)), self.ab_fc3(
|
||||
self.actor_linear(x))
|
||||
|
||||
def act(self, inputs):
|
||||
value, logits = self(inputs)
|
||||
probs = F.softmax(logits)
|
||||
action = probs.multinomial()
|
||||
action_log_probs = F.log_softmax(logits).gather(1, action)
|
||||
return value, action, action_log_probs
|
||||
|
||||
Reference in New Issue
Block a user