diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 8e3e43dd0..6d0bfd111 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -204,8 +204,8 @@ class ModelCatalog: "Using custom action distribution {}".format(action_dist_name)) dist_cls = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name) - dist_cls = ModelCatalog._get_multi_action_distribution( - dist_cls, action_space, {}, framework) + return ModelCatalog._get_multi_action_distribution( + dist_cls, action_space, config, framework) # Dist_type is given directly as a class. elif type(dist_type) is type and \ @@ -740,7 +740,8 @@ class ModelCatalog: action_space=action_space, child_distributions=child_dists, input_lens=input_lens), int(sum(input_lens)) - return dist_class + return dist_class, dist_class.required_model_output_shape( + action_space, config) @staticmethod def _validate_config(config: ModelConfigDict, framework: str) -> None: