From 964689b280dd63b3192148dbfabf27db45d7e40b Mon Sep 17 00:00:00 2001 From: Jan Blumenkamp Date: Mon, 25 Jan 2021 11:42:39 +0000 Subject: [PATCH] [RLlib] Fix bug in ModelCatalog when using custom action distribution (#12846) * return tuple returned from _get_multi_action_distribution when using custom action dict * Always return dst_class and required_model_output_shape in _get_multi_action_distribution * pass model config to _get_multi_action_distribution --- rllib/models/catalog.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 8e3e43dd0..6d0bfd111 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -204,8 +204,8 @@ class ModelCatalog: "Using custom action distribution {}".format(action_dist_name)) dist_cls = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name) - dist_cls = ModelCatalog._get_multi_action_distribution( - dist_cls, action_space, {}, framework) + return ModelCatalog._get_multi_action_distribution( + dist_cls, action_space, config, framework) # Dist_type is given directly as a class. elif type(dist_type) is type and \ @@ -740,7 +740,8 @@ class ModelCatalog: action_space=action_space, child_distributions=child_dists, input_lens=input_lens), int(sum(input_lens)) - return dist_class + return dist_class, dist_class.required_model_output_shape( + action_space, config) @staticmethod def _validate_config(config: ModelConfigDict, framework: str) -> None: