mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:37:39 +08:00
[RLlib] Model Annotations: Tensorflow (#11964)
This commit is contained in:
@@ -9,6 +9,8 @@
|
||||
https://www.aclweb.org/anthology/P19-1285.pdf
|
||||
"""
|
||||
import numpy as np
|
||||
import gym
|
||||
from typing import Optional, Any
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \
|
||||
@@ -16,6 +18,7 @@ from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \
|
||||
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
@@ -28,7 +31,11 @@ class PositionwiseFeedforward(tf.keras.layers.Layer):
|
||||
layer separately.
|
||||
"""
|
||||
|
||||
def __init__(self, out_dim, hidden_dim, output_activation=None, **kwargs):
|
||||
def __init__(self,
|
||||
out_dim: int,
|
||||
hidden_dim: int,
|
||||
output_activation: Optional[Any] = None,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._hidden_layer = tf.keras.layers.Dense(
|
||||
@@ -39,7 +46,7 @@ class PositionwiseFeedforward(tf.keras.layers.Layer):
|
||||
self._output_layer = tf.keras.layers.Dense(
|
||||
out_dim, activation=output_activation)
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
def call(self, inputs: TensorType, **kwargs) -> TensorType:
|
||||
del kwargs
|
||||
output = self._hidden_layer(inputs)
|
||||
return self._output_layer(output)
|
||||
@@ -48,9 +55,11 @@ class PositionwiseFeedforward(tf.keras.layers.Layer):
|
||||
class TrXLNet(RecurrentNetwork):
|
||||
"""A TrXL net Model described in [1]."""
|
||||
|
||||
def __init__(self, observation_space, action_space, num_outputs,
|
||||
model_config, name, num_transformer_units, attn_dim,
|
||||
num_heads, head_dim, ff_hidden_dim):
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space, num_outputs: int,
|
||||
model_config: ModelConfigDict, name: str,
|
||||
num_transformer_units: int, attn_dim: int, num_heads: int,
|
||||
head_dim: int, ff_hidden_dim: int):
|
||||
"""Initializes a TfXLNet object.
|
||||
|
||||
Args:
|
||||
@@ -109,7 +118,8 @@ class TrXLNet(RecurrentNetwork):
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
# To make Attention work with current RLlib's ModelV2 API:
|
||||
# We assume `state` is the history of L recent observations (all
|
||||
# concatenated into one tensor) and append the current inputs to the
|
||||
@@ -126,7 +136,7 @@ class TrXLNet(RecurrentNetwork):
|
||||
return logits, [observations]
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def get_initial_state(self):
|
||||
def get_initial_state(self) -> List[np.ndarray]:
|
||||
# State is the T last observations concat'd together into one Tensor.
|
||||
# Plus all Transformer blocks' E(l) outputs concat'd together (up to
|
||||
# tau timesteps).
|
||||
@@ -156,18 +166,18 @@ class GTrXLNet(RecurrentNetwork):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
num_transformer_units,
|
||||
attn_dim,
|
||||
num_heads,
|
||||
memory_tau,
|
||||
head_dim,
|
||||
ff_hidden_dim,
|
||||
init_gate_bias=2.0):
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
num_outputs: int,
|
||||
model_config: ModelConfigDict,
|
||||
name: str,
|
||||
num_transformer_units: int,
|
||||
attn_dim: int,
|
||||
num_heads: int,
|
||||
memory_tau: int,
|
||||
head_dim: int,
|
||||
ff_hidden_dim: int,
|
||||
init_gate_bias: float = 2.0):
|
||||
"""Initializes a GTrXLNet.
|
||||
|
||||
Args:
|
||||
@@ -271,7 +281,8 @@ class GTrXLNet(RecurrentNetwork):
|
||||
self.trxl_model.summary()
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
# To make Attention work with current RLlib's ModelV2 API:
|
||||
# We assume `state` is the history of L recent observations (all
|
||||
# concatenated into one tensor) and append the current inputs to the
|
||||
@@ -303,7 +314,7 @@ class GTrXLNet(RecurrentNetwork):
|
||||
return logits, [observations] + memory_outs
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def get_initial_state(self):
|
||||
def get_initial_state(self) -> List[np.ndarray]:
|
||||
# State is the T last observations concat'd together into one Tensor.
|
||||
# Plus all Transformer blocks' E(l) outputs concat'd together (up to
|
||||
# tau timesteps).
|
||||
@@ -312,11 +323,11 @@ class GTrXLNet(RecurrentNetwork):
|
||||
for _ in range(self.num_transformer_units)]
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
def value_function(self) -> TensorType:
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
|
||||
def relative_position_embedding(seq_length, out_dim):
|
||||
def relative_position_embedding(seq_length: int, out_dim: int) -> TensorType:
|
||||
"""Creates a [seq_length x seq_length] matrix for rel. pos encoding.
|
||||
|
||||
Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
|
||||
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
@@ -10,8 +12,9 @@ tf1, tf, tfv = try_import_tf()
|
||||
class FullyConnectedNetwork(TFModelV2):
|
||||
"""Generic fully connected network implemented in ModelV2 API."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
def __init__(self, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space, num_outputs: int,
|
||||
model_config: ModelConfigDict, name: str):
|
||||
super(FullyConnectedNetwork, self).__init__(
|
||||
obs_space, action_space, num_outputs, model_config, name)
|
||||
|
||||
@@ -113,9 +116,11 @@ class FullyConnectedNetwork(TFModelV2):
|
||||
if logits_out is not None else last_layer), value_out])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
model_out, self._value_out = self.base_model(input_dict["obs_flat"])
|
||||
return model_out, state
|
||||
|
||||
def value_function(self):
|
||||
def value_function(self) -> TensorType:
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import TensorType, TensorShape
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
class GRUGate(tf.keras.layers.Layer if tf else object):
|
||||
def __init__(self, init_bias=0., **kwargs):
|
||||
def __init__(self, init_bias: float = 0., **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_bias = init_bias
|
||||
|
||||
def build(self, input_shape):
|
||||
def build(self, input_shape: TensorShape):
|
||||
h_shape, x_shape = input_shape
|
||||
if x_shape[-1] != h_shape[-1]:
|
||||
raise ValueError(
|
||||
@@ -29,7 +30,7 @@ class GRUGate(tf.keras.layers.Layer if tf else object):
|
||||
self._bias_z = self.add_weight(
|
||||
shape=(dim, ), initializer=bias_initializer)
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
def call(self, inputs: TensorType, **kwargs) -> TensorType:
|
||||
# Pass in internal state first.
|
||||
h, X = inputs
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
https://arxiv.org/pdf/1706.03762.pdf
|
||||
"""
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
@@ -11,7 +12,7 @@ tf1, tf, tfv = try_import_tf()
|
||||
class MultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
||||
"""A multi-head attention layer described in [1]."""
|
||||
|
||||
def __init__(self, out_dim, num_heads, head_dim, **kwargs):
|
||||
def __init__(self, out_dim: int, num_heads: int, head_dim: int, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# No bias or non-linearity.
|
||||
@@ -22,7 +23,7 @@ class MultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
||||
self._linear_layer = tf.keras.layers.TimeDistributed(
|
||||
tf.keras.layers.Dense(out_dim, use_bias=False))
|
||||
|
||||
def call(self, inputs):
|
||||
def call(self, inputs: TensorType) -> TensorType:
|
||||
L = tf.shape(inputs)[1] # length of segment
|
||||
H = self._num_heads # number of attention heads
|
||||
D = self._head_dim # attention head dimension
|
||||
|
||||
@@ -2,6 +2,7 @@ import numpy as np
|
||||
|
||||
from ray.rllib.utils.framework import get_activation_fn, get_variable, \
|
||||
try_import_tf
|
||||
from ray.rllib.utils.framework import TensorType, TensorShape
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
@@ -18,14 +19,18 @@ class NoisyLayer(tf.keras.layers.Layer if tf else object):
|
||||
vanish along the training procedure.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, out_size, sigma0, activation="relu"):
|
||||
def __init__(self,
|
||||
prefix: str,
|
||||
out_size: int,
|
||||
sigma0: float,
|
||||
activation: str = "relu"):
|
||||
"""Initializes a NoisyLayer object.
|
||||
|
||||
Args:
|
||||
prefix:
|
||||
out_size:
|
||||
sigma0:
|
||||
non_linear:
|
||||
out_size: Output size for Noisy Layer
|
||||
sigma0: Initialization value for sigma_b (bias noise)
|
||||
non_linear: Non-linear activation for Noisy Layer
|
||||
"""
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
@@ -41,7 +46,7 @@ class NoisyLayer(tf.keras.layers.Layer if tf else object):
|
||||
self.sigma_w = None # Noise for weight matrix
|
||||
self.sigma_b = None # Noise for biases.
|
||||
|
||||
def build(self, input_shape):
|
||||
def build(self, input_shape: TensorShape):
|
||||
in_size = int(input_shape[1])
|
||||
|
||||
self.sigma_w = get_variable(
|
||||
@@ -78,7 +83,7 @@ class NoisyLayer(tf.keras.layers.Layer if tf else object):
|
||||
dtype=tf.float32,
|
||||
)
|
||||
|
||||
def call(self, inputs):
|
||||
def call(self, inputs: TensorType) -> TensorType:
|
||||
in_size = int(inputs.shape[1])
|
||||
epsilon_in = tf.random.normal(shape=[in_size])
|
||||
epsilon_out = tf.random.normal(shape=[self.out_size])
|
||||
@@ -98,5 +103,5 @@ class NoisyLayer(tf.keras.layers.Layer if tf else object):
|
||||
action_activation = fn(action_activation)
|
||||
return action_activation
|
||||
|
||||
def _f_epsilon(self, x):
|
||||
def _f_epsilon(self, x: TensorType) -> TensorType:
|
||||
return tf.math.sign(x) * tf.math.sqrt(tf.math.abs(x))
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
@@ -10,12 +13,12 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_dim,
|
||||
num_heads,
|
||||
head_dim,
|
||||
rel_pos_encoder,
|
||||
input_layernorm=False,
|
||||
output_activation=None,
|
||||
out_dim: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
rel_pos_encoder: Any,
|
||||
input_layernorm: bool = False,
|
||||
output_activation: Optional[Any] = None,
|
||||
**kwargs):
|
||||
"""Initializes a RelativeMultiHeadAttention keras Layer object.
|
||||
|
||||
@@ -55,7 +58,8 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
||||
if input_layernorm:
|
||||
self._input_layernorm = tf.keras.layers.LayerNormalization(axis=-1)
|
||||
|
||||
def call(self, inputs, memory=None):
|
||||
def call(self, inputs: TensorType,
|
||||
memory: Optional[TensorType] = None) -> TensorType:
|
||||
T = tf.shape(inputs)[1] # length of segment (time)
|
||||
H = self._num_heads # number of attention heads
|
||||
d = self._head_dim # attention head dimension
|
||||
@@ -105,7 +109,7 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
||||
return self._linear_layer(out)
|
||||
|
||||
@staticmethod
|
||||
def rel_shift(x):
|
||||
def rel_shift(x: TensorType) -> TensorType:
|
||||
# Transposed version of the shift approach described in [3].
|
||||
# https://github.com/kimiyoung/transformer-xl/blob/
|
||||
# 44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
@@ -10,7 +13,11 @@ class SkipConnection(tf.keras.layers.Layer if tf else object):
|
||||
input as hidden state input to a given fan_in_layer.
|
||||
"""
|
||||
|
||||
def __init__(self, layer, fan_in_layer=None, add_memory=False, **kwargs):
|
||||
def __init__(self,
|
||||
layer: Any,
|
||||
fan_in_layer: Optional[Any] = None,
|
||||
add_memory: bool = False,
|
||||
**kwargs):
|
||||
"""Initializes a SkipConnection keras layer object.
|
||||
|
||||
Args:
|
||||
@@ -23,7 +30,7 @@ class SkipConnection(tf.keras.layers.Layer if tf else object):
|
||||
self._layer = layer
|
||||
self._fan_in_layer = fan_in_layer
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
def call(self, inputs: TensorType, **kwargs) -> TensorType:
|
||||
# del kwargs
|
||||
outputs = self._layer(inputs, **kwargs)
|
||||
# Residual case, just add inputs to outputs.
|
||||
|
||||
+18
-11
@@ -1,10 +1,13 @@
|
||||
import numpy as np
|
||||
from typing import Tuple, Any, Optional
|
||||
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
def normc_initializer(std=1.0):
|
||||
def normc_initializer(std: float = 1.0) -> Any:
|
||||
def _initializer(shape, dtype=None, partition_info=None):
|
||||
out = np.random.randn(*shape).astype(np.float32)
|
||||
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
|
||||
@@ -13,14 +16,14 @@ def normc_initializer(std=1.0):
|
||||
return _initializer
|
||||
|
||||
|
||||
def conv2d(x,
|
||||
num_filters,
|
||||
name,
|
||||
filter_size=(3, 3),
|
||||
stride=(1, 1),
|
||||
pad="SAME",
|
||||
dtype=None,
|
||||
collections=None):
|
||||
def conv2d(x: TensorType,
|
||||
num_filters: int,
|
||||
name: str,
|
||||
filter_size: Tuple[int, int] = (3, 3),
|
||||
stride: Tuple[int, int] = (1, 1),
|
||||
pad: str = "SAME",
|
||||
dtype: Optional[Any] = None,
|
||||
collections: Optional[Any] = None) -> TensorType:
|
||||
if dtype is None:
|
||||
dtype = tf.float32
|
||||
|
||||
@@ -53,7 +56,11 @@ def conv2d(x,
|
||||
return tf1.nn.conv2d(x, w, stride_shape, pad) + b
|
||||
|
||||
|
||||
def linear(x, size, name, initializer=None, bias_init=0):
|
||||
def linear(x: TensorType,
|
||||
size: int,
|
||||
name: str,
|
||||
initializer: Optional[Any] = None,
|
||||
bias_init: float = 0.0) -> TensorType:
|
||||
w = tf1.get_variable(
|
||||
name + "/w", [x.get_shape()[1], size], initializer=initializer)
|
||||
b = tf1.get_variable(
|
||||
@@ -61,5 +68,5 @@ def linear(x, size, name, initializer=None, bias_init=0):
|
||||
return tf.matmul(x, w) + b
|
||||
|
||||
|
||||
def flatten(x):
|
||||
def flatten(x: TensorType) -> TensorType:
|
||||
return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
from typing import Dict, List
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
@@ -6,6 +8,7 @@ from ray.rllib.policy.rnn_sequencing import add_time_dimension
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
@@ -49,7 +52,9 @@ class RecurrentNetwork(TFModelV2):
|
||||
"""
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
"""Adds time dimension to batch before sending inputs to forward_rnn().
|
||||
|
||||
You should implement forward_rnn() in your subclass."""
|
||||
@@ -62,7 +67,8 @@ class RecurrentNetwork(TFModelV2):
|
||||
seq_lens)
|
||||
return tf.reshape(output, [-1, self.num_outputs]), new_state
|
||||
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
"""Call the model with the given input tensors and state.
|
||||
|
||||
Args:
|
||||
@@ -83,7 +89,7 @@ class RecurrentNetwork(TFModelV2):
|
||||
"""
|
||||
raise NotImplementedError("You must implement this for a RNN model")
|
||||
|
||||
def get_initial_state(self):
|
||||
def get_initial_state(self) -> List[TensorType]:
|
||||
"""Get the initial recurrent state values for the model.
|
||||
|
||||
Returns:
|
||||
@@ -104,8 +110,9 @@ class LSTMWrapper(RecurrentNetwork):
|
||||
"""An LSTM wrapper serving as an interface for ModelV2s that set use_lstm.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
def __init__(self, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space, num_outputs: int,
|
||||
model_config: ModelConfigDict, name: str):
|
||||
|
||||
super(LSTMWrapper, self).__init__(obs_space, action_space, None,
|
||||
model_config, name)
|
||||
@@ -157,7 +164,9 @@ class LSTMWrapper(RecurrentNetwork):
|
||||
self._rnn_model.summary()
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
assert seq_lens is not None
|
||||
# Push obs through "unwrapped" net's `forward()` first.
|
||||
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
|
||||
@@ -181,18 +190,19 @@ class LSTMWrapper(RecurrentNetwork):
|
||||
return super().forward(input_dict, state, seq_lens)
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
model_out, self._value_out, h, c = self._rnn_model([inputs, seq_lens] +
|
||||
state)
|
||||
return model_out, [h, c]
|
||||
|
||||
@override(ModelV2)
|
||||
def get_initial_state(self):
|
||||
def get_initial_state(self) -> List[np.ndarray]:
|
||||
return [
|
||||
np.zeros(self.cell_size, np.float32),
|
||||
np.zeros(self.cell_size, np.float32),
|
||||
]
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
def value_function(self) -> TensorType:
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
@@ -2,6 +2,7 @@ from math import log
|
||||
import numpy as np
|
||||
import functools
|
||||
import tree
|
||||
import gym
|
||||
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
@@ -10,7 +11,8 @@ from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
||||
from ray.rllib.utils.typing import TensorType, List
|
||||
from ray.rllib.utils.typing import TensorType, List, Union, \
|
||||
Tuple, ModelConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
tfp = try_import_tfp()
|
||||
@@ -50,23 +52,26 @@ class Categorical(TFActionDistribution):
|
||||
"""Categorical distribution for discrete action spaces."""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, inputs, model=None, temperature=1.0):
|
||||
def __init__(self,
|
||||
inputs: List[TensorType],
|
||||
model: ModelV2 = None,
|
||||
temperature: float = 1.0):
|
||||
assert temperature > 0.0, "Categorical `temperature` must be > 0.0!"
|
||||
# Allow softmax formula w/ temperature != 1.0:
|
||||
# Divide inputs by temperature.
|
||||
super().__init__(inputs / temperature, model)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
def deterministic_sample(self) -> TensorType:
|
||||
return tf.math.argmax(self.inputs, axis=1)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
def logp(self, x: TensorType) -> TensorType:
|
||||
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=self.inputs, labels=tf.cast(x, tf.int32))
|
||||
|
||||
@override(ActionDistribution)
|
||||
def entropy(self):
|
||||
def entropy(self) -> TensorType:
|
||||
a0 = self.inputs - tf.reduce_max(self.inputs, axis=1, keepdims=True)
|
||||
ea0 = tf.exp(a0)
|
||||
z0 = tf.reduce_sum(ea0, axis=1, keepdims=True)
|
||||
@@ -74,7 +79,7 @@ class Categorical(TFActionDistribution):
|
||||
return tf.reduce_sum(p0 * (tf.math.log(z0) - a0), axis=1)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def kl(self, other):
|
||||
def kl(self, other: ActionDistribution) -> TensorType:
|
||||
a0 = self.inputs - tf.reduce_max(self.inputs, axis=1, keepdims=True)
|
||||
a1 = other.inputs - tf.reduce_max(other.inputs, axis=1, keepdims=True)
|
||||
ea0 = tf.exp(a0)
|
||||
@@ -86,7 +91,7 @@ class Categorical(TFActionDistribution):
|
||||
p0 * (a0 - tf.math.log(z0) - a1 + tf.math.log(z1)), axis=1)
|
||||
|
||||
@override(TFActionDistribution)
|
||||
def _build_sample_op(self):
|
||||
def _build_sample_op(self) -> TensorType:
|
||||
return tf.squeeze(tf.random.categorical(self.inputs, 1), axis=1)
|
||||
|
||||
@staticmethod
|
||||
@@ -98,7 +103,8 @@ class Categorical(TFActionDistribution):
|
||||
class MultiCategorical(TFActionDistribution):
|
||||
"""MultiCategorical distribution for MultiDiscrete action spaces."""
|
||||
|
||||
def __init__(self, inputs, model, input_lens):
|
||||
def __init__(self, inputs: List[TensorType], model: ModelV2,
|
||||
input_lens: Union[List[int], np.ndarray, Tuple[int, ...]]):
|
||||
# skip TFActionDistribution init
|
||||
ActionDistribution.__init__(self, inputs, model)
|
||||
self.cats = [
|
||||
@@ -109,12 +115,12 @@ class MultiCategorical(TFActionDistribution):
|
||||
self.sampled_action_logp_op = self.logp(self.sample_op)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
def deterministic_sample(self) -> TensorType:
|
||||
return tf.stack(
|
||||
[cat.deterministic_sample() for cat in self.cats], axis=1)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, actions):
|
||||
def logp(self, actions: TensorType) -> TensorType:
|
||||
# If tensor is provided, unstack it into list.
|
||||
if isinstance(actions, tf.Tensor):
|
||||
actions = tf.unstack(tf.cast(actions, tf.int32), axis=1)
|
||||
@@ -123,30 +129,32 @@ class MultiCategorical(TFActionDistribution):
|
||||
return tf.reduce_sum(logps, axis=0)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def multi_entropy(self):
|
||||
def multi_entropy(self) -> TensorType:
|
||||
return tf.stack([cat.entropy() for cat in self.cats], axis=1)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def entropy(self):
|
||||
def entropy(self) -> TensorType:
|
||||
return tf.reduce_sum(self.multi_entropy(), axis=1)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def multi_kl(self, other):
|
||||
def multi_kl(self, other: ActionDistribution) -> TensorType:
|
||||
return tf.stack(
|
||||
[cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)],
|
||||
axis=1)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def kl(self, other):
|
||||
def kl(self, other: ActionDistribution) -> TensorType:
|
||||
return tf.reduce_sum(self.multi_kl(other), axis=1)
|
||||
|
||||
@override(TFActionDistribution)
|
||||
def _build_sample_op(self):
|
||||
def _build_sample_op(self) -> TensorType:
|
||||
return tf.stack([cat.sample() for cat in self.cats], axis=1)
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
def required_model_output_shape(
|
||||
action_space: gym.Space,
|
||||
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
||||
return np.sum(action_space.nvec)
|
||||
|
||||
|
||||
@@ -168,7 +176,10 @@ class GumbelSoftmax(TFActionDistribution):
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, inputs, model=None, temperature=1.0):
|
||||
def __init__(self,
|
||||
inputs: List[TensorType],
|
||||
model: ModelV2 = None,
|
||||
temperature: float = 1.0):
|
||||
"""Initializes a GumbelSoftmax distribution.
|
||||
|
||||
Args:
|
||||
@@ -184,12 +195,12 @@ class GumbelSoftmax(TFActionDistribution):
|
||||
super().__init__(inputs, model)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
def deterministic_sample(self) -> TensorType:
|
||||
# Return the dist object's prob values.
|
||||
return self.probs
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
def logp(self, x: TensorType) -> TensorType:
|
||||
# Override since the implementation of tfp.RelaxedOneHotCategorical
|
||||
# yields positive values.
|
||||
if x.shape != self.dist.logits.shape:
|
||||
@@ -204,12 +215,14 @@ class GumbelSoftmax(TFActionDistribution):
|
||||
-x * tf.nn.log_softmax(self.dist.logits, axis=-1), axis=-1)
|
||||
|
||||
@override(TFActionDistribution)
|
||||
def _build_sample_op(self):
|
||||
def _build_sample_op(self) -> TensorType:
|
||||
return self.dist.sample()
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
def required_model_output_shape(
|
||||
action_space: gym.Space,
|
||||
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
||||
return action_space.n
|
||||
|
||||
|
||||
@@ -220,7 +233,7 @@ class DiagGaussian(TFActionDistribution):
|
||||
second half the gaussian standard deviations.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs, model):
|
||||
def __init__(self, inputs: List[TensorType], model: ModelV2):
|
||||
mean, log_std = tf.split(inputs, 2, axis=1)
|
||||
self.mean = mean
|
||||
self.log_std = log_std
|
||||
@@ -228,11 +241,11 @@ class DiagGaussian(TFActionDistribution):
|
||||
super().__init__(inputs, model)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
def deterministic_sample(self) -> TensorType:
|
||||
return self.mean
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
def logp(self, x: TensorType) -> TensorType:
|
||||
return -0.5 * tf.reduce_sum(
|
||||
tf.math.square((tf.cast(x, tf.float32) - self.mean) / self.std),
|
||||
axis=1
|
||||
@@ -240,7 +253,7 @@ class DiagGaussian(TFActionDistribution):
|
||||
tf.reduce_sum(self.log_std, axis=1)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def kl(self, other):
|
||||
def kl(self, other: ActionDistribution) -> TensorType:
|
||||
assert isinstance(other, DiagGaussian)
|
||||
return tf.reduce_sum(
|
||||
other.log_std - self.log_std +
|
||||
@@ -249,17 +262,19 @@ class DiagGaussian(TFActionDistribution):
|
||||
axis=1)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def entropy(self):
|
||||
def entropy(self) -> TensorType:
|
||||
return tf.reduce_sum(
|
||||
self.log_std + .5 * np.log(2.0 * np.pi * np.e), axis=1)
|
||||
|
||||
@override(TFActionDistribution)
|
||||
def _build_sample_op(self):
|
||||
def _build_sample_op(self) -> TensorType:
|
||||
return self.mean + self.std * tf.random.normal(tf.shape(self.mean))
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
def required_model_output_shape(
|
||||
action_space: gym.Space,
|
||||
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
||||
return np.prod(action_space.shape) * 2
|
||||
|
||||
|
||||
@@ -270,7 +285,11 @@ class SquashedGaussian(TFActionDistribution):
|
||||
`low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs, model, low=-1.0, high=1.0):
|
||||
def __init__(self,
|
||||
inputs: List[TensorType],
|
||||
model: ModelV2,
|
||||
low: float = -1.0,
|
||||
high: float = 1.0):
|
||||
"""Parameterizes the distribution via `inputs`.
|
||||
|
||||
Args:
|
||||
@@ -292,16 +311,16 @@ class SquashedGaussian(TFActionDistribution):
|
||||
super().__init__(inputs, model)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
def deterministic_sample(self) -> TensorType:
|
||||
mean = self.distr.mean()
|
||||
return self._squash(mean)
|
||||
|
||||
@override(TFActionDistribution)
|
||||
def _build_sample_op(self):
|
||||
def _build_sample_op(self) -> TensorType:
|
||||
return self._squash(self.distr.sample())
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
def logp(self, x: TensorType) -> TensorType:
|
||||
# Unsquash values (from [low,high] to ]-inf,inf[)
|
||||
unsquashed_values = self._unsquash(x)
|
||||
# Get log prob of unsquashed values from our Normal.
|
||||
@@ -316,13 +335,13 @@ class SquashedGaussian(TFActionDistribution):
|
||||
axis=-1)
|
||||
return log_prob
|
||||
|
||||
def _squash(self, raw_values):
|
||||
def _squash(self, raw_values: TensorType) -> TensorType:
|
||||
# Returned values are within [low, high] (including `low` and `high`).
|
||||
squashed = ((tf.math.tanh(raw_values) + 1.0) / 2.0) * \
|
||||
(self.high - self.low) + self.low
|
||||
return tf.clip_by_value(squashed, self.low, self.high)
|
||||
|
||||
def _unsquash(self, values):
|
||||
def _unsquash(self, values: TensorType) -> TensorType:
|
||||
normed_values = (values - self.low) / (self.high - self.low) * 2.0 - \
|
||||
1.0
|
||||
# Stabilize input to atanh.
|
||||
@@ -333,7 +352,9 @@ class SquashedGaussian(TFActionDistribution):
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
def required_model_output_shape(
|
||||
action_space: gym.Space,
|
||||
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
||||
return np.prod(action_space.shape) * 2
|
||||
|
||||
|
||||
@@ -347,7 +368,11 @@ class Beta(TFActionDistribution):
|
||||
and Gamma(n) = (n - 1)!
|
||||
"""
|
||||
|
||||
def __init__(self, inputs, model, low=0.0, high=1.0):
|
||||
def __init__(self,
|
||||
inputs: List[TensorType],
|
||||
model: ModelV2,
|
||||
low: float = 0.0,
|
||||
high: float = 1.0):
|
||||
# Stabilize input parameters (possibly coming from a linear layer).
|
||||
inputs = tf.clip_by_value(inputs, log(SMALL_NUMBER),
|
||||
-log(SMALL_NUMBER))
|
||||
@@ -361,29 +386,31 @@ class Beta(TFActionDistribution):
|
||||
super().__init__(inputs, model)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
def deterministic_sample(self) -> TensorType:
|
||||
mean = self.dist.mean()
|
||||
return self._squash(mean)
|
||||
|
||||
@override(TFActionDistribution)
|
||||
def _build_sample_op(self):
|
||||
def _build_sample_op(self) -> TensorType:
|
||||
return self._squash(self.dist.sample())
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
def logp(self, x: TensorType) -> TensorType:
|
||||
unsquashed_values = self._unsquash(x)
|
||||
return tf.math.reduce_sum(
|
||||
self.dist.log_prob(unsquashed_values), axis=-1)
|
||||
|
||||
def _squash(self, raw_values):
|
||||
def _squash(self, raw_values: TensorType) -> TensorType:
|
||||
return raw_values * (self.high - self.low) + self.low
|
||||
|
||||
def _unsquash(self, values):
|
||||
def _unsquash(self, values: TensorType) -> TensorType:
|
||||
return (values - self.low) / (self.high - self.low)
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
def required_model_output_shape(
|
||||
action_space: gym.Space,
|
||||
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
||||
return np.prod(action_space.shape) * 2
|
||||
|
||||
|
||||
@@ -395,20 +422,22 @@ class Deterministic(TFActionDistribution):
|
||||
"""
|
||||
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
def deterministic_sample(self) -> TensorType:
|
||||
return self.inputs
|
||||
|
||||
@override(TFActionDistribution)
|
||||
def logp(self, x):
|
||||
def logp(self, x: TensorType) -> TensorType:
|
||||
return tf.zeros_like(self.inputs)
|
||||
|
||||
@override(TFActionDistribution)
|
||||
def _build_sample_op(self):
|
||||
def _build_sample_op(self) -> TensorType:
|
||||
return self.inputs
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
def required_model_output_shape(
|
||||
action_space: gym.Space,
|
||||
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
||||
return np.prod(action_space.shape)
|
||||
|
||||
|
||||
@@ -503,7 +532,7 @@ class Dirichlet(TFActionDistribution):
|
||||
|
||||
e.g. actions that represent resource allocation."""
|
||||
|
||||
def __init__(self, inputs, model):
|
||||
def __init__(self, inputs: List[TensorType], model: ModelV2):
|
||||
"""Input is a tensor of logits. The exponential of logits is used to
|
||||
parametrize the Dirichlet distribution as all parameters need to be
|
||||
positive. An arbitrary small epsilon is added to the concentration
|
||||
@@ -525,7 +554,7 @@ class Dirichlet(TFActionDistribution):
|
||||
return tf.nn.softmax(self.dist.concentration)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
def logp(self, x: TensorType) -> TensorType:
|
||||
# Support of Dirichlet are positive real numbers. x is already
|
||||
# an array of positive numbers, but we clip to avoid zeros due to
|
||||
# numerical errors.
|
||||
@@ -534,18 +563,20 @@ class Dirichlet(TFActionDistribution):
|
||||
return self.dist.log_prob(x)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def entropy(self):
|
||||
def entropy(self) -> TensorType:
|
||||
return self.dist.entropy()
|
||||
|
||||
@override(ActionDistribution)
|
||||
def kl(self, other):
|
||||
def kl(self, other: ActionDistribution) -> TensorType:
|
||||
return self.dist.kl_divergence(other.dist)
|
||||
|
||||
@override(TFActionDistribution)
|
||||
def _build_sample_op(self):
|
||||
def _build_sample_op(self) -> TensorType:
|
||||
return self.dist.sample()
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
def required_model_output_shape(
|
||||
action_space: gym.Space,
|
||||
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
||||
return np.prod(action_space.shape)
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import Dict, List
|
||||
import gym
|
||||
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.utils import get_filter_config
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
@@ -9,8 +13,9 @@ tf1, tf, tfv = try_import_tf()
|
||||
class VisionNetwork(TFModelV2):
|
||||
"""Generic vision network implemented in ModelV2 API."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
def __init__(self, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space, num_outputs: int,
|
||||
model_config: ModelConfigDict, name: str):
|
||||
if not model_config.get("conv_filters"):
|
||||
model_config["conv_filters"] = get_filter_config(obs_space.shape)
|
||||
|
||||
@@ -137,7 +142,9 @@ class VisionNetwork(TFModelV2):
|
||||
self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
# Explicit cast to float32 needed in eager.
|
||||
model_out, self._value_out = self.base_model(
|
||||
tf.cast(input_dict["obs"], tf.float32))
|
||||
@@ -148,5 +155,5 @@ class VisionNetwork(TFModelV2):
|
||||
else:
|
||||
return tf.squeeze(model_out, axis=[1, 2]), state
|
||||
|
||||
def value_function(self):
|
||||
def value_function(self) -> TensorType:
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
Reference in New Issue
Block a user