diff --git a/python/ray/autoscaler/aws/node_provider.py b/python/ray/autoscaler/aws/node_provider.py index 6b30870bf..44dbc21c8 100644 --- a/python/ray/autoscaler/aws/node_provider.py +++ b/python/ray/autoscaler/aws/node_provider.py @@ -9,6 +9,7 @@ import botocore from botocore.config import Config from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.aws.config import bootstrap_aws from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME, \ TAG_RAY_LAUNCH_CONFIG, TAG_RAY_NODE_TYPE, TAG_RAY_INSTANCE_TYPE from ray.ray_constants import BOTO_MAX_RETRIES, BOTO_CREATE_MAX_RETRIES @@ -399,3 +400,7 @@ class AWSNodeProvider(NodeProvider): def cleanup(self): self.tag_cache_update_event.set() self.tag_cache_kill_event.set() + + @staticmethod + def bootstrap_config(cluster_config): + return bootstrap_aws(cluster_config) diff --git a/python/ray/autoscaler/azure/node_provider.py b/python/ray/autoscaler/azure/node_provider.py index 62b09393f..69c197e61 100644 --- a/python/ray/autoscaler/azure/node_provider.py +++ b/python/ray/autoscaler/azure/node_provider.py @@ -13,6 +13,7 @@ from azure.mgmt.resource.resources.models import DeploymentMode from knack.util import CLIError from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.azure.config import bootstrap_azure from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME VM_NAME_MAX_LEN = 64 @@ -291,3 +292,7 @@ class AzureNodeProvider(NodeProvider): if node_id in self.cached_nodes: return self.cached_nodes[node_id] return self._get_node(node_id=node_id) + + @staticmethod + def bootstrap_config(cluster_config): + return bootstrap_azure(cluster_config) diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index 1840dcdc6..157fc82c1 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -117,8 +117,8 @@ def _bootstrap_config(config, no_config_cache=False): raise NotImplementedError("Unsupported provider {}".format( config["provider"])) - bootstrap_config, _ = importer() - resolved_config = bootstrap_config(config) + provider_cls = importer(config["provider"]) + resolved_config = provider_cls.bootstrap_config(config) if not no_config_cache: with open(cache_key, "w") as f: f.write(json.dumps(resolved_config)) diff --git a/python/ray/autoscaler/gcp/node_provider.py b/python/ray/autoscaler/gcp/node_provider.py index 58f8fff38..cbc4aecca 100644 --- a/python/ray/autoscaler/gcp/node_provider.py +++ b/python/ray/autoscaler/gcp/node_provider.py @@ -4,6 +4,7 @@ import time import logging from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.gcp.config import bootstrap_gcp from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL, \ construct_clients_from_provider_config @@ -238,3 +239,7 @@ class GCPNodeProvider(NodeProvider): return self.cached_nodes[node_id] return self._get_node(node_id) + + @staticmethod + def bootstrap_config(cluster_config): + return bootstrap_gcp(cluster_config) diff --git a/python/ray/autoscaler/kubernetes/node_provider.py b/python/ray/autoscaler/kubernetes/node_provider.py index 88c3a6d83..2c4a4e40f 100644 --- a/python/ray/autoscaler/kubernetes/node_provider.py +++ b/python/ray/autoscaler/kubernetes/node_provider.py @@ -2,6 +2,7 @@ import logging from ray.autoscaler.kubernetes import core_api, log_prefix from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.kubernetes.config import bootstrap_kubernetes from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME from ray.autoscaler.updater import KubernetesCommandRunner @@ -97,3 +98,7 @@ class KubernetesNodeProvider(NodeProvider): docker_config=None): return KubernetesCommandRunner(log_prefix, self.namespace, node_id, auth_config, process_runner) + + @staticmethod + def bootstrap_config(cluster_config): + return bootstrap_kubernetes(cluster_config) diff --git a/python/ray/autoscaler/local/node_provider.py b/python/ray/autoscaler/local/node_provider.py index b9961f9f7..12414b43b 100644 --- a/python/ray/autoscaler/local/node_provider.py +++ b/python/ray/autoscaler/local/node_provider.py @@ -6,6 +6,7 @@ import socket import logging from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.local.config import bootstrap_local from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, NODE_TYPE_WORKER, \ NODE_TYPE_HEAD @@ -147,3 +148,7 @@ class LocalNodeProvider(NodeProvider): info = workers[node_id] info["state"] = "terminated" self.state.put(node_id, info) + + @staticmethod + def bootstrap_config(cluster_config): + return bootstrap_local(cluster_config) diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index 56344b402..7698a0bd4 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -8,34 +8,29 @@ from ray.autoscaler.command_runner import SSHCommandRunner, DockerCommandRunner logger = logging.getLogger(__name__) -def import_aws(): - from ray.autoscaler.aws.config import bootstrap_aws +def import_aws(provider_config): from ray.autoscaler.aws.node_provider import AWSNodeProvider - return bootstrap_aws, AWSNodeProvider + return AWSNodeProvider -def import_gcp(): - from ray.autoscaler.gcp.config import bootstrap_gcp +def import_gcp(provider_config): from ray.autoscaler.gcp.node_provider import GCPNodeProvider - return bootstrap_gcp, GCPNodeProvider + return GCPNodeProvider -def import_azure(): - from ray.autoscaler.azure.config import bootstrap_azure +def import_azure(provider_config): from ray.autoscaler.azure.node_provider import AzureNodeProvider - return bootstrap_azure, AzureNodeProvider + return AzureNodeProvider -def import_local(): - from ray.autoscaler.local.config import bootstrap_local +def import_local(provider_config): from ray.autoscaler.local.node_provider import LocalNodeProvider - return bootstrap_local, LocalNodeProvider + return LocalNodeProvider -def import_kubernetes(): - from ray.autoscaler.kubernetes.config import bootstrap_kubernetes +def import_kubernetes(provider_config): from ray.autoscaler.kubernetes.node_provider import KubernetesNodeProvider - return bootstrap_kubernetes, KubernetesNodeProvider + return KubernetesNodeProvider def load_local_example_config(): @@ -66,13 +61,9 @@ def load_azure_example_config(): os.path.dirname(ray_azure.__file__), "example-full.yaml") -def import_external(): - """Mock a normal provider importer.""" - - def return_it_back(config): - return config - - return return_it_back, None +def import_external(provider_config): + provider_cls = load_class(path=provider_config["module"]) + return provider_cls NODE_PROVIDERS = { @@ -112,16 +103,11 @@ def load_class(path): def get_node_provider(provider_config, cluster_name): - if provider_config["type"] == "external": - provider_cls = load_class(path=provider_config["module"]) - return provider_cls(provider_config, cluster_name) - importer = NODE_PROVIDERS.get(provider_config["type"]) - if importer is None: raise NotImplementedError("Unsupported node provider: {}".format( provider_config["type"])) - _, provider_cls = importer() + provider_cls = importer(provider_config) return provider_cls(provider_config, cluster_name) @@ -227,6 +213,11 @@ class NodeProvider: demand scheduler.""" return None + @staticmethod + def bootstrap_config(cluster_config): + """Bootstraps the cluster config by adding env defaults if needed.""" + return cluster_config + def get_command_runner(self, log_prefix, node_id, diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index 504985cb8..0d8882f8f 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -280,7 +280,7 @@ class LoadMetricsTest(unittest.TestCase): class AutoscalingTest(unittest.TestCase): def setUp(self): NODE_PROVIDERS["mock"] = \ - lambda: (None, self.create_provider) + lambda config: self.create_provider self.provider = None self.tmpdir = tempfile.mkdtemp()