[RLlib] Attention Net integration into ModelV2 and learning RL example. (#8371)

This commit is contained in:
Sven Mika
2020-05-18 17:26:40 +02:00
committed by GitHub
parent 9347a5d10c
commit 796a834c48
44 changed files with 1279 additions and 911 deletions
+6 -6
View File
@@ -149,7 +149,7 @@ matrix:
- os: linux
env:
- RLLIB_TESTING=1 RLLIB_REGRESSION_TESTS=1
- TF_VERSION=2.0.0b1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.4
- PYTHON=3.6
@@ -182,7 +182,7 @@ matrix:
- os: linux
env:
- RLLIB_TESTING=1 RLLIB_REGRESSION_TESTS_TORCH=1
- TF_VERSION=2.0.0b1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.4
- PYTHON=3.6
@@ -200,7 +200,7 @@ matrix:
env:
- RLLIB_TESTING=1 RLLIB_QUICK_TRAIN_AND_MISC_TESTS=1
- PYTHON=3.6
- TF_VERSION=2.0.0b1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.4
- PYTHONWARNINGS=ignore
@@ -220,7 +220,7 @@ matrix:
env:
- RLLIB_TESTING=1 RLLIB_EXAMPLE_DIR_TESTS=1
- PYTHON=3.6
- TF_VERSION=2.0.0b1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.4
- PYTHONWARNINGS=ignore
@@ -239,7 +239,7 @@ matrix:
env:
- RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_A_TO_L=1
- PYTHON=3.6
- TF_VERSION=2.0.0b1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.4
- PYTHONWARNINGS=ignore
@@ -255,7 +255,7 @@ matrix:
env:
- RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_M_TO_Z=1
- PYTHON=3.6
- TF_VERSION=2.0.0b1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.4
- PYTHONWARNINGS=ignore
+1 -1
View File
@@ -209,7 +209,7 @@ install_dependencies() {
msys*) pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f "${torch_url}";;
esac
pip_packages=(scipy tensorflow=="${TF_VERSION:-2.0.0b1}" cython==0.29.0 gym \
pip_packages=(scipy tensorflow=="${TF_VERSION:-2.1.0}" cython==0.29.0 gym \
opencv-python-headless pyyaml pandas==0.24.2 requests feather-format lxml openpyxl xlrd \
py-spy pytest pytest-timeout networkx tabulate aiohttp uvicorn dataclasses pygments werkzeug \
kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio scikit-learn==0.22.2 numba \
+2 -2
View File
@@ -80,9 +80,9 @@ For a full example of a custom model in code, see the `keras model example <http
Recurrent Models
~~~~~~~~~~~~~~~~
Instead of using the ``use_lstm: True`` option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. For a RNN model it is preferred to subclass ``RecurrentTFModelV2`` to implement ``__init__()``, ``get_initial_state()``, and ``forward_rnn()``. You can check out the `custom_keras_rnn_model.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_keras_rnn_model.py>`__ model as an example to implement your own model:
Instead of using the ``use_lstm: True`` option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. For an RNN model it is preferred to subclass ``RecurrentNetwork`` to implement ``__init__()``, ``get_initial_state()``, and ``forward_rnn()``. You can check out the `custom_rnn_model.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_rnn_model.py>`__ model as an example to implement your own model:
.. autoclass:: ray.rllib.models.tf.recurrent_tf_modelv2.RecurrentTFModelV2
.. autoclass:: ray.rllib.models.tf.recurrent_net.RecurrentNetwork
.. automethod:: __init__
.. automethod:: forward_rnn
+23 -4
View File
@@ -1452,6 +1452,25 @@ py_test(
# --------------------------------------------------------------------
py_test(
name = "examples/attention_net_tf",
main = "examples/attention_net.py",
tags = ["examples", "examples_A"],
size = "large",
srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=80"]
)
# TODO(sven): GTrXL PyTorch.
# py_test(
# name = "examples/attention_net_torch",
# main = "examples/attention_net.py",
# tags = ["examples", "examples_A"],
# size = "large",
# srcs = ["examples/attention_net.py"],
# args = ["--as-test", "--torch", "--stop-reward=90"]
# )
py_test(
name = "examples/autoregressive_action_dist_tf",
main = "examples/autoregressive_action_dist.py",
@@ -1492,7 +1511,7 @@ py_test(
name = "examples/batch_norm_model_dqn_tf",
main = "examples/batch_norm_model.py",
tags = ["examples", "examples_B"],
size = "medium", # DQN learns much slower with BatchNorm.
size = "large", # DQN learns much slower with BatchNorm.
srcs = ["examples/batch_norm_model.py"],
args = ["--as-test", "--run=DQN", "--stop-reward=70"]
)
@@ -1501,7 +1520,7 @@ py_test(
name = "examples/batch_norm_model_dqn_torch",
main = "examples/batch_norm_model.py",
tags = ["examples", "examples_B"],
size = "medium", # DQN learns much slower with BatchNorm.
size = "large", # DQN learns much slower with BatchNorm.
srcs = ["examples/batch_norm_model.py"],
args = ["--as-test", "--torch", "--run=DQN", "--stop-reward=70"]
)
@@ -1555,7 +1574,7 @@ py_test(
name = "examples/cartpole_lstm_ppo_torch",
main = "examples/cartpole_lstm.py",
tags = ["examples", "examples_C"],
size = "small",
size = "medium",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--torch", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
)
@@ -1871,7 +1890,7 @@ py_test(
name = "examples/multi_agent_two_trainers_mixed_torch_tf",
main = "examples/multi_agent_two_trainers.py",
tags = ["examples", "examples_M"],
size = "small",
size = "medium",
srcs = ["examples/multi_agent_two_trainers.py"],
args = ["--as-test", "--mixed-torch-tf", "--stop-reward=70"]
)
+1 -1
View File
@@ -1,6 +1,6 @@
Implementation of the Soft Actor-Critic algorithm:
[1] Soft Actor-Critic Algorithms and Applications - T. Haarnoja, A. Zhou, K. Hartikainen, et. al
[1] Soft Actor-Critic Algorithms and Applications - T. Haarnoja, A. Zhou, K. Hartikainen, et al.
https://arxiv.org/abs/1812.05905.pdf
For supporting discrete action spaces, we implemented this patch on top of the original algorithm:
+75
View File
@@ -0,0 +1,75 @@
import argparse
import ray
from ray import tune
from ray.rllib.utils import try_import_tf
from ray.rllib.models.tf.attention_net import GTrXLNet
from ray.rllib.examples.env.look_and_push import LookAndPush, OneHot
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.examples.env.repeat_initial_obs_env import RepeatInitialObsEnv
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.tune import registry
tf = try_import_tf()
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="RepeatAfterMeEnv")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--torch", action="store_true")
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--stop-iters", type=int, default=200)
parser.add_argument("--stop-timesteps", type=int, default=500000)
parser.add_argument("--stop-reward", type=float, default=80)
if __name__ == "__main__":
args = parser.parse_args()
assert not args.torch, "PyTorch not supported for AttentionNets yet!"
ray.init(num_cpus=args.num_cpus or None, local_mode=True)
registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
registry.register_env("RepeatInitialObsEnv",
lambda _: RepeatInitialObsEnv())
registry.register_env("LookAndPush", lambda _: OneHot(LookAndPush()))
registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())
config = {
"env": args.env,
"env_config": {
"repeat_delay": 2,
},
"gamma": 0.99,
"num_workers": 0,
"num_envs_per_worker": 20,
"entropy_coeff": 0.001,
"num_sgd_iter": 5,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": GTrXLNet,
"max_seq_len": 50,
"custom_options": {
"num_transformer_units": 1,
"attn_dim": 64,
"num_heads": 2,
"memory_tau": 50,
"head_dim": 32,
"ff_hidden_dim": 32,
},
},
"use_pytorch": args.torch,
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, config=config, stop=stop, verbose=1)
if args.as_test:
check_learning_achieved(results, args.stop_reward)
ray.shutdown()
@@ -1,11 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from gym.spaces import Box, Discrete
import numpy as np
from rllib.models.tf import attention
from ray.rllib.utils import try_import_tf
from rllib.models.tf.attention_net import TrXLNet
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
@@ -19,16 +16,6 @@ def bit_shift_generator(seq_length, shift, batch_size):
yield seq, targets
def make_model(seq_length, num_tokens, num_layers, attn_dim, num_heads,
head_dim, ff_hidden_dim):
return tf.keras.Sequential((
attention.make_TrXL(seq_length, num_layers, attn_dim, num_heads,
head_dim, ff_hidden_dim),
tf.keras.layers.Dense(num_tokens),
))
def train_loss(targets, outputs):
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=targets, logits=outputs)
@@ -39,10 +26,13 @@ def train_bit_shift(seq_length, num_iterations, print_every_n):
optimizer = tf.keras.optimizers.Adam(1e-3)
model = make_model(
seq_length,
num_tokens=2,
num_layers=1,
model = TrXLNet(
observation_space=Box(low=0, high=1, shape=(1, ), dtype=np.int32),
action_space=Discrete(2),
num_outputs=2,
model_config={"max_seq_len": seq_length},
name="trxl",
num_transformer_units=1,
attn_dim=10,
num_heads=5,
head_dim=20,
@@ -59,13 +49,20 @@ def train_bit_shift(seq_length, num_iterations, print_every_n):
@tf.function
def update_step(inputs, targets):
optimizer.minimize(lambda: train_loss(targets, model(inputs)),
model_out = model(
{
"obs": inputs
},
state=[tf.reshape(inputs, [-1, seq_length, 1])],
seq_lens=np.full(shape=(train_batch, ), fill_value=seq_length))
optimizer.minimize(lambda: train_loss(targets, model_out),
lambda: model.trainable_variables)
for i, (inputs, targets) in zip(range(num_iterations), data_gen):
inputs_in = np.reshape(inputs, [-1, 1])
targets_in = np.reshape(targets, [-1])
update_step(
tf.convert_to_tensor(inputs), tf.convert_to_tensor(targets))
tf.convert_to_tensor(inputs_in), tf.convert_to_tensor(targets_in))
if i % print_every_n == 0:
test_inputs, test_targets = next(test_gen)
+2 -2
View File
@@ -8,16 +8,16 @@ This example shows:
You can visualize experiment results in ~/ray_results using TensorBoard.
"""
import argparse
import numpy as np
import gym
from gym.spaces import Discrete, Box
import numpy as np
import ray
from ray import tune
from ray.tune import grid_search
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_tf, try_import_torch
+1 -1
View File
@@ -10,7 +10,7 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.agents.dqn.distributional_q_tf_model import \
DistributionalQTFModel
from ray.rllib.utils import try_import_tf
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as MyVisionNetwork
from ray.rllib.models.tf.visionnet import VisionNetwork as MyVisionNetwork
tf = try_import_tf()
+23
View File
@@ -0,0 +1,23 @@
import gym
class DebugCounterEnv(gym.Env):
"""Simple Env that yields a ts counter as observation (0-based).
Actions have no effect.
The episode length is always 15.
Reward is always: current ts % 3.
"""
def __init__(self):
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(0, 100, (1, ))
self.i = 0
def reset(self):
self.i = 0
return [self.i]
def step(self, action):
self.i += 1
return [self.i], self.i % 3, self.i >= 15, {}
+66
View File
@@ -0,0 +1,66 @@
import gym
import numpy as np
class LookAndPush(gym.Env):
"""Memory-requiring Env: Best sequence of actions depends on prev. states.
Optimal behavior:
0) a=0 -> observe next state (s'), which is the "hidden" state.
If a=1 here, the hidden state is not observed.
1) a=1 to always jump to s=2 (not matter what the prev. state was).
2) a=1 to move to s=3.
3) a=1 to move to s=4.
4) a=0 OR 1 depending on s' observed after 0): +10 reward and done.
otherwise: -10 reward and done.
"""
def __init__(self):
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Discrete(5)
self._state = None
self._case = None
def reset(self):
self._state = 2
self._case = np.random.choice(2)
return self._state
def step(self, action):
assert self.action_space.contains(action)
if self._state == 4:
if action and self._case:
return self._state, 10., True, {}
else:
return self._state, -10, True, {}
else:
if action:
if self._state == 0:
self._state = 2
else:
self._state += 1
elif self._state == 2:
self._state = self._case
return self._state, -1, False, {}
class OneHot(gym.Wrapper):
def __init__(self, env):
super(OneHot, self).__init__(env)
self.observation_space = gym.spaces.Box(0., 1.,
(env.observation_space.n, ))
def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
return self._encode_obs(obs)
def step(self, action):
obs, reward, done, info = self.env.step(action)
return self._encode_obs(obs), reward, done, info
def _encode_obs(self, obs):
new_obs = np.ones(self.env.observation_space.n)
new_obs[obs] = 1.0
return new_obs
@@ -2,7 +2,7 @@ from gym.spaces import Box
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
+1 -1
View File
@@ -1,7 +1,7 @@
import random
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
@@ -1,9 +1,9 @@
import numpy as np
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.recurrent_tf_modelv2 import RecurrentTFModelV2
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.recurrent_torch_model import RecurrentTorchModel
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch
@@ -11,7 +11,7 @@ tf = try_import_tf()
torch, nn = try_import_torch()
class MobileV2PlusRNNModel(RecurrentTFModelV2):
class MobileV2PlusRNNModel(RecurrentNetwork):
"""A conv. + recurrent keras net example using a pre-trained MobileNet."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
@@ -71,7 +71,7 @@ class MobileV2PlusRNNModel(RecurrentTFModelV2):
self.register_variables(self.rnn_model.variables)
self.rnn_model.summary()
@override(RecurrentTFModelV2)
@override(RecurrentNetwork)
def forward_rnn(self, inputs, state, seq_lens):
model_out, self._value_out, h, c = self.rnn_model([inputs, seq_lens] +
state)
@@ -89,7 +89,7 @@ class MobileV2PlusRNNModel(RecurrentTFModelV2):
return tf.reshape(self._value_out, [-1])
class TorchMobileV2PlusRNNModel(RecurrentTorchModel):
class TorchMobileV2PlusRNNModel(TorchRNN):
"""A conv. + recurrent torch net example using a pre-trained MobileNet."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
@@ -117,7 +117,7 @@ class TorchMobileV2PlusRNNModel(RecurrentTorchModel):
# Holds the current "base" output (before logits layer).
self._features = None
@override(RecurrentTFModelV2)
@override(TorchRNN)
def forward_rnn(self, inputs, state, seq_lens):
# Create image dims.
vision_in = torch.reshape(inputs, [-1] + self.cnn_shape)
@@ -4,7 +4,7 @@ from ray.rllib.agents.dqn.distributional_q_tf_model import \
DistributionalQTFModel
from ray.rllib.agents.dqn.dqn_torch_model import \
DQNTorchModel
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.numpy import LARGE_INTEGER
+6 -6
View File
@@ -2,8 +2,8 @@ import numpy as np
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.tf.recurrent_tf_modelv2 import RecurrentTFModelV2
from ray.rllib.models.torch.recurrent_torch_model import RecurrentTorchModel
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch
@@ -11,7 +11,7 @@ tf = try_import_tf()
torch, nn = try_import_torch()
class RNNModel(RecurrentTFModelV2):
class RNNModel(RecurrentNetwork):
"""Example of using the Keras functional API to define a RNN model."""
def __init__(self,
@@ -57,7 +57,7 @@ class RNNModel(RecurrentTFModelV2):
self.register_variables(self.rnn_model.variables)
self.rnn_model.summary()
@override(RecurrentTFModelV2)
@override(RecurrentNetwork)
def forward_rnn(self, inputs, state, seq_lens):
model_out, self._value_out, h, c = self.rnn_model([inputs, seq_lens] +
state)
@@ -75,7 +75,7 @@ class RNNModel(RecurrentTFModelV2):
return tf.reshape(self._value_out, [-1])
class TorchRNNModel(RecurrentTorchModel):
class TorchRNNModel(TorchRNN):
def __init__(self,
obs_space,
action_space,
@@ -114,7 +114,7 @@ class TorchRNNModel(RecurrentTorchModel):
assert self._features is not None, "must call forward() first"
return torch.reshape(self.value_branch(self._features), [-1])
@override(RecurrentTorchModel)
@override(TorchRNN)
def forward_rnn(self, inputs, state, seq_lens):
"""Feeds `inputs` (B x T x ..) through the Gru Unit.
+131
View File
@@ -0,0 +1,131 @@
import numpy as np
import pickle
import ray
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
class SpyLayer(tf.keras.layers.Layer):
"""A keras Layer, which intercepts its inputs and stored them as pickled.
"""
def __init__(self, num_outputs, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=num_outputs, kernel_initializer=normc_initializer(0.01))
def call(self, inputs, **kwargs):
"""Does a forward pass through our Dense, but also intercepts inputs.
"""
del kwargs
spy_fn = tf.py_func(
self.spy,
[
inputs[0], # observations
inputs[2], # seq_lens
inputs[3], # h_in
inputs[4], # c_in
inputs[5], # h_out
inputs[6], # c_out
],
tf.int64,
stateful=True)
# Compute outputs
with tf.control_dependencies([spy_fn]):
return self.dense(inputs[1])
@staticmethod
def spy(inputs, seq_lens, h_in, c_in, h_out, c_out):
"""The actual spy operation: Store inputs in internal_kv."""
if len(inputs) == 1:
return 0 # don't capture inference inputs
# TF runs this function in an isolated context, so we have to use
# redis to communicate back to our suite
ray.experimental.internal_kv._internal_kv_put(
"rnn_spy_in_{}".format(RNNSpyModel.capture_index),
pickle.dumps({
"sequences": inputs,
"seq_lens": seq_lens,
"state_in": [h_in, c_in],
"state_out": [h_out, c_out]
}),
overwrite=True)
RNNSpyModel.capture_index += 1
return 0
class RNNSpyModel(RecurrentNetwork):
capture_index = 0
cell_size = 3
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super().__init__(obs_space, action_space, num_outputs, model_config,
name)
self.cell_size = RNNSpyModel.cell_size
# Create a keras LSTM model.
inputs = tf.keras.layers.Input(
shape=(None, ) + obs_space.shape, name="input")
state_in_h = tf.keras.layers.Input(shape=(self.cell_size, ), name="h")
state_in_c = tf.keras.layers.Input(shape=(self.cell_size, ), name="c")
seq_lens = tf.keras.layers.Input(
shape=(), name="seq_lens", dtype=tf.int32)
lstm_out, state_out_h, state_out_c = tf.keras.layers.LSTM(
self.cell_size,
return_sequences=True,
return_state=True,
name="lstm")(
inputs=inputs,
mask=tf.sequence_mask(seq_lens),
initial_state=[state_in_h, state_in_c])
logits = SpyLayer(num_outputs=self.num_outputs)([
inputs, lstm_out, seq_lens, state_in_h, state_in_c, state_out_h,
state_out_c
])
# Value branch.
value_out = tf.keras.layers.Dense(
units=1, kernel_initializer=normc_initializer(1.0))(lstm_out)
self.base_model = tf.keras.Model(
[inputs, seq_lens, state_in_h, state_in_c],
[logits, value_out, state_out_h, state_out_c])
self.base_model.summary()
self.register_variables(self.base_model.variables)
@override(RecurrentNetwork)
def forward_rnn(self, inputs, state, seq_lens):
# Previously, a new class object was created during
# deserialization and this `capture_index`
# variable would be refreshed between class instantiations.
# This behavior is no longer the case, so we manually refresh
# the variable.
RNNSpyModel.capture_index = 0
model_out, value_out, h, c = self.base_model(
[inputs, seq_lens, state[0], state[1]])
self._value_out = value_out
return model_out, [h, c]
@override(ModelV2)
def value_function(self):
return tf.reshape(self._value_out, [-1])
@override(ModelV2)
def get_initial_state(self):
return [
np.zeros(self.cell_size, np.float32),
np.zeros(self.cell_size, np.float32)
]
@@ -18,13 +18,13 @@ import argparse
import ray
from ray import tune
from ray.tune.registry import register_env
from ray.rllib.examples.env.parametric_actions_cartpole import \
ParametricActionsCartPole
from ray.rllib.examples.models.parametric_actions_model import \
ParametricActionsModel, TorchParametricActionsModel
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.tune.registry import register_env
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
-173
View File
@@ -1,173 +0,0 @@
import argparse
import gym
import numpy as np
import ray
from ray import tune
from ray.tune import registry
from ray.rllib import models
from ray.rllib.utils import try_import_tf
from ray.rllib.models.tf import attention
from ray.rllib.models.tf import recurrent_tf_modelv2
from ray.rllib.examples.custom_keras_rnn_model import RepeatAfterMeEnv
from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv
tf = try_import_tf()
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="RepeatAfterMeEnv")
parser.add_argument("--stop", type=int, default=90)
parser.add_argument("--num-cpus", type=int, default=0)
class OneHot(gym.Wrapper):
def __init__(self, env):
super(OneHot, self).__init__(env)
self.observation_space = gym.spaces.Box(0., 1.,
(env.observation_space.n,))
def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
return self._encode_obs(obs)
def step(self, action):
obs, reward, done, info = self.env.step(action)
return self._encode_obs(obs), reward, done, info
def _encode_obs(self, obs):
new_obs = np.ones(self.env.observation_space.n)
new_obs[obs] = 1.0
return new_obs
class LookAndPush(gym.Env):
def __init__(self):
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Discrete(5)
self._state = None
self._case = None
def reset(self):
self._state = 2
self._case = np.random.choice(2)
return self._state
def step(self, action):
assert self.action_space.contains(action)
if self._state == 4:
if action and self._case:
return self._state, 10., True, {}
else:
return self._state, -10, True, {}
else:
if action:
if self._state == 0:
self._state = 2
else:
self._state += 1
elif self._state == 2:
self._state = self._case
return self._state, -1, False, {}
class GRUTrXL(recurrent_tf_modelv2.RecurrentTFModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(GRUTrXL, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
self.max_seq_len = model_config["max_seq_len"]
self.obs_dim = obs_space.shape[0]
input_layer = tf.keras.layers.Input(
shape=(self.max_seq_len, obs_space.shape[0]),
name="inputs",
)
trxl_out = attention.make_GRU_TrXL(
seq_length=model_config["max_seq_len"],
num_layers=model_config["custom_options"]["num_layers"],
attn_dim=model_config["custom_options"]["attn_dim"],
num_heads=model_config["custom_options"]["num_heads"],
head_dim=model_config["custom_options"]["head_dim"],
ff_hidden_dim=model_config["custom_options"]["ff_hidden_dim"],
)(input_layer)
# Postprocess TrXL output with another hidden layer and compute values
logits = tf.keras.layers.Dense(
self.num_outputs,
activation=tf.keras.activations.linear,
name="logits")(trxl_out)
values_out = tf.keras.layers.Dense(
1, activation=None, name="values")(trxl_out)
self.trxl_model = tf.keras.Model(
inputs=[input_layer],
outputs=[logits, values_out],
)
self.register_variables(self.trxl_model.variables)
self.trxl_model.summary()
def forward_rnn(self, inputs, state, seq_lens):
state = state[0]
# We assume state is the history of recent observations and append
# the current inputs to the end and only keep the most recent (up to
# max_seq_len). This allows us to deal with timestep-wise inference
# and full sequence training with the same logic.
state = tf.concat((state, inputs), axis=1)[:, -self.max_seq_len:]
logits, self._value_out = self.trxl_model(state)
in_T = tf.shape(inputs)[1]
logits = logits[:, -in_T:]
self._value_out = self._value_out[:, -in_T:]
return logits, [state]
def get_initial_state(self):
return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)]
def value_function(self):
return tf.reshape(self._value_out, [-1])
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
models.ModelCatalog.register_custom_model("trxl", GRUTrXL)
registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
registry.register_env("RepeatInitialEnv", lambda _: RepeatInitialEnv())
registry.register_env("LookAndPush", lambda _: OneHot(LookAndPush()))
tune.run(
args.run,
stop={"episode_reward_mean": args.stop},
config={
"env": args.env,
"env_config": {
"repeat_delay": 2,
},
"gamma": 0.99,
"num_workers": 0,
"num_envs_per_worker": 20,
"entropy_coeff": 0.001,
"num_sgd_iter": 5,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": "trxl",
"max_seq_len": 10,
"custom_options": {
"num_layers": 1,
"attn_dim": 10,
"num_heads": 1,
"head_dim": 10,
"ff_hidden_dim": 20,
},
},
})
+30 -9
View File
@@ -280,10 +280,13 @@ class ModelCatalog:
"""
if model_config.get("custom_model"):
model_cls = _global_registry.get(RLLIB_MODEL,
model_config["custom_model"])
if isinstance(model_config["custom_model"], type):
model_cls = model_config["custom_model"]
else:
model_cls = _global_registry.get(RLLIB_MODEL,
model_config["custom_model"])
# TODO(sven): Hard-deprecate Model(V1).
if issubclass(model_cls, ModelV2):
logger.info("Wrapping {} as {}".format(model_cls,
model_interface))
model_cls = ModelCatalog._wrap_if_needed(
@@ -299,9 +302,26 @@ class ModelCatalog:
return v
with tf.variable_creator_scope(track_var_creation):
instance = model_cls(obs_space, action_space,
num_outputs, model_config, name,
**model_kwargs)
# Try calling with kwargs first (custom ModelV2 should
# accept these as kwargs, not get them from
# config["custom_options"] anymore)
try:
instance = model_cls(obs_space, action_space,
num_outputs, model_config,
name, **model_kwargs)
except TypeError as e:
# Keyword error: Try old way w/o kwargs.
if "__init__() got an unexpected " in e.args[0]:
logger.warning(
"Custom ModelV2 should accept all custom "
"options as **kwargs, instead of expecting"
" them in config['custom_options']!")
instance = model_cls(obs_space, action_space,
num_outputs, model_config,
name)
# Other error -> re-raise.
else:
raise e
registered = set(instance.variables())
not_registered = set()
for var in created:
@@ -322,7 +342,8 @@ class ModelCatalog:
instance = model_cls(obs_space, action_space, num_outputs,
model_config, name, **model_kwargs)
return instance
# TODO(sven): Hard-deprecate Model(V1). This check will be
# superflous then.
elif tf.executing_eagerly():
raise ValueError(
"Eager execution requires a TFModelV2 model to be "
@@ -536,9 +557,9 @@ class ModelCatalog:
from ray.rllib.models.torch.visionnet import (VisionNetwork as
VisionNet)
else:
from ray.rllib.models.tf.fcnet_v2 import \
from ray.rllib.models.tf.fcnet import \
FullyConnectedNetwork as FCNet
from ray.rllib.models.tf.visionnet_v2 import \
from ray.rllib.models.tf.visionnet import \
VisionNetwork as VisionNet
# Discrete/1D obs-spaces.
+5 -5
View File
@@ -1,12 +1,12 @@
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.models.tf.recurrent_tf_modelv2 import \
RecurrentTFModelV2
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.tf.recurrent_net import \
RecurrentNetwork
from ray.rllib.models.tf.visionnet import VisionNetwork
__all__ = [
"FullyConnectedNetwork",
"RecurrentTFModelV2",
"RecurrentNetwork",
"TFModelV2",
"VisionNetwork",
]
-284
View File
@@ -1,284 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def relative_position_embedding(seq_length, out_dim):
inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim))
pos_offsets = tf.range(seq_length - 1., -1., -1.)
inputs = pos_offsets[:, None] * inverse_freq[None, :]
return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1)
def rel_shift(x):
# Transposed version of the shift approach implemented by Dai et al. 2019
# https://github.com/kimiyoung/transformer-xl/blob/
# 44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31
x_size = tf.shape(x)
x = tf.pad(x, [[0, 0], [0, 0], [1, 0], [0, 0]])
x = tf.reshape(x, [x_size[0], x_size[2] + 1, x_size[1], x_size[3]])
x = tf.slice(x, [0, 1, 0, 0], [-1, -1, -1, -1])
x = tf.reshape(x, x_size)
return x
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, out_dim, num_heads, head_dim, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
# no bias or non-linearity
self._num_heads = num_heads
self._head_dim = head_dim
self._qkv_layer = tf.keras.layers.Dense(
3 * num_heads * head_dim, use_bias=False)
self._linear_layer = tf.keras.layers.TimeDistributed(
tf.keras.layers.Dense(out_dim, use_bias=False))
def call(self, inputs):
L = tf.shape(inputs)[1] # length of segment
H = self._num_heads # number of attention heads
D = self._head_dim # attention head dimension
qkv = self._qkv_layer(inputs)
queries, keys, values = tf.split(qkv, 3, -1)
queries = queries[:, -L:] # only query based on the segment
queries = tf.reshape(queries, [-1, L, H, D])
keys = tf.reshape(keys, [-1, L, H, D])
values = tf.reshape(values, [-1, L, H, D])
score = tf.einsum("bihd,bjhd->bijh", queries, keys)
score = score / D**0.5
# causal mask of the same length as the sequence
mask = tf.sequence_mask(tf.range(1, L + 1), dtype=score.dtype)
mask = mask[None, :, :, None]
masked_score = score * mask + 1e30 * (mask - 1.)
wmat = tf.nn.softmax(masked_score, axis=2)
out = tf.einsum("bijh,bjhd->bihd", wmat, values)
out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * D]), axis=0))
return self._linear_layer(out)
class RelativeMultiHeadAttention(tf.keras.layers.Layer):
def __init__(self,
out_dim,
num_heads,
head_dim,
rel_pos_encoder,
input_layernorm=False,
output_activation=None,
**kwargs):
super(RelativeMultiHeadAttention, self).__init__(**kwargs)
# no bias or non-linearity
self._num_heads = num_heads
self._head_dim = head_dim
self._qkv_layer = tf.keras.layers.Dense(
3 * num_heads * head_dim, use_bias=False)
self._linear_layer = tf.keras.layers.TimeDistributed(
tf.keras.layers.Dense(
out_dim, use_bias=False, activation=output_activation))
self._uvar = self.add_weight(shape=(num_heads, head_dim))
self._vvar = self.add_weight(shape=(num_heads, head_dim))
self._pos_proj = tf.keras.layers.Dense(
num_heads * head_dim, use_bias=False)
self._rel_pos_encoder = rel_pos_encoder
self._input_layernorm = None
if input_layernorm:
self._input_layernorm = tf.keras.layers.LayerNormalization(axis=-1)
def call(self, inputs, memory=None):
L = tf.shape(inputs)[1] # length of segment
H = self._num_heads # number of attention heads
D = self._head_dim # attention head dimension
# length of the memory segment
M = memory.shape[0] if memory is not None else 0
if memory is not None:
inputs = np.concatenate((tf.stop_gradient(memory), inputs), axis=1)
if self._input_layernorm is not None:
inputs = self._input_layernorm(inputs)
qkv = self._qkv_layer(inputs)
queries, keys, values = tf.split(qkv, 3, -1)
queries = queries[:, -L:] # only query based on the segment
queries = tf.reshape(queries, [-1, L, H, D])
keys = tf.reshape(keys, [-1, L + M, H, D])
values = tf.reshape(values, [-1, L + M, H, D])
rel = self._pos_proj(self._rel_pos_encoder)
rel = tf.reshape(rel, [L, H, D])
score = tf.einsum("bihd,bjhd->bijh", queries + self._uvar, keys)
pos_score = tf.einsum("bihd,jhd->bijh", queries + self._vvar, rel)
score = score + rel_shift(pos_score)
score = score / D**0.5
# causal mask of the same length as the sequence
mask = tf.sequence_mask(tf.range(M + 1, L + M + 1), dtype=score.dtype)
mask = mask[None, :, :, None]
masked_score = score * mask + 1e30 * (mask - 1.)
wmat = tf.nn.softmax(masked_score, axis=2)
out = tf.einsum("bijh,bjhd->bihd", wmat, values)
out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * D]), axis=0))
return self._linear_layer(out)
class PositionwiseFeedforward(tf.keras.layers.Layer):
def __init__(self, out_dim, hidden_dim, output_activation=None, **kwargs):
super(PositionwiseFeedforward, self).__init__(**kwargs)
self._hidden_layer = tf.keras.layers.Dense(
hidden_dim,
activation=tf.nn.relu,
)
self._output_layer = tf.keras.layers.Dense(
out_dim, activation=output_activation)
def call(self, inputs, **kwargs):
del kwargs
output = self._hidden_layer(inputs)
return self._output_layer(output)
class SkipConnection(tf.keras.layers.Layer):
"""Skip connection layer.
If no fan-in layer is specified, then this layer behaves as a regular
residual layer.
"""
def __init__(self, layer, fan_in_layer=None, **kwargs):
super(SkipConnection, self).__init__(**kwargs)
self._fan_in_layer = fan_in_layer
self._layer = layer
def call(self, inputs, **kwargs):
del kwargs
outputs = self._layer(inputs)
if self._fan_in_layer is None:
outputs = outputs + inputs
else:
outputs = self._fan_in_layer((inputs, outputs))
return outputs
class GRUGate(tf.keras.layers.Layer):
def __init__(self, init_bias=0., **kwargs):
super(GRUGate, self).__init__(**kwargs)
self._init_bias = init_bias
def build(self, input_shape):
x_shape, y_shape = input_shape
if x_shape[-1] != y_shape[-1]:
raise ValueError(
"Both inputs to GRUGate must equal size last axis.")
self._w_r = self.add_weight(shape=(y_shape[-1], y_shape[-1]))
self._w_z = self.add_weight(shape=(y_shape[-1], y_shape[-1]))
self._w_h = self.add_weight(shape=(y_shape[-1], y_shape[-1]))
self._u_r = self.add_weight(shape=(x_shape[-1], x_shape[-1]))
self._u_z = self.add_weight(shape=(x_shape[-1], x_shape[-1]))
self._u_h = self.add_weight(shape=(x_shape[-1], x_shape[-1]))
def bias_initializer(shape, dtype):
return tf.fill(shape, tf.cast(self._init_bias, dtype=dtype))
self._bias_z = self.add_weight(
shape=(x_shape[-1], ), initializer=bias_initializer)
def call(self, inputs, **kwargs):
x, y = inputs
r = (tf.tensordot(y, self._w_r, axes=1) + tf.tensordot(
x, self._u_r, axes=1))
r = tf.nn.sigmoid(r)
z = (tf.tensordot(y, self._w_z, axes=1) + tf.tensordot(
x, self._u_z, axes=1) + self._bias_z)
z = tf.nn.sigmoid(z)
h = (tf.tensordot(y, self._w_h, axes=1) + tf.tensordot(
(x * r), self._u_h, axes=1))
h = tf.nn.tanh(h)
return (1 - z) * x + z * h
def make_TrXL(seq_length, num_layers, attn_dim, num_heads, head_dim,
ff_hidden_dim):
pos_embedding = relative_position_embedding(seq_length, attn_dim)
layers = [tf.keras.layers.Dense(attn_dim)]
for _ in range(num_layers):
layers.append(
SkipConnection(
RelativeMultiHeadAttention(attn_dim, num_heads, head_dim,
pos_embedding)))
layers.append(tf.keras.layers.LayerNormalization(axis=-1))
layers.append(
SkipConnection(PositionwiseFeedforward(attn_dim, ff_hidden_dim)))
layers.append(tf.keras.layers.LayerNormalization(axis=-1))
return tf.keras.Sequential(layers)
def make_GRU_TrXL(seq_length,
num_layers,
attn_dim,
num_heads,
head_dim,
ff_hidden_dim,
init_gate_bias=2.):
# Default initial bias for the gate taken from
# Parisotto, Emilio, et al. "Stabilizing Transformers for Reinforcement
# Learning." arXiv preprint arXiv:1910.06764 (2019).
pos_embedding = relative_position_embedding(seq_length, attn_dim)
layers = [tf.keras.layers.Dense(attn_dim)]
for _ in range(num_layers):
layers.append(
SkipConnection(
RelativeMultiHeadAttention(
attn_dim,
num_heads,
head_dim,
pos_embedding,
input_layernorm=True,
output_activation=tf.nn.relu),
fan_in_layer=GRUGate(init_gate_bias),
))
layers.append(
SkipConnection(
tf.keras.Sequential(
(tf.keras.layers.LayerNormalization(axis=-1),
PositionwiseFeedforward(
attn_dim, ff_hidden_dim,
output_activation=tf.nn.relu))),
fan_in_layer=GRUGate(init_gate_bias),
))
return tf.keras.Sequential(layers)
+336
View File
@@ -0,0 +1,336 @@
"""
[1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar,
Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017.
https://arxiv.org/pdf/1706.03762.pdf
[2] - Stabilizing Transformers for Reinforcement Learning - E. Parisotto
et al. - DeepMind - 2019. https://arxiv.org/pdf/1910.06764.pdf
[3] - Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.
Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019.
https://www.aclweb.org/anthology/P19-1285.pdf
"""
import numpy as np
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \
SkipConnection
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
# TODO(sven): Use RLlib's FCNet instead.
class PositionwiseFeedforward(tf.keras.layers.Layer):
"""A 2x linear layer with ReLU activation in between described in [1].
Each timestep coming from the attention head will be passed through this
layer separately.
"""
def __init__(self, out_dim, hidden_dim, output_activation=None, **kwargs):
super().__init__(**kwargs)
self._hidden_layer = tf.keras.layers.Dense(
hidden_dim,
activation=tf.nn.relu,
)
self._output_layer = tf.keras.layers.Dense(
out_dim, activation=output_activation)
def call(self, inputs, **kwargs):
del kwargs
output = self._hidden_layer(inputs)
return self._output_layer(output)
class TrXLNet(RecurrentNetwork):
"""A TrXL net Model described in [1]."""
def __init__(self, observation_space, action_space, num_outputs,
model_config, name, num_transformer_units, attn_dim,
num_heads, head_dim, ff_hidden_dim):
"""Initializes a TfXLNet object.
Args:
num_transformer_units (int): The number of Transformer repeats to
use (denoted L in [2]).
attn_dim (int): The input and output dimensions of one Transformer
unit.
num_heads (int): The number of attention heads to use in parallel.
Denoted as `H` in [3].
head_dim (int): The dimension of a single(!) head.
Denoted as `d` in [3].
ff_hidden_dim (int): The dimension of the hidden layer within
the position-wise MLP (after the multi-head attention block
within one Transformer unit). This is the size of the first
of the two layers within the PositionwiseFeedforward. The
second layer always has size=`attn_dim`.
"""
super().__init__(observation_space, action_space, num_outputs,
model_config, name)
self.num_transformer_units = num_transformer_units
self.attn_dim = attn_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.max_seq_len = model_config["max_seq_len"]
self.obs_dim = observation_space.shape[0]
pos_embedding = relative_position_embedding(self.max_seq_len, attn_dim)
inputs = tf.keras.layers.Input(
shape=(self.max_seq_len, self.obs_dim), name="inputs")
E_out = tf.keras.layers.Dense(attn_dim)(inputs)
for _ in range(self.num_transformer_units):
MHA_out = SkipConnection(
RelativeMultiHeadAttention(
out_dim=attn_dim,
num_heads=num_heads,
head_dim=head_dim,
rel_pos_encoder=pos_embedding,
input_layernorm=False,
output_activation=None),
fan_in_layer=None)(E_out)
E_out = SkipConnection(
PositionwiseFeedforward(attn_dim, ff_hidden_dim))(MHA_out)
E_out = tf.keras.layers.LayerNormalization(axis=-1)(E_out)
# Postprocess TrXL output with another hidden layer and compute values.
logits = tf.keras.layers.Dense(
self.num_outputs,
activation=tf.keras.activations.linear,
name="logits")(E_out)
self.base_model = tf.keras.models.Model([inputs], [logits])
self.register_variables(self.base_model.variables)
@override(RecurrentNetwork)
def forward_rnn(self, inputs, state, seq_lens):
# To make Attention work with current RLlib's ModelV2 API:
# We assume `state` is the history of L recent observations (all
# concatenated into one tensor) and append the current inputs to the
# end and only keep the most recent (up to `max_seq_len`). This allows
# us to deal with timestep-wise inference and full sequence training
# within the same logic.
observations = state[0]
observations = tf.concat(
(observations, inputs), axis=1)[:, -self.max_seq_len:]
logits = self.base_model([observations])
T = tf.shape(inputs)[1] # Length of input segment (time).
logits = logits[:, -T:]
return logits, [observations]
@override(RecurrentNetwork)
def get_initial_state(self):
# State is the T last observations concat'd together into one Tensor.
# Plus all Transformer blocks' E(l) outputs concat'd together (up to
# tau timesteps).
return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)]
class GTrXLNet(RecurrentNetwork):
"""A GTrXL net Model described in [2].
This is still in an experimental phase.
Can be used as a drop-in replacement for LSTMs in PPO and IMPALA.
For an example script, see: `ray/rllib/examples/attention_net.py`.
To use this network as a replacement for an RNN, configure your Trainer
as follows:
Examples:
>> config["model"]["custom_model"] = GTrXLNet
>> config["model"]["max_seq_len"] = 10
>> config["model"]["custom_options"] = {
>> num_transformer_units=1,
>> attn_dim=32,
>> num_heads=2,
>> memory_tau=50,
>> etc..
>> }
"""
def __init__(self,
observation_space,
action_space,
num_outputs,
model_config,
name,
num_transformer_units,
attn_dim,
num_heads,
memory_tau,
head_dim,
ff_hidden_dim,
init_gate_bias=2.0):
"""Initializes a GTrXLNet.
Args:
num_transformer_units (int): The number of Transformer repeats to
use (denoted L in [2]).
attn_dim (int): The input and output dimensions of one Transformer
unit.
num_heads (int): The number of attention heads to use in parallel.
Denoted as `H` in [3].
memory_tau (int): The number of timesteps to store in each
transformer block's memory M (concat'd over time and fed into
next transformer block as input).
head_dim (int): The dimension of a single(!) head.
Denoted as `d` in [3].
ff_hidden_dim (int): The dimension of the hidden layer within
the position-wise MLP (after the multi-head attention block
within one Transformer unit). This is the size of the first
of the two layers within the PositionwiseFeedforward. The
second layer always has size=`attn_dim`.
init_gate_bias (float): Initial bias values for the GRU gates (two
GRUs per Transformer unit, one after the MHA, one after the
position-wise MLP).
"""
super().__init__(observation_space, action_space, num_outputs,
model_config, name)
self.num_transformer_units = num_transformer_units
self.attn_dim = attn_dim
self.num_heads = num_heads
self.memory_tau = memory_tau
self.head_dim = head_dim
self.max_seq_len = model_config["max_seq_len"]
self.obs_dim = observation_space.shape[0]
# Constant (non-trainable) sinusoid rel pos encoding matrix.
Phi = relative_position_embedding(self.max_seq_len + self.memory_tau,
self.attn_dim)
# Raw observation input.
input_layer = tf.keras.layers.Input(
shape=(self.max_seq_len, self.obs_dim), name="inputs")
memory_ins = [
tf.keras.layers.Input(
shape=(self.memory_tau, self.attn_dim),
dtype=tf.float32,
name="memory_in_{}".format(i))
for i in range(self.num_transformer_units)
]
# Map observation dim to input/output transformer (attention) dim.
E_out = tf.keras.layers.Dense(self.attn_dim)(input_layer)
# Output, collected and concat'd to build the internal, tau-len
# Memory units used for additional contextual information.
memory_outs = [E_out]
# 2) Create L Transformer blocks according to [2].
for i in range(self.num_transformer_units):
# RelativeMultiHeadAttention part.
MHA_out = SkipConnection(
RelativeMultiHeadAttention(
out_dim=self.attn_dim,
num_heads=num_heads,
head_dim=head_dim,
rel_pos_encoder=Phi,
input_layernorm=True,
output_activation=tf.nn.relu),
fan_in_layer=GRUGate(init_gate_bias),
name="mha_{}".format(i + 1))(
E_out, memory=memory_ins[i])
# Position-wise MLP part.
E_out = SkipConnection(
tf.keras.Sequential(
(tf.keras.layers.LayerNormalization(axis=-1),
PositionwiseFeedforward(
out_dim=self.attn_dim,
hidden_dim=ff_hidden_dim,
output_activation=tf.nn.relu))),
fan_in_layer=GRUGate(init_gate_bias),
name="pos_wise_mlp_{}".format(i + 1))(MHA_out)
# Output of position-wise MLP == E(l-1), which is concat'd
# to the current Mem block (M(l-1)) to yield E~(l-1), which is then
# used by the next transformer block.
memory_outs.append(E_out)
# Postprocess TrXL output with another hidden layer and compute values.
logits = tf.keras.layers.Dense(
self.num_outputs,
activation=tf.keras.activations.linear,
name="logits")(E_out)
self._value_out = None
values_out = tf.keras.layers.Dense(
1, activation=None, name="values")(E_out)
self.trxl_model = tf.keras.Model(
inputs=[input_layer] + memory_ins,
outputs=[logits, values_out] + memory_outs[:-1])
self.register_variables(self.trxl_model.variables)
self.trxl_model.summary()
@override(RecurrentNetwork)
def forward_rnn(self, inputs, state, seq_lens):
# To make Attention work with current RLlib's ModelV2 API:
# We assume `state` is the history of L recent observations (all
# concatenated into one tensor) and append the current inputs to the
# end and only keep the most recent (up to `max_seq_len`). This allows
# us to deal with timestep-wise inference and full sequence training
# within the same logic.
observations = state[0]
memory = state[1:]
observations = tf.concat(
(observations, inputs), axis=1)[:, -self.max_seq_len:]
all_out = self.trxl_model([observations] + memory)
logits, self._value_out = all_out[0], all_out[1]
memory_outs = all_out[2:]
# If memory_tau > max_seq_len -> overlap w/ previous `memory` input.
if self.memory_tau > self.max_seq_len:
memory_outs = [
tf.concat(
[memory[i][:, -(self.memory_tau - self.max_seq_len):], m],
axis=1) for i, m in enumerate(memory_outs)
]
else:
memory_outs = [m[:, -self.memory_tau:] for m in memory_outs]
T = tf.shape(inputs)[1] # Length of input segment (time).
logits = logits[:, -T:]
self._value_out = self._value_out[:, -T:]
return logits, [observations] + memory_outs
@override(RecurrentNetwork)
def get_initial_state(self):
# State is the T last observations concat'd together into one Tensor.
# Plus all Transformer blocks' E(l) outputs concat'd together (up to
# tau timesteps).
return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + \
[np.zeros((self.memory_tau, self.attn_dim), np.float32)
for _ in range(self.num_transformer_units)]
@override(ModelV2)
def value_function(self):
return tf.reshape(self._value_out, [-1])
def relative_position_embedding(seq_length, out_dim):
"""Creates a [seq_length x seq_length] matrix for rel. pos encoding.
Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding
matrix.
Args:
seq_length (int): The max. sequence length (time axis).
out_dim (int): The number of nodes to go into the first Tranformer
layer with.
Returns:
tf.Tensor: The encoding matrix Phi.
"""
inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim))
pos_offsets = tf.range(seq_length - 1., -1., -1.)
inputs = pos_offsets[:, None] * inverse_freq[None, :]
return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1)
+114
View File
@@ -0,0 +1,114 @@
import numpy as np
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
tf = try_import_tf()
class FullyConnectedNetwork(TFModelV2):
"""Generic fully connected network implemented in ModelV2 API."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(FullyConnectedNetwork, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
activation = get_activation_fn(model_config.get("fcnet_activation"))
hiddens = model_config.get("fcnet_hiddens", [])
no_final_linear = model_config.get("no_final_linear")
vf_share_layers = model_config.get("vf_share_layers")
free_log_std = model_config.get("free_log_std")
# Maybe generate free-floating bias variables for the second half of
# the outputs.
if free_log_std:
assert num_outputs % 2 == 0, (
"num_outputs must be divisible by two", num_outputs)
num_outputs = num_outputs // 2
self.log_std_var = tf.Variable(
[0.0] * num_outputs, dtype=tf.float32, name="log_std")
self.register_variables([self.log_std_var])
# We are using obs_flat, so take the flattened shape as input.
inputs = tf.keras.layers.Input(
shape=(np.product(obs_space.shape), ), name="observations")
last_layer = layer_out = inputs
i = 1
# Create layers 0 to second-last.
for size in hiddens[:-1]:
last_layer = tf.keras.layers.Dense(
size,
name="fc_{}".format(i),
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
i += 1
# The last layer is adjusted to be of size num_outputs, but it's a
# layer with activation.
if no_final_linear and num_outputs:
layer_out = tf.keras.layers.Dense(
num_outputs,
name="fc_out",
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
# Finish the layers with the provided sizes (`hiddens`), plus -
# iff num_outputs > 0 - a last linear layer of size num_outputs.
else:
if len(hiddens) > 0:
last_layer = tf.keras.layers.Dense(
hiddens[-1],
name="fc_{}".format(i),
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
if num_outputs:
layer_out = tf.keras.layers.Dense(
num_outputs,
name="fc_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
# Adjust num_outputs to be the number of nodes in the last layer.
else:
self.num_outputs = (
[np.product(obs_space.shape)] + hiddens[-1:-1])[-1]
# Concat the log std vars to the end of the state-dependent means.
if free_log_std:
def tiled_log_std(x):
return tf.tile(
tf.expand_dims(self.log_std_var, 0), [tf.shape(x)[0], 1])
log_std_out = tf.keras.layers.Lambda(tiled_log_std)(inputs)
layer_out = tf.keras.layers.Concatenate(axis=1)(
[layer_out, log_std_out])
if not vf_share_layers:
# build a parallel set of hidden layers for the value net
last_layer = inputs
i = 1
for size in hiddens:
last_layer = tf.keras.layers.Dense(
size,
name="fc_value_{}".format(i),
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
i += 1
value_out = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
self.base_model = tf.keras.Model(inputs, [layer_out, value_out])
self.register_variables(self.base_model.variables)
def forward(self, input_dict, state, seq_lens):
model_out, self._value_out = self.base_model(input_dict["obs_flat"])
return model_out, state
def value_function(self):
return tf.reshape(self._value_out, [-1])
+1 -1
View File
@@ -7,7 +7,7 @@ from ray.rllib.utils.framework import get_activation_fn, try_import_tf
tf = try_import_tf()
# Deprecated: see as an alternative models/tf/fcnet_v2.py
# Deprecated: see as an alternative models/tf.fcnet.py
class FullyConnectedNetwork(Model):
"""Generic fully connected network."""
+6 -113
View File
@@ -1,114 +1,7 @@
import numpy as np
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as TFFCNet
from ray.rllib.utils.deprecation import renamed_class
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
tf = try_import_tf()
class FullyConnectedNetwork(TFModelV2):
"""Generic fully connected network implemented in ModelV2 API."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(FullyConnectedNetwork, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
activation = get_activation_fn(model_config.get("fcnet_activation"))
hiddens = model_config.get("fcnet_hiddens", [])
no_final_linear = model_config.get("no_final_linear")
vf_share_layers = model_config.get("vf_share_layers")
free_log_std = model_config.get("free_log_std")
# Maybe generate free-floating bias variables for the second half of
# the outputs.
if free_log_std:
assert num_outputs % 2 == 0, (
"num_outputs must be divisible by two", num_outputs)
num_outputs = num_outputs // 2
self.log_std_var = tf.Variable(
[0.0] * num_outputs, dtype=tf.float32, name="log_std")
self.register_variables([self.log_std_var])
# We are using obs_flat, so take the flattened shape as input.
inputs = tf.keras.layers.Input(
shape=(np.product(obs_space.shape), ), name="observations")
last_layer = layer_out = inputs
i = 1
# Create layers 0 to second-last.
for size in hiddens[:-1]:
last_layer = tf.keras.layers.Dense(
size,
name="fc_{}".format(i),
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
i += 1
# The last layer is adjusted to be of size num_outputs, but it's a
# layer with activation.
if no_final_linear and num_outputs:
layer_out = tf.keras.layers.Dense(
num_outputs,
name="fc_out",
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
# Finish the layers with the provided sizes (`hiddens`), plus -
# iff num_outputs > 0 - a last linear layer of size num_outputs.
else:
if len(hiddens) > 0:
last_layer = tf.keras.layers.Dense(
hiddens[-1],
name="fc_{}".format(i),
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
if num_outputs:
layer_out = tf.keras.layers.Dense(
num_outputs,
name="fc_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
# Adjust num_outputs to be the number of nodes in the last layer.
else:
self.num_outputs = (
[np.product(obs_space.shape)] + hiddens[-1:-1])[-1]
# Concat the log std vars to the end of the state-dependent means.
if free_log_std:
def tiled_log_std(x):
return tf.tile(
tf.expand_dims(self.log_std_var, 0), [tf.shape(x)[0], 1])
log_std_out = tf.keras.layers.Lambda(tiled_log_std)(inputs)
layer_out = tf.keras.layers.Concatenate(axis=1)(
[layer_out, log_std_out])
if not vf_share_layers:
# build a parallel set of hidden layers for the value net
last_layer = inputs
i = 1
for size in hiddens:
last_layer = tf.keras.layers.Dense(
size,
name="fc_value_{}".format(i),
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
i += 1
value_out = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
self.base_model = tf.keras.Model(inputs, [layer_out, value_out])
self.register_variables(self.base_model.variables)
def forward(self, input_dict, state, seq_lens):
model_out, self._value_out = self.base_model(input_dict["obs_flat"])
return model_out, state
def value_function(self):
return tf.reshape(self._value_out, [-1])
FullyConnectedNetwork = renamed_class(
cls=TFFCNet,
old_name="ray.rllib.models.tf.fcnet_v2.FullyConnectedNetwork",
)
+6
View File
@@ -0,0 +1,6 @@
from ray.rllib.models.tf.layers.gru_gate import GRUGate
from ray.rllib.models.tf.layers.relative_multi_head_attention import \
RelativeMultiHeadAttention
from ray.rllib.models.tf.layers.skip_connection import SkipConnection
__all__ = ["GRUGate", "RelativeMultiHeadAttention", "SkipConnection"]
+48
View File
@@ -0,0 +1,48 @@
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
class GRUGate(tf.keras.layers.Layer):
def __init__(self, init_bias=0., **kwargs):
super().__init__(**kwargs)
self._init_bias = init_bias
def build(self, input_shape):
h_shape, x_shape = input_shape
if x_shape[-1] != h_shape[-1]:
raise ValueError(
"Both inputs to GRUGate must have equal size in last axis!")
dim = int(h_shape[-1])
self._w_r = self.add_weight(shape=(dim, dim))
self._w_z = self.add_weight(shape=(dim, dim))
self._w_h = self.add_weight(shape=(dim, dim))
self._u_r = self.add_weight(shape=(dim, dim))
self._u_z = self.add_weight(shape=(dim, dim))
self._u_h = self.add_weight(shape=(dim, dim))
def bias_initializer(shape, dtype):
return tf.fill(shape, tf.cast(self._init_bias, dtype=dtype))
self._bias_z = self.add_weight(
shape=(dim, ), initializer=bias_initializer)
def call(self, inputs, **kwargs):
# Pass in internal state first.
h, X = inputs
r = tf.tensordot(X, self._w_r, axes=1) + \
tf.tensordot(h, self._u_r, axes=1)
r = tf.nn.sigmoid(r)
z = tf.tensordot(X, self._w_z, axes=1) + \
tf.tensordot(h, self._u_z, axes=1) - self._bias_z
z = tf.nn.sigmoid(z)
h_next = tf.tensordot(X, self._w_h, axes=1) + \
tf.tensordot((h * r), self._u_h, axes=1)
h_next = tf.nn.tanh(h_next)
return (1 - z) * h + z * h_next
@@ -0,0 +1,51 @@
"""
[1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar,
Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017.
https://arxiv.org/pdf/1706.03762.pdf
"""
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
class MultiHeadAttention(tf.keras.layers.Layer):
"""A multi-head attention layer described in [1]."""
def __init__(self, out_dim, num_heads, head_dim, **kwargs):
super().__init__(**kwargs)
# No bias or non-linearity.
self._num_heads = num_heads
self._head_dim = head_dim
self._qkv_layer = tf.keras.layers.Dense(
3 * num_heads * head_dim, use_bias=False)
self._linear_layer = tf.keras.layers.TimeDistributed(
tf.keras.layers.Dense(out_dim, use_bias=False))
def call(self, inputs):
L = tf.shape(inputs)[1] # length of segment
H = self._num_heads # number of attention heads
D = self._head_dim # attention head dimension
qkv = self._qkv_layer(inputs)
queries, keys, values = tf.split(qkv, 3, -1)
queries = queries[:, -L:] # only query based on the segment
queries = tf.reshape(queries, [-1, L, H, D])
keys = tf.reshape(keys, [-1, L, H, D])
values = tf.reshape(values, [-1, L, H, D])
score = tf.einsum("bihd,bjhd->bijh", queries, keys)
score = score / D**0.5
# causal mask of the same length as the sequence
mask = tf.sequence_mask(tf.range(1, L + 1), dtype=score.dtype)
mask = mask[None, :, :, None]
masked_score = score * mask + 1e30 * (mask - 1.)
wmat = tf.nn.softmax(masked_score, axis=2)
out = tf.einsum("bijh,bjhd->bihd", wmat, values)
out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * D]), axis=0))
return self._linear_layer(out)
@@ -0,0 +1,119 @@
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
class RelativeMultiHeadAttention(tf.keras.layers.Layer):
"""A RelativeMultiHeadAttention layer as described in [3].
Uses segment level recurrence with state reuse.
"""
def __init__(self,
out_dim,
num_heads,
head_dim,
rel_pos_encoder,
input_layernorm=False,
output_activation=None,
**kwargs):
"""Initializes a RelativeMultiHeadAttention keras Layer object.
Args:
out_dim (int):
num_heads (int): The number of attention heads to use.
Denoted `H` in [2].
head_dim (int): The dimension of a single(!) attention head
Denoted `D` in [2].
rel_pos_encoder (:
input_layernorm (bool): Whether to prepend a LayerNorm before
everything else. Should be True for building a GTrXL.
output_activation (Optional[tf.nn.activation]): Optional tf.nn
activation function. Should be relu for GTrXL.
**kwargs:
"""
super().__init__(**kwargs)
# No bias or non-linearity.
self._num_heads = num_heads
self._head_dim = head_dim
# 3=Query, key, and value inputs.
self._qkv_layer = tf.keras.layers.Dense(
3 * num_heads * head_dim, use_bias=False)
self._linear_layer = tf.keras.layers.TimeDistributed(
tf.keras.layers.Dense(
out_dim, use_bias=False, activation=output_activation))
self._uvar = self.add_weight(shape=(num_heads, head_dim))
self._vvar = self.add_weight(shape=(num_heads, head_dim))
self._pos_proj = tf.keras.layers.Dense(
num_heads * head_dim, use_bias=False)
self._rel_pos_encoder = rel_pos_encoder
self._input_layernorm = None
if input_layernorm:
self._input_layernorm = tf.keras.layers.LayerNormalization(axis=-1)
def call(self, inputs, memory=None):
T = tf.shape(inputs)[1] # length of segment (time)
H = self._num_heads # number of attention heads
d = self._head_dim # attention head dimension
# Add previous memory chunk (as const, w/o gradient) to input.
# Tau (number of (prev) time slices in each memory chunk).
Tau = memory.shape.as_list()[1] if memory is not None else 0
if memory is not None:
inputs = tf.concat((tf.stop_gradient(memory), inputs), axis=1)
# Apply the Layer-Norm.
if self._input_layernorm is not None:
inputs = self._input_layernorm(inputs)
qkv = self._qkv_layer(inputs)
queries, keys, values = tf.split(qkv, 3, -1)
# Cut out Tau memory timesteps from query.
queries = queries[:, -T:]
queries = tf.reshape(queries, [-1, T, H, d])
keys = tf.reshape(keys, [-1, T + Tau, H, d])
values = tf.reshape(values, [-1, T + Tau, H, d])
R = self._pos_proj(self._rel_pos_encoder)
R = tf.reshape(R, [T + Tau, H, d])
# b=batch
# i and j=time indices (i=max-timesteps (inputs); j=Tau memory space)
# h=head
# d=head-dim (over which we will reduce-sum)
score = tf.einsum("bihd,bjhd->bijh", queries + self._uvar, keys)
pos_score = tf.einsum("bihd,jhd->bijh", queries + self._vvar, R)
score = score + self.rel_shift(pos_score)
score = score / d**0.5
# causal mask of the same length as the sequence
mask = tf.sequence_mask(
tf.range(Tau + 1, T + Tau + 1), dtype=score.dtype)
mask = mask[None, :, :, None]
masked_score = score * mask + 1e30 * (mask - 1.)
wmat = tf.nn.softmax(masked_score, axis=2)
out = tf.einsum("bijh,bjhd->bihd", wmat, values)
out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * d]), axis=0))
return self._linear_layer(out)
@staticmethod
def rel_shift(x):
# Transposed version of the shift approach described in [3].
# https://github.com/kimiyoung/transformer-xl/blob/
# 44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31
x_size = tf.shape(x)
x = tf.pad(x, [[0, 0], [0, 0], [1, 0], [0, 0]])
x = tf.reshape(x, [x_size[0], x_size[2] + 1, x_size[1], x_size[3]])
x = tf.slice(x, [0, 1, 0, 0], [-1, -1, -1, -1])
x = tf.reshape(x, x_size)
return x
+37
View File
@@ -0,0 +1,37 @@
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
class SkipConnection(tf.keras.layers.Layer):
"""Skip connection layer.
Adds the original input to the output (regular residual layer) OR uses
input as hidden state input to a given fan_in_layer.
"""
def __init__(self, layer, fan_in_layer=None, add_memory=False, **kwargs):
"""Initializes a SkipConnection keras layer object.
Args:
layer (tf.keras.layers.Layer): Any layer processing inputs.
fan_in_layer (Optional[tf.keras.layers.Layer]): An optional
layer taking two inputs: The original input and the output
of `layer`.
"""
super().__init__(**kwargs)
self._layer = layer
self._fan_in_layer = fan_in_layer
def call(self, inputs, **kwargs):
# del kwargs
outputs = self._layer(inputs, **kwargs)
# Residual case, just add inputs to outputs.
if self._fan_in_layer is None:
outputs = outputs + inputs
# Fan-in e.g. RNN: Call fan-in with `inputs` and `outputs`.
else:
# NOTE: In the GRU case, `inputs` is the state input.
outputs = self._fan_in_layer((inputs, outputs))
return outputs
+2 -2
View File
@@ -10,7 +10,7 @@ from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
# Deprecated: see as an alternative models/tf/recurrent_tf_modelv2.py
# Deprecated: see as an alternative models/tf/recurrent_net.py
class LSTM(Model):
"""Adds a LSTM cell on top of some other model output.
@@ -24,7 +24,7 @@ class LSTM(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
# Hard deprecate this class. All Models should use the ModelV2
# API from here on.
deprecation_warning("Model->LSTM", "RecurrentTFModelV2", error=False)
deprecation_warning("Model->LSTM", "RecurrentNetwork", error=False)
cell_size = options.get("lstm_cell_size")
if options.get("lstm_use_prev_action_reward"):
@@ -8,14 +8,14 @@ tf = try_import_tf()
@DeveloperAPI
class RecurrentTFModelV2(TFModelV2):
class RecurrentNetwork(TFModelV2):
"""Helper class to simplify implementing RNN models with TFModelV2.
Instead of implementing forward(), you can implement forward_rnn() which
takes batches with the time dimension added already.
Here is an example implementation for a subclass
``MyRNNClass(RecurrentTFModelV2)``::
``MyRNNClass(RecurrentNetwork)``::
def __init__(self, *args, **kwargs):
super(MyModelClass, self).__init__(*args, **kwargs)
@@ -50,6 +50,8 @@ class RecurrentTFModelV2(TFModelV2):
"""Adds time dimension to batch before sending inputs to forward_rnn().
You should implement forward_rnn() in your subclass."""
assert seq_lens is not None
output, new_state = self.forward_rnn(
add_time_dimension(
input_dict["obs_flat"], seq_lens, framework="tf"), state,
+7
View File
@@ -0,0 +1,7 @@
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.utils.deprecation import renamed_class
RecurrentTFModelV2 = renamed_class(
cls=RecurrentNetwork,
old_name="ray.rllib.models.tf.recurrent_tf_model_v2.RecurrentTFModelV2",
)
+118
View File
@@ -0,0 +1,118 @@
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.visionnet_v1 import _get_filter_config
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
tf = try_import_tf()
class VisionNetwork(TFModelV2):
"""Generic vision network implemented in ModelV2 API."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(VisionNetwork, self).__init__(obs_space, action_space,
num_outputs, model_config, name)
activation = get_activation_fn(model_config.get("conv_activation"))
filters = model_config.get("conv_filters")
if not filters:
filters = _get_filter_config(obs_space.shape)
no_final_linear = model_config.get("no_final_linear")
vf_share_layers = model_config.get("vf_share_layers")
inputs = tf.keras.layers.Input(
shape=obs_space.shape, name="observations")
last_layer = inputs
# Build the action layers
for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
activation=activation,
padding="same",
data_format="channels_last",
name="conv{}".format(i))(last_layer)
out_size, kernel, stride = filters[-1]
# No final linear: Last layer is a Conv2D and uses num_outputs.
if no_final_linear:
last_layer = tf.keras.layers.Conv2D(
num_outputs,
kernel,
strides=(stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",
name="conv_out")(last_layer)
conv_out = last_layer
# Finish network normally (w/o overriding last layer size with
# `num_outputs`), then add another linear one of size `num_outputs`.
else:
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",
name="conv{}".format(i + 1))(last_layer)
conv_out = tf.keras.layers.Conv2D(
num_outputs, [1, 1],
activation=None,
padding="same",
data_format="channels_last",
name="conv_out")(last_layer)
# Build the value layers
if vf_share_layers:
last_layer = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
value_out = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
else:
# build a parallel set of hidden layers for the value net
last_layer = inputs
for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
activation=activation,
padding="same",
data_format="channels_last",
name="conv_value_{}".format(i))(last_layer)
out_size, kernel, stride = filters[-1]
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",
name="conv_value_{}".format(i + 1))(last_layer)
last_layer = tf.keras.layers.Conv2D(
1, [1, 1],
activation=None,
padding="same",
data_format="channels_last",
name="conv_value_out")(last_layer)
value_out = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
self.register_variables(self.base_model.variables)
def forward(self, input_dict, state, seq_lens):
# explicit cast to float32 needed in eager
model_out, self._value_out = self.base_model(
tf.cast(input_dict["obs"], tf.float32))
return tf.squeeze(model_out, axis=[1, 2]), state
def value_function(self):
return tf.reshape(self._value_out, [-1])
+1 -1
View File
@@ -7,7 +7,7 @@ from ray.rllib.utils.framework import get_activation_fn, try_import_tf
tf = try_import_tf()
# Deprecated: see as an alternative models/tf/visionnet_v2.py
# Deprecated: see as an alternative models/tf.visionnet.py
class VisionNetwork(Model):
"""Generic vision network."""
+6 -117
View File
@@ -1,118 +1,7 @@
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.visionnet_v1 import _get_filter_config
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
from ray.rllib.models.tf.vision_net import VisionNetwork as TFVision
from ray.rllib.utils.deprecation import renamed_class
tf = try_import_tf()
class VisionNetwork(TFModelV2):
"""Generic vision network implemented in ModelV2 API."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(VisionNetwork, self).__init__(obs_space, action_space,
num_outputs, model_config, name)
activation = get_activation_fn(model_config.get("conv_activation"))
filters = model_config.get("conv_filters")
if not filters:
filters = _get_filter_config(obs_space.shape)
no_final_linear = model_config.get("no_final_linear")
vf_share_layers = model_config.get("vf_share_layers")
inputs = tf.keras.layers.Input(
shape=obs_space.shape, name="observations")
last_layer = inputs
# Build the action layers
for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
activation=activation,
padding="same",
data_format="channels_last",
name="conv{}".format(i))(last_layer)
out_size, kernel, stride = filters[-1]
# No final linear: Last layer is a Conv2D and uses num_outputs.
if no_final_linear:
last_layer = tf.keras.layers.Conv2D(
num_outputs,
kernel,
strides=(stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",
name="conv_out")(last_layer)
conv_out = last_layer
# Finish network normally (w/o overriding last layer size with
# `num_outputs`), then add another linear one of size `num_outputs`.
else:
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",
name="conv{}".format(i + 1))(last_layer)
conv_out = tf.keras.layers.Conv2D(
num_outputs, [1, 1],
activation=None,
padding="same",
data_format="channels_last",
name="conv_out")(last_layer)
# Build the value layers
if vf_share_layers:
last_layer = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
value_out = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
else:
# build a parallel set of hidden layers for the value net
last_layer = inputs
for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
activation=activation,
padding="same",
data_format="channels_last",
name="conv_value_{}".format(i))(last_layer)
out_size, kernel, stride = filters[-1]
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=(stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",
name="conv_value_{}".format(i + 1))(last_layer)
last_layer = tf.keras.layers.Conv2D(
1, [1, 1],
activation=None,
padding="same",
data_format="channels_last",
name="conv_value_out")(last_layer)
value_out = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
self.register_variables(self.base_model.variables)
def forward(self, input_dict, state, seq_lens):
# explicit cast to float32 needed in eager
model_out, self._value_out = self.base_model(
tf.cast(input_dict["obs"], tf.float32))
return tf.squeeze(model_out, axis=[1, 2]), state
def value_function(self):
return tf.reshape(self._value_out, [-1])
VisionNetwork = renamed_class(
cls=TFVision,
old_name="ray.rllib.models.tf.visionnet_v2.VisionNetwork",
)
+3 -3
View File
@@ -2,13 +2,13 @@
# dependencies b/c of that.
# from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
# from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
# from ray.rllib.models.torch.recurrent_torch_model import \
# RecurrentTorchModel
# from ray.rllib.models.torch.recurrent_net import \
# RecurrentNetwork
# from ray.rllib.models.torch.visionnet import VisionNetwork
# __all__ = [
# "FullyConnectedNetwork",
# "RecurrentTorchModel",
# "RecurrentNetwork",
# "TorchModelV2",
# "VisionNetwork",
# ]
View File
@@ -10,14 +10,14 @@ torch, nn = try_import_torch()
@DeveloperAPI
class RecurrentTorchModel(TorchModelV2, nn.Module):
class RecurrentNetwork(TorchModelV2, nn.Module):
"""Helper class to simplify implementing RNN models with TorchModelV2.
Instead of implementing forward(), you can implement forward_rnn() which
takes batches with the time dimension added already.
Here is an example implementation for a subclass
``MyRNNClass(nn.Module, RecurrentTorchModel)``::
``MyRNNClass(nn.Module, RecurrentNetwork)``::
def __init__(self, obs_space, num_outputs):
self.obs_size = _get_size(obs_space)
@@ -41,7 +41,7 @@ class RecurrentTorchModel(TorchModelV2, nn.Module):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
@override(RecurrentTorchModel)
@override(RecurrentNetwork)
def forward_rnn(self, input_dict, state, seq_lens):
x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float()))
h_in = state[0].reshape(-1, self.rnn_hidden_dim)
+2 -5
View File
@@ -160,11 +160,8 @@ class DynamicTFPolicy(TFPolicy):
action_space=action_space,
num_outputs=logit_dim,
model_config=self.config["model"],
framework="tf")
# NOTE: Adding below line will break existing custom models
# that do not expect extra options in **kwargs but rather in
# model_config["custom_options"].
# **self.config["model"].get("custom_options", {}))
framework="tf",
**self.config["model"].get("custom_options", {}))
# Create the Exploration object to use for this Policy.
self.exploration = self._create_exploration()
+2 -2
View File
@@ -9,8 +9,8 @@ from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.models.preprocessors import (NoPreprocessor, OneHotPreprocessor,
Preprocessor)
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.tf.visionnet import VisionNetwork
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
+3 -128
View File
@@ -1,20 +1,14 @@
import gym
import numpy as np
import pickle
import unittest
import ray
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
from ray.rllib.examples.env.debug_counter_env import DebugCounterEnv
from ray.rllib.examples.models.rnn_spy_model import RNNSpyModel
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.recurrent_tf_modelv2 import RecurrentTFModelV2
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
from ray.tune.registry import register_env
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.annotations import override
tf = try_import_tf()
class TestLSTMUtils(unittest.TestCase):
@@ -92,125 +86,6 @@ class TestLSTMUtils(unittest.TestCase):
self.assertEqual(seq_lens.tolist(), [1, 2])
class RNNSpyModel(RecurrentTFModelV2):
capture_index = 0
cell_size = 3
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super().__init__(obs_space, action_space, num_outputs, model_config,
name)
self.cell_size = RNNSpyModel.cell_size
def spy(inputs, seq_lens, h_in, c_in, h_out, c_out):
if len(inputs) == 1:
return 0 # don't capture inference inputs
# TF runs this function in an isolated context, so we have to use
# redis to communicate back to our suite
ray.experimental.internal_kv._internal_kv_put(
"rnn_spy_in_{}".format(RNNSpyModel.capture_index),
pickle.dumps({
"sequences": inputs,
"seq_lens": seq_lens,
"state_in": [h_in, c_in],
"state_out": [h_out, c_out]
}),
overwrite=True)
RNNSpyModel.capture_index += 1
return 0
# Create a keras LSTM model.
inputs = tf.keras.layers.Input(
shape=(None, ) + obs_space.shape, name="input")
state_in_h = tf.keras.layers.Input(shape=(self.cell_size, ), name="h")
state_in_c = tf.keras.layers.Input(shape=(self.cell_size, ), name="c")
seq_lens = tf.keras.layers.Input(
shape=(), name="seq_lens", dtype=tf.int32)
lstm_out, state_out_h, state_out_c = tf.keras.layers.LSTM(
self.cell_size,
return_sequences=True,
return_state=True,
name="lstm")(
inputs=inputs,
mask=tf.sequence_mask(seq_lens),
initial_state=[state_in_h, state_in_c])
self.dense = tf.keras.layers.Dense(
units=self.num_outputs, kernel_initializer=normc_initializer(0.01))
def lambda_(inputs):
spy_fn = tf.py_func(
spy,
[
inputs[0], # observations
inputs[2], # seq_lens
inputs[3], # h_in
inputs[4], # c_in
inputs[5], # h_out
inputs[6], # c_out
],
tf.int64,
stateful=True)
# Compute outputs
with tf.control_dependencies([spy_fn]):
return self.dense(inputs[1]) # lstm_out
logits = tf.keras.layers.Lambda(lambda_)([
inputs, lstm_out, seq_lens, state_in_h, state_in_c, state_out_h,
state_out_c
])
# Value branch.
value_out = tf.keras.layers.Dense(
units=1, kernel_initializer=normc_initializer(1.0))(lstm_out)
self.base_model = tf.keras.Model(
[inputs, seq_lens, state_in_h, state_in_c],
[logits, value_out, state_out_h, state_out_c])
self.base_model.summary()
self.register_variables(self.base_model.variables)
@override(RecurrentTFModelV2)
def forward_rnn(self, inputs, state, seq_lens):
# Previously, a new class object was created during
# deserialization and this `capture_index`
# variable would be refreshed between class instantiations.
# This behavior is no longer the case, so we manually refresh
# the variable.
RNNSpyModel.capture_index = 0
model_out, value_out, h, c = self.base_model(
[inputs, seq_lens, state[0], state[1]])
self._value_out = value_out
return model_out, [h, c]
@override(ModelV2)
def value_function(self):
return tf.reshape(self._value_out, [-1])
@override(ModelV2)
def get_initial_state(self):
return [
np.zeros(self.cell_size, np.float32),
np.zeros(self.cell_size, np.float32)
]
class DebugCounterEnv(gym.Env):
def __init__(self):
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(0, 100, (1, ))
self.i = 0
def reset(self):
self.i = 0
return [self.i]
def step(self, action):
self.i += 1
return [self.i], self.i % 3, self.i >= 15, {}
class TestRNNSequencing(unittest.TestCase):
def setUp(self) -> None:
ray.init(num_cpus=4)
+8 -7
View File
@@ -9,8 +9,8 @@ from ray.rllib.agents.registry import get_agent_class
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
MultiAgentMountainCar
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork as FCNetV2
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork as VisionNetV2
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as FCNetV2
from ray.rllib.models.tf.visionnet import VisionNetwork as VisionNetV2
from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVisionNetV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNetV2
from ray.rllib.utils.error import UnsupportedSpaceException
@@ -253,11 +253,12 @@ class ModelSupportedSpaces(unittest.TestCase):
})
def test_ddpg_multiagent(self):
check_support_multiagent("DDPG", {
"timesteps_per_iteration": 1,
"use_state_preprocessor": True,
"learning_starts": 500,
})
check_support_multiagent(
"DDPG", {
"timesteps_per_iteration": 1,
"use_state_preprocessor": True,
"learning_starts": 500,
})
def test_dqn_multiagent(self):
check_support_multiagent("DQN", {"timesteps_per_iteration": 1})