diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index d83e325f2..2e330f9e1 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -216,7 +216,11 @@ class Agent(Trainable): Arguments: observation (obj): observation from the environment. - state (list): RNN hidden state, if any. + state (list): RNN hidden state, if any. If state is not None, + then all of compute_single_action(...) is returned + (computed action, rnn state, logits dictionary). + Otherwise compute_single_action(...)[0] is + returned (computed action). policy_id (str): policy to query (only applies to multi-agent). """ @@ -224,10 +228,15 @@ class Agent(Trainable): state = [] filtered_obs = self.local_evaluator.filters[policy_id]( observation, update=False) + if state: + return self.local_evaluator.for_policy( + lambda p: p.compute_single_action( + filtered_obs, state, is_training=False), + policy_id=policy_id) return self.local_evaluator.for_policy( - lambda p: p.compute_single_action( - filtered_obs, state, is_training=False)[0], - policy_id=policy_id) + lambda p: p.compute_single_action( + filtered_obs, state, is_training=False)[0], + policy_id=policy_id) def get_weights(self, policies=None): """Return a dictionary of policy ids to weights.