mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[RLlib] Fix bug in ModelCatalog when using custom action distribution (#12846)
* return tuple returned from _get_multi_action_distribution when using custom action dict * Always return dst_class and required_model_output_shape in _get_multi_action_distribution * pass model config to _get_multi_action_distribution
This commit is contained in:
@@ -204,8 +204,8 @@ class ModelCatalog:
|
||||
"Using custom action distribution {}".format(action_dist_name))
|
||||
dist_cls = _global_registry.get(RLLIB_ACTION_DIST,
|
||||
action_dist_name)
|
||||
dist_cls = ModelCatalog._get_multi_action_distribution(
|
||||
dist_cls, action_space, {}, framework)
|
||||
return ModelCatalog._get_multi_action_distribution(
|
||||
dist_cls, action_space, config, framework)
|
||||
|
||||
# Dist_type is given directly as a class.
|
||||
elif type(dist_type) is type and \
|
||||
@@ -740,7 +740,8 @@ class ModelCatalog:
|
||||
action_space=action_space,
|
||||
child_distributions=child_dists,
|
||||
input_lens=input_lens), int(sum(input_lens))
|
||||
return dist_class
|
||||
return dist_class, dist_class.required_model_output_shape(
|
||||
action_space, config)
|
||||
|
||||
@staticmethod
|
||||
def _validate_config(config: ModelConfigDict, framework: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user