[RLlib] Nested action space PR (minimally invasive; torch only + test). (#8101)

- Add TorchMultiActionDistribution class.
- Add framework-agnostic test cases for TorchMultiActionDistribution.
This commit is contained in:
Sven Mika
2020-04-23 09:09:22 +02:00
committed by GitHub
parent a9d8d16b6b
commit e9ee5c4e5f
11 changed files with 361 additions and 40 deletions
+2 -2
View File
@@ -186,7 +186,7 @@ matrix:
- . ./ci/travis/ci.sh build
script:
- . ./ci/travis/ci.sh preload
- travis_wait 90 bazel test --config=ci --test_output=streamed --build_tests_only --test_tag_filters=learning_tests rllib/...
- travis_wait 120 bazel test --config=ci --test_output=streamed --build_tests_only --test_tag_filters=learning_tests rllib/...
# RLlib: Learning tests with tf=1.x (from rllib/tuned_examples/regression_tests/*.yaml).
# Requested by Edi (MS): Test all learning capabilities with tf1.x
@@ -205,7 +205,7 @@ matrix:
- . ./ci/travis/ci.sh build
script:
- . ./ci/travis/ci.sh preload
- travis_wait 90 bazel test --config=ci --test_output=streamed --build_tests_only --test_tag_filters=learning_tests rllib/...
- travis_wait 120 bazel test --config=ci --test_output=streamed --build_tests_only --test_tag_filters=learning_tests rllib/...
# RLlib: Quick Agent train.py runs (compilation & running, no(!) learning).
# Agent single tests (compilation, loss-funcs, etc..).
+1 -1
View File
@@ -144,7 +144,7 @@ py_test(
py_test(
name = "test_sac",
tags = ["agents_dir"],
size = "medium",
size = "large",
srcs = ["agents/sac/tests/test_sac.py"]
)
+6 -13
View File
@@ -5,6 +5,7 @@ import numpy as np
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.space_utils import flatten_to_single_ndarray
@DeveloperAPI
@@ -110,10 +111,11 @@ class MultiAgentEpisode:
"""Returns the last action for the specified agent, or zeros."""
if agent_id in self._agent_to_last_action:
return _flatten_action(self._agent_to_last_action[agent_id])
return flatten_to_single_ndarray(
self._agent_to_last_action[agent_id])
else:
policy = self._policies[self.policy_for(agent_id)]
flat = _flatten_action(policy.action_space.sample())
flat = flatten_to_single_ndarray(policy.action_space.sample())
return np.zeros_like(flat)
@DeveloperAPI
@@ -121,7 +123,8 @@ class MultiAgentEpisode:
"""Returns the previous action for the specified agent."""
if agent_id in self._agent_to_prev_action:
return _flatten_action(self._agent_to_prev_action[agent_id])
return flatten_to_single_ndarray(
self._agent_to_prev_action[agent_id])
else:
# We're at t=0, so return all zeros.
return np.zeros_like(self.last_action_for(agent_id))
@@ -186,13 +189,3 @@ class MultiAgentEpisode:
self._agent_to_index[agent_id] = self._next_agent_index
self._next_agent_index += 1
return self._agent_to_index[agent_id]
def _flatten_action(action):
# Concatenate tuple actions
if isinstance(action, list) or isinstance(action, tuple):
expanded = []
for a in action:
expanded.append(np.reshape(a, [-1]))
action = np.concatenate(expanded, axis=0).flatten()
return action
+4 -3
View File
@@ -6,7 +6,7 @@ import threading
import time
from ray.util.debug import log_once
from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.evaluation.sample_batch_builder import \
MultiAgentSampleBatchBuilder
@@ -18,6 +18,7 @@ from ray.rllib.offline import InputReader
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.tuple_actions import TupleActions
from ray.rllib.utils.space_utils import flatten_to_single_ndarray
from ray.rllib.utils.tf_run_builder import TFRunBuilder
logger = logging.getLogger(__name__)
@@ -545,8 +546,8 @@ def _process_observations(worker, base_env, policies, batch_builder_pool,
episode.last_info_for(agent_id) or {},
episode.rnn_state_for(agent_id),
np.zeros_like(
_flatten_action(policy.action_space.sample())),
0.0))
flatten_to_single_ndarray(
policy.action_space.sample())), 0.0))
return active_envs, to_eval, outputs
+131 -9
View File
@@ -1,19 +1,23 @@
import numpy as np
from gym.spaces import Box
from gym.spaces import Box, Tuple
from scipy.stats import norm, beta
import unittest
from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical, \
SquashedGaussian, GumbelSoftmax
from ray.rllib.models.tf.tf_action_dist import Categorical, \
DiagGaussian, GumbelSoftmax, MultiActionDistribution, MultiCategorical, \
SquashedGaussian
from ray.rllib.models.torch.torch_action_dist import TorchMultiCategorical, \
TorchSquashedGaussian, TorchBeta
from ray.rllib.utils import try_import_tf, try_import_torch
TorchSquashedGaussian, TorchBeta, TorchCategorical, \
TorchMultiActionDistribution, TorchDiagGaussian
from ray.rllib.utils import try_import_tree
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.numpy import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \
softmax, SMALL_NUMBER, LARGE_INTEGER
from ray.rllib.utils.test_utils import check, framework_iterator
tf = try_import_tf()
torch, _ = try_import_torch()
tree = try_import_tree()
class TestDistributions(unittest.TestCase):
@@ -195,9 +199,7 @@ class TestDistributions(unittest.TestCase):
1.0 - SMALL_NUMBER)
unsquashed_values = np.arctanh(save_normed_values)
log_prob_unsquashed = np.sum(
np.log(norm.pdf(unsquashed_values, means,
stds)),
-1)
np.log(norm.pdf(unsquashed_values, means, stds)), -1)
log_prob = log_prob_unsquashed - \
np.sum(np.log(1 - np.tanh(unsquashed_values) ** 2),
axis=-1)
@@ -281,7 +283,7 @@ class TestDistributions(unittest.TestCase):
# TODO(sven): Test entropy outputs (against scipy).
def test_gumbel_softmax(self):
"""Tests the GumbelSoftmax ActionDistribution (tf-eager only)."""
"""Tests the GumbelSoftmax ActionDistribution (tf + eager only)."""
for fw, sess in framework_iterator(
frameworks=["tf", "eager"], session=True):
batch_size = 1000
@@ -307,6 +309,126 @@ class TestDistributions(unittest.TestCase):
outs = sess.run(outs)
check(np.mean(np.argmax(outs, -1)), expected_mean, rtol=0.08)
def test_multi_action_distribution(self):
"""Tests the MultiActionDistribution (only torch so far)."""
batch_size = 1000
input_space = Tuple([
Box(-10.0, 10.0, shape=(batch_size, 4)),
Box(-2.0, 2.0, shape=(
batch_size,
6,
))
])
std_space = Box(
-0.05, 0.05, shape=(
batch_size,
3,
))
value_space = Tuple([
Box(0, 3, shape=(batch_size, ), dtype=np.int32),
Box(-2.0, 2.0, shape=(batch_size, 3), dtype=np.float32)
])
for fw, sess in framework_iterator(frameworks="torch", session=True):
if fw == "torch":
cls = TorchMultiActionDistribution
child_distr_cls = [TorchCategorical, TorchDiagGaussian]
else:
cls = MultiActionDistribution
child_distr_cls = [Categorical, DiagGaussian]
inputs = list(input_space.sample())
distr = cls(
np.concatenate([inputs[0], inputs[1]], axis=1),
model={},
action_space=value_space,
child_distributions=child_distr_cls,
input_lens=[4, 6])
# Sample deterministically.
expected_det = [
np.argmax(inputs[0], axis=-1),
inputs[1][:, :3], # [:3]=Mean values.
]
out = distr.deterministic_sample()
if sess:
out = sess.run(out)
check(out[0], expected_det[0])
check(out[1], expected_det[1])
# Stochastic sampling -> expect roughly the mean.
inputs = list(input_space.sample())
# Fix categorical inputs (not needed for distribution itself, but
# for our expectation calculations).
inputs[0] = softmax(inputs[0], -1)
# Fix std inputs (shouldn't be too large for this test).
inputs[1][:, 3:] = std_space.sample()
distr = cls(
np.concatenate([inputs[0], inputs[1]], axis=1),
model={},
action_space=value_space,
child_distributions=child_distr_cls,
input_lens=[4, 6])
expected_mean = [
np.mean(np.sum(inputs[0] * np.array([0, 1, 2, 3]), -1)),
inputs[1][:, :3], # [:3]=Mean values.
]
out = distr.sample()
if sess:
out = sess.run(out)
out = list(out)
if fw == "torch":
out[0] = out[0].numpy()
out[1] = out[1].numpy()
check(np.mean(out[0]), expected_mean[0], decimals=1)
check(np.mean(out[1], 0), np.mean(expected_mean[1], 0), decimals=1)
# Test log-likelihood outputs.
# Make sure beta-values are within 0.0 and 1.0 for the numpy
# calculation (which doesn't have scaling).
inputs = list(input_space.sample())
distr = cls(
np.concatenate([inputs[0], inputs[1]], axis=1),
model={},
action_space=value_space,
child_distributions=child_distr_cls,
input_lens=[4, 6])
inputs[0] = softmax(inputs[0], -1)
values = list(value_space.sample())
inputs[1][:, 3:] = np.exp(inputs[1][:, 3:])
expected_log_llh = np.sum(
np.concatenate([
np.expand_dims(
np.log(
[i[values[0][j]]
for j, i in enumerate(inputs[0])]), -1),
np.log(
norm.pdf(values[1], inputs[1][:, :3],
inputs[1][:, 3:]))
], -1), -1)
values[0] = np.expand_dims(values[0], -1)
if fw == "torch":
values = tree.map_structure(lambda s: torch.Tensor(s), values)
# Test all flattened input.
concat = np.concatenate(tree.flatten(values),
-1).astype(np.float32)
out = distr.logp(concat)
if sess:
out = sess.run(out)
check(out, expected_log_llh, atol=15)
# Test structured input.
out = distr.logp(values)
if sess:
out = sess.run(out)
check(out, expected_log_llh, atol=15)
# Test flattened input.
out = distr.logp(tree.flatten(values))
if sess:
out = sess.run(out)
check(out, expected_log_llh, atol=15)
if __name__ == "__main__":
import pytest
+2 -2
View File
@@ -228,8 +228,8 @@ class DiagGaussian(TFActionDistribution):
def logp(self, x):
return -0.5 * tf.reduce_sum(
tf.square((x - self.mean) / self.std), axis=1) - \
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) - \
tf.reduce_sum(self.log_std, axis=1)
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) - \
tf.reduce_sum(self.log_std, axis=1)
@override(ActionDistribution)
def kl(self, other):
+104 -1
View File
@@ -1,14 +1,18 @@
import functools
from math import log
import numpy as np
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.utils import try_import_tree
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \
MAX_LOG_NN_OUTPUT
from ray.rllib.utils.space_utils import get_base_struct_from_space
from ray.rllib.utils.torch_ops import atanh
torch, nn = try_import_torch()
tree = try_import_tree()
class TorchDistributionWrapper(ActionDistribution):
@@ -305,3 +309,102 @@ class TorchDeterministic(TorchDistributionWrapper):
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape)
class TorchMultiActionDistribution(TorchDistributionWrapper):
"""Action distribution that operates on multiple, possibly nested actions.
"""
def __init__(self, inputs, model, *, child_distributions, input_lens,
action_space):
"""Initializes a TorchMultiActionDistribution object.
Args:
inputs (torch.Tensor): A single tensor of shape [BATCH, size].
model (ModelV2): The ModelV2 object used to produce inputs for this
distribution.
child_distributions (any[torch.Tensor]): Any struct
that contains the child distribution classes to use to
instantiate the child distributions from `inputs`. This could
be an already flattened list or a struct according to
`action_space`.
input_lens (any[int]): A flat list or a nested struct of input
split lengths used to split `inputs`.
action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
and possibly nested action space.
"""
if not isinstance(inputs, torch.Tensor):
inputs = torch.Tensor(inputs)
super().__init__(inputs, model)
self.action_space_struct = get_base_struct_from_space(action_space)
input_lens = tree.flatten(input_lens)
flat_child_distributions = tree.flatten(child_distributions)
split_inputs = torch.split(inputs, input_lens, dim=1)
self.flat_child_distributions = tree.map_structure(
lambda dist, input_: dist(input_, model), flat_child_distributions,
list(split_inputs))
@override(ActionDistribution)
def logp(self, x):
if isinstance(x, np.ndarray):
x = torch.Tensor(x)
# Single tensor input (all merged).
if isinstance(x, torch.Tensor):
split_indices = []
for dist in self.flat_child_distributions:
if isinstance(dist, TorchCategorical):
split_indices.append(1)
else:
split_indices.append(dist.sample().size()[1])
split_x = list(torch.split(x, split_indices, dim=1))
# Structured or flattened (by single action component) input.
else:
split_x = tree.flatten(x)
def map_(val, dist):
# Remove extra categorical dimension.
if isinstance(dist, TorchCategorical):
val = torch.squeeze(val, dim=-1).int()
return dist.logp(val)
# Remove extra categorical dimension and take the logp of each
# component.
flat_logps = tree.map_structure(map_, split_x,
self.flat_child_distributions)
return functools.reduce(lambda a, b: a + b, flat_logps)
@override(ActionDistribution)
def kl(self, other):
kl_list = [
d.kl(o) for d, o in zip(self.flat_child_distributions,
other.flat_child_distributions)
]
return functools.reduce(lambda a, b: a + b, kl_list)
@override(ActionDistribution)
def entropy(self):
entropy_list = [d.entropy() for d in self.flat_child_distributions]
return functools.reduce(lambda a, b: a + b, entropy_list)
@override(ActionDistribution)
def sample(self):
child_distributions = tree.unflatten_as(self.action_space_struct,
self.flat_child_distributions)
return tree.map_structure(lambda s: s.sample(), child_distributions)
@override(ActionDistribution)
def deterministic_sample(self):
child_distributions = tree.unflatten_as(self.action_space_struct,
self.flat_child_distributions)
return tree.map_structure(lambda s: s.deterministic_sample(),
child_distributions)
@override(TorchDistributionWrapper)
def sampled_action_logp(self):
p = self.flat_child_distributions[0].sampled_action_logp()
for c in self.flat_child_distributions[1:]:
p += c.sampled_action_logp()
return p
+2 -2
View File
@@ -7,13 +7,13 @@ import logging
import numpy as np
from ray.util.debug import log_once
from ray.rllib.evaluation.episode import _flatten_action
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.space_utils import flatten_to_single_ndarray
tf = try_import_tf()
logger = logging.getLogger(__name__)
@@ -241,7 +241,7 @@ def build_eager_tf_policy(name,
SampleBatch.CUR_OBS: tf.convert_to_tensor(
np.array([observation_space.sample()])),
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
[_flatten_action(action_space.sample())]),
[flatten_to_single_ndarray(action_space.sample())]),
SampleBatch.PREV_REWARDS: tf.convert_to_tensor([0.]),
}
+3 -3
View File
@@ -13,10 +13,10 @@ import gym
import ray
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.episode import _flatten_action
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.space_utils import flatten_to_single_ndarray
from ray.tune.utils import merge_dicts
from ray.tune.registry import get_trainable_cls
@@ -362,7 +362,7 @@ def rollout(agent,
use_lstm = {DEFAULT_POLICY_ID: False}
action_init = {
p: _flatten_action(m.action_space.sample())
p: flatten_to_single_ndarray(m.action_space.sample())
for p, m in policy_map.items()
}
@@ -411,7 +411,7 @@ def rollout(agent,
prev_action=prev_actions[agent_id],
prev_reward=prev_rewards[agent_id],
policy_id=policy_id)
a_action = _flatten_action(a_action) # tuple actions
a_action = flatten_to_single_ndarray(a_action)
action_dict[agent_id] = a_action
prev_actions[agent_id] = a_action
action = action_dict
+12 -4
View File
@@ -14,8 +14,6 @@ from ray.rllib.utils.policy_server import PolicyServer
from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \
PolynomialSchedule, ExponentialSchedule, ConstantSchedule
from ray.rllib.utils.test_utils import check, framework_iterator
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
convert_to_torch_tensor
from ray.tune.utils import merge_dicts, deep_update
@@ -58,12 +56,21 @@ def force_list(elements=None, to_tuple=False):
force_tuple = partial(force_list, to_tuple=True)
# TODO(sven): remove at some point.
def try_import_tree():
try:
import tree
return tree
except (ImportError, ModuleNotFoundError):
raise ModuleNotFoundError(
"`dm-tree` is not installed! Run `pip install dm-tree`.")
__all__ = [
"add_mixins",
"check",
"check_framework",
"convert_to_non_torch_type",
"convert_to_torch_tensor",
"deprecation_warning",
"fc",
"force_list",
@@ -83,6 +90,7 @@ __all__ = [
"try_import_tf",
"try_import_tfp",
"try_import_torch",
"try_import_tree",
"ConstantSchedule",
"DeveloperAPI",
"ExponentialSchedule",
+94
View File
@@ -0,0 +1,94 @@
from gym.spaces import Tuple, Dict
import numpy as np
from ray.rllib.utils import try_import_tree
tree = try_import_tree()
def flatten_space(space):
"""Flattens a gym.Space into its primitive components.
Primitive components are any non Tuple/Dict spaces.
Args:
space(gym.Space): The gym.Space to flatten. This may be any
supported type (including nested Tuples and Dicts).
Returns:
List[gym.Space]: The flattened list of primitive Spaces. This list
does not contain Tuples or Dicts anymore.
"""
def _helper_flatten(space_, l):
if isinstance(space_, Tuple):
for s in space_:
_helper_flatten(s, l)
elif isinstance(space_, Dict):
for k in space_.spaces:
_helper_flatten(space_[k], l)
else:
l.append(space_)
ret = []
_helper_flatten(space, ret)
return ret
def get_base_struct_from_space(space):
"""Returns a Tuple/Dict Space as native (equally structured) py tuple/dict.
Args:
space (gym.Space): The Space to get the python struct for.
Returns:
Union[dict,tuple,gym.Space]: The struct equivalent to the given Space.
Note that the returned struct still contains all original
"primitive" Spaces (e.g. Box, Discrete).
Examples:
>>> get_base_struct_from_space(Dict({
>>> "a": Box(),
>>> "b": Tuple([Discrete(2), Discrete(3)])
>>> }))
>>> # Will return: dict(a=Box(), b=tuple(Discrete(2), Discrete(3)))
"""
def _helper_struct(space_):
if isinstance(space_, Tuple):
return tuple(_helper_struct(s) for s in space_)
elif isinstance(space_, Dict):
return {k: _helper_struct(space_[k]) for k in space_.spaces}
else:
return space_
return _helper_struct(space)
def flatten_to_single_ndarray(input_):
"""Returns a single np.ndarray given a list/tuple of np.ndarrays.
Args:
input_ (Union[List[np.ndarray],np.ndarray]): The list of ndarrays or
a single ndarray.
Returns:
np.ndarray: The result after concatenating all single arrays in input_.
Examples:
>>> flatten_to_single_ndarray([
>>> np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
>>> np.array([7, 8, 9]),
>>> ])
>>> # Will return:
>>> # np.array([
>>> # 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0
>>> # ])
"""
# Concatenate tuple actions
if isinstance(input_, (list, tuple)):
expanded = []
for in_ in input_:
expanded.append(np.reshape(in_, [-1]))
input_ = np.concatenate(expanded, axis=0).flatten()
return input_