mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 05:45:24 +08:00
[RLlib] Add informative error message when bad Conv2D stack is used with fixed num_outputs (no flattening at end). (#9966)
This commit is contained in:
@@ -11,15 +11,17 @@ class VisionNetwork(TFModelV2):
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
if not model_config.get("conv_filters"):
|
||||
model_config["conv_filters"] = _get_filter_config(obs_space.shape)
|
||||
|
||||
super(VisionNetwork, self).__init__(obs_space, action_space,
|
||||
num_outputs, model_config, name)
|
||||
|
||||
activation = get_activation_fn(model_config.get("conv_activation"))
|
||||
filters = model_config.get("conv_filters")
|
||||
if not filters:
|
||||
filters = _get_filter_config(obs_space.shape)
|
||||
no_final_linear = model_config.get("no_final_linear")
|
||||
vf_share_layers = model_config.get("vf_share_layers")
|
||||
activation = get_activation_fn(
|
||||
self.model_config.get("conv_activation"), framework="tf")
|
||||
filters = self.model_config["conv_filters"]
|
||||
no_final_linear = self.model_config.get("no_final_linear")
|
||||
vf_share_layers = self.model_config.get("vf_share_layers")
|
||||
|
||||
inputs = tf.keras.layers.Input(
|
||||
shape=obs_space.shape, name="observations")
|
||||
@@ -73,6 +75,16 @@ class VisionNetwork(TFModelV2):
|
||||
padding="same",
|
||||
data_format="channels_last",
|
||||
name="conv_out")(last_layer)
|
||||
|
||||
if conv_out.shape[1] != 1 or conv_out.shape[2] != 1:
|
||||
raise ValueError(
|
||||
"Given `conv_filters` ({}) do not result in a [B, 1, "
|
||||
"1, {} (`num_outputs`)] shape (but in {})! Please "
|
||||
"adjust your Conv2D stack such that the dims 1 and 2 "
|
||||
"are both 1.".format(
|
||||
self.model_config["conv_filters"],
|
||||
self.num_outputs, list(conv_out.shape)))
|
||||
|
||||
# num_outputs not known -> Flatten, then set self.num_outputs
|
||||
# to the resulting number of nodes.
|
||||
else:
|
||||
|
||||
@@ -15,17 +15,18 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
if not model_config.get("conv_filters"):
|
||||
model_config["conv_filters"] = _get_filter_config(obs_space.shape)
|
||||
|
||||
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
activation = get_activation_fn(
|
||||
model_config.get("conv_activation"), framework="torch")
|
||||
filters = model_config.get("conv_filters")
|
||||
if not filters:
|
||||
filters = _get_filter_config(obs_space.shape)
|
||||
no_final_linear = model_config.get("no_final_linear")
|
||||
vf_share_layers = model_config.get("vf_share_layers")
|
||||
self.model_config.get("conv_activation"), framework="torch")
|
||||
filters = self.model_config["conv_filters"]
|
||||
no_final_linear = self.model_config.get("no_final_linear")
|
||||
vf_share_layers = self.model_config.get("vf_share_layers")
|
||||
|
||||
# Whether the last layer is the output of a Flattened (rather than
|
||||
# a n x (1,1) Conv2D).
|
||||
@@ -152,8 +153,17 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
||||
if not self.last_layer_is_flattened:
|
||||
if self._logits:
|
||||
conv_out = self._logits(conv_out)
|
||||
if conv_out.shape[2] != 1 or conv_out.shape[3] != 1:
|
||||
raise ValueError(
|
||||
"Given `conv_filters` ({}) do not result in a [B, {} "
|
||||
"(`num_outputs`), 1, 1] shape (but in {})! Please adjust "
|
||||
"your Conv2D stack such that the last 2 dims are both "
|
||||
"1.".format(
|
||||
self.model_config["conv_filters"], self.num_outputs,
|
||||
list(conv_out.shape)))
|
||||
logits = conv_out.squeeze(3)
|
||||
logits = logits.squeeze(2)
|
||||
|
||||
return logits, state
|
||||
else:
|
||||
return conv_out, state
|
||||
|
||||
Reference in New Issue
Block a user