From 59ccbc0fc7b5dcbaf25c85e83da165bcff92e06c Mon Sep 17 00:00:00 2001 From: Michael Luo Date: Thu, 12 Nov 2020 03:18:50 -0800 Subject: [PATCH] [RLlib] Model Annotations: Tensorflow (#11964) --- rllib/models/tf/attention_net.py | 57 +++++--- rllib/models/tf/fcnet.py | 13 +- rllib/models/tf/layers/gru_gate.py | 7 +- .../models/tf/layers/multi_head_attention.py | 5 +- rllib/models/tf/layers/noisy_layer.py | 19 ++- .../layers/relative_multi_head_attention.py | 20 +-- rllib/models/tf/layers/skip_connection.py | 11 +- rllib/models/tf/misc.py | 29 ++-- rllib/models/tf/recurrent_net.py | 28 ++-- rllib/models/tf/tf_action_dist.py | 135 +++++++++++------- rllib/models/tf/visionnet.py | 15 +- 11 files changed, 214 insertions(+), 125 deletions(-) diff --git a/rllib/models/tf/attention_net.py b/rllib/models/tf/attention_net.py index c96cf6c48..48d25c132 100644 --- a/rllib/models/tf/attention_net.py +++ b/rllib/models/tf/attention_net.py @@ -9,6 +9,8 @@ https://www.aclweb.org/anthology/P19-1285.pdf """ import numpy as np +import gym +from typing import Optional, Any from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \ @@ -16,6 +18,7 @@ from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \ from ray.rllib.models.tf.recurrent_net import RecurrentNetwork from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import ModelConfigDict, TensorType, List tf1, tf, tfv = try_import_tf() @@ -28,7 +31,11 @@ class PositionwiseFeedforward(tf.keras.layers.Layer): layer separately. """ - def __init__(self, out_dim, hidden_dim, output_activation=None, **kwargs): + def __init__(self, + out_dim: int, + hidden_dim: int, + output_activation: Optional[Any] = None, + **kwargs): super().__init__(**kwargs) self._hidden_layer = tf.keras.layers.Dense( @@ -39,7 +46,7 @@ class PositionwiseFeedforward(tf.keras.layers.Layer): self._output_layer = tf.keras.layers.Dense( out_dim, activation=output_activation) - def call(self, inputs, **kwargs): + def call(self, inputs: TensorType, **kwargs) -> TensorType: del kwargs output = self._hidden_layer(inputs) return self._output_layer(output) @@ -48,9 +55,11 @@ class PositionwiseFeedforward(tf.keras.layers.Layer): class TrXLNet(RecurrentNetwork): """A TrXL net Model described in [1].""" - def __init__(self, observation_space, action_space, num_outputs, - model_config, name, num_transformer_units, attn_dim, - num_heads, head_dim, ff_hidden_dim): + def __init__(self, observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, num_outputs: int, + model_config: ModelConfigDict, name: str, + num_transformer_units: int, attn_dim: int, num_heads: int, + head_dim: int, ff_hidden_dim: int): """Initializes a TfXLNet object. Args: @@ -109,7 +118,8 @@ class TrXLNet(RecurrentNetwork): self.register_variables(self.base_model.variables) @override(RecurrentNetwork) - def forward_rnn(self, inputs, state, seq_lens): + def forward_rnn(self, inputs: TensorType, state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): # To make Attention work with current RLlib's ModelV2 API: # We assume `state` is the history of L recent observations (all # concatenated into one tensor) and append the current inputs to the @@ -126,7 +136,7 @@ class TrXLNet(RecurrentNetwork): return logits, [observations] @override(RecurrentNetwork) - def get_initial_state(self): + def get_initial_state(self) -> List[np.ndarray]: # State is the T last observations concat'd together into one Tensor. # Plus all Transformer blocks' E(l) outputs concat'd together (up to # tau timesteps). @@ -156,18 +166,18 @@ class GTrXLNet(RecurrentNetwork): """ def __init__(self, - observation_space, - action_space, - num_outputs, - model_config, - name, - num_transformer_units, - attn_dim, - num_heads, - memory_tau, - head_dim, - ff_hidden_dim, - init_gate_bias=2.0): + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: int, + model_config: ModelConfigDict, + name: str, + num_transformer_units: int, + attn_dim: int, + num_heads: int, + memory_tau: int, + head_dim: int, + ff_hidden_dim: int, + init_gate_bias: float = 2.0): """Initializes a GTrXLNet. Args: @@ -271,7 +281,8 @@ class GTrXLNet(RecurrentNetwork): self.trxl_model.summary() @override(RecurrentNetwork) - def forward_rnn(self, inputs, state, seq_lens): + def forward_rnn(self, inputs: TensorType, state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): # To make Attention work with current RLlib's ModelV2 API: # We assume `state` is the history of L recent observations (all # concatenated into one tensor) and append the current inputs to the @@ -303,7 +314,7 @@ class GTrXLNet(RecurrentNetwork): return logits, [observations] + memory_outs @override(RecurrentNetwork) - def get_initial_state(self): + def get_initial_state(self) -> List[np.ndarray]: # State is the T last observations concat'd together into one Tensor. # Plus all Transformer blocks' E(l) outputs concat'd together (up to # tau timesteps). @@ -312,11 +323,11 @@ class GTrXLNet(RecurrentNetwork): for _ in range(self.num_transformer_units)] @override(ModelV2) - def value_function(self): + def value_function(self) -> TensorType: return tf.reshape(self._value_out, [-1]) -def relative_position_embedding(seq_length, out_dim): +def relative_position_embedding(seq_length: int, out_dim: int) -> TensorType: """Creates a [seq_length x seq_length] matrix for rel. pos encoding. Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding diff --git a/rllib/models/tf/fcnet.py b/rllib/models/tf/fcnet.py index 36f0f3819..0f72c546b 100644 --- a/rllib/models/tf/fcnet.py +++ b/rllib/models/tf/fcnet.py @@ -1,8 +1,10 @@ import numpy as np +import gym from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils.framework import get_activation_fn, try_import_tf +from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict tf1, tf, tfv = try_import_tf() @@ -10,8 +12,9 @@ tf1, tf, tfv = try_import_tf() class FullyConnectedNetwork(TFModelV2): """Generic fully connected network implemented in ModelV2 API.""" - def __init__(self, obs_space, action_space, num_outputs, model_config, - name): + def __init__(self, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, num_outputs: int, + model_config: ModelConfigDict, name: str): super(FullyConnectedNetwork, self).__init__( obs_space, action_space, num_outputs, model_config, name) @@ -113,9 +116,11 @@ class FullyConnectedNetwork(TFModelV2): if logits_out is not None else last_layer), value_out]) self.register_variables(self.base_model.variables) - def forward(self, input_dict, state, seq_lens): + def forward(self, input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): model_out, self._value_out = self.base_model(input_dict["obs_flat"]) return model_out, state - def value_function(self): + def value_function(self) -> TensorType: return tf.reshape(self._value_out, [-1]) diff --git a/rllib/models/tf/layers/gru_gate.py b/rllib/models/tf/layers/gru_gate.py index 89dbf652b..6c054058d 100644 --- a/rllib/models/tf/layers/gru_gate.py +++ b/rllib/models/tf/layers/gru_gate.py @@ -1,14 +1,15 @@ from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import TensorType, TensorShape tf1, tf, tfv = try_import_tf() class GRUGate(tf.keras.layers.Layer if tf else object): - def __init__(self, init_bias=0., **kwargs): + def __init__(self, init_bias: float = 0., **kwargs): super().__init__(**kwargs) self._init_bias = init_bias - def build(self, input_shape): + def build(self, input_shape: TensorShape): h_shape, x_shape = input_shape if x_shape[-1] != h_shape[-1]: raise ValueError( @@ -29,7 +30,7 @@ class GRUGate(tf.keras.layers.Layer if tf else object): self._bias_z = self.add_weight( shape=(dim, ), initializer=bias_initializer) - def call(self, inputs, **kwargs): + def call(self, inputs: TensorType, **kwargs) -> TensorType: # Pass in internal state first. h, X = inputs diff --git a/rllib/models/tf/layers/multi_head_attention.py b/rllib/models/tf/layers/multi_head_attention.py index 0971f186f..4ed4fb5a6 100644 --- a/rllib/models/tf/layers/multi_head_attention.py +++ b/rllib/models/tf/layers/multi_head_attention.py @@ -4,6 +4,7 @@ https://arxiv.org/pdf/1706.03762.pdf """ from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import TensorType tf1, tf, tfv = try_import_tf() @@ -11,7 +12,7 @@ tf1, tf, tfv = try_import_tf() class MultiHeadAttention(tf.keras.layers.Layer if tf else object): """A multi-head attention layer described in [1].""" - def __init__(self, out_dim, num_heads, head_dim, **kwargs): + def __init__(self, out_dim: int, num_heads: int, head_dim: int, **kwargs): super().__init__(**kwargs) # No bias or non-linearity. @@ -22,7 +23,7 @@ class MultiHeadAttention(tf.keras.layers.Layer if tf else object): self._linear_layer = tf.keras.layers.TimeDistributed( tf.keras.layers.Dense(out_dim, use_bias=False)) - def call(self, inputs): + def call(self, inputs: TensorType) -> TensorType: L = tf.shape(inputs)[1] # length of segment H = self._num_heads # number of attention heads D = self._head_dim # attention head dimension diff --git a/rllib/models/tf/layers/noisy_layer.py b/rllib/models/tf/layers/noisy_layer.py index f7945d3f8..49b11e0c6 100644 --- a/rllib/models/tf/layers/noisy_layer.py +++ b/rllib/models/tf/layers/noisy_layer.py @@ -2,6 +2,7 @@ import numpy as np from ray.rllib.utils.framework import get_activation_fn, get_variable, \ try_import_tf +from ray.rllib.utils.framework import TensorType, TensorShape tf1, tf, tfv = try_import_tf() @@ -18,14 +19,18 @@ class NoisyLayer(tf.keras.layers.Layer if tf else object): vanish along the training procedure. """ - def __init__(self, prefix, out_size, sigma0, activation="relu"): + def __init__(self, + prefix: str, + out_size: int, + sigma0: float, + activation: str = "relu"): """Initializes a NoisyLayer object. Args: prefix: - out_size: - sigma0: - non_linear: + out_size: Output size for Noisy Layer + sigma0: Initialization value for sigma_b (bias noise) + non_linear: Non-linear activation for Noisy Layer """ super().__init__() self.prefix = prefix @@ -41,7 +46,7 @@ class NoisyLayer(tf.keras.layers.Layer if tf else object): self.sigma_w = None # Noise for weight matrix self.sigma_b = None # Noise for biases. - def build(self, input_shape): + def build(self, input_shape: TensorShape): in_size = int(input_shape[1]) self.sigma_w = get_variable( @@ -78,7 +83,7 @@ class NoisyLayer(tf.keras.layers.Layer if tf else object): dtype=tf.float32, ) - def call(self, inputs): + def call(self, inputs: TensorType) -> TensorType: in_size = int(inputs.shape[1]) epsilon_in = tf.random.normal(shape=[in_size]) epsilon_out = tf.random.normal(shape=[self.out_size]) @@ -98,5 +103,5 @@ class NoisyLayer(tf.keras.layers.Layer if tf else object): action_activation = fn(action_activation) return action_activation - def _f_epsilon(self, x): + def _f_epsilon(self, x: TensorType) -> TensorType: return tf.math.sign(x) * tf.math.sqrt(tf.math.abs(x)) diff --git a/rllib/models/tf/layers/relative_multi_head_attention.py b/rllib/models/tf/layers/relative_multi_head_attention.py index bd52c0bf7..f7d70ab60 100644 --- a/rllib/models/tf/layers/relative_multi_head_attention.py +++ b/rllib/models/tf/layers/relative_multi_head_attention.py @@ -1,4 +1,7 @@ +from typing import Optional, Any + from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import TensorType tf1, tf, tfv = try_import_tf() @@ -10,12 +13,12 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): """ def __init__(self, - out_dim, - num_heads, - head_dim, - rel_pos_encoder, - input_layernorm=False, - output_activation=None, + out_dim: int, + num_heads: int, + head_dim: int, + rel_pos_encoder: Any, + input_layernorm: bool = False, + output_activation: Optional[Any] = None, **kwargs): """Initializes a RelativeMultiHeadAttention keras Layer object. @@ -55,7 +58,8 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): if input_layernorm: self._input_layernorm = tf.keras.layers.LayerNormalization(axis=-1) - def call(self, inputs, memory=None): + def call(self, inputs: TensorType, + memory: Optional[TensorType] = None) -> TensorType: T = tf.shape(inputs)[1] # length of segment (time) H = self._num_heads # number of attention heads d = self._head_dim # attention head dimension @@ -105,7 +109,7 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): return self._linear_layer(out) @staticmethod - def rel_shift(x): + def rel_shift(x: TensorType) -> TensorType: # Transposed version of the shift approach described in [3]. # https://github.com/kimiyoung/transformer-xl/blob/ # 44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31 diff --git a/rllib/models/tf/layers/skip_connection.py b/rllib/models/tf/layers/skip_connection.py index 9d6b766e4..efb89f2e3 100644 --- a/rllib/models/tf/layers/skip_connection.py +++ b/rllib/models/tf/layers/skip_connection.py @@ -1,4 +1,7 @@ +from typing import Optional, Any + from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import TensorType tf1, tf, tfv = try_import_tf() @@ -10,7 +13,11 @@ class SkipConnection(tf.keras.layers.Layer if tf else object): input as hidden state input to a given fan_in_layer. """ - def __init__(self, layer, fan_in_layer=None, add_memory=False, **kwargs): + def __init__(self, + layer: Any, + fan_in_layer: Optional[Any] = None, + add_memory: bool = False, + **kwargs): """Initializes a SkipConnection keras layer object. Args: @@ -23,7 +30,7 @@ class SkipConnection(tf.keras.layers.Layer if tf else object): self._layer = layer self._fan_in_layer = fan_in_layer - def call(self, inputs, **kwargs): + def call(self, inputs: TensorType, **kwargs) -> TensorType: # del kwargs outputs = self._layer(inputs, **kwargs) # Residual case, just add inputs to outputs. diff --git a/rllib/models/tf/misc.py b/rllib/models/tf/misc.py index 1da1bbb86..4e219aff2 100644 --- a/rllib/models/tf/misc.py +++ b/rllib/models/tf/misc.py @@ -1,10 +1,13 @@ import numpy as np +from typing import Tuple, Any, Optional + from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import TensorType tf1, tf, tfv = try_import_tf() -def normc_initializer(std=1.0): +def normc_initializer(std: float = 1.0) -> Any: def _initializer(shape, dtype=None, partition_info=None): out = np.random.randn(*shape).astype(np.float32) out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) @@ -13,14 +16,14 @@ def normc_initializer(std=1.0): return _initializer -def conv2d(x, - num_filters, - name, - filter_size=(3, 3), - stride=(1, 1), - pad="SAME", - dtype=None, - collections=None): +def conv2d(x: TensorType, + num_filters: int, + name: str, + filter_size: Tuple[int, int] = (3, 3), + stride: Tuple[int, int] = (1, 1), + pad: str = "SAME", + dtype: Optional[Any] = None, + collections: Optional[Any] = None) -> TensorType: if dtype is None: dtype = tf.float32 @@ -53,7 +56,11 @@ def conv2d(x, return tf1.nn.conv2d(x, w, stride_shape, pad) + b -def linear(x, size, name, initializer=None, bias_init=0): +def linear(x: TensorType, + size: int, + name: str, + initializer: Optional[Any] = None, + bias_init: float = 0.0) -> TensorType: w = tf1.get_variable( name + "/w", [x.get_shape()[1], size], initializer=initializer) b = tf1.get_variable( @@ -61,5 +68,5 @@ def linear(x, size, name, initializer=None, bias_init=0): return tf.matmul(x, w) + b -def flatten(x): +def flatten(x: TensorType) -> TensorType: return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])]) diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index 9a0ee7438..5be621882 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -1,4 +1,6 @@ import numpy as np +import gym +from typing import Dict, List from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_modelv2 import TFModelV2 @@ -6,6 +8,7 @@ from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() @@ -49,7 +52,9 @@ class RecurrentNetwork(TFModelV2): """ @override(ModelV2) - def forward(self, input_dict, state, seq_lens): + def forward(self, input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): """Adds time dimension to batch before sending inputs to forward_rnn(). You should implement forward_rnn() in your subclass.""" @@ -62,7 +67,8 @@ class RecurrentNetwork(TFModelV2): seq_lens) return tf.reshape(output, [-1, self.num_outputs]), new_state - def forward_rnn(self, inputs, state, seq_lens): + def forward_rnn(self, inputs: TensorType, state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): """Call the model with the given input tensors and state. Args: @@ -83,7 +89,7 @@ class RecurrentNetwork(TFModelV2): """ raise NotImplementedError("You must implement this for a RNN model") - def get_initial_state(self): + def get_initial_state(self) -> List[TensorType]: """Get the initial recurrent state values for the model. Returns: @@ -104,8 +110,9 @@ class LSTMWrapper(RecurrentNetwork): """An LSTM wrapper serving as an interface for ModelV2s that set use_lstm. """ - def __init__(self, obs_space, action_space, num_outputs, model_config, - name): + def __init__(self, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, num_outputs: int, + model_config: ModelConfigDict, name: str): super(LSTMWrapper, self).__init__(obs_space, action_space, None, model_config, name) @@ -157,7 +164,9 @@ class LSTMWrapper(RecurrentNetwork): self._rnn_model.summary() @override(RecurrentNetwork) - def forward(self, input_dict, state, seq_lens): + def forward(self, input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): assert seq_lens is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _ = self._wrapped_forward(input_dict, [], None) @@ -181,18 +190,19 @@ class LSTMWrapper(RecurrentNetwork): return super().forward(input_dict, state, seq_lens) @override(RecurrentNetwork) - def forward_rnn(self, inputs, state, seq_lens): + def forward_rnn(self, inputs: TensorType, state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): model_out, self._value_out, h, c = self._rnn_model([inputs, seq_lens] + state) return model_out, [h, c] @override(ModelV2) - def get_initial_state(self): + def get_initial_state(self) -> List[np.ndarray]: return [ np.zeros(self.cell_size, np.float32), np.zeros(self.cell_size, np.float32), ] @override(ModelV2) - def value_function(self): + def value_function(self) -> TensorType: return tf.reshape(self._value_out, [-1]) diff --git a/rllib/models/tf/tf_action_dist.py b/rllib/models/tf/tf_action_dist.py index d125b99d0..27813ddd3 100644 --- a/rllib/models/tf/tf_action_dist.py +++ b/rllib/models/tf/tf_action_dist.py @@ -2,6 +2,7 @@ from math import log import numpy as np import functools import tree +import gym from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 @@ -10,7 +11,8 @@ from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \ from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space -from ray.rllib.utils.typing import TensorType, List +from ray.rllib.utils.typing import TensorType, List, Union, \ + Tuple, ModelConfigDict tf1, tf, tfv = try_import_tf() tfp = try_import_tfp() @@ -50,23 +52,26 @@ class Categorical(TFActionDistribution): """Categorical distribution for discrete action spaces.""" @DeveloperAPI - def __init__(self, inputs, model=None, temperature=1.0): + def __init__(self, + inputs: List[TensorType], + model: ModelV2 = None, + temperature: float = 1.0): assert temperature > 0.0, "Categorical `temperature` must be > 0.0!" # Allow softmax formula w/ temperature != 1.0: # Divide inputs by temperature. super().__init__(inputs / temperature, model) @override(ActionDistribution) - def deterministic_sample(self): + def deterministic_sample(self) -> TensorType: return tf.math.argmax(self.inputs, axis=1) @override(ActionDistribution) - def logp(self, x): + def logp(self, x: TensorType) -> TensorType: return -tf.nn.sparse_softmax_cross_entropy_with_logits( logits=self.inputs, labels=tf.cast(x, tf.int32)) @override(ActionDistribution) - def entropy(self): + def entropy(self) -> TensorType: a0 = self.inputs - tf.reduce_max(self.inputs, axis=1, keepdims=True) ea0 = tf.exp(a0) z0 = tf.reduce_sum(ea0, axis=1, keepdims=True) @@ -74,7 +79,7 @@ class Categorical(TFActionDistribution): return tf.reduce_sum(p0 * (tf.math.log(z0) - a0), axis=1) @override(ActionDistribution) - def kl(self, other): + def kl(self, other: ActionDistribution) -> TensorType: a0 = self.inputs - tf.reduce_max(self.inputs, axis=1, keepdims=True) a1 = other.inputs - tf.reduce_max(other.inputs, axis=1, keepdims=True) ea0 = tf.exp(a0) @@ -86,7 +91,7 @@ class Categorical(TFActionDistribution): p0 * (a0 - tf.math.log(z0) - a1 + tf.math.log(z1)), axis=1) @override(TFActionDistribution) - def _build_sample_op(self): + def _build_sample_op(self) -> TensorType: return tf.squeeze(tf.random.categorical(self.inputs, 1), axis=1) @staticmethod @@ -98,7 +103,8 @@ class Categorical(TFActionDistribution): class MultiCategorical(TFActionDistribution): """MultiCategorical distribution for MultiDiscrete action spaces.""" - def __init__(self, inputs, model, input_lens): + def __init__(self, inputs: List[TensorType], model: ModelV2, + input_lens: Union[List[int], np.ndarray, Tuple[int, ...]]): # skip TFActionDistribution init ActionDistribution.__init__(self, inputs, model) self.cats = [ @@ -109,12 +115,12 @@ class MultiCategorical(TFActionDistribution): self.sampled_action_logp_op = self.logp(self.sample_op) @override(ActionDistribution) - def deterministic_sample(self): + def deterministic_sample(self) -> TensorType: return tf.stack( [cat.deterministic_sample() for cat in self.cats], axis=1) @override(ActionDistribution) - def logp(self, actions): + def logp(self, actions: TensorType) -> TensorType: # If tensor is provided, unstack it into list. if isinstance(actions, tf.Tensor): actions = tf.unstack(tf.cast(actions, tf.int32), axis=1) @@ -123,30 +129,32 @@ class MultiCategorical(TFActionDistribution): return tf.reduce_sum(logps, axis=0) @override(ActionDistribution) - def multi_entropy(self): + def multi_entropy(self) -> TensorType: return tf.stack([cat.entropy() for cat in self.cats], axis=1) @override(ActionDistribution) - def entropy(self): + def entropy(self) -> TensorType: return tf.reduce_sum(self.multi_entropy(), axis=1) @override(ActionDistribution) - def multi_kl(self, other): + def multi_kl(self, other: ActionDistribution) -> TensorType: return tf.stack( [cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)], axis=1) @override(ActionDistribution) - def kl(self, other): + def kl(self, other: ActionDistribution) -> TensorType: return tf.reduce_sum(self.multi_kl(other), axis=1) @override(TFActionDistribution) - def _build_sample_op(self): + def _build_sample_op(self) -> TensorType: return tf.stack([cat.sample() for cat in self.cats], axis=1) @staticmethod @override(ActionDistribution) - def required_model_output_shape(action_space, model_config): + def required_model_output_shape( + action_space: gym.Space, + model_config: ModelConfigDict) -> Union[int, np.ndarray]: return np.sum(action_space.nvec) @@ -168,7 +176,10 @@ class GumbelSoftmax(TFActionDistribution): """ @DeveloperAPI - def __init__(self, inputs, model=None, temperature=1.0): + def __init__(self, + inputs: List[TensorType], + model: ModelV2 = None, + temperature: float = 1.0): """Initializes a GumbelSoftmax distribution. Args: @@ -184,12 +195,12 @@ class GumbelSoftmax(TFActionDistribution): super().__init__(inputs, model) @override(ActionDistribution) - def deterministic_sample(self): + def deterministic_sample(self) -> TensorType: # Return the dist object's prob values. return self.probs @override(ActionDistribution) - def logp(self, x): + def logp(self, x: TensorType) -> TensorType: # Override since the implementation of tfp.RelaxedOneHotCategorical # yields positive values. if x.shape != self.dist.logits.shape: @@ -204,12 +215,14 @@ class GumbelSoftmax(TFActionDistribution): -x * tf.nn.log_softmax(self.dist.logits, axis=-1), axis=-1) @override(TFActionDistribution) - def _build_sample_op(self): + def _build_sample_op(self) -> TensorType: return self.dist.sample() @staticmethod @override(ActionDistribution) - def required_model_output_shape(action_space, model_config): + def required_model_output_shape( + action_space: gym.Space, + model_config: ModelConfigDict) -> Union[int, np.ndarray]: return action_space.n @@ -220,7 +233,7 @@ class DiagGaussian(TFActionDistribution): second half the gaussian standard deviations. """ - def __init__(self, inputs, model): + def __init__(self, inputs: List[TensorType], model: ModelV2): mean, log_std = tf.split(inputs, 2, axis=1) self.mean = mean self.log_std = log_std @@ -228,11 +241,11 @@ class DiagGaussian(TFActionDistribution): super().__init__(inputs, model) @override(ActionDistribution) - def deterministic_sample(self): + def deterministic_sample(self) -> TensorType: return self.mean @override(ActionDistribution) - def logp(self, x): + def logp(self, x: TensorType) -> TensorType: return -0.5 * tf.reduce_sum( tf.math.square((tf.cast(x, tf.float32) - self.mean) / self.std), axis=1 @@ -240,7 +253,7 @@ class DiagGaussian(TFActionDistribution): tf.reduce_sum(self.log_std, axis=1) @override(ActionDistribution) - def kl(self, other): + def kl(self, other: ActionDistribution) -> TensorType: assert isinstance(other, DiagGaussian) return tf.reduce_sum( other.log_std - self.log_std + @@ -249,17 +262,19 @@ class DiagGaussian(TFActionDistribution): axis=1) @override(ActionDistribution) - def entropy(self): + def entropy(self) -> TensorType: return tf.reduce_sum( self.log_std + .5 * np.log(2.0 * np.pi * np.e), axis=1) @override(TFActionDistribution) - def _build_sample_op(self): + def _build_sample_op(self) -> TensorType: return self.mean + self.std * tf.random.normal(tf.shape(self.mean)) @staticmethod @override(ActionDistribution) - def required_model_output_shape(action_space, model_config): + def required_model_output_shape( + action_space: gym.Space, + model_config: ModelConfigDict) -> Union[int, np.ndarray]: return np.prod(action_space.shape) * 2 @@ -270,7 +285,11 @@ class SquashedGaussian(TFActionDistribution): `low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively. """ - def __init__(self, inputs, model, low=-1.0, high=1.0): + def __init__(self, + inputs: List[TensorType], + model: ModelV2, + low: float = -1.0, + high: float = 1.0): """Parameterizes the distribution via `inputs`. Args: @@ -292,16 +311,16 @@ class SquashedGaussian(TFActionDistribution): super().__init__(inputs, model) @override(ActionDistribution) - def deterministic_sample(self): + def deterministic_sample(self) -> TensorType: mean = self.distr.mean() return self._squash(mean) @override(TFActionDistribution) - def _build_sample_op(self): + def _build_sample_op(self) -> TensorType: return self._squash(self.distr.sample()) @override(ActionDistribution) - def logp(self, x): + def logp(self, x: TensorType) -> TensorType: # Unsquash values (from [low,high] to ]-inf,inf[) unsquashed_values = self._unsquash(x) # Get log prob of unsquashed values from our Normal. @@ -316,13 +335,13 @@ class SquashedGaussian(TFActionDistribution): axis=-1) return log_prob - def _squash(self, raw_values): + def _squash(self, raw_values: TensorType) -> TensorType: # Returned values are within [low, high] (including `low` and `high`). squashed = ((tf.math.tanh(raw_values) + 1.0) / 2.0) * \ (self.high - self.low) + self.low return tf.clip_by_value(squashed, self.low, self.high) - def _unsquash(self, values): + def _unsquash(self, values: TensorType) -> TensorType: normed_values = (values - self.low) / (self.high - self.low) * 2.0 - \ 1.0 # Stabilize input to atanh. @@ -333,7 +352,9 @@ class SquashedGaussian(TFActionDistribution): @staticmethod @override(ActionDistribution) - def required_model_output_shape(action_space, model_config): + def required_model_output_shape( + action_space: gym.Space, + model_config: ModelConfigDict) -> Union[int, np.ndarray]: return np.prod(action_space.shape) * 2 @@ -347,7 +368,11 @@ class Beta(TFActionDistribution): and Gamma(n) = (n - 1)! """ - def __init__(self, inputs, model, low=0.0, high=1.0): + def __init__(self, + inputs: List[TensorType], + model: ModelV2, + low: float = 0.0, + high: float = 1.0): # Stabilize input parameters (possibly coming from a linear layer). inputs = tf.clip_by_value(inputs, log(SMALL_NUMBER), -log(SMALL_NUMBER)) @@ -361,29 +386,31 @@ class Beta(TFActionDistribution): super().__init__(inputs, model) @override(ActionDistribution) - def deterministic_sample(self): + def deterministic_sample(self) -> TensorType: mean = self.dist.mean() return self._squash(mean) @override(TFActionDistribution) - def _build_sample_op(self): + def _build_sample_op(self) -> TensorType: return self._squash(self.dist.sample()) @override(ActionDistribution) - def logp(self, x): + def logp(self, x: TensorType) -> TensorType: unsquashed_values = self._unsquash(x) return tf.math.reduce_sum( self.dist.log_prob(unsquashed_values), axis=-1) - def _squash(self, raw_values): + def _squash(self, raw_values: TensorType) -> TensorType: return raw_values * (self.high - self.low) + self.low - def _unsquash(self, values): + def _unsquash(self, values: TensorType) -> TensorType: return (values - self.low) / (self.high - self.low) @staticmethod @override(ActionDistribution) - def required_model_output_shape(action_space, model_config): + def required_model_output_shape( + action_space: gym.Space, + model_config: ModelConfigDict) -> Union[int, np.ndarray]: return np.prod(action_space.shape) * 2 @@ -395,20 +422,22 @@ class Deterministic(TFActionDistribution): """ @override(ActionDistribution) - def deterministic_sample(self): + def deterministic_sample(self) -> TensorType: return self.inputs @override(TFActionDistribution) - def logp(self, x): + def logp(self, x: TensorType) -> TensorType: return tf.zeros_like(self.inputs) @override(TFActionDistribution) - def _build_sample_op(self): + def _build_sample_op(self) -> TensorType: return self.inputs @staticmethod @override(ActionDistribution) - def required_model_output_shape(action_space, model_config): + def required_model_output_shape( + action_space: gym.Space, + model_config: ModelConfigDict) -> Union[int, np.ndarray]: return np.prod(action_space.shape) @@ -503,7 +532,7 @@ class Dirichlet(TFActionDistribution): e.g. actions that represent resource allocation.""" - def __init__(self, inputs, model): + def __init__(self, inputs: List[TensorType], model: ModelV2): """Input is a tensor of logits. The exponential of logits is used to parametrize the Dirichlet distribution as all parameters need to be positive. An arbitrary small epsilon is added to the concentration @@ -525,7 +554,7 @@ class Dirichlet(TFActionDistribution): return tf.nn.softmax(self.dist.concentration) @override(ActionDistribution) - def logp(self, x): + def logp(self, x: TensorType) -> TensorType: # Support of Dirichlet are positive real numbers. x is already # an array of positive numbers, but we clip to avoid zeros due to # numerical errors. @@ -534,18 +563,20 @@ class Dirichlet(TFActionDistribution): return self.dist.log_prob(x) @override(ActionDistribution) - def entropy(self): + def entropy(self) -> TensorType: return self.dist.entropy() @override(ActionDistribution) - def kl(self, other): + def kl(self, other: ActionDistribution) -> TensorType: return self.dist.kl_divergence(other.dist) @override(TFActionDistribution) - def _build_sample_op(self): + def _build_sample_op(self) -> TensorType: return self.dist.sample() @staticmethod @override(ActionDistribution) - def required_model_output_shape(action_space, model_config): + def required_model_output_shape( + action_space: gym.Space, + model_config: ModelConfigDict) -> Union[int, np.ndarray]: return np.prod(action_space.shape) diff --git a/rllib/models/tf/visionnet.py b/rllib/models/tf/visionnet.py index 857929b11..e09668b49 100644 --- a/rllib/models/tf/visionnet.py +++ b/rllib/models/tf/visionnet.py @@ -1,7 +1,11 @@ +from typing import Dict, List +import gym + from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.models.utils import get_filter_config from ray.rllib.utils.framework import get_activation_fn, try_import_tf +from ray.rllib.utils.typing import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() @@ -9,8 +13,9 @@ tf1, tf, tfv = try_import_tf() class VisionNetwork(TFModelV2): """Generic vision network implemented in ModelV2 API.""" - def __init__(self, obs_space, action_space, num_outputs, model_config, - name): + def __init__(self, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, num_outputs: int, + model_config: ModelConfigDict, name: str): if not model_config.get("conv_filters"): model_config["conv_filters"] = get_filter_config(obs_space.shape) @@ -137,7 +142,9 @@ class VisionNetwork(TFModelV2): self.base_model = tf.keras.Model(inputs, [conv_out, value_out]) self.register_variables(self.base_model.variables) - def forward(self, input_dict, state, seq_lens): + def forward(self, input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): # Explicit cast to float32 needed in eager. model_out, self._value_out = self.base_model( tf.cast(input_dict["obs"], tf.float32)) @@ -148,5 +155,5 @@ class VisionNetwork(TFModelV2): else: return tf.squeeze(model_out, axis=[1, 2]), state - def value_function(self): + def value_function(self) -> TensorType: return tf.reshape(self._value_out, [-1])