diff --git a/rllib/agents/ars/ars_tf_policy.py b/rllib/agents/ars/ars_tf_policy.py index b3f5dfa1f..b47cef75d 100644 --- a/rllib/agents/ars/ars_tf_policy.py +++ b/rllib/agents/ars/ars_tf_policy.py @@ -7,11 +7,11 @@ import numpy as np import ray import ray.experimental.tf_utils from ray.rllib.agents.es.es_tf_policy import make_session -from ray.rllib.evaluation.sampler import unbatch_actions from ray.rllib.models import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.filter import get_filter from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.space_utils import unbatch tf = try_import_tf() @@ -59,7 +59,7 @@ class ARSTFPolicy: observation = self.observation_filter(observation[None], update=update) action = self.sess.run( self.sampler, feed_dict={self.inputs: observation}) - action = unbatch_actions(action) + action = unbatch(action) if add_noise and isinstance(self.action_space, gym.spaces.Box): action += np.random.randn(*action.shape) * self.action_noise_std return action diff --git a/rllib/agents/es/es_tf_policy.py b/rllib/agents/es/es_tf_policy.py index 483e32fff..da2c274e5 100644 --- a/rllib/agents/es/es_tf_policy.py +++ b/rllib/agents/es/es_tf_policy.py @@ -6,13 +6,12 @@ import numpy as np import ray import ray.experimental.tf_utils -from ray.rllib.evaluation.sampler import unbatch_actions from ray.rllib.models import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import try_import_tree from ray.rllib.utils.filter import get_filter from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.space_utils import get_base_struct_from_space +from ray.rllib.utils.space_utils import get_base_struct_from_space, unbatch tf = try_import_tf() tree = try_import_tree() @@ -111,7 +110,7 @@ class ESTFPolicy: self.action_space_struct) # Convert `flat_actions` to a list of lists of action components # (list of single actions). - actions = unbatch_actions(actions) + actions = unbatch(actions) return actions def _add_noise(self, single_action, single_action_space): diff --git a/rllib/agents/es/es_torch_policy.py b/rllib/agents/es/es_torch_policy.py index 34bcd10bc..95d844584 100644 --- a/rllib/agents/es/es_torch_policy.py +++ b/rllib/agents/es/es_torch_policy.py @@ -5,12 +5,12 @@ import gym import numpy as np import ray -from ray.rllib.evaluation.sampler import unbatch_actions from ray.rllib.models import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.filter import get_filter from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.space_utils import unbatch from ray.rllib.utils.torch_ops import convert_to_torch_tensor torch, _ = try_import_torch() @@ -61,7 +61,7 @@ def before_init(policy, observation_space, action_space, config): }, [], None) dist = policy.dist_class(dist_inputs, policy.model) action = dist.sample().detach().numpy() - action = unbatch_actions(action) + action = unbatch(action) if add_noise and isinstance(policy.action_space, gym.spaces.Box): action += np.random.randn(*action.shape) * policy.action_noise_std return action diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index d19395cd9..c8e50b1f3 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -18,8 +18,8 @@ from ray.rllib.offline import InputReader from ray.rllib.utils import try_import_tree from ray.rllib.utils.annotations import override from ray.rllib.utils.debug import summarize +from ray.rllib.utils.space_utils import flatten_to_single_ndarray, unbatch from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.utils.space_utils import flatten_to_single_ndarray tree = try_import_tree() @@ -688,7 +688,7 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes, if clip_actions: actions = clip_action(actions, policy.action_space_struct) # Split action-component batches into single action rows. - actions = unbatch_actions(actions) + actions = unbatch(actions) for i, action in enumerate(actions): env_id = eval_data[i].env_id agent_id = eval_data[i].agent_id @@ -726,43 +726,6 @@ def _fetch_atari_metrics(base_env): return atari_out -def unbatch_actions(action_batches): - """Converts action_batches from list of batches to batch of lists. - - Input: Struct of batches: - {"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])} - Output: Batch (list) of structs (each of these structs representing a - single action): - [ - {"a": 1, "b": (4, 7.0)}, <- action 1 - {"a": 2, "b": (5, 8.0)}, <- action 2 - {"a": 3, "b": (6, 9.0)}, <- action 3 - ] - - Args: - action_batches (any): The list of action-component batches. Each item - in this list represents the batch for a single action component - (in case action is Tuple/Dict), meaning the list is already - flattened. - Alternatively, `action_batches` may also simply be a batch of - primitive actions (non Tuple/Dict). - - Returns: - List[List[action-components]]: The list of action rows. Each item - in the returned list represents a single (maybe complex) action. - """ - flat_action_batches = tree.flatten(action_batches) - - out = [] - for batch_pos in range(len(flat_action_batches[0])): - out.append( - tree.unflatten_as(action_batches, [ - flat_action_batches[i][batch_pos] - for i in range(len(flat_action_batches)) - ])) - return out - - def _to_column_format(rnn_state_rows): num_cols = len(rnn_state_rows[0]) return [[row[i] for row in rnn_state_rows] for i in range(num_cols)] diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index dc037d109..786fcbbb0 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -7,7 +7,8 @@ from ray.rllib.utils import try_import_tree from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.from_config import from_config -from ray.rllib.utils.space_utils import get_base_struct_from_space +from ray.rllib.utils.space_utils import get_base_struct_from_space, \ + unbatch tree = try_import_tree() @@ -158,7 +159,7 @@ class Policy(metaclass=ABCMeta): if state is not None: state_batch = [[s] for s in state] - [action], state_out, info = self.compute_actions( + batched_action, state_out, info = self.compute_actions( [obs], state_batch, prev_action_batch=prev_action_batch, @@ -168,11 +169,16 @@ class Policy(metaclass=ABCMeta): explore=explore, timestep=timestep) + single_action = unbatch(batched_action) + assert len(single_action) == 1 + single_action = single_action[0] + if clip_actions: - action = clip_action(action, self.action_space_struct) + single_action = clip_action(single_action, + self.action_space_struct) # Return action, internal state(s), infos. - return action, [s[0] for s in state_out], \ + return single_action, [s[0] for s in state_out], \ {k: v[0] for k, v in info.items()} @DeveloperAPI diff --git a/rllib/utils/space_utils.py b/rllib/utils/space_utils.py index 9081d7398..8111723e3 100644 --- a/rllib/utils/space_utils.py +++ b/rllib/utils/space_utils.py @@ -92,3 +92,39 @@ def flatten_to_single_ndarray(input_): expanded.append(np.reshape(in_, [-1])) input_ = np.concatenate(expanded, axis=0).flatten() return input_ + + +def unbatch(batches_struct): + """Converts input from (nested) struct of batches to batch of structs. + + Input: Struct of different batches (each batch has size=3): + {"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])} + Output: Batch (list) of structs (each of these structs representing a + single action): + [ + {"a": 1, "b": (4, 7.0)}, <- action 1 + {"a": 2, "b": (5, 8.0)}, <- action 2 + {"a": 3, "b": (6, 9.0)}, <- action 3 + ] + + Args: + batches_struct (any): The struct of component batches. Each leaf item + in this struct represents the batch for a single component + (in case struct is tuple/dict). + Alternatively, `batches_struct` may also simply be a batch of + primitives (non tuple/dict). + + Returns: + List[struct[components]]: The list of rows. Each item + in the returned list represents a single (maybe complex) struct. + """ + flat_batches = tree.flatten(batches_struct) + + out = [] + for batch_pos in range(len(flat_batches[0])): + out.append( + tree.unflatten_as( + batches_struct, + [flat_batches[i][batch_pos] + for i in range(len(flat_batches))])) + return out