From d5604eaba321c11c1b9616c283262c4ddea55049 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 21 Dec 2020 21:38:34 -0500 Subject: [PATCH] [RLlib] Attention nets PyTorch support and cleanup (using traj. view API). (#12029) --- rllib/BUILD | 19 ++- rllib/agents/ppo/ppo_torch_policy.py | 3 +- .../collectors/simple_list_collector.py | 8 +- .../tests/test_trajectory_view_api.py | 97 +++++++++++- rllib/examples/attention_net.py | 5 +- .../examples/custom_metrics_and_callbacks.py | 1 + rllib/examples/env/debug_counter_env.py | 19 ++- .../models/centralized_critic_models.py | 9 +- rllib/models/torch/attention_net.py | 140 +++++++++--------- rllib/models/torch/modules/gru_gate.py | 25 ++-- .../modules/relative_multi_head_attention.py | 84 ++++++++--- rllib/policy/eager_tf_policy.py | 12 +- rllib/policy/tf_policy_template.py | 3 - rllib/policy/torch_policy.py | 15 +- 14 files changed, 292 insertions(+), 148 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index c645c27a0..44a147b6d 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1480,20 +1480,19 @@ py_test( name = "examples/attention_net_tf", main = "examples/attention_net.py", tags = ["examples", "examples_A"], - size = "large", + size = "medium", srcs = ["examples/attention_net.py"], args = ["--as-test", "--stop-reward=80"] ) -# TODO(sven): GTrXL PyTorch. -# py_test( -# name = "examples/attention_net_torch", -# main = "examples/attention_net.py", -# tags = ["examples", "examples_A"], -# size = "large", -# srcs = ["examples/attention_net.py"], -# args = ["--as-test", "--torch", "--stop-reward=90"] -# ) +py_test( + name = "examples/attention_net_torch", + main = "examples/attention_net.py", + tags = ["examples", "examples_A"], + size = "medium", + srcs = ["examples/attention_net.py"], + args = ["--as-test", "--stop-reward=80", "--torch"] +) py_test( name = "examples/autoregressive_action_dist_tf", diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index fa2ca6c1d..d99251298 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -49,7 +49,8 @@ def ppo_surrogate_loss( # RNN case: Mask away 0-padded chunks at end of time axis. if state: - max_seq_len = torch.max(train_batch["seq_lens"]) + B = len(train_batch["seq_lens"]) + max_seq_len = logits.shape[0] // B mask = sequence_mask( train_batch["seq_lens"], max_seq_len, diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 1d5fe3f76..96e6d0624 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -35,9 +35,6 @@ def to_float_np_array(v: List[Any]) -> np.ndarray: return arr -_INIT_COLS = [SampleBatch.OBS] - - class _AgentCollector: """Collects samples for one agent in one trajectory (episode). @@ -55,8 +52,9 @@ class _AgentCollector: # or internal state inputs. self.shift_before = -min( (int(vr.shift.split(":")[0]) - if isinstance(vr.shift, str) else vr.shift) + - (-1 if vr.data_col in _INIT_COLS or k in _INIT_COLS else 0) + if isinstance(vr.shift, str) else vr.shift) - + (1 + if vr.data_col == SampleBatch.OBS or k == SampleBatch.OBS else 0) for k, vr in view_reqs.items()) # The actual data buffers (lists holding each timestep's data). self.buffers: Dict[str, List] = {} diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index a50978bfd..1a13300de 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -7,19 +7,41 @@ import unittest import ray from ray import tune +from ray.rllib.agents.callbacks import DefaultCallbacks import ray.rllib.agents.dqn as dqn import ray.rllib.agents.ppo as ppo from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.examples.policy.episode_env_aware_policy import \ - EpisodeEnvAwareLSTMPolicy + EpisodeEnvAwareAttentionPolicy, EpisodeEnvAwareLSTMPolicy +from ray.rllib.models.tf.attention_net import GTrXLNet from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils.annotations import override from ray.rllib.utils.test_utils import framework_iterator, check +class MyCallbacks(DefaultCallbacks): + @override(DefaultCallbacks) + def on_learn_on_batch(self, *, policy, train_batch, **kwargs): + assert train_batch.count == 201 + assert sum(train_batch.seq_lens) == 201 + for k, v in train_batch.data.items(): + if k == "state_in_0": + assert len(v) == len(train_batch.seq_lens) + else: + assert len(v) == 201 + current = None + for o in train_batch[SampleBatch.OBS]: + if current: + assert o == current + 1 + current = o + if o == 15: + current = None + + class TestTrajectoryViewAPI(unittest.TestCase): @classmethod def setUpClass(cls) -> None: @@ -116,6 +138,45 @@ class TestTrajectoryViewAPI(unittest.TestCase): assert view_req_policy[key].shift == 1 trainer.stop() + def test_traj_view_attention_net(self): + config = ppo.DEFAULT_CONFIG.copy() + # Setup attention net. + config["model"] = config["model"].copy() + config["model"]["max_seq_len"] = 50 + config["model"]["custom_model"] = GTrXLNet + config["model"]["custom_model_config"] = { + "num_transformer_units": 1, + "attn_dim": 64, + "num_heads": 2, + "memory_inference": 50, + "memory_training": 50, + "head_dim": 32, + "ff_hidden_dim": 32, + } + # Test with odd batch numbers. + config["train_batch_size"] = 1031 + config["sgd_minibatch_size"] = 201 + config["num_sgd_iter"] = 5 + config["num_workers"] = 0 + config["callbacks"] = MyCallbacks + config["env_config"] = { + "config": { + "start_at_t": 1 + } + } # first obs is [1.0] + + for _ in framework_iterator(config, frameworks="tf2"): + trainer = ppo.PPOTrainer( + config, + env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv", + ) + rw = trainer.workers.local_worker() + sample = rw.sample() + assert sample.count == config["rollout_fragment_length"] + results = trainer.train() + assert results["train_batch_size"] == config["train_batch_size"] + trainer.stop() + def test_traj_view_simple_performance(self): """Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`. """ @@ -298,6 +359,40 @@ class TestTrajectoryViewAPI(unittest.TestCase): pol_batch_wo = result.policy_batches["pol0"] check(pol_batch_w.data, pol_batch_wo.data) + def test_traj_view_attention_functionality(self): + action_space = Box(-float("inf"), float("inf"), shape=(3, )) + obs_space = Box(float("-inf"), float("inf"), (4, )) + max_seq_len = 50 + rollout_fragment_length = 201 + policies = { + "pol0": (EpisodeEnvAwareAttentionPolicy, obs_space, action_space, + {}), + } + + def policy_fn(agent_id): + return "pol0" + + config = { + "multiagent": { + "policies": policies, + "policy_mapping_fn": policy_fn, + }, + "model": { + "max_seq_len": max_seq_len, + }, + }, + + rollout_worker_w_api = RolloutWorker( + env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), + policy_config=dict(config, **{"_use_trajectory_view_api": True}), + rollout_fragment_length=rollout_fragment_length, + policy_spec=policies, + policy_mapping_fn=policy_fn, + num_envs=1, + ) + batch = rollout_worker_w_api.sample() + print(batch) + def test_counting_by_agent_steps(self): """Test whether a PPOTrainer can be built with all frameworks.""" config = copy.deepcopy(ppo.DEFAULT_CONFIG) diff --git a/rllib/examples/attention_net.py b/rllib/examples/attention_net.py index de3f06c29..a490b73e9 100644 --- a/rllib/examples/attention_net.py +++ b/rllib/examples/attention_net.py @@ -4,6 +4,7 @@ import os import ray from ray import tune from ray.rllib.models.tf.attention_net import GTrXLNet +from ray.rllib.models.torch.attention_net import GTrXLNet as TorchGTrXLNet from ray.rllib.examples.env.look_and_push import LookAndPush, OneHot from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv from ray.rllib.examples.env.repeat_initial_obs_env import RepeatInitialObsEnv @@ -27,8 +28,6 @@ parser.add_argument("--stop-reward", type=float, default=80) if __name__ == "__main__": args = parser.parse_args() - assert not args.torch, "PyTorch not supported for AttentionNets yet!" - ray.init(num_cpus=args.num_cpus or None) registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c)) @@ -52,7 +51,7 @@ if __name__ == "__main__": "num_sgd_iter": 10, "vf_loss_coeff": 1e-5, "model": { - "custom_model": GTrXLNet, + "custom_model": TorchGTrXLNet if args.torch else GTrXLNet, "max_seq_len": 50, "custom_model_config": { "num_transformer_units": 1, diff --git a/rllib/examples/custom_metrics_and_callbacks.py b/rllib/examples/custom_metrics_and_callbacks.py index d7a2c849d..745a94029 100644 --- a/rllib/examples/custom_metrics_and_callbacks.py +++ b/rllib/examples/custom_metrics_and_callbacks.py @@ -28,6 +28,7 @@ class MyCallbacks(DefaultCallbacks): episode: MultiAgentEpisode, env_index: int, **kwargs): print("episode {} (env-idx={}) started.".format( episode.episode_id, env_index)) + episode.user_data["pole_angles"] = [] episode.hist_data["pole_angles"] = [] diff --git a/rllib/examples/env/debug_counter_env.py b/rllib/examples/env/debug_counter_env.py index c14d49951..aa3a9b3b7 100644 --- a/rllib/examples/env/debug_counter_env.py +++ b/rllib/examples/env/debug_counter_env.py @@ -12,18 +12,25 @@ class DebugCounterEnv(gym.Env): Reward is always: current ts % 3. """ - def __init__(self): + def __init__(self, config=None): + config = config or {} self.action_space = gym.spaces.Discrete(2) - self.observation_space = gym.spaces.Box(0, 100, (1, )) - self.i = 0 + self.observation_space = \ + gym.spaces.Box(0, 100, (1, ), dtype=np.float32) + self.start_at_t = int(config.get("start_at_t", 0)) + self.i = self.start_at_t def reset(self): - self.i = 0 - return [self.i] + self.i = self.start_at_t + return self._get_obs() def step(self, action): self.i += 1 - return [self.i], self.i % 3, self.i >= 15, {} + return self._get_obs(), float(self.i % 3), \ + self.i >= 15 + self.start_at_t, {} + + def _get_obs(self): + return np.array([self.i], dtype=np.float32) class MultiAgentDebugCounterEnv(MultiAgentEnv): diff --git a/rllib/examples/models/centralized_critic_models.py b/rllib/examples/models/centralized_critic_models.py index 276f42381..23f1e8b92 100644 --- a/rllib/examples/models/centralized_critic_models.py +++ b/rllib/examples/models/centralized_critic_models.py @@ -45,9 +45,10 @@ class CentralizedCriticModel(TFModelV2): def central_value_function(self, obs, opponent_obs, opponent_actions): return tf.reshape( - self.central_vf( - [obs, opponent_obs, - tf.one_hot(opponent_actions, 2)]), [-1]) + self.central_vf([ + obs, opponent_obs, + tf.one_hot(tf.cast(opponent_actions, tf.int32), 2) + ]), [-1]) @override(ModelV2) def value_function(self): @@ -124,7 +125,7 @@ class TorchCentralizedCriticModel(TorchModelV2, nn.Module): def central_value_function(self, obs, opponent_obs, opponent_actions): input_ = torch.cat([ obs, opponent_obs, - torch.nn.functional.one_hot(opponent_actions, 2).float() + torch.nn.functional.one_hot(opponent_actions.long(), 2).float() ], 1) return torch.reshape(self.central_vf(input_), [-1]) diff --git a/rllib/models/torch/attention_net.py b/rllib/models/torch/attention_net.py index 58480a64b..27d2d494e 100644 --- a/rllib/models/torch/attention_net.py +++ b/rllib/models/torch/attention_net.py @@ -10,12 +10,15 @@ """ import numpy as np import gym +from gym.spaces import Box from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.modules import GRUGate, \ RelativeMultiHeadAttention, SkipConnection from ray.rllib.models.torch.recurrent_net import RecurrentNetwork +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import ModelConfigDict, TensorType, List @@ -23,26 +26,6 @@ from ray.rllib.utils.typing import ModelConfigDict, TensorType, List torch, nn = try_import_torch() -def relative_position_embedding(seq_length: int, out_dim: int) -> TensorType: - """Creates a [seq_length x seq_length] matrix for rel. pos encoding. - - Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding - matrix. - - Args: - seq_length (int): The max. sequence length (time axis). - out_dim (int): The number of nodes to go into the first Tranformer - layer with. - - Returns: - torch.Tensor: The encoding matrix Phi. - """ - inverse_freq = 1 / (10000**(torch.arange(0, out_dim, 2.0) / out_dim)) - pos_offsets = torch.arange(seq_length - 1, -1, -1) - inputs = pos_offsets[:, None] * inverse_freq[None, :] - return torch.cat((torch.sin(inputs), torch.cos(inputs)), dim=-1) - - class GTrXLNet(RecurrentNetwork, nn.Module): """A GTrXL net Model described in [2]. @@ -74,7 +57,8 @@ class GTrXLNet(RecurrentNetwork, nn.Module): num_transformer_units: int, attn_dim: int, num_heads: int, - memory_tau: int, + memory_inference: int, + memory_training: int, head_dim: int, ff_hidden_dim: int, init_gate_bias: float = 2.0): @@ -87,9 +71,15 @@ class GTrXLNet(RecurrentNetwork, nn.Module): unit. num_heads (int): The number of attention heads to use in parallel. Denoted as `H` in [3]. - memory_tau (int): The number of timesteps to store in each - transformer block's memory M (concat'd over time and fed into - next transformer block as input). + memory_inference (int): The number of timesteps to concat (time + axis) and feed into the next transformer unit as inference + input. The first transformer unit will receive this number of + past observations (plus the current one), instead. + memory_training (int): The number of timesteps to concat (time + axis) and feed into the next transformer unit as training + input (plus the actual input sequence of len=max_seq_len). + The first transformer unit will receive this number of + past observations (plus the input sequence), instead. head_dim (int): The dimension of a single(!) head. Denoted as `d` in [3]. ff_hidden_dim (int): The dimension of the hidden layer within @@ -110,20 +100,18 @@ class GTrXLNet(RecurrentNetwork, nn.Module): self.num_transformer_units = num_transformer_units self.attn_dim = attn_dim self.num_heads = num_heads - self.memory_tau = memory_tau + self.memory_inference = memory_inference + self.memory_training = memory_training self.head_dim = head_dim self.max_seq_len = model_config["max_seq_len"] self.obs_dim = observation_space.shape[0] - # Constant (non-trainable) sinusoid rel pos encoding matrix. - Phi = relative_position_embedding(self.max_seq_len + self.memory_tau, - self.attn_dim) - self.linear_layer = SlimFC( in_size=self.obs_dim, out_size=self.attn_dim) self.layers = [self.linear_layer] + attention_layers = [] # 2) Create L Transformer blocks according to [2]. for i in range(self.num_transformer_units): # RelativeMultiHeadAttention part. @@ -133,7 +121,6 @@ class GTrXLNet(RecurrentNetwork, nn.Module): out_dim=self.attn_dim, num_heads=num_heads, head_dim=head_dim, - rel_pos_encoder=Phi, input_layernorm=True, output_activation=nn.ReLU), fan_in_layer=GRUGate(self.attn_dim, init_gate_bias)) @@ -154,8 +141,13 @@ class GTrXLNet(RecurrentNetwork, nn.Module): activation_fn=nn.ReLU)), fan_in_layer=GRUGate(self.attn_dim, init_gate_bias)) - # Build a list of all layers in order. - self.layers.extend([MHA_layer, E_layer]) + # Build a list of all attanlayers in order. + attention_layers.extend([MHA_layer, E_layer]) + + # Create a Sequential such that all parameters inside the attention + # layers are automatically registered with this top-level model. + self.attention_layers = nn.Sequential(*attention_layers) + self.layers.extend(attention_layers) # Postprocess GTrXL output with another hidden layer. self.logits = SlimFC( @@ -168,62 +160,64 @@ class GTrXLNet(RecurrentNetwork, nn.Module): self.values_out = SlimFC( in_size=self.attn_dim, out_size=1, activation_fn=None) - @override(RecurrentNetwork) - def forward_rnn(self, inputs: TensorType, state: List[TensorType], - seq_lens: TensorType) -> (TensorType, List[TensorType]): - # To make Attention work with current RLlib's ModelV2 API: - # We assume `state` is the history of L recent observations (all - # concatenated into one tensor) and append the current inputs to the - # end and only keep the most recent (up to `max_seq_len`). This allows - # us to deal with timestep-wise inference and full sequence training - # within the same logic. - state = [torch.from_numpy(item) for item in state] - observations = state[0] - memory = state[1:] + # Setup inference view (`memory-inference` x past observations + + # current one (0)) + # 1 to `num_transformer_units`: Memory data (one per transformer unit). + for i in range(self.num_transformer_units): + space = Box(-1.0, 1.0, shape=(self.attn_dim, )) + self.inference_view_requirements["state_in_{}".format(i)] = \ + ViewRequirement( + "state_out_{}".format(i), + shift="-{}:-1".format(self.memory_inference), + # Repeat the incoming state every max-seq-len times. + batch_repeat_value=self.max_seq_len, + space=space) + self.inference_view_requirements["state_out_{}".format(i)] = \ + ViewRequirement( + space=space, + used_for_training=False) - inputs = torch.reshape(inputs, [1, -1, observations.shape[-1]]) - observations = torch.cat( - (observations, inputs), axis=1)[:, -self.max_seq_len:] + @override(ModelV2) + def forward(self, input_dict, state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): + assert seq_lens is not None + + # Add the needed batch rank (tf Models' Input requires this). + observations = input_dict[SampleBatch.OBS] + # Add the time dim to observations. + B = len(seq_lens) + T = observations.shape[0] // B + observations = torch.reshape(observations, + [-1, T] + list(observations.shape[1:])) all_out = observations + memory_outs = [] for i in range(len(self.layers)): # MHA layers which need memory passed in. if i % 2 == 1: - all_out = self.layers[i](all_out, memory=memory[i // 2]) - # Either linear layers or MultiLayerPerceptrons. + all_out = self.layers[i](all_out, memory=state[i // 2]) + # Either self.linear_layer (initial obs -> attn. dim layer) or + # MultiLayerPerceptrons. The output of these layers is always the + # memory for the next forward pass. else: all_out = self.layers[i](all_out) + memory_outs.append(all_out) + + # Discard last output (not needed as a memory since it's the last + # layer). + memory_outs = memory_outs[:-1] logits = self.logits(all_out) self._value_out = self.values_out(all_out) - memory_outs = all_out[2:] - # If memory_tau > max_seq_len -> overlap w/ previous `memory` input. - if self.memory_tau > self.max_seq_len: - memory_outs = [ - torch.cat( - [memory[i][:, -(self.memory_tau - self.max_seq_len):], m], - axis=1) for i, m in enumerate(memory_outs) - ] - else: - memory_outs = [m[:, -self.memory_tau:] for m in memory_outs] - - T = list(inputs.size())[1] # Length of input segment (time). - - # Postprocessing final output. - logits = logits[:, -T:] - self._value_out = self._value_out[:, -T:] - - return logits, [observations] + memory_outs + return torch.reshape(logits, [-1, self.num_outputs]), [ + torch.reshape(m, [-1, self.attn_dim]) for m in memory_outs + ] + # TODO: (sven) Deprecate this once trajectory view API has fully matured. @override(RecurrentNetwork) def get_initial_state(self) -> List[np.ndarray]: - # State is the T last observations concat'd together into one Tensor. - # Plus all Transformer blocks' E(l) outputs concat'd together (up to - # tau timesteps). - return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + \ - [np.zeros((self.memory_tau, self.attn_dim), np.float32) - for _ in range(self.num_transformer_units)] + return [] @override(ModelV2) def value_function(self) -> TensorType: diff --git a/rllib/models/torch/modules/gru_gate.py b/rllib/models/torch/modules/gru_gate.py index 4cabc5eb2..724c41464 100644 --- a/rllib/models/torch/modules/gru_gate.py +++ b/rllib/models/torch/modules/gru_gate.py @@ -13,26 +13,29 @@ class GRUGate(nn.Module): init_bias (int): Bias added to every input to stabilize training """ super().__init__(**kwargs) - self._init_bias = init_bias - # Xavier initialization of torch tensors - self._w_r = torch.zeros(dim, dim) - self._w_z = torch.zeros(dim, dim) - self._w_h = torch.zeros(dim, dim) - - self._u_r = torch.zeros(dim, dim) - self._u_z = torch.zeros(dim, dim) - self._u_h = torch.zeros(dim, dim) - + self._w_r = nn.Parameter(torch.zeros(dim, dim)) + self._w_z = nn.Parameter(torch.zeros(dim, dim)) + self._w_h = nn.Parameter(torch.zeros(dim, dim)) nn.init.xavier_uniform_(self._w_r) nn.init.xavier_uniform_(self._w_z) nn.init.xavier_uniform_(self._w_h) + self.register_parameter("_w_r", self._w_r) + self.register_parameter("_w_z", self._w_z) + self.register_parameter("_w_h", self._w_h) + self._u_r = nn.Parameter(torch.zeros(dim, dim)) + self._u_z = nn.Parameter(torch.zeros(dim, dim)) + self._u_h = nn.Parameter(torch.zeros(dim, dim)) nn.init.xavier_uniform_(self._u_r) nn.init.xavier_uniform_(self._u_z) nn.init.xavier_uniform_(self._u_h) + self.register_parameter("_u_r", self._u_r) + self.register_parameter("_u_z", self._u_z) + self.register_parameter("_u_h", self._u_h) - self._bias_z = torch.zeros(dim, ).fill_(self._init_bias) + self._bias_z = nn.Parameter(torch.zeros(dim, ).fill_(init_bias)) + self.register_parameter("_bias_z", self._bias_z) def forward(self, inputs: TensorType, **kwargs) -> TensorType: # Pass in internal state first. diff --git a/rllib/models/torch/modules/relative_multi_head_attention.py b/rllib/models/torch/modules/relative_multi_head_attention.py index fe28d6f73..3efa9c664 100644 --- a/rllib/models/torch/modules/relative_multi_head_attention.py +++ b/rllib/models/torch/modules/relative_multi_head_attention.py @@ -1,11 +1,47 @@ +from typing import Union + from ray.rllib.utils.framework import try_import_torch from ray.rllib.models.torch.misc import SlimFC from ray.rllib.utils.torch_ops import sequence_mask -from ray.rllib.utils.typing import TensorType, Any +from ray.rllib.utils.typing import TensorType torch, nn = try_import_torch() +class RelativePositionEmbedding(nn.Module): + """Creates a [seq_length x seq_length] matrix for rel. pos encoding. + + Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding + matrix. + + Args: + seq_length (int): The max. sequence length (time axis). + out_dim (int): The number of nodes to go into the first Tranformer + layer with. + + Returns: + torch.Tensor: The encoding matrix Phi. + """ + + def __init__(self, out_dim, **kwargs): + super().__init__() + self.out_dim = out_dim + + out_range = torch.arange(0, self.out_dim, 2.0) + inverse_freq = 1 / (10000**(out_range / self.out_dim)) + self.register_buffer("inverse_freq", inverse_freq) + + def forward(self, seq_length): + pos_input = torch.arange( + seq_length - 1, -1, -1.0, + dtype=torch.float).to(self.inverse_freq.device) + sinusoid_input = torch.einsum("i,j->ij", pos_input, self.inverse_freq) + pos_embeddings = torch.cat( + [torch.sin(sinusoid_input), + torch.cos(sinusoid_input)], dim=-1) + return pos_embeddings[:, None, :] + + class RelativeMultiHeadAttention(nn.Module): """A RelativeMultiHeadAttention layer as described in [3]. @@ -17,24 +53,24 @@ class RelativeMultiHeadAttention(nn.Module): out_dim: int, num_heads: int, head_dim: int, - rel_pos_encoder: Any, input_layernorm: bool = False, - output_activation: Any = None, + output_activation: Union[str, callable] = None, **kwargs): """Initializes a RelativeMultiHeadAttention nn.Module object. Args: in_dim (int): - out_dim (int): + out_dim (int): The output dimension of this module. Also known as + "attention dim". num_heads (int): The number of attention heads to use. Denoted `H` in [2]. head_dim (int): The dimension of a single(!) attention head Denoted `D` in [2]. - rel_pos_encoder (: input_layernorm (bool): Whether to prepend a LayerNorm before everything else. Should be True for building a GTrXL. - output_activation (Optional[tf.nn.activation]): Optional tf.nn - activation function. Should be relu for GTrXL. + output_activation (Union[str, callable]): Optional activation + function or activation function specifier (str). + Should be "relu" for GTrXL. **kwargs: """ super().__init__(**kwargs) @@ -53,17 +89,18 @@ class RelativeMultiHeadAttention(nn.Module): use_bias=False, activation_fn=output_activation) - self._pos_proj = SlimFC( - in_size=in_dim, out_size=num_heads * head_dim, use_bias=False) - - self._uvar = torch.zeros(num_heads, head_dim) - self._vvar = torch.zeros(num_heads, head_dim) + self._uvar = nn.Parameter(torch.zeros(num_heads, head_dim)) + self._vvar = nn.Parameter(torch.zeros(num_heads, head_dim)) nn.init.xavier_uniform_(self._uvar) nn.init.xavier_uniform_(self._vvar) + self.register_parameter("_uvar", self._uvar) + self.register_parameter("_vvar", self._vvar) + + self._pos_proj = SlimFC( + in_size=in_dim, out_size=num_heads * head_dim, use_bias=False) + self._rel_pos_embedding = RelativePositionEmbedding(out_dim) - self._rel_pos_encoder = rel_pos_encoder self._input_layernorm = None - if input_layernorm: self._input_layernorm = torch.nn.LayerNorm(in_dim) @@ -75,10 +112,8 @@ class RelativeMultiHeadAttention(nn.Module): # Add previous memory chunk (as const, w/o gradient) to input. # Tau (number of (prev) time slices in each memory chunk). - Tau = list(memory.shape)[1] if memory is not None else 0 - if memory is not None: - memory.requires_grad_(False) - inputs = torch.cat((memory, inputs), dim=1) + Tau = list(memory.shape)[1] + inputs = torch.cat((memory.detach(), inputs), dim=1) # Apply the Layer-Norm. if self._input_layernorm is not None: @@ -91,11 +126,11 @@ class RelativeMultiHeadAttention(nn.Module): queries = queries[:, -T:] queries = torch.reshape(queries, [-1, T, H, d]) - keys = torch.reshape(keys, [-1, T + Tau, H, d]) - values = torch.reshape(values, [-1, T + Tau, H, d]) + keys = torch.reshape(keys, [-1, Tau + T, H, d]) + values = torch.reshape(values, [-1, Tau + T, H, d]) - R = self._pos_proj(self._rel_pos_encoder) - R = torch.reshape(R, [T + Tau, H, d]) + R = self._pos_proj(self._rel_pos_embedding(Tau + T)) + R = torch.reshape(R, [Tau + T, H, d]) # b=batch # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space) @@ -108,10 +143,11 @@ class RelativeMultiHeadAttention(nn.Module): # causal mask of the same length as the sequence mask = sequence_mask( - torch.arange(Tau + 1, T + Tau + 1), dtype=score.dtype) + torch.arange(Tau + 1, Tau + T + 1), + dtype=score.dtype).to(score.device) mask = mask[None, :, :, None] - masked_score = score * mask + 1e30 * (mask.to(torch.float32) - 1.) + masked_score = score * mask + 1e30 * (mask.float() - 1.) wmat = nn.functional.softmax(masked_score, dim=2) out = torch.einsum("bijh,bjhd->bihd", wmat, values) diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index f17d60e06..af4fa512c 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -259,10 +259,8 @@ def build_eager_tf_policy(name, self._update_model_inference_view_requirements_from_init_state() self.exploration = self._create_exploration() - self._state_in = [ - tf.convert_to_tensor([s]) - for s in self.model.get_initial_state() - ] + self._state_inputs = self.model.get_initial_state() + self._is_recurrent = len(self._state_inputs) > 0 # Combine view_requirements for Model and Policy. self.view_requirements.update( @@ -375,6 +373,8 @@ def build_eager_tf_policy(name, # TODO: remove python side effect to cull sources of bugs. self._is_training = False + self._is_recurrent = \ + state_batches is not None and state_batches != [] self._state_in = state_batches or [] if not tf1.executing_eagerly(): @@ -552,11 +552,11 @@ def build_eager_tf_policy(name, @override(Policy) def is_recurrent(self): - return len(self._state_in) > 0 + return self._is_recurrent @override(Policy) def num_state_tensors(self): - return len(self._state_in) + return len(self._state_inputs) @override(Policy) def get_initial_state(self): diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index a4f5e12b2..34e7da360 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -170,9 +170,6 @@ def build_tf_policy( mixins (Optional[List[type]]): Optional list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the DynamicTFPolicy class. - view_requirements_fn (Callable[[Policy], - Dict[str, ViewRequirement]]): An optional callable to retrieve - additional train view requirements for this policy. get_batch_divisibility_req (Optional[Callable[[Policy], int]]): Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1. diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index c27a7603d..10e875d50 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -110,6 +110,8 @@ class TorchPolicy(Policy): logger.info("TorchPolicy running on CPU.") self.device = torch.device("cpu") self.model = model.to(self.device) + self._state_inputs = self.model.get_initial_state() + self._is_recurrent = len(self._state_inputs) > 0 # Auto-update model's inference view requirements, if recurrent. self._update_model_inference_view_requirements_from_init_state() # Combine view_requirements for Model and Policy. @@ -203,6 +205,11 @@ class TorchPolicy(Policy): Tuple: - actions, state_out, extra_fetches, logp. """ + self._is_recurrent = state_batches is not None and state_batches != [] + # Switch to eval mode. + if self.model: + self.model.eval() + if self.action_sampler_fn: action_dist = dist_inputs = None state_out = state_batches @@ -325,6 +332,9 @@ class TorchPolicy(Policy): @DeveloperAPI def learn_on_batch( self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: + # Set Model to train mode. + if self.model: + self.model.train() # Callback handling. self.callbacks.on_learn_on_batch( policy=self, train_batch=postprocessed_batch) @@ -354,6 +364,9 @@ class TorchPolicy(Policy): view_requirements=self.view_requirements, ) + # Mark the batch as "is_training" so the Model can use this + # information. + postprocessed_batch["is_training"] = True train_batch = self._lazy_tensor_dict(postprocessed_batch) # Calculate the actual policy loss. @@ -448,7 +461,7 @@ class TorchPolicy(Policy): @override(Policy) @DeveloperAPI def is_recurrent(self) -> bool: - return len(self.model.get_initial_state()) > 0 + return self._is_recurrent @override(Policy) @DeveloperAPI