From 592c161032bfd697ef6f2f334c5cd007062bec4b Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Wed, 25 Nov 2020 20:27:46 +0100 Subject: [PATCH] [RLlib] Issue 12118: LSTM prev-a/r should be separately configurable. Fix missing prev-a one-hot encoding. (#12397) * WIP. * Fix and LINT. --- rllib/BUILD | 2 +- rllib/agents/impala/tests/test_impala.py | 3 +- rllib/agents/ppo/tests/test_ppo.py | 3 +- rllib/agents/trainer.py | 14 ++++- .../tests/test_trajectory_view_api.py | 3 +- rllib/examples/cartpole_lstm.py | 6 ++- rllib/models/catalog.py | 11 +++- rllib/models/tf/recurrent_net.py | 54 ++++++++++++------- rllib/models/torch/recurrent_net.py | 53 +++++++++++------- rllib/tests/test_multi_agent_env.py | 27 +++------- rllib/tests/test_nested_observation_spaces.py | 18 +++---- rllib/utils/numpy.py | 15 ++++-- rllib/utils/tf_ops.py | 12 +++++ 13 files changed, 140 insertions(+), 81 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 85e1125c6..89c784e25 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1604,7 +1604,7 @@ py_test( tags = ["examples", "examples_C"], size = "large", srcs = ["examples/cartpole_lstm.py"], - args = ["--as-test", "--run=PPO", "--stop-reward=40", "--use-prev-action-reward", "--num-cpus=4"] + args = ["--as-test", "--run=PPO", "--stop-reward=40", "--use-prev-action", "--use-prev-reward", "--num-cpus=4"] ) py_test( diff --git a/rllib/agents/impala/tests/test_impala.py b/rllib/agents/impala/tests/test_impala.py index 128b67c68..a9697c50b 100644 --- a/rllib/agents/impala/tests/test_impala.py +++ b/rllib/agents/impala/tests/test_impala.py @@ -40,7 +40,8 @@ class TestIMPALA(unittest.TestCase): # Test w/ LSTM. print("w/ LSTM") local_cfg["model"]["use_lstm"] = True - local_cfg["model"]["lstm_use_prev_action_reward"] = True + local_cfg["model"]["lstm_use_prev_action"] = True + local_cfg["model"]["lstm_use_prev_reward"] = True local_cfg["num_aggregation_workers"] = 2 trainer = impala.ImpalaTrainer(config=local_cfg, env=env) for i in range(num_iterations): diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index 29d47fa61..50e3b99bc 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -62,7 +62,8 @@ class TestPPO(unittest.TestCase): for lstm in [True, False]: print("LSTM={}".format(lstm)) config["model"]["use_lstm"] = lstm - config["model"]["lstm_use_prev_action_reward"] = lstm + config["model"]["lstm_use_prev_action"] = lstm + config["model"]["lstm_use_prev_reward"] = lstm trainer = ppo.PPOTrainer(config=config, env=env) for i in range(num_iterations): trainer.train() diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index a86d1107c..a404e7ecb 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -24,6 +24,7 @@ from ray.rllib.utils import FilterManager, deep_update, merge_dicts from ray.rllib.utils.spaces import space_utils from ray.rllib.utils.framework import try_import_tf, TensorStructType from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI +from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.from_config import from_config from ray.rllib.utils.typing import TrainerConfigDict, \ PartialTrainerConfigDict, EnvInfoDict, ResultDict, EnvType, PolicyID @@ -1054,9 +1055,20 @@ class Trainer(Trainable): if type(config["input_evaluation"]) != list: raise ValueError( - "`input_evaluation` must be a list of strings, got {}".format( + "`input_evaluation` must be a list of strings, got {}!".format( config["input_evaluation"])) + # Check model config. + prev_a_r = config.get("model", {}).get("lstm_use_prev_action_reward", + DEPRECATED_VALUE) + if prev_a_r != DEPRECATED_VALUE: + deprecation_warning( + "model.lstm_use_prev_action_reward", + "model.lstm_use_prev_action and model.lstm_use_prev_reward", + error=False) + config["model"]["lstm_use_prev_action"] = prev_a_r + config["model"]["lstm_use_prev_reward"] = prev_a_r + def _try_recover(self): """Try to identify and remove any unhealthy workers. diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index ae9fd2248..7897b3226 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -77,7 +77,8 @@ class TestTrajectoryViewAPI(unittest.TestCase): config["model"] = config["model"].copy() # Activate LSTM + prev-action + rewards. config["model"]["use_lstm"] = True - config["model"]["lstm_use_prev_action_reward"] = True + config["model"]["lstm_use_prev_action"] = True + config["model"]["lstm_use_prev_reward"] = True for _ in framework_iterator(config): trainer = ppo.PPOTrainer(config, env="CartPole-v0") diff --git a/rllib/examples/cartpole_lstm.py b/rllib/examples/cartpole_lstm.py index 2df09d73f..1c9edc655 100644 --- a/rllib/examples/cartpole_lstm.py +++ b/rllib/examples/cartpole_lstm.py @@ -10,7 +10,8 @@ parser.add_argument("--num-cpus", type=int, default=0) parser.add_argument( "--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf") parser.add_argument("--as-test", action="store_true") -parser.add_argument("--use-prev-action-reward", action="store_true") +parser.add_argument("--use-prev-action", action="store_true") +parser.add_argument("--use-prev-reward", action="store_true") parser.add_argument("--stop-iters", type=int, default=200) parser.add_argument("--stop-timesteps", type=int, default=100000) parser.add_argument("--stop-reward", type=float, default=150.0) @@ -44,7 +45,8 @@ if __name__ == "__main__": "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "model": { "use_lstm": True, - "lstm_use_prev_action_reward": args.use_prev_action_reward, + "lstm_use_prev_action": args.use_prev_action, + "lstm_use_prev_reward": args.use_prev_reward, }, "framework": args.framework, # Run with tracing enabled for tfe/tf2. diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 8c241a3fa..43d115f4f 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -18,6 +18,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \ TorchDeterministic, TorchDiagGaussian, \ TorchMultiActionDistribution, TorchMultiCategorical from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI +from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.spaces.simplex import Simplex @@ -58,8 +59,10 @@ MODEL_DEFAULTS: ModelConfigDict = { "max_seq_len": 20, # Size of the LSTM cell. "lstm_cell_size": 256, - # Whether to feed a_{t-1}, r_{t-1} to LSTM. - "lstm_use_prev_action_reward": False, + # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete). + "lstm_use_prev_action": False, + # Whether to feed r_{t-1} to LSTM. + "lstm_use_prev_reward": False, # Experimental (only works with `_use_trajectory_view_api`=True): # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..). "_time_major": False, @@ -87,6 +90,10 @@ MODEL_DEFAULTS: ModelConfigDict = { # Custom preprocessors are deprecated. Please use a wrapper class around # your environment instead to preprocess observations. "custom_preprocessor": None, + + # Deprecated keys: + # Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead. + "lstm_use_prev_action_reward": DEPRECATED_VALUE, } # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index d931c7ac6..f939c7ae3 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -1,5 +1,6 @@ import numpy as np import gym +from gym.spaces import Discrete, MultiDiscrete from typing import Dict, List from ray.rllib.models.modelv2 import ModelV2 @@ -9,6 +10,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_ops import one_hot from ray.rllib.utils.typing import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() @@ -119,15 +121,23 @@ class LSTMWrapper(RecurrentNetwork): model_config, name) self.cell_size = model_config["lstm_cell_size"] - self.use_prev_action_reward = model_config[ - "lstm_use_prev_action_reward"] - if action_space.shape is not None: + self.use_prev_action = model_config["lstm_use_prev_action"] + self.use_prev_reward = model_config["lstm_use_prev_reward"] + + if isinstance(action_space, Discrete): + self.action_dim = action_space.n + elif isinstance(action_space, MultiDiscrete): + self.action_dim = np.product(action_space.nvec) + elif action_space.shape is not None: self.action_dim = int(np.product(action_space.shape)) else: self.action_dim = int(len(action_space)) + # Add prev-action/reward nodes to input to LSTM. - if self.use_prev_action_reward: - self.num_outputs += 1 + self.action_dim + if self.use_prev_action: + self.num_outputs += self.action_dim + if self.use_prev_reward: + self.num_outputs += 1 # Define input layers. input_layer = tf.keras.layers.Input( @@ -165,12 +175,13 @@ class LSTMWrapper(RecurrentNetwork): self._rnn_model.summary() # Add prev-a/r to this model's view, if required. - if model_config["lstm_use_prev_action_reward"]: - self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ - ViewRequirement(SampleBatch.REWARDS, shift=-1) + if model_config["lstm_use_prev_action"]: self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \ ViewRequirement(SampleBatch.ACTIONS, space=self.action_space, shift=-1) + if model_config["lstm_use_prev_reward"]: + self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ + ViewRequirement(SampleBatch.REWARDS, shift=-1) @override(RecurrentNetwork) def forward(self, input_dict: Dict[str, TensorType], @@ -181,18 +192,21 @@ class LSTMWrapper(RecurrentNetwork): wrapped_out, _ = self._wrapped_forward(input_dict, [], None) # Concat. prev-action/reward if required. - if self.model_config["lstm_use_prev_action_reward"]: - wrapped_out = tf.concat( - [ - wrapped_out, - tf.reshape( - tf.cast(input_dict[SampleBatch.PREV_ACTIONS], - tf.float32), [-1, self.action_dim]), - tf.reshape( - tf.cast(input_dict[SampleBatch.PREV_REWARDS], - tf.float32), [-1, 1]), - ], - axis=1) + prev_a_r = [] + if self.model_config["lstm_use_prev_action"]: + prev_a = input_dict[SampleBatch.PREV_ACTIONS] + if isinstance(self.action_space, (Discrete, MultiDiscrete)): + prev_a = one_hot(prev_a, self.action_space) + prev_a_r.append( + tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])) + if self.model_config["lstm_use_prev_reward"]: + prev_a_r.append( + tf.reshape( + tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), + [-1, 1])) + + if prev_a_r: + wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1) # Then through our LSTM. input_dict["obs_flat"] = wrapped_out diff --git a/rllib/models/torch/recurrent_net.py b/rllib/models/torch/recurrent_net.py index 6dcead835..d558bf3db 100644 --- a/rllib/models/torch/recurrent_net.py +++ b/rllib/models/torch/recurrent_net.py @@ -1,5 +1,6 @@ import numpy as np import gym +from gym.spaces import Discrete, MultiDiscrete from typing import Dict, List, Union from ray.rllib.models.modelv2 import ModelV2 @@ -10,6 +11,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_ops import one_hot from ray.rllib.utils.typing import ModelConfigDict, TensorType torch, nn = try_import_torch() @@ -118,15 +120,24 @@ class LSTMWrapper(RecurrentNetwork, nn.Module): self.cell_size = model_config["lstm_cell_size"] self.time_major = model_config.get("_time_major", False) - self.use_prev_action_reward = model_config[ - "lstm_use_prev_action_reward"] - if action_space.shape is not None: + self.use_prev_action = model_config["lstm_use_prev_action"] + self.use_prev_reward = model_config["lstm_use_prev_reward"] + + if isinstance(action_space, Discrete): + self.action_dim = action_space.n + elif isinstance(action_space, MultiDiscrete): + self.action_dim = np.product(action_space.nvec) + elif action_space.shape is not None: self.action_dim = int(np.product(action_space.shape)) else: self.action_dim = int(len(action_space)) + # Add prev-action/reward nodes to input to LSTM. - if self.use_prev_action_reward: - self.num_outputs += 1 + self.action_dim + if self.use_prev_action: + self.num_outputs += self.action_dim + if self.use_prev_reward: + self.num_outputs += 1 + self.lstm = nn.LSTM( self.num_outputs, self.cell_size, batch_first=not self.time_major) @@ -145,12 +156,13 @@ class LSTMWrapper(RecurrentNetwork, nn.Module): initializer=torch.nn.init.xavier_uniform_) # Add prev-a/r to this model's view, if required. - if model_config["lstm_use_prev_action_reward"]: - self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ - ViewRequirement(SampleBatch.REWARDS, shift=-1) + if model_config["lstm_use_prev_action"]: self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \ ViewRequirement(SampleBatch.ACTIONS, space=self.action_space, shift=-1) + if model_config["lstm_use_prev_reward"]: + self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ + ViewRequirement(SampleBatch.REWARDS, shift=-1) @override(RecurrentNetwork) def forward(self, input_dict: Dict[str, TensorType], @@ -161,16 +173,21 @@ class LSTMWrapper(RecurrentNetwork, nn.Module): wrapped_out, _ = self._wrapped_forward(input_dict, [], None) # Concat. prev-action/reward if required. - if self.model_config["lstm_use_prev_action_reward"]: - wrapped_out = torch.cat( - [ - wrapped_out, - torch.reshape(input_dict[SampleBatch.PREV_ACTIONS].float(), - [-1, self.action_dim]), - torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(), - [-1, 1]), - ], - dim=1) + prev_a_r = [] + if self.model_config["lstm_use_prev_action"]: + if isinstance(self.action_space, (Discrete, MultiDiscrete)): + prev_a = one_hot(input_dict[SampleBatch.PREV_ACTIONS].float(), + self.action_space) + else: + prev_a = input_dict[SampleBatch.PREV_ACTIONS].float() + prev_a_r.append(torch.reshape(prev_a, [-1, self.action_dim])) + if self.model_config["lstm_use_prev_reward"]: + prev_a_r.append( + torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(), + [-1, 1])) + + if prev_a_r: + wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1) # Then through our LSTM. input_dict["obs_flat"] = wrapped_out diff --git a/rllib/tests/test_multi_agent_env.py b/rllib/tests/test_multi_agent_env.py index 6fe112909..a8c5bf3be 100644 --- a/rllib/tests/test_multi_agent_env.py +++ b/rllib/tests/test_multi_agent_env.py @@ -1,4 +1,5 @@ import gym +import numpy as np import random import unittest @@ -12,12 +13,8 @@ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \ from ray.rllib.tests.test_rollout_worker import MockPolicy from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv - - -def one_hot(i, n): - out = [0.0] * n - out[i] = 1.0 - return out +from ray.rllib.utils.numpy import one_hot +from ray.rllib.utils.test_utils import check class TestMultiAgentEnv(unittest.TestCase): @@ -270,20 +267,10 @@ class TestMultiAgentEnv(unittest.TestCase): # since we round robin introduce agents into the env, some of the env # steps don't count as proper transitions self.assertEqual(batch.policy_batches["p0"].count, 42) - self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10], [ - one_hot(0, 10), - one_hot(1, 10), - one_hot(2, 10), - one_hot(3, 10), - one_hot(4, 10), - ] * 2) - self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10], [ - one_hot(1, 10), - one_hot(2, 10), - one_hot(3, 10), - one_hot(4, 10), - one_hot(5, 10), - ] * 2) + check(batch.policy_batches["p0"]["obs"][:10], + one_hot(np.array([0, 1, 2, 3, 4] * 2), 10)) + check(batch.policy_batches["p0"]["new_obs"][:10], + one_hot(np.array([1, 2, 3, 4, 5] * 2), 10)) self.assertEqual(batch.policy_batches["p0"]["rewards"].tolist()[:10], [100, 100, 100, 100, 0] * 2) self.assertEqual(batch.policy_batches["p0"]["dones"].tolist()[:10], diff --git a/rllib/tests/test_nested_observation_spaces.py b/rllib/tests/test_nested_observation_spaces.py index 40849efd3..74814fd17 100644 --- a/rllib/tests/test_nested_observation_spaces.py +++ b/rllib/tests/test_nested_observation_spaces.py @@ -20,7 +20,9 @@ from ray.rllib.rollout import rollout from ray.rllib.tests.test_external_env import SimpleServing from ray.tune.registry import register_env from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.spaces.repeated import Repeated +from ray.rllib.utils.test_utils import check tf1, tf, tfv = try_import_tf() _, nn = try_import_torch() @@ -69,12 +71,6 @@ REPEATED_SPACE = Repeated(PLAYER_SPACE, max_len=MAX_PLAYERS) REPEATED_SAMPLES = [REPEATED_SPACE.sample() for _ in range(10)] -def one_hot(i, n): - out = [0.0] * n - out[i] = 1.0 - return out - - class NestedDictEnv(gym.Env): def __init__(self): self.action_space = spaces.Discrete(2) @@ -354,7 +350,7 @@ class NestedSpacesTest(unittest.TestCase): DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5) self.assertEqual(seen[0][0].tolist(), pos_i) self.assertEqual(seen[1][0].tolist(), cam_i) - self.assertEqual(seen[2][0].tolist(), task_i) + check(seen[2][0], task_i) def do_test_nested_tuple(self, make_env): ModelCatalog.register_custom_model("composite2", TupleSpyModel) @@ -385,7 +381,7 @@ class NestedSpacesTest(unittest.TestCase): task_i = one_hot(TUPLE_SAMPLES[i][2], 5) self.assertEqual(seen[0][0].tolist(), pos_i) self.assertEqual(seen[1][0].tolist(), cam_i) - self.assertEqual(seen[2][0].tolist(), task_i) + check(seen[2][0], task_i) def test_nested_dict_gym(self): self.do_test_nested_dict(lambda _: NestedDictEnv()) @@ -459,7 +455,7 @@ class NestedSpacesTest(unittest.TestCase): DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5) self.assertEqual(seen[0][0].tolist(), pos_i) self.assertEqual(seen[1][0].tolist(), cam_i) - self.assertEqual(seen[2][0].tolist(), task_i) + check(seen[2][0], task_i) for i in range(4): seen = pickle.loads( @@ -470,7 +466,7 @@ class NestedSpacesTest(unittest.TestCase): task_i = one_hot(TUPLE_SAMPLES[i][2], 5) self.assertEqual(seen[0][0].tolist(), pos_i) self.assertEqual(seen[1][0].tolist(), cam_i) - self.assertEqual(seen[2][0].tolist(), task_i) + check(seen[2][0], task_i) def test_rollout_dict_space(self): register_env("nested", lambda _: NestedDictEnv()) @@ -521,7 +517,7 @@ class NestedSpacesTest(unittest.TestCase): # the ray-kv indices before training. self.assertEqual(seen[0][-1].tolist(), pos_i) self.assertEqual(seen[1][-1].tolist(), cam_i) - self.assertEqual(seen[2][-1].tolist(), task_i) + check(seen[2][-1], task_i) # TODO(ekl) should probably also add a test for TF/eager def test_torch_repeated(self): diff --git a/rllib/utils/numpy.py b/rllib/utils/numpy.py index cdb78f982..61fcb2a11 100644 --- a/rllib/utils/numpy.py +++ b/rllib/utils/numpy.py @@ -2,6 +2,7 @@ import numpy as np import tree from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.typing import TensorType, Union tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -88,14 +89,17 @@ def relu(x, alpha=0.0): return np.maximum(x, x * alpha, x) -def one_hot(x, depth=0, on_value=1, off_value=0): +def one_hot(x: Union[TensorType, int], + depth: int = 0, + on_value: int = 1.0, + off_value: float = 0.0): """ One-hot utility function for numpy. Thanks to qianyizhang: https://gist.github.com/qianyizhang/07ee1c15cad08afb03f5de69349efc30. Args: - x (np.ndarray): The input to be one-hot encoded. + x (TensorType): The input to be one-hot encoded. depth (int): The max. number to be one-hot encoded (size of last rank). on_value (float): The value to use for on. Default: 1.0. off_value (float): The value to use for off. Default: 0.0. @@ -103,8 +107,12 @@ def one_hot(x, depth=0, on_value=1, off_value=0): Returns: np.ndarray: The one-hot encoded equivalent of the input array. """ + + # Handle simple ints properly. + if isinstance(x, int): + x = np.array(x, dtype=np.int32) # Handle torch arrays properly. - if torch and isinstance(x, torch.Tensor): + elif torch and isinstance(x, torch.Tensor): x = x.numpy() # Handle bool arrays correctly. @@ -112,6 +120,7 @@ def one_hot(x, depth=0, on_value=1, off_value=0): x = x.astype(np.int) depth = 2 + # If depth is not given, try to infer it from the values in the array. if depth == 0: depth = np.max(x) + 1 assert np.max(x) < depth, \ diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index b93a25e82..39d2ce003 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -1,4 +1,5 @@ import gym +from gym.spaces import Discrete, MultiDiscrete import numpy as np import tree @@ -65,6 +66,17 @@ def huber_loss(x, delta=1.0): tf.math.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta)) +def one_hot(x, space): + if isinstance(space, Discrete): + return tf.one_hot(x, space.n) + elif isinstance(space, MultiDiscrete): + return tf.concat( + [tf.one_hot(x[:, i], n) for i, n in enumerate(space.nvec)], + axis=-1) + else: + raise ValueError("Unsupported space for `one_hot`: {}".format(space)) + + def reduce_mean_ignore_inf(x, axis): """Same as tf.reduce_mean() but ignores -inf values.""" mask = tf.not_equal(x, tf.float32.min)