mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 08:27:21 +08:00
[RLlib] Issue 12118: LSTM prev-a/r should be separately configurable. Fix missing prev-a one-hot encoding. (#12397)
* WIP. * Fix and LINT.
This commit is contained in:
+1
-1
@@ -1604,7 +1604,7 @@ py_test(
|
||||
tags = ["examples", "examples_C"],
|
||||
size = "large",
|
||||
srcs = ["examples/cartpole_lstm.py"],
|
||||
args = ["--as-test", "--run=PPO", "--stop-reward=40", "--use-prev-action-reward", "--num-cpus=4"]
|
||||
args = ["--as-test", "--run=PPO", "--stop-reward=40", "--use-prev-action", "--use-prev-reward", "--num-cpus=4"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
||||
@@ -40,7 +40,8 @@ class TestIMPALA(unittest.TestCase):
|
||||
# Test w/ LSTM.
|
||||
print("w/ LSTM")
|
||||
local_cfg["model"]["use_lstm"] = True
|
||||
local_cfg["model"]["lstm_use_prev_action_reward"] = True
|
||||
local_cfg["model"]["lstm_use_prev_action"] = True
|
||||
local_cfg["model"]["lstm_use_prev_reward"] = True
|
||||
local_cfg["num_aggregation_workers"] = 2
|
||||
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
|
||||
for i in range(num_iterations):
|
||||
|
||||
@@ -62,7 +62,8 @@ class TestPPO(unittest.TestCase):
|
||||
for lstm in [True, False]:
|
||||
print("LSTM={}".format(lstm))
|
||||
config["model"]["use_lstm"] = lstm
|
||||
config["model"]["lstm_use_prev_action_reward"] = lstm
|
||||
config["model"]["lstm_use_prev_action"] = lstm
|
||||
config["model"]["lstm_use_prev_reward"] = lstm
|
||||
trainer = ppo.PPOTrainer(config=config, env=env)
|
||||
for i in range(num_iterations):
|
||||
trainer.train()
|
||||
|
||||
+13
-1
@@ -24,6 +24,7 @@ from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.rllib.utils.spaces import space_utils
|
||||
from ray.rllib.utils.framework import try_import_tf, TensorStructType
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, \
|
||||
PartialTrainerConfigDict, EnvInfoDict, ResultDict, EnvType, PolicyID
|
||||
@@ -1054,9 +1055,20 @@ class Trainer(Trainable):
|
||||
|
||||
if type(config["input_evaluation"]) != list:
|
||||
raise ValueError(
|
||||
"`input_evaluation` must be a list of strings, got {}".format(
|
||||
"`input_evaluation` must be a list of strings, got {}!".format(
|
||||
config["input_evaluation"]))
|
||||
|
||||
# Check model config.
|
||||
prev_a_r = config.get("model", {}).get("lstm_use_prev_action_reward",
|
||||
DEPRECATED_VALUE)
|
||||
if prev_a_r != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
"model.lstm_use_prev_action_reward",
|
||||
"model.lstm_use_prev_action and model.lstm_use_prev_reward",
|
||||
error=False)
|
||||
config["model"]["lstm_use_prev_action"] = prev_a_r
|
||||
config["model"]["lstm_use_prev_reward"] = prev_a_r
|
||||
|
||||
def _try_recover(self):
|
||||
"""Try to identify and remove any unhealthy workers.
|
||||
|
||||
|
||||
@@ -77,7 +77,8 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
||||
config["model"] = config["model"].copy()
|
||||
# Activate LSTM + prev-action + rewards.
|
||||
config["model"]["use_lstm"] = True
|
||||
config["model"]["lstm_use_prev_action_reward"] = True
|
||||
config["model"]["lstm_use_prev_action"] = True
|
||||
config["model"]["lstm_use_prev_reward"] = True
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
|
||||
|
||||
@@ -10,7 +10,8 @@ parser.add_argument("--num-cpus", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
|
||||
parser.add_argument("--as-test", action="store_true")
|
||||
parser.add_argument("--use-prev-action-reward", action="store_true")
|
||||
parser.add_argument("--use-prev-action", action="store_true")
|
||||
parser.add_argument("--use-prev-reward", action="store_true")
|
||||
parser.add_argument("--stop-iters", type=int, default=200)
|
||||
parser.add_argument("--stop-timesteps", type=int, default=100000)
|
||||
parser.add_argument("--stop-reward", type=float, default=150.0)
|
||||
@@ -44,7 +45,8 @@ if __name__ == "__main__":
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"model": {
|
||||
"use_lstm": True,
|
||||
"lstm_use_prev_action_reward": args.use_prev_action_reward,
|
||||
"lstm_use_prev_action": args.use_prev_action,
|
||||
"lstm_use_prev_reward": args.use_prev_reward,
|
||||
},
|
||||
"framework": args.framework,
|
||||
# Run with tracing enabled for tfe/tf2.
|
||||
|
||||
@@ -18,6 +18,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
||||
TorchDeterministic, TorchDiagGaussian, \
|
||||
TorchMultiActionDistribution, TorchMultiCategorical
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
@@ -58,8 +59,10 @@ MODEL_DEFAULTS: ModelConfigDict = {
|
||||
"max_seq_len": 20,
|
||||
# Size of the LSTM cell.
|
||||
"lstm_cell_size": 256,
|
||||
# Whether to feed a_{t-1}, r_{t-1} to LSTM.
|
||||
"lstm_use_prev_action_reward": False,
|
||||
# Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete).
|
||||
"lstm_use_prev_action": False,
|
||||
# Whether to feed r_{t-1} to LSTM.
|
||||
"lstm_use_prev_reward": False,
|
||||
# Experimental (only works with `_use_trajectory_view_api`=True):
|
||||
# Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
|
||||
"_time_major": False,
|
||||
@@ -87,6 +90,10 @@ MODEL_DEFAULTS: ModelConfigDict = {
|
||||
# Custom preprocessors are deprecated. Please use a wrapper class around
|
||||
# your environment instead to preprocess observations.
|
||||
"custom_preprocessor": None,
|
||||
|
||||
# Deprecated keys:
|
||||
# Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
|
||||
"lstm_use_prev_action_reward": DEPRECATED_VALUE,
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
from gym.spaces import Discrete, MultiDiscrete
|
||||
from typing import Dict, List
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
@@ -9,6 +10,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import one_hot
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
@@ -119,15 +121,23 @@ class LSTMWrapper(RecurrentNetwork):
|
||||
model_config, name)
|
||||
|
||||
self.cell_size = model_config["lstm_cell_size"]
|
||||
self.use_prev_action_reward = model_config[
|
||||
"lstm_use_prev_action_reward"]
|
||||
if action_space.shape is not None:
|
||||
self.use_prev_action = model_config["lstm_use_prev_action"]
|
||||
self.use_prev_reward = model_config["lstm_use_prev_reward"]
|
||||
|
||||
if isinstance(action_space, Discrete):
|
||||
self.action_dim = action_space.n
|
||||
elif isinstance(action_space, MultiDiscrete):
|
||||
self.action_dim = np.product(action_space.nvec)
|
||||
elif action_space.shape is not None:
|
||||
self.action_dim = int(np.product(action_space.shape))
|
||||
else:
|
||||
self.action_dim = int(len(action_space))
|
||||
|
||||
# Add prev-action/reward nodes to input to LSTM.
|
||||
if self.use_prev_action_reward:
|
||||
self.num_outputs += 1 + self.action_dim
|
||||
if self.use_prev_action:
|
||||
self.num_outputs += self.action_dim
|
||||
if self.use_prev_reward:
|
||||
self.num_outputs += 1
|
||||
|
||||
# Define input layers.
|
||||
input_layer = tf.keras.layers.Input(
|
||||
@@ -165,12 +175,13 @@ class LSTMWrapper(RecurrentNetwork):
|
||||
self._rnn_model.summary()
|
||||
|
||||
# Add prev-a/r to this model's view, if required.
|
||||
if model_config["lstm_use_prev_action_reward"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||||
if model_config["lstm_use_prev_action"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||||
shift=-1)
|
||||
if model_config["lstm_use_prev_reward"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
@@ -181,18 +192,21 @@ class LSTMWrapper(RecurrentNetwork):
|
||||
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
|
||||
|
||||
# Concat. prev-action/reward if required.
|
||||
if self.model_config["lstm_use_prev_action_reward"]:
|
||||
wrapped_out = tf.concat(
|
||||
[
|
||||
wrapped_out,
|
||||
tf.reshape(
|
||||
tf.cast(input_dict[SampleBatch.PREV_ACTIONS],
|
||||
tf.float32), [-1, self.action_dim]),
|
||||
tf.reshape(
|
||||
tf.cast(input_dict[SampleBatch.PREV_REWARDS],
|
||||
tf.float32), [-1, 1]),
|
||||
],
|
||||
axis=1)
|
||||
prev_a_r = []
|
||||
if self.model_config["lstm_use_prev_action"]:
|
||||
prev_a = input_dict[SampleBatch.PREV_ACTIONS]
|
||||
if isinstance(self.action_space, (Discrete, MultiDiscrete)):
|
||||
prev_a = one_hot(prev_a, self.action_space)
|
||||
prev_a_r.append(
|
||||
tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim]))
|
||||
if self.model_config["lstm_use_prev_reward"]:
|
||||
prev_a_r.append(
|
||||
tf.reshape(
|
||||
tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
|
||||
[-1, 1]))
|
||||
|
||||
if prev_a_r:
|
||||
wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
|
||||
|
||||
# Then through our LSTM.
|
||||
input_dict["obs_flat"] = wrapped_out
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
from gym.spaces import Discrete, MultiDiscrete
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
@@ -10,6 +11,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import one_hot
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
@@ -118,15 +120,24 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
||||
|
||||
self.cell_size = model_config["lstm_cell_size"]
|
||||
self.time_major = model_config.get("_time_major", False)
|
||||
self.use_prev_action_reward = model_config[
|
||||
"lstm_use_prev_action_reward"]
|
||||
if action_space.shape is not None:
|
||||
self.use_prev_action = model_config["lstm_use_prev_action"]
|
||||
self.use_prev_reward = model_config["lstm_use_prev_reward"]
|
||||
|
||||
if isinstance(action_space, Discrete):
|
||||
self.action_dim = action_space.n
|
||||
elif isinstance(action_space, MultiDiscrete):
|
||||
self.action_dim = np.product(action_space.nvec)
|
||||
elif action_space.shape is not None:
|
||||
self.action_dim = int(np.product(action_space.shape))
|
||||
else:
|
||||
self.action_dim = int(len(action_space))
|
||||
|
||||
# Add prev-action/reward nodes to input to LSTM.
|
||||
if self.use_prev_action_reward:
|
||||
self.num_outputs += 1 + self.action_dim
|
||||
if self.use_prev_action:
|
||||
self.num_outputs += self.action_dim
|
||||
if self.use_prev_reward:
|
||||
self.num_outputs += 1
|
||||
|
||||
self.lstm = nn.LSTM(
|
||||
self.num_outputs, self.cell_size, batch_first=not self.time_major)
|
||||
|
||||
@@ -145,12 +156,13 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
||||
initializer=torch.nn.init.xavier_uniform_)
|
||||
|
||||
# Add prev-a/r to this model's view, if required.
|
||||
if model_config["lstm_use_prev_action_reward"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||||
if model_config["lstm_use_prev_action"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
|
||||
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
|
||||
shift=-1)
|
||||
if model_config["lstm_use_prev_reward"]:
|
||||
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
|
||||
ViewRequirement(SampleBatch.REWARDS, shift=-1)
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
@@ -161,16 +173,21 @@ class LSTMWrapper(RecurrentNetwork, nn.Module):
|
||||
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
|
||||
|
||||
# Concat. prev-action/reward if required.
|
||||
if self.model_config["lstm_use_prev_action_reward"]:
|
||||
wrapped_out = torch.cat(
|
||||
[
|
||||
wrapped_out,
|
||||
torch.reshape(input_dict[SampleBatch.PREV_ACTIONS].float(),
|
||||
[-1, self.action_dim]),
|
||||
torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(),
|
||||
[-1, 1]),
|
||||
],
|
||||
dim=1)
|
||||
prev_a_r = []
|
||||
if self.model_config["lstm_use_prev_action"]:
|
||||
if isinstance(self.action_space, (Discrete, MultiDiscrete)):
|
||||
prev_a = one_hot(input_dict[SampleBatch.PREV_ACTIONS].float(),
|
||||
self.action_space)
|
||||
else:
|
||||
prev_a = input_dict[SampleBatch.PREV_ACTIONS].float()
|
||||
prev_a_r.append(torch.reshape(prev_a, [-1, self.action_dim]))
|
||||
if self.model_config["lstm_use_prev_reward"]:
|
||||
prev_a_r.append(
|
||||
torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(),
|
||||
[-1, 1]))
|
||||
|
||||
if prev_a_r:
|
||||
wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1)
|
||||
|
||||
# Then through our LSTM.
|
||||
input_dict["obs_flat"] = wrapped_out
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
import random
|
||||
import unittest
|
||||
|
||||
@@ -12,12 +13,8 @@ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
|
||||
from ray.rllib.tests.test_rollout_worker import MockPolicy
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
|
||||
|
||||
|
||||
def one_hot(i, n):
|
||||
out = [0.0] * n
|
||||
out[i] = 1.0
|
||||
return out
|
||||
from ray.rllib.utils.numpy import one_hot
|
||||
from ray.rllib.utils.test_utils import check
|
||||
|
||||
|
||||
class TestMultiAgentEnv(unittest.TestCase):
|
||||
@@ -270,20 +267,10 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
# since we round robin introduce agents into the env, some of the env
|
||||
# steps don't count as proper transitions
|
||||
self.assertEqual(batch.policy_batches["p0"].count, 42)
|
||||
self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10], [
|
||||
one_hot(0, 10),
|
||||
one_hot(1, 10),
|
||||
one_hot(2, 10),
|
||||
one_hot(3, 10),
|
||||
one_hot(4, 10),
|
||||
] * 2)
|
||||
self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10], [
|
||||
one_hot(1, 10),
|
||||
one_hot(2, 10),
|
||||
one_hot(3, 10),
|
||||
one_hot(4, 10),
|
||||
one_hot(5, 10),
|
||||
] * 2)
|
||||
check(batch.policy_batches["p0"]["obs"][:10],
|
||||
one_hot(np.array([0, 1, 2, 3, 4] * 2), 10))
|
||||
check(batch.policy_batches["p0"]["new_obs"][:10],
|
||||
one_hot(np.array([1, 2, 3, 4, 5] * 2), 10))
|
||||
self.assertEqual(batch.policy_batches["p0"]["rewards"].tolist()[:10],
|
||||
[100, 100, 100, 100, 0] * 2)
|
||||
self.assertEqual(batch.policy_batches["p0"]["dones"].tolist()[:10],
|
||||
|
||||
@@ -20,7 +20,9 @@ from ray.rllib.rollout import rollout
|
||||
from ray.rllib.tests.test_external_env import SimpleServing
|
||||
from ray.tune.registry import register_env
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.numpy import one_hot
|
||||
from ray.rllib.utils.spaces.repeated import Repeated
|
||||
from ray.rllib.utils.test_utils import check
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
_, nn = try_import_torch()
|
||||
@@ -69,12 +71,6 @@ REPEATED_SPACE = Repeated(PLAYER_SPACE, max_len=MAX_PLAYERS)
|
||||
REPEATED_SAMPLES = [REPEATED_SPACE.sample() for _ in range(10)]
|
||||
|
||||
|
||||
def one_hot(i, n):
|
||||
out = [0.0] * n
|
||||
out[i] = 1.0
|
||||
return out
|
||||
|
||||
|
||||
class NestedDictEnv(gym.Env):
|
||||
def __init__(self):
|
||||
self.action_space = spaces.Discrete(2)
|
||||
@@ -354,7 +350,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
|
||||
self.assertEqual(seen[0][0].tolist(), pos_i)
|
||||
self.assertEqual(seen[1][0].tolist(), cam_i)
|
||||
self.assertEqual(seen[2][0].tolist(), task_i)
|
||||
check(seen[2][0], task_i)
|
||||
|
||||
def do_test_nested_tuple(self, make_env):
|
||||
ModelCatalog.register_custom_model("composite2", TupleSpyModel)
|
||||
@@ -385,7 +381,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
task_i = one_hot(TUPLE_SAMPLES[i][2], 5)
|
||||
self.assertEqual(seen[0][0].tolist(), pos_i)
|
||||
self.assertEqual(seen[1][0].tolist(), cam_i)
|
||||
self.assertEqual(seen[2][0].tolist(), task_i)
|
||||
check(seen[2][0], task_i)
|
||||
|
||||
def test_nested_dict_gym(self):
|
||||
self.do_test_nested_dict(lambda _: NestedDictEnv())
|
||||
@@ -459,7 +455,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
|
||||
self.assertEqual(seen[0][0].tolist(), pos_i)
|
||||
self.assertEqual(seen[1][0].tolist(), cam_i)
|
||||
self.assertEqual(seen[2][0].tolist(), task_i)
|
||||
check(seen[2][0], task_i)
|
||||
|
||||
for i in range(4):
|
||||
seen = pickle.loads(
|
||||
@@ -470,7 +466,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
task_i = one_hot(TUPLE_SAMPLES[i][2], 5)
|
||||
self.assertEqual(seen[0][0].tolist(), pos_i)
|
||||
self.assertEqual(seen[1][0].tolist(), cam_i)
|
||||
self.assertEqual(seen[2][0].tolist(), task_i)
|
||||
check(seen[2][0], task_i)
|
||||
|
||||
def test_rollout_dict_space(self):
|
||||
register_env("nested", lambda _: NestedDictEnv())
|
||||
@@ -521,7 +517,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
# the ray-kv indices before training.
|
||||
self.assertEqual(seen[0][-1].tolist(), pos_i)
|
||||
self.assertEqual(seen[1][-1].tolist(), cam_i)
|
||||
self.assertEqual(seen[2][-1].tolist(), task_i)
|
||||
check(seen[2][-1], task_i)
|
||||
|
||||
# TODO(ekl) should probably also add a test for TF/eager
|
||||
def test_torch_repeated(self):
|
||||
|
||||
+12
-3
@@ -2,6 +2,7 @@ import numpy as np
|
||||
import tree
|
||||
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.typing import TensorType, Union
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
@@ -88,14 +89,17 @@ def relu(x, alpha=0.0):
|
||||
return np.maximum(x, x * alpha, x)
|
||||
|
||||
|
||||
def one_hot(x, depth=0, on_value=1, off_value=0):
|
||||
def one_hot(x: Union[TensorType, int],
|
||||
depth: int = 0,
|
||||
on_value: int = 1.0,
|
||||
off_value: float = 0.0):
|
||||
"""
|
||||
One-hot utility function for numpy.
|
||||
Thanks to qianyizhang:
|
||||
https://gist.github.com/qianyizhang/07ee1c15cad08afb03f5de69349efc30.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): The input to be one-hot encoded.
|
||||
x (TensorType): The input to be one-hot encoded.
|
||||
depth (int): The max. number to be one-hot encoded (size of last rank).
|
||||
on_value (float): The value to use for on. Default: 1.0.
|
||||
off_value (float): The value to use for off. Default: 0.0.
|
||||
@@ -103,8 +107,12 @@ def one_hot(x, depth=0, on_value=1, off_value=0):
|
||||
Returns:
|
||||
np.ndarray: The one-hot encoded equivalent of the input array.
|
||||
"""
|
||||
|
||||
# Handle simple ints properly.
|
||||
if isinstance(x, int):
|
||||
x = np.array(x, dtype=np.int32)
|
||||
# Handle torch arrays properly.
|
||||
if torch and isinstance(x, torch.Tensor):
|
||||
elif torch and isinstance(x, torch.Tensor):
|
||||
x = x.numpy()
|
||||
|
||||
# Handle bool arrays correctly.
|
||||
@@ -112,6 +120,7 @@ def one_hot(x, depth=0, on_value=1, off_value=0):
|
||||
x = x.astype(np.int)
|
||||
depth = 2
|
||||
|
||||
# If depth is not given, try to infer it from the values in the array.
|
||||
if depth == 0:
|
||||
depth = np.max(x) + 1
|
||||
assert np.max(x) < depth, \
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import gym
|
||||
from gym.spaces import Discrete, MultiDiscrete
|
||||
import numpy as np
|
||||
import tree
|
||||
|
||||
@@ -65,6 +66,17 @@ def huber_loss(x, delta=1.0):
|
||||
tf.math.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta))
|
||||
|
||||
|
||||
def one_hot(x, space):
|
||||
if isinstance(space, Discrete):
|
||||
return tf.one_hot(x, space.n)
|
||||
elif isinstance(space, MultiDiscrete):
|
||||
return tf.concat(
|
||||
[tf.one_hot(x[:, i], n) for i, n in enumerate(space.nvec)],
|
||||
axis=-1)
|
||||
else:
|
||||
raise ValueError("Unsupported space for `one_hot`: {}".format(space))
|
||||
|
||||
|
||||
def reduce_mean_ignore_inf(x, axis):
|
||||
"""Same as tf.reduce_mean() but ignores -inf values."""
|
||||
mask = tf.not_equal(x, tf.float32.min)
|
||||
|
||||
Reference in New Issue
Block a user