From 73a1cb702bd54a43ef7d5f41f4b0b4e9aa879c24 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 11 Nov 2020 16:10:46 -0600 Subject: [PATCH] Split _get_node_provider_cls off from _get_node_provider (#11949) --- python/ray/autoscaler/_private/providers.py | 37 ++++++++++++++++++--- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/python/ray/autoscaler/_private/providers.py b/python/ray/autoscaler/_private/providers.py index c9ffccaaa..54be5dfd4 100644 --- a/python/ray/autoscaler/_private/providers.py +++ b/python/ray/autoscaler/_private/providers.py @@ -131,18 +131,47 @@ def _load_class(path): return getattr(module, class_str) -def _get_node_provider(provider_config: Dict[str, Any], - cluster_name: str, - use_cache: bool = True) -> Any: +def _get_node_provider_cls(provider_config: Dict[str, Any]): + """Get the node provider class for a given provider config. + + Note that this may be used by private node providers that proxy methods to + built-in node providers, so we should maintain backwards compatibility. + + Args: + provider_config: provider section of the autoscaler config. + + Returns: + NodeProvider class + """ importer = _NODE_PROVIDERS.get(provider_config["type"]) if importer is None: raise NotImplementedError("Unsupported node provider: {}".format( provider_config["type"])) - provider_cls = importer(provider_config) + return importer(provider_config) + + +def _get_node_provider(provider_config: Dict[str, Any], + cluster_name: str, + use_cache: bool = True) -> Any: + """Get the instantiated node provider for a given provider config. + + Note that this may be used by private node providers that proxy methods to + built-in node providers, so we should maintain backwards compatibility. + + Args: + provider_config: provider section of the autoscaler config. + cluster_name: cluster name from the autoscaler config. + use_cache: whether or not to use a cached definition if available. If + False, the returned object will also not be stored in the cache. + + Returns: + NodeProvider + """ provider_key = (json.dumps(provider_config, sort_keys=True), cluster_name) if use_cache and provider_key in _provider_instances: return _provider_instances[provider_key] + provider_cls = _get_node_provider_cls(provider_config) new_provider = provider_cls(provider_config, cluster_name) if use_cache: