mirror of
https://github.com/wassname/ray.git
synced 2026-07-06 03:25:12 +08:00
[RLlib] Attention nets PyTorch support and cleanup (using traj. view API). (#12029)
This commit is contained in:
+9
-10
@@ -1480,20 +1480,19 @@ py_test(
|
||||
name = "examples/attention_net_tf",
|
||||
main = "examples/attention_net.py",
|
||||
tags = ["examples", "examples_A"],
|
||||
size = "large",
|
||||
size = "medium",
|
||||
srcs = ["examples/attention_net.py"],
|
||||
args = ["--as-test", "--stop-reward=80"]
|
||||
)
|
||||
|
||||
# TODO(sven): GTrXL PyTorch.
|
||||
# py_test(
|
||||
# name = "examples/attention_net_torch",
|
||||
# main = "examples/attention_net.py",
|
||||
# tags = ["examples", "examples_A"],
|
||||
# size = "large",
|
||||
# srcs = ["examples/attention_net.py"],
|
||||
# args = ["--as-test", "--torch", "--stop-reward=90"]
|
||||
# )
|
||||
py_test(
|
||||
name = "examples/attention_net_torch",
|
||||
main = "examples/attention_net.py",
|
||||
tags = ["examples", "examples_A"],
|
||||
size = "medium",
|
||||
srcs = ["examples/attention_net.py"],
|
||||
args = ["--as-test", "--stop-reward=80", "--torch"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/autoregressive_action_dist_tf",
|
||||
|
||||
@@ -49,7 +49,8 @@ def ppo_surrogate_loss(
|
||||
|
||||
# RNN case: Mask away 0-padded chunks at end of time axis.
|
||||
if state:
|
||||
max_seq_len = torch.max(train_batch["seq_lens"])
|
||||
B = len(train_batch["seq_lens"])
|
||||
max_seq_len = logits.shape[0] // B
|
||||
mask = sequence_mask(
|
||||
train_batch["seq_lens"],
|
||||
max_seq_len,
|
||||
|
||||
@@ -35,9 +35,6 @@ def to_float_np_array(v: List[Any]) -> np.ndarray:
|
||||
return arr
|
||||
|
||||
|
||||
_INIT_COLS = [SampleBatch.OBS]
|
||||
|
||||
|
||||
class _AgentCollector:
|
||||
"""Collects samples for one agent in one trajectory (episode).
|
||||
|
||||
@@ -55,8 +52,9 @@ class _AgentCollector:
|
||||
# or internal state inputs.
|
||||
self.shift_before = -min(
|
||||
(int(vr.shift.split(":")[0])
|
||||
if isinstance(vr.shift, str) else vr.shift) +
|
||||
(-1 if vr.data_col in _INIT_COLS or k in _INIT_COLS else 0)
|
||||
if isinstance(vr.shift, str) else vr.shift) -
|
||||
(1
|
||||
if vr.data_col == SampleBatch.OBS or k == SampleBatch.OBS else 0)
|
||||
for k, vr in view_reqs.items())
|
||||
# The actual data buffers (lists holding each timestep's data).
|
||||
self.buffers: Dict[str, List] = {}
|
||||
|
||||
@@ -7,19 +7,41 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.examples.policy.episode_env_aware_policy import \
|
||||
EpisodeEnvAwareLSTMPolicy
|
||||
EpisodeEnvAwareAttentionPolicy, EpisodeEnvAwareLSTMPolicy
|
||||
from ray.rllib.models.tf.attention_net import GTrXLNet
|
||||
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.test_utils import framework_iterator, check
|
||||
|
||||
|
||||
class MyCallbacks(DefaultCallbacks):
|
||||
@override(DefaultCallbacks)
|
||||
def on_learn_on_batch(self, *, policy, train_batch, **kwargs):
|
||||
assert train_batch.count == 201
|
||||
assert sum(train_batch.seq_lens) == 201
|
||||
for k, v in train_batch.data.items():
|
||||
if k == "state_in_0":
|
||||
assert len(v) == len(train_batch.seq_lens)
|
||||
else:
|
||||
assert len(v) == 201
|
||||
current = None
|
||||
for o in train_batch[SampleBatch.OBS]:
|
||||
if current:
|
||||
assert o == current + 1
|
||||
current = o
|
||||
if o == 15:
|
||||
current = None
|
||||
|
||||
|
||||
class TestTrajectoryViewAPI(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
@@ -116,6 +138,45 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
||||
assert view_req_policy[key].shift == 1
|
||||
trainer.stop()
|
||||
|
||||
def test_traj_view_attention_net(self):
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
# Setup attention net.
|
||||
config["model"] = config["model"].copy()
|
||||
config["model"]["max_seq_len"] = 50
|
||||
config["model"]["custom_model"] = GTrXLNet
|
||||
config["model"]["custom_model_config"] = {
|
||||
"num_transformer_units": 1,
|
||||
"attn_dim": 64,
|
||||
"num_heads": 2,
|
||||
"memory_inference": 50,
|
||||
"memory_training": 50,
|
||||
"head_dim": 32,
|
||||
"ff_hidden_dim": 32,
|
||||
}
|
||||
# Test with odd batch numbers.
|
||||
config["train_batch_size"] = 1031
|
||||
config["sgd_minibatch_size"] = 201
|
||||
config["num_sgd_iter"] = 5
|
||||
config["num_workers"] = 0
|
||||
config["callbacks"] = MyCallbacks
|
||||
config["env_config"] = {
|
||||
"config": {
|
||||
"start_at_t": 1
|
||||
}
|
||||
} # first obs is [1.0]
|
||||
|
||||
for _ in framework_iterator(config, frameworks="tf2"):
|
||||
trainer = ppo.PPOTrainer(
|
||||
config,
|
||||
env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv",
|
||||
)
|
||||
rw = trainer.workers.local_worker()
|
||||
sample = rw.sample()
|
||||
assert sample.count == config["rollout_fragment_length"]
|
||||
results = trainer.train()
|
||||
assert results["train_batch_size"] == config["train_batch_size"]
|
||||
trainer.stop()
|
||||
|
||||
def test_traj_view_simple_performance(self):
|
||||
"""Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`.
|
||||
"""
|
||||
@@ -298,6 +359,40 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
||||
pol_batch_wo = result.policy_batches["pol0"]
|
||||
check(pol_batch_w.data, pol_batch_wo.data)
|
||||
|
||||
def test_traj_view_attention_functionality(self):
|
||||
action_space = Box(-float("inf"), float("inf"), shape=(3, ))
|
||||
obs_space = Box(float("-inf"), float("inf"), (4, ))
|
||||
max_seq_len = 50
|
||||
rollout_fragment_length = 201
|
||||
policies = {
|
||||
"pol0": (EpisodeEnvAwareAttentionPolicy, obs_space, action_space,
|
||||
{}),
|
||||
}
|
||||
|
||||
def policy_fn(agent_id):
|
||||
return "pol0"
|
||||
|
||||
config = {
|
||||
"multiagent": {
|
||||
"policies": policies,
|
||||
"policy_mapping_fn": policy_fn,
|
||||
},
|
||||
"model": {
|
||||
"max_seq_len": max_seq_len,
|
||||
},
|
||||
},
|
||||
|
||||
rollout_worker_w_api = RolloutWorker(
|
||||
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
|
||||
policy_config=dict(config, **{"_use_trajectory_view_api": True}),
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
policy_spec=policies,
|
||||
policy_mapping_fn=policy_fn,
|
||||
num_envs=1,
|
||||
)
|
||||
batch = rollout_worker_w_api.sample()
|
||||
print(batch)
|
||||
|
||||
def test_counting_by_agent_steps(self):
|
||||
"""Test whether a PPOTrainer can be built with all frameworks."""
|
||||
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.models.tf.attention_net import GTrXLNet
|
||||
from ray.rllib.models.torch.attention_net import GTrXLNet as TorchGTrXLNet
|
||||
from ray.rllib.examples.env.look_and_push import LookAndPush, OneHot
|
||||
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
|
||||
from ray.rllib.examples.env.repeat_initial_obs_env import RepeatInitialObsEnv
|
||||
@@ -27,8 +28,6 @@ parser.add_argument("--stop-reward", type=float, default=80)
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
assert not args.torch, "PyTorch not supported for AttentionNets yet!"
|
||||
|
||||
ray.init(num_cpus=args.num_cpus or None)
|
||||
|
||||
registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
|
||||
@@ -52,7 +51,7 @@ if __name__ == "__main__":
|
||||
"num_sgd_iter": 10,
|
||||
"vf_loss_coeff": 1e-5,
|
||||
"model": {
|
||||
"custom_model": GTrXLNet,
|
||||
"custom_model": TorchGTrXLNet if args.torch else GTrXLNet,
|
||||
"max_seq_len": 50,
|
||||
"custom_model_config": {
|
||||
"num_transformer_units": 1,
|
||||
|
||||
@@ -28,6 +28,7 @@ class MyCallbacks(DefaultCallbacks):
|
||||
episode: MultiAgentEpisode, env_index: int, **kwargs):
|
||||
print("episode {} (env-idx={}) started.".format(
|
||||
episode.episode_id, env_index))
|
||||
|
||||
episode.user_data["pole_angles"] = []
|
||||
episode.hist_data["pole_angles"] = []
|
||||
|
||||
|
||||
+13
-6
@@ -12,18 +12,25 @@ class DebugCounterEnv(gym.Env):
|
||||
Reward is always: current ts % 3.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, config=None):
|
||||
config = config or {}
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
self.observation_space = gym.spaces.Box(0, 100, (1, ))
|
||||
self.i = 0
|
||||
self.observation_space = \
|
||||
gym.spaces.Box(0, 100, (1, ), dtype=np.float32)
|
||||
self.start_at_t = int(config.get("start_at_t", 0))
|
||||
self.i = self.start_at_t
|
||||
|
||||
def reset(self):
|
||||
self.i = 0
|
||||
return [self.i]
|
||||
self.i = self.start_at_t
|
||||
return self._get_obs()
|
||||
|
||||
def step(self, action):
|
||||
self.i += 1
|
||||
return [self.i], self.i % 3, self.i >= 15, {}
|
||||
return self._get_obs(), float(self.i % 3), \
|
||||
self.i >= 15 + self.start_at_t, {}
|
||||
|
||||
def _get_obs(self):
|
||||
return np.array([self.i], dtype=np.float32)
|
||||
|
||||
|
||||
class MultiAgentDebugCounterEnv(MultiAgentEnv):
|
||||
|
||||
@@ -45,9 +45,10 @@ class CentralizedCriticModel(TFModelV2):
|
||||
|
||||
def central_value_function(self, obs, opponent_obs, opponent_actions):
|
||||
return tf.reshape(
|
||||
self.central_vf(
|
||||
[obs, opponent_obs,
|
||||
tf.one_hot(opponent_actions, 2)]), [-1])
|
||||
self.central_vf([
|
||||
obs, opponent_obs,
|
||||
tf.one_hot(tf.cast(opponent_actions, tf.int32), 2)
|
||||
]), [-1])
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
@@ -124,7 +125,7 @@ class TorchCentralizedCriticModel(TorchModelV2, nn.Module):
|
||||
def central_value_function(self, obs, opponent_obs, opponent_actions):
|
||||
input_ = torch.cat([
|
||||
obs, opponent_obs,
|
||||
torch.nn.functional.one_hot(opponent_actions, 2).float()
|
||||
torch.nn.functional.one_hot(opponent_actions.long(), 2).float()
|
||||
], 1)
|
||||
return torch.reshape(self.central_vf(input_), [-1])
|
||||
|
||||
|
||||
@@ -10,12 +10,15 @@
|
||||
"""
|
||||
import numpy as np
|
||||
import gym
|
||||
from gym.spaces import Box
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.torch.modules import GRUGate, \
|
||||
RelativeMultiHeadAttention, SkipConnection
|
||||
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
|
||||
@@ -23,26 +26,6 @@ from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
def relative_position_embedding(seq_length: int, out_dim: int) -> TensorType:
|
||||
"""Creates a [seq_length x seq_length] matrix for rel. pos encoding.
|
||||
|
||||
Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding
|
||||
matrix.
|
||||
|
||||
Args:
|
||||
seq_length (int): The max. sequence length (time axis).
|
||||
out_dim (int): The number of nodes to go into the first Tranformer
|
||||
layer with.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The encoding matrix Phi.
|
||||
"""
|
||||
inverse_freq = 1 / (10000**(torch.arange(0, out_dim, 2.0) / out_dim))
|
||||
pos_offsets = torch.arange(seq_length - 1, -1, -1)
|
||||
inputs = pos_offsets[:, None] * inverse_freq[None, :]
|
||||
return torch.cat((torch.sin(inputs), torch.cos(inputs)), dim=-1)
|
||||
|
||||
|
||||
class GTrXLNet(RecurrentNetwork, nn.Module):
|
||||
"""A GTrXL net Model described in [2].
|
||||
|
||||
@@ -74,7 +57,8 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
|
||||
num_transformer_units: int,
|
||||
attn_dim: int,
|
||||
num_heads: int,
|
||||
memory_tau: int,
|
||||
memory_inference: int,
|
||||
memory_training: int,
|
||||
head_dim: int,
|
||||
ff_hidden_dim: int,
|
||||
init_gate_bias: float = 2.0):
|
||||
@@ -87,9 +71,15 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
|
||||
unit.
|
||||
num_heads (int): The number of attention heads to use in parallel.
|
||||
Denoted as `H` in [3].
|
||||
memory_tau (int): The number of timesteps to store in each
|
||||
transformer block's memory M (concat'd over time and fed into
|
||||
next transformer block as input).
|
||||
memory_inference (int): The number of timesteps to concat (time
|
||||
axis) and feed into the next transformer unit as inference
|
||||
input. The first transformer unit will receive this number of
|
||||
past observations (plus the current one), instead.
|
||||
memory_training (int): The number of timesteps to concat (time
|
||||
axis) and feed into the next transformer unit as training
|
||||
input (plus the actual input sequence of len=max_seq_len).
|
||||
The first transformer unit will receive this number of
|
||||
past observations (plus the input sequence), instead.
|
||||
head_dim (int): The dimension of a single(!) head.
|
||||
Denoted as `d` in [3].
|
||||
ff_hidden_dim (int): The dimension of the hidden layer within
|
||||
@@ -110,20 +100,18 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
|
||||
self.num_transformer_units = num_transformer_units
|
||||
self.attn_dim = attn_dim
|
||||
self.num_heads = num_heads
|
||||
self.memory_tau = memory_tau
|
||||
self.memory_inference = memory_inference
|
||||
self.memory_training = memory_training
|
||||
self.head_dim = head_dim
|
||||
self.max_seq_len = model_config["max_seq_len"]
|
||||
self.obs_dim = observation_space.shape[0]
|
||||
|
||||
# Constant (non-trainable) sinusoid rel pos encoding matrix.
|
||||
Phi = relative_position_embedding(self.max_seq_len + self.memory_tau,
|
||||
self.attn_dim)
|
||||
|
||||
self.linear_layer = SlimFC(
|
||||
in_size=self.obs_dim, out_size=self.attn_dim)
|
||||
|
||||
self.layers = [self.linear_layer]
|
||||
|
||||
attention_layers = []
|
||||
# 2) Create L Transformer blocks according to [2].
|
||||
for i in range(self.num_transformer_units):
|
||||
# RelativeMultiHeadAttention part.
|
||||
@@ -133,7 +121,6 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
|
||||
out_dim=self.attn_dim,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
rel_pos_encoder=Phi,
|
||||
input_layernorm=True,
|
||||
output_activation=nn.ReLU),
|
||||
fan_in_layer=GRUGate(self.attn_dim, init_gate_bias))
|
||||
@@ -154,8 +141,13 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
|
||||
activation_fn=nn.ReLU)),
|
||||
fan_in_layer=GRUGate(self.attn_dim, init_gate_bias))
|
||||
|
||||
# Build a list of all layers in order.
|
||||
self.layers.extend([MHA_layer, E_layer])
|
||||
# Build a list of all attanlayers in order.
|
||||
attention_layers.extend([MHA_layer, E_layer])
|
||||
|
||||
# Create a Sequential such that all parameters inside the attention
|
||||
# layers are automatically registered with this top-level model.
|
||||
self.attention_layers = nn.Sequential(*attention_layers)
|
||||
self.layers.extend(attention_layers)
|
||||
|
||||
# Postprocess GTrXL output with another hidden layer.
|
||||
self.logits = SlimFC(
|
||||
@@ -168,62 +160,64 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
|
||||
self.values_out = SlimFC(
|
||||
in_size=self.attn_dim, out_size=1, activation_fn=None)
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
# To make Attention work with current RLlib's ModelV2 API:
|
||||
# We assume `state` is the history of L recent observations (all
|
||||
# concatenated into one tensor) and append the current inputs to the
|
||||
# end and only keep the most recent (up to `max_seq_len`). This allows
|
||||
# us to deal with timestep-wise inference and full sequence training
|
||||
# within the same logic.
|
||||
state = [torch.from_numpy(item) for item in state]
|
||||
observations = state[0]
|
||||
memory = state[1:]
|
||||
# Setup inference view (`memory-inference` x past observations +
|
||||
# current one (0))
|
||||
# 1 to `num_transformer_units`: Memory data (one per transformer unit).
|
||||
for i in range(self.num_transformer_units):
|
||||
space = Box(-1.0, 1.0, shape=(self.attn_dim, ))
|
||||
self.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
shift="-{}:-1".format(self.memory_inference),
|
||||
# Repeat the incoming state every max-seq-len times.
|
||||
batch_repeat_value=self.max_seq_len,
|
||||
space=space)
|
||||
self.inference_view_requirements["state_out_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
space=space,
|
||||
used_for_training=False)
|
||||
|
||||
inputs = torch.reshape(inputs, [1, -1, observations.shape[-1]])
|
||||
observations = torch.cat(
|
||||
(observations, inputs), axis=1)[:, -self.max_seq_len:]
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
assert seq_lens is not None
|
||||
|
||||
# Add the needed batch rank (tf Models' Input requires this).
|
||||
observations = input_dict[SampleBatch.OBS]
|
||||
# Add the time dim to observations.
|
||||
B = len(seq_lens)
|
||||
T = observations.shape[0] // B
|
||||
observations = torch.reshape(observations,
|
||||
[-1, T] + list(observations.shape[1:]))
|
||||
|
||||
all_out = observations
|
||||
memory_outs = []
|
||||
for i in range(len(self.layers)):
|
||||
# MHA layers which need memory passed in.
|
||||
if i % 2 == 1:
|
||||
all_out = self.layers[i](all_out, memory=memory[i // 2])
|
||||
# Either linear layers or MultiLayerPerceptrons.
|
||||
all_out = self.layers[i](all_out, memory=state[i // 2])
|
||||
# Either self.linear_layer (initial obs -> attn. dim layer) or
|
||||
# MultiLayerPerceptrons. The output of these layers is always the
|
||||
# memory for the next forward pass.
|
||||
else:
|
||||
all_out = self.layers[i](all_out)
|
||||
memory_outs.append(all_out)
|
||||
|
||||
# Discard last output (not needed as a memory since it's the last
|
||||
# layer).
|
||||
memory_outs = memory_outs[:-1]
|
||||
|
||||
logits = self.logits(all_out)
|
||||
self._value_out = self.values_out(all_out)
|
||||
|
||||
memory_outs = all_out[2:]
|
||||
# If memory_tau > max_seq_len -> overlap w/ previous `memory` input.
|
||||
if self.memory_tau > self.max_seq_len:
|
||||
memory_outs = [
|
||||
torch.cat(
|
||||
[memory[i][:, -(self.memory_tau - self.max_seq_len):], m],
|
||||
axis=1) for i, m in enumerate(memory_outs)
|
||||
]
|
||||
else:
|
||||
memory_outs = [m[:, -self.memory_tau:] for m in memory_outs]
|
||||
|
||||
T = list(inputs.size())[1] # Length of input segment (time).
|
||||
|
||||
# Postprocessing final output.
|
||||
logits = logits[:, -T:]
|
||||
self._value_out = self._value_out[:, -T:]
|
||||
|
||||
return logits, [observations] + memory_outs
|
||||
return torch.reshape(logits, [-1, self.num_outputs]), [
|
||||
torch.reshape(m, [-1, self.attn_dim]) for m in memory_outs
|
||||
]
|
||||
|
||||
# TODO: (sven) Deprecate this once trajectory view API has fully matured.
|
||||
@override(RecurrentNetwork)
|
||||
def get_initial_state(self) -> List[np.ndarray]:
|
||||
# State is the T last observations concat'd together into one Tensor.
|
||||
# Plus all Transformer blocks' E(l) outputs concat'd together (up to
|
||||
# tau timesteps).
|
||||
return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + \
|
||||
[np.zeros((self.memory_tau, self.attn_dim), np.float32)
|
||||
for _ in range(self.num_transformer_units)]
|
||||
return []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self) -> TensorType:
|
||||
|
||||
@@ -13,26 +13,29 @@ class GRUGate(nn.Module):
|
||||
init_bias (int): Bias added to every input to stabilize training
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._init_bias = init_bias
|
||||
|
||||
# Xavier initialization of torch tensors
|
||||
self._w_r = torch.zeros(dim, dim)
|
||||
self._w_z = torch.zeros(dim, dim)
|
||||
self._w_h = torch.zeros(dim, dim)
|
||||
|
||||
self._u_r = torch.zeros(dim, dim)
|
||||
self._u_z = torch.zeros(dim, dim)
|
||||
self._u_h = torch.zeros(dim, dim)
|
||||
|
||||
self._w_r = nn.Parameter(torch.zeros(dim, dim))
|
||||
self._w_z = nn.Parameter(torch.zeros(dim, dim))
|
||||
self._w_h = nn.Parameter(torch.zeros(dim, dim))
|
||||
nn.init.xavier_uniform_(self._w_r)
|
||||
nn.init.xavier_uniform_(self._w_z)
|
||||
nn.init.xavier_uniform_(self._w_h)
|
||||
self.register_parameter("_w_r", self._w_r)
|
||||
self.register_parameter("_w_z", self._w_z)
|
||||
self.register_parameter("_w_h", self._w_h)
|
||||
|
||||
self._u_r = nn.Parameter(torch.zeros(dim, dim))
|
||||
self._u_z = nn.Parameter(torch.zeros(dim, dim))
|
||||
self._u_h = nn.Parameter(torch.zeros(dim, dim))
|
||||
nn.init.xavier_uniform_(self._u_r)
|
||||
nn.init.xavier_uniform_(self._u_z)
|
||||
nn.init.xavier_uniform_(self._u_h)
|
||||
self.register_parameter("_u_r", self._u_r)
|
||||
self.register_parameter("_u_z", self._u_z)
|
||||
self.register_parameter("_u_h", self._u_h)
|
||||
|
||||
self._bias_z = torch.zeros(dim, ).fill_(self._init_bias)
|
||||
self._bias_z = nn.Parameter(torch.zeros(dim, ).fill_(init_bias))
|
||||
self.register_parameter("_bias_z", self._bias_z)
|
||||
|
||||
def forward(self, inputs: TensorType, **kwargs) -> TensorType:
|
||||
# Pass in internal state first.
|
||||
|
||||
@@ -1,11 +1,47 @@
|
||||
from typing import Union
|
||||
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.utils.torch_ops import sequence_mask
|
||||
from ray.rllib.utils.typing import TensorType, Any
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
class RelativePositionEmbedding(nn.Module):
|
||||
"""Creates a [seq_length x seq_length] matrix for rel. pos encoding.
|
||||
|
||||
Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding
|
||||
matrix.
|
||||
|
||||
Args:
|
||||
seq_length (int): The max. sequence length (time axis).
|
||||
out_dim (int): The number of nodes to go into the first Tranformer
|
||||
layer with.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The encoding matrix Phi.
|
||||
"""
|
||||
|
||||
def __init__(self, out_dim, **kwargs):
|
||||
super().__init__()
|
||||
self.out_dim = out_dim
|
||||
|
||||
out_range = torch.arange(0, self.out_dim, 2.0)
|
||||
inverse_freq = 1 / (10000**(out_range / self.out_dim))
|
||||
self.register_buffer("inverse_freq", inverse_freq)
|
||||
|
||||
def forward(self, seq_length):
|
||||
pos_input = torch.arange(
|
||||
seq_length - 1, -1, -1.0,
|
||||
dtype=torch.float).to(self.inverse_freq.device)
|
||||
sinusoid_input = torch.einsum("i,j->ij", pos_input, self.inverse_freq)
|
||||
pos_embeddings = torch.cat(
|
||||
[torch.sin(sinusoid_input),
|
||||
torch.cos(sinusoid_input)], dim=-1)
|
||||
return pos_embeddings[:, None, :]
|
||||
|
||||
|
||||
class RelativeMultiHeadAttention(nn.Module):
|
||||
"""A RelativeMultiHeadAttention layer as described in [3].
|
||||
|
||||
@@ -17,24 +53,24 @@ class RelativeMultiHeadAttention(nn.Module):
|
||||
out_dim: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
rel_pos_encoder: Any,
|
||||
input_layernorm: bool = False,
|
||||
output_activation: Any = None,
|
||||
output_activation: Union[str, callable] = None,
|
||||
**kwargs):
|
||||
"""Initializes a RelativeMultiHeadAttention nn.Module object.
|
||||
|
||||
Args:
|
||||
in_dim (int):
|
||||
out_dim (int):
|
||||
out_dim (int): The output dimension of this module. Also known as
|
||||
"attention dim".
|
||||
num_heads (int): The number of attention heads to use.
|
||||
Denoted `H` in [2].
|
||||
head_dim (int): The dimension of a single(!) attention head
|
||||
Denoted `D` in [2].
|
||||
rel_pos_encoder (:
|
||||
input_layernorm (bool): Whether to prepend a LayerNorm before
|
||||
everything else. Should be True for building a GTrXL.
|
||||
output_activation (Optional[tf.nn.activation]): Optional tf.nn
|
||||
activation function. Should be relu for GTrXL.
|
||||
output_activation (Union[str, callable]): Optional activation
|
||||
function or activation function specifier (str).
|
||||
Should be "relu" for GTrXL.
|
||||
**kwargs:
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
@@ -53,17 +89,18 @@ class RelativeMultiHeadAttention(nn.Module):
|
||||
use_bias=False,
|
||||
activation_fn=output_activation)
|
||||
|
||||
self._pos_proj = SlimFC(
|
||||
in_size=in_dim, out_size=num_heads * head_dim, use_bias=False)
|
||||
|
||||
self._uvar = torch.zeros(num_heads, head_dim)
|
||||
self._vvar = torch.zeros(num_heads, head_dim)
|
||||
self._uvar = nn.Parameter(torch.zeros(num_heads, head_dim))
|
||||
self._vvar = nn.Parameter(torch.zeros(num_heads, head_dim))
|
||||
nn.init.xavier_uniform_(self._uvar)
|
||||
nn.init.xavier_uniform_(self._vvar)
|
||||
self.register_parameter("_uvar", self._uvar)
|
||||
self.register_parameter("_vvar", self._vvar)
|
||||
|
||||
self._pos_proj = SlimFC(
|
||||
in_size=in_dim, out_size=num_heads * head_dim, use_bias=False)
|
||||
self._rel_pos_embedding = RelativePositionEmbedding(out_dim)
|
||||
|
||||
self._rel_pos_encoder = rel_pos_encoder
|
||||
self._input_layernorm = None
|
||||
|
||||
if input_layernorm:
|
||||
self._input_layernorm = torch.nn.LayerNorm(in_dim)
|
||||
|
||||
@@ -75,10 +112,8 @@ class RelativeMultiHeadAttention(nn.Module):
|
||||
|
||||
# Add previous memory chunk (as const, w/o gradient) to input.
|
||||
# Tau (number of (prev) time slices in each memory chunk).
|
||||
Tau = list(memory.shape)[1] if memory is not None else 0
|
||||
if memory is not None:
|
||||
memory.requires_grad_(False)
|
||||
inputs = torch.cat((memory, inputs), dim=1)
|
||||
Tau = list(memory.shape)[1]
|
||||
inputs = torch.cat((memory.detach(), inputs), dim=1)
|
||||
|
||||
# Apply the Layer-Norm.
|
||||
if self._input_layernorm is not None:
|
||||
@@ -91,11 +126,11 @@ class RelativeMultiHeadAttention(nn.Module):
|
||||
queries = queries[:, -T:]
|
||||
|
||||
queries = torch.reshape(queries, [-1, T, H, d])
|
||||
keys = torch.reshape(keys, [-1, T + Tau, H, d])
|
||||
values = torch.reshape(values, [-1, T + Tau, H, d])
|
||||
keys = torch.reshape(keys, [-1, Tau + T, H, d])
|
||||
values = torch.reshape(values, [-1, Tau + T, H, d])
|
||||
|
||||
R = self._pos_proj(self._rel_pos_encoder)
|
||||
R = torch.reshape(R, [T + Tau, H, d])
|
||||
R = self._pos_proj(self._rel_pos_embedding(Tau + T))
|
||||
R = torch.reshape(R, [Tau + T, H, d])
|
||||
|
||||
# b=batch
|
||||
# i and j=time indices (i=max-timesteps (inputs); j=Tau memory space)
|
||||
@@ -108,10 +143,11 @@ class RelativeMultiHeadAttention(nn.Module):
|
||||
|
||||
# causal mask of the same length as the sequence
|
||||
mask = sequence_mask(
|
||||
torch.arange(Tau + 1, T + Tau + 1), dtype=score.dtype)
|
||||
torch.arange(Tau + 1, Tau + T + 1),
|
||||
dtype=score.dtype).to(score.device)
|
||||
mask = mask[None, :, :, None]
|
||||
|
||||
masked_score = score * mask + 1e30 * (mask.to(torch.float32) - 1.)
|
||||
masked_score = score * mask + 1e30 * (mask.float() - 1.)
|
||||
wmat = nn.functional.softmax(masked_score, dim=2)
|
||||
|
||||
out = torch.einsum("bijh,bjhd->bihd", wmat, values)
|
||||
|
||||
@@ -259,10 +259,8 @@ def build_eager_tf_policy(name,
|
||||
self._update_model_inference_view_requirements_from_init_state()
|
||||
|
||||
self.exploration = self._create_exploration()
|
||||
self._state_in = [
|
||||
tf.convert_to_tensor([s])
|
||||
for s in self.model.get_initial_state()
|
||||
]
|
||||
self._state_inputs = self.model.get_initial_state()
|
||||
self._is_recurrent = len(self._state_inputs) > 0
|
||||
|
||||
# Combine view_requirements for Model and Policy.
|
||||
self.view_requirements.update(
|
||||
@@ -375,6 +373,8 @@ def build_eager_tf_policy(name,
|
||||
|
||||
# TODO: remove python side effect to cull sources of bugs.
|
||||
self._is_training = False
|
||||
self._is_recurrent = \
|
||||
state_batches is not None and state_batches != []
|
||||
self._state_in = state_batches or []
|
||||
|
||||
if not tf1.executing_eagerly():
|
||||
@@ -552,11 +552,11 @@ def build_eager_tf_policy(name,
|
||||
|
||||
@override(Policy)
|
||||
def is_recurrent(self):
|
||||
return len(self._state_in) > 0
|
||||
return self._is_recurrent
|
||||
|
||||
@override(Policy)
|
||||
def num_state_tensors(self):
|
||||
return len(self._state_in)
|
||||
return len(self._state_inputs)
|
||||
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
|
||||
@@ -170,9 +170,6 @@ def build_tf_policy(
|
||||
mixins (Optional[List[type]]): Optional list of any class mixins for
|
||||
the returned policy class. These mixins will be applied in order
|
||||
and will have higher precedence than the DynamicTFPolicy class.
|
||||
view_requirements_fn (Callable[[Policy],
|
||||
Dict[str, ViewRequirement]]): An optional callable to retrieve
|
||||
additional train view requirements for this policy.
|
||||
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
|
||||
Optional callable that returns the divisibility requirement for
|
||||
sample batches. If None, will assume a value of 1.
|
||||
|
||||
@@ -110,6 +110,8 @@ class TorchPolicy(Policy):
|
||||
logger.info("TorchPolicy running on CPU.")
|
||||
self.device = torch.device("cpu")
|
||||
self.model = model.to(self.device)
|
||||
self._state_inputs = self.model.get_initial_state()
|
||||
self._is_recurrent = len(self._state_inputs) > 0
|
||||
# Auto-update model's inference view requirements, if recurrent.
|
||||
self._update_model_inference_view_requirements_from_init_state()
|
||||
# Combine view_requirements for Model and Policy.
|
||||
@@ -203,6 +205,11 @@ class TorchPolicy(Policy):
|
||||
Tuple:
|
||||
- actions, state_out, extra_fetches, logp.
|
||||
"""
|
||||
self._is_recurrent = state_batches is not None and state_batches != []
|
||||
# Switch to eval mode.
|
||||
if self.model:
|
||||
self.model.eval()
|
||||
|
||||
if self.action_sampler_fn:
|
||||
action_dist = dist_inputs = None
|
||||
state_out = state_batches
|
||||
@@ -325,6 +332,9 @@ class TorchPolicy(Policy):
|
||||
@DeveloperAPI
|
||||
def learn_on_batch(
|
||||
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
# Set Model to train mode.
|
||||
if self.model:
|
||||
self.model.train()
|
||||
# Callback handling.
|
||||
self.callbacks.on_learn_on_batch(
|
||||
policy=self, train_batch=postprocessed_batch)
|
||||
@@ -354,6 +364,9 @@ class TorchPolicy(Policy):
|
||||
view_requirements=self.view_requirements,
|
||||
)
|
||||
|
||||
# Mark the batch as "is_training" so the Model can use this
|
||||
# information.
|
||||
postprocessed_batch["is_training"] = True
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
||||
# Calculate the actual policy loss.
|
||||
@@ -448,7 +461,7 @@ class TorchPolicy(Policy):
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def is_recurrent(self) -> bool:
|
||||
return len(self.model.get_initial_state()) > 0
|
||||
return self._is_recurrent
|
||||
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
|
||||
Reference in New Issue
Block a user