[Autoscaler] Making bootstrap config part of the node provider interface (#9443)

* supporting custom bootstrap config for external node providers

* bootstrap config

* renamed config to cluster_config

* lint

* remove 2 args from importer

* complete move of bootstrap to node_provider

* renamed provider_cls

* move imports outside functions

* lint

* Update python/ray/autoscaler/node_provider.py

Co-authored-by: Eric Liang <ekhliang@gmail.com>

* final fixes

* keeping lines to reduce diff

* lint

* lamba config

* filling in -> adding for lint

Co-authored-by: Ameer Haj Ali <ameerhajali@Ameers-MacBook-Pro.local>
Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
Ameer Haj Ali
2020-07-16 19:54:20 +03:00
committed by GitHub
parent 63e052a5f3
commit 1e46d4e29f
8 changed files with 47 additions and 31 deletions
@@ -9,6 +9,7 @@ import botocore
from botocore.config import Config
from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.aws.config import bootstrap_aws
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME, \
TAG_RAY_LAUNCH_CONFIG, TAG_RAY_NODE_TYPE, TAG_RAY_INSTANCE_TYPE
from ray.ray_constants import BOTO_MAX_RETRIES, BOTO_CREATE_MAX_RETRIES
@@ -399,3 +400,7 @@ class AWSNodeProvider(NodeProvider):
def cleanup(self):
self.tag_cache_update_event.set()
self.tag_cache_kill_event.set()
@staticmethod
def bootstrap_config(cluster_config):
return bootstrap_aws(cluster_config)
@@ -13,6 +13,7 @@ from azure.mgmt.resource.resources.models import DeploymentMode
from knack.util import CLIError
from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.azure.config import bootstrap_azure
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
VM_NAME_MAX_LEN = 64
@@ -291,3 +292,7 @@ class AzureNodeProvider(NodeProvider):
if node_id in self.cached_nodes:
return self.cached_nodes[node_id]
return self._get_node(node_id=node_id)
@staticmethod
def bootstrap_config(cluster_config):
return bootstrap_azure(cluster_config)
+2 -2
View File
@@ -117,8 +117,8 @@ def _bootstrap_config(config, no_config_cache=False):
raise NotImplementedError("Unsupported provider {}".format(
config["provider"]))
bootstrap_config, _ = importer()
resolved_config = bootstrap_config(config)
provider_cls = importer(config["provider"])
resolved_config = provider_cls.bootstrap_config(config)
if not no_config_cache:
with open(cache_key, "w") as f:
f.write(json.dumps(resolved_config))
@@ -4,6 +4,7 @@ import time
import logging
from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.gcp.config import bootstrap_gcp
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL, \
construct_clients_from_provider_config
@@ -238,3 +239,7 @@ class GCPNodeProvider(NodeProvider):
return self.cached_nodes[node_id]
return self._get_node(node_id)
@staticmethod
def bootstrap_config(cluster_config):
return bootstrap_gcp(cluster_config)
@@ -2,6 +2,7 @@ import logging
from ray.autoscaler.kubernetes import core_api, log_prefix
from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.kubernetes.config import bootstrap_kubernetes
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME
from ray.autoscaler.updater import KubernetesCommandRunner
@@ -97,3 +98,7 @@ class KubernetesNodeProvider(NodeProvider):
docker_config=None):
return KubernetesCommandRunner(log_prefix, self.namespace, node_id,
auth_config, process_runner)
@staticmethod
def bootstrap_config(cluster_config):
return bootstrap_kubernetes(cluster_config)
@@ -6,6 +6,7 @@ import socket
import logging
from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.local.config import bootstrap_local
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, NODE_TYPE_WORKER, \
NODE_TYPE_HEAD
@@ -147,3 +148,7 @@ class LocalNodeProvider(NodeProvider):
info = workers[node_id]
info["state"] = "terminated"
self.state.put(node_id, info)
@staticmethod
def bootstrap_config(cluster_config):
return bootstrap_local(cluster_config)
+19 -28
View File
@@ -8,34 +8,29 @@ from ray.autoscaler.command_runner import SSHCommandRunner, DockerCommandRunner
logger = logging.getLogger(__name__)
def import_aws():
from ray.autoscaler.aws.config import bootstrap_aws
def import_aws(provider_config):
from ray.autoscaler.aws.node_provider import AWSNodeProvider
return bootstrap_aws, AWSNodeProvider
return AWSNodeProvider
def import_gcp():
from ray.autoscaler.gcp.config import bootstrap_gcp
def import_gcp(provider_config):
from ray.autoscaler.gcp.node_provider import GCPNodeProvider
return bootstrap_gcp, GCPNodeProvider
return GCPNodeProvider
def import_azure():
from ray.autoscaler.azure.config import bootstrap_azure
def import_azure(provider_config):
from ray.autoscaler.azure.node_provider import AzureNodeProvider
return bootstrap_azure, AzureNodeProvider
return AzureNodeProvider
def import_local():
from ray.autoscaler.local.config import bootstrap_local
def import_local(provider_config):
from ray.autoscaler.local.node_provider import LocalNodeProvider
return bootstrap_local, LocalNodeProvider
return LocalNodeProvider
def import_kubernetes():
from ray.autoscaler.kubernetes.config import bootstrap_kubernetes
def import_kubernetes(provider_config):
from ray.autoscaler.kubernetes.node_provider import KubernetesNodeProvider
return bootstrap_kubernetes, KubernetesNodeProvider
return KubernetesNodeProvider
def load_local_example_config():
@@ -66,13 +61,9 @@ def load_azure_example_config():
os.path.dirname(ray_azure.__file__), "example-full.yaml")
def import_external():
"""Mock a normal provider importer."""
def return_it_back(config):
return config
return return_it_back, None
def import_external(provider_config):
provider_cls = load_class(path=provider_config["module"])
return provider_cls
NODE_PROVIDERS = {
@@ -112,16 +103,11 @@ def load_class(path):
def get_node_provider(provider_config, cluster_name):
if provider_config["type"] == "external":
provider_cls = load_class(path=provider_config["module"])
return provider_cls(provider_config, cluster_name)
importer = NODE_PROVIDERS.get(provider_config["type"])
if importer is None:
raise NotImplementedError("Unsupported node provider: {}".format(
provider_config["type"]))
_, provider_cls = importer()
provider_cls = importer(provider_config)
return provider_cls(provider_config, cluster_name)
@@ -227,6 +213,11 @@ class NodeProvider:
demand scheduler."""
return None
@staticmethod
def bootstrap_config(cluster_config):
"""Bootstraps the cluster config by adding env defaults if needed."""
return cluster_config
def get_command_runner(self,
log_prefix,
node_id,
+1 -1
View File
@@ -280,7 +280,7 @@ class LoadMetricsTest(unittest.TestCase):
class AutoscalingTest(unittest.TestCase):
def setUp(self):
NODE_PROVIDERS["mock"] = \
lambda: (None, self.create_provider)
lambda config: self.create_provider
self.provider = None
self.tmpdir = tempfile.mkdtemp()