diff --git a/doc/source/rllib-toc.rst b/doc/source/rllib-toc.rst index 20cc9b484..fca6a7623 100644 --- a/doc/source/rllib-toc.rst +++ b/doc/source/rllib-toc.rst @@ -18,6 +18,8 @@ Training APIs - `Custom Training Workflows `__ + - `Computing Actions `__ + - `Accessing Policy State `__ - `Accessing Model State `__ diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 7d7772657..ae80727f5 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -184,6 +184,78 @@ In the `basic training example `__ directly to implement `fully customized training workflows `__. +Computing Actions +~~~~~~~~~~~~~~~~~ + +The simplest way to programmatically compute actions from a trained agent is to use ``trainer.compute_action()``. +This method preprocesses and filters the observation before passing it to the agent policy. +For more advanced usage, you can access the ``workers`` and policies held by the trainer +directly as ``compute_action()`` does: + +.. code-block:: python + + class Trainer(Trainable): + + @PublicAPI + def compute_action(self, + observation, + state=None, + prev_action=None, + prev_reward=None, + info=None, + policy_id=DEFAULT_POLICY_ID, + full_fetch=False): + """Computes an action for the specified policy. + + Note that you can also access the policy object through + self.get_policy(policy_id) and call compute_actions() on it directly. + + Arguments: + observation (obj): observation from the environment. + state (list): RNN hidden state, if any. If state is not None, + then all of compute_single_action(...) is returned + (computed action, rnn state, logits dictionary). + Otherwise compute_single_action(...)[0] is + returned (computed action). + prev_action (obj): previous action value, if any + prev_reward (int): previous reward, if any + info (dict): info object, if any + policy_id (str): policy to query (only applies to multi-agent). + full_fetch (bool): whether to return extra action fetch results. + This is always set to true if RNN state is specified. + + Returns: + Just the computed action if full_fetch=False, or the full output + of policy.compute_actions() otherwise. + """ + + if state is None: + state = [] + preprocessed = self.workers.local_worker().preprocessors[ + policy_id].transform(observation) + filtered_obs = self.workers.local_worker().filters[policy_id]( + preprocessed, update=False) + if state: + return self.get_policy(policy_id).compute_single_action( + filtered_obs, + state, + prev_action, + prev_reward, + info, + clip_actions=self.config["clip_actions"]) + res = self.get_policy(policy_id).compute_single_action( + filtered_obs, + state, + prev_action, + prev_reward, + info, + clip_actions=self.config["clip_actions"]) + if full_fetch: + return res + else: + return res[0] # backwards compatibility + + Accessing Policy State ~~~~~~~~~~~~~~~~~~~~~~ It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *rollout workers* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.workers.foreach_worker()`` or ``trainer.workers.foreach_worker_with_index()``. These functions take a lambda function that is applied with the worker as an arg. You can also return values from these functions and those will be returned as a list.