From 5d5643e63375928b941e4e6f9d2def60cfe7a86f Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Fri, 7 Aug 2020 12:04:17 +0200 Subject: [PATCH] [RLlib] Add informative error message when bad Conv2D stack is used with fixed `num_outputs` (no flattening at end). (#9966) --- rllib/models/tf/visionnet.py | 24 ++++++++++++++++++------ rllib/models/torch/visionnet.py | 22 ++++++++++++++++------ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/rllib/models/tf/visionnet.py b/rllib/models/tf/visionnet.py index 97f8bcf5d..86136515d 100644 --- a/rllib/models/tf/visionnet.py +++ b/rllib/models/tf/visionnet.py @@ -11,15 +11,17 @@ class VisionNetwork(TFModelV2): def __init__(self, obs_space, action_space, num_outputs, model_config, name): + if not model_config.get("conv_filters"): + model_config["conv_filters"] = _get_filter_config(obs_space.shape) + super(VisionNetwork, self).__init__(obs_space, action_space, num_outputs, model_config, name) - activation = get_activation_fn(model_config.get("conv_activation")) - filters = model_config.get("conv_filters") - if not filters: - filters = _get_filter_config(obs_space.shape) - no_final_linear = model_config.get("no_final_linear") - vf_share_layers = model_config.get("vf_share_layers") + activation = get_activation_fn( + self.model_config.get("conv_activation"), framework="tf") + filters = self.model_config["conv_filters"] + no_final_linear = self.model_config.get("no_final_linear") + vf_share_layers = self.model_config.get("vf_share_layers") inputs = tf.keras.layers.Input( shape=obs_space.shape, name="observations") @@ -73,6 +75,16 @@ class VisionNetwork(TFModelV2): padding="same", data_format="channels_last", name="conv_out")(last_layer) + + if conv_out.shape[1] != 1 or conv_out.shape[2] != 1: + raise ValueError( + "Given `conv_filters` ({}) do not result in a [B, 1, " + "1, {} (`num_outputs`)] shape (but in {})! Please " + "adjust your Conv2D stack such that the dims 1 and 2 " + "are both 1.".format( + self.model_config["conv_filters"], + self.num_outputs, list(conv_out.shape))) + # num_outputs not known -> Flatten, then set self.num_outputs # to the resulting number of nodes. else: diff --git a/rllib/models/torch/visionnet.py b/rllib/models/torch/visionnet.py index 31e262cc7..b9aff6a9e 100644 --- a/rllib/models/torch/visionnet.py +++ b/rllib/models/torch/visionnet.py @@ -15,17 +15,18 @@ class VisionNetwork(TorchModelV2, nn.Module): def __init__(self, obs_space, action_space, num_outputs, model_config, name): + if not model_config.get("conv_filters"): + model_config["conv_filters"] = _get_filter_config(obs_space.shape) + TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) activation = get_activation_fn( - model_config.get("conv_activation"), framework="torch") - filters = model_config.get("conv_filters") - if not filters: - filters = _get_filter_config(obs_space.shape) - no_final_linear = model_config.get("no_final_linear") - vf_share_layers = model_config.get("vf_share_layers") + self.model_config.get("conv_activation"), framework="torch") + filters = self.model_config["conv_filters"] + no_final_linear = self.model_config.get("no_final_linear") + vf_share_layers = self.model_config.get("vf_share_layers") # Whether the last layer is the output of a Flattened (rather than # a n x (1,1) Conv2D). @@ -152,8 +153,17 @@ class VisionNetwork(TorchModelV2, nn.Module): if not self.last_layer_is_flattened: if self._logits: conv_out = self._logits(conv_out) + if conv_out.shape[2] != 1 or conv_out.shape[3] != 1: + raise ValueError( + "Given `conv_filters` ({}) do not result in a [B, {} " + "(`num_outputs`), 1, 1] shape (but in {})! Please adjust " + "your Conv2D stack such that the last 2 dims are both " + "1.".format( + self.model_config["conv_filters"], self.num_outputs, + list(conv_out.shape))) logits = conv_out.squeeze(3) logits = logits.squeeze(2) + return logits, state else: return conv_out, state