mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 20:24:03 +08:00
[Autoscaler] Making bootstrap config part of the node provider interface (#9443)
* supporting custom bootstrap config for external node providers * bootstrap config * renamed config to cluster_config * lint * remove 2 args from importer * complete move of bootstrap to node_provider * renamed provider_cls * move imports outside functions * lint * Update python/ray/autoscaler/node_provider.py Co-authored-by: Eric Liang <ekhliang@gmail.com> * final fixes * keeping lines to reduce diff * lint * lamba config * filling in -> adding for lint Co-authored-by: Ameer Haj Ali <ameerhajali@Ameers-MacBook-Pro.local> Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ import botocore
|
||||
from botocore.config import Config
|
||||
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.aws.config import bootstrap_aws
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME, \
|
||||
TAG_RAY_LAUNCH_CONFIG, TAG_RAY_NODE_TYPE, TAG_RAY_INSTANCE_TYPE
|
||||
from ray.ray_constants import BOTO_MAX_RETRIES, BOTO_CREATE_MAX_RETRIES
|
||||
@@ -399,3 +400,7 @@ class AWSNodeProvider(NodeProvider):
|
||||
def cleanup(self):
|
||||
self.tag_cache_update_event.set()
|
||||
self.tag_cache_kill_event.set()
|
||||
|
||||
@staticmethod
|
||||
def bootstrap_config(cluster_config):
|
||||
return bootstrap_aws(cluster_config)
|
||||
|
||||
@@ -13,6 +13,7 @@ from azure.mgmt.resource.resources.models import DeploymentMode
|
||||
from knack.util import CLIError
|
||||
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.azure.config import bootstrap_azure
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
|
||||
|
||||
VM_NAME_MAX_LEN = 64
|
||||
@@ -291,3 +292,7 @@ class AzureNodeProvider(NodeProvider):
|
||||
if node_id in self.cached_nodes:
|
||||
return self.cached_nodes[node_id]
|
||||
return self._get_node(node_id=node_id)
|
||||
|
||||
@staticmethod
|
||||
def bootstrap_config(cluster_config):
|
||||
return bootstrap_azure(cluster_config)
|
||||
|
||||
@@ -117,8 +117,8 @@ def _bootstrap_config(config, no_config_cache=False):
|
||||
raise NotImplementedError("Unsupported provider {}".format(
|
||||
config["provider"]))
|
||||
|
||||
bootstrap_config, _ = importer()
|
||||
resolved_config = bootstrap_config(config)
|
||||
provider_cls = importer(config["provider"])
|
||||
resolved_config = provider_cls.bootstrap_config(config)
|
||||
if not no_config_cache:
|
||||
with open(cache_key, "w") as f:
|
||||
f.write(json.dumps(resolved_config))
|
||||
|
||||
@@ -4,6 +4,7 @@ import time
|
||||
import logging
|
||||
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.gcp.config import bootstrap_gcp
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
|
||||
from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL, \
|
||||
construct_clients_from_provider_config
|
||||
@@ -238,3 +239,7 @@ class GCPNodeProvider(NodeProvider):
|
||||
return self.cached_nodes[node_id]
|
||||
|
||||
return self._get_node(node_id)
|
||||
|
||||
@staticmethod
|
||||
def bootstrap_config(cluster_config):
|
||||
return bootstrap_gcp(cluster_config)
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
|
||||
from ray.autoscaler.kubernetes import core_api, log_prefix
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.kubernetes.config import bootstrap_kubernetes
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME
|
||||
from ray.autoscaler.updater import KubernetesCommandRunner
|
||||
|
||||
@@ -97,3 +98,7 @@ class KubernetesNodeProvider(NodeProvider):
|
||||
docker_config=None):
|
||||
return KubernetesCommandRunner(log_prefix, self.namespace, node_id,
|
||||
auth_config, process_runner)
|
||||
|
||||
@staticmethod
|
||||
def bootstrap_config(cluster_config):
|
||||
return bootstrap_kubernetes(cluster_config)
|
||||
|
||||
@@ -6,6 +6,7 @@ import socket
|
||||
import logging
|
||||
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.local.config import bootstrap_local
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, NODE_TYPE_WORKER, \
|
||||
NODE_TYPE_HEAD
|
||||
|
||||
@@ -147,3 +148,7 @@ class LocalNodeProvider(NodeProvider):
|
||||
info = workers[node_id]
|
||||
info["state"] = "terminated"
|
||||
self.state.put(node_id, info)
|
||||
|
||||
@staticmethod
|
||||
def bootstrap_config(cluster_config):
|
||||
return bootstrap_local(cluster_config)
|
||||
|
||||
@@ -8,34 +8,29 @@ from ray.autoscaler.command_runner import SSHCommandRunner, DockerCommandRunner
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_aws():
|
||||
from ray.autoscaler.aws.config import bootstrap_aws
|
||||
def import_aws(provider_config):
|
||||
from ray.autoscaler.aws.node_provider import AWSNodeProvider
|
||||
return bootstrap_aws, AWSNodeProvider
|
||||
return AWSNodeProvider
|
||||
|
||||
|
||||
def import_gcp():
|
||||
from ray.autoscaler.gcp.config import bootstrap_gcp
|
||||
def import_gcp(provider_config):
|
||||
from ray.autoscaler.gcp.node_provider import GCPNodeProvider
|
||||
return bootstrap_gcp, GCPNodeProvider
|
||||
return GCPNodeProvider
|
||||
|
||||
|
||||
def import_azure():
|
||||
from ray.autoscaler.azure.config import bootstrap_azure
|
||||
def import_azure(provider_config):
|
||||
from ray.autoscaler.azure.node_provider import AzureNodeProvider
|
||||
return bootstrap_azure, AzureNodeProvider
|
||||
return AzureNodeProvider
|
||||
|
||||
|
||||
def import_local():
|
||||
from ray.autoscaler.local.config import bootstrap_local
|
||||
def import_local(provider_config):
|
||||
from ray.autoscaler.local.node_provider import LocalNodeProvider
|
||||
return bootstrap_local, LocalNodeProvider
|
||||
return LocalNodeProvider
|
||||
|
||||
|
||||
def import_kubernetes():
|
||||
from ray.autoscaler.kubernetes.config import bootstrap_kubernetes
|
||||
def import_kubernetes(provider_config):
|
||||
from ray.autoscaler.kubernetes.node_provider import KubernetesNodeProvider
|
||||
return bootstrap_kubernetes, KubernetesNodeProvider
|
||||
return KubernetesNodeProvider
|
||||
|
||||
|
||||
def load_local_example_config():
|
||||
@@ -66,13 +61,9 @@ def load_azure_example_config():
|
||||
os.path.dirname(ray_azure.__file__), "example-full.yaml")
|
||||
|
||||
|
||||
def import_external():
|
||||
"""Mock a normal provider importer."""
|
||||
|
||||
def return_it_back(config):
|
||||
return config
|
||||
|
||||
return return_it_back, None
|
||||
def import_external(provider_config):
|
||||
provider_cls = load_class(path=provider_config["module"])
|
||||
return provider_cls
|
||||
|
||||
|
||||
NODE_PROVIDERS = {
|
||||
@@ -112,16 +103,11 @@ def load_class(path):
|
||||
|
||||
|
||||
def get_node_provider(provider_config, cluster_name):
|
||||
if provider_config["type"] == "external":
|
||||
provider_cls = load_class(path=provider_config["module"])
|
||||
return provider_cls(provider_config, cluster_name)
|
||||
|
||||
importer = NODE_PROVIDERS.get(provider_config["type"])
|
||||
|
||||
if importer is None:
|
||||
raise NotImplementedError("Unsupported node provider: {}".format(
|
||||
provider_config["type"]))
|
||||
_, provider_cls = importer()
|
||||
provider_cls = importer(provider_config)
|
||||
return provider_cls(provider_config, cluster_name)
|
||||
|
||||
|
||||
@@ -227,6 +213,11 @@ class NodeProvider:
|
||||
demand scheduler."""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def bootstrap_config(cluster_config):
|
||||
"""Bootstraps the cluster config by adding env defaults if needed."""
|
||||
return cluster_config
|
||||
|
||||
def get_command_runner(self,
|
||||
log_prefix,
|
||||
node_id,
|
||||
|
||||
@@ -280,7 +280,7 @@ class LoadMetricsTest(unittest.TestCase):
|
||||
class AutoscalingTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
NODE_PROVIDERS["mock"] = \
|
||||
lambda: (None, self.create_provider)
|
||||
lambda config: self.create_provider
|
||||
self.provider = None
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user