updated agent.compute_action to return rnn state (#2581)

* updated agent.compute_action to return rnn state

* updated compute_action method, added case for state=None

* fixing lint
This commit is contained in:
efang96
2018-08-13 18:04:42 -07:00
committed by Eric Liang
parent 8769b8ac32
commit baba624373
+13 -4
View File
@@ -216,7 +216,11 @@ class Agent(Trainable):
Arguments:
observation (obj): observation from the environment.
state (list): RNN hidden state, if any.
state (list): RNN hidden state, if any. If state is not None,
then all of compute_single_action(...) is returned
(computed action, rnn state, logits dictionary).
Otherwise compute_single_action(...)[0] is
returned (computed action).
policy_id (str): policy to query (only applies to multi-agent).
"""
@@ -224,10 +228,15 @@ class Agent(Trainable):
state = []
filtered_obs = self.local_evaluator.filters[policy_id](
observation, update=False)
if state:
return self.local_evaluator.for_policy(
lambda p: p.compute_single_action(
filtered_obs, state, is_training=False),
policy_id=policy_id)
return self.local_evaluator.for_policy(
lambda p: p.compute_single_action(
filtered_obs, state, is_training=False)[0],
policy_id=policy_id)
lambda p: p.compute_single_action(
filtered_obs, state, is_training=False)[0],
policy_id=policy_id)
def get_weights(self, policies=None):
"""Return a dictionary of policy ids to weights.