mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 18:11:38 +08:00
updated agent.compute_action to return rnn state (#2581)
* updated agent.compute_action to return rnn state * updated compute_action method, added case for state=None * fixing lint
This commit is contained in:
@@ -216,7 +216,11 @@ class Agent(Trainable):
|
||||
|
||||
Arguments:
|
||||
observation (obj): observation from the environment.
|
||||
state (list): RNN hidden state, if any.
|
||||
state (list): RNN hidden state, if any. If state is not None,
|
||||
then all of compute_single_action(...) is returned
|
||||
(computed action, rnn state, logits dictionary).
|
||||
Otherwise compute_single_action(...)[0] is
|
||||
returned (computed action).
|
||||
policy_id (str): policy to query (only applies to multi-agent).
|
||||
"""
|
||||
|
||||
@@ -224,10 +228,15 @@ class Agent(Trainable):
|
||||
state = []
|
||||
filtered_obs = self.local_evaluator.filters[policy_id](
|
||||
observation, update=False)
|
||||
if state:
|
||||
return self.local_evaluator.for_policy(
|
||||
lambda p: p.compute_single_action(
|
||||
filtered_obs, state, is_training=False),
|
||||
policy_id=policy_id)
|
||||
return self.local_evaluator.for_policy(
|
||||
lambda p: p.compute_single_action(
|
||||
filtered_obs, state, is_training=False)[0],
|
||||
policy_id=policy_id)
|
||||
lambda p: p.compute_single_action(
|
||||
filtered_obs, state, is_training=False)[0],
|
||||
policy_id=policy_id)
|
||||
|
||||
def get_weights(self, policies=None):
|
||||
"""Return a dictionary of policy ids to weights.
|
||||
|
||||
Reference in New Issue
Block a user