diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index b1f933aee..87f08bd29 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -191,35 +191,44 @@ Similarly, you can create and register custom PyTorch models for use with PyTorc import ray from ray.rllib.agents import a3c from ray.rllib.models import ModelCatalog - from ray.rllib.models.torch.model import TorchModel + from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 - class CustomTorchModel(TorchModel): + class CustomTorchModel(TorchModelV2): - def __init__(self, obs_space, num_outputs, options): - TorchModel.__init__(self, obs_space, num_outputs, options) + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super(CustomTorchModel, self).__init__( + obs_space, action_space, num_outputs, model_config, name) ... # setup hidden layers - def _forward(self, input_dict, hidden_state): - """Forward pass for the model. + def forward(self, input_dict, state, seq_lens): + """Call the model with the given input tensors and state. - Prefer implementing this instead of forward() directly for proper - handling of Dict and Tuple observations. + Any complex observations (dicts, tuples, etc.) will be unpacked by + __call__ before being passed to forward(). To access the flattened + observation tensor, refer to input_dict["obs_flat"]. + + This method can be called any number of times. In eager execution, + each call to forward() will eagerly evaluate the model. In symbolic + execution, each call to forward creates a computation graph that + operates over the variables of this model (i.e., shares weights). + + Custom models should override this instead of __call__. Arguments: - input_dict (dict): Dictionary of tensor inputs, commonly - including "obs", "prev_action", "prev_reward", each of shape - [BATCH_SIZE, ...]. - hidden_state (list): List of hidden state tensors, each of shape - [BATCH_SIZE, h_size]. + input_dict (dict): dictionary of input tensors, including "obs", + "obs_flat", "prev_action", "prev_reward", "is_training" + state (list): list of state tensors with sizes matching those + returned by get_initial_state + the batch dimension + seq_lens (Tensor): 1d tensor holding input sequence lengths Returns: - (outputs, feature_layer, values, state): Tensors of size - [BATCH_SIZE, num_outputs], [BATCH_SIZE, desired_feature_size], - [BATCH_SIZE], and [len(hidden_state), BATCH_SIZE, h_size]. + (outputs, state): The model output tensor of size + [BATCH, num_outputs] """ obs = input_dict["obs"] ... - return logits, features, value, hidden_state + return logits, state ModelCatalog.register_custom_model("my_model", CustomTorchModel) diff --git a/python/ray/rllib/agents/a3c/a3c_torch_policy.py b/python/ray/rllib/agents/a3c/a3c_torch_policy.py index f11ff51be..8045c397f 100644 --- a/python/ray/rllib/agents/a3c/a3c_torch_policy.py +++ b/python/ray/rllib/agents/a3c/a3c_torch_policy.py @@ -14,9 +14,10 @@ from ray.rllib.policy.torch_policy_template import build_torch_policy def actor_critic_loss(policy, batch_tensors): - logits, _, values, _ = policy.model({ + logits, _ = policy.model({ SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] - }, []) + }) # TODO(ekl) seq lens shouldn't be None + values = policy.model.value_function() dist = policy.dist_class(logits) log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS]) policy.entropy = dist.entropy().mean() @@ -53,8 +54,8 @@ def add_advantages(policy, policy.config["lambda"]) -def model_value_predictions(policy, input_dict, state_batches, model_out): - return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} +def model_value_predictions(policy, input_dict, state_batches, model): + return {SampleBatch.VF_PREDS: model.value_function().cpu().numpy()} def apply_grad_clipping(policy): @@ -74,8 +75,8 @@ class ValueNetworkMixin(object): def _value(self, obs): with self.lock: obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) - _, _, vf, _ = self.model({"obs": obs}, []) - return vf.detach().cpu().numpy().squeeze() + _ = self.model({"obs": obs}, [], [1]) + return self.model.value_function().detach().cpu().numpy().squeeze() A3CTorchPolicy = build_torch_policy( diff --git a/python/ray/rllib/agents/pg/torch_pg_policy.py b/python/ray/rllib/agents/pg/torch_pg_policy.py index d0f1cda71..442c57f48 100644 --- a/python/ray/rllib/agents/pg/torch_pg_policy.py +++ b/python/ray/rllib/agents/pg/torch_pg_policy.py @@ -10,9 +10,9 @@ from ray.rllib.policy.torch_policy_template import build_torch_policy def pg_torch_loss(policy, batch_tensors): - logits, _, values, _ = policy.model({ + logits, _ = policy.model({ SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] - }, []) + }) action_dist = policy.dist_class(logits) log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) # save the error in the policy object diff --git a/python/ray/rllib/agents/qmix/model.py b/python/ray/rllib/agents/qmix/model.py index dd4377a4a..81e9dce1c 100644 --- a/python/ray/rllib/agents/qmix/model.py +++ b/python/ray/rllib/agents/qmix/model.py @@ -6,33 +6,35 @@ from torch import nn import torch.nn.functional as F from ray.rllib.models.preprocessors import get_preprocessor -from ray.rllib.models.torch.model import TorchModel +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.utils.annotations import override -class RNNModel(TorchModel): +class RNNModel(TorchModelV2): """The default RNN model for QMIX.""" - def __init__(self, obs_space, num_outputs, options): - TorchModel.__init__(self, obs_space, num_outputs, options) + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super(RNNModel, self).__init__(obs_space, action_space, num_outputs, + model_config, name) self.obs_size = _get_size(obs_space) - self.rnn_hidden_dim = options["lstm_cell_size"] + self.rnn_hidden_dim = model_config["lstm_cell_size"] self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim) self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim) self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs) - @override(TorchModel) - def state_init(self): + @override(TorchModelV2) + def get_initial_state(self): # make hidden states on same device as model return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)] - @override(TorchModel) - def _forward(self, input_dict, hidden_state): - x = F.relu(self.fc1(input_dict["obs"])) + @override(TorchModelV2) + def forward(self, input_dict, hidden_state, seq_lens): + x = F.relu(self.fc1(input_dict["obs_flat"].float())) h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim) h = self.rnn(x, h_in) q = self.fc2(h) - return q, h, None, [h] + return q, [h] def _get_size(obs_space): diff --git a/python/ray/rllib/agents/qmix/qmix_policy.py b/python/ray/rllib/agents/qmix/qmix_policy.py index 990458996..87020d222 100644 --- a/python/ray/rllib/agents/qmix/qmix_policy.py +++ b/python/ray/rllib/agents/qmix/qmix_policy.py @@ -65,7 +65,10 @@ class QMixLoss(nn.Module): # Calculate estimated Q-Values mac_out = [] - h = [s.expand([B, self.n_agents, -1]) for s in self.model.state_init()] + h = [ + s.expand([B, self.n_agents, -1]) + for s in self.model.get_initial_state() + ] for t in range(T): q, h = _mac(self.model, obs[:, t], h) mac_out.append(q) @@ -79,7 +82,7 @@ class QMixLoss(nn.Module): target_mac_out = [] target_h = [ s.expand([B, self.n_agents, -1]) - for s in self.target_model.state_init() + for s in self.target_model.get_initial_state() ] for t in range(T): target_q, target_h = _mac(self.target_model, next_obs[:, t], @@ -171,16 +174,23 @@ class QMixTorchPolicy(Policy): self.has_action_mask = False self.obs_size = _get_size(agent_obs_space) - self.model = ModelCatalog.get_torch_model( + self.model = ModelCatalog.get_model_v2( agent_obs_space, + action_space.spaces[0], self.n_actions, config["model"], - default_model_cls=RNNModel) - self.target_model = ModelCatalog.get_torch_model( + framework="torch", + name="model", + default_model=RNNModel) + + self.target_model = ModelCatalog.get_model_v2( agent_obs_space, + action_space.spaces[0], self.n_actions, config["model"], - default_model_cls=RNNModel) + framework="torch", + name="target_model", + default_model=RNNModel) # Setup the mixer network. # The global state is just the stacked agent observations for now. @@ -320,7 +330,7 @@ class QMixTorchPolicy(Policy): def get_initial_state(self): return [ s.expand([self.n_agents, -1]).numpy() - for s in self.model.state_init() + for s in self.model.get_initial_state() ] @override(Policy) @@ -425,7 +435,7 @@ def _mac(model, obs, h): """Forward pass of the multi-agent controller. Arguments: - model: TorchModel class + model: TorchModelV2 class obs: Tensor of shape [B, n_agents, obs_size] h: List of tensors of shape [B, n_agents, h_size] @@ -436,6 +446,6 @@ def _mac(model, obs, h): B, n_agents = obs.size(0), obs.size(1) obs_flat = obs.reshape([B * n_agents, -1]) h_flat = [s.reshape([B * n_agents, -1]) for s in h] - q_flat, _, _, h_flat = model.forward({"obs": obs_flat}, h_flat) + q_flat, h_flat = model({"obs": obs_flat}, h_flat, None) return q_flat.reshape( [B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat] diff --git a/python/ray/rllib/examples/custom_torch_policy.py b/python/ray/rllib/examples/custom_torch_policy.py index 7ab2786cf..e9b30876d 100644 --- a/python/ray/rllib/examples/custom_torch_policy.py +++ b/python/ray/rllib/examples/custom_torch_policy.py @@ -15,9 +15,9 @@ parser.add_argument("--iters", type=int, default=200) def policy_gradient_loss(policy, batch_tensors): - logits, _, values, _ = policy.model({ + logits, _ = policy.model({ SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] - }, []) + }) action_dist = policy.dist_class(logits) log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) return -batch_tensors[SampleBatch.REWARDS].dot(log_probs) diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 4bdda4ecf..9b222fd91 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -208,7 +208,7 @@ class ModelCatalog(object): action_space, num_outputs, model_config, - framework="tf", + framework, name=None, model_interface=None, default_model=None, @@ -240,29 +240,37 @@ class ModelCatalog(object): model_interface): raise ValueError("The given model must subclass", model_interface) - created = set() - # Track and warn if variables were created but no registered - def track_var_creation(next_creator, **kw): - v = next_creator(**kw) - created.add(v) - return v + if framework == "tf": + created = set() - with tf.variable_creator_scope(track_var_creation): + # Track and warn if vars were created but not registered + def track_var_creation(next_creator, **kw): + v = next_creator(**kw) + created.add(v) + return v + + with tf.variable_creator_scope(track_var_creation): + instance = model_cls(obs_space, action_space, + num_outputs, model_config, name, + **model_kwargs) + registered = set(instance.variables()) + not_registered = set() + for var in created: + if var not in registered: + not_registered.add(var) + if not_registered: + raise ValueError( + "It looks like variables {} were created as part " + "of {} but does not appear in model.variables() " + "({}). Did you forget to call " + "model.register_variables() on the variables in " + "question?".format(not_registered, instance, + registered)) + else: + # no variable tracking instance = model_cls(obs_space, action_space, num_outputs, model_config, name, **model_kwargs) - registered = set(instance.variables()) - not_registered = set() - for var in created: - if var not in registered: - not_registered.add(var) - if not_registered: - raise ValueError( - "It looks like variables {} were created as part of " - "{} but does not appear in model.variables() ({}). " - "Did you forget to call model.register_variables() " - "on the variables in question?".format( - not_registered, instance, registered)) return instance if framework == "tf": @@ -271,8 +279,15 @@ class ModelCatalog(object): make_v1_wrapper(legacy_model_cls), model_interface) return wrapper(obs_space, action_space, num_outputs, model_config, name, **model_kwargs) - - raise NotImplementedError("TODO: support {} models".format(framework)) + elif framework == "torch": + if default_model: + return default_model(obs_space, action_space, num_outputs, + model_config, name) + return ModelCatalog._get_default_torch_model_v2( + obs_space, action_space, num_outputs, model_config, name) + else: + raise NotImplementedError( + "Framework must be 'tf' or 'torch': {}".format(framework)) @staticmethod def _wrap_if_needed(model_cls, model_interface): @@ -367,46 +382,32 @@ class ModelCatalog(object): num_outputs, options=None, default_model_cls=None): - """Returns a custom model for PyTorch algorithms. + raise DeprecationWarning("Please use get_model_v2() instead.") - Args: - obs_space (Space): The input observation space. - num_outputs (int): The size of the output vector of the model. - options (dict): Optional args to pass to the model constructor. - default_model_cls (cls): Optional class to use if no custom model. - - Returns: - model (models.Model): Neural network model. - """ + def _get_default_torch_model_v2(obs_space, action_space, num_outputs, + model_config, name): from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as PyTorchFCNet) from ray.rllib.models.torch.visionnet import (VisionNetwork as PyTorchVisionNet) - options = options or MODEL_DEFAULTS + model_config = model_config or MODEL_DEFAULTS - if options.get("custom_model"): - model = options["custom_model"] - logger.debug("Using custom torch model {}".format(model)) - return _global_registry.get(RLLIB_MODEL, - model)(obs_space, num_outputs, options) - - if options.get("use_lstm"): + if model_config.get("use_lstm"): raise NotImplementedError( "LSTM auto-wrapping not implemented for torch") - if default_model_cls: - return default_model_cls(obs_space, num_outputs, options) - if isinstance(obs_space, gym.spaces.Discrete): obs_rank = 1 else: obs_rank = len(obs_space.shape) if obs_rank > 1: - return PyTorchVisionNet(obs_space, num_outputs, options) + return PyTorchVisionNet(obs_space, action_space, num_outputs, + model_config, name) - return PyTorchFCNet(obs_space, num_outputs, options) + return PyTorchFCNet(obs_space, action_space, num_outputs, model_config, + name) @staticmethod @DeveloperAPI diff --git a/python/ray/rllib/models/fcnet.py b/python/ray/rllib/models/fcnet.py index 76d054e3d..01469e659 100644 --- a/python/ray/rllib/models/fcnet.py +++ b/python/ray/rllib/models/fcnet.py @@ -10,6 +10,7 @@ from ray.rllib.utils import try_import_tf tf = try_import_tf() +# TODO(ekl) rewrite this using ModelV2 class FullyConnectedNetwork(Model): """Generic fully connected network.""" diff --git a/python/ray/rllib/models/model.py b/python/ray/rllib/models/model.py index a43b088cc..06a746184 100644 --- a/python/ray/rllib/models/model.py +++ b/python/ray/rllib/models/model.py @@ -14,10 +14,12 @@ from ray.rllib.utils import try_import_tf tf = try_import_tf() -@PublicAPI +# Deprecated: use TFModelV2 instead class Model(object): """Defines an abstract network model for use with RLlib. + This class is deprecated: please use TFModelV2 instead. + Models convert input tensors to a number of output features. These features can then be interpreted by ActionDistribution classes to determine e.g. agent action values. diff --git a/python/ray/rllib/models/modelv2.py b/python/ray/rllib/models/modelv2.py index a5162be37..77cab1e57 100644 --- a/python/ray/rllib/models/modelv2.py +++ b/python/ray/rllib/models/modelv2.py @@ -3,8 +3,10 @@ from __future__ import division from __future__ import print_function from ray.rllib.models.model import restore_original_dimensions +from ray.rllib.utils.annotations import PublicAPI +@PublicAPI class ModelV2(object): """Defines a Keras-style abstract network model for use with RLlib. @@ -39,7 +41,6 @@ class ModelV2(object): self.model_config = model_config self.name = name or "default_model" self.framework = framework - self.var_list = [] def get_initial_state(self): """Get the initial recurrent state values for the model. @@ -118,19 +119,7 @@ class ModelV2(object): """ return {} - def register_variables(self, variables): - """Register the given list of variables with this model.""" - self.var_list.extend(variables) - - def variables(self): - """Returns the list of variables for this model.""" - return list(self.var_list) - - def trainable_variables(self): - """Returns the list of trainable variables for this model.""" - return [v for v in self.variables() if v.trainable] - - def __call__(self, input_dict, state, seq_lens): + def __call__(self, input_dict, state=None, seq_lens=None): """Call the model with the given input tensors and state. This is the method used by RLlib to execute the forward pass. It calls @@ -156,7 +145,7 @@ class ModelV2(object): restored["obs"] = restore_original_dimensions( input_dict["obs"], self.obs_space, self.framework) restored["obs_flat"] = input_dict["obs"] - outputs, state = self.forward(restored, state, seq_lens) + outputs, state = self.forward(restored, state or [], seq_lens) try: shape = outputs.shape diff --git a/python/ray/rllib/models/tf/modelv1_compat.py b/python/ray/rllib/models/tf/modelv1_compat.py index b55a6cf98..4291ad1a4 100644 --- a/python/ray/rllib/models/tf/modelv1_compat.py +++ b/python/ray/rllib/models/tf/modelv1_compat.py @@ -98,7 +98,7 @@ def make_v1_wrapper(legacy_model_cls): "Cannot get update ops before wrapped v1 model init") return list(self._update_ops) - @override(ModelV2) + @override(TFModelV2) def variables(self): var_list = super(ModelV1Wrapper, self).variables() for v in scope_vars(self.variable_scope): diff --git a/python/ray/rllib/models/tf/tf_modelv2.py b/python/ray/rllib/models/tf/tf_modelv2.py index 26fb117d6..9a7b8cfa2 100644 --- a/python/ray/rllib/models/tf/tf_modelv2.py +++ b/python/ray/rllib/models/tf/tf_modelv2.py @@ -3,13 +3,18 @@ from __future__ import division from __future__ import print_function from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils import try_import_tf tf = try_import_tf() +@PublicAPI class TFModelV2(ModelV2): - """TF version of ModelV2.""" + """TF version of ModelV2. + + Note that this class by itself is not a valid model unless you + implement forward() in a subclass.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): @@ -21,9 +26,22 @@ class TFModelV2(ModelV2): model_config, name, framework="tf") + self.var_list = [] def update_ops(self): """Return the list of update ops for this model. For example, this should include any BatchNorm update ops.""" return [] + + def register_variables(self, variables): + """Register the given list of variables with this model.""" + self.var_list.extend(variables) + + def variables(self): + """Returns the list of variables for this model.""" + return list(self.var_list) + + def trainable_variables(self): + """Returns the list of trainable variables for this model.""" + return [v for v in self.variables() if v.trainable] diff --git a/python/ray/rllib/models/torch/fcnet.py b/python/ray/rllib/models/torch/fcnet.py index 68957ba11..efdfd5d77 100644 --- a/python/ray/rllib/models/torch/fcnet.py +++ b/python/ray/rllib/models/torch/fcnet.py @@ -6,7 +6,7 @@ import logging import numpy as np import torch.nn as nn -from ray.rllib.models.torch.model import TorchModel +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.misc import normc_initializer, SlimFC, \ _get_activation_fn from ray.rllib.utils.annotations import override @@ -14,13 +14,16 @@ from ray.rllib.utils.annotations import override logger = logging.getLogger(__name__) -class FullyConnectedNetwork(TorchModel): +class FullyConnectedNetwork(TorchModelV2): """Generic fully connected network.""" - def __init__(self, obs_space, num_outputs, options): - TorchModel.__init__(self, obs_space, num_outputs, options) - hiddens = options.get("fcnet_hiddens") - activation = _get_activation_fn(options.get("fcnet_activation")) + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super(FullyConnectedNetwork, self).__init__( + obs_space, action_space, num_outputs, model_config, name) + + hiddens = model_config.get("fcnet_hiddens") + activation = _get_activation_fn(model_config.get("fcnet_activation")) logger.debug("Constructing fcnet {} {}".format(hiddens, activation)) layers = [] last_layer_size = np.product(obs_space.shape) @@ -45,13 +48,17 @@ class FullyConnectedNetwork(TorchModel): out_size=1, initializer=normc_initializer(1.0), activation_fn=None) + self._cur_value = None - @override(nn.Module) - def forward(self, input_dict, hidden_state): - # Note that we override forward() and not _forward() to get the - # flattened obs here. - obs = input_dict["obs"] + @override(TorchModelV2) + def forward(self, input_dict, state, seq_lens): + obs = input_dict["obs_flat"] features = self._hidden_layers(obs.reshape(obs.shape[0], -1)) logits = self._logits(features) - value = self._value_branch(features).squeeze(1) - return logits, features, value, hidden_state + self._cur_value = self._value_branch(features).squeeze(1) + return logits, state + + @override(TorchModelV2) + def value_function(self): + assert self._cur_value is not None, "must call forward() first" + return self._cur_value diff --git a/python/ray/rllib/models/torch/model.py b/python/ray/rllib/models/torch/model.py deleted file mode 100644 index e06e3f1bb..000000000 --- a/python/ray/rllib/models/torch/model.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -import torch.nn as nn - -from ray.rllib.models.model import restore_original_dimensions -from ray.rllib.utils.annotations import PublicAPI - - -# TODO(ekl) rewrite using modelv2 -@PublicAPI -class TorchModel(nn.Module): - """Defines an abstract network model for use with RLlib / PyTorch.""" - - def __init__(self, obs_space, num_outputs, options): - """All custom RLlib torch models must support this constructor. - - Arguments: - obs_space (gym.Space): Input observation space. - num_outputs (int): Output tensor must be of size - [BATCH_SIZE, num_outputs]. - options (dict): Dictionary of model options. - """ - nn.Module.__init__(self) - self.obs_space = obs_space - self.num_outputs = num_outputs - self.options = options - - @PublicAPI - def forward(self, input_dict, hidden_state): - """Wraps _forward() to unpack flattened Dict and Tuple observations.""" - input_dict["obs"] = input_dict["obs"].float() # TODO(ekl): avoid cast - input_dict["obs_flat"] = input_dict["obs"] - input_dict["obs"] = restore_original_dimensions( - input_dict["obs"], self.obs_space, tensorlib=torch) - outputs, features, vf, h = self._forward(input_dict, hidden_state) - return outputs, features, vf, h - - @PublicAPI - def state_init(self): - """Returns a list of initial hidden state tensors, if any.""" - return [] - - @PublicAPI - def _forward(self, input_dict, hidden_state): - """Forward pass for the model. - - Prefer implementing this instead of forward() directly for proper - handling of Dict and Tuple observations. - - Arguments: - input_dict (dict): Dictionary of tensor inputs, commonly - including "obs", "prev_action", "prev_reward", each of shape - [BATCH_SIZE, ...]. - hidden_state (list): List of hidden state tensors, each of shape - [BATCH_SIZE, h_size]. - - Returns: - (outputs, feature_layer, values, state): Tensors of size - [BATCH_SIZE, num_outputs], [BATCH_SIZE, desired_feature_size], - [BATCH_SIZE], and [len(hidden_state), BATCH_SIZE, h_size]. - """ - raise NotImplementedError diff --git a/python/ray/rllib/models/torch/torch_modelv2.py b/python/ray/rllib/models/torch/torch_modelv2.py index 901cf91cf..f02f0ffc5 100644 --- a/python/ray/rllib/models/torch/torch_modelv2.py +++ b/python/ray/rllib/models/torch/torch_modelv2.py @@ -2,19 +2,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import torch.nn as nn + from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.utils.annotations import PublicAPI -class TorchModelV2(ModelV2): - """Torch version of ModelV2.""" +@PublicAPI +class TorchModelV2(ModelV2, nn.Module): + """Torch version of ModelV2. - def __init__(self, obs_space, action_space, output_spec, model_config, + Note that this class by itself is not a valid model unless you + implement forward() in a subclass.""" + + def __init__(self, obs_space, action_space, num_outputs, model_config, name): ModelV2.__init__( self, obs_space, action_space, - output_spec, + num_outputs, model_config, name, framework="torch") + nn.Module.__init__(self) diff --git a/python/ray/rllib/models/torch/visionnet.py b/python/ray/rllib/models/torch/visionnet.py index 9851b91ab..64efc10a3 100644 --- a/python/ray/rllib/models/torch/visionnet.py +++ b/python/ray/rllib/models/torch/visionnet.py @@ -4,19 +4,22 @@ from __future__ import print_function import torch.nn as nn -from ray.rllib.models.torch.model import TorchModel +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.misc import normc_initializer, valid_padding, \ SlimConv2d, SlimFC from ray.rllib.models.visionnet import _get_filter_config from ray.rllib.utils.annotations import override -class VisionNetwork(TorchModel): +class VisionNetwork(TorchModelV2): """Generic vision network.""" - def __init__(self, obs_space, num_outputs, options): - TorchModel.__init__(self, obs_space, num_outputs, options) - filters = options.get("conv_filters") + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super(VisionNetwork, self).__init__(obs_space, action_space, + num_outputs, model_config, name) + + filters = model_config.get("conv_filters") if not filters: filters = _get_filter_config(obs_space.shape) layers = [] @@ -40,13 +43,19 @@ class VisionNetwork(TorchModel): out_channels, num_outputs, initializer=nn.init.xavier_uniform_) self._value_branch = SlimFC( out_channels, 1, initializer=normc_initializer()) + self._cur_value = None - @override(TorchModel) - def _forward(self, input_dict, hidden_state): - features = self._hidden_layers(input_dict["obs"]) + @override(TorchModelV2) + def forward(self, input_dict, state, seq_lens): + features = self._hidden_layers(input_dict["obs"].float()) logits = self._logits(features) - value = self._value_branch(features).squeeze(1) - return logits, features, value, hidden_state + self._cur_value = self._value_branch(features).squeeze(1) + return logits, state + + @override(TorchModelV2) + def value_function(self): + assert self._cur_value is not None, "must call forward() first" + return self._cur_value def _hidden_layers(self, obs): res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major diff --git a/python/ray/rllib/models/visionnet.py b/python/ray/rllib/models/visionnet.py index fabaf4c03..ff64b4a4b 100644 --- a/python/ray/rllib/models/visionnet.py +++ b/python/ray/rllib/models/visionnet.py @@ -10,6 +10,7 @@ from ray.rllib.utils import try_import_tf tf = try_import_tf() +# TODO(ekl) rewrite this using ModelV2 class VisionNetwork(Model): """Generic vision network.""" diff --git a/python/ray/rllib/policy/torch_policy.py b/python/ray/rllib/policy/torch_policy.py index 045902621..16eeb6a0a 100644 --- a/python/ray/rllib/policy/torch_policy.py +++ b/python/ray/rllib/policy/torch_policy.py @@ -76,14 +76,14 @@ class TorchPolicy(Policy): input_dict["prev_actions"] = prev_action_batch if prev_reward_batch: input_dict["prev_rewards"] = prev_reward_batch - model_out = self._model(input_dict, state_batches) - logits, _, vf, state = model_out + model_out = self._model(input_dict, state_batches, [1]) + logits, state = model_out action_dist = self._action_dist_cls(logits) actions = action_dist.sample() return (actions.cpu().numpy(), [h.cpu().numpy() for h in state], self.extra_action_out(input_dict, state_batches, - model_out)) + self._model)) @override(Policy) def learn_on_batch(self, postprocessed_batch): @@ -145,20 +145,20 @@ class TorchPolicy(Policy): @override(Policy) def get_initial_state(self): - return [s.numpy() for s in self._model.state_init()] + return [s.numpy() for s in self._model.get_initial_state()] def extra_grad_process(self): """Allow subclass to do extra processing on gradients and return processing info.""" return {} - def extra_action_out(self, input_dict, state_batches, model_out): + def extra_action_out(self, input_dict, state_batches, model): """Returns dict of extra info to include in experience batch. Arguments: input_dict (dict): Dict of model input tensors. state_batches (list): List of state tensors. - model_out (list): Outputs of the policy model module.""" + model (TorchModelV2): Reference to the model.""" return {} def extra_grad_info(self, batch_tensors): diff --git a/python/ray/rllib/policy/torch_policy_template.py b/python/ray/rllib/policy/torch_policy_template.py index f1b0c0c68..d3d7e6987 100644 --- a/python/ray/rllib/policy/torch_policy_template.py +++ b/python/ray/rllib/policy/torch_policy_template.py @@ -74,8 +74,12 @@ def build_torch_policy(name, else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"], torch=True) - self.model = ModelCatalog.get_torch_model( - obs_space, logit_dim, self.config["model"]) + self.model = ModelCatalog.get_model_v2( + obs_space, + action_space, + logit_dim, + self.config["model"], + framework="torch") TorchPolicy.__init__(self, obs_space, action_space, self.model, loss_fn, self.dist_class) @@ -101,13 +105,13 @@ def build_torch_policy(name, return TorchPolicy.extra_grad_process(self) @override(TorchPolicy) - def extra_action_out(self, input_dict, state_batches, model_out): + def extra_action_out(self, input_dict, state_batches, model): if extra_action_out_fn: return extra_action_out_fn(self, input_dict, state_batches, - model_out) + model) else: return TorchPolicy.extra_action_out(self, input_dict, - state_batches, model_out) + state_batches, model) @override(TorchPolicy) def optimizer(self): diff --git a/python/ray/rllib/tests/test_nested_spaces.py b/python/ray/rllib/tests/test_nested_spaces.py index 610fad2f1..548b2c2b4 100644 --- a/python/ray/rllib/tests/test_nested_spaces.py +++ b/python/ray/rllib/tests/test_nested_spaces.py @@ -19,7 +19,7 @@ from ray.rllib.env.vector_env import VectorEnv from ray.rllib.models import ModelCatalog from ray.rllib.models.model import Model from ray.rllib.models.torch.fcnet import FullyConnectedNetwork -from ray.rllib.models.torch.model import TorchModel +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.rollout import rollout from ray.rllib.tests.test_external_env import SimpleServing from ray.tune.registry import register_env @@ -133,16 +133,18 @@ class InvalidModel2(Model): return tf.constant(0), tf.constant(0) -class TorchSpyModel(TorchModel): +class TorchSpyModel(TorchModelV2): capture_index = 0 - def __init__(self, obs_space, num_outputs, options): - TorchModel.__init__(self, obs_space, num_outputs, options) + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super(TorchSpyModel, self).__init__(obs_space, action_space, + num_outputs, model_config, name) self.fc = FullyConnectedNetwork( obs_space.original_space.spaces["sensors"].spaces["position"], - num_outputs, options) + action_space, num_outputs, model_config, name) - def _forward(self, input_dict, hidden_state): + def forward(self, input_dict, state, seq_lens): pos = input_dict["obs"]["sensors"]["position"].numpy() front_cam = input_dict["obs"]["sensors"]["front_cam"][0].numpy() task = input_dict["obs"]["inner_state"]["job_status"]["task"].numpy() @@ -153,7 +155,10 @@ class TorchSpyModel(TorchModel): TorchSpyModel.capture_index += 1 return self.fc({ "obs": input_dict["obs"]["sensors"]["position"] - }, hidden_state) + }, state, seq_lens) + + def value_function(self): + return self.fc.value_function() class DictSpyModel(Model):