[RLlib] Policy.compute_single_action() broken for nested actions (Issue 8411). (#8514)

This commit is contained in:
Sven Mika
2020-05-20 22:29:08 +02:00
committed by GitHub
parent ebf060d484
commit d76578700d
6 changed files with 54 additions and 50 deletions
+2 -2
View File
@@ -7,11 +7,11 @@ import numpy as np
import ray
import ray.experimental.tf_utils
from ray.rllib.agents.es.es_tf_policy import make_session
from ray.rllib.evaluation.sampler import unbatch_actions
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.space_utils import unbatch
tf = try_import_tf()
@@ -59,7 +59,7 @@ class ARSTFPolicy:
observation = self.observation_filter(observation[None], update=update)
action = self.sess.run(
self.sampler, feed_dict={self.inputs: observation})
action = unbatch_actions(action)
action = unbatch(action)
if add_noise and isinstance(self.action_space, gym.spaces.Box):
action += np.random.randn(*action.shape) * self.action_noise_std
return action
+2 -3
View File
@@ -6,13 +6,12 @@ import numpy as np
import ray
import ray.experimental.tf_utils
from ray.rllib.evaluation.sampler import unbatch_actions
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import try_import_tree
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.space_utils import get_base_struct_from_space
from ray.rllib.utils.space_utils import get_base_struct_from_space, unbatch
tf = try_import_tf()
tree = try_import_tree()
@@ -111,7 +110,7 @@ class ESTFPolicy:
self.action_space_struct)
# Convert `flat_actions` to a list of lists of action components
# (list of single actions).
actions = unbatch_actions(actions)
actions = unbatch(actions)
return actions
def _add_noise(self, single_action, single_action_space):
+2 -2
View File
@@ -5,12 +5,12 @@ import gym
import numpy as np
import ray
from ray.rllib.evaluation.sampler import unbatch_actions
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.space_utils import unbatch
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
torch, _ = try_import_torch()
@@ -61,7 +61,7 @@ def before_init(policy, observation_space, action_space, config):
}, [], None)
dist = policy.dist_class(dist_inputs, policy.model)
action = dist.sample().detach().numpy()
action = unbatch_actions(action)
action = unbatch(action)
if add_noise and isinstance(policy.action_space, gym.spaces.Box):
action += np.random.randn(*action.shape) * policy.action_noise_std
return action
+2 -39
View File
@@ -18,8 +18,8 @@ from ray.rllib.offline import InputReader
from ray.rllib.utils import try_import_tree
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.space_utils import flatten_to_single_ndarray, unbatch
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.space_utils import flatten_to_single_ndarray
tree = try_import_tree()
@@ -688,7 +688,7 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
if clip_actions:
actions = clip_action(actions, policy.action_space_struct)
# Split action-component batches into single action rows.
actions = unbatch_actions(actions)
actions = unbatch(actions)
for i, action in enumerate(actions):
env_id = eval_data[i].env_id
agent_id = eval_data[i].agent_id
@@ -726,43 +726,6 @@ def _fetch_atari_metrics(base_env):
return atari_out
def unbatch_actions(action_batches):
"""Converts action_batches from list of batches to batch of lists.
Input: Struct of batches:
{"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])}
Output: Batch (list) of structs (each of these structs representing a
single action):
[
{"a": 1, "b": (4, 7.0)}, <- action 1
{"a": 2, "b": (5, 8.0)}, <- action 2
{"a": 3, "b": (6, 9.0)}, <- action 3
]
Args:
action_batches (any): The list of action-component batches. Each item
in this list represents the batch for a single action component
(in case action is Tuple/Dict), meaning the list is already
flattened.
Alternatively, `action_batches` may also simply be a batch of
primitive actions (non Tuple/Dict).
Returns:
List[List[action-components]]: The list of action rows. Each item
in the returned list represents a single (maybe complex) action.
"""
flat_action_batches = tree.flatten(action_batches)
out = []
for batch_pos in range(len(flat_action_batches[0])):
out.append(
tree.unflatten_as(action_batches, [
flat_action_batches[i][batch_pos]
for i in range(len(flat_action_batches))
]))
return out
def _to_column_format(rnn_state_rows):
num_cols = len(rnn_state_rows[0])
return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]
+10 -4
View File
@@ -7,7 +7,8 @@ from ray.rllib.utils import try_import_tree
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.space_utils import get_base_struct_from_space
from ray.rllib.utils.space_utils import get_base_struct_from_space, \
unbatch
tree = try_import_tree()
@@ -158,7 +159,7 @@ class Policy(metaclass=ABCMeta):
if state is not None:
state_batch = [[s] for s in state]
[action], state_out, info = self.compute_actions(
batched_action, state_out, info = self.compute_actions(
[obs],
state_batch,
prev_action_batch=prev_action_batch,
@@ -168,11 +169,16 @@ class Policy(metaclass=ABCMeta):
explore=explore,
timestep=timestep)
single_action = unbatch(batched_action)
assert len(single_action) == 1
single_action = single_action[0]
if clip_actions:
action = clip_action(action, self.action_space_struct)
single_action = clip_action(single_action,
self.action_space_struct)
# Return action, internal state(s), infos.
return action, [s[0] for s in state_out], \
return single_action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}
@DeveloperAPI
+36
View File
@@ -92,3 +92,39 @@ def flatten_to_single_ndarray(input_):
expanded.append(np.reshape(in_, [-1]))
input_ = np.concatenate(expanded, axis=0).flatten()
return input_
def unbatch(batches_struct):
"""Converts input from (nested) struct of batches to batch of structs.
Input: Struct of different batches (each batch has size=3):
{"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])}
Output: Batch (list) of structs (each of these structs representing a
single action):
[
{"a": 1, "b": (4, 7.0)}, <- action 1
{"a": 2, "b": (5, 8.0)}, <- action 2
{"a": 3, "b": (6, 9.0)}, <- action 3
]
Args:
batches_struct (any): The struct of component batches. Each leaf item
in this struct represents the batch for a single component
(in case struct is tuple/dict).
Alternatively, `batches_struct` may also simply be a batch of
primitives (non tuple/dict).
Returns:
List[struct[components]]: The list of rows. Each item
in the returned list represents a single (maybe complex) struct.
"""
flat_batches = tree.flatten(batches_struct)
out = []
for batch_pos in range(len(flat_batches[0])):
out.append(
tree.unflatten_as(
batches_struct,
[flat_batches[i][batch_pos]
for i in range(len(flat_batches))]))
return out