mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
Split _get_node_provider_cls off from _get_node_provider (#11949)
This commit is contained in:
@@ -131,18 +131,47 @@ def _load_class(path):
|
||||
return getattr(module, class_str)
|
||||
|
||||
|
||||
def _get_node_provider(provider_config: Dict[str, Any],
|
||||
cluster_name: str,
|
||||
use_cache: bool = True) -> Any:
|
||||
def _get_node_provider_cls(provider_config: Dict[str, Any]):
|
||||
"""Get the node provider class for a given provider config.
|
||||
|
||||
Note that this may be used by private node providers that proxy methods to
|
||||
built-in node providers, so we should maintain backwards compatibility.
|
||||
|
||||
Args:
|
||||
provider_config: provider section of the autoscaler config.
|
||||
|
||||
Returns:
|
||||
NodeProvider class
|
||||
"""
|
||||
importer = _NODE_PROVIDERS.get(provider_config["type"])
|
||||
if importer is None:
|
||||
raise NotImplementedError("Unsupported node provider: {}".format(
|
||||
provider_config["type"]))
|
||||
provider_cls = importer(provider_config)
|
||||
return importer(provider_config)
|
||||
|
||||
|
||||
def _get_node_provider(provider_config: Dict[str, Any],
|
||||
cluster_name: str,
|
||||
use_cache: bool = True) -> Any:
|
||||
"""Get the instantiated node provider for a given provider config.
|
||||
|
||||
Note that this may be used by private node providers that proxy methods to
|
||||
built-in node providers, so we should maintain backwards compatibility.
|
||||
|
||||
Args:
|
||||
provider_config: provider section of the autoscaler config.
|
||||
cluster_name: cluster name from the autoscaler config.
|
||||
use_cache: whether or not to use a cached definition if available. If
|
||||
False, the returned object will also not be stored in the cache.
|
||||
|
||||
Returns:
|
||||
NodeProvider
|
||||
"""
|
||||
provider_key = (json.dumps(provider_config, sort_keys=True), cluster_name)
|
||||
if use_cache and provider_key in _provider_instances:
|
||||
return _provider_instances[provider_key]
|
||||
|
||||
provider_cls = _get_node_provider_cls(provider_config)
|
||||
new_provider = provider_cls(provider_config, cluster_name)
|
||||
|
||||
if use_cache:
|
||||
|
||||
Reference in New Issue
Block a user