From 9aa1cd613d2457176c731807b237d8747b539b1b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 1 Jun 2019 16:58:49 +0800 Subject: [PATCH] [rllib] Allow Torch policies access to full action input dict in extra_action_out_fn (#4894) * fix torch extra out * preserve setitem * fix docs --- doc/source/rllib-concepts.rst | 2 +- .../ray/rllib/agents/a3c/a3c_torch_policy.py | 2 +- python/ray/rllib/policy/torch_policy.py | 30 ++++++++++++++----- .../ray/rllib/policy/torch_policy_template.py | 8 +++-- python/ray/rllib/utils/tracking_dict.py | 5 ++++ 5 files changed, 34 insertions(+), 13 deletions(-) diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index 8556e419a..2f9603b69 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -413,7 +413,7 @@ Now, building on the TF examples above, let's look at how the `A3C torch policy .. code-block:: python - def model_value_predictions(policy, model_out): + def model_value_predictions(policy, input_dict, state_batches, model_out): return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} ``postprocess_fn`` and ``mixins``: Similar to the PPO example, we need access to the value function during postprocessing (i.e., ``add_advantages`` below calls ``policy._value()``. The value function is exposed through a mixin class that defines the method: diff --git a/python/ray/rllib/agents/a3c/a3c_torch_policy.py b/python/ray/rllib/agents/a3c/a3c_torch_policy.py index 6ccf6c48d..f11ff51be 100644 --- a/python/ray/rllib/agents/a3c/a3c_torch_policy.py +++ b/python/ray/rllib/agents/a3c/a3c_torch_policy.py @@ -53,7 +53,7 @@ def add_advantages(policy, policy.config["lambda"]) -def model_value_predictions(policy, model_out): +def model_value_predictions(policy, input_dict, state_batches, model_out): return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} diff --git a/python/ray/rllib/policy/torch_policy.py b/python/ray/rllib/policy/torch_policy.py index 633e438c5..045902621 100644 --- a/python/ray/rllib/policy/torch_policy.py +++ b/python/ray/rllib/policy/torch_policy.py @@ -2,9 +2,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np import os -import numpy as np from threading import Lock try: @@ -69,15 +69,21 @@ class TorchPolicy(Policy): **kwargs): with self.lock: with torch.no_grad(): - ob = torch.from_numpy(np.array(obs_batch)) \ - .float().to(self.device) - model_out = self._model({"obs": ob}, state_batches) + input_dict = self._lazy_tensor_dict({ + "obs": obs_batch, + }) + if prev_action_batch: + input_dict["prev_actions"] = prev_action_batch + if prev_reward_batch: + input_dict["prev_rewards"] = prev_reward_batch + model_out = self._model(input_dict, state_batches) logits, _, vf, state = model_out action_dist = self._action_dist_cls(logits) actions = action_dist.sample() return (actions.cpu().numpy(), [h.cpu().numpy() for h in state], - self.extra_action_out(model_out)) + self.extra_action_out(input_dict, state_batches, + model_out)) @override(Policy) def learn_on_batch(self, postprocessed_batch): @@ -146,10 +152,12 @@ class TorchPolicy(Policy): return processing info.""" return {} - def extra_action_out(self, model_out): + def extra_action_out(self, input_dict, state_batches, model_out): """Returns dict of extra info to include in experience batch. Arguments: + input_dict (dict): Dict of model input tensors. + state_batches (list): List of state tensors. model_out (list): Outputs of the policy model module.""" return {} @@ -168,6 +176,12 @@ class TorchPolicy(Policy): def _lazy_tensor_dict(self, postprocessed_batch): batch_tensors = UsageTrackingDict(postprocessed_batch) - batch_tensors.set_get_interceptor( - lambda arr: torch.from_numpy(arr).to(self.device)) + + def convert(arr): + tensor = torch.from_numpy(np.asarray(arr)) + if tensor.dtype == torch.double: + tensor = tensor.float() + return tensor.to(self.device) + + batch_tensors.set_get_interceptor(convert) return batch_tensors diff --git a/python/ray/rllib/policy/torch_policy_template.py b/python/ray/rllib/policy/torch_policy_template.py index 049591c04..19e943600 100644 --- a/python/ray/rllib/policy/torch_policy_template.py +++ b/python/ray/rllib/policy/torch_policy_template.py @@ -108,11 +108,13 @@ def build_torch_policy(name, return TorchPolicy.extra_grad_process(self) @override(TorchPolicy) - def extra_action_out(self, model_out): + def extra_action_out(self, input_dict, state_batches, model_out): if extra_action_out_fn: - return extra_action_out_fn(self, model_out) + return extra_action_out_fn(self, input_dict, state_batches, + model_out) else: - return TorchPolicy.extra_action_out(self, model_out) + return TorchPolicy.extra_action_out(self, input_dict, + state_batches, model_out) @override(TorchPolicy) def optimizer(self): diff --git a/python/ray/rllib/utils/tracking_dict.py b/python/ray/rllib/utils/tracking_dict.py index c0f145734..9b64925dc 100644 --- a/python/ray/rllib/utils/tracking_dict.py +++ b/python/ray/rllib/utils/tracking_dict.py @@ -30,3 +30,8 @@ class UsageTrackingDict(dict): self.intercepted_values[key] = self.get_interceptor(value) value = self.intercepted_values[key] return value + + def __setitem__(self, key, value): + dict.__setitem__(self, key, value) + if key in self.intercepted_values: + self.intercepted_values[key] = value