diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 59678af7e..279256de4 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -453,7 +453,7 @@ with the remaining non-image (flat) inputs (the 1D Box and discrete/one-hot comp Take a look at this model example that does exactly that: -.. literalinclude:: ../../rllib/examples/models/cnn_plus_fc_concat_model.py +.. literalinclude:: ../../rllib/models/tf/complex_input_net.py :language: python :start-after: __sphinx_doc_begin__ :end-before: __sphinx_doc_end__ diff --git a/rllib/agents/sac/sac.py b/rllib/agents/sac/sac.py index 5c476248c..97d0f7d77 100644 --- a/rllib/agents/sac/sac.py +++ b/rllib/agents/sac/sac.py @@ -16,6 +16,7 @@ from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer from ray.rllib.agents.sac.sac_tf_policy import SACTFPolicy from ray.rllib.policy.policy import Policy +from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.typing import TrainerConfigDict logger = logging.getLogger(__name__) @@ -39,16 +40,37 @@ DEFAULT_CONFIG = with_common_config({ # Use a e.g. conv2D state preprocessing network before concatenating the # resulting (feature) vector with the action input for the input to # the Q-networks. - "use_state_preprocessor": False, - # Model options for the Q network(s). + "use_state_preprocessor": DEPRECATED_VALUE, + # Model options for the Q network(s). These will override MODEL_DEFAULTS. + # The `Q_model` dict is treated just as the top-level `model` dict in + # setting up the Q-network(s) (2 if twin_q=True). + # That means, you can do for different observation spaces: + # obs=Box(1D) -> Tuple(Box(1D) + Action) -> concat -> post_fcnet + # obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action + # -> post_fcnet + # obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action) + # -> vision-net -> concat w/ Box(1D) and action -> post_fcnet + # You can also have SAC use your custom_model as Q-model(s), by simply + # specifying the `custom_model` sub-key in below dict (just like you would + # do in the top-level `model` dict. "Q_model": { - "fcnet_activation": "relu", "fcnet_hiddens": [256, 256], + "fcnet_activation": "relu", + "post_fcnet_hiddens": [], + "post_fcnet_activation": None, + "custom_model": None, # Use this to define custom Q-model(s). + "custom_model_config": {}, }, - # Model options for the policy function. + # Model options for the policy function (see `Q_model` above for details). + # The difference to `Q_model` above is that no action concat'ing is + # performed before the post_fcnet stack. "policy_model": { - "fcnet_activation": "relu", "fcnet_hiddens": [256, 256], + "fcnet_activation": "relu", + "post_fcnet_hiddens": [], + "post_fcnet_activation": None, + "custom_model": None, # Use this to define a custom policy model. + "custom_model_config": {}, }, # Unsquash actions to the upper and lower bounds of env's action space. # Ignored for discrete action spaces. @@ -145,11 +167,10 @@ def validate_config(config: TrainerConfigDict) -> None: Raises: ValueError: In case something is wrong with the config. """ - if config["model"].get("custom_model"): - logger.warning( - "Setting use_state_preprocessor=True since a custom model " - "was specified.") - config["use_state_preprocessor"] = True + if config["use_state_preprocessor"] != DEPRECATED_VALUE: + deprecation_warning( + old="config['use_state_preprocessor']", error=False) + config["use_state_preprocessor"] = DEPRECATED_VALUE if config["grad_clip"] is not None and config["grad_clip"] <= 0.0: raise ValueError("`grad_clip` value must be > 0.0!") diff --git a/rllib/agents/sac/sac_tf_model.py b/rllib/agents/sac/sac_tf_model.py index 4c890385f..e2c56b521 100644 --- a/rllib/agents/sac/sac_tf_model.py +++ b/rllib/agents/sac/sac_tf_model.py @@ -1,9 +1,12 @@ import gym from gym.spaces import Box, Discrete import numpy as np -from typing import Optional, Tuple +from typing import Dict, List, Optional +from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.typing import ModelConfigDict, TensorType @@ -14,14 +17,21 @@ tf1, tf, tfv = try_import_tf() class SACTFModel(TFModelV2): """Extension of the standard TFModelV2 for SAC. - Instances of this Model get created via wrapping this class around another - default- or custom model (inside - rllib/agents/sac/sac_tf_policy.py::build_sac_model). Doing so simply adds - this class' methods (`get_q_values`, etc..) to the wrapped model, such that - the wrapped model can be used by the SAC algorithm. + To customize, do one of the following: + - sub-class SACTFModel and override one or more of its methods. + - Use SAC's `Q_model` and `policy_model` keys to tweak the default model + behaviors (e.g. fcnet_hiddens, conv_filters, etc..). + - Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys + to specify your own custom Q-model(s) and policy-models, which will be + created within this SACTFModel (see `build_policy_model` and + `build_q_model`. + + Note: It is not recommended to override the `forward` method for SAC. This + would lead to shared weights (between policy and Q-nets), which will then + not be optimized by either of the critic- or actor-optimizers! Data flow: - `obs` -> forward() -> `model_out` + `obs` -> forward() (should stay a noop method!) -> `model_out` `model_out` -> get_policy_output() -> pi(actions|obs) `model_out`, `actions` -> get_q_values() -> Q(s, a) `model_out`, `actions` -> get_twin_q_values() -> Q_twin(s, a) @@ -33,20 +43,18 @@ class SACTFModel(TFModelV2): num_outputs: Optional[int], model_config: ModelConfigDict, name: str, - actor_hidden_activation: str = "relu", - actor_hiddens: Tuple[int] = (256, 256), - critic_hidden_activation: str = "relu", - critic_hiddens: Tuple[int] = (256, 256), + policy_model_config: ModelConfigDict = None, + q_model_config: ModelConfigDict = None, twin_q: bool = False, initial_alpha: float = 1.0, target_entropy: Optional[float] = None): """Initialize a SACTFModel instance. Args: - actor_hidden_activation (str): Activation for the actor network. - actor_hiddens (list): Hidden layers sizes for the actor network. - critic_hidden_activation (str): Activation for the critic network. - critic_hiddens (list): Hidden layers sizes for the critic network. + policy_model_config (ModelConfigDict): The config dict for the + policy network. + q_model_config (ModelConfigDict): The config dict for the + Q-network(s) (2 if twin_q=True). twin_q (bool): Build twin Q networks (Q-net and target) for more stable Q-learning. initial_alpha (float): The initial value for the to-be-optimized @@ -77,54 +85,15 @@ class SACTFModel(TFModelV2): action_outs = self.action_dim q_outs = 1 - self.model_out = tf.keras.layers.Input( - shape=(self.num_outputs, ), name="model_out") - self.action_model = tf.keras.Sequential([ - tf.keras.layers.Dense( - units=hidden, - activation=getattr(tf.nn, actor_hidden_activation, None), - name="action_{}".format(i + 1)) - for i, hidden in enumerate(actor_hiddens) - ] + [ - tf.keras.layers.Dense( - units=action_outs, activation=None, name="action_out") - ]) - self.shift_and_log_scale_diag = self.action_model(self.model_out) - - self.actions_input = None - if not self.discrete: - self.actions_input = tf.keras.layers.Input( - shape=(self.action_dim, ), name="actions") - - def build_q_net(name, observations, actions): - # For continuous actions: Feed obs and actions (concatenated) - # through the NN. For discrete actions, only obs. - q_net = tf.keras.Sequential(([ - tf.keras.layers.Concatenate(axis=1), - ] if not self.discrete else []) + [ - tf.keras.layers.Dense( - units=units, - activation=getattr(tf.nn, critic_hidden_activation, None), - name="{}_hidden_{}".format(name, i)) - for i, units in enumerate(critic_hiddens) - ] + [ - tf.keras.layers.Dense( - units=q_outs, activation=None, name="{}_out".format(name)) - ]) - - # TODO(hartikainen): Remove the unnecessary Model calls here - if self.discrete: - q_net = tf.keras.Model(observations, q_net(observations)) - else: - q_net = tf.keras.Model([observations, actions], - q_net([observations, actions])) - return q_net - - self.q_net = build_q_net("q", self.model_out, self.actions_input) + self.action_model = self.build_policy_model( + self.obs_space, action_outs, policy_model_config, "policy_model") + self.q_net = self.build_q_model(self.obs_space, self.action_space, + q_outs, q_model_config, "q") if twin_q: - self.twin_q_net = build_q_net("twin_q", self.model_out, - self.actions_input) + self.twin_q_net = self.build_q_model(self.obs_space, + self.action_space, q_outs, + q_model_config, "twin_q") else: self.twin_q_net = None @@ -143,6 +112,80 @@ class SACTFModel(TFModelV2): target_entropy = -np.prod(action_space.shape) self.target_entropy = target_entropy + @override(TFModelV2) + def forward(self, input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): + """The common (Q-net and policy-net) forward pass. + + NOTE: It is not(!) recommended to override this method as it would + introduce a shared pre-network, which would be updated by both + actor- and critic optimizers. + """ + return input_dict["obs"], state + + def build_policy_model(self, obs_space, num_outputs, policy_model_config, + name): + """Builds the policy model used by this SAC. + + Override this method in a sub-class of SACTFModel to implement your + own policy net. Alternatively, simply set `custom_model` within the + top level SAC `policy_model` config key to make this default + implementation of `build_policy_model` use your custom policy network. + + Returns: + TFModelV2: The TFModelV2 policy sub-model. + """ + model = ModelCatalog.get_model_v2( + obs_space, + self.action_space, + num_outputs, + policy_model_config, + framework="tf", + name=name) + return model + + def build_q_model(self, obs_space, action_space, num_outputs, + q_model_config, name): + """Builds one of the (twin) Q-nets used by this SAC. + + Override this method in a sub-class of SACTFModel to implement your + own Q-nets. Alternatively, simply set `custom_model` within the + top level SAC `Q_model` config key to make this default implementation + of `build_q_model` use your custom Q-nets. + + Returns: + TFModelV2: The TFModelV2 Q-net sub-model. + """ + self.concat_obs_and_actions = False + if self.discrete: + input_space = obs_space + else: + orig_space = getattr(obs_space, "original_space", obs_space) + if isinstance(orig_space, Box) and len(orig_space.shape) == 1: + input_space = Box( + float("-inf"), + float("inf"), + shape=(orig_space.shape[0] + action_space.shape[0], )) + self.concat_obs_and_actions = True + else: + if isinstance(orig_space, gym.spaces.Tuple): + spaces = orig_space.spaces + elif isinstance(orig_space, gym.spaces.Dict): + spaces = list(orig_space.spaces.values()) + else: + spaces = [obs_space] + input_space = gym.spaces.Tuple(spaces + [action_space]) + + model = ModelCatalog.get_model_v2( + input_space, + action_space, + num_outputs, + q_model_config, + framework="tf", + name=name) + return model + def get_q_values(self, model_out: TensorType, actions: Optional[TensorType] = None) -> TensorType: @@ -161,12 +204,7 @@ class SACTFModel(TFModelV2): Returns: TensorType: Q-values tensor of shape [BATCH_SIZE, 1]. """ - # Continuous case -> concat actions to model_out. - if actions is not None: - return self.q_net([model_out, actions]) - # Discrete case -> return q-vals for all actions. - else: - return self.q_net(model_out) + return self._get_q_value(model_out, actions, self.q_net) def get_twin_q_values(self, model_out: TensorType, @@ -185,12 +223,32 @@ class SACTFModel(TFModelV2): Returns: TensorType: Q-values tensor of shape [BATCH_SIZE, 1]. """ + return self._get_q_value(model_out, actions, self.twin_q_net) + + def _get_q_value(self, model_out, actions, net): + # Model outs may come as original Tuple/Dict observations, concat them + # here if this is the case. + if isinstance(net.obs_space, Box): + if isinstance(model_out, (list, tuple)): + model_out = tf.concat(model_out, axis=-1) + elif isinstance(model_out, dict): + model_out = list(model_out.values()) + # Continuous case -> concat actions to model_out. if actions is not None: - return self.twin_q_net([model_out, actions]) + if self.concat_obs_and_actions: + input_dict = {"obs": tf.concat([model_out, actions], axis=-1)} + else: + input_dict = {"obs": force_list(model_out) + [actions]} # Discrete case -> return q-vals for all actions. else: - return self.twin_q_net(model_out) + input_dict = {"obs": model_out} + # Switch on training mode (when getting Q-values, we are usually in + # training). + input_dict["is_training"] = True + + out, _ = net(input_dict, [], None) + return out def get_policy_output(self, model_out: TensorType) -> TensorType: """Returns policy outputs, given the output of self.__call__(). @@ -207,15 +265,23 @@ class SACTFModel(TFModelV2): Returns: TensorType: Distribution inputs for sampling actions. """ - return self.action_model(model_out) + # Model outs may come as original Tuple observations, concat them + # here if this is the case. + if isinstance(self.action_model.obs_space, Box): + if isinstance(model_out, (list, tuple)): + model_out = tf.concat(model_out, axis=-1) + elif isinstance(model_out, dict): + model_out = tf.concat(list(model_out.values()), axis=-1) + out, _ = self.action_model({"obs": model_out}, [], None) + return out def policy_variables(self): """Return the list of variables for the policy net.""" - return list(self.action_model.variables) + return self.action_model.variables() def q_variables(self): """Return the list of variables for Q / twin Q nets.""" - return self.q_net.variables + (self.twin_q_net.variables - if self.twin_q_net else []) + return self.q_net.variables() + (self.twin_q_net.variables() + if self.twin_q_net else []) diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index 44ddbff1f..83fa076ed 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -6,6 +6,7 @@ import gym from gym.spaces import Box, Discrete from functools import partial import logging +import numpy as np from typing import Dict, List, Optional, Tuple, Type, Union import ray @@ -17,7 +18,7 @@ from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \ from ray.rllib.agents.sac.sac_tf_model import SACTFModel from ray.rllib.agents.sac.sac_torch_model import SACTorchModel from ray.rllib.evaluation.episode import MultiAgentEpisode -from ray.rllib.models import ModelCatalog +from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import Beta, Categorical, \ DiagGaussian, Dirichlet, SquashedGaussian, TFActionDistribution @@ -55,40 +56,35 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space, `policy.target_model`. """ # With separate state-preprocessor (before obs+action concat). - if config["use_state_preprocessor"]: - num_outputs = 256 # Flatten last Conv2D to this many nodes. - # No separate state-preprocessor: concat obs+actions right away. - else: - num_outputs = 0 - # No state preprocessor: fcnet_hiddens should be empty. - if config["model"]["fcnet_hiddens"]: - logger.warning( - "When not using a state-preprocessor with SAC, `fcnet_hiddens`" - " will be set to an empty list! Any hidden layer sizes are " - "defined via `policy_model.fcnet_hiddens` and " - "`Q_model.fcnet_hiddens`.") - config["model"]["fcnet_hiddens"] = [] + num_outputs = int(np.product(obs_space.shape)) # Force-ignore any additionally provided hidden layer sizes. # Everything should be configured using SAC's "Q_model" and "policy_model" # settings. + policy_model_config = MODEL_DEFAULTS.copy() + policy_model_config.update(config["policy_model"]) + q_model_config = MODEL_DEFAULTS.copy() + q_model_config.update(config["Q_model"]) + + default_model_cls = SACTorchModel if config["framework"] == "torch" \ + else SACTFModel + model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=num_outputs, model_config=config["model"], framework=config["framework"], - model_interface=SACTorchModel - if config["framework"] == "torch" else SACTFModel, + default_model=default_model_cls, name="sac_model", - actor_hidden_activation=config["policy_model"]["fcnet_activation"], - actor_hiddens=config["policy_model"]["fcnet_hiddens"], - critic_hidden_activation=config["Q_model"]["fcnet_activation"], - critic_hiddens=config["Q_model"]["fcnet_hiddens"], + policy_model_config=policy_model_config, + q_model_config=q_model_config, twin_q=config["twin_q"], initial_alpha=config["initial_alpha"], target_entropy=config["target_entropy"]) + assert isinstance(model, default_model_cls) + # Create an exact copy of the model and store it in `policy.target_model`. # This will be used for tau-synched Q-target models that run behind the # actual Q-networks and are used for target q-value calculations in the @@ -99,17 +95,16 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space, num_outputs=num_outputs, model_config=config["model"], framework=config["framework"], - model_interface=SACTorchModel - if config["framework"] == "torch" else SACTFModel, + default_model=default_model_cls, name="target_sac_model", - actor_hidden_activation=config["policy_model"]["fcnet_activation"], - actor_hiddens=config["policy_model"]["fcnet_hiddens"], - critic_hidden_activation=config["Q_model"]["fcnet_activation"], - critic_hiddens=config["Q_model"]["fcnet_hiddens"], + policy_model_config=policy_model_config, + q_model_config=q_model_config, twin_q=config["twin_q"], initial_alpha=config["initial_alpha"], target_entropy=config["target_entropy"]) + assert isinstance(policy.target_model, default_model_cls) + return model @@ -198,14 +193,14 @@ def get_distribution_inputs_and_class( dist inputs, dist class, and a list of internal state outputs (in the RNN case). """ - # Get base-model output (w/o the SAC specific parts of the network). - model_out, state_out = model({ + # Get base-model (forward) output (this should be a noop call). + forward_out, state_out = model({ "obs": obs_batch, "is_training": policy._get_is_training_placeholder(), }, [], None) # Use the base output to get the policy outputs from the SAC model's # policy components. - distribution_inputs = model.get_policy_output(model_out) + distribution_inputs = model.get_policy_output(forward_out) # Get a distribution class to be used with the just calculated dist-inputs. action_dist_class = _get_dist_class(policy.config, policy.action_space) diff --git a/rllib/agents/sac/sac_torch_model.py b/rllib/agents/sac/sac_torch_model.py index 5f8b05980..f3fe34e23 100644 --- a/rllib/agents/sac/sac_torch_model.py +++ b/rllib/agents/sac/sac_torch_model.py @@ -1,11 +1,12 @@ import gym from gym.spaces import Box, Discrete import numpy as np -from typing import Optional, Tuple +from typing import Dict, List, Optional -from ray.rllib.models.torch.misc import SlimFC +from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.models.utils import get_activation_fn +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.typing import ModelConfigDict, TensorType @@ -16,14 +17,21 @@ torch, nn = try_import_torch() class SACTorchModel(TorchModelV2, nn.Module): """Extension of the standard TorchModelV2 for SAC. - Instances of this Model get created via wrapping this class around another - default- or custom model (inside - rllib/agents/sac/sac_torch_policy.py::build_sac_model). Doing so simply - adds this class' methods (`get_q_values`, etc..) to the wrapped model, such - that the wrapped model can be used by the SAC algorithm. + To customize, do one of the following: + - sub-class SACTorchModel and override one or more of its methods. + - Use SAC's `Q_model` and `policy_model` keys to tweak the default model + behaviors (e.g. fcnet_hiddens, conv_filters, etc..). + - Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys + to specify your own custom Q-model(s) and policy-models, which will be + created within this SACTFModel (see `build_policy_model` and + `build_q_model`. + + Note: It is not recommended to override the `forward` method for SAC. This + would lead to shared weights (between policy and Q-nets), which will then + not be optimized by either of the critic- or actor-optimizers! Data flow: - `obs` -> forward() -> `model_out` + `obs` -> forward() (should stay a noop method!) -> `model_out` `model_out` -> get_policy_output() -> pi(actions|obs) `model_out`, `actions` -> get_q_values() -> Q(s, a) `model_out`, `actions` -> get_twin_q_values() -> Q_twin(s, a) @@ -35,20 +43,18 @@ class SACTorchModel(TorchModelV2, nn.Module): num_outputs: Optional[int], model_config: ModelConfigDict, name: str, - actor_hidden_activation: str = "relu", - actor_hiddens: Tuple[int] = (256, 256), - critic_hidden_activation: str = "relu", - critic_hiddens: Tuple[int] = (256, 256), + policy_model_config: ModelConfigDict = None, + q_model_config: ModelConfigDict = None, twin_q: bool = False, initial_alpha: float = 1.0, target_entropy: Optional[float] = None): """Initializes a SACTorchModel instance. 7 Args: - actor_hidden_activation (str): Activation for the actor network. - actor_hiddens (list): Hidden layers sizes for the actor network. - critic_hidden_activation (str): Activation for the critic network. - critic_hiddens (list): Hidden layers sizes for the critic network. + policy_model_config (ModelConfigDict): The config dict for the + policy network. + q_model_config (ModelConfigDict): The config dict for the + Q-network(s) (2 if twin_q=True). twin_q (bool): Build twin Q networks (Q-net and target) for more stable Q-learning. initial_alpha (float): The initial value for the to-be-optimized @@ -69,74 +75,29 @@ class SACTorchModel(TorchModelV2, nn.Module): self.action_dim = action_space.n self.discrete = True action_outs = q_outs = self.action_dim - action_ins = None # No action inputs for the discrete case. elif isinstance(action_space, Box): self.action_dim = np.product(action_space.shape) self.discrete = False action_outs = 2 * self.action_dim - action_ins = self.action_dim q_outs = 1 else: assert isinstance(action_space, Simplex) self.action_dim = np.product(action_space.shape) self.discrete = False action_outs = self.action_dim - action_ins = self.action_dim q_outs = 1 # Build the policy network. - self.action_model = nn.Sequential() - ins = self.num_outputs - self.obs_ins = ins - activation = get_activation_fn( - actor_hidden_activation, framework="torch") - for i, n in enumerate(actor_hiddens): - self.action_model.add_module( - "action_{}".format(i), - SlimFC( - ins, - n, - initializer=torch.nn.init.xavier_uniform_, - activation_fn=activation)) - ins = n - self.action_model.add_module( - "action_out", - SlimFC( - ins, - action_outs, - initializer=torch.nn.init.xavier_uniform_, - activation_fn=None)) + self.action_model = self.build_policy_model( + self.obs_space, action_outs, policy_model_config, "policy_model") - # Build the Q-net(s), including target Q-net(s). - def build_q_net(name_): - activation = get_activation_fn( - critic_hidden_activation, framework="torch") - # For continuous actions: Feed obs and actions (concatenated) - # through the NN. For discrete actions, only obs. - q_net = nn.Sequential() - ins = self.obs_ins + (0 if self.discrete else action_ins) - for i, n in enumerate(critic_hiddens): - q_net.add_module( - "{}_hidden_{}".format(name_, i), - SlimFC( - ins, - n, - initializer=torch.nn.init.xavier_uniform_, - activation_fn=activation)) - ins = n - - q_net.add_module( - "{}_out".format(name_), - SlimFC( - ins, - q_outs, - initializer=torch.nn.init.xavier_uniform_, - activation_fn=None)) - return q_net - - self.q_net = build_q_net("q") + # Build the Q-network(s). + self.q_net = self.build_q_model(self.obs_space, self.action_space, + q_outs, q_model_config, "q") if twin_q: - self.twin_q_net = build_q_net("twin_q") + self.twin_q_net = self.build_q_model(self.obs_space, + self.action_space, q_outs, + q_model_config, "twin_q") else: self.twin_q_net = None @@ -157,6 +118,80 @@ class SACTorchModel(TorchModelV2, nn.Module): self.target_entropy = torch.tensor( data=[target_entropy], dtype=torch.float32, requires_grad=False) + @override(TorchModelV2) + def forward(self, input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): + """The common (Q-net and policy-net) forward pass. + + NOTE: It is not(!) recommended to override this method as it would + introduce a shared pre-network, which would be updated by both + actor- and critic optimizers. + """ + return input_dict["obs"], state + + def build_policy_model(self, obs_space, num_outputs, policy_model_config, + name): + """Builds the policy model used by this SAC. + + Override this method in a sub-class of SACTFModel to implement your + own policy net. Alternatively, simply set `custom_model` within the + top level SAC `policy_model` config key to make this default + implementation of `build_policy_model` use your custom policy network. + + Returns: + TorchModelV2: The TorchModelV2 policy sub-model. + """ + model = ModelCatalog.get_model_v2( + obs_space, + self.action_space, + num_outputs, + policy_model_config, + framework="torch", + name=name) + return model + + def build_q_model(self, obs_space, action_space, num_outputs, + q_model_config, name): + """Builds one of the (twin) Q-nets used by this SAC. + + Override this method in a sub-class of SACTFModel to implement your + own Q-nets. Alternatively, simply set `custom_model` within the + top level SAC `Q_model` config key to make this default implementation + of `build_q_model` use your custom Q-nets. + + Returns: + TorchModelV2: The TorchModelV2 Q-net sub-model. + """ + self.concat_obs_and_actions = False + if self.discrete: + input_space = obs_space + else: + orig_space = getattr(obs_space, "original_space", obs_space) + if isinstance(orig_space, Box) and len(orig_space.shape) == 1: + input_space = Box( + float("-inf"), + float("inf"), + shape=(orig_space.shape[0] + action_space.shape[0], )) + self.concat_obs_and_actions = True + else: + if isinstance(orig_space, gym.spaces.Tuple): + spaces = orig_space.spaces + elif isinstance(orig_space, gym.spaces.Dict): + spaces = list(orig_space.spaces.values()) + else: + spaces = [obs_space] + input_space = gym.spaces.Tuple(spaces + [action_space]) + + model = ModelCatalog.get_model_v2( + input_space, + action_space, + num_outputs, + q_model_config, + framework="torch", + name=name) + return model + def get_q_values(self, model_out: TensorType, actions: Optional[TensorType] = None) -> TensorType: @@ -175,12 +210,7 @@ class SACTorchModel(TorchModelV2, nn.Module): Returns: TensorType: Q-values tensor of shape [BATCH_SIZE, 1]. """ - # Continuous case -> concat actions to model_out. - if actions is not None: - return self.q_net(torch.cat([model_out, actions], -1)) - # Discrete case -> return q-vals for all actions. - else: - return self.q_net(model_out) + return self._get_q_value(model_out, actions, self.q_net) def get_twin_q_values(self, model_out: TensorType, @@ -199,12 +229,32 @@ class SACTorchModel(TorchModelV2, nn.Module): Returns: TensorType: Q-values tensor of shape [BATCH_SIZE, 1]. """ + return self._get_q_value(model_out, actions, self.twin_q_net) + + def _get_q_value(self, model_out, actions, net): + # Model outs may come as original Tuple observations, concat them + # here if this is the case. + if isinstance(net.obs_space, Box): + if isinstance(model_out, (list, tuple)): + model_out = torch.cat(model_out, dim=-1) + elif isinstance(model_out, dict): + model_out = list(model_out.values()) + # Continuous case -> concat actions to model_out. if actions is not None: - return self.twin_q_net(torch.cat([model_out, actions], -1)) + if self.concat_obs_and_actions: + input_dict = {"obs": torch.cat([model_out, actions], dim=-1)} + else: + input_dict = {"obs": force_list(model_out) + [actions]} # Discrete case -> return q-vals for all actions. else: - return self.twin_q_net(model_out) + input_dict = {"obs": model_out} + # Switch on training mode (when getting Q-values, we are usually in + # training). + input_dict["is_training"] = True + + out, _ = net(input_dict, [], None) + return out def get_policy_output(self, model_out: TensorType) -> TensorType: """Returns policy outputs, given the output of self.__call__(). @@ -221,15 +271,23 @@ class SACTorchModel(TorchModelV2, nn.Module): Returns: TensorType: Distribution inputs for sampling actions. """ - return self.action_model(model_out) + # Model outs may come as original Tuple observations, concat them + # here if this is the case. + if isinstance(self.action_model.obs_space, Box): + if isinstance(model_out, (list, tuple)): + model_out = torch.cat(model_out, dim=-1) + elif isinstance(model_out, dict): + model_out = torch.cat(list(model_out.values()), dim=-1) + out, _ = self.action_model({"obs": model_out}, [], None) + return out def policy_variables(self): """Return the list of variables for the policy net.""" - return list(self.action_model.parameters()) + return self.action_model.variables() def q_variables(self): """Return the list of variables for Q / twin Q nets.""" - return list(self.q_net.parameters()) + \ - (list(self.twin_q_net.parameters()) if self.twin_q_net else []) + return self.q_net.variables() + (self.twin_q_net.variables() + if self.twin_q_net else []) diff --git a/rllib/agents/sac/tests/test_sac.py b/rllib/agents/sac/tests/test_sac.py index 6a84b19c7..1ec873709 100644 --- a/rllib/agents/sac/tests/test_sac.py +++ b/rllib/agents/sac/tests/test_sac.py @@ -1,5 +1,5 @@ from gym import Env -from gym.spaces import Box +from gym.spaces import Box, Discrete, Tuple import numpy as np import re import unittest @@ -9,6 +9,10 @@ import ray.rllib.agents.sac as sac from ray.rllib.agents.sac.sac_tf_policy import sac_actor_critic_loss as tf_loss from ray.rllib.agents.sac.sac_torch_policy import actor_critic_loss as \ loss_torch +from ray.rllib.examples.env.random_env import RandomEnv +from ray.rllib.examples.models.batch_norm_model import KerasBatchNormModel, \ + TorchBatchNormModel +from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.tf.tf_action_dist import Dirichlet from ray.rllib.models.torch.torch_action_dist import TorchDirichlet from ray.rllib.execution.replay_buffer import LocalReplayBuffer @@ -52,7 +56,7 @@ class SimpleEnv(Env): class TestSAC(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - ray.init() + ray.init(local_mode=True) @classmethod def tearDownClass(cls) -> None: @@ -61,22 +65,46 @@ class TestSAC(unittest.TestCase): def test_sac_compilation(self): """Tests whether an SACTrainer can be built with all frameworks.""" config = sac.DEFAULT_CONFIG.copy() + config["Q_model"] = sac.DEFAULT_CONFIG["Q_model"].copy() config["num_workers"] = 0 # Run locally. config["twin_q"] = True - config["soft_horizon"] = True config["clip_actions"] = False config["normalize_actions"] = True config["learning_starts"] = 0 config["prioritized_replay"] = True + config["rollout_fragment_length"] = 10 + config["train_batch_size"] = 10 num_iterations = 1 - for _ in framework_iterator(config): + + ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel) + ModelCatalog.register_custom_model("batch_norm_torch", + TorchBatchNormModel) + + image_space = Box(-1.0, 1.0, shape=(84, 84, 3)) + simple_space = Box(-1.0, 1.0, shape=(3, )) + + for fw in framework_iterator(config): # Test for different env types (discrete w/ and w/o image, + cont). for env in [ - "Pendulum-v0", "MsPacmanNoFrameskip-v4", "CartPole-v0" + RandomEnv, + "MsPacmanNoFrameskip-v4", + "CartPole-v0", ]: print("Env={}".format(env)) - config["use_state_preprocessor"] = \ - env == "MsPacmanNoFrameskip-v4" + if env == RandomEnv: + config["env_config"] = { + "observation_space": Tuple( + [simple_space, + Discrete(2), image_space]), + "action_space": Box(-1.0, 1.0, shape=(1, )), + } + else: + config["env_config"] = {} + # Test making the Q-model a custom one for CartPole, otherwise, + # use the default model. + config["Q_model"]["custom_model"] = "batch_norm{}".format( + "_torch" + if fw == "torch" else "") if env == "CartPole-v0" else None trainer = sac.SACTrainer(config=config, env=env) for i in range(num_iterations): results = trainer.train() @@ -103,49 +131,56 @@ class TestSAC(unittest.TestCase): config["env_config"] = {"simplex_actions": True} map_ = { - # Normal net. - "default_policy/sequential/action_1/kernel": "action_model." - "action_0._model.0.weight", - "default_policy/sequential/action_1/bias": "action_model." - "action_0._model.0.bias", - "default_policy/sequential/action_out/kernel": "action_model." - "action_out._model.0.weight", - "default_policy/sequential/action_out/bias": "action_model." - "action_out._model.0.bias", - "default_policy/sequential_1/q_hidden_0/kernel": "q_net." - "q_hidden_0._model.0.weight", - "default_policy/sequential_1/q_hidden_0/bias": "q_net." - "q_hidden_0._model.0.bias", - "default_policy/sequential_1/q_out/kernel": "q_net." - "q_out._model.0.weight", - "default_policy/sequential_1/q_out/bias": "q_net." - "q_out._model.0.bias", - "default_policy/value_out/kernel": "_value_branch." + # Action net. + "default_policy/fc_1/kernel": "action_model._hidden_layers.0." "_model.0.weight", - "default_policy/value_out/bias": "_value_branch." + "default_policy/fc_1/bias": "action_model._hidden_layers.0." "_model.0.bias", + "default_policy/fc_out/kernel": "action_model." + "_logits._model.0.weight", + "default_policy/fc_out/bias": "action_model._logits._model.0.bias", + "default_policy/value_out/kernel": "action_model." + "_value_branch._model.0.weight", + "default_policy/value_out/bias": "action_model." + "_value_branch._model.0.bias", + # Q-net. + "default_policy/fc_1_1/kernel": "q_net." + "_hidden_layers.0._model.0.weight", + "default_policy/fc_1_1/bias": "q_net." + "_hidden_layers.0._model.0.bias", + "default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight", + "default_policy/fc_out_1/bias": "q_net._logits._model.0.bias", + "default_policy/value_out_1/kernel": "q_net." + "_value_branch._model.0.weight", + "default_policy/value_out_1/bias": "q_net." + "_value_branch._model.0.bias", "default_policy/log_alpha": "log_alpha", - # Target net. - "default_policy/sequential_2/action_1/kernel": "action_model." - "action_0._model.0.weight", - "default_policy/sequential_2/action_1/bias": "action_model." - "action_0._model.0.bias", - "default_policy/sequential_2/action_out/kernel": "action_model." - "action_out._model.0.weight", - "default_policy/sequential_2/action_out/bias": "action_model." - "action_out._model.0.bias", - "default_policy/sequential_3/q_hidden_0/kernel": "q_net." - "q_hidden_0._model.0.weight", - "default_policy/sequential_3/q_hidden_0/bias": "q_net." - "q_hidden_0._model.0.bias", - "default_policy/sequential_3/q_out/kernel": "q_net." - "q_out._model.0.weight", - "default_policy/sequential_3/q_out/bias": "q_net." - "q_out._model.0.bias", - "default_policy/value_out_1/kernel": "_value_branch." - "_model.0.weight", - "default_policy/value_out_1/bias": "_value_branch." - "_model.0.bias", + # Target action-net. + "default_policy/fc_1_2/kernel": "action_model." + "_hidden_layers.0._model.0.weight", + "default_policy/fc_1_2/bias": "action_model." + "_hidden_layers.0._model.0.bias", + "default_policy/fc_out_2/kernel": "action_model." + "_logits._model.0.weight", + "default_policy/fc_out_2/bias": "action_model." + "_logits._model.0.bias", + "default_policy/value_out_2/kernel": "action_model." + "_value_branch._model.0.weight", + "default_policy/value_out_2/bias": "action_model." + "_value_branch._model.0.bias", + # Target Q-net + "default_policy/fc_1_3/kernel": "q_net." + "_hidden_layers.0._model.0.weight", + "default_policy/fc_1_3/bias": "q_net." + "_hidden_layers.0._model.0.bias", + "default_policy/fc_out_3/kernel": "q_net." + "_logits._model.0.weight", + "default_policy/fc_out_3/bias": "q_net." + "_logits._model.0.bias", + "default_policy/value_out_3/kernel": "q_net." + "_value_branch._model.0.weight", + "default_policy/value_out_3/bias": "q_net." + "_value_branch._model.0.bias", "default_policy/log_alpha_1": "log_alpha", } @@ -225,10 +260,12 @@ class TestSAC(unittest.TestCase): policy.td_error, policy.optimizer().compute_gradients( policy.critic_loss[0], - policy.model.q_variables()), + [v for v in policy.model.q_variables() if + "value_" not in v.name]), policy.optimizer().compute_gradients( policy.actor_loss, - policy.model.policy_variables()), + [v for v in policy.model.policy_variables() if + "value_" not in v.name]), policy.optimizer().compute_gradients( policy.alpha_loss, policy.model.log_alpha)], feed_dict=policy._get_loss_inputs_dict( @@ -261,8 +298,6 @@ class TestSAC(unittest.TestCase): a.backward() # `actor_loss` depends on Q-net vars (but these grads must # be ignored and overridden in critic_loss.backward!). - assert not any(v.grad is None - for v in policy.model.q_variables()) assert not all( torch.mean(v.grad) == 0 for v in policy.model.policy_variables()) @@ -273,45 +308,38 @@ class TestSAC(unittest.TestCase): # Compare with tf ones. torch_a_grads = [ v.grad for v in policy.model.policy_variables() + if v.grad is not None ] - for tf_g, torch_g in zip(tf_a_grads, torch_a_grads): - if tf_g.shape != torch_g.shape: - check(tf_g, np.transpose(torch_g.detach().cpu())) - else: - check(tf_g, torch_g) + check(tf_a_grads[2], + np.transpose(torch_a_grads[0].detach().cpu())) # Test critic gradients. policy.critic_optims[0].zero_grad() assert all( torch.mean(v.grad) == 0.0 - for v in policy.model.q_variables()) + for v in policy.model.q_variables() if v.grad is not None) assert all( torch.min(v.grad) == 0.0 - for v in policy.model.q_variables()) + for v in policy.model.q_variables() if v.grad is not None) assert policy.model.log_alpha.grad is None c[0].backward() assert not all( torch.mean(v.grad) == 0 - for v in policy.model.q_variables()) + for v in policy.model.q_variables() if v.grad is not None) assert not all( - torch.min(v.grad) == 0 for v in policy.model.q_variables()) + torch.min(v.grad) == 0 for v in policy.model.q_variables() + if v.grad is not None) assert policy.model.log_alpha.grad is None # Compare with tf ones. torch_c_grads = [v.grad for v in policy.model.q_variables()] - for tf_g, torch_g in zip(tf_c_grads, torch_c_grads): - if tf_g.shape != torch_g.shape: - check(tf_g, np.transpose(torch_g.detach().cpu())) - else: - check(tf_g, torch_g) + check(tf_c_grads[0], + np.transpose(torch_c_grads[2].detach().cpu())) # Compare (unchanged(!) actor grads) with tf ones. torch_a_grads = [ v.grad for v in policy.model.policy_variables() ] - for tf_g, torch_g in zip(tf_a_grads, torch_a_grads): - if tf_g.shape != torch_g.shape: - check(tf_g, np.transpose(torch_g.detach().cpu())) - else: - check(tf_g, torch_g) + check(tf_a_grads[2], + np.transpose(torch_a_grads[0].detach().cpu())) # Test alpha gradient. policy.alpha_optim.zero_grad() @@ -336,7 +364,7 @@ class TestSAC(unittest.TestCase): prev_fw_loss = (c, a, e, t) # Update weights from our batch (n times). - for update_iteration in range(10): + for update_iteration in range(5): print("train iteration {}".format(update_iteration)) if fw == "tf": in_ = self._get_batch_helper(obs_size, actions, batch_size) @@ -350,10 +378,9 @@ class TestSAC(unittest.TestCase): # Net must have changed. if tf_updated_weights: check( - updated_weights[ - "default_policy/sequential/action_1/kernel"], + updated_weights["default_policy/fc_1/kernel"], tf_updated_weights[-1][ - "default_policy/sequential/action_1/kernel"], + "default_policy/fc_1/kernel"], false=True) tf_updated_weights.append(updated_weights) @@ -367,7 +394,9 @@ class TestSAC(unittest.TestCase): buf._fake_batch = in_ trainer.train() # Compare updated model. - for tf_key in sorted(tf_weights.keys())[2:10]: + for tf_key in sorted(tf_weights.keys()): + if re.search("_[23]|alpha", tf_key): + continue tf_var = tf_weights[tf_key] torch_var = policy.model.state_dict()[map_[tf_key]] if tf_var.shape != torch_var.shape: @@ -381,7 +410,9 @@ class TestSAC(unittest.TestCase): check(policy.model.log_alpha, tf_weights["default_policy/log_alpha"]) # Compare target nets. - for tf_key in sorted(tf_weights.keys())[10:18]: + for tf_key in sorted(tf_weights.keys()): + if not re.search("_[23]", tf_key): + continue tf_var = tf_weights[tf_key] torch_var = policy.target_model.state_dict()[map_[ tf_key]] @@ -437,9 +468,9 @@ class TestSAC(unittest.TestCase): fc( relu( fc(model_out_t, - weights[ks[3]], - weights[ks[2]], - framework=fw)), weights[ks[5]], weights[ks[4]]), None) + weights[ks[1]], + weights[ks[0]], + framework=fw)), weights[ks[9]], weights[ks[8]]), None) policy_t = action_dist_t.deterministic_sample() log_pis_t = action_dist_t.logp(policy_t) if sess: @@ -452,9 +483,9 @@ class TestSAC(unittest.TestCase): fc( relu( fc(model_out_tp1, - weights[ks[3]], - weights[ks[2]], - framework=fw)), weights[ks[5]], weights[ks[4]]), None) + weights[ks[1]], + weights[ks[0]], + framework=fw)), weights[ks[9]], weights[ks[8]]), None) policy_tp1 = action_dist_tp1.deterministic_sample() log_pis_tp1 = action_dist_tp1.logp(policy_tp1) if sess: @@ -468,11 +499,11 @@ class TestSAC(unittest.TestCase): relu( fc(np.concatenate( [model_out_t, train_batch[SampleBatch.ACTIONS]], -1), - weights[ks[7]], - weights[ks[6]], + weights[ks[3]], + weights[ks[2]], framework=fw)), - weights[ks[9]], - weights[ks[8]], + weights[ks[11]], + weights[ks[10]], framework=fw) # Q-values for current policy in given current state. @@ -480,11 +511,11 @@ class TestSAC(unittest.TestCase): q_t_det_policy = fc( relu( fc(np.concatenate([model_out_t, policy_t], -1), - weights[ks[7]], - weights[ks[6]], + weights[ks[3]], + weights[ks[2]], framework=fw)), - weights[ks[9]], - weights[ks[8]], + weights[ks[11]], + weights[ks[10]], framework=fw) # Target q network evaluation. @@ -493,11 +524,11 @@ class TestSAC(unittest.TestCase): q_tp1 = fc( relu( fc(np.concatenate([target_model_out_tp1, policy_tp1], -1), - weights[ks[15]], - weights[ks[14]], + weights[ks[7]], + weights[ks[6]], framework=fw)), - weights[ks[17]], - weights[ks[16]], + weights[ks[15]], + weights[ks[14]], framework=fw) else: assert fw == "tfe" @@ -538,9 +569,9 @@ class TestSAC(unittest.TestCase): map_[k]: convert_to_torch_tensor( np.transpose(v) if re.search("kernel", k) else np.array([v]) if re.search("log_alpha", k) else v) - for k, v in weights_dict.items() - if re.search("(sequential(/|_1)|value_out/|log_alpha)", k) + for i, (k, v) in enumerate(weights_dict.items()) if i < 13 } + return model_dict def _translate_tfe_weights(self, weights_dict, map_): diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index d0770cdf7..39d4bef77 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -32,7 +32,7 @@ from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.debug import summarize -from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.filter import get_filter, Filter from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.sgd import do_minibatch_sgd @@ -396,15 +396,22 @@ class RolloutWorker(ParallelIteratorWorker): if clip_rewards is None: clip_rewards = True - # framestacking via trajectory view API is enabled. - num_framestacks = model_config.get("num_framestacks", 0) - if not policy_config["_use_trajectory_view_api"]: - model_config["num_framestacks"] = num_framestacks = 0 - elif num_framestacks == "auto": - model_config["num_framestacks"] = num_framestacks = 4 - framestack_traj_view = num_framestacks > 1 # Deprecated way of framestacking is used. framestack = model_config.get("framestack") is True + # framestacking via trajectory view API is enabled. + num_framestacks = model_config.get("num_framestacks", 0) + + # No trajectory view API: No traj. view based framestacking. + if not policy_config["_use_trajectory_view_api"]: + model_config["num_framestacks"] = num_framestacks = 0 + # Trajectory view API is on and num_framestacks=auto: Only + # stack traj. view based if old `framestack=[invalid value]`. + elif num_framestacks == "auto": + if framestack == DEPRECATED_VALUE: + model_config["num_framestacks"] = num_framestacks = 4 + else: + model_config["num_framestacks"] = num_framestacks = 0 + framestack_traj_view = num_framestacks > 1 def wrap(env): env = wrap_deepmind( diff --git a/rllib/examples/models/cnn_plus_fc_concat_model.py b/rllib/examples/models/cnn_plus_fc_concat_model.py deleted file mode 100644 index 6f8e3d85e..000000000 --- a/rllib/examples/models/cnn_plus_fc_concat_model.py +++ /dev/null @@ -1,218 +0,0 @@ -from gym.spaces import Discrete, Tuple - -from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.tf.misc import normc_initializer -from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.models.torch.misc import normc_initializer as \ - torch_normc_initializer, SlimFC -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.models.utils import get_filter_config -from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_tf, try_import_torch - -tf1, tf, tfv = try_import_tf() -torch, nn = try_import_torch() - - -# __sphinx_doc_begin__ -class CNNPlusFCConcatModel(TFModelV2): - """TFModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s). - - Note: This model should be used for complex (Dict or Tuple) observation - spaces that have one or more image components. - """ - - def __init__(self, obs_space, action_space, num_outputs, model_config, - name): - # TODO: (sven) Support Dicts as well. - assert isinstance(obs_space.original_space, (Tuple)), \ - "`obs_space.original_space` must be Tuple!" - - super().__init__(obs_space, action_space, num_outputs, model_config, - name) - - # Build the CNN(s) given obs_space's image components. - self.cnns = {} - concat_size = 0 - for i, component in enumerate(obs_space.original_space): - # Image space. - if len(component.shape) == 3: - config = { - "conv_filters": model_config.get( - "conv_filters", get_filter_config(component.shape)), - "conv_activation": model_config.get("conv_activation"), - } - cnn = ModelCatalog.get_model_v2( - component, - action_space, - num_outputs=None, - model_config=config, - framework="tf", - name="cnn_{}".format(i)) - concat_size += cnn.num_outputs - self.cnns[i] = cnn - # Discrete inputs -> One-hot encode. - elif isinstance(component, Discrete): - concat_size += component.n - # TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers). - # Everything else (1D Box). - else: - assert len(component.shape) == 1, \ - "Only input Box 1D or 3D spaces allowed!" - concat_size += component.shape[-1] - - self.logits_and_value_model = None - self._value_out = None - if num_outputs: - # Action-distribution head. - concat_layer = tf.keras.layers.Input((concat_size, )) - logits_layer = tf.keras.layers.Dense( - num_outputs, - activation=tf.keras.activations.linear, - name="logits")(concat_layer) - - # Create the value branch model. - value_layer = tf.keras.layers.Dense( - 1, - name="value_out", - activation=None, - kernel_initializer=normc_initializer(0.01))(concat_layer) - self.logits_and_value_model = tf.keras.models.Model( - concat_layer, [logits_layer, value_layer]) - else: - self.num_outputs = concat_size - - @override(ModelV2) - def forward(self, input_dict, state, seq_lens): - # Push image observations through our CNNs. - outs = [] - for i, component in enumerate(input_dict["obs"]): - if i in self.cnns: - cnn_out, _ = self.cnns[i]({"obs": component}) - outs.append(cnn_out) - else: - outs.append(component) - # Concat all outputs and the non-image inputs. - out = tf.concat(outs, axis=1) - if not self.logits_and_value_model: - return out, [] - - # Value branch. - logits, values = self.logits_and_value_model(out) - self._value_out = tf.reshape(values, [-1]) - return logits, [] - - @override(ModelV2) - def value_function(self): - return self._value_out - - -# __sphinx_doc_end__ - - -class TorchCNNPlusFCConcatModel(TorchModelV2, nn.Module): - """TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s). - - Note: This model should be used for complex (Dict or Tuple) observation - spaces that have one or more image components. - """ - - def __init__(self, obs_space, action_space, num_outputs, model_config, - name): - # TODO: (sven) Support Dicts as well. - assert isinstance(obs_space.original_space, (Tuple)), \ - "`obs_space.original_space` must be Tuple!" - - nn.Module.__init__(self) - TorchModelV2.__init__(self, obs_space, action_space, num_outputs, - model_config, name) - - # Atari type CNNs or IMPALA type CNNs (with residual layers)? - self.cnn_type = self.model_config["custom_model_config"].get( - "conv_type", "atari") - - # Build the CNN(s) given obs_space's image components. - self.cnns = {} - concat_size = 0 - for i, component in enumerate(obs_space.original_space): - # Image space. - if len(component.shape) == 3: - config = { - "conv_filters": model_config.get( - "conv_filters", get_filter_config(component.shape)), - "conv_activation": model_config.get("conv_activation"), - } - if self.cnn_type == "atari": - cnn = ModelCatalog.get_model_v2( - component, - action_space, - num_outputs=None, - model_config=config, - framework="torch", - name="cnn_{}".format(i)) - else: - cnn = TorchImpalaVisionNet( - component, - action_space, - num_outputs=None, - model_config=config, - name="cnn_{}".format(i)) - - concat_size += cnn.num_outputs - self.cnns[i] = cnn - self.add_module("cnn_{}".format(i), cnn) - # Discrete inputs -> One-hot encode. - elif isinstance(component, Discrete): - concat_size += component.n - # TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers). - # Everything else (1D Box). - else: - assert len(component.shape) == 1, \ - "Only input Box 1D or 3D spaces allowed!" - concat_size += component.shape[-1] - - self.logits_layer = None - self.value_layer = None - self._value_out = None - - if num_outputs: - # Action-distribution head. - self.logits_layer = SlimFC( - in_size=concat_size, - out_size=num_outputs, - activation_fn=None, - ) - # Create the value branch model. - self.value_layer = SlimFC( - in_size=concat_size, - out_size=1, - activation_fn=None, - initializer=torch_normc_initializer(0.01)) - else: - self.num_outputs = concat_size - - @override(ModelV2) - def forward(self, input_dict, state, seq_lens): - # Push image observations through our CNNs. - outs = [] - for i, component in enumerate(input_dict["obs"]): - if i in self.cnns: - cnn_out, _ = self.cnns[i]({"obs": component}) - outs.append(cnn_out) - else: - outs.append(component) - # Concat all outputs and the non-image inputs. - out = torch.cat(outs, dim=1) - if self.logits_layer is None: - return out, [] - - # Value branch. - logits, values = self.logits_layer(out), self.value_layer(out) - self._value_out = torch.reshape(values, [-1]) - return logits, [] - - @override(ModelV2) - def value_function(self): - return self._value_out diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 66796d71f..74ddcbeab 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -19,7 +19,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \ TorchDeterministic, TorchDiagGaussian, \ TorchMultiActionDistribution, TorchMultiCategorical from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI -from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning +from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.spaces.simplex import Simplex @@ -56,6 +56,18 @@ MODEL_DEFAULTS: ModelConfigDict = { # "linear" (or None). "conv_activation": "relu", + # Some default models support a final FC stack of n Dense layers with given + # activation: + # - Complex observation spaces: Image components are fed through + # VisionNets, flat Boxes are left as-is, Discrete are one-hot'd, then + # everything is concated and pushed through this final FC stack. + # - VisionNets (CNNs), e.g. after the CNN stack, there may be + # additional Dense layers. + # - FullyConnectedNetworks will have this additional FCStack as well + # (that's why it's empty by default). + "post_fcnet_hiddens": [], + "post_fcnet_activation": "relu", + # For DiagGaussian action distributions, make the second half of the model # outputs floating bias variables instead of state-dependent. This only # has an effect is using the default fully connected net. @@ -688,17 +700,22 @@ class ModelCatalog: framework: str = "tf") -> Type[ModelV2]: VisionNet = None + ComplexNet = None if framework in ["tf2", "tf", "tfe"]: from ray.rllib.models.tf.fcnet import \ FullyConnectedNetwork as FCNet from ray.rllib.models.tf.visionnet import \ VisionNetwork as VisionNet + from ray.rllib.models.tf.complex_input_net import \ + ComplexInputNetwork as ComplexNet elif framework == "torch": from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as FCNet) from ray.rllib.models.torch.visionnet import (VisionNetwork as VisionNet) + from ray.rllib.models.torch.complex_input_net import \ + ComplexInputNetwork as ComplexNet elif framework == "jax": from ray.rllib.models.jax.fcnet import (FullyConnectedNetwork as FCNet) @@ -710,16 +727,29 @@ class ModelCatalog: # Discrete/1D obs-spaces or 2D obs space but traj. view framestacking # disabled. num_framestacks = model_config.get("num_framestacks", "auto") + + # Tuple space, where at least one sub-space is image. + # -> Complex input model. + space_to_check = input_space if not hasattr( + input_space, "original_space") else input_space.original_space + if isinstance(input_space, + Tuple) or (isinstance(space_to_check, Tuple) and any( + isinstance(s, Box) and len(s.shape) >= 2 + for s in space_to_check.spaces)): + return ComplexNet + + # Single, flattenable/one-hot-abe space -> Simple FCNet. if isinstance(input_space, (Discrete, MultiDiscrete)) or \ len(input_space.shape) == 1 or ( len(input_space.shape) == 2 and ( num_framestacks == "auto" or num_framestacks <= 1)): return FCNet - # Default Conv2D net. - else: - if framework == "jax": - raise NotImplementedError("No Conv2D default net for JAX yet!") - return VisionNet + + elif framework == "jax": + raise NotImplementedError("No non-FC default net for JAX yet!") + + # Last resort: Conv2D stack for single image spaces. + return VisionNet @staticmethod def _get_multi_action_distribution(dist_class, action_space, config, @@ -768,8 +798,8 @@ class ModelCatalog: "framework=jax so far!") if config.get("framestack") != DEPRECATED_VALUE: - deprecation_warning( - old="framestack", new="num_framestacks (int)", error=False) + # deprecation_warning( + # old="framestack", new="num_framestacks (int)", error=False) # If old behavior is desired, disable traj. view-style # framestacking. config["num_framestacks"] = 0 diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 70ad50202..bd5ee1132 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -203,9 +203,13 @@ class ModelV2: restored = input_dict.copy() restored["obs"] = restore_original_dimensions( input_dict["obs"], self.obs_space, self.framework) - if len(input_dict["obs"].shape) > 2: - restored["obs_flat"] = flatten(input_dict["obs"], self.framework) - else: + try: + if len(input_dict["obs"].shape) > 2: + restored["obs_flat"] = flatten(input_dict["obs"], + self.framework) + else: + restored["obs_flat"] = input_dict["obs"] + except AttributeError: restored["obs_flat"] = input_dict["obs"] with self.context(): res = self.forward(restored, state or [], seq_lens) @@ -216,15 +220,6 @@ class ModelV2: "got {}".format(res)) outputs, state = res - try: - shape = outputs.shape - except AttributeError: - raise ValueError("Output is not a tensor: {}".format(outputs)) - else: - if len(shape) != 2 or int(shape[1]) != self.num_outputs: - raise ValueError( - "Expected output shape of [None, {}], got {}".format( - self.num_outputs, shape)) if not isinstance(state, list): raise ValueError("State output is not a list: {}".format(state)) @@ -418,15 +413,15 @@ def restore_original_dimensions(obs: TensorType, observation space. """ - if hasattr(obs_space, "original_space"): - if tensorlib == "tf": - tensorlib = tf - elif tensorlib == "torch": - assert torch is not None - tensorlib = torch - return _unpack_obs(obs, obs_space.original_space, tensorlib=tensorlib) - else: + if tensorlib == "tf": + tensorlib = tf + elif tensorlib == "torch": + assert torch is not None + tensorlib = torch + original_space = getattr(obs_space, "original_space", obs_space) + if original_space is obs_space: return obs + return _unpack_obs(obs, original_space, tensorlib=tensorlib) # Cache of preprocessors, for if the user is calling unpack obs often. @@ -490,7 +485,8 @@ def _unpack_obs(obs: TensorType, space: gym.Space, tensorlib.reshape(obs_slice, batch_dims + list(p.shape)), v, tensorlib=tensorlib) - elif isinstance(space, Repeated): + # Repeated space. + else: assert isinstance(prep, RepeatedValuesPreprocessor), prep child_size = prep.child_preprocessor.size # The list lengths are stored in the first slot of the flat obs. @@ -503,8 +499,6 @@ def _unpack_obs(obs: TensorType, space: gym.Space, with_repeat_dim, space.child_space, tensorlib=tensorlib) return RepeatedValues( u, lengths=lengths, max_len=prep._obs_space.max_len) - else: - assert False, space return u else: return obs diff --git a/rllib/models/tf/complex_input_net.py b/rllib/models/tf/complex_input_net.py new file mode 100644 index 000000000..8bc691e24 --- /dev/null +++ b/rllib/models/tf/complex_input_net.py @@ -0,0 +1,156 @@ +from gym.spaces import Box, Discrete, Tuple +import numpy as np + +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions +from ray.rllib.models.tf.misc import normc_initializer +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.models.utils import get_filter_config +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_ops import one_hot + +tf1, tf, tfv = try_import_tf() + + +# __sphinx_doc_begin__ +class ComplexInputNetwork(TFModelV2): + """TFModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s). + + Note: This model should be used for complex (Dict or Tuple) observation + spaces that have one or more image components. + + The data flow is as follows: + + `obs` (e.g. Tuple[img0, img1, discrete0]) -> `CNN0 + CNN1 + ONE-HOT` + `CNN0 + CNN1 + ONE-HOT` -> concat all flat outputs -> `out` + `out` -> (optional) FC-stack -> `out2` + `out2` -> action (logits) and vaulue heads. + """ + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + # TODO: (sven) Support Dicts as well. + self.original_space = obs_space.original_space if \ + hasattr(obs_space, "original_space") else obs_space + assert isinstance(self.original_space, (Tuple)), \ + "`obs_space.original_space` must be Tuple!" + + super().__init__(self.original_space, action_space, num_outputs, + model_config, name) + + # Build the CNN(s) given obs_space's image components. + self.cnns = {} + self.one_hot = {} + self.flatten = {} + concat_size = 0 + for i, component in enumerate(self.original_space): + # Image space. + if len(component.shape) == 3: + config = { + "conv_filters": model_config.get( + "conv_filters", get_filter_config(component.shape)), + "conv_activation": model_config.get("conv_activation"), + "post_fcnet_hiddens": [], + } + cnn = ModelCatalog.get_model_v2( + component, + action_space, + num_outputs=None, + model_config=config, + framework="tf", + name="cnn_{}".format(i)) + concat_size += cnn.num_outputs + self.cnns[i] = cnn + # Discrete inputs -> One-hot encode. + elif isinstance(component, Discrete): + self.one_hot[i] = True + concat_size += component.n + # TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers). + # Everything else (1D Box). + else: + self.flatten[i] = int(np.product(component.shape)) + concat_size += self.flatten[i] + + # Optional post-concat FC-stack. + post_fc_stack_config = { + "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []), + "fcnet_activation": model_config.get("post_fcnet_activation", + "relu") + } + self.post_fc_stack = ModelCatalog.get_model_v2( + Box(float("-inf"), + float("inf"), + shape=(concat_size, ), + dtype=np.float32), + self.action_space, + None, + post_fc_stack_config, + framework="tf", + name="post_fc_stack") + + # Actions and value heads. + self.logits_and_value_model = None + self._value_out = None + if num_outputs: + # Action-distribution head. + concat_layer = tf.keras.layers.Input( + (self.post_fc_stack.num_outputs, )) + logits_layer = tf.keras.layers.Dense( + num_outputs, + activation=tf.keras.activations.linear, + name="logits")(concat_layer) + + # Create the value branch model. + value_layer = tf.keras.layers.Dense( + 1, + name="value_out", + activation=None, + kernel_initializer=normc_initializer(0.01))(concat_layer) + self.logits_and_value_model = tf.keras.models.Model( + concat_layer, [logits_layer, value_layer]) + else: + self.num_outputs = self.post_fc_stack.num_outputs + + @override(ModelV2) + def forward(self, input_dict, state, seq_lens): + if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: + orig_obs = input_dict[SampleBatch.OBS] + else: + orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS], + self.obs_space, "tf") + # Push image observations through our CNNs. + outs = [] + for i, component in enumerate(orig_obs): + if i in self.cnns: + cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component}) + outs.append(cnn_out) + elif i in self.one_hot: + if component.dtype in [tf.int32, tf.int64, tf.uint8]: + outs.append( + one_hot(component, self.original_space.spaces[i])) + else: + outs.append(component) + else: + outs.append(tf.reshape(component, [-1, self.flatten[i]])) + # Concat all outputs and the non-image inputs. + out = tf.concat(outs, axis=1) + # Push through (optional) FC-stack (this may be an empty stack). + out, _ = self.post_fc_stack({SampleBatch.OBS: out}, [], None) + + # No logits/value branches. + if not self.logits_and_value_model: + return out, [] + + # Logits- and value branches. + logits, values = self.logits_and_value_model(out) + self._value_out = tf.reshape(values, [-1]) + return logits, [] + + @override(ModelV2) + def value_function(self): + return self._value_out + + +# __sphinx_doc_end__ diff --git a/rllib/models/tf/fcnet.py b/rllib/models/tf/fcnet.py index eea01014d..9b0e8c565 100644 --- a/rllib/models/tf/fcnet.py +++ b/rllib/models/tf/fcnet.py @@ -19,8 +19,12 @@ class FullyConnectedNetwork(TFModelV2): super(FullyConnectedNetwork, self).__init__( obs_space, action_space, num_outputs, model_config, name) - activation = get_activation_fn(model_config.get("fcnet_activation")) - hiddens = model_config.get("fcnet_hiddens", []) + hiddens = model_config.get("fcnet_hiddens", []) + \ + model_config.get("post_fcnet_hiddens", []) + activation = model_config.get("fcnet_activation") + if not model_config.get("fcnet_hiddens", []): + activation = model_config.get("post_fcnet_activation") + activation = get_activation_fn(activation) no_final_linear = model_config.get("no_final_linear") vf_share_layers = model_config.get("vf_share_layers") free_log_std = model_config.get("free_log_std") diff --git a/rllib/models/tf/tf_modelv2.py b/rllib/models/tf/tf_modelv2.py index 4394d3213..dfb850a33 100644 --- a/rllib/models/tf/tf_modelv2.py +++ b/rllib/models/tf/tf_modelv2.py @@ -107,7 +107,8 @@ class TFModelV2(ModelV2): if isinstance(struct, tf.keras.models.Model): ret = {} for var in struct.variables: - key = current_key + "." + re.sub("/", ".", var.name) + name = re.sub("/", ".", var.name) + key = current_key + "." + name ret[key] = var return ret # Other TFModelV2: Include its vars into ours. @@ -118,7 +119,7 @@ class TFModelV2(ModelV2): } # tf.Variable elif isinstance(struct, tf.Variable): - return {current_key + "." + struct.name: struct} + return {current_key: struct} # List/Tuple. elif isinstance(struct, (tuple, list)): ret = {} @@ -133,7 +134,7 @@ class TFModelV2(ModelV2): current_key += "_" ret = {} for key, value in struct.items(): - sub_vars = TFModelV2._find_sub_modules(current_key + key, + sub_vars = TFModelV2._find_sub_modules(current_key + str(key), value) ret.update(sub_vars) return ret diff --git a/rllib/models/tf/visionnet.py b/rllib/models/tf/visionnet.py index b83e867b6..955ac1e52 100644 --- a/rllib/models/tf/visionnet.py +++ b/rllib/models/tf/visionnet.py @@ -13,7 +13,17 @@ tf1, tf, tfv = try_import_tf() class VisionNetwork(TFModelV2): - """Generic vision network implemented in ModelV2 API.""" + """Generic vision network implemented in ModelV2 API. + + An additional post-conv fully connected stack can be added and configured + via the config keys: + `post_fcnet_hiddens`: Dense layer sizes after the Conv2D stack. + `post_fcnet_activation`: Activation function to use for this FC stack. + + Examples: + + + """ def __init__(self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int, @@ -29,6 +39,12 @@ class VisionNetwork(TFModelV2): filters = self.model_config["conv_filters"] assert len(filters) > 0,\ "Must provide at least 1 entry in `conv_filters`!" + + # Post FC net config. + post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", []) + post_fcnet_activation = get_activation_fn( + model_config.get("post_fcnet_activation"), framework="tf") + no_final_linear = self.model_config.get("no_final_linear") vf_share_layers = self.model_config.get("vf_share_layers") self.traj_view_framestacking = False @@ -62,17 +78,29 @@ class VisionNetwork(TFModelV2): out_size, kernel, stride = filters[-1] - # No final linear: Last layer is a Conv2D and uses num_outputs. + # No final linear: Last layer has activation function and exits with + # num_outputs nodes (this could be a 1x1 conv or a FC layer, depending + # on `post_fcnet_...` settings). if no_final_linear and num_outputs: last_layer = tf.keras.layers.Conv2D( - num_outputs, + out_size if post_fcnet_hiddens else num_outputs, kernel, strides=(stride, stride), activation=activation, padding="valid", data_format="channels_last", name="conv_out")(last_layer) - conv_out = last_layer + # Add (optional) post-fc-stack after last Conv2D layer. + layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs] + if post_fcnet_hiddens else + []) + for i, out_size in enumerate(layer_sizes): + last_layer = tf.keras.layers.Dense( + out_size, + name="post_fcnet_{}".format(i), + activation=post_fcnet_activation, + kernel_initializer=normc_initializer(1.0))(last_layer) + # Finish network normally (w/o overriding last layer size with # `num_outputs`), then add another linear one of size `num_outputs`. else: @@ -88,29 +116,56 @@ class VisionNetwork(TFModelV2): # num_outputs defined. Use that to create an exact # `num_output`-sized (1,1)-Conv2D. if num_outputs: - conv_out = tf.keras.layers.Conv2D( - num_outputs, [1, 1], - activation=None, - padding="same", - data_format="channels_last", - name="conv_out")(last_layer) + if post_fcnet_hiddens: + last_cnn = last_layer = tf.keras.layers.Conv2D( + post_fcnet_hiddens[0], [1, 1], + activation=post_fcnet_activation, + padding="same", + data_format="channels_last", + name="conv_out")(last_layer) + # Add (optional) post-fc-stack after last Conv2D layer. + for i, out_size in enumerate(post_fcnet_hiddens[1:] + + [num_outputs]): + last_layer = tf.keras.layers.Dense( + out_size, + name="post_fcnet_{}".format(i + 1), + activation=post_fcnet_activation + if i < len(post_fcnet_hiddens) - 1 else None, + kernel_initializer=normc_initializer(1.0))( + last_layer) + else: + last_cnn = last_layer = tf.keras.layers.Conv2D( + num_outputs, [1, 1], + activation=None, + padding="same", + data_format="channels_last", + name="conv_out")(last_layer) - if conv_out.shape[1] != 1 or conv_out.shape[2] != 1: + if last_cnn.shape[1] != 1 or last_cnn.shape[2] != 1: raise ValueError( "Given `conv_filters` ({}) do not result in a [B, 1, " "1, {} (`num_outputs`)] shape (but in {})! Please " "adjust your Conv2D stack such that the dims 1 and 2 " "are both 1.".format(self.model_config["conv_filters"], self.num_outputs, - list(conv_out.shape))) + list(last_cnn.shape))) # num_outputs not known -> Flatten, then set self.num_outputs # to the resulting number of nodes. else: self.last_layer_is_flattened = True - conv_out = tf.keras.layers.Flatten( + last_layer = tf.keras.layers.Flatten( data_format="channels_last")(last_layer) - self.num_outputs = conv_out.shape[1] + + # Add (optional) post-fc-stack after last Conv2D layer. + for i, out_size in enumerate(post_fcnet_hiddens): + last_layer = tf.keras.layers.Dense( + out_size, + name="post_fcnet_{}".format(i), + activation=post_fcnet_activation, + kernel_initializer=normc_initializer(1.0))(last_layer) + self.num_outputs = last_layer.shape[1] + logits_out = last_layer # Build the value layers if vf_share_layers: @@ -151,7 +206,7 @@ class VisionNetwork(TFModelV2): value_out = tf.keras.layers.Lambda( lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer) - self.base_model = tf.keras.Model(inputs, [conv_out, value_out]) + self.base_model = tf.keras.Model(inputs, [logits_out, value_out]) # Optional: framestacking obs/new_obs for Atari. if self.traj_view_framestacking: diff --git a/rllib/models/torch/complex_input_net.py b/rllib/models/torch/complex_input_net.py new file mode 100644 index 000000000..2b9601947 --- /dev/null +++ b/rllib/models/torch/complex_input_net.py @@ -0,0 +1,163 @@ +from gym.spaces import Box, Discrete, Tuple +import numpy as np + +# TODO (sven): add IMPALA-style option. +# from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet +from ray.rllib.models.torch.misc import normc_initializer as \ + torch_normc_initializer, SlimFC +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.models.utils import get_filter_config +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_ops import one_hot + +torch, nn = try_import_torch() + + +class ComplexInputNetwork(TorchModelV2, nn.Module): + """TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s). + + Note: This model should be used for complex (Dict or Tuple) observation + spaces that have one or more image components. + + The data flow is as follows: + + `obs` (e.g. Tuple[img0, img1, discrete0]) -> `CNN0 + CNN1 + ONE-HOT` + `CNN0 + CNN1 + ONE-HOT` -> concat all flat outputs -> `out` + `out` -> (optional) FC-stack -> `out2` + `out2` -> action (logits) and vaulue heads. + """ + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + # TODO: (sven) Support Dicts as well. + self.original_space = obs_space.original_space if \ + hasattr(obs_space, "original_space") else obs_space + assert isinstance(self.original_space, (Tuple)), \ + "`obs_space.original_space` must be Tuple!" + + nn.Module.__init__(self) + TorchModelV2.__init__(self, self.original_space, action_space, + num_outputs, model_config, name) + + # Atari type CNNs or IMPALA type CNNs (with residual layers)? + # self.cnn_type = self.model_config["custom_model_config"].get( + # "conv_type", "atari") + + # Build the CNN(s) given obs_space's image components. + self.cnns = {} + self.one_hot = {} + self.flatten = {} + concat_size = 0 + for i, component in enumerate(self.original_space): + # Image space. + if len(component.shape) == 3: + config = { + "conv_filters": model_config.get( + "conv_filters", get_filter_config(component.shape)), + "conv_activation": model_config.get("conv_activation"), + "post_fcnet_hiddens": [], + } + # if self.cnn_type == "atari": + cnn = ModelCatalog.get_model_v2( + component, + action_space, + num_outputs=None, + model_config=config, + framework="torch", + name="cnn_{}".format(i)) + # TODO (sven): add IMPALA-style option. + # else: + # cnn = TorchImpalaVisionNet( + # component, + # action_space, + # num_outputs=None, + # model_config=config, + # name="cnn_{}".format(i)) + + concat_size += cnn.num_outputs + self.cnns[i] = cnn + self.add_module("cnn_{}".format(i), cnn) + # Discrete inputs -> One-hot encode. + elif isinstance(component, Discrete): + self.one_hot[i] = True + concat_size += component.n + # TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers). + # Everything else (1D Box). + else: + self.flatten[i] = int(np.product(component.shape)) + concat_size += self.flatten[i] + + # Optional post-concat FC-stack. + post_fc_stack_config = { + "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []), + "fcnet_activation": model_config.get("post_fcnet_activation", + "relu") + } + self.post_fc_stack = ModelCatalog.get_model_v2( + Box(float("-inf"), + float("inf"), + shape=(concat_size, ), + dtype=np.float32), + self.action_space, + None, + post_fc_stack_config, + framework="torch", + name="post_fc_stack") + + # Actions and value heads. + self.logits_layer = None + self.value_layer = None + self._value_out = None + + if num_outputs: + # Action-distribution head. + self.logits_layer = SlimFC( + in_size=self.post_fc_stack.num_outputs, + out_size=num_outputs, + activation_fn=None, + ) + # Create the value branch model. + self.value_layer = SlimFC( + in_size=self.post_fc_stack.num_outputs, + out_size=1, + activation_fn=None, + initializer=torch_normc_initializer(0.01)) + else: + self.num_outputs = concat_size + + @override(ModelV2) + def forward(self, input_dict, state, seq_lens): + # Push image observations through our CNNs. + outs = [] + for i, component in enumerate(input_dict["obs"]): + if i in self.cnns: + cnn_out, _ = self.cnns[i]({"obs": component}) + outs.append(cnn_out) + elif i in self.one_hot: + if component.dtype in [torch.int32, torch.int64, torch.uint8]: + outs.append( + one_hot(component, self.original_space.spaces[i])) + else: + outs.append(component) + else: + outs.append(torch.reshape(component, [-1, self.flatten[i]])) + # Concat all outputs and the non-image inputs. + out = torch.cat(outs, dim=1) + # Push through (optional) FC-stack (this may be an empty stack). + out, _ = self.post_fc_stack({"obs": out}, [], None) + + # No logits/value branches. + if self.logits_layer is None: + return out, [] + + # Logits- and value branches. + logits, values = self.logits_layer(out), self.value_layer(out) + self._value_out = torch.reshape(values, [-1]) + return logits, [] + + @override(ModelV2) + def value_function(self): + return self._value_out diff --git a/rllib/models/torch/fcnet.py b/rllib/models/torch/fcnet.py index 58fbb6bc4..91b9c0e1d 100644 --- a/rllib/models/torch/fcnet.py +++ b/rllib/models/torch/fcnet.py @@ -24,8 +24,11 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module): model_config, name) nn.Module.__init__(self) + hiddens = model_config.get("fcnet_hiddens", []) + \ + model_config.get("post_fcnet_hiddens", []) activation = model_config.get("fcnet_activation") - hiddens = model_config.get("fcnet_hiddens", []) + if not model_config.get("fcnet_hiddens", []): + activation = model_config.get("post_fcnet_activation") no_final_linear = model_config.get("no_final_linear") self.vf_share_layers = model_config.get("vf_share_layers") self.free_log_std = model_config.get("free_log_std") diff --git a/rllib/models/torch/visionnet.py b/rllib/models/torch/visionnet.py index cd6352acd..133c851f5 100644 --- a/rllib/models/torch/visionnet.py +++ b/rllib/models/torch/visionnet.py @@ -5,7 +5,7 @@ import gym from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.misc import normc_initializer, same_padding, \ SlimConv2d, SlimFC -from ray.rllib.models.utils import get_filter_config +from ray.rllib.models.utils import get_activation_fn, get_filter_config from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override @@ -33,6 +33,12 @@ class VisionNetwork(TorchModelV2, nn.Module): filters = self.model_config["conv_filters"] assert len(filters) > 0,\ "Must provide at least 1 entry in `conv_filters`!" + + # Post FC net config. + post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", []) + post_fcnet_activation = get_activation_fn( + model_config.get("post_fcnet_activation"), framework="torch") + no_final_linear = self.model_config.get("no_final_linear") vf_share_layers = self.model_config.get("vf_share_layers") @@ -68,17 +74,33 @@ class VisionNetwork(TorchModelV2, nn.Module): out_channels, kernel, stride = filters[-1] - # No final linear: Last layer is a Conv2D and uses num_outputs. + # No final linear: Last layer has activation function and exits with + # num_outputs nodes (this could be a 1x1 conv or a FC layer, depending + # on `post_fcnet_...` settings). if no_final_linear and num_outputs: + out_channels = out_channels if post_fcnet_hiddens else num_outputs layers.append( SlimConv2d( in_channels, - num_outputs, + out_channels, kernel, stride, None, # padding=valid activation_fn=activation)) - out_channels = num_outputs + + # Add (optional) post-fc-stack after last Conv2D layer. + layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs] + if post_fcnet_hiddens else + []) + for i, out_size in enumerate(layer_sizes): + layers.append( + SlimFC( + in_size=out_channels, + out_size=out_size, + activation_fn=post_fcnet_activation, + initializer=normc_initializer(1.0))) + out_channels = out_size + # Finish network normally (w/o overriding last layer size with # `num_outputs`), then add another linear one of size `num_outputs`. else: @@ -99,12 +121,31 @@ class VisionNetwork(TorchModelV2, nn.Module): np.ceil((in_size[1] - kernel[1]) / stride) ] padding, _ = same_padding(in_size, [1, 1], [1, 1]) - self._logits = SlimConv2d( - out_channels, - num_outputs, [1, 1], - 1, - padding, - activation_fn=None) + if post_fcnet_hiddens: + layers.append(nn.Flatten()) + in_size = out_channels + # Add (optional) post-fc-stack after last Conv2D layer. + for i, out_size in enumerate(post_fcnet_hiddens + + [num_outputs]): + layers.append( + SlimFC( + in_size=in_size, + out_size=out_size, + activation_fn=post_fcnet_activation + if i < len(post_fcnet_hiddens) - 1 else None, + initializer=normc_initializer(1.0))) + in_size = out_size + # Last layer is logits layer. + self._logits = layers.pop() + + else: + self._logits = SlimConv2d( + out_channels, + num_outputs, [1, 1], + 1, + padding, + activation_fn=None) + # num_outputs not known -> Flatten, then set self.num_outputs # to the resulting number of nodes. else: @@ -196,16 +237,19 @@ class VisionNetwork(TorchModelV2, nn.Module): if not self.last_layer_is_flattened: if self._logits: conv_out = self._logits(conv_out) - if conv_out.shape[2] != 1 or conv_out.shape[3] != 1: - raise ValueError( - "Given `conv_filters` ({}) do not result in a [B, {} " - "(`num_outputs`), 1, 1] shape (but in {})! Please adjust " - "your Conv2D stack such that the last 2 dims are both " - "1.".format(self.model_config["conv_filters"], - self.num_outputs, list(conv_out.shape))) - logits = conv_out.squeeze(3) - logits = logits.squeeze(2) - + if len(conv_out.shape) == 4: + if conv_out.shape[2] != 1 or conv_out.shape[3] != 1: + raise ValueError( + "Given `conv_filters` ({}) do not result in a [B, {} " + "(`num_outputs`), 1, 1] shape (but in {})! Please " + "adjust your Conv2D stack such that the last 2 dims " + "are both 1.".format(self.model_config["conv_filters"], + self.num_outputs, + list(conv_out.shape))) + logits = conv_out.squeeze(3) + logits = logits.squeeze(2) + else: + logits = conv_out return logits, state else: return conv_out, state diff --git a/rllib/policy/tests/test_compute_log_likelihoods.py b/rllib/policy/tests/test_compute_log_likelihoods.py index b64eabd47..77c52d44b 100644 --- a/rllib/policy/tests/test_compute_log_likelihoods.py +++ b/rllib/policy/tests/test_compute_log_likelihoods.py @@ -177,8 +177,8 @@ class TestComputeLogLikelihood(unittest.TestCase): config, prev_a, continuous=True, - layer_key=("sequential/action", (2, 4), - ("action_model.action_0.", "action_model.action_out.")), + layer_key=("fc", (0, 2), ("action_model._hidden_layers.0.", + "action_model._logits.")), logp_func=logp_func) def test_sac_discr(self): @@ -188,12 +188,7 @@ class TestComputeLogLikelihood(unittest.TestCase): config["policy_model"]["fcnet_activation"] = "linear" prev_a = np.array(0) - do_test_log_likelihood( - sac.SACTrainer, - config, - prev_a, - layer_key=("sequential/action", (0, 2), - ("action_model.action_0.", "action_model.action_out."))) + do_test_log_likelihood(sac.SACTrainer, config, prev_a) if __name__ == "__main__": diff --git a/rllib/tests/run_regression_tests.py b/rllib/tests/run_regression_tests.py index 3f42147e4..cc2650425 100644 --- a/rllib/tests/run_regression_tests.py +++ b/rllib/tests/run_regression_tests.py @@ -37,6 +37,10 @@ parser.add_argument( "--yaml-dir", type=str, help="The directory in which to find all yamls to test.") +parser.add_argument( + "--local-mode", + action="store_true", + help="Run ray in local mode for easier debugging.") # Obsoleted arg, use --framework=torch instead. parser.add_argument( @@ -92,7 +96,7 @@ if __name__ == "__main__": passed = False for i in range(3): try: - ray.init(num_cpus=5) + ray.init(num_cpus=5, local_mode=args.local_mode) trials = run_experiments(experiments, resume=False, verbose=2) finally: ray.shutdown() diff --git a/rllib/tests/test_nested_observation_spaces.py b/rllib/tests/test_nested_observation_spaces.py index 1a10e8c71..e1aac7b42 100644 --- a/rllib/tests/test_nested_observation_spaces.py +++ b/rllib/tests/test_nested_observation_spaces.py @@ -333,7 +333,7 @@ class NestedSpacesTest(unittest.TestCase): def test_invalid_model2(self): ModelCatalog.register_custom_model("invalid2", InvalidModel2) self.assertRaisesRegexp( - ValueError, "Expected output shape of", + ValueError, "State output is not a list", lambda: PGTrainer( env="CartPole-v0", config={ "model": { diff --git a/rllib/tests/test_supported_spaces.py b/rllib/tests/test_supported_spaces.py index 39a7ebb93..40bba43b2 100644 --- a/rllib/tests/test_supported_spaces.py +++ b/rllib/tests/test_supported_spaces.py @@ -15,7 +15,7 @@ from ray.rllib.utils.test_utils import framework_iterator ACTION_SPACES_TO_TEST = { "discrete": Discrete(5), "vector": Box(-1.0, 1.0, (5, ), dtype=np.float32), - # "vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32), + "vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32), "multidiscrete": MultiDiscrete([1, 2, 3, 4]), "tuple": Tuple( [Discrete(2), @@ -63,8 +63,6 @@ def check_support(alg, config, train=True, check_bounds=False, tfe=False): p_done=1.0, check_action_bounds=check_bounds))) stat = "ok" - if alg == "SAC": - config["use_state_preprocessor"] = o_name in ["atari", "image"] try: a = get_agent_class(alg)(config=config, env=RandomEnv) diff --git a/rllib/tuned_examples/sac/atari-sac.yaml b/rllib/tuned_examples/sac/atari-sac.yaml index 28c6d26db..4efca8620 100644 --- a/rllib/tuned_examples/sac/atari-sac.yaml +++ b/rllib/tuned_examples/sac/atari-sac.yaml @@ -14,8 +14,6 @@ atari-sac-tf-and-torch: framework: grid_search: [tf, torch] gamma: 0.99 - # state-preprocessor=Our default Atari Conv2D-net. - use_state_preprocessor: true Q_model: hidden_activation: relu hidden_layer_sizes: [512] diff --git a/rllib/tuned_examples/sac/mspacman-sac.yaml b/rllib/tuned_examples/sac/mspacman-sac.yaml index 50883b114..9d563884b 100644 --- a/rllib/tuned_examples/sac/mspacman-sac.yaml +++ b/rllib/tuned_examples/sac/mspacman-sac.yaml @@ -11,8 +11,6 @@ mspacman-sac-tf: # Works for both torch and tf. framework: tf gamma: 0.99 - # state-preprocessor=Our default Atari Conv2D-net. - use_state_preprocessor: true Q_model: fcnet_hiddens: [512] fcnet_activation: relu diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index eda9d1cfa..89a402117 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -301,13 +301,10 @@ def check_compute_single_action(trainer, assert worker_set if isinstance(worker_set, list): obs_space = trainer.get_policy().observation_space - try: - obs_space = obs_space.original_space - except AttributeError: - pass else: obs_space = worker_set.local_worker().for_policy( lambda p: p.observation_space) + obs_space = getattr(obs_space, "original_space", obs_space) else: method_to_test = pol.compute_single_action obs_space = pol.observation_space diff --git a/rllib/utils/threading.py b/rllib/utils/threading.py index 7361dad65..adc7dfe10 100644 --- a/rllib/utils/threading.py +++ b/rllib/utils/threading.py @@ -22,6 +22,6 @@ def with_lock(func: Callable): except AttributeError: raise AttributeError( "Object {} must have a `self._lock` property (assigned to a " - "threading.Lock() object in its constructor)!".format(self)) + "threading.RLock() object in its constructor)!".format(self)) return wrapper