diff --git a/python/ray/autoscaler/_private/providers.py b/python/ray/autoscaler/_private/providers.py index c9ffccaaa..54be5dfd4 100644 --- a/python/ray/autoscaler/_private/providers.py +++ b/python/ray/autoscaler/_private/providers.py @@ -131,18 +131,47 @@ def _load_class(path): return getattr(module, class_str) -def _get_node_provider(provider_config: Dict[str, Any], - cluster_name: str, - use_cache: bool = True) -> Any: +def _get_node_provider_cls(provider_config: Dict[str, Any]): + """Get the node provider class for a given provider config. + + Note that this may be used by private node providers that proxy methods to + built-in node providers, so we should maintain backwards compatibility. + + Args: + provider_config: provider section of the autoscaler config. + + Returns: + NodeProvider class + """ importer = _NODE_PROVIDERS.get(provider_config["type"]) if importer is None: raise NotImplementedError("Unsupported node provider: {}".format( provider_config["type"])) - provider_cls = importer(provider_config) + return importer(provider_config) + + +def _get_node_provider(provider_config: Dict[str, Any], + cluster_name: str, + use_cache: bool = True) -> Any: + """Get the instantiated node provider for a given provider config. + + Note that this may be used by private node providers that proxy methods to + built-in node providers, so we should maintain backwards compatibility. + + Args: + provider_config: provider section of the autoscaler config. + cluster_name: cluster name from the autoscaler config. + use_cache: whether or not to use a cached definition if available. If + False, the returned object will also not be stored in the cache. + + Returns: + NodeProvider + """ provider_key = (json.dumps(provider_config, sort_keys=True), cluster_name) if use_cache and provider_key in _provider_instances: return _provider_instances[provider_key] + provider_cls = _get_node_provider_cls(provider_config) new_provider = provider_cls(provider_config, cluster_name) if use_cache: