mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:53:14 +08:00
[RLlib] Policy.compute_single_action() broken for nested actions (Issue 8411). (#8514)
This commit is contained in:
@@ -7,11 +7,11 @@ import numpy as np
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.es.es_tf_policy import make_session
|
||||
from ray.rllib.evaluation.sampler import unbatch_actions
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.space_utils import unbatch
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
@@ -59,7 +59,7 @@ class ARSTFPolicy:
|
||||
observation = self.observation_filter(observation[None], update=update)
|
||||
action = self.sess.run(
|
||||
self.sampler, feed_dict={self.inputs: observation})
|
||||
action = unbatch_actions(action)
|
||||
action = unbatch(action)
|
||||
if add_noise and isinstance(self.action_space, gym.spaces.Box):
|
||||
action += np.random.randn(*action.shape) * self.action_noise_std
|
||||
return action
|
||||
|
||||
@@ -6,13 +6,12 @@ import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.sampler import unbatch_actions
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.space_utils import get_base_struct_from_space
|
||||
from ray.rllib.utils.space_utils import get_base_struct_from_space, unbatch
|
||||
|
||||
tf = try_import_tf()
|
||||
tree = try_import_tree()
|
||||
@@ -111,7 +110,7 @@ class ESTFPolicy:
|
||||
self.action_space_struct)
|
||||
# Convert `flat_actions` to a list of lists of action components
|
||||
# (list of single actions).
|
||||
actions = unbatch_actions(actions)
|
||||
actions = unbatch(actions)
|
||||
return actions
|
||||
|
||||
def _add_noise(self, single_action, single_action_space):
|
||||
|
||||
@@ -5,12 +5,12 @@ import gym
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sampler import unbatch_actions
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.space_utils import unbatch
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
@@ -61,7 +61,7 @@ def before_init(policy, observation_space, action_space, config):
|
||||
}, [], None)
|
||||
dist = policy.dist_class(dist_inputs, policy.model)
|
||||
action = dist.sample().detach().numpy()
|
||||
action = unbatch_actions(action)
|
||||
action = unbatch(action)
|
||||
if add_noise and isinstance(policy.action_space, gym.spaces.Box):
|
||||
action += np.random.randn(*action.shape) * policy.action_noise_std
|
||||
return action
|
||||
|
||||
@@ -18,8 +18,8 @@ from ray.rllib.offline import InputReader
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.space_utils import flatten_to_single_ndarray, unbatch
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils.space_utils import flatten_to_single_ndarray
|
||||
|
||||
tree = try_import_tree()
|
||||
|
||||
@@ -688,7 +688,7 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
|
||||
if clip_actions:
|
||||
actions = clip_action(actions, policy.action_space_struct)
|
||||
# Split action-component batches into single action rows.
|
||||
actions = unbatch_actions(actions)
|
||||
actions = unbatch(actions)
|
||||
for i, action in enumerate(actions):
|
||||
env_id = eval_data[i].env_id
|
||||
agent_id = eval_data[i].agent_id
|
||||
@@ -726,43 +726,6 @@ def _fetch_atari_metrics(base_env):
|
||||
return atari_out
|
||||
|
||||
|
||||
def unbatch_actions(action_batches):
|
||||
"""Converts action_batches from list of batches to batch of lists.
|
||||
|
||||
Input: Struct of batches:
|
||||
{"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])}
|
||||
Output: Batch (list) of structs (each of these structs representing a
|
||||
single action):
|
||||
[
|
||||
{"a": 1, "b": (4, 7.0)}, <- action 1
|
||||
{"a": 2, "b": (5, 8.0)}, <- action 2
|
||||
{"a": 3, "b": (6, 9.0)}, <- action 3
|
||||
]
|
||||
|
||||
Args:
|
||||
action_batches (any): The list of action-component batches. Each item
|
||||
in this list represents the batch for a single action component
|
||||
(in case action is Tuple/Dict), meaning the list is already
|
||||
flattened.
|
||||
Alternatively, `action_batches` may also simply be a batch of
|
||||
primitive actions (non Tuple/Dict).
|
||||
|
||||
Returns:
|
||||
List[List[action-components]]: The list of action rows. Each item
|
||||
in the returned list represents a single (maybe complex) action.
|
||||
"""
|
||||
flat_action_batches = tree.flatten(action_batches)
|
||||
|
||||
out = []
|
||||
for batch_pos in range(len(flat_action_batches[0])):
|
||||
out.append(
|
||||
tree.unflatten_as(action_batches, [
|
||||
flat_action_batches[i][batch_pos]
|
||||
for i in range(len(flat_action_batches))
|
||||
]))
|
||||
return out
|
||||
|
||||
|
||||
def _to_column_format(rnn_state_rows):
|
||||
num_cols = len(rnn_state_rows[0])
|
||||
return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]
|
||||
|
||||
+10
-4
@@ -7,7 +7,8 @@ from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.space_utils import get_base_struct_from_space
|
||||
from ray.rllib.utils.space_utils import get_base_struct_from_space, \
|
||||
unbatch
|
||||
|
||||
tree = try_import_tree()
|
||||
|
||||
@@ -158,7 +159,7 @@ class Policy(metaclass=ABCMeta):
|
||||
if state is not None:
|
||||
state_batch = [[s] for s in state]
|
||||
|
||||
[action], state_out, info = self.compute_actions(
|
||||
batched_action, state_out, info = self.compute_actions(
|
||||
[obs],
|
||||
state_batch,
|
||||
prev_action_batch=prev_action_batch,
|
||||
@@ -168,11 +169,16 @@ class Policy(metaclass=ABCMeta):
|
||||
explore=explore,
|
||||
timestep=timestep)
|
||||
|
||||
single_action = unbatch(batched_action)
|
||||
assert len(single_action) == 1
|
||||
single_action = single_action[0]
|
||||
|
||||
if clip_actions:
|
||||
action = clip_action(action, self.action_space_struct)
|
||||
single_action = clip_action(single_action,
|
||||
self.action_space_struct)
|
||||
|
||||
# Return action, internal state(s), infos.
|
||||
return action, [s[0] for s in state_out], \
|
||||
return single_action, [s[0] for s in state_out], \
|
||||
{k: v[0] for k, v in info.items()}
|
||||
|
||||
@DeveloperAPI
|
||||
|
||||
@@ -92,3 +92,39 @@ def flatten_to_single_ndarray(input_):
|
||||
expanded.append(np.reshape(in_, [-1]))
|
||||
input_ = np.concatenate(expanded, axis=0).flatten()
|
||||
return input_
|
||||
|
||||
|
||||
def unbatch(batches_struct):
|
||||
"""Converts input from (nested) struct of batches to batch of structs.
|
||||
|
||||
Input: Struct of different batches (each batch has size=3):
|
||||
{"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])}
|
||||
Output: Batch (list) of structs (each of these structs representing a
|
||||
single action):
|
||||
[
|
||||
{"a": 1, "b": (4, 7.0)}, <- action 1
|
||||
{"a": 2, "b": (5, 8.0)}, <- action 2
|
||||
{"a": 3, "b": (6, 9.0)}, <- action 3
|
||||
]
|
||||
|
||||
Args:
|
||||
batches_struct (any): The struct of component batches. Each leaf item
|
||||
in this struct represents the batch for a single component
|
||||
(in case struct is tuple/dict).
|
||||
Alternatively, `batches_struct` may also simply be a batch of
|
||||
primitives (non tuple/dict).
|
||||
|
||||
Returns:
|
||||
List[struct[components]]: The list of rows. Each item
|
||||
in the returned list represents a single (maybe complex) struct.
|
||||
"""
|
||||
flat_batches = tree.flatten(batches_struct)
|
||||
|
||||
out = []
|
||||
for batch_pos in range(len(flat_batches[0])):
|
||||
out.append(
|
||||
tree.unflatten_as(
|
||||
batches_struct,
|
||||
[flat_batches[i][batch_pos]
|
||||
for i in range(len(flat_batches))]))
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user