[rllib] Allow Torch policies access to full action input dict in extra_action_out_fn (#4894)

* fix torch extra out

* preserve setitem

* fix docs
This commit is contained in:
Eric Liang
2019-06-01 16:58:49 +08:00
committed by GitHub
parent 1c073e92e4
commit 9aa1cd613d
5 changed files with 34 additions and 13 deletions
@@ -53,7 +53,7 @@ def add_advantages(policy,
policy.config["lambda"])
def model_value_predictions(policy, model_out):
def model_value_predictions(policy, input_dict, state_batches, model_out):
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
+22 -8
View File
@@ -2,9 +2,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import numpy as np
from threading import Lock
try:
@@ -69,15 +69,21 @@ class TorchPolicy(Policy):
**kwargs):
with self.lock:
with torch.no_grad():
ob = torch.from_numpy(np.array(obs_batch)) \
.float().to(self.device)
model_out = self._model({"obs": ob}, state_batches)
input_dict = self._lazy_tensor_dict({
"obs": obs_batch,
})
if prev_action_batch:
input_dict["prev_actions"] = prev_action_batch
if prev_reward_batch:
input_dict["prev_rewards"] = prev_reward_batch
model_out = self._model(input_dict, state_batches)
logits, _, vf, state = model_out
action_dist = self._action_dist_cls(logits)
actions = action_dist.sample()
return (actions.cpu().numpy(),
[h.cpu().numpy() for h in state],
self.extra_action_out(model_out))
self.extra_action_out(input_dict, state_batches,
model_out))
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
@@ -146,10 +152,12 @@ class TorchPolicy(Policy):
return processing info."""
return {}
def extra_action_out(self, model_out):
def extra_action_out(self, input_dict, state_batches, model_out):
"""Returns dict of extra info to include in experience batch.
Arguments:
input_dict (dict): Dict of model input tensors.
state_batches (list): List of state tensors.
model_out (list): Outputs of the policy model module."""
return {}
@@ -168,6 +176,12 @@ class TorchPolicy(Policy):
def _lazy_tensor_dict(self, postprocessed_batch):
batch_tensors = UsageTrackingDict(postprocessed_batch)
batch_tensors.set_get_interceptor(
lambda arr: torch.from_numpy(arr).to(self.device))
def convert(arr):
tensor = torch.from_numpy(np.asarray(arr))
if tensor.dtype == torch.double:
tensor = tensor.float()
return tensor.to(self.device)
batch_tensors.set_get_interceptor(convert)
return batch_tensors
@@ -108,11 +108,13 @@ def build_torch_policy(name,
return TorchPolicy.extra_grad_process(self)
@override(TorchPolicy)
def extra_action_out(self, model_out):
def extra_action_out(self, input_dict, state_batches, model_out):
if extra_action_out_fn:
return extra_action_out_fn(self, model_out)
return extra_action_out_fn(self, input_dict, state_batches,
model_out)
else:
return TorchPolicy.extra_action_out(self, model_out)
return TorchPolicy.extra_action_out(self, input_dict,
state_batches, model_out)
@override(TorchPolicy)
def optimizer(self):
+5
View File
@@ -30,3 +30,8 @@ class UsageTrackingDict(dict):
self.intercepted_values[key] = self.get_interceptor(value)
value = self.intercepted_values[key]
return value
def __setitem__(self, key, value):
dict.__setitem__(self, key, value)
if key in self.intercepted_values:
self.intercepted_values[key] = value