From baba62437399003bc4f320b329d3ac452bab4616 Mon Sep 17 00:00:00 2001 From: efang96 Date: Mon, 13 Aug 2018 18:04:42 -0700 Subject: [PATCH] updated agent.compute_action to return rnn state (#2581) * updated agent.compute_action to return rnn state * updated compute_action method, added case for state=None * fixing lint --- python/ray/rllib/agents/agent.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index d83e325f2..2e330f9e1 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -216,7 +216,11 @@ class Agent(Trainable): Arguments: observation (obj): observation from the environment. - state (list): RNN hidden state, if any. + state (list): RNN hidden state, if any. If state is not None, + then all of compute_single_action(...) is returned + (computed action, rnn state, logits dictionary). + Otherwise compute_single_action(...)[0] is + returned (computed action). policy_id (str): policy to query (only applies to multi-agent). """ @@ -224,10 +228,15 @@ class Agent(Trainable): state = [] filtered_obs = self.local_evaluator.filters[policy_id]( observation, update=False) + if state: + return self.local_evaluator.for_policy( + lambda p: p.compute_single_action( + filtered_obs, state, is_training=False), + policy_id=policy_id) return self.local_evaluator.for_policy( - lambda p: p.compute_single_action( - filtered_obs, state, is_training=False)[0], - policy_id=policy_id) + lambda p: p.compute_single_action( + filtered_obs, state, is_training=False)[0], + policy_id=policy_id) def get_weights(self, policies=None): """Return a dictionary of policy ids to weights.