mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
[RLlib] Nested action space PR (minimally invasive; torch only + test). (#8101)
- Add TorchMultiActionDistribution class. - Add framework-agnostic test cases for TorchMultiActionDistribution.
This commit is contained in:
+2
-2
@@ -186,7 +186,7 @@ matrix:
|
||||
- . ./ci/travis/ci.sh build
|
||||
script:
|
||||
- . ./ci/travis/ci.sh preload
|
||||
- travis_wait 90 bazel test --config=ci --test_output=streamed --build_tests_only --test_tag_filters=learning_tests rllib/...
|
||||
- travis_wait 120 bazel test --config=ci --test_output=streamed --build_tests_only --test_tag_filters=learning_tests rllib/...
|
||||
|
||||
# RLlib: Learning tests with tf=1.x (from rllib/tuned_examples/regression_tests/*.yaml).
|
||||
# Requested by Edi (MS): Test all learning capabilities with tf1.x
|
||||
@@ -205,7 +205,7 @@ matrix:
|
||||
- . ./ci/travis/ci.sh build
|
||||
script:
|
||||
- . ./ci/travis/ci.sh preload
|
||||
- travis_wait 90 bazel test --config=ci --test_output=streamed --build_tests_only --test_tag_filters=learning_tests rllib/...
|
||||
- travis_wait 120 bazel test --config=ci --test_output=streamed --build_tests_only --test_tag_filters=learning_tests rllib/...
|
||||
|
||||
# RLlib: Quick Agent train.py runs (compilation & running, no(!) learning).
|
||||
# Agent single tests (compilation, loss-funcs, etc..).
|
||||
|
||||
+1
-1
@@ -144,7 +144,7 @@ py_test(
|
||||
py_test(
|
||||
name = "test_sac",
|
||||
tags = ["agents_dir"],
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["agents/sac/tests/test_sac.py"]
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import numpy as np
|
||||
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.space_utils import flatten_to_single_ndarray
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@@ -110,10 +111,11 @@ class MultiAgentEpisode:
|
||||
"""Returns the last action for the specified agent, or zeros."""
|
||||
|
||||
if agent_id in self._agent_to_last_action:
|
||||
return _flatten_action(self._agent_to_last_action[agent_id])
|
||||
return flatten_to_single_ndarray(
|
||||
self._agent_to_last_action[agent_id])
|
||||
else:
|
||||
policy = self._policies[self.policy_for(agent_id)]
|
||||
flat = _flatten_action(policy.action_space.sample())
|
||||
flat = flatten_to_single_ndarray(policy.action_space.sample())
|
||||
return np.zeros_like(flat)
|
||||
|
||||
@DeveloperAPI
|
||||
@@ -121,7 +123,8 @@ class MultiAgentEpisode:
|
||||
"""Returns the previous action for the specified agent."""
|
||||
|
||||
if agent_id in self._agent_to_prev_action:
|
||||
return _flatten_action(self._agent_to_prev_action[agent_id])
|
||||
return flatten_to_single_ndarray(
|
||||
self._agent_to_prev_action[agent_id])
|
||||
else:
|
||||
# We're at t=0, so return all zeros.
|
||||
return np.zeros_like(self.last_action_for(agent_id))
|
||||
@@ -186,13 +189,3 @@ class MultiAgentEpisode:
|
||||
self._agent_to_index[agent_id] = self._next_agent_index
|
||||
self._next_agent_index += 1
|
||||
return self._agent_to_index[agent_id]
|
||||
|
||||
|
||||
def _flatten_action(action):
|
||||
# Concatenate tuple actions
|
||||
if isinstance(action, list) or isinstance(action, tuple):
|
||||
expanded = []
|
||||
for a in action:
|
||||
expanded.append(np.reshape(a, [-1]))
|
||||
action = np.concatenate(expanded, axis=0).flatten()
|
||||
return action
|
||||
|
||||
@@ -6,7 +6,7 @@ import threading
|
||||
import time
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
|
||||
from ray.rllib.evaluation.sample_batch_builder import \
|
||||
MultiAgentSampleBatchBuilder
|
||||
@@ -18,6 +18,7 @@ from ray.rllib.offline import InputReader
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.tuple_actions import TupleActions
|
||||
from ray.rllib.utils.space_utils import flatten_to_single_ndarray
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -545,8 +546,8 @@ def _process_observations(worker, base_env, policies, batch_builder_pool,
|
||||
episode.last_info_for(agent_id) or {},
|
||||
episode.rnn_state_for(agent_id),
|
||||
np.zeros_like(
|
||||
_flatten_action(policy.action_space.sample())),
|
||||
0.0))
|
||||
flatten_to_single_ndarray(
|
||||
policy.action_space.sample())), 0.0))
|
||||
|
||||
return active_envs, to_eval, outputs
|
||||
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Box
|
||||
from gym.spaces import Box, Tuple
|
||||
from scipy.stats import norm, beta
|
||||
import unittest
|
||||
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical, \
|
||||
SquashedGaussian, GumbelSoftmax
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical, \
|
||||
DiagGaussian, GumbelSoftmax, MultiActionDistribution, MultiCategorical, \
|
||||
SquashedGaussian
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchMultiCategorical, \
|
||||
TorchSquashedGaussian, TorchBeta
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
TorchSquashedGaussian, TorchBeta, TorchCategorical, \
|
||||
TorchMultiActionDistribution, TorchDiagGaussian
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.numpy import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \
|
||||
softmax, SMALL_NUMBER, LARGE_INTEGER
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
tree = try_import_tree()
|
||||
|
||||
|
||||
class TestDistributions(unittest.TestCase):
|
||||
@@ -195,9 +199,7 @@ class TestDistributions(unittest.TestCase):
|
||||
1.0 - SMALL_NUMBER)
|
||||
unsquashed_values = np.arctanh(save_normed_values)
|
||||
log_prob_unsquashed = np.sum(
|
||||
np.log(norm.pdf(unsquashed_values, means,
|
||||
stds)),
|
||||
-1)
|
||||
np.log(norm.pdf(unsquashed_values, means, stds)), -1)
|
||||
log_prob = log_prob_unsquashed - \
|
||||
np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2),
|
||||
axis=-1)
|
||||
@@ -281,7 +283,7 @@ class TestDistributions(unittest.TestCase):
|
||||
# TODO(sven): Test entropy outputs (against scipy).
|
||||
|
||||
def test_gumbel_softmax(self):
|
||||
"""Tests the GumbelSoftmax ActionDistribution (tf-eager only)."""
|
||||
"""Tests the GumbelSoftmax ActionDistribution (tf + eager only)."""
|
||||
for fw, sess in framework_iterator(
|
||||
frameworks=["tf", "eager"], session=True):
|
||||
batch_size = 1000
|
||||
@@ -307,6 +309,126 @@ class TestDistributions(unittest.TestCase):
|
||||
outs = sess.run(outs)
|
||||
check(np.mean(np.argmax(outs, -1)), expected_mean, rtol=0.08)
|
||||
|
||||
def test_multi_action_distribution(self):
|
||||
"""Tests the MultiActionDistribution (only torch so far)."""
|
||||
batch_size = 1000
|
||||
input_space = Tuple([
|
||||
Box(-10.0, 10.0, shape=(batch_size, 4)),
|
||||
Box(-2.0, 2.0, shape=(
|
||||
batch_size,
|
||||
6,
|
||||
))
|
||||
])
|
||||
std_space = Box(
|
||||
-0.05, 0.05, shape=(
|
||||
batch_size,
|
||||
3,
|
||||
))
|
||||
|
||||
value_space = Tuple([
|
||||
Box(0, 3, shape=(batch_size, ), dtype=np.int32),
|
||||
Box(-2.0, 2.0, shape=(batch_size, 3), dtype=np.float32)
|
||||
])
|
||||
|
||||
for fw, sess in framework_iterator(frameworks="torch", session=True):
|
||||
if fw == "torch":
|
||||
cls = TorchMultiActionDistribution
|
||||
child_distr_cls = [TorchCategorical, TorchDiagGaussian]
|
||||
else:
|
||||
cls = MultiActionDistribution
|
||||
child_distr_cls = [Categorical, DiagGaussian]
|
||||
|
||||
inputs = list(input_space.sample())
|
||||
distr = cls(
|
||||
np.concatenate([inputs[0], inputs[1]], axis=1),
|
||||
model={},
|
||||
action_space=value_space,
|
||||
child_distributions=child_distr_cls,
|
||||
input_lens=[4, 6])
|
||||
|
||||
# Sample deterministically.
|
||||
expected_det = [
|
||||
np.argmax(inputs[0], axis=-1),
|
||||
inputs[1][:, :3], # [:3]=Mean values.
|
||||
]
|
||||
out = distr.deterministic_sample()
|
||||
if sess:
|
||||
out = sess.run(out)
|
||||
check(out[0], expected_det[0])
|
||||
check(out[1], expected_det[1])
|
||||
|
||||
# Stochastic sampling -> expect roughly the mean.
|
||||
inputs = list(input_space.sample())
|
||||
# Fix categorical inputs (not needed for distribution itself, but
|
||||
# for our expectation calculations).
|
||||
inputs[0] = softmax(inputs[0], -1)
|
||||
# Fix std inputs (shouldn't be too large for this test).
|
||||
inputs[1][:, 3:] = std_space.sample()
|
||||
distr = cls(
|
||||
np.concatenate([inputs[0], inputs[1]], axis=1),
|
||||
model={},
|
||||
action_space=value_space,
|
||||
child_distributions=child_distr_cls,
|
||||
input_lens=[4, 6])
|
||||
expected_mean = [
|
||||
np.mean(np.sum(inputs[0] * np.array([0, 1, 2, 3]), -1)),
|
||||
inputs[1][:, :3], # [:3]=Mean values.
|
||||
]
|
||||
out = distr.sample()
|
||||
if sess:
|
||||
out = sess.run(out)
|
||||
out = list(out)
|
||||
if fw == "torch":
|
||||
out[0] = out[0].numpy()
|
||||
out[1] = out[1].numpy()
|
||||
check(np.mean(out[0]), expected_mean[0], decimals=1)
|
||||
check(np.mean(out[1], 0), np.mean(expected_mean[1], 0), decimals=1)
|
||||
|
||||
# Test log-likelihood outputs.
|
||||
# Make sure beta-values are within 0.0 and 1.0 for the numpy
|
||||
# calculation (which doesn't have scaling).
|
||||
inputs = list(input_space.sample())
|
||||
distr = cls(
|
||||
np.concatenate([inputs[0], inputs[1]], axis=1),
|
||||
model={},
|
||||
action_space=value_space,
|
||||
child_distributions=child_distr_cls,
|
||||
input_lens=[4, 6])
|
||||
inputs[0] = softmax(inputs[0], -1)
|
||||
values = list(value_space.sample())
|
||||
inputs[1][:, 3:] = np.exp(inputs[1][:, 3:])
|
||||
expected_log_llh = np.sum(
|
||||
np.concatenate([
|
||||
np.expand_dims(
|
||||
np.log(
|
||||
[i[values[0][j]]
|
||||
for j, i in enumerate(inputs[0])]), -1),
|
||||
np.log(
|
||||
norm.pdf(values[1], inputs[1][:, :3],
|
||||
inputs[1][:, 3:]))
|
||||
], -1), -1)
|
||||
|
||||
values[0] = np.expand_dims(values[0], -1)
|
||||
if fw == "torch":
|
||||
values = tree.map_structure(lambda s: torch.Tensor(s), values)
|
||||
# Test all flattened input.
|
||||
concat = np.concatenate(tree.flatten(values),
|
||||
-1).astype(np.float32)
|
||||
out = distr.logp(concat)
|
||||
if sess:
|
||||
out = sess.run(out)
|
||||
check(out, expected_log_llh, atol=15)
|
||||
# Test structured input.
|
||||
out = distr.logp(values)
|
||||
if sess:
|
||||
out = sess.run(out)
|
||||
check(out, expected_log_llh, atol=15)
|
||||
# Test flattened input.
|
||||
out = distr.logp(tree.flatten(values))
|
||||
if sess:
|
||||
out = sess.run(out)
|
||||
check(out, expected_log_llh, atol=15)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
@@ -228,8 +228,8 @@ class DiagGaussian(TFActionDistribution):
|
||||
def logp(self, x):
|
||||
return -0.5 * tf.reduce_sum(
|
||||
tf.square((x - self.mean) / self.std), axis=1) - \
|
||||
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) - \
|
||||
tf.reduce_sum(self.log_std, axis=1)
|
||||
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) - \
|
||||
tf.reduce_sum(self.log_std, axis=1)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def kl(self, other):
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
import functools
|
||||
from math import log
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \
|
||||
MAX_LOG_NN_OUTPUT
|
||||
from ray.rllib.utils.space_utils import get_base_struct_from_space
|
||||
from ray.rllib.utils.torch_ops import atanh
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
tree = try_import_tree()
|
||||
|
||||
|
||||
class TorchDistributionWrapper(ActionDistribution):
|
||||
@@ -305,3 +309,102 @@ class TorchDeterministic(TorchDistributionWrapper):
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
return np.prod(action_space.shape)
|
||||
|
||||
|
||||
class TorchMultiActionDistribution(TorchDistributionWrapper):
|
||||
"""Action distribution that operates on multiple, possibly nested actions.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs, model, *, child_distributions, input_lens,
|
||||
action_space):
|
||||
"""Initializes a TorchMultiActionDistribution object.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): A single tensor of shape [BATCH, size].
|
||||
model (ModelV2): The ModelV2 object used to produce inputs for this
|
||||
distribution.
|
||||
child_distributions (any[torch.Tensor]): Any struct
|
||||
that contains the child distribution classes to use to
|
||||
instantiate the child distributions from `inputs`. This could
|
||||
be an already flattened list or a struct according to
|
||||
`action_space`.
|
||||
input_lens (any[int]): A flat list or a nested struct of input
|
||||
split lengths used to split `inputs`.
|
||||
action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
|
||||
and possibly nested action space.
|
||||
"""
|
||||
if not isinstance(inputs, torch.Tensor):
|
||||
inputs = torch.Tensor(inputs)
|
||||
super().__init__(inputs, model)
|
||||
|
||||
self.action_space_struct = get_base_struct_from_space(action_space)
|
||||
|
||||
input_lens = tree.flatten(input_lens)
|
||||
flat_child_distributions = tree.flatten(child_distributions)
|
||||
split_inputs = torch.split(inputs, input_lens, dim=1)
|
||||
self.flat_child_distributions = tree.map_structure(
|
||||
lambda dist, input_: dist(input_, model), flat_child_distributions,
|
||||
list(split_inputs))
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
if isinstance(x, np.ndarray):
|
||||
x = torch.Tensor(x)
|
||||
# Single tensor input (all merged).
|
||||
if isinstance(x, torch.Tensor):
|
||||
split_indices = []
|
||||
for dist in self.flat_child_distributions:
|
||||
if isinstance(dist, TorchCategorical):
|
||||
split_indices.append(1)
|
||||
else:
|
||||
split_indices.append(dist.sample().size()[1])
|
||||
split_x = list(torch.split(x, split_indices, dim=1))
|
||||
# Structured or flattened (by single action component) input.
|
||||
else:
|
||||
split_x = tree.flatten(x)
|
||||
|
||||
def map_(val, dist):
|
||||
# Remove extra categorical dimension.
|
||||
if isinstance(dist, TorchCategorical):
|
||||
val = torch.squeeze(val, dim=-1).int()
|
||||
return dist.logp(val)
|
||||
|
||||
# Remove extra categorical dimension and take the logp of each
|
||||
# component.
|
||||
flat_logps = tree.map_structure(map_, split_x,
|
||||
self.flat_child_distributions)
|
||||
|
||||
return functools.reduce(lambda a, b: a + b, flat_logps)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def kl(self, other):
|
||||
kl_list = [
|
||||
d.kl(o) for d, o in zip(self.flat_child_distributions,
|
||||
other.flat_child_distributions)
|
||||
]
|
||||
return functools.reduce(lambda a, b: a + b, kl_list)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def entropy(self):
|
||||
entropy_list = [d.entropy() for d in self.flat_child_distributions]
|
||||
return functools.reduce(lambda a, b: a + b, entropy_list)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def sample(self):
|
||||
child_distributions = tree.unflatten_as(self.action_space_struct,
|
||||
self.flat_child_distributions)
|
||||
return tree.map_structure(lambda s: s.sample(), child_distributions)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
child_distributions = tree.unflatten_as(self.action_space_struct,
|
||||
self.flat_child_distributions)
|
||||
return tree.map_structure(lambda s: s.deterministic_sample(),
|
||||
child_distributions)
|
||||
|
||||
@override(TorchDistributionWrapper)
|
||||
def sampled_action_logp(self):
|
||||
p = self.flat_child_distributions[0].sampled_action_logp()
|
||||
for c in self.flat_child_distributions[1:]:
|
||||
p += c.sampled_action_logp()
|
||||
return p
|
||||
|
||||
@@ -7,13 +7,13 @@ import logging
|
||||
import numpy as np
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.evaluation.episode import _flatten_action
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.space_utils import flatten_to_single_ndarray
|
||||
|
||||
tf = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -241,7 +241,7 @@ def build_eager_tf_policy(name,
|
||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(
|
||||
np.array([observation_space.sample()])),
|
||||
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
|
||||
[_flatten_action(action_space.sample())]),
|
||||
[flatten_to_single_ndarray(action_space.sample())]),
|
||||
SampleBatch.PREV_REWARDS: tf.convert_to_tensor([0.]),
|
||||
}
|
||||
|
||||
|
||||
+3
-3
@@ -13,10 +13,10 @@ import gym
|
||||
import ray
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.evaluation.episode import _flatten_action
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.space_utils import flatten_to_single_ndarray
|
||||
from ray.tune.utils import merge_dicts
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
|
||||
@@ -362,7 +362,7 @@ def rollout(agent,
|
||||
use_lstm = {DEFAULT_POLICY_ID: False}
|
||||
|
||||
action_init = {
|
||||
p: _flatten_action(m.action_space.sample())
|
||||
p: flatten_to_single_ndarray(m.action_space.sample())
|
||||
for p, m in policy_map.items()
|
||||
}
|
||||
|
||||
@@ -411,7 +411,7 @@ def rollout(agent,
|
||||
prev_action=prev_actions[agent_id],
|
||||
prev_reward=prev_rewards[agent_id],
|
||||
policy_id=policy_id)
|
||||
a_action = _flatten_action(a_action) # tuple actions
|
||||
a_action = flatten_to_single_ndarray(a_action)
|
||||
action_dict[agent_id] = a_action
|
||||
prev_actions[agent_id] = a_action
|
||||
action = action_dict
|
||||
|
||||
+12
-4
@@ -14,8 +14,6 @@ from ray.rllib.utils.policy_server import PolicyServer
|
||||
from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \
|
||||
PolynomialSchedule, ExponentialSchedule, ConstantSchedule
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
|
||||
convert_to_torch_tensor
|
||||
from ray.tune.utils import merge_dicts, deep_update
|
||||
|
||||
|
||||
@@ -58,12 +56,21 @@ def force_list(elements=None, to_tuple=False):
|
||||
|
||||
force_tuple = partial(force_list, to_tuple=True)
|
||||
|
||||
|
||||
# TODO(sven): remove at some point.
|
||||
def try_import_tree():
|
||||
try:
|
||||
import tree
|
||||
return tree
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ModuleNotFoundError(
|
||||
"`dm-tree` is not installed! Run `pip install dm-tree`.")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"add_mixins",
|
||||
"check",
|
||||
"check_framework",
|
||||
"convert_to_non_torch_type",
|
||||
"convert_to_torch_tensor",
|
||||
"deprecation_warning",
|
||||
"fc",
|
||||
"force_list",
|
||||
@@ -83,6 +90,7 @@ __all__ = [
|
||||
"try_import_tf",
|
||||
"try_import_tfp",
|
||||
"try_import_torch",
|
||||
"try_import_tree",
|
||||
"ConstantSchedule",
|
||||
"DeveloperAPI",
|
||||
"ExponentialSchedule",
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
from gym.spaces import Tuple, Dict
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils import try_import_tree
|
||||
|
||||
tree = try_import_tree()
|
||||
|
||||
|
||||
def flatten_space(space):
|
||||
"""Flattens a gym.Space into its primitive components.
|
||||
|
||||
Primitive components are any non Tuple/Dict spaces.
|
||||
|
||||
Args:
|
||||
space(gym.Space): The gym.Space to flatten. This may be any
|
||||
supported type (including nested Tuples and Dicts).
|
||||
|
||||
Returns:
|
||||
List[gym.Space]: The flattened list of primitive Spaces. This list
|
||||
does not contain Tuples or Dicts anymore.
|
||||
"""
|
||||
|
||||
def _helper_flatten(space_, l):
|
||||
if isinstance(space_, Tuple):
|
||||
for s in space_:
|
||||
_helper_flatten(s, l)
|
||||
elif isinstance(space_, Dict):
|
||||
for k in space_.spaces:
|
||||
_helper_flatten(space_[k], l)
|
||||
else:
|
||||
l.append(space_)
|
||||
|
||||
ret = []
|
||||
_helper_flatten(space, ret)
|
||||
return ret
|
||||
|
||||
|
||||
def get_base_struct_from_space(space):
|
||||
"""Returns a Tuple/Dict Space as native (equally structured) py tuple/dict.
|
||||
|
||||
Args:
|
||||
space (gym.Space): The Space to get the python struct for.
|
||||
|
||||
Returns:
|
||||
Union[dict,tuple,gym.Space]: The struct equivalent to the given Space.
|
||||
Note that the returned struct still contains all original
|
||||
"primitive" Spaces (e.g. Box, Discrete).
|
||||
|
||||
Examples:
|
||||
>>> get_base_struct_from_space(Dict({
|
||||
>>> "a": Box(),
|
||||
>>> "b": Tuple([Discrete(2), Discrete(3)])
|
||||
>>> }))
|
||||
>>> # Will return: dict(a=Box(), b=tuple(Discrete(2), Discrete(3)))
|
||||
"""
|
||||
|
||||
def _helper_struct(space_):
|
||||
if isinstance(space_, Tuple):
|
||||
return tuple(_helper_struct(s) for s in space_)
|
||||
elif isinstance(space_, Dict):
|
||||
return {k: _helper_struct(space_[k]) for k in space_.spaces}
|
||||
else:
|
||||
return space_
|
||||
|
||||
return _helper_struct(space)
|
||||
|
||||
|
||||
def flatten_to_single_ndarray(input_):
|
||||
"""Returns a single np.ndarray given a list/tuple of np.ndarrays.
|
||||
|
||||
Args:
|
||||
input_ (Union[List[np.ndarray],np.ndarray]): The list of ndarrays or
|
||||
a single ndarray.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The result after concatenating all single arrays in input_.
|
||||
|
||||
Examples:
|
||||
>>> flatten_to_single_ndarray([
|
||||
>>> np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
|
||||
>>> np.array([7, 8, 9]),
|
||||
>>> ])
|
||||
>>> # Will return:
|
||||
>>> # np.array([
|
||||
>>> # 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0
|
||||
>>> # ])
|
||||
"""
|
||||
# Concatenate tuple actions
|
||||
if isinstance(input_, (list, tuple)):
|
||||
expanded = []
|
||||
for in_ in input_:
|
||||
expanded.append(np.reshape(in_, [-1]))
|
||||
input_ = np.concatenate(expanded, axis=0).flatten()
|
||||
return input_
|
||||
Reference in New Issue
Block a user