[RLlib] Fix bug in ModelCatalog when using custom action distribution (#12846)

* return tuple returned from _get_multi_action_distribution when using custom action dict

* Always return dst_class and required_model_output_shape in _get_multi_action_distribution

* pass model config to _get_multi_action_distribution
This commit is contained in:
Jan Blumenkamp
2021-01-25 11:42:39 +00:00
committed by GitHub
parent 9423930bcc
commit 964689b280
+4 -3
View File
@@ -204,8 +204,8 @@ class ModelCatalog:
"Using custom action distribution {}".format(action_dist_name))
dist_cls = _global_registry.get(RLLIB_ACTION_DIST,
action_dist_name)
dist_cls = ModelCatalog._get_multi_action_distribution(
dist_cls, action_space, {}, framework)
return ModelCatalog._get_multi_action_distribution(
dist_cls, action_space, config, framework)
# Dist_type is given directly as a class.
elif type(dist_type) is type and \
@@ -740,7 +740,8 @@ class ModelCatalog:
action_space=action_space,
child_distributions=child_dists,
input_lens=input_lens), int(sum(input_lens))
return dist_class
return dist_class, dist_class.required_model_output_shape(
action_space, config)
@staticmethod
def _validate_config(config: ModelConfigDict, framework: str) -> None: