[RLlib] Attention nets PyTorch support and cleanup (using traj. view API). (#12029)

This commit is contained in:
Sven Mika
2020-12-21 21:38:34 -05:00
committed by GitHub
parent 8068041006
commit d5604eaba3
14 changed files with 292 additions and 148 deletions
+9 -10
View File
@@ -1480,20 +1480,19 @@ py_test(
name = "examples/attention_net_tf",
main = "examples/attention_net.py",
tags = ["examples", "examples_A"],
size = "large",
size = "medium",
srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=80"]
)
# TODO(sven): GTrXL PyTorch.
# py_test(
# name = "examples/attention_net_torch",
# main = "examples/attention_net.py",
# tags = ["examples", "examples_A"],
# size = "large",
# srcs = ["examples/attention_net.py"],
# args = ["--as-test", "--torch", "--stop-reward=90"]
# )
py_test(
name = "examples/attention_net_torch",
main = "examples/attention_net.py",
tags = ["examples", "examples_A"],
size = "medium",
srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=80", "--torch"]
)
py_test(
name = "examples/autoregressive_action_dist_tf",
+2 -1
View File
@@ -49,7 +49,8 @@ def ppo_surrogate_loss(
# RNN case: Mask away 0-padded chunks at end of time axis.
if state:
max_seq_len = torch.max(train_batch["seq_lens"])
B = len(train_batch["seq_lens"])
max_seq_len = logits.shape[0] // B
mask = sequence_mask(
train_batch["seq_lens"],
max_seq_len,
@@ -35,9 +35,6 @@ def to_float_np_array(v: List[Any]) -> np.ndarray:
return arr
_INIT_COLS = [SampleBatch.OBS]
class _AgentCollector:
"""Collects samples for one agent in one trajectory (episode).
@@ -55,8 +52,9 @@ class _AgentCollector:
# or internal state inputs.
self.shift_before = -min(
(int(vr.shift.split(":")[0])
if isinstance(vr.shift, str) else vr.shift) +
(-1 if vr.data_col in _INIT_COLS or k in _INIT_COLS else 0)
if isinstance(vr.shift, str) else vr.shift) -
(1
if vr.data_col == SampleBatch.OBS or k == SampleBatch.OBS else 0)
for k, vr in view_reqs.items())
# The actual data buffers (lists holding each timestep's data).
self.buffers: Dict[str, List] = {}
@@ -7,19 +7,41 @@ import unittest
import ray
from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
import ray.rllib.agents.dqn as dqn
import ray.rllib.agents.ppo as ppo
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.examples.policy.episode_env_aware_policy import \
EpisodeEnvAwareLSTMPolicy
EpisodeEnvAwareAttentionPolicy, EpisodeEnvAwareLSTMPolicy
from ray.rllib.models.tf.attention_net import GTrXLNet
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.test_utils import framework_iterator, check
class MyCallbacks(DefaultCallbacks):
@override(DefaultCallbacks)
def on_learn_on_batch(self, *, policy, train_batch, **kwargs):
assert train_batch.count == 201
assert sum(train_batch.seq_lens) == 201
for k, v in train_batch.data.items():
if k == "state_in_0":
assert len(v) == len(train_batch.seq_lens)
else:
assert len(v) == 201
current = None
for o in train_batch[SampleBatch.OBS]:
if current:
assert o == current + 1
current = o
if o == 15:
current = None
class TestTrajectoryViewAPI(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
@@ -116,6 +138,45 @@ class TestTrajectoryViewAPI(unittest.TestCase):
assert view_req_policy[key].shift == 1
trainer.stop()
def test_traj_view_attention_net(self):
config = ppo.DEFAULT_CONFIG.copy()
# Setup attention net.
config["model"] = config["model"].copy()
config["model"]["max_seq_len"] = 50
config["model"]["custom_model"] = GTrXLNet
config["model"]["custom_model_config"] = {
"num_transformer_units": 1,
"attn_dim": 64,
"num_heads": 2,
"memory_inference": 50,
"memory_training": 50,
"head_dim": 32,
"ff_hidden_dim": 32,
}
# Test with odd batch numbers.
config["train_batch_size"] = 1031
config["sgd_minibatch_size"] = 201
config["num_sgd_iter"] = 5
config["num_workers"] = 0
config["callbacks"] = MyCallbacks
config["env_config"] = {
"config": {
"start_at_t": 1
}
} # first obs is [1.0]
for _ in framework_iterator(config, frameworks="tf2"):
trainer = ppo.PPOTrainer(
config,
env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv",
)
rw = trainer.workers.local_worker()
sample = rw.sample()
assert sample.count == config["rollout_fragment_length"]
results = trainer.train()
assert results["train_batch_size"] == config["train_batch_size"]
trainer.stop()
def test_traj_view_simple_performance(self):
"""Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`.
"""
@@ -298,6 +359,40 @@ class TestTrajectoryViewAPI(unittest.TestCase):
pol_batch_wo = result.policy_batches["pol0"]
check(pol_batch_w.data, pol_batch_wo.data)
def test_traj_view_attention_functionality(self):
action_space = Box(-float("inf"), float("inf"), shape=(3, ))
obs_space = Box(float("-inf"), float("inf"), (4, ))
max_seq_len = 50
rollout_fragment_length = 201
policies = {
"pol0": (EpisodeEnvAwareAttentionPolicy, obs_space, action_space,
{}),
}
def policy_fn(agent_id):
return "pol0"
config = {
"multiagent": {
"policies": policies,
"policy_mapping_fn": policy_fn,
},
"model": {
"max_seq_len": max_seq_len,
},
},
rollout_worker_w_api = RolloutWorker(
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
policy_config=dict(config, **{"_use_trajectory_view_api": True}),
rollout_fragment_length=rollout_fragment_length,
policy_spec=policies,
policy_mapping_fn=policy_fn,
num_envs=1,
)
batch = rollout_worker_w_api.sample()
print(batch)
def test_counting_by_agent_steps(self):
"""Test whether a PPOTrainer can be built with all frameworks."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
+2 -3
View File
@@ -4,6 +4,7 @@ import os
import ray
from ray import tune
from ray.rllib.models.tf.attention_net import GTrXLNet
from ray.rllib.models.torch.attention_net import GTrXLNet as TorchGTrXLNet
from ray.rllib.examples.env.look_and_push import LookAndPush, OneHot
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.examples.env.repeat_initial_obs_env import RepeatInitialObsEnv
@@ -27,8 +28,6 @@ parser.add_argument("--stop-reward", type=float, default=80)
if __name__ == "__main__":
args = parser.parse_args()
assert not args.torch, "PyTorch not supported for AttentionNets yet!"
ray.init(num_cpus=args.num_cpus or None)
registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
@@ -52,7 +51,7 @@ if __name__ == "__main__":
"num_sgd_iter": 10,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": GTrXLNet,
"custom_model": TorchGTrXLNet if args.torch else GTrXLNet,
"max_seq_len": 50,
"custom_model_config": {
"num_transformer_units": 1,
@@ -28,6 +28,7 @@ class MyCallbacks(DefaultCallbacks):
episode: MultiAgentEpisode, env_index: int, **kwargs):
print("episode {} (env-idx={}) started.".format(
episode.episode_id, env_index))
episode.user_data["pole_angles"] = []
episode.hist_data["pole_angles"] = []
+13 -6
View File
@@ -12,18 +12,25 @@ class DebugCounterEnv(gym.Env):
Reward is always: current ts % 3.
"""
def __init__(self):
def __init__(self, config=None):
config = config or {}
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(0, 100, (1, ))
self.i = 0
self.observation_space = \
gym.spaces.Box(0, 100, (1, ), dtype=np.float32)
self.start_at_t = int(config.get("start_at_t", 0))
self.i = self.start_at_t
def reset(self):
self.i = 0
return [self.i]
self.i = self.start_at_t
return self._get_obs()
def step(self, action):
self.i += 1
return [self.i], self.i % 3, self.i >= 15, {}
return self._get_obs(), float(self.i % 3), \
self.i >= 15 + self.start_at_t, {}
def _get_obs(self):
return np.array([self.i], dtype=np.float32)
class MultiAgentDebugCounterEnv(MultiAgentEnv):
@@ -45,9 +45,10 @@ class CentralizedCriticModel(TFModelV2):
def central_value_function(self, obs, opponent_obs, opponent_actions):
return tf.reshape(
self.central_vf(
[obs, opponent_obs,
tf.one_hot(opponent_actions, 2)]), [-1])
self.central_vf([
obs, opponent_obs,
tf.one_hot(tf.cast(opponent_actions, tf.int32), 2)
]), [-1])
@override(ModelV2)
def value_function(self):
@@ -124,7 +125,7 @@ class TorchCentralizedCriticModel(TorchModelV2, nn.Module):
def central_value_function(self, obs, opponent_obs, opponent_actions):
input_ = torch.cat([
obs, opponent_obs,
torch.nn.functional.one_hot(opponent_actions, 2).float()
torch.nn.functional.one_hot(opponent_actions.long(), 2).float()
], 1)
return torch.reshape(self.central_vf(input_), [-1])
+67 -73
View File
@@ -10,12 +10,15 @@
"""
import numpy as np
import gym
from gym.spaces import Box
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.modules import GRUGate, \
RelativeMultiHeadAttention, SkipConnection
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
@@ -23,26 +26,6 @@ from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
torch, nn = try_import_torch()
def relative_position_embedding(seq_length: int, out_dim: int) -> TensorType:
"""Creates a [seq_length x seq_length] matrix for rel. pos encoding.
Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding
matrix.
Args:
seq_length (int): The max. sequence length (time axis).
out_dim (int): The number of nodes to go into the first Tranformer
layer with.
Returns:
torch.Tensor: The encoding matrix Phi.
"""
inverse_freq = 1 / (10000**(torch.arange(0, out_dim, 2.0) / out_dim))
pos_offsets = torch.arange(seq_length - 1, -1, -1)
inputs = pos_offsets[:, None] * inverse_freq[None, :]
return torch.cat((torch.sin(inputs), torch.cos(inputs)), dim=-1)
class GTrXLNet(RecurrentNetwork, nn.Module):
"""A GTrXL net Model described in [2].
@@ -74,7 +57,8 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
num_transformer_units: int,
attn_dim: int,
num_heads: int,
memory_tau: int,
memory_inference: int,
memory_training: int,
head_dim: int,
ff_hidden_dim: int,
init_gate_bias: float = 2.0):
@@ -87,9 +71,15 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
unit.
num_heads (int): The number of attention heads to use in parallel.
Denoted as `H` in [3].
memory_tau (int): The number of timesteps to store in each
transformer block's memory M (concat'd over time and fed into
next transformer block as input).
memory_inference (int): The number of timesteps to concat (time
axis) and feed into the next transformer unit as inference
input. The first transformer unit will receive this number of
past observations (plus the current one), instead.
memory_training (int): The number of timesteps to concat (time
axis) and feed into the next transformer unit as training
input (plus the actual input sequence of len=max_seq_len).
The first transformer unit will receive this number of
past observations (plus the input sequence), instead.
head_dim (int): The dimension of a single(!) head.
Denoted as `d` in [3].
ff_hidden_dim (int): The dimension of the hidden layer within
@@ -110,20 +100,18 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
self.num_transformer_units = num_transformer_units
self.attn_dim = attn_dim
self.num_heads = num_heads
self.memory_tau = memory_tau
self.memory_inference = memory_inference
self.memory_training = memory_training
self.head_dim = head_dim
self.max_seq_len = model_config["max_seq_len"]
self.obs_dim = observation_space.shape[0]
# Constant (non-trainable) sinusoid rel pos encoding matrix.
Phi = relative_position_embedding(self.max_seq_len + self.memory_tau,
self.attn_dim)
self.linear_layer = SlimFC(
in_size=self.obs_dim, out_size=self.attn_dim)
self.layers = [self.linear_layer]
attention_layers = []
# 2) Create L Transformer blocks according to [2].
for i in range(self.num_transformer_units):
# RelativeMultiHeadAttention part.
@@ -133,7 +121,6 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
out_dim=self.attn_dim,
num_heads=num_heads,
head_dim=head_dim,
rel_pos_encoder=Phi,
input_layernorm=True,
output_activation=nn.ReLU),
fan_in_layer=GRUGate(self.attn_dim, init_gate_bias))
@@ -154,8 +141,13 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
activation_fn=nn.ReLU)),
fan_in_layer=GRUGate(self.attn_dim, init_gate_bias))
# Build a list of all layers in order.
self.layers.extend([MHA_layer, E_layer])
# Build a list of all attanlayers in order.
attention_layers.extend([MHA_layer, E_layer])
# Create a Sequential such that all parameters inside the attention
# layers are automatically registered with this top-level model.
self.attention_layers = nn.Sequential(*attention_layers)
self.layers.extend(attention_layers)
# Postprocess GTrXL output with another hidden layer.
self.logits = SlimFC(
@@ -168,62 +160,64 @@ class GTrXLNet(RecurrentNetwork, nn.Module):
self.values_out = SlimFC(
in_size=self.attn_dim, out_size=1, activation_fn=None)
@override(RecurrentNetwork)
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
# To make Attention work with current RLlib's ModelV2 API:
# We assume `state` is the history of L recent observations (all
# concatenated into one tensor) and append the current inputs to the
# end and only keep the most recent (up to `max_seq_len`). This allows
# us to deal with timestep-wise inference and full sequence training
# within the same logic.
state = [torch.from_numpy(item) for item in state]
observations = state[0]
memory = state[1:]
# Setup inference view (`memory-inference` x past observations +
# current one (0))
# 1 to `num_transformer_units`: Memory data (one per transformer unit).
for i in range(self.num_transformer_units):
space = Box(-1.0, 1.0, shape=(self.attn_dim, ))
self.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift="-{}:-1".format(self.memory_inference),
# Repeat the incoming state every max-seq-len times.
batch_repeat_value=self.max_seq_len,
space=space)
self.inference_view_requirements["state_out_{}".format(i)] = \
ViewRequirement(
space=space,
used_for_training=False)
inputs = torch.reshape(inputs, [1, -1, observations.shape[-1]])
observations = torch.cat(
(observations, inputs), axis=1)[:, -self.max_seq_len:]
@override(ModelV2)
def forward(self, input_dict, state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
assert seq_lens is not None
# Add the needed batch rank (tf Models' Input requires this).
observations = input_dict[SampleBatch.OBS]
# Add the time dim to observations.
B = len(seq_lens)
T = observations.shape[0] // B
observations = torch.reshape(observations,
[-1, T] + list(observations.shape[1:]))
all_out = observations
memory_outs = []
for i in range(len(self.layers)):
# MHA layers which need memory passed in.
if i % 2 == 1:
all_out = self.layers[i](all_out, memory=memory[i // 2])
# Either linear layers or MultiLayerPerceptrons.
all_out = self.layers[i](all_out, memory=state[i // 2])
# Either self.linear_layer (initial obs -> attn. dim layer) or
# MultiLayerPerceptrons. The output of these layers is always the
# memory for the next forward pass.
else:
all_out = self.layers[i](all_out)
memory_outs.append(all_out)
# Discard last output (not needed as a memory since it's the last
# layer).
memory_outs = memory_outs[:-1]
logits = self.logits(all_out)
self._value_out = self.values_out(all_out)
memory_outs = all_out[2:]
# If memory_tau > max_seq_len -> overlap w/ previous `memory` input.
if self.memory_tau > self.max_seq_len:
memory_outs = [
torch.cat(
[memory[i][:, -(self.memory_tau - self.max_seq_len):], m],
axis=1) for i, m in enumerate(memory_outs)
]
else:
memory_outs = [m[:, -self.memory_tau:] for m in memory_outs]
T = list(inputs.size())[1] # Length of input segment (time).
# Postprocessing final output.
logits = logits[:, -T:]
self._value_out = self._value_out[:, -T:]
return logits, [observations] + memory_outs
return torch.reshape(logits, [-1, self.num_outputs]), [
torch.reshape(m, [-1, self.attn_dim]) for m in memory_outs
]
# TODO: (sven) Deprecate this once trajectory view API has fully matured.
@override(RecurrentNetwork)
def get_initial_state(self) -> List[np.ndarray]:
# State is the T last observations concat'd together into one Tensor.
# Plus all Transformer blocks' E(l) outputs concat'd together (up to
# tau timesteps).
return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + \
[np.zeros((self.memory_tau, self.attn_dim), np.float32)
for _ in range(self.num_transformer_units)]
return []
@override(ModelV2)
def value_function(self) -> TensorType:
+14 -11
View File
@@ -13,26 +13,29 @@ class GRUGate(nn.Module):
init_bias (int): Bias added to every input to stabilize training
"""
super().__init__(**kwargs)
self._init_bias = init_bias
# Xavier initialization of torch tensors
self._w_r = torch.zeros(dim, dim)
self._w_z = torch.zeros(dim, dim)
self._w_h = torch.zeros(dim, dim)
self._u_r = torch.zeros(dim, dim)
self._u_z = torch.zeros(dim, dim)
self._u_h = torch.zeros(dim, dim)
self._w_r = nn.Parameter(torch.zeros(dim, dim))
self._w_z = nn.Parameter(torch.zeros(dim, dim))
self._w_h = nn.Parameter(torch.zeros(dim, dim))
nn.init.xavier_uniform_(self._w_r)
nn.init.xavier_uniform_(self._w_z)
nn.init.xavier_uniform_(self._w_h)
self.register_parameter("_w_r", self._w_r)
self.register_parameter("_w_z", self._w_z)
self.register_parameter("_w_h", self._w_h)
self._u_r = nn.Parameter(torch.zeros(dim, dim))
self._u_z = nn.Parameter(torch.zeros(dim, dim))
self._u_h = nn.Parameter(torch.zeros(dim, dim))
nn.init.xavier_uniform_(self._u_r)
nn.init.xavier_uniform_(self._u_z)
nn.init.xavier_uniform_(self._u_h)
self.register_parameter("_u_r", self._u_r)
self.register_parameter("_u_z", self._u_z)
self.register_parameter("_u_h", self._u_h)
self._bias_z = torch.zeros(dim, ).fill_(self._init_bias)
self._bias_z = nn.Parameter(torch.zeros(dim, ).fill_(init_bias))
self.register_parameter("_bias_z", self._bias_z)
def forward(self, inputs: TensorType, **kwargs) -> TensorType:
# Pass in internal state first.
@@ -1,11 +1,47 @@
from typing import Union
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.utils.torch_ops import sequence_mask
from ray.rllib.utils.typing import TensorType, Any
from ray.rllib.utils.typing import TensorType
torch, nn = try_import_torch()
class RelativePositionEmbedding(nn.Module):
"""Creates a [seq_length x seq_length] matrix for rel. pos encoding.
Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding
matrix.
Args:
seq_length (int): The max. sequence length (time axis).
out_dim (int): The number of nodes to go into the first Tranformer
layer with.
Returns:
torch.Tensor: The encoding matrix Phi.
"""
def __init__(self, out_dim, **kwargs):
super().__init__()
self.out_dim = out_dim
out_range = torch.arange(0, self.out_dim, 2.0)
inverse_freq = 1 / (10000**(out_range / self.out_dim))
self.register_buffer("inverse_freq", inverse_freq)
def forward(self, seq_length):
pos_input = torch.arange(
seq_length - 1, -1, -1.0,
dtype=torch.float).to(self.inverse_freq.device)
sinusoid_input = torch.einsum("i,j->ij", pos_input, self.inverse_freq)
pos_embeddings = torch.cat(
[torch.sin(sinusoid_input),
torch.cos(sinusoid_input)], dim=-1)
return pos_embeddings[:, None, :]
class RelativeMultiHeadAttention(nn.Module):
"""A RelativeMultiHeadAttention layer as described in [3].
@@ -17,24 +53,24 @@ class RelativeMultiHeadAttention(nn.Module):
out_dim: int,
num_heads: int,
head_dim: int,
rel_pos_encoder: Any,
input_layernorm: bool = False,
output_activation: Any = None,
output_activation: Union[str, callable] = None,
**kwargs):
"""Initializes a RelativeMultiHeadAttention nn.Module object.
Args:
in_dim (int):
out_dim (int):
out_dim (int): The output dimension of this module. Also known as
"attention dim".
num_heads (int): The number of attention heads to use.
Denoted `H` in [2].
head_dim (int): The dimension of a single(!) attention head
Denoted `D` in [2].
rel_pos_encoder (:
input_layernorm (bool): Whether to prepend a LayerNorm before
everything else. Should be True for building a GTrXL.
output_activation (Optional[tf.nn.activation]): Optional tf.nn
activation function. Should be relu for GTrXL.
output_activation (Union[str, callable]): Optional activation
function or activation function specifier (str).
Should be "relu" for GTrXL.
**kwargs:
"""
super().__init__(**kwargs)
@@ -53,17 +89,18 @@ class RelativeMultiHeadAttention(nn.Module):
use_bias=False,
activation_fn=output_activation)
self._pos_proj = SlimFC(
in_size=in_dim, out_size=num_heads * head_dim, use_bias=False)
self._uvar = torch.zeros(num_heads, head_dim)
self._vvar = torch.zeros(num_heads, head_dim)
self._uvar = nn.Parameter(torch.zeros(num_heads, head_dim))
self._vvar = nn.Parameter(torch.zeros(num_heads, head_dim))
nn.init.xavier_uniform_(self._uvar)
nn.init.xavier_uniform_(self._vvar)
self.register_parameter("_uvar", self._uvar)
self.register_parameter("_vvar", self._vvar)
self._pos_proj = SlimFC(
in_size=in_dim, out_size=num_heads * head_dim, use_bias=False)
self._rel_pos_embedding = RelativePositionEmbedding(out_dim)
self._rel_pos_encoder = rel_pos_encoder
self._input_layernorm = None
if input_layernorm:
self._input_layernorm = torch.nn.LayerNorm(in_dim)
@@ -75,10 +112,8 @@ class RelativeMultiHeadAttention(nn.Module):
# Add previous memory chunk (as const, w/o gradient) to input.
# Tau (number of (prev) time slices in each memory chunk).
Tau = list(memory.shape)[1] if memory is not None else 0
if memory is not None:
memory.requires_grad_(False)
inputs = torch.cat((memory, inputs), dim=1)
Tau = list(memory.shape)[1]
inputs = torch.cat((memory.detach(), inputs), dim=1)
# Apply the Layer-Norm.
if self._input_layernorm is not None:
@@ -91,11 +126,11 @@ class RelativeMultiHeadAttention(nn.Module):
queries = queries[:, -T:]
queries = torch.reshape(queries, [-1, T, H, d])
keys = torch.reshape(keys, [-1, T + Tau, H, d])
values = torch.reshape(values, [-1, T + Tau, H, d])
keys = torch.reshape(keys, [-1, Tau + T, H, d])
values = torch.reshape(values, [-1, Tau + T, H, d])
R = self._pos_proj(self._rel_pos_encoder)
R = torch.reshape(R, [T + Tau, H, d])
R = self._pos_proj(self._rel_pos_embedding(Tau + T))
R = torch.reshape(R, [Tau + T, H, d])
# b=batch
# i and j=time indices (i=max-timesteps (inputs); j=Tau memory space)
@@ -108,10 +143,11 @@ class RelativeMultiHeadAttention(nn.Module):
# causal mask of the same length as the sequence
mask = sequence_mask(
torch.arange(Tau + 1, T + Tau + 1), dtype=score.dtype)
torch.arange(Tau + 1, Tau + T + 1),
dtype=score.dtype).to(score.device)
mask = mask[None, :, :, None]
masked_score = score * mask + 1e30 * (mask.to(torch.float32) - 1.)
masked_score = score * mask + 1e30 * (mask.float() - 1.)
wmat = nn.functional.softmax(masked_score, dim=2)
out = torch.einsum("bijh,bjhd->bihd", wmat, values)
+6 -6
View File
@@ -259,10 +259,8 @@ def build_eager_tf_policy(name,
self._update_model_inference_view_requirements_from_init_state()
self.exploration = self._create_exploration()
self._state_in = [
tf.convert_to_tensor([s])
for s in self.model.get_initial_state()
]
self._state_inputs = self.model.get_initial_state()
self._is_recurrent = len(self._state_inputs) > 0
# Combine view_requirements for Model and Policy.
self.view_requirements.update(
@@ -375,6 +373,8 @@ def build_eager_tf_policy(name,
# TODO: remove python side effect to cull sources of bugs.
self._is_training = False
self._is_recurrent = \
state_batches is not None and state_batches != []
self._state_in = state_batches or []
if not tf1.executing_eagerly():
@@ -552,11 +552,11 @@ def build_eager_tf_policy(name,
@override(Policy)
def is_recurrent(self):
return len(self._state_in) > 0
return self._is_recurrent
@override(Policy)
def num_state_tensors(self):
return len(self._state_in)
return len(self._state_inputs)
@override(Policy)
def get_initial_state(self):
-3
View File
@@ -170,9 +170,6 @@ def build_tf_policy(
mixins (Optional[List[type]]): Optional list of any class mixins for
the returned policy class. These mixins will be applied in order
and will have higher precedence than the DynamicTFPolicy class.
view_requirements_fn (Callable[[Policy],
Dict[str, ViewRequirement]]): An optional callable to retrieve
additional train view requirements for this policy.
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
Optional callable that returns the divisibility requirement for
sample batches. If None, will assume a value of 1.
+14 -1
View File
@@ -110,6 +110,8 @@ class TorchPolicy(Policy):
logger.info("TorchPolicy running on CPU.")
self.device = torch.device("cpu")
self.model = model.to(self.device)
self._state_inputs = self.model.get_initial_state()
self._is_recurrent = len(self._state_inputs) > 0
# Auto-update model's inference view requirements, if recurrent.
self._update_model_inference_view_requirements_from_init_state()
# Combine view_requirements for Model and Policy.
@@ -203,6 +205,11 @@ class TorchPolicy(Policy):
Tuple:
- actions, state_out, extra_fetches, logp.
"""
self._is_recurrent = state_batches is not None and state_batches != []
# Switch to eval mode.
if self.model:
self.model.eval()
if self.action_sampler_fn:
action_dist = dist_inputs = None
state_out = state_batches
@@ -325,6 +332,9 @@ class TorchPolicy(Policy):
@DeveloperAPI
def learn_on_batch(
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
# Set Model to train mode.
if self.model:
self.model.train()
# Callback handling.
self.callbacks.on_learn_on_batch(
policy=self, train_batch=postprocessed_batch)
@@ -354,6 +364,9 @@ class TorchPolicy(Policy):
view_requirements=self.view_requirements,
)
# Mark the batch as "is_training" so the Model can use this
# information.
postprocessed_batch["is_training"] = True
train_batch = self._lazy_tensor_dict(postprocessed_batch)
# Calculate the actual policy loss.
@@ -448,7 +461,7 @@ class TorchPolicy(Policy):
@override(Policy)
@DeveloperAPI
def is_recurrent(self) -> bool:
return len(self.model.get_initial_state()) > 0
return self._is_recurrent
@override(Policy)
@DeveloperAPI