mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 13:12:46 +08:00
[rllib] Allow Torch policies access to full action input dict in extra_action_out_fn (#4894)
* fix torch extra out * preserve setitem * fix docs
This commit is contained in:
@@ -53,7 +53,7 @@ def add_advantages(policy,
|
||||
policy.config["lambda"])
|
||||
|
||||
|
||||
def model_value_predictions(policy, model_out):
|
||||
def model_value_predictions(policy, input_dict, state_batches, model_out):
|
||||
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
|
||||
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from threading import Lock
|
||||
|
||||
try:
|
||||
@@ -69,15 +69,21 @@ class TorchPolicy(Policy):
|
||||
**kwargs):
|
||||
with self.lock:
|
||||
with torch.no_grad():
|
||||
ob = torch.from_numpy(np.array(obs_batch)) \
|
||||
.float().to(self.device)
|
||||
model_out = self._model({"obs": ob}, state_batches)
|
||||
input_dict = self._lazy_tensor_dict({
|
||||
"obs": obs_batch,
|
||||
})
|
||||
if prev_action_batch:
|
||||
input_dict["prev_actions"] = prev_action_batch
|
||||
if prev_reward_batch:
|
||||
input_dict["prev_rewards"] = prev_reward_batch
|
||||
model_out = self._model(input_dict, state_batches)
|
||||
logits, _, vf, state = model_out
|
||||
action_dist = self._action_dist_cls(logits)
|
||||
actions = action_dist.sample()
|
||||
return (actions.cpu().numpy(),
|
||||
[h.cpu().numpy() for h in state],
|
||||
self.extra_action_out(model_out))
|
||||
self.extra_action_out(input_dict, state_batches,
|
||||
model_out))
|
||||
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
@@ -146,10 +152,12 @@ class TorchPolicy(Policy):
|
||||
return processing info."""
|
||||
return {}
|
||||
|
||||
def extra_action_out(self, model_out):
|
||||
def extra_action_out(self, input_dict, state_batches, model_out):
|
||||
"""Returns dict of extra info to include in experience batch.
|
||||
|
||||
Arguments:
|
||||
input_dict (dict): Dict of model input tensors.
|
||||
state_batches (list): List of state tensors.
|
||||
model_out (list): Outputs of the policy model module."""
|
||||
return {}
|
||||
|
||||
@@ -168,6 +176,12 @@ class TorchPolicy(Policy):
|
||||
|
||||
def _lazy_tensor_dict(self, postprocessed_batch):
|
||||
batch_tensors = UsageTrackingDict(postprocessed_batch)
|
||||
batch_tensors.set_get_interceptor(
|
||||
lambda arr: torch.from_numpy(arr).to(self.device))
|
||||
|
||||
def convert(arr):
|
||||
tensor = torch.from_numpy(np.asarray(arr))
|
||||
if tensor.dtype == torch.double:
|
||||
tensor = tensor.float()
|
||||
return tensor.to(self.device)
|
||||
|
||||
batch_tensors.set_get_interceptor(convert)
|
||||
return batch_tensors
|
||||
|
||||
@@ -108,11 +108,13 @@ def build_torch_policy(name,
|
||||
return TorchPolicy.extra_grad_process(self)
|
||||
|
||||
@override(TorchPolicy)
|
||||
def extra_action_out(self, model_out):
|
||||
def extra_action_out(self, input_dict, state_batches, model_out):
|
||||
if extra_action_out_fn:
|
||||
return extra_action_out_fn(self, model_out)
|
||||
return extra_action_out_fn(self, input_dict, state_batches,
|
||||
model_out)
|
||||
else:
|
||||
return TorchPolicy.extra_action_out(self, model_out)
|
||||
return TorchPolicy.extra_action_out(self, input_dict,
|
||||
state_batches, model_out)
|
||||
|
||||
@override(TorchPolicy)
|
||||
def optimizer(self):
|
||||
|
||||
@@ -30,3 +30,8 @@ class UsageTrackingDict(dict):
|
||||
self.intercepted_values[key] = self.get_interceptor(value)
|
||||
value = self.intercepted_values[key]
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
dict.__setitem__(self, key, value)
|
||||
if key in self.intercepted_values:
|
||||
self.intercepted_values[key] = value
|
||||
|
||||
Reference in New Issue
Block a user