[RLlib] Model Annotations: Tensorflow (#11964)

This commit is contained in:
Michael Luo
2020-11-12 03:18:50 -08:00
committed by GitHub
parent b2984d1c34
commit 59ccbc0fc7
11 changed files with 214 additions and 125 deletions
+34 -23
View File
@@ -9,6 +9,8 @@
https://www.aclweb.org/anthology/P19-1285.pdf
"""
import numpy as np
import gym
from typing import Optional, Any
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \
@@ -16,6 +18,7 @@ from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
tf1, tf, tfv = try_import_tf()
@@ -28,7 +31,11 @@ class PositionwiseFeedforward(tf.keras.layers.Layer):
layer separately.
"""
def __init__(self, out_dim, hidden_dim, output_activation=None, **kwargs):
def __init__(self,
out_dim: int,
hidden_dim: int,
output_activation: Optional[Any] = None,
**kwargs):
super().__init__(**kwargs)
self._hidden_layer = tf.keras.layers.Dense(
@@ -39,7 +46,7 @@ class PositionwiseFeedforward(tf.keras.layers.Layer):
self._output_layer = tf.keras.layers.Dense(
out_dim, activation=output_activation)
def call(self, inputs, **kwargs):
def call(self, inputs: TensorType, **kwargs) -> TensorType:
del kwargs
output = self._hidden_layer(inputs)
return self._output_layer(output)
@@ -48,9 +55,11 @@ class PositionwiseFeedforward(tf.keras.layers.Layer):
class TrXLNet(RecurrentNetwork):
"""A TrXL net Model described in [1]."""
def __init__(self, observation_space, action_space, num_outputs,
model_config, name, num_transformer_units, attn_dim,
num_heads, head_dim, ff_hidden_dim):
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str,
num_transformer_units: int, attn_dim: int, num_heads: int,
head_dim: int, ff_hidden_dim: int):
"""Initializes a TfXLNet object.
Args:
@@ -109,7 +118,8 @@ class TrXLNet(RecurrentNetwork):
self.register_variables(self.base_model.variables)
@override(RecurrentNetwork)
def forward_rnn(self, inputs, state, seq_lens):
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
# To make Attention work with current RLlib's ModelV2 API:
# We assume `state` is the history of L recent observations (all
# concatenated into one tensor) and append the current inputs to the
@@ -126,7 +136,7 @@ class TrXLNet(RecurrentNetwork):
return logits, [observations]
@override(RecurrentNetwork)
def get_initial_state(self):
def get_initial_state(self) -> List[np.ndarray]:
# State is the T last observations concat'd together into one Tensor.
# Plus all Transformer blocks' E(l) outputs concat'd together (up to
# tau timesteps).
@@ -156,18 +166,18 @@ class GTrXLNet(RecurrentNetwork):
"""
def __init__(self,
observation_space,
action_space,
num_outputs,
model_config,
name,
num_transformer_units,
attn_dim,
num_heads,
memory_tau,
head_dim,
ff_hidden_dim,
init_gate_bias=2.0):
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
num_transformer_units: int,
attn_dim: int,
num_heads: int,
memory_tau: int,
head_dim: int,
ff_hidden_dim: int,
init_gate_bias: float = 2.0):
"""Initializes a GTrXLNet.
Args:
@@ -271,7 +281,8 @@ class GTrXLNet(RecurrentNetwork):
self.trxl_model.summary()
@override(RecurrentNetwork)
def forward_rnn(self, inputs, state, seq_lens):
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
# To make Attention work with current RLlib's ModelV2 API:
# We assume `state` is the history of L recent observations (all
# concatenated into one tensor) and append the current inputs to the
@@ -303,7 +314,7 @@ class GTrXLNet(RecurrentNetwork):
return logits, [observations] + memory_outs
@override(RecurrentNetwork)
def get_initial_state(self):
def get_initial_state(self) -> List[np.ndarray]:
# State is the T last observations concat'd together into one Tensor.
# Plus all Transformer blocks' E(l) outputs concat'd together (up to
# tau timesteps).
@@ -312,11 +323,11 @@ class GTrXLNet(RecurrentNetwork):
for _ in range(self.num_transformer_units)]
@override(ModelV2)
def value_function(self):
def value_function(self) -> TensorType:
return tf.reshape(self._value_out, [-1])
def relative_position_embedding(seq_length, out_dim):
def relative_position_embedding(seq_length: int, out_dim: int) -> TensorType:
"""Creates a [seq_length x seq_length] matrix for rel. pos encoding.
Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding
+9 -4
View File
@@ -1,8 +1,10 @@
import numpy as np
import gym
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
tf1, tf, tfv = try_import_tf()
@@ -10,8 +12,9 @@ tf1, tf, tfv = try_import_tf()
class FullyConnectedNetwork(TFModelV2):
"""Generic fully connected network implemented in ModelV2 API."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str):
super(FullyConnectedNetwork, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
@@ -113,9 +116,11 @@ class FullyConnectedNetwork(TFModelV2):
if logits_out is not None else last_layer), value_out])
self.register_variables(self.base_model.variables)
def forward(self, input_dict, state, seq_lens):
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
model_out, self._value_out = self.base_model(input_dict["obs_flat"])
return model_out, state
def value_function(self):
def value_function(self) -> TensorType:
return tf.reshape(self._value_out, [-1])
+4 -3
View File
@@ -1,14 +1,15 @@
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType, TensorShape
tf1, tf, tfv = try_import_tf()
class GRUGate(tf.keras.layers.Layer if tf else object):
def __init__(self, init_bias=0., **kwargs):
def __init__(self, init_bias: float = 0., **kwargs):
super().__init__(**kwargs)
self._init_bias = init_bias
def build(self, input_shape):
def build(self, input_shape: TensorShape):
h_shape, x_shape = input_shape
if x_shape[-1] != h_shape[-1]:
raise ValueError(
@@ -29,7 +30,7 @@ class GRUGate(tf.keras.layers.Layer if tf else object):
self._bias_z = self.add_weight(
shape=(dim, ), initializer=bias_initializer)
def call(self, inputs, **kwargs):
def call(self, inputs: TensorType, **kwargs) -> TensorType:
# Pass in internal state first.
h, X = inputs
@@ -4,6 +4,7 @@
https://arxiv.org/pdf/1706.03762.pdf
"""
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType
tf1, tf, tfv = try_import_tf()
@@ -11,7 +12,7 @@ tf1, tf, tfv = try_import_tf()
class MultiHeadAttention(tf.keras.layers.Layer if tf else object):
"""A multi-head attention layer described in [1]."""
def __init__(self, out_dim, num_heads, head_dim, **kwargs):
def __init__(self, out_dim: int, num_heads: int, head_dim: int, **kwargs):
super().__init__(**kwargs)
# No bias or non-linearity.
@@ -22,7 +23,7 @@ class MultiHeadAttention(tf.keras.layers.Layer if tf else object):
self._linear_layer = tf.keras.layers.TimeDistributed(
tf.keras.layers.Dense(out_dim, use_bias=False))
def call(self, inputs):
def call(self, inputs: TensorType) -> TensorType:
L = tf.shape(inputs)[1] # length of segment
H = self._num_heads # number of attention heads
D = self._head_dim # attention head dimension
+12 -7
View File
@@ -2,6 +2,7 @@ import numpy as np
from ray.rllib.utils.framework import get_activation_fn, get_variable, \
try_import_tf
from ray.rllib.utils.framework import TensorType, TensorShape
tf1, tf, tfv = try_import_tf()
@@ -18,14 +19,18 @@ class NoisyLayer(tf.keras.layers.Layer if tf else object):
vanish along the training procedure.
"""
def __init__(self, prefix, out_size, sigma0, activation="relu"):
def __init__(self,
prefix: str,
out_size: int,
sigma0: float,
activation: str = "relu"):
"""Initializes a NoisyLayer object.
Args:
prefix:
out_size:
sigma0:
non_linear:
out_size: Output size for Noisy Layer
sigma0: Initialization value for sigma_b (bias noise)
non_linear: Non-linear activation for Noisy Layer
"""
super().__init__()
self.prefix = prefix
@@ -41,7 +46,7 @@ class NoisyLayer(tf.keras.layers.Layer if tf else object):
self.sigma_w = None # Noise for weight matrix
self.sigma_b = None # Noise for biases.
def build(self, input_shape):
def build(self, input_shape: TensorShape):
in_size = int(input_shape[1])
self.sigma_w = get_variable(
@@ -78,7 +83,7 @@ class NoisyLayer(tf.keras.layers.Layer if tf else object):
dtype=tf.float32,
)
def call(self, inputs):
def call(self, inputs: TensorType) -> TensorType:
in_size = int(inputs.shape[1])
epsilon_in = tf.random.normal(shape=[in_size])
epsilon_out = tf.random.normal(shape=[self.out_size])
@@ -98,5 +103,5 @@ class NoisyLayer(tf.keras.layers.Layer if tf else object):
action_activation = fn(action_activation)
return action_activation
def _f_epsilon(self, x):
def _f_epsilon(self, x: TensorType) -> TensorType:
return tf.math.sign(x) * tf.math.sqrt(tf.math.abs(x))
@@ -1,4 +1,7 @@
from typing import Optional, Any
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType
tf1, tf, tfv = try_import_tf()
@@ -10,12 +13,12 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
"""
def __init__(self,
out_dim,
num_heads,
head_dim,
rel_pos_encoder,
input_layernorm=False,
output_activation=None,
out_dim: int,
num_heads: int,
head_dim: int,
rel_pos_encoder: Any,
input_layernorm: bool = False,
output_activation: Optional[Any] = None,
**kwargs):
"""Initializes a RelativeMultiHeadAttention keras Layer object.
@@ -55,7 +58,8 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
if input_layernorm:
self._input_layernorm = tf.keras.layers.LayerNormalization(axis=-1)
def call(self, inputs, memory=None):
def call(self, inputs: TensorType,
memory: Optional[TensorType] = None) -> TensorType:
T = tf.shape(inputs)[1] # length of segment (time)
H = self._num_heads # number of attention heads
d = self._head_dim # attention head dimension
@@ -105,7 +109,7 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
return self._linear_layer(out)
@staticmethod
def rel_shift(x):
def rel_shift(x: TensorType) -> TensorType:
# Transposed version of the shift approach described in [3].
# https://github.com/kimiyoung/transformer-xl/blob/
# 44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31
+9 -2
View File
@@ -1,4 +1,7 @@
from typing import Optional, Any
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType
tf1, tf, tfv = try_import_tf()
@@ -10,7 +13,11 @@ class SkipConnection(tf.keras.layers.Layer if tf else object):
input as hidden state input to a given fan_in_layer.
"""
def __init__(self, layer, fan_in_layer=None, add_memory=False, **kwargs):
def __init__(self,
layer: Any,
fan_in_layer: Optional[Any] = None,
add_memory: bool = False,
**kwargs):
"""Initializes a SkipConnection keras layer object.
Args:
@@ -23,7 +30,7 @@ class SkipConnection(tf.keras.layers.Layer if tf else object):
self._layer = layer
self._fan_in_layer = fan_in_layer
def call(self, inputs, **kwargs):
def call(self, inputs: TensorType, **kwargs) -> TensorType:
# del kwargs
outputs = self._layer(inputs, **kwargs)
# Residual case, just add inputs to outputs.
+18 -11
View File
@@ -1,10 +1,13 @@
import numpy as np
from typing import Tuple, Any, Optional
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType
tf1, tf, tfv = try_import_tf()
def normc_initializer(std=1.0):
def normc_initializer(std: float = 1.0) -> Any:
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
@@ -13,14 +16,14 @@ def normc_initializer(std=1.0):
return _initializer
def conv2d(x,
num_filters,
name,
filter_size=(3, 3),
stride=(1, 1),
pad="SAME",
dtype=None,
collections=None):
def conv2d(x: TensorType,
num_filters: int,
name: str,
filter_size: Tuple[int, int] = (3, 3),
stride: Tuple[int, int] = (1, 1),
pad: str = "SAME",
dtype: Optional[Any] = None,
collections: Optional[Any] = None) -> TensorType:
if dtype is None:
dtype = tf.float32
@@ -53,7 +56,11 @@ def conv2d(x,
return tf1.nn.conv2d(x, w, stride_shape, pad) + b
def linear(x, size, name, initializer=None, bias_init=0):
def linear(x: TensorType,
size: int,
name: str,
initializer: Optional[Any] = None,
bias_init: float = 0.0) -> TensorType:
w = tf1.get_variable(
name + "/w", [x.get_shape()[1], size], initializer=initializer)
b = tf1.get_variable(
@@ -61,5 +68,5 @@ def linear(x, size, name, initializer=None, bias_init=0):
return tf.matmul(x, w) + b
def flatten(x):
def flatten(x: TensorType) -> TensorType:
return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
+19 -9
View File
@@ -1,4 +1,6 @@
import numpy as np
import gym
from typing import Dict, List
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
@@ -6,6 +8,7 @@ from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
@@ -49,7 +52,9 @@ class RecurrentNetwork(TFModelV2):
"""
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
"""Adds time dimension to batch before sending inputs to forward_rnn().
You should implement forward_rnn() in your subclass."""
@@ -62,7 +67,8 @@ class RecurrentNetwork(TFModelV2):
seq_lens)
return tf.reshape(output, [-1, self.num_outputs]), new_state
def forward_rnn(self, inputs, state, seq_lens):
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
"""Call the model with the given input tensors and state.
Args:
@@ -83,7 +89,7 @@ class RecurrentNetwork(TFModelV2):
"""
raise NotImplementedError("You must implement this for a RNN model")
def get_initial_state(self):
def get_initial_state(self) -> List[TensorType]:
"""Get the initial recurrent state values for the model.
Returns:
@@ -104,8 +110,9 @@ class LSTMWrapper(RecurrentNetwork):
"""An LSTM wrapper serving as an interface for ModelV2s that set use_lstm.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str):
super(LSTMWrapper, self).__init__(obs_space, action_space, None,
model_config, name)
@@ -157,7 +164,9 @@ class LSTMWrapper(RecurrentNetwork):
self._rnn_model.summary()
@override(RecurrentNetwork)
def forward(self, input_dict, state, seq_lens):
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
assert seq_lens is not None
# Push obs through "unwrapped" net's `forward()` first.
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
@@ -181,18 +190,19 @@ class LSTMWrapper(RecurrentNetwork):
return super().forward(input_dict, state, seq_lens)
@override(RecurrentNetwork)
def forward_rnn(self, inputs, state, seq_lens):
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
model_out, self._value_out, h, c = self._rnn_model([inputs, seq_lens] +
state)
return model_out, [h, c]
@override(ModelV2)
def get_initial_state(self):
def get_initial_state(self) -> List[np.ndarray]:
return [
np.zeros(self.cell_size, np.float32),
np.zeros(self.cell_size, np.float32),
]
@override(ModelV2)
def value_function(self):
def value_function(self) -> TensorType:
return tf.reshape(self._value_out, [-1])
+83 -52
View File
@@ -2,6 +2,7 @@ from math import log
import numpy as np
import functools
import tree
import gym
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
@@ -10,7 +11,8 @@ from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import TensorType, List
from ray.rllib.utils.typing import TensorType, List, Union, \
Tuple, ModelConfigDict
tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
@@ -50,23 +52,26 @@ class Categorical(TFActionDistribution):
"""Categorical distribution for discrete action spaces."""
@DeveloperAPI
def __init__(self, inputs, model=None, temperature=1.0):
def __init__(self,
inputs: List[TensorType],
model: ModelV2 = None,
temperature: float = 1.0):
assert temperature > 0.0, "Categorical `temperature` must be > 0.0!"
# Allow softmax formula w/ temperature != 1.0:
# Divide inputs by temperature.
super().__init__(inputs / temperature, model)
@override(ActionDistribution)
def deterministic_sample(self):
def deterministic_sample(self) -> TensorType:
return tf.math.argmax(self.inputs, axis=1)
@override(ActionDistribution)
def logp(self, x):
def logp(self, x: TensorType) -> TensorType:
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=self.inputs, labels=tf.cast(x, tf.int32))
@override(ActionDistribution)
def entropy(self):
def entropy(self) -> TensorType:
a0 = self.inputs - tf.reduce_max(self.inputs, axis=1, keepdims=True)
ea0 = tf.exp(a0)
z0 = tf.reduce_sum(ea0, axis=1, keepdims=True)
@@ -74,7 +79,7 @@ class Categorical(TFActionDistribution):
return tf.reduce_sum(p0 * (tf.math.log(z0) - a0), axis=1)
@override(ActionDistribution)
def kl(self, other):
def kl(self, other: ActionDistribution) -> TensorType:
a0 = self.inputs - tf.reduce_max(self.inputs, axis=1, keepdims=True)
a1 = other.inputs - tf.reduce_max(other.inputs, axis=1, keepdims=True)
ea0 = tf.exp(a0)
@@ -86,7 +91,7 @@ class Categorical(TFActionDistribution):
p0 * (a0 - tf.math.log(z0) - a1 + tf.math.log(z1)), axis=1)
@override(TFActionDistribution)
def _build_sample_op(self):
def _build_sample_op(self) -> TensorType:
return tf.squeeze(tf.random.categorical(self.inputs, 1), axis=1)
@staticmethod
@@ -98,7 +103,8 @@ class Categorical(TFActionDistribution):
class MultiCategorical(TFActionDistribution):
"""MultiCategorical distribution for MultiDiscrete action spaces."""
def __init__(self, inputs, model, input_lens):
def __init__(self, inputs: List[TensorType], model: ModelV2,
input_lens: Union[List[int], np.ndarray, Tuple[int, ...]]):
# skip TFActionDistribution init
ActionDistribution.__init__(self, inputs, model)
self.cats = [
@@ -109,12 +115,12 @@ class MultiCategorical(TFActionDistribution):
self.sampled_action_logp_op = self.logp(self.sample_op)
@override(ActionDistribution)
def deterministic_sample(self):
def deterministic_sample(self) -> TensorType:
return tf.stack(
[cat.deterministic_sample() for cat in self.cats], axis=1)
@override(ActionDistribution)
def logp(self, actions):
def logp(self, actions: TensorType) -> TensorType:
# If tensor is provided, unstack it into list.
if isinstance(actions, tf.Tensor):
actions = tf.unstack(tf.cast(actions, tf.int32), axis=1)
@@ -123,30 +129,32 @@ class MultiCategorical(TFActionDistribution):
return tf.reduce_sum(logps, axis=0)
@override(ActionDistribution)
def multi_entropy(self):
def multi_entropy(self) -> TensorType:
return tf.stack([cat.entropy() for cat in self.cats], axis=1)
@override(ActionDistribution)
def entropy(self):
def entropy(self) -> TensorType:
return tf.reduce_sum(self.multi_entropy(), axis=1)
@override(ActionDistribution)
def multi_kl(self, other):
def multi_kl(self, other: ActionDistribution) -> TensorType:
return tf.stack(
[cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)],
axis=1)
@override(ActionDistribution)
def kl(self, other):
def kl(self, other: ActionDistribution) -> TensorType:
return tf.reduce_sum(self.multi_kl(other), axis=1)
@override(TFActionDistribution)
def _build_sample_op(self):
def _build_sample_op(self) -> TensorType:
return tf.stack([cat.sample() for cat in self.cats], axis=1)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
def required_model_output_shape(
action_space: gym.Space,
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
return np.sum(action_space.nvec)
@@ -168,7 +176,10 @@ class GumbelSoftmax(TFActionDistribution):
"""
@DeveloperAPI
def __init__(self, inputs, model=None, temperature=1.0):
def __init__(self,
inputs: List[TensorType],
model: ModelV2 = None,
temperature: float = 1.0):
"""Initializes a GumbelSoftmax distribution.
Args:
@@ -184,12 +195,12 @@ class GumbelSoftmax(TFActionDistribution):
super().__init__(inputs, model)
@override(ActionDistribution)
def deterministic_sample(self):
def deterministic_sample(self) -> TensorType:
# Return the dist object's prob values.
return self.probs
@override(ActionDistribution)
def logp(self, x):
def logp(self, x: TensorType) -> TensorType:
# Override since the implementation of tfp.RelaxedOneHotCategorical
# yields positive values.
if x.shape != self.dist.logits.shape:
@@ -204,12 +215,14 @@ class GumbelSoftmax(TFActionDistribution):
-x * tf.nn.log_softmax(self.dist.logits, axis=-1), axis=-1)
@override(TFActionDistribution)
def _build_sample_op(self):
def _build_sample_op(self) -> TensorType:
return self.dist.sample()
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
def required_model_output_shape(
action_space: gym.Space,
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
return action_space.n
@@ -220,7 +233,7 @@ class DiagGaussian(TFActionDistribution):
second half the gaussian standard deviations.
"""
def __init__(self, inputs, model):
def __init__(self, inputs: List[TensorType], model: ModelV2):
mean, log_std = tf.split(inputs, 2, axis=1)
self.mean = mean
self.log_std = log_std
@@ -228,11 +241,11 @@ class DiagGaussian(TFActionDistribution):
super().__init__(inputs, model)
@override(ActionDistribution)
def deterministic_sample(self):
def deterministic_sample(self) -> TensorType:
return self.mean
@override(ActionDistribution)
def logp(self, x):
def logp(self, x: TensorType) -> TensorType:
return -0.5 * tf.reduce_sum(
tf.math.square((tf.cast(x, tf.float32) - self.mean) / self.std),
axis=1
@@ -240,7 +253,7 @@ class DiagGaussian(TFActionDistribution):
tf.reduce_sum(self.log_std, axis=1)
@override(ActionDistribution)
def kl(self, other):
def kl(self, other: ActionDistribution) -> TensorType:
assert isinstance(other, DiagGaussian)
return tf.reduce_sum(
other.log_std - self.log_std +
@@ -249,17 +262,19 @@ class DiagGaussian(TFActionDistribution):
axis=1)
@override(ActionDistribution)
def entropy(self):
def entropy(self) -> TensorType:
return tf.reduce_sum(
self.log_std + .5 * np.log(2.0 * np.pi * np.e), axis=1)
@override(TFActionDistribution)
def _build_sample_op(self):
def _build_sample_op(self) -> TensorType:
return self.mean + self.std * tf.random.normal(tf.shape(self.mean))
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
def required_model_output_shape(
action_space: gym.Space,
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
return np.prod(action_space.shape) * 2
@@ -270,7 +285,11 @@ class SquashedGaussian(TFActionDistribution):
`low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
"""
def __init__(self, inputs, model, low=-1.0, high=1.0):
def __init__(self,
inputs: List[TensorType],
model: ModelV2,
low: float = -1.0,
high: float = 1.0):
"""Parameterizes the distribution via `inputs`.
Args:
@@ -292,16 +311,16 @@ class SquashedGaussian(TFActionDistribution):
super().__init__(inputs, model)
@override(ActionDistribution)
def deterministic_sample(self):
def deterministic_sample(self) -> TensorType:
mean = self.distr.mean()
return self._squash(mean)
@override(TFActionDistribution)
def _build_sample_op(self):
def _build_sample_op(self) -> TensorType:
return self._squash(self.distr.sample())
@override(ActionDistribution)
def logp(self, x):
def logp(self, x: TensorType) -> TensorType:
# Unsquash values (from [low,high] to ]-inf,inf[)
unsquashed_values = self._unsquash(x)
# Get log prob of unsquashed values from our Normal.
@@ -316,13 +335,13 @@ class SquashedGaussian(TFActionDistribution):
axis=-1)
return log_prob
def _squash(self, raw_values):
def _squash(self, raw_values: TensorType) -> TensorType:
# Returned values are within [low, high] (including `low` and `high`).
squashed = ((tf.math.tanh(raw_values) + 1.0) / 2.0) * \
(self.high - self.low) + self.low
return tf.clip_by_value(squashed, self.low, self.high)
def _unsquash(self, values):
def _unsquash(self, values: TensorType) -> TensorType:
normed_values = (values - self.low) / (self.high - self.low) * 2.0 - \
1.0
# Stabilize input to atanh.
@@ -333,7 +352,9 @@ class SquashedGaussian(TFActionDistribution):
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
def required_model_output_shape(
action_space: gym.Space,
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
return np.prod(action_space.shape) * 2
@@ -347,7 +368,11 @@ class Beta(TFActionDistribution):
and Gamma(n) = (n - 1)!
"""
def __init__(self, inputs, model, low=0.0, high=1.0):
def __init__(self,
inputs: List[TensorType],
model: ModelV2,
low: float = 0.0,
high: float = 1.0):
# Stabilize input parameters (possibly coming from a linear layer).
inputs = tf.clip_by_value(inputs, log(SMALL_NUMBER),
-log(SMALL_NUMBER))
@@ -361,29 +386,31 @@ class Beta(TFActionDistribution):
super().__init__(inputs, model)
@override(ActionDistribution)
def deterministic_sample(self):
def deterministic_sample(self) -> TensorType:
mean = self.dist.mean()
return self._squash(mean)
@override(TFActionDistribution)
def _build_sample_op(self):
def _build_sample_op(self) -> TensorType:
return self._squash(self.dist.sample())
@override(ActionDistribution)
def logp(self, x):
def logp(self, x: TensorType) -> TensorType:
unsquashed_values = self._unsquash(x)
return tf.math.reduce_sum(
self.dist.log_prob(unsquashed_values), axis=-1)
def _squash(self, raw_values):
def _squash(self, raw_values: TensorType) -> TensorType:
return raw_values * (self.high - self.low) + self.low
def _unsquash(self, values):
def _unsquash(self, values: TensorType) -> TensorType:
return (values - self.low) / (self.high - self.low)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
def required_model_output_shape(
action_space: gym.Space,
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
return np.prod(action_space.shape) * 2
@@ -395,20 +422,22 @@ class Deterministic(TFActionDistribution):
"""
@override(ActionDistribution)
def deterministic_sample(self):
def deterministic_sample(self) -> TensorType:
return self.inputs
@override(TFActionDistribution)
def logp(self, x):
def logp(self, x: TensorType) -> TensorType:
return tf.zeros_like(self.inputs)
@override(TFActionDistribution)
def _build_sample_op(self):
def _build_sample_op(self) -> TensorType:
return self.inputs
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
def required_model_output_shape(
action_space: gym.Space,
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
return np.prod(action_space.shape)
@@ -503,7 +532,7 @@ class Dirichlet(TFActionDistribution):
e.g. actions that represent resource allocation."""
def __init__(self, inputs, model):
def __init__(self, inputs: List[TensorType], model: ModelV2):
"""Input is a tensor of logits. The exponential of logits is used to
parametrize the Dirichlet distribution as all parameters need to be
positive. An arbitrary small epsilon is added to the concentration
@@ -525,7 +554,7 @@ class Dirichlet(TFActionDistribution):
return tf.nn.softmax(self.dist.concentration)
@override(ActionDistribution)
def logp(self, x):
def logp(self, x: TensorType) -> TensorType:
# Support of Dirichlet are positive real numbers. x is already
# an array of positive numbers, but we clip to avoid zeros due to
# numerical errors.
@@ -534,18 +563,20 @@ class Dirichlet(TFActionDistribution):
return self.dist.log_prob(x)
@override(ActionDistribution)
def entropy(self):
def entropy(self) -> TensorType:
return self.dist.entropy()
@override(ActionDistribution)
def kl(self, other):
def kl(self, other: ActionDistribution) -> TensorType:
return self.dist.kl_divergence(other.dist)
@override(TFActionDistribution)
def _build_sample_op(self):
def _build_sample_op(self) -> TensorType:
return self.dist.sample()
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
def required_model_output_shape(
action_space: gym.Space,
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
return np.prod(action_space.shape)
+11 -4
View File
@@ -1,7 +1,11 @@
from typing import Dict, List
import gym
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.utils import get_filter_config
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
@@ -9,8 +13,9 @@ tf1, tf, tfv = try_import_tf()
class VisionNetwork(TFModelV2):
"""Generic vision network implemented in ModelV2 API."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str):
if not model_config.get("conv_filters"):
model_config["conv_filters"] = get_filter_config(obs_space.shape)
@@ -137,7 +142,9 @@ class VisionNetwork(TFModelV2):
self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
self.register_variables(self.base_model.variables)
def forward(self, input_dict, state, seq_lens):
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
# Explicit cast to float32 needed in eager.
model_out, self._value_out = self.base_model(
tf.cast(input_dict["obs"], tf.float32))
@@ -148,5 +155,5 @@ class VisionNetwork(TFModelV2):
else:
return tf.squeeze(model_out, axis=[1, 2]), state
def value_function(self):
def value_function(self) -> TensorType:
return tf.reshape(self._value_out, [-1])