diff --git a/.travis.yml b/.travis.yml index bb50822fe..689177906 100644 --- a/.travis.yml +++ b/.travis.yml @@ -149,7 +149,7 @@ matrix: - os: linux env: - RLLIB_TESTING=1 RLLIB_REGRESSION_TESTS=1 - - TF_VERSION=2.0.0b1 + - TF_VERSION=2.1.0 - TFP_VERSION=0.8 - TORCH_VERSION=1.4 - PYTHON=3.6 @@ -182,7 +182,7 @@ matrix: - os: linux env: - RLLIB_TESTING=1 RLLIB_REGRESSION_TESTS_TORCH=1 - - TF_VERSION=2.0.0b1 + - TF_VERSION=2.1.0 - TFP_VERSION=0.8 - TORCH_VERSION=1.4 - PYTHON=3.6 @@ -200,7 +200,7 @@ matrix: env: - RLLIB_TESTING=1 RLLIB_QUICK_TRAIN_AND_MISC_TESTS=1 - PYTHON=3.6 - - TF_VERSION=2.0.0b1 + - TF_VERSION=2.1.0 - TFP_VERSION=0.8 - TORCH_VERSION=1.4 - PYTHONWARNINGS=ignore @@ -220,7 +220,7 @@ matrix: env: - RLLIB_TESTING=1 RLLIB_EXAMPLE_DIR_TESTS=1 - PYTHON=3.6 - - TF_VERSION=2.0.0b1 + - TF_VERSION=2.1.0 - TFP_VERSION=0.8 - TORCH_VERSION=1.4 - PYTHONWARNINGS=ignore @@ -239,7 +239,7 @@ matrix: env: - RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_A_TO_L=1 - PYTHON=3.6 - - TF_VERSION=2.0.0b1 + - TF_VERSION=2.1.0 - TFP_VERSION=0.8 - TORCH_VERSION=1.4 - PYTHONWARNINGS=ignore @@ -255,7 +255,7 @@ matrix: env: - RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_M_TO_Z=1 - PYTHON=3.6 - - TF_VERSION=2.0.0b1 + - TF_VERSION=2.1.0 - TFP_VERSION=0.8 - TORCH_VERSION=1.4 - PYTHONWARNINGS=ignore diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index 5467afbb1..dc820b75d 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -209,7 +209,7 @@ install_dependencies() { msys*) pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f "${torch_url}";; esac - pip_packages=(scipy tensorflow=="${TF_VERSION:-2.0.0b1}" cython==0.29.0 gym \ + pip_packages=(scipy tensorflow=="${TF_VERSION:-2.1.0}" cython==0.29.0 gym \ opencv-python-headless pyyaml pandas==0.24.2 requests feather-format lxml openpyxl xlrd \ py-spy pytest pytest-timeout networkx tabulate aiohttp uvicorn dataclasses pygments werkzeug \ kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio scikit-learn==0.22.2 numba \ diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index bd1d9b27a..eacda45ab 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -80,9 +80,9 @@ For a full example of a custom model in code, see the `keras model example `__ model as an example to implement your own model: +Instead of using the ``use_lstm: True`` option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. For an RNN model it is preferred to subclass ``RecurrentNetwork`` to implement ``__init__()``, ``get_initial_state()``, and ``forward_rnn()``. You can check out the `custom_rnn_model.py `__ model as an example to implement your own model: -.. autoclass:: ray.rllib.models.tf.recurrent_tf_modelv2.RecurrentTFModelV2 +.. autoclass:: ray.rllib.models.tf.recurrent_net.RecurrentNetwork .. automethod:: __init__ .. automethod:: forward_rnn diff --git a/rllib/BUILD b/rllib/BUILD index 900eca88b..3dc416c9a 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1452,6 +1452,25 @@ py_test( # -------------------------------------------------------------------- +py_test( + name = "examples/attention_net_tf", + main = "examples/attention_net.py", + tags = ["examples", "examples_A"], + size = "large", + srcs = ["examples/attention_net.py"], + args = ["--as-test", "--stop-reward=80"] +) + +# TODO(sven): GTrXL PyTorch. +# py_test( +# name = "examples/attention_net_torch", +# main = "examples/attention_net.py", +# tags = ["examples", "examples_A"], +# size = "large", +# srcs = ["examples/attention_net.py"], +# args = ["--as-test", "--torch", "--stop-reward=90"] +# ) + py_test( name = "examples/autoregressive_action_dist_tf", main = "examples/autoregressive_action_dist.py", @@ -1492,7 +1511,7 @@ py_test( name = "examples/batch_norm_model_dqn_tf", main = "examples/batch_norm_model.py", tags = ["examples", "examples_B"], - size = "medium", # DQN learns much slower with BatchNorm. + size = "large", # DQN learns much slower with BatchNorm. srcs = ["examples/batch_norm_model.py"], args = ["--as-test", "--run=DQN", "--stop-reward=70"] ) @@ -1501,7 +1520,7 @@ py_test( name = "examples/batch_norm_model_dqn_torch", main = "examples/batch_norm_model.py", tags = ["examples", "examples_B"], - size = "medium", # DQN learns much slower with BatchNorm. + size = "large", # DQN learns much slower with BatchNorm. srcs = ["examples/batch_norm_model.py"], args = ["--as-test", "--torch", "--run=DQN", "--stop-reward=70"] ) @@ -1555,7 +1574,7 @@ py_test( name = "examples/cartpole_lstm_ppo_torch", main = "examples/cartpole_lstm.py", tags = ["examples", "examples_C"], - size = "small", + size = "medium", srcs = ["examples/cartpole_lstm.py"], args = ["--as-test", "--torch", "--run=PPO", "--stop-reward=40", "--num-cpus=4"] ) @@ -1871,7 +1890,7 @@ py_test( name = "examples/multi_agent_two_trainers_mixed_torch_tf", main = "examples/multi_agent_two_trainers.py", tags = ["examples", "examples_M"], - size = "small", + size = "medium", srcs = ["examples/multi_agent_two_trainers.py"], args = ["--as-test", "--mixed-torch-tf", "--stop-reward=70"] ) diff --git a/rllib/agents/sac/README.md b/rllib/agents/sac/README.md index 7b600ec1f..a85329f13 100644 --- a/rllib/agents/sac/README.md +++ b/rllib/agents/sac/README.md @@ -1,6 +1,6 @@ Implementation of the Soft Actor-Critic algorithm: -[1] Soft Actor-Critic Algorithms and Applications - T. Haarnoja, A. Zhou, K. Hartikainen, et. al +[1] Soft Actor-Critic Algorithms and Applications - T. Haarnoja, A. Zhou, K. Hartikainen, et al. https://arxiv.org/abs/1812.05905.pdf For supporting discrete action spaces, we implemented this patch on top of the original algorithm: diff --git a/rllib/examples/attention_net.py b/rllib/examples/attention_net.py new file mode 100644 index 000000000..8360b4b53 --- /dev/null +++ b/rllib/examples/attention_net.py @@ -0,0 +1,75 @@ +import argparse + +import ray +from ray import tune +from ray.rllib.utils import try_import_tf +from ray.rllib.models.tf.attention_net import GTrXLNet +from ray.rllib.examples.env.look_and_push import LookAndPush, OneHot +from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv +from ray.rllib.examples.env.repeat_initial_obs_env import RepeatInitialObsEnv +from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole +from ray.rllib.utils.test_utils import check_learning_achieved +from ray.tune import registry + +tf = try_import_tf() + +parser = argparse.ArgumentParser() +parser.add_argument("--run", type=str, default="PPO") +parser.add_argument("--env", type=str, default="RepeatAfterMeEnv") +parser.add_argument("--num-cpus", type=int, default=0) +parser.add_argument("--torch", action="store_true") +parser.add_argument("--as-test", action="store_true") +parser.add_argument("--stop-iters", type=int, default=200) +parser.add_argument("--stop-timesteps", type=int, default=500000) +parser.add_argument("--stop-reward", type=float, default=80) + +if __name__ == "__main__": + args = parser.parse_args() + + assert not args.torch, "PyTorch not supported for AttentionNets yet!" + + ray.init(num_cpus=args.num_cpus or None, local_mode=True) + + registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c)) + registry.register_env("RepeatInitialObsEnv", + lambda _: RepeatInitialObsEnv()) + registry.register_env("LookAndPush", lambda _: OneHot(LookAndPush())) + registry.register_env("StatelessCartPole", lambda _: StatelessCartPole()) + + config = { + "env": args.env, + "env_config": { + "repeat_delay": 2, + }, + "gamma": 0.99, + "num_workers": 0, + "num_envs_per_worker": 20, + "entropy_coeff": 0.001, + "num_sgd_iter": 5, + "vf_loss_coeff": 1e-5, + "model": { + "custom_model": GTrXLNet, + "max_seq_len": 50, + "custom_options": { + "num_transformer_units": 1, + "attn_dim": 64, + "num_heads": 2, + "memory_tau": 50, + "head_dim": 32, + "ff_hidden_dim": 32, + }, + }, + "use_pytorch": args.torch, + } + + stop = { + "training_iteration": args.stop_iters, + "timesteps_total": args.stop_timesteps, + "episode_reward_mean": args.stop_reward, + } + + results = tune.run(args.run, config=config, stop=stop, verbose=1) + + if args.as_test: + check_learning_achieved(results, args.stop_reward) + ray.shutdown() diff --git a/rllib/examples/supervised_attention.py b/rllib/examples/attention_net_supervised.py similarity index 63% rename from rllib/examples/supervised_attention.py rename to rllib/examples/attention_net_supervised.py index a4947aa53..642bed2db 100644 --- a/rllib/examples/supervised_attention.py +++ b/rllib/examples/attention_net_supervised.py @@ -1,11 +1,8 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - +from gym.spaces import Box, Discrete import numpy as np -from rllib.models.tf import attention -from ray.rllib.utils import try_import_tf +from rllib.models.tf.attention_net import TrXLNet +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() @@ -19,16 +16,6 @@ def bit_shift_generator(seq_length, shift, batch_size): yield seq, targets -def make_model(seq_length, num_tokens, num_layers, attn_dim, num_heads, - head_dim, ff_hidden_dim): - - return tf.keras.Sequential(( - attention.make_TrXL(seq_length, num_layers, attn_dim, num_heads, - head_dim, ff_hidden_dim), - tf.keras.layers.Dense(num_tokens), - )) - - def train_loss(targets, outputs): loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=targets, logits=outputs) @@ -39,10 +26,13 @@ def train_bit_shift(seq_length, num_iterations, print_every_n): optimizer = tf.keras.optimizers.Adam(1e-3) - model = make_model( - seq_length, - num_tokens=2, - num_layers=1, + model = TrXLNet( + observation_space=Box(low=0, high=1, shape=(1, ), dtype=np.int32), + action_space=Discrete(2), + num_outputs=2, + model_config={"max_seq_len": seq_length}, + name="trxl", + num_transformer_units=1, attn_dim=10, num_heads=5, head_dim=20, @@ -59,13 +49,20 @@ def train_bit_shift(seq_length, num_iterations, print_every_n): @tf.function def update_step(inputs, targets): - - optimizer.minimize(lambda: train_loss(targets, model(inputs)), + model_out = model( + { + "obs": inputs + }, + state=[tf.reshape(inputs, [-1, seq_length, 1])], + seq_lens=np.full(shape=(train_batch, ), fill_value=seq_length)) + optimizer.minimize(lambda: train_loss(targets, model_out), lambda: model.trainable_variables) for i, (inputs, targets) in zip(range(num_iterations), data_gen): + inputs_in = np.reshape(inputs, [-1, 1]) + targets_in = np.reshape(targets, [-1]) update_step( - tf.convert_to_tensor(inputs), tf.convert_to_tensor(targets)) + tf.convert_to_tensor(inputs_in), tf.convert_to_tensor(targets_in)) if i % print_every_n == 0: test_inputs, test_targets = next(test_gen) diff --git a/rllib/examples/custom_env.py b/rllib/examples/custom_env.py index 148e174b5..09e25b363 100644 --- a/rllib/examples/custom_env.py +++ b/rllib/examples/custom_env.py @@ -8,16 +8,16 @@ This example shows: You can visualize experiment results in ~/ray_results using TensorBoard. """ import argparse -import numpy as np import gym from gym.spaces import Discrete, Box +import numpy as np import ray from ray import tune from ray.tune import grid_search from ray.rllib.models import ModelCatalog from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork +from ray.rllib.models.tf.fcnet import FullyConnectedNetwork from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC from ray.rllib.utils.framework import try_import_tf, try_import_torch diff --git a/rllib/examples/custom_keras_model.py b/rllib/examples/custom_keras_model.py index 57a30977d..4f77f2bd8 100644 --- a/rllib/examples/custom_keras_model.py +++ b/rllib/examples/custom_keras_model.py @@ -10,7 +10,7 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.agents.dqn.distributional_q_tf_model import \ DistributionalQTFModel from ray.rllib.utils import try_import_tf -from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as MyVisionNetwork +from ray.rllib.models.tf.visionnet import VisionNetwork as MyVisionNetwork tf = try_import_tf() diff --git a/rllib/examples/env/debug_counter_env.py b/rllib/examples/env/debug_counter_env.py new file mode 100644 index 000000000..0fc3ad972 --- /dev/null +++ b/rllib/examples/env/debug_counter_env.py @@ -0,0 +1,23 @@ +import gym + + +class DebugCounterEnv(gym.Env): + """Simple Env that yields a ts counter as observation (0-based). + + Actions have no effect. + The episode length is always 15. + Reward is always: current ts % 3. + """ + + def __init__(self): + self.action_space = gym.spaces.Discrete(2) + self.observation_space = gym.spaces.Box(0, 100, (1, )) + self.i = 0 + + def reset(self): + self.i = 0 + return [self.i] + + def step(self, action): + self.i += 1 + return [self.i], self.i % 3, self.i >= 15, {} diff --git a/rllib/examples/env/look_and_push.py b/rllib/examples/env/look_and_push.py new file mode 100644 index 000000000..d8ac05594 --- /dev/null +++ b/rllib/examples/env/look_and_push.py @@ -0,0 +1,66 @@ +import gym +import numpy as np + + +class LookAndPush(gym.Env): + """Memory-requiring Env: Best sequence of actions depends on prev. states. + + Optimal behavior: + 0) a=0 -> observe next state (s'), which is the "hidden" state. + If a=1 here, the hidden state is not observed. + 1) a=1 to always jump to s=2 (not matter what the prev. state was). + 2) a=1 to move to s=3. + 3) a=1 to move to s=4. + 4) a=0 OR 1 depending on s' observed after 0): +10 reward and done. + otherwise: -10 reward and done. + """ + + def __init__(self): + self.action_space = gym.spaces.Discrete(2) + self.observation_space = gym.spaces.Discrete(5) + self._state = None + self._case = None + + def reset(self): + self._state = 2 + self._case = np.random.choice(2) + return self._state + + def step(self, action): + assert self.action_space.contains(action) + + if self._state == 4: + if action and self._case: + return self._state, 10., True, {} + else: + return self._state, -10, True, {} + else: + if action: + if self._state == 0: + self._state = 2 + else: + self._state += 1 + elif self._state == 2: + self._state = self._case + + return self._state, -1, False, {} + + +class OneHot(gym.Wrapper): + def __init__(self, env): + super(OneHot, self).__init__(env) + self.observation_space = gym.spaces.Box(0., 1., + (env.observation_space.n, )) + + def reset(self, **kwargs): + obs = self.env.reset(**kwargs) + return self._encode_obs(obs) + + def step(self, action): + obs, reward, done, info = self.env.step(action) + return self._encode_obs(obs), reward, done, info + + def _encode_obs(self, obs): + new_obs = np.ones(self.env.observation_space.n) + new_obs[obs] = 1.0 + return new_obs diff --git a/rllib/examples/models/centralized_critic_models.py b/rllib/examples/models/centralized_critic_models.py index 325b6b493..030ab66fe 100644 --- a/rllib/examples/models/centralized_critic_models.py +++ b/rllib/examples/models/centralized_critic_models.py @@ -2,7 +2,7 @@ from gym.spaces import Box from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork +from ray.rllib.models.tf.fcnet import FullyConnectedNetwork from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC diff --git a/rllib/examples/models/eager_model.py b/rllib/examples/models/eager_model.py index 3a6ae83c6..6e2d44c04 100644 --- a/rllib/examples/models/eager_model.py +++ b/rllib/examples/models/eager_model.py @@ -1,7 +1,7 @@ import random from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork +from ray.rllib.models.tf.fcnet import FullyConnectedNetwork from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf diff --git a/rllib/examples/models/mobilenet_v2_with_lstm_models.py b/rllib/examples/models/mobilenet_v2_with_lstm_models.py index 4ce78f4dc..3bc7052be 100644 --- a/rllib/examples/models/mobilenet_v2_with_lstm_models.py +++ b/rllib/examples/models/mobilenet_v2_with_lstm_models.py @@ -1,9 +1,9 @@ import numpy as np from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.tf.recurrent_tf_modelv2 import RecurrentTFModelV2 +from ray.rllib.models.tf.recurrent_net import RecurrentNetwork from ray.rllib.models.torch.misc import SlimFC -from ray.rllib.models.torch.recurrent_torch_model import RecurrentTorchModel +from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf, try_import_torch @@ -11,7 +11,7 @@ tf = try_import_tf() torch, nn = try_import_torch() -class MobileV2PlusRNNModel(RecurrentTFModelV2): +class MobileV2PlusRNNModel(RecurrentNetwork): """A conv. + recurrent keras net example using a pre-trained MobileNet.""" def __init__(self, obs_space, action_space, num_outputs, model_config, @@ -71,7 +71,7 @@ class MobileV2PlusRNNModel(RecurrentTFModelV2): self.register_variables(self.rnn_model.variables) self.rnn_model.summary() - @override(RecurrentTFModelV2) + @override(RecurrentNetwork) def forward_rnn(self, inputs, state, seq_lens): model_out, self._value_out, h, c = self.rnn_model([inputs, seq_lens] + state) @@ -89,7 +89,7 @@ class MobileV2PlusRNNModel(RecurrentTFModelV2): return tf.reshape(self._value_out, [-1]) -class TorchMobileV2PlusRNNModel(RecurrentTorchModel): +class TorchMobileV2PlusRNNModel(TorchRNN): """A conv. + recurrent torch net example using a pre-trained MobileNet.""" def __init__(self, obs_space, action_space, num_outputs, model_config, @@ -117,7 +117,7 @@ class TorchMobileV2PlusRNNModel(RecurrentTorchModel): # Holds the current "base" output (before logits layer). self._features = None - @override(RecurrentTFModelV2) + @override(TorchRNN) def forward_rnn(self, inputs, state, seq_lens): # Create image dims. vision_in = torch.reshape(inputs, [-1] + self.cnn_shape) diff --git a/rllib/examples/models/parametric_actions_model.py b/rllib/examples/models/parametric_actions_model.py index 15918d28b..1a4e91562 100644 --- a/rllib/examples/models/parametric_actions_model.py +++ b/rllib/examples/models/parametric_actions_model.py @@ -4,7 +4,7 @@ from ray.rllib.agents.dqn.distributional_q_tf_model import \ DistributionalQTFModel from ray.rllib.agents.dqn.dqn_torch_model import \ DQNTorchModel -from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork +from ray.rllib.models.tf.fcnet import FullyConnectedNetwork from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.numpy import LARGE_INTEGER diff --git a/rllib/examples/models/rnn_model.py b/rllib/examples/models/rnn_model.py index e13eee24c..fc6d31c16 100644 --- a/rllib/examples/models/rnn_model.py +++ b/rllib/examples/models/rnn_model.py @@ -2,8 +2,8 @@ import numpy as np from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.preprocessors import get_preprocessor -from ray.rllib.models.tf.recurrent_tf_modelv2 import RecurrentTFModelV2 -from ray.rllib.models.torch.recurrent_torch_model import RecurrentTorchModel +from ray.rllib.models.tf.recurrent_net import RecurrentNetwork +from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf, try_import_torch @@ -11,7 +11,7 @@ tf = try_import_tf() torch, nn = try_import_torch() -class RNNModel(RecurrentTFModelV2): +class RNNModel(RecurrentNetwork): """Example of using the Keras functional API to define a RNN model.""" def __init__(self, @@ -57,7 +57,7 @@ class RNNModel(RecurrentTFModelV2): self.register_variables(self.rnn_model.variables) self.rnn_model.summary() - @override(RecurrentTFModelV2) + @override(RecurrentNetwork) def forward_rnn(self, inputs, state, seq_lens): model_out, self._value_out, h, c = self.rnn_model([inputs, seq_lens] + state) @@ -75,7 +75,7 @@ class RNNModel(RecurrentTFModelV2): return tf.reshape(self._value_out, [-1]) -class TorchRNNModel(RecurrentTorchModel): +class TorchRNNModel(TorchRNN): def __init__(self, obs_space, action_space, @@ -114,7 +114,7 @@ class TorchRNNModel(RecurrentTorchModel): assert self._features is not None, "must call forward() first" return torch.reshape(self.value_branch(self._features), [-1]) - @override(RecurrentTorchModel) + @override(TorchRNN) def forward_rnn(self, inputs, state, seq_lens): """Feeds `inputs` (B x T x ..) through the Gru Unit. diff --git a/rllib/examples/models/rnn_spy_model.py b/rllib/examples/models/rnn_spy_model.py new file mode 100644 index 000000000..18f06f202 --- /dev/null +++ b/rllib/examples/models/rnn_spy_model.py @@ -0,0 +1,131 @@ +import numpy as np +import pickle + +import ray +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.misc import normc_initializer +from ray.rllib.models.tf.recurrent_net import RecurrentNetwork +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf + +tf = try_import_tf() + + +class SpyLayer(tf.keras.layers.Layer): + """A keras Layer, which intercepts its inputs and stored them as pickled. + """ + + def __init__(self, num_outputs, **kwargs): + super().__init__(**kwargs) + + self.dense = tf.keras.layers.Dense( + units=num_outputs, kernel_initializer=normc_initializer(0.01)) + + def call(self, inputs, **kwargs): + """Does a forward pass through our Dense, but also intercepts inputs. + """ + + del kwargs + spy_fn = tf.py_func( + self.spy, + [ + inputs[0], # observations + inputs[2], # seq_lens + inputs[3], # h_in + inputs[4], # c_in + inputs[5], # h_out + inputs[6], # c_out + ], + tf.int64, + stateful=True) + + # Compute outputs + with tf.control_dependencies([spy_fn]): + return self.dense(inputs[1]) + + @staticmethod + def spy(inputs, seq_lens, h_in, c_in, h_out, c_out): + """The actual spy operation: Store inputs in internal_kv.""" + + if len(inputs) == 1: + return 0 # don't capture inference inputs + # TF runs this function in an isolated context, so we have to use + # redis to communicate back to our suite + ray.experimental.internal_kv._internal_kv_put( + "rnn_spy_in_{}".format(RNNSpyModel.capture_index), + pickle.dumps({ + "sequences": inputs, + "seq_lens": seq_lens, + "state_in": [h_in, c_in], + "state_out": [h_out, c_out] + }), + overwrite=True) + RNNSpyModel.capture_index += 1 + return 0 + + +class RNNSpyModel(RecurrentNetwork): + capture_index = 0 + cell_size = 3 + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super().__init__(obs_space, action_space, num_outputs, model_config, + name) + self.cell_size = RNNSpyModel.cell_size + + # Create a keras LSTM model. + inputs = tf.keras.layers.Input( + shape=(None, ) + obs_space.shape, name="input") + state_in_h = tf.keras.layers.Input(shape=(self.cell_size, ), name="h") + state_in_c = tf.keras.layers.Input(shape=(self.cell_size, ), name="c") + seq_lens = tf.keras.layers.Input( + shape=(), name="seq_lens", dtype=tf.int32) + + lstm_out, state_out_h, state_out_c = tf.keras.layers.LSTM( + self.cell_size, + return_sequences=True, + return_state=True, + name="lstm")( + inputs=inputs, + mask=tf.sequence_mask(seq_lens), + initial_state=[state_in_h, state_in_c]) + + logits = SpyLayer(num_outputs=self.num_outputs)([ + inputs, lstm_out, seq_lens, state_in_h, state_in_c, state_out_h, + state_out_c + ]) + + # Value branch. + value_out = tf.keras.layers.Dense( + units=1, kernel_initializer=normc_initializer(1.0))(lstm_out) + + self.base_model = tf.keras.Model( + [inputs, seq_lens, state_in_h, state_in_c], + [logits, value_out, state_out_h, state_out_c]) + self.base_model.summary() + self.register_variables(self.base_model.variables) + + @override(RecurrentNetwork) + def forward_rnn(self, inputs, state, seq_lens): + # Previously, a new class object was created during + # deserialization and this `capture_index` + # variable would be refreshed between class instantiations. + # This behavior is no longer the case, so we manually refresh + # the variable. + RNNSpyModel.capture_index = 0 + model_out, value_out, h, c = self.base_model( + [inputs, seq_lens, state[0], state[1]]) + self._value_out = value_out + return model_out, [h, c] + + @override(ModelV2) + def value_function(self): + return tf.reshape(self._value_out, [-1]) + + @override(ModelV2) + def get_initial_state(self): + return [ + np.zeros(self.cell_size, np.float32), + np.zeros(self.cell_size, np.float32) + ] diff --git a/rllib/examples/parametric_actions_cartpole.py b/rllib/examples/parametric_actions_cartpole.py index c83c588a3..ae0287a9a 100644 --- a/rllib/examples/parametric_actions_cartpole.py +++ b/rllib/examples/parametric_actions_cartpole.py @@ -18,13 +18,13 @@ import argparse import ray from ray import tune -from ray.tune.registry import register_env from ray.rllib.examples.env.parametric_actions_cartpole import \ ParametricActionsCartPole from ray.rllib.examples.models.parametric_actions_model import \ ParametricActionsModel, TorchParametricActionsModel from ray.rllib.models import ModelCatalog from ray.rllib.utils.test_utils import check_learning_achieved +from ray.tune.registry import register_env parser = argparse.ArgumentParser() parser.add_argument("--run", type=str, default="PPO") diff --git a/rllib/examples/rl_attention.py b/rllib/examples/rl_attention.py deleted file mode 100644 index a18d38be6..000000000 --- a/rllib/examples/rl_attention.py +++ /dev/null @@ -1,173 +0,0 @@ -import argparse - -import gym - -import numpy as np - -import ray -from ray import tune - -from ray.tune import registry - -from ray.rllib import models -from ray.rllib.utils import try_import_tf -from ray.rllib.models.tf import attention -from ray.rllib.models.tf import recurrent_tf_modelv2 -from ray.rllib.examples.custom_keras_rnn_model import RepeatAfterMeEnv -from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv - -tf = try_import_tf() - -parser = argparse.ArgumentParser() -parser.add_argument("--run", type=str, default="PPO") -parser.add_argument("--env", type=str, default="RepeatAfterMeEnv") -parser.add_argument("--stop", type=int, default=90) -parser.add_argument("--num-cpus", type=int, default=0) - - -class OneHot(gym.Wrapper): - - def __init__(self, env): - super(OneHot, self).__init__(env) - self.observation_space = gym.spaces.Box(0., 1., - (env.observation_space.n,)) - - def reset(self, **kwargs): - obs = self.env.reset(**kwargs) - return self._encode_obs(obs) - - def step(self, action): - obs, reward, done, info = self.env.step(action) - return self._encode_obs(obs), reward, done, info - - def _encode_obs(self, obs): - new_obs = np.ones(self.env.observation_space.n) - new_obs[obs] = 1.0 - return new_obs - - -class LookAndPush(gym.Env): - def __init__(self): - self.action_space = gym.spaces.Discrete(2) - self.observation_space = gym.spaces.Discrete(5) - self._state = None - self._case = None - - def reset(self): - self._state = 2 - self._case = np.random.choice(2) - return self._state - - def step(self, action): - assert self.action_space.contains(action) - - if self._state == 4: - if action and self._case: - return self._state, 10., True, {} - else: - return self._state, -10, True, {} - else: - if action: - if self._state == 0: - self._state = 2 - else: - self._state += 1 - elif self._state == 2: - self._state = self._case - - return self._state, -1, False, {} - - -class GRUTrXL(recurrent_tf_modelv2.RecurrentTFModelV2): - - def __init__(self, obs_space, action_space, num_outputs, model_config, - name): - super(GRUTrXL, self).__init__(obs_space, action_space, num_outputs, - model_config, name) - self.max_seq_len = model_config["max_seq_len"] - self.obs_dim = obs_space.shape[0] - input_layer = tf.keras.layers.Input( - shape=(self.max_seq_len, obs_space.shape[0]), - name="inputs", - ) - - trxl_out = attention.make_GRU_TrXL( - seq_length=model_config["max_seq_len"], - num_layers=model_config["custom_options"]["num_layers"], - attn_dim=model_config["custom_options"]["attn_dim"], - num_heads=model_config["custom_options"]["num_heads"], - head_dim=model_config["custom_options"]["head_dim"], - ff_hidden_dim=model_config["custom_options"]["ff_hidden_dim"], - )(input_layer) - - # Postprocess TrXL output with another hidden layer and compute values - logits = tf.keras.layers.Dense( - self.num_outputs, - activation=tf.keras.activations.linear, - name="logits")(trxl_out) - values_out = tf.keras.layers.Dense( - 1, activation=None, name="values")(trxl_out) - - self.trxl_model = tf.keras.Model( - inputs=[input_layer], - outputs=[logits, values_out], - ) - self.register_variables(self.trxl_model.variables) - self.trxl_model.summary() - - def forward_rnn(self, inputs, state, seq_lens): - state = state[0] - - # We assume state is the history of recent observations and append - # the current inputs to the end and only keep the most recent (up to - # max_seq_len). This allows us to deal with timestep-wise inference - # and full sequence training with the same logic. - state = tf.concat((state, inputs), axis=1)[:, -self.max_seq_len:] - logits, self._value_out = self.trxl_model(state) - - in_T = tf.shape(inputs)[1] - logits = logits[:, -in_T:] - self._value_out = self._value_out[:, -in_T:] - - return logits, [state] - - def get_initial_state(self): - return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] - - def value_function(self): - return tf.reshape(self._value_out, [-1]) - - -if __name__ == "__main__": - args = parser.parse_args() - ray.init(num_cpus=args.num_cpus or None) - models.ModelCatalog.register_custom_model("trxl", GRUTrXL) - registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c)) - registry.register_env("RepeatInitialEnv", lambda _: RepeatInitialEnv()) - registry.register_env("LookAndPush", lambda _: OneHot(LookAndPush())) - tune.run( - args.run, - stop={"episode_reward_mean": args.stop}, - config={ - "env": args.env, - "env_config": { - "repeat_delay": 2, - }, - "gamma": 0.99, - "num_workers": 0, - "num_envs_per_worker": 20, - "entropy_coeff": 0.001, - "num_sgd_iter": 5, - "vf_loss_coeff": 1e-5, - "model": { - "custom_model": "trxl", - "max_seq_len": 10, - "custom_options": { - "num_layers": 1, - "attn_dim": 10, - "num_heads": 1, - "head_dim": 10, - "ff_hidden_dim": 20, - }, - }, - }) diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index b646fc452..f226b4532 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -280,10 +280,13 @@ class ModelCatalog: """ if model_config.get("custom_model"): - model_cls = _global_registry.get(RLLIB_MODEL, - model_config["custom_model"]) + if isinstance(model_config["custom_model"], type): + model_cls = model_config["custom_model"] + else: + model_cls = _global_registry.get(RLLIB_MODEL, + model_config["custom_model"]) + # TODO(sven): Hard-deprecate Model(V1). if issubclass(model_cls, ModelV2): - logger.info("Wrapping {} as {}".format(model_cls, model_interface)) model_cls = ModelCatalog._wrap_if_needed( @@ -299,9 +302,26 @@ class ModelCatalog: return v with tf.variable_creator_scope(track_var_creation): - instance = model_cls(obs_space, action_space, - num_outputs, model_config, name, - **model_kwargs) + # Try calling with kwargs first (custom ModelV2 should + # accept these as kwargs, not get them from + # config["custom_options"] anymore) + try: + instance = model_cls(obs_space, action_space, + num_outputs, model_config, + name, **model_kwargs) + except TypeError as e: + # Keyword error: Try old way w/o kwargs. + if "__init__() got an unexpected " in e.args[0]: + logger.warning( + "Custom ModelV2 should accept all custom " + "options as **kwargs, instead of expecting" + " them in config['custom_options']!") + instance = model_cls(obs_space, action_space, + num_outputs, model_config, + name) + # Other error -> re-raise. + else: + raise e registered = set(instance.variables()) not_registered = set() for var in created: @@ -322,7 +342,8 @@ class ModelCatalog: instance = model_cls(obs_space, action_space, num_outputs, model_config, name, **model_kwargs) return instance - + # TODO(sven): Hard-deprecate Model(V1). This check will be + # superflous then. elif tf.executing_eagerly(): raise ValueError( "Eager execution requires a TFModelV2 model to be " @@ -536,9 +557,9 @@ class ModelCatalog: from ray.rllib.models.torch.visionnet import (VisionNetwork as VisionNet) else: - from ray.rllib.models.tf.fcnet_v2 import \ + from ray.rllib.models.tf.fcnet import \ FullyConnectedNetwork as FCNet - from ray.rllib.models.tf.visionnet_v2 import \ + from ray.rllib.models.tf.visionnet import \ VisionNetwork as VisionNet # Discrete/1D obs-spaces. diff --git a/rllib/models/tf/__init__.py b/rllib/models/tf/__init__.py index 2ca11563f..d43e8f403 100644 --- a/rllib/models/tf/__init__.py +++ b/rllib/models/tf/__init__.py @@ -1,12 +1,12 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork -from ray.rllib.models.tf.recurrent_tf_modelv2 import \ - RecurrentTFModelV2 -from ray.rllib.models.tf.visionnet_v2 import VisionNetwork +from ray.rllib.models.tf.fcnet import FullyConnectedNetwork +from ray.rllib.models.tf.recurrent_net import \ + RecurrentNetwork +from ray.rllib.models.tf.visionnet import VisionNetwork __all__ = [ "FullyConnectedNetwork", - "RecurrentTFModelV2", + "RecurrentNetwork", "TFModelV2", "VisionNetwork", ] diff --git a/rllib/models/tf/attention.py b/rllib/models/tf/attention.py deleted file mode 100644 index e0ff6b1d7..000000000 --- a/rllib/models/tf/attention.py +++ /dev/null @@ -1,284 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from ray.rllib.utils import try_import_tf - -tf = try_import_tf() - - -def relative_position_embedding(seq_length, out_dim): - inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim)) - pos_offsets = tf.range(seq_length - 1., -1., -1.) - inputs = pos_offsets[:, None] * inverse_freq[None, :] - return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1) - - -def rel_shift(x): - # Transposed version of the shift approach implemented by Dai et al. 2019 - # https://github.com/kimiyoung/transformer-xl/blob/ - # 44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31 - x_size = tf.shape(x) - - x = tf.pad(x, [[0, 0], [0, 0], [1, 0], [0, 0]]) - x = tf.reshape(x, [x_size[0], x_size[2] + 1, x_size[1], x_size[3]]) - x = tf.slice(x, [0, 1, 0, 0], [-1, -1, -1, -1]) - x = tf.reshape(x, x_size) - - return x - - -class MultiHeadAttention(tf.keras.layers.Layer): - def __init__(self, out_dim, num_heads, head_dim, **kwargs): - super(MultiHeadAttention, self).__init__(**kwargs) - - # no bias or non-linearity - self._num_heads = num_heads - self._head_dim = head_dim - self._qkv_layer = tf.keras.layers.Dense( - 3 * num_heads * head_dim, use_bias=False) - self._linear_layer = tf.keras.layers.TimeDistributed( - tf.keras.layers.Dense(out_dim, use_bias=False)) - - def call(self, inputs): - L = tf.shape(inputs)[1] # length of segment - H = self._num_heads # number of attention heads - D = self._head_dim # attention head dimension - - qkv = self._qkv_layer(inputs) - - queries, keys, values = tf.split(qkv, 3, -1) - queries = queries[:, -L:] # only query based on the segment - - queries = tf.reshape(queries, [-1, L, H, D]) - keys = tf.reshape(keys, [-1, L, H, D]) - values = tf.reshape(values, [-1, L, H, D]) - - score = tf.einsum("bihd,bjhd->bijh", queries, keys) - score = score / D**0.5 - - # causal mask of the same length as the sequence - mask = tf.sequence_mask(tf.range(1, L + 1), dtype=score.dtype) - mask = mask[None, :, :, None] - - masked_score = score * mask + 1e30 * (mask - 1.) - wmat = tf.nn.softmax(masked_score, axis=2) - - out = tf.einsum("bijh,bjhd->bihd", wmat, values) - out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * D]), axis=0)) - return self._linear_layer(out) - - -class RelativeMultiHeadAttention(tf.keras.layers.Layer): - def __init__(self, - out_dim, - num_heads, - head_dim, - rel_pos_encoder, - input_layernorm=False, - output_activation=None, - **kwargs): - super(RelativeMultiHeadAttention, self).__init__(**kwargs) - - # no bias or non-linearity - self._num_heads = num_heads - self._head_dim = head_dim - self._qkv_layer = tf.keras.layers.Dense( - 3 * num_heads * head_dim, use_bias=False) - self._linear_layer = tf.keras.layers.TimeDistributed( - tf.keras.layers.Dense( - out_dim, use_bias=False, activation=output_activation)) - - self._uvar = self.add_weight(shape=(num_heads, head_dim)) - self._vvar = self.add_weight(shape=(num_heads, head_dim)) - - self._pos_proj = tf.keras.layers.Dense( - num_heads * head_dim, use_bias=False) - self._rel_pos_encoder = rel_pos_encoder - - self._input_layernorm = None - if input_layernorm: - self._input_layernorm = tf.keras.layers.LayerNormalization(axis=-1) - - def call(self, inputs, memory=None): - L = tf.shape(inputs)[1] # length of segment - H = self._num_heads # number of attention heads - D = self._head_dim # attention head dimension - - # length of the memory segment - M = memory.shape[0] if memory is not None else 0 - - if memory is not None: - inputs = np.concatenate((tf.stop_gradient(memory), inputs), axis=1) - - if self._input_layernorm is not None: - inputs = self._input_layernorm(inputs) - - qkv = self._qkv_layer(inputs) - - queries, keys, values = tf.split(qkv, 3, -1) - queries = queries[:, -L:] # only query based on the segment - - queries = tf.reshape(queries, [-1, L, H, D]) - keys = tf.reshape(keys, [-1, L + M, H, D]) - values = tf.reshape(values, [-1, L + M, H, D]) - - rel = self._pos_proj(self._rel_pos_encoder) - rel = tf.reshape(rel, [L, H, D]) - - score = tf.einsum("bihd,bjhd->bijh", queries + self._uvar, keys) - pos_score = tf.einsum("bihd,jhd->bijh", queries + self._vvar, rel) - score = score + rel_shift(pos_score) - score = score / D**0.5 - - # causal mask of the same length as the sequence - mask = tf.sequence_mask(tf.range(M + 1, L + M + 1), dtype=score.dtype) - mask = mask[None, :, :, None] - - masked_score = score * mask + 1e30 * (mask - 1.) - wmat = tf.nn.softmax(masked_score, axis=2) - - out = tf.einsum("bijh,bjhd->bihd", wmat, values) - out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * D]), axis=0)) - return self._linear_layer(out) - - -class PositionwiseFeedforward(tf.keras.layers.Layer): - def __init__(self, out_dim, hidden_dim, output_activation=None, **kwargs): - super(PositionwiseFeedforward, self).__init__(**kwargs) - - self._hidden_layer = tf.keras.layers.Dense( - hidden_dim, - activation=tf.nn.relu, - ) - self._output_layer = tf.keras.layers.Dense( - out_dim, activation=output_activation) - - def call(self, inputs, **kwargs): - del kwargs - output = self._hidden_layer(inputs) - return self._output_layer(output) - - -class SkipConnection(tf.keras.layers.Layer): - """Skip connection layer. - - If no fan-in layer is specified, then this layer behaves as a regular - residual layer. - """ - - def __init__(self, layer, fan_in_layer=None, **kwargs): - super(SkipConnection, self).__init__(**kwargs) - self._fan_in_layer = fan_in_layer - self._layer = layer - - def call(self, inputs, **kwargs): - del kwargs - outputs = self._layer(inputs) - if self._fan_in_layer is None: - outputs = outputs + inputs - else: - outputs = self._fan_in_layer((inputs, outputs)) - - return outputs - - -class GRUGate(tf.keras.layers.Layer): - def __init__(self, init_bias=0., **kwargs): - super(GRUGate, self).__init__(**kwargs) - self._init_bias = init_bias - - def build(self, input_shape): - x_shape, y_shape = input_shape - if x_shape[-1] != y_shape[-1]: - raise ValueError( - "Both inputs to GRUGate must equal size last axis.") - - self._w_r = self.add_weight(shape=(y_shape[-1], y_shape[-1])) - self._w_z = self.add_weight(shape=(y_shape[-1], y_shape[-1])) - self._w_h = self.add_weight(shape=(y_shape[-1], y_shape[-1])) - self._u_r = self.add_weight(shape=(x_shape[-1], x_shape[-1])) - self._u_z = self.add_weight(shape=(x_shape[-1], x_shape[-1])) - self._u_h = self.add_weight(shape=(x_shape[-1], x_shape[-1])) - - def bias_initializer(shape, dtype): - return tf.fill(shape, tf.cast(self._init_bias, dtype=dtype)) - - self._bias_z = self.add_weight( - shape=(x_shape[-1], ), initializer=bias_initializer) - - def call(self, inputs, **kwargs): - x, y = inputs - r = (tf.tensordot(y, self._w_r, axes=1) + tf.tensordot( - x, self._u_r, axes=1)) - r = tf.nn.sigmoid(r) - - z = (tf.tensordot(y, self._w_z, axes=1) + tf.tensordot( - x, self._u_z, axes=1) + self._bias_z) - z = tf.nn.sigmoid(z) - - h = (tf.tensordot(y, self._w_h, axes=1) + tf.tensordot( - (x * r), self._u_h, axes=1)) - h = tf.nn.tanh(h) - - return (1 - z) * x + z * h - - -def make_TrXL(seq_length, num_layers, attn_dim, num_heads, head_dim, - ff_hidden_dim): - pos_embedding = relative_position_embedding(seq_length, attn_dim) - - layers = [tf.keras.layers.Dense(attn_dim)] - for _ in range(num_layers): - layers.append( - SkipConnection( - RelativeMultiHeadAttention(attn_dim, num_heads, head_dim, - pos_embedding))) - layers.append(tf.keras.layers.LayerNormalization(axis=-1)) - - layers.append( - SkipConnection(PositionwiseFeedforward(attn_dim, ff_hidden_dim))) - layers.append(tf.keras.layers.LayerNormalization(axis=-1)) - - return tf.keras.Sequential(layers) - - -def make_GRU_TrXL(seq_length, - num_layers, - attn_dim, - num_heads, - head_dim, - ff_hidden_dim, - init_gate_bias=2.): - # Default initial bias for the gate taken from - # Parisotto, Emilio, et al. "Stabilizing Transformers for Reinforcement - # Learning." arXiv preprint arXiv:1910.06764 (2019). - pos_embedding = relative_position_embedding(seq_length, attn_dim) - - layers = [tf.keras.layers.Dense(attn_dim)] - for _ in range(num_layers): - layers.append( - SkipConnection( - RelativeMultiHeadAttention( - attn_dim, - num_heads, - head_dim, - pos_embedding, - input_layernorm=True, - output_activation=tf.nn.relu), - fan_in_layer=GRUGate(init_gate_bias), - )) - - layers.append( - SkipConnection( - tf.keras.Sequential( - (tf.keras.layers.LayerNormalization(axis=-1), - PositionwiseFeedforward( - attn_dim, ff_hidden_dim, - output_activation=tf.nn.relu))), - fan_in_layer=GRUGate(init_gate_bias), - )) - - return tf.keras.Sequential(layers) diff --git a/rllib/models/tf/attention_net.py b/rllib/models/tf/attention_net.py new file mode 100644 index 000000000..d4e25630c --- /dev/null +++ b/rllib/models/tf/attention_net.py @@ -0,0 +1,336 @@ +""" +[1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar, + Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017. + https://arxiv.org/pdf/1706.03762.pdf +[2] - Stabilizing Transformers for Reinforcement Learning - E. Parisotto + et al. - DeepMind - 2019. https://arxiv.org/pdf/1910.06764.pdf +[3] - Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context. + Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019. + https://www.aclweb.org/anthology/P19-1285.pdf +""" +import numpy as np + +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \ + SkipConnection +from ray.rllib.models.tf.recurrent_net import RecurrentNetwork +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf + +tf = try_import_tf() + + +# TODO(sven): Use RLlib's FCNet instead. +class PositionwiseFeedforward(tf.keras.layers.Layer): + """A 2x linear layer with ReLU activation in between described in [1]. + + Each timestep coming from the attention head will be passed through this + layer separately. + """ + + def __init__(self, out_dim, hidden_dim, output_activation=None, **kwargs): + super().__init__(**kwargs) + + self._hidden_layer = tf.keras.layers.Dense( + hidden_dim, + activation=tf.nn.relu, + ) + + self._output_layer = tf.keras.layers.Dense( + out_dim, activation=output_activation) + + def call(self, inputs, **kwargs): + del kwargs + output = self._hidden_layer(inputs) + return self._output_layer(output) + + +class TrXLNet(RecurrentNetwork): + """A TrXL net Model described in [1].""" + + def __init__(self, observation_space, action_space, num_outputs, + model_config, name, num_transformer_units, attn_dim, + num_heads, head_dim, ff_hidden_dim): + """Initializes a TfXLNet object. + + Args: + num_transformer_units (int): The number of Transformer repeats to + use (denoted L in [2]). + attn_dim (int): The input and output dimensions of one Transformer + unit. + num_heads (int): The number of attention heads to use in parallel. + Denoted as `H` in [3]. + head_dim (int): The dimension of a single(!) head. + Denoted as `d` in [3]. + ff_hidden_dim (int): The dimension of the hidden layer within + the position-wise MLP (after the multi-head attention block + within one Transformer unit). This is the size of the first + of the two layers within the PositionwiseFeedforward. The + second layer always has size=`attn_dim`. + """ + + super().__init__(observation_space, action_space, num_outputs, + model_config, name) + + self.num_transformer_units = num_transformer_units + self.attn_dim = attn_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.max_seq_len = model_config["max_seq_len"] + self.obs_dim = observation_space.shape[0] + + pos_embedding = relative_position_embedding(self.max_seq_len, attn_dim) + + inputs = tf.keras.layers.Input( + shape=(self.max_seq_len, self.obs_dim), name="inputs") + E_out = tf.keras.layers.Dense(attn_dim)(inputs) + + for _ in range(self.num_transformer_units): + MHA_out = SkipConnection( + RelativeMultiHeadAttention( + out_dim=attn_dim, + num_heads=num_heads, + head_dim=head_dim, + rel_pos_encoder=pos_embedding, + input_layernorm=False, + output_activation=None), + fan_in_layer=None)(E_out) + E_out = SkipConnection( + PositionwiseFeedforward(attn_dim, ff_hidden_dim))(MHA_out) + E_out = tf.keras.layers.LayerNormalization(axis=-1)(E_out) + + # Postprocess TrXL output with another hidden layer and compute values. + logits = tf.keras.layers.Dense( + self.num_outputs, + activation=tf.keras.activations.linear, + name="logits")(E_out) + + self.base_model = tf.keras.models.Model([inputs], [logits]) + self.register_variables(self.base_model.variables) + + @override(RecurrentNetwork) + def forward_rnn(self, inputs, state, seq_lens): + # To make Attention work with current RLlib's ModelV2 API: + # We assume `state` is the history of L recent observations (all + # concatenated into one tensor) and append the current inputs to the + # end and only keep the most recent (up to `max_seq_len`). This allows + # us to deal with timestep-wise inference and full sequence training + # within the same logic. + observations = state[0] + observations = tf.concat( + (observations, inputs), axis=1)[:, -self.max_seq_len:] + logits = self.base_model([observations]) + T = tf.shape(inputs)[1] # Length of input segment (time). + logits = logits[:, -T:] + + return logits, [observations] + + @override(RecurrentNetwork) + def get_initial_state(self): + # State is the T last observations concat'd together into one Tensor. + # Plus all Transformer blocks' E(l) outputs concat'd together (up to + # tau timesteps). + return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + + +class GTrXLNet(RecurrentNetwork): + """A GTrXL net Model described in [2]. + + This is still in an experimental phase. + Can be used as a drop-in replacement for LSTMs in PPO and IMPALA. + For an example script, see: `ray/rllib/examples/attention_net.py`. + + To use this network as a replacement for an RNN, configure your Trainer + as follows: + + Examples: + >> config["model"]["custom_model"] = GTrXLNet + >> config["model"]["max_seq_len"] = 10 + >> config["model"]["custom_options"] = { + >> num_transformer_units=1, + >> attn_dim=32, + >> num_heads=2, + >> memory_tau=50, + >> etc.. + >> } + """ + + def __init__(self, + observation_space, + action_space, + num_outputs, + model_config, + name, + num_transformer_units, + attn_dim, + num_heads, + memory_tau, + head_dim, + ff_hidden_dim, + init_gate_bias=2.0): + """Initializes a GTrXLNet. + + Args: + num_transformer_units (int): The number of Transformer repeats to + use (denoted L in [2]). + attn_dim (int): The input and output dimensions of one Transformer + unit. + num_heads (int): The number of attention heads to use in parallel. + Denoted as `H` in [3]. + memory_tau (int): The number of timesteps to store in each + transformer block's memory M (concat'd over time and fed into + next transformer block as input). + head_dim (int): The dimension of a single(!) head. + Denoted as `d` in [3]. + ff_hidden_dim (int): The dimension of the hidden layer within + the position-wise MLP (after the multi-head attention block + within one Transformer unit). This is the size of the first + of the two layers within the PositionwiseFeedforward. The + second layer always has size=`attn_dim`. + init_gate_bias (float): Initial bias values for the GRU gates (two + GRUs per Transformer unit, one after the MHA, one after the + position-wise MLP). + """ + + super().__init__(observation_space, action_space, num_outputs, + model_config, name) + + self.num_transformer_units = num_transformer_units + self.attn_dim = attn_dim + self.num_heads = num_heads + self.memory_tau = memory_tau + self.head_dim = head_dim + self.max_seq_len = model_config["max_seq_len"] + self.obs_dim = observation_space.shape[0] + + # Constant (non-trainable) sinusoid rel pos encoding matrix. + Phi = relative_position_embedding(self.max_seq_len + self.memory_tau, + self.attn_dim) + + # Raw observation input. + input_layer = tf.keras.layers.Input( + shape=(self.max_seq_len, self.obs_dim), name="inputs") + memory_ins = [ + tf.keras.layers.Input( + shape=(self.memory_tau, self.attn_dim), + dtype=tf.float32, + name="memory_in_{}".format(i)) + for i in range(self.num_transformer_units) + ] + + # Map observation dim to input/output transformer (attention) dim. + E_out = tf.keras.layers.Dense(self.attn_dim)(input_layer) + # Output, collected and concat'd to build the internal, tau-len + # Memory units used for additional contextual information. + memory_outs = [E_out] + + # 2) Create L Transformer blocks according to [2]. + for i in range(self.num_transformer_units): + # RelativeMultiHeadAttention part. + MHA_out = SkipConnection( + RelativeMultiHeadAttention( + out_dim=self.attn_dim, + num_heads=num_heads, + head_dim=head_dim, + rel_pos_encoder=Phi, + input_layernorm=True, + output_activation=tf.nn.relu), + fan_in_layer=GRUGate(init_gate_bias), + name="mha_{}".format(i + 1))( + E_out, memory=memory_ins[i]) + # Position-wise MLP part. + E_out = SkipConnection( + tf.keras.Sequential( + (tf.keras.layers.LayerNormalization(axis=-1), + PositionwiseFeedforward( + out_dim=self.attn_dim, + hidden_dim=ff_hidden_dim, + output_activation=tf.nn.relu))), + fan_in_layer=GRUGate(init_gate_bias), + name="pos_wise_mlp_{}".format(i + 1))(MHA_out) + # Output of position-wise MLP == E(l-1), which is concat'd + # to the current Mem block (M(l-1)) to yield E~(l-1), which is then + # used by the next transformer block. + memory_outs.append(E_out) + + # Postprocess TrXL output with another hidden layer and compute values. + logits = tf.keras.layers.Dense( + self.num_outputs, + activation=tf.keras.activations.linear, + name="logits")(E_out) + + self._value_out = None + values_out = tf.keras.layers.Dense( + 1, activation=None, name="values")(E_out) + + self.trxl_model = tf.keras.Model( + inputs=[input_layer] + memory_ins, + outputs=[logits, values_out] + memory_outs[:-1]) + + self.register_variables(self.trxl_model.variables) + self.trxl_model.summary() + + @override(RecurrentNetwork) + def forward_rnn(self, inputs, state, seq_lens): + # To make Attention work with current RLlib's ModelV2 API: + # We assume `state` is the history of L recent observations (all + # concatenated into one tensor) and append the current inputs to the + # end and only keep the most recent (up to `max_seq_len`). This allows + # us to deal with timestep-wise inference and full sequence training + # within the same logic. + observations = state[0] + memory = state[1:] + + observations = tf.concat( + (observations, inputs), axis=1)[:, -self.max_seq_len:] + all_out = self.trxl_model([observations] + memory) + logits, self._value_out = all_out[0], all_out[1] + memory_outs = all_out[2:] + # If memory_tau > max_seq_len -> overlap w/ previous `memory` input. + if self.memory_tau > self.max_seq_len: + memory_outs = [ + tf.concat( + [memory[i][:, -(self.memory_tau - self.max_seq_len):], m], + axis=1) for i, m in enumerate(memory_outs) + ] + else: + memory_outs = [m[:, -self.memory_tau:] for m in memory_outs] + + T = tf.shape(inputs)[1] # Length of input segment (time). + logits = logits[:, -T:] + self._value_out = self._value_out[:, -T:] + + return logits, [observations] + memory_outs + + @override(RecurrentNetwork) + def get_initial_state(self): + # State is the T last observations concat'd together into one Tensor. + # Plus all Transformer blocks' E(l) outputs concat'd together (up to + # tau timesteps). + return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + \ + [np.zeros((self.memory_tau, self.attn_dim), np.float32) + for _ in range(self.num_transformer_units)] + + @override(ModelV2) + def value_function(self): + return tf.reshape(self._value_out, [-1]) + + +def relative_position_embedding(seq_length, out_dim): + """Creates a [seq_length x seq_length] matrix for rel. pos encoding. + + Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding + matrix. + + Args: + seq_length (int): The max. sequence length (time axis). + out_dim (int): The number of nodes to go into the first Tranformer + layer with. + + Returns: + tf.Tensor: The encoding matrix Phi. + """ + inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim)) + pos_offsets = tf.range(seq_length - 1., -1., -1.) + inputs = pos_offsets[:, None] * inverse_freq[None, :] + return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1) diff --git a/rllib/models/tf/fcnet.py b/rllib/models/tf/fcnet.py new file mode 100644 index 000000000..745639a98 --- /dev/null +++ b/rllib/models/tf/fcnet.py @@ -0,0 +1,114 @@ +import numpy as np + +from ray.rllib.models.tf.misc import normc_initializer +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.utils.framework import get_activation_fn, try_import_tf + +tf = try_import_tf() + + +class FullyConnectedNetwork(TFModelV2): + """Generic fully connected network implemented in ModelV2 API.""" + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super(FullyConnectedNetwork, self).__init__( + obs_space, action_space, num_outputs, model_config, name) + + activation = get_activation_fn(model_config.get("fcnet_activation")) + hiddens = model_config.get("fcnet_hiddens", []) + no_final_linear = model_config.get("no_final_linear") + vf_share_layers = model_config.get("vf_share_layers") + free_log_std = model_config.get("free_log_std") + + # Maybe generate free-floating bias variables for the second half of + # the outputs. + if free_log_std: + assert num_outputs % 2 == 0, ( + "num_outputs must be divisible by two", num_outputs) + num_outputs = num_outputs // 2 + self.log_std_var = tf.Variable( + [0.0] * num_outputs, dtype=tf.float32, name="log_std") + self.register_variables([self.log_std_var]) + + # We are using obs_flat, so take the flattened shape as input. + inputs = tf.keras.layers.Input( + shape=(np.product(obs_space.shape), ), name="observations") + last_layer = layer_out = inputs + i = 1 + + # Create layers 0 to second-last. + for size in hiddens[:-1]: + last_layer = tf.keras.layers.Dense( + size, + name="fc_{}".format(i), + activation=activation, + kernel_initializer=normc_initializer(1.0))(last_layer) + i += 1 + + # The last layer is adjusted to be of size num_outputs, but it's a + # layer with activation. + if no_final_linear and num_outputs: + layer_out = tf.keras.layers.Dense( + num_outputs, + name="fc_out", + activation=activation, + kernel_initializer=normc_initializer(1.0))(last_layer) + # Finish the layers with the provided sizes (`hiddens`), plus - + # iff num_outputs > 0 - a last linear layer of size num_outputs. + else: + if len(hiddens) > 0: + last_layer = tf.keras.layers.Dense( + hiddens[-1], + name="fc_{}".format(i), + activation=activation, + kernel_initializer=normc_initializer(1.0))(last_layer) + if num_outputs: + layer_out = tf.keras.layers.Dense( + num_outputs, + name="fc_out", + activation=None, + kernel_initializer=normc_initializer(0.01))(last_layer) + # Adjust num_outputs to be the number of nodes in the last layer. + else: + self.num_outputs = ( + [np.product(obs_space.shape)] + hiddens[-1:-1])[-1] + + # Concat the log std vars to the end of the state-dependent means. + if free_log_std: + + def tiled_log_std(x): + return tf.tile( + tf.expand_dims(self.log_std_var, 0), [tf.shape(x)[0], 1]) + + log_std_out = tf.keras.layers.Lambda(tiled_log_std)(inputs) + layer_out = tf.keras.layers.Concatenate(axis=1)( + [layer_out, log_std_out]) + + if not vf_share_layers: + # build a parallel set of hidden layers for the value net + last_layer = inputs + i = 1 + for size in hiddens: + last_layer = tf.keras.layers.Dense( + size, + name="fc_value_{}".format(i), + activation=activation, + kernel_initializer=normc_initializer(1.0))(last_layer) + i += 1 + + value_out = tf.keras.layers.Dense( + 1, + name="value_out", + activation=None, + kernel_initializer=normc_initializer(0.01))(last_layer) + + self.base_model = tf.keras.Model(inputs, [layer_out, value_out]) + self.register_variables(self.base_model.variables) + + def forward(self, input_dict, state, seq_lens): + model_out, self._value_out = self.base_model(input_dict["obs_flat"]) + return model_out, state + + def value_function(self): + return tf.reshape(self._value_out, [-1]) diff --git a/rllib/models/tf/fcnet_v1.py b/rllib/models/tf/fcnet_v1.py index b5a8b075d..54746111f 100644 --- a/rllib/models/tf/fcnet_v1.py +++ b/rllib/models/tf/fcnet_v1.py @@ -7,7 +7,7 @@ from ray.rllib.utils.framework import get_activation_fn, try_import_tf tf = try_import_tf() -# Deprecated: see as an alternative models/tf/fcnet_v2.py +# Deprecated: see as an alternative models/tf.fcnet.py class FullyConnectedNetwork(Model): """Generic fully connected network.""" diff --git a/rllib/models/tf/fcnet_v2.py b/rllib/models/tf/fcnet_v2.py index 745639a98..4e07e078e 100644 --- a/rllib/models/tf/fcnet_v2.py +++ b/rllib/models/tf/fcnet_v2.py @@ -1,114 +1,7 @@ -import numpy as np +from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as TFFCNet +from ray.rllib.utils.deprecation import renamed_class -from ray.rllib.models.tf.misc import normc_initializer -from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.utils.framework import get_activation_fn, try_import_tf - -tf = try_import_tf() - - -class FullyConnectedNetwork(TFModelV2): - """Generic fully connected network implemented in ModelV2 API.""" - - def __init__(self, obs_space, action_space, num_outputs, model_config, - name): - super(FullyConnectedNetwork, self).__init__( - obs_space, action_space, num_outputs, model_config, name) - - activation = get_activation_fn(model_config.get("fcnet_activation")) - hiddens = model_config.get("fcnet_hiddens", []) - no_final_linear = model_config.get("no_final_linear") - vf_share_layers = model_config.get("vf_share_layers") - free_log_std = model_config.get("free_log_std") - - # Maybe generate free-floating bias variables for the second half of - # the outputs. - if free_log_std: - assert num_outputs % 2 == 0, ( - "num_outputs must be divisible by two", num_outputs) - num_outputs = num_outputs // 2 - self.log_std_var = tf.Variable( - [0.0] * num_outputs, dtype=tf.float32, name="log_std") - self.register_variables([self.log_std_var]) - - # We are using obs_flat, so take the flattened shape as input. - inputs = tf.keras.layers.Input( - shape=(np.product(obs_space.shape), ), name="observations") - last_layer = layer_out = inputs - i = 1 - - # Create layers 0 to second-last. - for size in hiddens[:-1]: - last_layer = tf.keras.layers.Dense( - size, - name="fc_{}".format(i), - activation=activation, - kernel_initializer=normc_initializer(1.0))(last_layer) - i += 1 - - # The last layer is adjusted to be of size num_outputs, but it's a - # layer with activation. - if no_final_linear and num_outputs: - layer_out = tf.keras.layers.Dense( - num_outputs, - name="fc_out", - activation=activation, - kernel_initializer=normc_initializer(1.0))(last_layer) - # Finish the layers with the provided sizes (`hiddens`), plus - - # iff num_outputs > 0 - a last linear layer of size num_outputs. - else: - if len(hiddens) > 0: - last_layer = tf.keras.layers.Dense( - hiddens[-1], - name="fc_{}".format(i), - activation=activation, - kernel_initializer=normc_initializer(1.0))(last_layer) - if num_outputs: - layer_out = tf.keras.layers.Dense( - num_outputs, - name="fc_out", - activation=None, - kernel_initializer=normc_initializer(0.01))(last_layer) - # Adjust num_outputs to be the number of nodes in the last layer. - else: - self.num_outputs = ( - [np.product(obs_space.shape)] + hiddens[-1:-1])[-1] - - # Concat the log std vars to the end of the state-dependent means. - if free_log_std: - - def tiled_log_std(x): - return tf.tile( - tf.expand_dims(self.log_std_var, 0), [tf.shape(x)[0], 1]) - - log_std_out = tf.keras.layers.Lambda(tiled_log_std)(inputs) - layer_out = tf.keras.layers.Concatenate(axis=1)( - [layer_out, log_std_out]) - - if not vf_share_layers: - # build a parallel set of hidden layers for the value net - last_layer = inputs - i = 1 - for size in hiddens: - last_layer = tf.keras.layers.Dense( - size, - name="fc_value_{}".format(i), - activation=activation, - kernel_initializer=normc_initializer(1.0))(last_layer) - i += 1 - - value_out = tf.keras.layers.Dense( - 1, - name="value_out", - activation=None, - kernel_initializer=normc_initializer(0.01))(last_layer) - - self.base_model = tf.keras.Model(inputs, [layer_out, value_out]) - self.register_variables(self.base_model.variables) - - def forward(self, input_dict, state, seq_lens): - model_out, self._value_out = self.base_model(input_dict["obs_flat"]) - return model_out, state - - def value_function(self): - return tf.reshape(self._value_out, [-1]) +FullyConnectedNetwork = renamed_class( + cls=TFFCNet, + old_name="ray.rllib.models.tf.fcnet_v2.FullyConnectedNetwork", +) diff --git a/rllib/models/tf/layers/__init__.py b/rllib/models/tf/layers/__init__.py new file mode 100644 index 000000000..54b82ec1e --- /dev/null +++ b/rllib/models/tf/layers/__init__.py @@ -0,0 +1,6 @@ +from ray.rllib.models.tf.layers.gru_gate import GRUGate +from ray.rllib.models.tf.layers.relative_multi_head_attention import \ + RelativeMultiHeadAttention +from ray.rllib.models.tf.layers.skip_connection import SkipConnection + +__all__ = ["GRUGate", "RelativeMultiHeadAttention", "SkipConnection"] diff --git a/rllib/models/tf/layers/gru_gate.py b/rllib/models/tf/layers/gru_gate.py new file mode 100644 index 000000000..f738626a8 --- /dev/null +++ b/rllib/models/tf/layers/gru_gate.py @@ -0,0 +1,48 @@ +from ray.rllib.utils.framework import try_import_tf + +tf = try_import_tf() + + +class GRUGate(tf.keras.layers.Layer): + def __init__(self, init_bias=0., **kwargs): + super().__init__(**kwargs) + self._init_bias = init_bias + + def build(self, input_shape): + h_shape, x_shape = input_shape + if x_shape[-1] != h_shape[-1]: + raise ValueError( + "Both inputs to GRUGate must have equal size in last axis!") + + dim = int(h_shape[-1]) + self._w_r = self.add_weight(shape=(dim, dim)) + self._w_z = self.add_weight(shape=(dim, dim)) + self._w_h = self.add_weight(shape=(dim, dim)) + + self._u_r = self.add_weight(shape=(dim, dim)) + self._u_z = self.add_weight(shape=(dim, dim)) + self._u_h = self.add_weight(shape=(dim, dim)) + + def bias_initializer(shape, dtype): + return tf.fill(shape, tf.cast(self._init_bias, dtype=dtype)) + + self._bias_z = self.add_weight( + shape=(dim, ), initializer=bias_initializer) + + def call(self, inputs, **kwargs): + # Pass in internal state first. + h, X = inputs + + r = tf.tensordot(X, self._w_r, axes=1) + \ + tf.tensordot(h, self._u_r, axes=1) + r = tf.nn.sigmoid(r) + + z = tf.tensordot(X, self._w_z, axes=1) + \ + tf.tensordot(h, self._u_z, axes=1) - self._bias_z + z = tf.nn.sigmoid(z) + + h_next = tf.tensordot(X, self._w_h, axes=1) + \ + tf.tensordot((h * r), self._u_h, axes=1) + h_next = tf.nn.tanh(h_next) + + return (1 - z) * h + z * h_next diff --git a/rllib/models/tf/layers/multi_head_attention.py b/rllib/models/tf/layers/multi_head_attention.py new file mode 100644 index 000000000..074f53c6a --- /dev/null +++ b/rllib/models/tf/layers/multi_head_attention.py @@ -0,0 +1,51 @@ +""" +[1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar, + Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017. + https://arxiv.org/pdf/1706.03762.pdf +""" +from ray.rllib.utils.framework import try_import_tf + +tf = try_import_tf() + + +class MultiHeadAttention(tf.keras.layers.Layer): + """A multi-head attention layer described in [1].""" + + def __init__(self, out_dim, num_heads, head_dim, **kwargs): + super().__init__(**kwargs) + + # No bias or non-linearity. + self._num_heads = num_heads + self._head_dim = head_dim + self._qkv_layer = tf.keras.layers.Dense( + 3 * num_heads * head_dim, use_bias=False) + self._linear_layer = tf.keras.layers.TimeDistributed( + tf.keras.layers.Dense(out_dim, use_bias=False)) + + def call(self, inputs): + L = tf.shape(inputs)[1] # length of segment + H = self._num_heads # number of attention heads + D = self._head_dim # attention head dimension + + qkv = self._qkv_layer(inputs) + + queries, keys, values = tf.split(qkv, 3, -1) + queries = queries[:, -L:] # only query based on the segment + + queries = tf.reshape(queries, [-1, L, H, D]) + keys = tf.reshape(keys, [-1, L, H, D]) + values = tf.reshape(values, [-1, L, H, D]) + + score = tf.einsum("bihd,bjhd->bijh", queries, keys) + score = score / D**0.5 + + # causal mask of the same length as the sequence + mask = tf.sequence_mask(tf.range(1, L + 1), dtype=score.dtype) + mask = mask[None, :, :, None] + + masked_score = score * mask + 1e30 * (mask - 1.) + wmat = tf.nn.softmax(masked_score, axis=2) + + out = tf.einsum("bijh,bjhd->bihd", wmat, values) + out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * D]), axis=0)) + return self._linear_layer(out) diff --git a/rllib/models/tf/layers/relative_multi_head_attention.py b/rllib/models/tf/layers/relative_multi_head_attention.py new file mode 100644 index 000000000..f9837e5f1 --- /dev/null +++ b/rllib/models/tf/layers/relative_multi_head_attention.py @@ -0,0 +1,119 @@ +from ray.rllib.utils.framework import try_import_tf + +tf = try_import_tf() + + +class RelativeMultiHeadAttention(tf.keras.layers.Layer): + """A RelativeMultiHeadAttention layer as described in [3]. + + Uses segment level recurrence with state reuse. + """ + + def __init__(self, + out_dim, + num_heads, + head_dim, + rel_pos_encoder, + input_layernorm=False, + output_activation=None, + **kwargs): + """Initializes a RelativeMultiHeadAttention keras Layer object. + + Args: + out_dim (int): + num_heads (int): The number of attention heads to use. + Denoted `H` in [2]. + head_dim (int): The dimension of a single(!) attention head + Denoted `D` in [2]. + rel_pos_encoder (: + input_layernorm (bool): Whether to prepend a LayerNorm before + everything else. Should be True for building a GTrXL. + output_activation (Optional[tf.nn.activation]): Optional tf.nn + activation function. Should be relu for GTrXL. + **kwargs: + """ + super().__init__(**kwargs) + + # No bias or non-linearity. + self._num_heads = num_heads + self._head_dim = head_dim + # 3=Query, key, and value inputs. + self._qkv_layer = tf.keras.layers.Dense( + 3 * num_heads * head_dim, use_bias=False) + self._linear_layer = tf.keras.layers.TimeDistributed( + tf.keras.layers.Dense( + out_dim, use_bias=False, activation=output_activation)) + + self._uvar = self.add_weight(shape=(num_heads, head_dim)) + self._vvar = self.add_weight(shape=(num_heads, head_dim)) + + self._pos_proj = tf.keras.layers.Dense( + num_heads * head_dim, use_bias=False) + self._rel_pos_encoder = rel_pos_encoder + + self._input_layernorm = None + if input_layernorm: + self._input_layernorm = tf.keras.layers.LayerNormalization(axis=-1) + + def call(self, inputs, memory=None): + T = tf.shape(inputs)[1] # length of segment (time) + H = self._num_heads # number of attention heads + d = self._head_dim # attention head dimension + + # Add previous memory chunk (as const, w/o gradient) to input. + # Tau (number of (prev) time slices in each memory chunk). + Tau = memory.shape.as_list()[1] if memory is not None else 0 + if memory is not None: + inputs = tf.concat((tf.stop_gradient(memory), inputs), axis=1) + + # Apply the Layer-Norm. + if self._input_layernorm is not None: + inputs = self._input_layernorm(inputs) + + qkv = self._qkv_layer(inputs) + + queries, keys, values = tf.split(qkv, 3, -1) + # Cut out Tau memory timesteps from query. + queries = queries[:, -T:] + + queries = tf.reshape(queries, [-1, T, H, d]) + keys = tf.reshape(keys, [-1, T + Tau, H, d]) + values = tf.reshape(values, [-1, T + Tau, H, d]) + + R = self._pos_proj(self._rel_pos_encoder) + R = tf.reshape(R, [T + Tau, H, d]) + + # b=batch + # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space) + # h=head + # d=head-dim (over which we will reduce-sum) + score = tf.einsum("bihd,bjhd->bijh", queries + self._uvar, keys) + pos_score = tf.einsum("bihd,jhd->bijh", queries + self._vvar, R) + score = score + self.rel_shift(pos_score) + score = score / d**0.5 + + # causal mask of the same length as the sequence + mask = tf.sequence_mask( + tf.range(Tau + 1, T + Tau + 1), dtype=score.dtype) + mask = mask[None, :, :, None] + + masked_score = score * mask + 1e30 * (mask - 1.) + wmat = tf.nn.softmax(masked_score, axis=2) + + out = tf.einsum("bijh,bjhd->bihd", wmat, values) + out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * d]), axis=0)) + return self._linear_layer(out) + + @staticmethod + def rel_shift(x): + # Transposed version of the shift approach described in [3]. + # https://github.com/kimiyoung/transformer-xl/blob/ + # 44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31 + x_size = tf.shape(x) + + x = tf.pad(x, [[0, 0], [0, 0], [1, 0], [0, 0]]) + x = tf.reshape(x, [x_size[0], x_size[2] + 1, x_size[1], x_size[3]]) + x = tf.slice(x, [0, 1, 0, 0], [-1, -1, -1, -1]) + x = tf.reshape(x, x_size) + + return x diff --git a/rllib/models/tf/layers/skip_connection.py b/rllib/models/tf/layers/skip_connection.py new file mode 100644 index 000000000..f56c7b9ac --- /dev/null +++ b/rllib/models/tf/layers/skip_connection.py @@ -0,0 +1,37 @@ +from ray.rllib.utils.framework import try_import_tf + +tf = try_import_tf() + + +class SkipConnection(tf.keras.layers.Layer): + """Skip connection layer. + + Adds the original input to the output (regular residual layer) OR uses + input as hidden state input to a given fan_in_layer. + """ + + def __init__(self, layer, fan_in_layer=None, add_memory=False, **kwargs): + """Initializes a SkipConnection keras layer object. + + Args: + layer (tf.keras.layers.Layer): Any layer processing inputs. + fan_in_layer (Optional[tf.keras.layers.Layer]): An optional + layer taking two inputs: The original input and the output + of `layer`. + """ + super().__init__(**kwargs) + self._layer = layer + self._fan_in_layer = fan_in_layer + + def call(self, inputs, **kwargs): + # del kwargs + outputs = self._layer(inputs, **kwargs) + # Residual case, just add inputs to outputs. + if self._fan_in_layer is None: + outputs = outputs + inputs + # Fan-in e.g. RNN: Call fan-in with `inputs` and `outputs`. + else: + # NOTE: In the GRU case, `inputs` is the state input. + outputs = self._fan_in_layer((inputs, outputs)) + + return outputs diff --git a/rllib/models/tf/lstm_v1.py b/rllib/models/tf/lstm_v1.py index 972e9aedd..c1889340f 100644 --- a/rllib/models/tf/lstm_v1.py +++ b/rllib/models/tf/lstm_v1.py @@ -10,7 +10,7 @@ from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() -# Deprecated: see as an alternative models/tf/recurrent_tf_modelv2.py +# Deprecated: see as an alternative models/tf/recurrent_net.py class LSTM(Model): """Adds a LSTM cell on top of some other model output. @@ -24,7 +24,7 @@ class LSTM(Model): def _build_layers_v2(self, input_dict, num_outputs, options): # Hard deprecate this class. All Models should use the ModelV2 # API from here on. - deprecation_warning("Model->LSTM", "RecurrentTFModelV2", error=False) + deprecation_warning("Model->LSTM", "RecurrentNetwork", error=False) cell_size = options.get("lstm_cell_size") if options.get("lstm_use_prev_action_reward"): diff --git a/rllib/models/tf/recurrent_tf_modelv2.py b/rllib/models/tf/recurrent_net.py similarity index 97% rename from rllib/models/tf/recurrent_tf_modelv2.py rename to rllib/models/tf/recurrent_net.py index 1dd8a43ed..f1abca6b8 100644 --- a/rllib/models/tf/recurrent_tf_modelv2.py +++ b/rllib/models/tf/recurrent_net.py @@ -8,14 +8,14 @@ tf = try_import_tf() @DeveloperAPI -class RecurrentTFModelV2(TFModelV2): +class RecurrentNetwork(TFModelV2): """Helper class to simplify implementing RNN models with TFModelV2. Instead of implementing forward(), you can implement forward_rnn() which takes batches with the time dimension added already. Here is an example implementation for a subclass - ``MyRNNClass(RecurrentTFModelV2)``:: + ``MyRNNClass(RecurrentNetwork)``:: def __init__(self, *args, **kwargs): super(MyModelClass, self).__init__(*args, **kwargs) @@ -50,6 +50,8 @@ class RecurrentTFModelV2(TFModelV2): """Adds time dimension to batch before sending inputs to forward_rnn(). You should implement forward_rnn() in your subclass.""" + assert seq_lens is not None + output, new_state = self.forward_rnn( add_time_dimension( input_dict["obs_flat"], seq_lens, framework="tf"), state, diff --git a/rllib/models/tf/recurrent_tf_model_v2.py b/rllib/models/tf/recurrent_tf_model_v2.py new file mode 100644 index 000000000..de99edb96 --- /dev/null +++ b/rllib/models/tf/recurrent_tf_model_v2.py @@ -0,0 +1,7 @@ +from ray.rllib.models.tf.recurrent_net import RecurrentNetwork +from ray.rllib.utils.deprecation import renamed_class + +RecurrentTFModelV2 = renamed_class( + cls=RecurrentNetwork, + old_name="ray.rllib.models.tf.recurrent_tf_model_v2.RecurrentTFModelV2", +) diff --git a/rllib/models/tf/visionnet.py b/rllib/models/tf/visionnet.py new file mode 100644 index 000000000..4887d04b7 --- /dev/null +++ b/rllib/models/tf/visionnet.py @@ -0,0 +1,118 @@ +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.models.tf.visionnet_v1 import _get_filter_config +from ray.rllib.models.tf.misc import normc_initializer +from ray.rllib.utils.framework import get_activation_fn, try_import_tf + +tf = try_import_tf() + + +class VisionNetwork(TFModelV2): + """Generic vision network implemented in ModelV2 API.""" + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super(VisionNetwork, self).__init__(obs_space, action_space, + num_outputs, model_config, name) + + activation = get_activation_fn(model_config.get("conv_activation")) + filters = model_config.get("conv_filters") + if not filters: + filters = _get_filter_config(obs_space.shape) + no_final_linear = model_config.get("no_final_linear") + vf_share_layers = model_config.get("vf_share_layers") + + inputs = tf.keras.layers.Input( + shape=obs_space.shape, name="observations") + last_layer = inputs + + # Build the action layers + for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1): + last_layer = tf.keras.layers.Conv2D( + out_size, + kernel, + strides=(stride, stride), + activation=activation, + padding="same", + data_format="channels_last", + name="conv{}".format(i))(last_layer) + out_size, kernel, stride = filters[-1] + + # No final linear: Last layer is a Conv2D and uses num_outputs. + if no_final_linear: + last_layer = tf.keras.layers.Conv2D( + num_outputs, + kernel, + strides=(stride, stride), + activation=activation, + padding="valid", + data_format="channels_last", + name="conv_out")(last_layer) + conv_out = last_layer + # Finish network normally (w/o overriding last layer size with + # `num_outputs`), then add another linear one of size `num_outputs`. + else: + last_layer = tf.keras.layers.Conv2D( + out_size, + kernel, + strides=(stride, stride), + activation=activation, + padding="valid", + data_format="channels_last", + name="conv{}".format(i + 1))(last_layer) + conv_out = tf.keras.layers.Conv2D( + num_outputs, [1, 1], + activation=None, + padding="same", + data_format="channels_last", + name="conv_out")(last_layer) + + # Build the value layers + if vf_share_layers: + last_layer = tf.keras.layers.Lambda( + lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer) + value_out = tf.keras.layers.Dense( + 1, + name="value_out", + activation=None, + kernel_initializer=normc_initializer(0.01))(last_layer) + else: + # build a parallel set of hidden layers for the value net + last_layer = inputs + for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1): + last_layer = tf.keras.layers.Conv2D( + out_size, + kernel, + strides=(stride, stride), + activation=activation, + padding="same", + data_format="channels_last", + name="conv_value_{}".format(i))(last_layer) + out_size, kernel, stride = filters[-1] + last_layer = tf.keras.layers.Conv2D( + out_size, + kernel, + strides=(stride, stride), + activation=activation, + padding="valid", + data_format="channels_last", + name="conv_value_{}".format(i + 1))(last_layer) + last_layer = tf.keras.layers.Conv2D( + 1, [1, 1], + activation=None, + padding="same", + data_format="channels_last", + name="conv_value_out")(last_layer) + value_out = tf.keras.layers.Lambda( + lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer) + + self.base_model = tf.keras.Model(inputs, [conv_out, value_out]) + self.register_variables(self.base_model.variables) + + def forward(self, input_dict, state, seq_lens): + # explicit cast to float32 needed in eager + model_out, self._value_out = self.base_model( + tf.cast(input_dict["obs"], tf.float32)) + return tf.squeeze(model_out, axis=[1, 2]), state + + def value_function(self): + return tf.reshape(self._value_out, [-1]) diff --git a/rllib/models/tf/visionnet_v1.py b/rllib/models/tf/visionnet_v1.py index 02d5328ec..539e84e9c 100644 --- a/rllib/models/tf/visionnet_v1.py +++ b/rllib/models/tf/visionnet_v1.py @@ -7,7 +7,7 @@ from ray.rllib.utils.framework import get_activation_fn, try_import_tf tf = try_import_tf() -# Deprecated: see as an alternative models/tf/visionnet_v2.py +# Deprecated: see as an alternative models/tf.visionnet.py class VisionNetwork(Model): """Generic vision network.""" diff --git a/rllib/models/tf/visionnet_v2.py b/rllib/models/tf/visionnet_v2.py index 4887d04b7..1f1e2db92 100644 --- a/rllib/models/tf/visionnet_v2.py +++ b/rllib/models/tf/visionnet_v2.py @@ -1,118 +1,7 @@ -from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.models.tf.visionnet_v1 import _get_filter_config -from ray.rllib.models.tf.misc import normc_initializer -from ray.rllib.utils.framework import get_activation_fn, try_import_tf +from ray.rllib.models.tf.vision_net import VisionNetwork as TFVision +from ray.rllib.utils.deprecation import renamed_class -tf = try_import_tf() - - -class VisionNetwork(TFModelV2): - """Generic vision network implemented in ModelV2 API.""" - - def __init__(self, obs_space, action_space, num_outputs, model_config, - name): - super(VisionNetwork, self).__init__(obs_space, action_space, - num_outputs, model_config, name) - - activation = get_activation_fn(model_config.get("conv_activation")) - filters = model_config.get("conv_filters") - if not filters: - filters = _get_filter_config(obs_space.shape) - no_final_linear = model_config.get("no_final_linear") - vf_share_layers = model_config.get("vf_share_layers") - - inputs = tf.keras.layers.Input( - shape=obs_space.shape, name="observations") - last_layer = inputs - - # Build the action layers - for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1): - last_layer = tf.keras.layers.Conv2D( - out_size, - kernel, - strides=(stride, stride), - activation=activation, - padding="same", - data_format="channels_last", - name="conv{}".format(i))(last_layer) - out_size, kernel, stride = filters[-1] - - # No final linear: Last layer is a Conv2D and uses num_outputs. - if no_final_linear: - last_layer = tf.keras.layers.Conv2D( - num_outputs, - kernel, - strides=(stride, stride), - activation=activation, - padding="valid", - data_format="channels_last", - name="conv_out")(last_layer) - conv_out = last_layer - # Finish network normally (w/o overriding last layer size with - # `num_outputs`), then add another linear one of size `num_outputs`. - else: - last_layer = tf.keras.layers.Conv2D( - out_size, - kernel, - strides=(stride, stride), - activation=activation, - padding="valid", - data_format="channels_last", - name="conv{}".format(i + 1))(last_layer) - conv_out = tf.keras.layers.Conv2D( - num_outputs, [1, 1], - activation=None, - padding="same", - data_format="channels_last", - name="conv_out")(last_layer) - - # Build the value layers - if vf_share_layers: - last_layer = tf.keras.layers.Lambda( - lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer) - value_out = tf.keras.layers.Dense( - 1, - name="value_out", - activation=None, - kernel_initializer=normc_initializer(0.01))(last_layer) - else: - # build a parallel set of hidden layers for the value net - last_layer = inputs - for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1): - last_layer = tf.keras.layers.Conv2D( - out_size, - kernel, - strides=(stride, stride), - activation=activation, - padding="same", - data_format="channels_last", - name="conv_value_{}".format(i))(last_layer) - out_size, kernel, stride = filters[-1] - last_layer = tf.keras.layers.Conv2D( - out_size, - kernel, - strides=(stride, stride), - activation=activation, - padding="valid", - data_format="channels_last", - name="conv_value_{}".format(i + 1))(last_layer) - last_layer = tf.keras.layers.Conv2D( - 1, [1, 1], - activation=None, - padding="same", - data_format="channels_last", - name="conv_value_out")(last_layer) - value_out = tf.keras.layers.Lambda( - lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer) - - self.base_model = tf.keras.Model(inputs, [conv_out, value_out]) - self.register_variables(self.base_model.variables) - - def forward(self, input_dict, state, seq_lens): - # explicit cast to float32 needed in eager - model_out, self._value_out = self.base_model( - tf.cast(input_dict["obs"], tf.float32)) - return tf.squeeze(model_out, axis=[1, 2]), state - - def value_function(self): - return tf.reshape(self._value_out, [-1]) +VisionNetwork = renamed_class( + cls=TFVision, + old_name="ray.rllib.models.tf.visionnet_v2.VisionNetwork", +) diff --git a/rllib/models/torch/__init__.py b/rllib/models/torch/__init__.py index c7575920c..36471e586 100644 --- a/rllib/models/torch/__init__.py +++ b/rllib/models/torch/__init__.py @@ -2,13 +2,13 @@ # dependencies b/c of that. # from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 # from ray.rllib.models.torch.fcnet import FullyConnectedNetwork -# from ray.rllib.models.torch.recurrent_torch_model import \ -# RecurrentTorchModel +# from ray.rllib.models.torch.recurrent_net import \ +# RecurrentNetwork # from ray.rllib.models.torch.visionnet import VisionNetwork # __all__ = [ # "FullyConnectedNetwork", -# "RecurrentTorchModel", +# "RecurrentNetwork", # "TorchModelV2", # "VisionNetwork", # ] diff --git a/rllib/models/torch/attention_net.py b/rllib/models/torch/attention_net.py new file mode 100644 index 000000000..e69de29bb diff --git a/rllib/models/torch/recurrent_torch_model.py b/rllib/models/torch/recurrent_net.py similarity index 96% rename from rllib/models/torch/recurrent_torch_model.py rename to rllib/models/torch/recurrent_net.py index dce8eb95d..934f89402 100644 --- a/rllib/models/torch/recurrent_torch_model.py +++ b/rllib/models/torch/recurrent_net.py @@ -10,14 +10,14 @@ torch, nn = try_import_torch() @DeveloperAPI -class RecurrentTorchModel(TorchModelV2, nn.Module): +class RecurrentNetwork(TorchModelV2, nn.Module): """Helper class to simplify implementing RNN models with TorchModelV2. Instead of implementing forward(), you can implement forward_rnn() which takes batches with the time dimension added already. Here is an example implementation for a subclass - ``MyRNNClass(nn.Module, RecurrentTorchModel)``:: + ``MyRNNClass(nn.Module, RecurrentNetwork)``:: def __init__(self, obs_space, num_outputs): self.obs_size = _get_size(obs_space) @@ -41,7 +41,7 @@ class RecurrentTorchModel(TorchModelV2, nn.Module): assert self._cur_value is not None, "must call forward() first" return self._cur_value - @override(RecurrentTorchModel) + @override(RecurrentNetwork) def forward_rnn(self, input_dict, state, seq_lens): x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float())) h_in = state[0].reshape(-1, self.rnn_hidden_dim) diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index e301c2084..072181e62 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -160,11 +160,8 @@ class DynamicTFPolicy(TFPolicy): action_space=action_space, num_outputs=logit_dim, model_config=self.config["model"], - framework="tf") - # NOTE: Adding below line will break existing custom models - # that do not expect extra options in **kwargs but rather in - # model_config["custom_options"]. - # **self.config["model"].get("custom_options", {})) + framework="tf", + **self.config["model"].get("custom_options", {})) # Create the Exploration object to use for this Policy. self.exploration = self._create_exploration() diff --git a/rllib/tests/test_catalog.py b/rllib/tests/test_catalog.py index c7f8ab7d0..f80a13a81 100644 --- a/rllib/tests/test_catalog.py +++ b/rllib/tests/test_catalog.py @@ -9,8 +9,8 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.models.preprocessors import (NoPreprocessor, OneHotPreprocessor, Preprocessor) -from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork -from ray.rllib.models.tf.visionnet_v2 import VisionNetwork +from ray.rllib.models.tf.fcnet import FullyConnectedNetwork +from ray.rllib.models.tf.visionnet import VisionNetwork from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf diff --git a/rllib/tests/test_lstm.py b/rllib/tests/test_lstm.py index bb3246e04..a2120c516 100644 --- a/rllib/tests/test_lstm.py +++ b/rllib/tests/test_lstm.py @@ -1,20 +1,14 @@ -import gym import numpy as np import pickle import unittest import ray from ray.rllib.agents.ppo import PPOTrainer -from ray.rllib.policy.rnn_sequencing import chop_into_sequences +from ray.rllib.examples.env.debug_counter_env import DebugCounterEnv +from ray.rllib.examples.models.rnn_spy_model import RNNSpyModel from ray.rllib.models import ModelCatalog -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.tf.misc import normc_initializer -from ray.rllib.models.tf.recurrent_tf_modelv2 import RecurrentTFModelV2 +from ray.rllib.policy.rnn_sequencing import chop_into_sequences from ray.tune.registry import register_env -from ray.rllib.utils import try_import_tf -from ray.rllib.utils.annotations import override - -tf = try_import_tf() class TestLSTMUtils(unittest.TestCase): @@ -92,125 +86,6 @@ class TestLSTMUtils(unittest.TestCase): self.assertEqual(seq_lens.tolist(), [1, 2]) -class RNNSpyModel(RecurrentTFModelV2): - capture_index = 0 - cell_size = 3 - - def __init__(self, obs_space, action_space, num_outputs, model_config, - name): - super().__init__(obs_space, action_space, num_outputs, model_config, - name) - self.cell_size = RNNSpyModel.cell_size - - def spy(inputs, seq_lens, h_in, c_in, h_out, c_out): - if len(inputs) == 1: - return 0 # don't capture inference inputs - # TF runs this function in an isolated context, so we have to use - # redis to communicate back to our suite - ray.experimental.internal_kv._internal_kv_put( - "rnn_spy_in_{}".format(RNNSpyModel.capture_index), - pickle.dumps({ - "sequences": inputs, - "seq_lens": seq_lens, - "state_in": [h_in, c_in], - "state_out": [h_out, c_out] - }), - overwrite=True) - RNNSpyModel.capture_index += 1 - return 0 - - # Create a keras LSTM model. - inputs = tf.keras.layers.Input( - shape=(None, ) + obs_space.shape, name="input") - state_in_h = tf.keras.layers.Input(shape=(self.cell_size, ), name="h") - state_in_c = tf.keras.layers.Input(shape=(self.cell_size, ), name="c") - seq_lens = tf.keras.layers.Input( - shape=(), name="seq_lens", dtype=tf.int32) - - lstm_out, state_out_h, state_out_c = tf.keras.layers.LSTM( - self.cell_size, - return_sequences=True, - return_state=True, - name="lstm")( - inputs=inputs, - mask=tf.sequence_mask(seq_lens), - initial_state=[state_in_h, state_in_c]) - self.dense = tf.keras.layers.Dense( - units=self.num_outputs, kernel_initializer=normc_initializer(0.01)) - - def lambda_(inputs): - spy_fn = tf.py_func( - spy, - [ - inputs[0], # observations - inputs[2], # seq_lens - inputs[3], # h_in - inputs[4], # c_in - inputs[5], # h_out - inputs[6], # c_out - ], - tf.int64, - stateful=True) - - # Compute outputs - with tf.control_dependencies([spy_fn]): - return self.dense(inputs[1]) # lstm_out - - logits = tf.keras.layers.Lambda(lambda_)([ - inputs, lstm_out, seq_lens, state_in_h, state_in_c, state_out_h, - state_out_c - ]) - - # Value branch. - value_out = tf.keras.layers.Dense( - units=1, kernel_initializer=normc_initializer(1.0))(lstm_out) - - self.base_model = tf.keras.Model( - [inputs, seq_lens, state_in_h, state_in_c], - [logits, value_out, state_out_h, state_out_c]) - self.base_model.summary() - self.register_variables(self.base_model.variables) - - @override(RecurrentTFModelV2) - def forward_rnn(self, inputs, state, seq_lens): - # Previously, a new class object was created during - # deserialization and this `capture_index` - # variable would be refreshed between class instantiations. - # This behavior is no longer the case, so we manually refresh - # the variable. - RNNSpyModel.capture_index = 0 - model_out, value_out, h, c = self.base_model( - [inputs, seq_lens, state[0], state[1]]) - self._value_out = value_out - return model_out, [h, c] - - @override(ModelV2) - def value_function(self): - return tf.reshape(self._value_out, [-1]) - - @override(ModelV2) - def get_initial_state(self): - return [ - np.zeros(self.cell_size, np.float32), - np.zeros(self.cell_size, np.float32) - ] - - -class DebugCounterEnv(gym.Env): - def __init__(self): - self.action_space = gym.spaces.Discrete(2) - self.observation_space = gym.spaces.Box(0, 100, (1, )) - self.i = 0 - - def reset(self): - self.i = 0 - return [self.i] - - def step(self, action): - self.i += 1 - return [self.i], self.i % 3, self.i >= 15, {} - - class TestRNNSequencing(unittest.TestCase): def setUp(self) -> None: ray.init(num_cpus=4) diff --git a/rllib/tests/test_supported_spaces.py b/rllib/tests/test_supported_spaces.py index 5c850e736..0097a8240 100644 --- a/rllib/tests/test_supported_spaces.py +++ b/rllib/tests/test_supported_spaces.py @@ -9,8 +9,8 @@ from ray.rllib.agents.registry import get_agent_class from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \ MultiAgentMountainCar from ray.rllib.examples.env.random_env import RandomEnv -from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as FCNetV2 -from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as VisionNetV2 +from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as FCNetV2 +from ray.rllib.models.tf.visionnet import VisionNetwork as VisionNetV2 from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVisionNetV2 from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNetV2 from ray.rllib.utils.error import UnsupportedSpaceException @@ -253,11 +253,12 @@ class ModelSupportedSpaces(unittest.TestCase): }) def test_ddpg_multiagent(self): - check_support_multiagent("DDPG", { - "timesteps_per_iteration": 1, - "use_state_preprocessor": True, - "learning_starts": 500, - }) + check_support_multiagent( + "DDPG", { + "timesteps_per_iteration": 1, + "use_state_preprocessor": True, + "learning_starts": 500, + }) def test_dqn_multiagent(self): check_support_multiagent("DQN", {"timesteps_per_iteration": 1})