From 74dc14d1fc4da4773d317f8570402cbe69d84f51 Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Thu, 31 May 2018 09:00:03 -0700 Subject: [PATCH] [autoscaler] GCP node provider (#2061) * Google Cloud Platform scaffolding * Add minimal gcp config example * Add googleapiclient discoveries, update gcp.config constants * Rename and update gcp.config key pair name function * Implement gcp.config._configure_project * Fix the create project get project flow * Implement gcp.config._configure_iam_role * Implement service account iam binding * Implement gcp.config._configure_key_pair * Implement rsa key pair generation * Implement gcp.config._configure_subnet * Save work-in-progress gcp.config._configure_firewall_rules. These are likely to be not needed at all. Saving them if we happen to need them later. * Remove unnecessary firewall configuration * Update example-minimal.yaml configuration * Add new wait_for_compute_operation, rename old wait_for_operation * Temporarily rename autoscaler tags due to gcp incompatibility * Implement initial gcp.node_provider.nodes * Still missing filter support * Implement initial gcp.node_provider.create_node * Implement another compute wait operation (wait_For_compute_zone_operation). TODO: figure out if we can remove the function. * Implement initial gcp.node_provider._node and node status functions * Implement initial gcp.node_provider.terminate_node * Implement node tagging and ip getter methods for nodes * Temporarily rename tags due to gcp incompatibility * Tiny tweaks for autoscaler.updater * Remove unused config from gcp node_provider * Add new example-full example to gcp, update load_gcp_example_config * Implement label filtering for gcp.node_provider.nodes * Revert unnecessary change in ssh command * Revert "Temporarily rename tags due to gcp incompatibility" This reverts commit e2fe634c5d11d705c0f5d3e76c80c37394bb23fb. * Revert "Temporarily rename autoscaler tags due to gcp incompatibility" This reverts commit c938ee435f4b75854a14e78242ad7f1d1ed8ad4b. * Refactor autoscaler tagging to support multiple tag specs * Remove missing cryptography imports * Update quote function import * Fix threading issue in gcp.config with the compute discovery object * Add gcs support for log_sync * Fix the labels/tags naming discrepancy * Add expanduser to file_mounts hashing * Fix gcp.node_provider.internal_ip * Add uuid to node name * Remove 'set -i' from updater ssh command * Also add TODO with the context and reason for the change. * Update ssh key creation in autoscaler.gcp.config * Fix wait_for_compute_zone_operation's threading issue Google discovery api's compute object is not thread safe, and thus needs to be recreated for each thread. This moves the `wait_for_compute_zone_operation` under `autoscaler.gcp.config`, and adds compute as its argument. * Address pr feedback from @ericl * Expand local file mount paths in NodeUpdater * Add ssh_user name to key names * Update updater ssh to attempt 'set -i' and fall back if that fails * Update gcp/example-full.yaml * Fix wait crm operation in gcp.config * Update gcp/example-minimal.yaml to match aws/example-minimal.yaml * Fix gcp/example-full.yaml comment indentation * Add gcp/example-full.yaml to setup files * Update example-full.yaml command * Revert "Refactor autoscaler tagging to support multiple tag specs" This reverts commit 9cf48409ca2e5b66f800153853072c706fa502f6. * Update tag spec to only use characters [0-9a-z_-] * Change the tag values to conform gcp spec * Add project_id in the ssh key name * Replace '_' with '-' in autoscaler tag names * Revert "Update updater ssh to attempt 'set -i' and fall back if that fails" This reverts commit 23a0066c5254449e49746bd5e43b94b66f32bfb4. * Revert "Remove 'set -i' from updater ssh command" This reverts commit 5fa034cdf79fa7f8903691518c0d75699c630172. * Add fallback to `set -i` in force_interactive command * Update autoscaler tests to match current implementation * Update GCPNodeProvider.create_node to include hash in instance name * Add support for creating multiple instance on one create_node call * Clean TODOs * Update styles * Replace single quotes with double quotes * Some minor indentation fixes etc. * Remove unnecessary comment. Fix indentation. * Yapfify files that fail flake8 test * Yapfify more files * Update project_id handling in gcp node provider * temporary yapf mod * Revert "temporary yapf mod" This reverts commit b6744e4e15d4d936d1a14f4bf155ed1d3bb14126. * Fix autoscaler/updater.py lint error, remove unused variable --- python/ray/autoscaler/autoscaler.py | 29 +- python/ray/autoscaler/commands.py | 11 +- python/ray/autoscaler/gcp/__init__.py | 0 python/ray/autoscaler/gcp/config.py | 427 ++++++++++++++++++ python/ray/autoscaler/gcp/example-full.yaml | 161 +++++++ .../ray/autoscaler/gcp/example-minimal.yaml | 17 + python/ray/autoscaler/gcp/node_provider.py | 213 +++++++++ python/ray/autoscaler/node_provider.py | 25 +- python/ray/autoscaler/tags.py | 19 +- python/ray/autoscaler/updater.py | 34 +- python/ray/tune/log_sync.py | 43 +- python/setup.py | 5 +- test/autoscaler_test.py | 14 +- 13 files changed, 936 insertions(+), 62 deletions(-) create mode 100644 python/ray/autoscaler/gcp/__init__.py create mode 100644 python/ray/autoscaler/gcp/config.py create mode 100644 python/ray/autoscaler/gcp/example-full.yaml create mode 100644 python/ray/autoscaler/gcp/example-minimal.yaml create mode 100644 python/ray/autoscaler/gcp/node_provider.py diff --git a/python/ray/autoscaler/autoscaler.py b/python/ray/autoscaler/autoscaler.py index a3f64c2d2..354081fdc 100644 --- a/python/ray/autoscaler/autoscaler.py +++ b/python/ray/autoscaler/autoscaler.py @@ -22,8 +22,9 @@ from ray.autoscaler.node_provider import get_node_provider, \ get_default_config from ray.autoscaler.updater import NodeUpdaterProcess from ray.autoscaler.docker import dockerize_if_needed -from ray.autoscaler.tags import TAG_RAY_LAUNCH_CONFIG, \ - TAG_RAY_RUNTIME_CONFIG, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE, TAG_NAME +from ray.autoscaler.tags import (TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG, + TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE, + TAG_RAY_NODE_NAME) import ray.services as services REQUIRED, OPTIONAL = True, False @@ -58,6 +59,7 @@ CLUSTER_CONFIG_SCHEMA = { "availability_zone": (str, OPTIONAL), # e.g. us-east-1a "module": (str, OPTIONAL), # module, if using external node provider + "project_id": (None, OPTIONAL), # gcp project id, if using gcp }, REQUIRED), @@ -244,6 +246,14 @@ class StandardAutoscaler(object): self.last_update_time = 0.0 self.update_interval_s = update_interval_s + # Expand local file_mounts to allow ~ in the paths. This can't be done + # earlier when the config is written since we might be on different + # platform and the expansion would result in wrong path. + self.config["file_mounts"] = { + remote: os.path.expanduser(local) + for remote, local in self.config["file_mounts"].items() + } + for local_path in self.config["file_mounts"].values(): assert os.path.exists(local_path) @@ -254,8 +264,8 @@ class StandardAutoscaler(object): self.reload_config(errors_fatal=False) self._update() except Exception as e: - print("StandardAutoscaler: Error during autoscaling: {}", - traceback.format_exc()) + print("StandardAutoscaler: Error during autoscaling: {}" + "".format(traceback.format_exc())) self.num_failures += 1 if self.num_failures > self.max_failures: print("*** StandardAutoscaler: Too many errors, abort. ***") @@ -446,9 +456,10 @@ class StandardAutoscaler(object): num_before = len(self.workers()) self.provider.create_node( self.config["worker_nodes"], { - TAG_NAME: "ray-{}-worker".format(self.config["cluster_name"]), - TAG_RAY_NODE_TYPE: "Worker", - TAG_RAY_NODE_STATUS: "Uninitialized", + TAG_RAY_NODE_NAME: "ray-{}-worker".format( + self.config["cluster_name"]), + TAG_RAY_NODE_TYPE: "worker", + TAG_RAY_NODE_STATUS: "uninitialized", TAG_RAY_LAUNCH_CONFIG: self.launch_hash, }, count) if len(self.workers()) <= num_before: @@ -456,7 +467,7 @@ class StandardAutoscaler(object): def workers(self): return self.provider.nodes(tag_filters={ - TAG_RAY_NODE_TYPE: "Worker", + TAG_RAY_NODE_TYPE: "worker", }) def debug_string(self, nodes=None): @@ -565,7 +576,7 @@ def hash_runtime_conf(file_mounts, extra_objs): with open(os.path.join(dirpath, name), "rb") as f: hasher.update(f.read()) else: - with open(path, 'r') as f: + with open(os.path.expanduser(path), "r") as f: hasher.update(f.read().encode("utf-8")) hasher.update(json.dumps(sorted(file_mounts.items())).encode("utf-8")) diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index f38e99ff4..2765b44a8 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -19,7 +19,7 @@ from ray.autoscaler.autoscaler import validate_config, hash_runtime_conf, \ hash_launch_conf, fillout_defaults from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \ - TAG_NAME + TAG_RAY_NODE_NAME from ray.autoscaler.updater import NodeUpdaterProcess @@ -57,7 +57,7 @@ def teardown_cluster(config_file, yes): provider = get_node_provider(config["provider"], config["cluster_name"]) head_node_tags = { - TAG_RAY_NODE_TYPE: "Head", + TAG_RAY_NODE_TYPE: "head", } for node in provider.nodes(head_node_tags): print("Terminating head node {}".format(node)) @@ -76,7 +76,7 @@ def get_or_create_head_node(config, no_restart, yes): provider = get_node_provider(config["provider"], config["cluster_name"]) head_node_tags = { - TAG_RAY_NODE_TYPE: "Head", + TAG_RAY_NODE_TYPE: "head", } nodes = provider.nodes(head_node_tags) if len(nodes) > 0: @@ -98,7 +98,8 @@ def get_or_create_head_node(config, no_restart, yes): provider.terminate_node(head_node) print("Launching new head node...") head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash - head_node_tags[TAG_NAME] = "ray-{}-head".format(config["cluster_name"]) + head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format( + config["cluster_name"]) provider.create_node(config["head_node"], head_node_tags, 1) nodes = provider.nodes(head_node_tags) @@ -185,7 +186,7 @@ def get_head_node_ip(config_file): config = yaml.load(open(config_file).read()) provider = get_node_provider(config["provider"], config["cluster_name"]) head_node_tags = { - TAG_RAY_NODE_TYPE: "Head", + TAG_RAY_NODE_TYPE: "head", } nodes = provider.nodes(head_node_tags) if len(nodes) > 0: diff --git a/python/ray/autoscaler/gcp/__init__.py b/python/ray/autoscaler/gcp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/autoscaler/gcp/config.py b/python/ray/autoscaler/gcp/config.py new file mode 100644 index 000000000..6df3cb777 --- /dev/null +++ b/python/ray/autoscaler/gcp/config.py @@ -0,0 +1,427 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.backends import default_backend +from googleapiclient import discovery, errors + +crm = discovery.build("cloudresourcemanager", "v1") +iam = discovery.build("iam", "v1") +compute = discovery.build("compute", "v1") + +VERSION = "v1" + +RAY = "ray-autoscaler" +DEFAULT_SERVICE_ACCOUNT_ID = RAY + "-sa-" + VERSION +SERVICE_ACCOUNT_EMAIL_TEMPLATE = ( + "{account_id}@{project_id}.iam.gserviceaccount.com") +DEFAULT_SERVICE_ACCOUNT_CONFIG = { + "displayName": "Ray Autoscaler Service Account ({})".format(VERSION), +} +DEFAULT_SERVICE_ACCOUNT_ROLES = ("roles/storage.objectAdmin", + "roles/compute.admin") + +MAX_POLLS = 12 +POLL_INTERVAL = 5 + + +def wait_for_crm_operation(operation): + """Poll for cloud resource manager operation until finished.""" + print("Waiting for operation {} to finish...".format(operation)) + + for _ in range(MAX_POLLS): + result = crm.operations().get(name=operation["name"]).execute() + if "error" in result: + raise Exception(result["error"]) + + if "done" in result and result["done"]: + print("Done.") + break + + time.sleep(POLL_INTERVAL) + + return result + + +def wait_for_compute_global_operation(project_name, operation): + """Poll for global compute operation until finished.""" + print("Waiting for operation {} to finish...".format(operation["name"])) + + for _ in range(MAX_POLLS): + result = compute.globalOperations().get( + project=project_name, + operation=operation["name"], + ).execute() + if "error" in result: + raise Exception(result["error"]) + + if result["status"] == "DONE": + print("Done.") + break + + time.sleep(POLL_INTERVAL) + + return result + + +def key_pair_name(i, region, project_id, ssh_user): + """Returns the ith default gcp_key_pair_name.""" + key_name = "{}_gcp_{}_{}_{}".format(RAY, region, project_id, ssh_user, i) + return key_name + + +def key_pair_paths(key_name): + """Returns public and private key paths for a given key_name.""" + public_key_path = os.path.expanduser("~/.ssh/{}.pub".format(key_name)) + private_key_path = os.path.expanduser("~/.ssh/{}.pem".format(key_name)) + return public_key_path, private_key_path + + +def generate_rsa_key_pair(): + """Create public and private ssh-keys.""" + + key = rsa.generate_private_key( + backend=default_backend(), public_exponent=65537, key_size=2048) + + public_key = key.public_key().public_bytes( + serialization.Encoding.OpenSSH, + serialization.PublicFormat.OpenSSH).decode("utf-8") + + pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption()).decode("utf-8") + + return public_key, pem + + +def bootstrap_gcp(config): + config = _configure_project(config) + config = _configure_iam_role(config) + config = _configure_key_pair(config) + config = _configure_subnet(config) + + return config + + +def _configure_project(config): + """Setup a Google Cloud Platform Project. + + Google Compute Platform organizes all the resources, such as storage + buckets, users, and instances under projects. This is different from + aws ec2 where everything is global. + """ + project_id = config["provider"].get("project_id") + assert config["provider"]["project_id"] is not None, ( + "'project_id' must be set in the 'provider' section of the autoscaler" + " config. Notice that the project id must be globally unique.") + + project = _get_project(project_id) + + if project is None: + # Project not found, try creating it + _create_project(project_id) + project = _get_project(project_id) + + assert project is not None, "Failed to create project" + assert project["lifecycleState"] == "ACTIVE", ( + "Project status needs to be ACTIVE, got {}".format( + project["lifecycleState"])) + + config["provider"]["project_id"] = project["projectId"] + + return config + + +def _configure_iam_role(config): + """Setup a gcp service account with IAM roles. + + Creates a gcp service acconut and binds IAM roles which allow it to control + control storage/compute services. Specifically, the head node needs to have + an IAM role that allows it to create further gce instances and store items + in google cloud storage. + + TODO: Allow the name/id of the service account to be configured + """ + email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( + account_id=DEFAULT_SERVICE_ACCOUNT_ID, + project_id=config["provider"]["project_id"]) + service_account = _get_service_account(email, config) + + if service_account is None: + print("Creating new service account {}".format( + DEFAULT_SERVICE_ACCOUNT_ID)) + + service_account = _create_service_account( + DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config) + + assert service_account is not None, "Failed to create service account" + + _add_iam_policy_binding(service_account, DEFAULT_SERVICE_ACCOUNT_ROLES) + + config["head_node"]["serviceAccounts"] = [{ + "email": service_account["email"], + # NOTE: The amount of access is determined by the scope + IAM + # role of the service account. Even if the cloud-platform scope + # gives (scope) access to the whole cloud-platform, the service + # account is limited by the IAM rights specified below. + "scopes": ["https://www.googleapis.com/auth/cloud-platform"] + }] + + return config + + +def _configure_key_pair(config): + """Configure SSH access, using an existing key pair if possible. + + Creates a project-wide ssh key that can be used to access all the instances + unless explicitly prohibited by instance config. + + The ssh-keys created by ray are of format: + + [USERNAME]:ssh-rsa [KEY_VALUE] [USERNAME] + + where: + + [USERNAME] is the user for the SSH key, specified in the config. + [KEY_VALUE] is the public SSH key value. + """ + + if "ssh_private_key" in config["auth"]: + return config + + ssh_user = config["auth"]["ssh_user"] + + project = compute.projects().get( + project=config["provider"]["project_id"]).execute() + + # Key pairs associated with project meta data. The key pairs are general, + # and not just ssh keys. + ssh_keys_str = next( + (item for item in project["commonInstanceMetadata"].get("items", []) + if item["key"] == "ssh-keys"), {}).get("value", "") + + ssh_keys = ssh_keys_str.split("\n") if ssh_keys_str else [] + + # Try a few times to get or create a good key pair. + key_found = False + for i in range(10): + key_name = key_pair_name(i, config["provider"]["region"], + config["provider"]["project_id"], ssh_user) + public_key_path, private_key_path = key_pair_paths(key_name) + + for ssh_key in ssh_keys: + key_parts = ssh_key.split(" ") + if len(key_parts) != 3: + continue + + if key_parts[2] == ssh_user and os.path.exists(private_key_path): + # Found a key + key_found = True + break + + # Create a key since it doesn't exist locally or in GCP + if not key_found and not os.path.exists(private_key_path): + print("Creating new key pair {}".format(key_name)) + public_key, private_key = generate_rsa_key_pair() + + _create_project_ssh_key_pair(project, public_key, ssh_user) + + with open(private_key_path, "w") as f: + f.write(private_key) + os.chmod(private_key_path, 0o600) + + with open(public_key_path, "w") as f: + f.write(public_key) + + key_found = True + + break + + if key_found: + break + + assert key_found, "SSH keypair for user {} not found for {}".format( + ssh_user, private_key_path) + assert os.path.exists(private_key_path), ( + "Private key file {} not found for user {}" + "".format(private_key_path, ssh_user)) + + print("Private key not specified in config, using {}" + "".format(private_key_path)) + + config["auth"]["ssh_private_key"] = private_key_path + + return config + + +def _configure_subnet(config): + """Pick a reasonable subnet if not specified by the config.""" + + subnets = _list_subnets(config) + + if not subnets: + raise NotImplementedError("Should be able to create subnet.") + + # TODO: make sure that we have usable subnet. Maybe call + # compute.subnetworks().listUsable? For some reason it didn't + # work out-of-the-box + default_subnet = subnets[0] + + if "networkInterfaces" not in config["head_node"]: + config["head_node"]["networkInterfaces"] = [{ + "subnetwork": default_subnet["selfLink"], + "accessConfigs": [{ + "name": "External NAT", + "type": "ONE_TO_ONE_NAT", + }], + }] + + if "networkInterfaces" not in config["worker_nodes"]: + config["worker_nodes"]["networkInterfaces"] = [{ + "subnetwork": default_subnet["selfLink"], + "accessConfigs": [{ + "name": "External NAT", + "type": "ONE_TO_ONE_NAT", + }], + }] + + return config + + +def _list_subnets(config): + response = compute.subnetworks().list( + project=config["provider"]["project_id"], + region=config["provider"]["region"]).execute() + + return response["items"] + + +def _get_subnet(config, subnet_id): + subnet = compute.subnetworks().get( + project=config["provider"]["project_id"], + region=config["provider"]["region"], + subnetwork=subnet_id, + ).execute() + + return subnet + + +def _get_project(project_id): + try: + project = crm.projects().get(projectId=project_id).execute() + except errors.HttpError as e: + if e.resp.status != 403: + raise + project = None + + return project + + +def _create_project(project_id): + operation = crm.projects().create(body={ + "projectId": project_id, + "name": project_id + }).execute() + + result = wait_for_crm_operation(operation) + + return result + + +def _get_service_account(account, config): + project_id = config["provider"]["project_id"] + full_name = ("projects/{project_id}/serviceAccounts/{account}" + "".format(project_id=project_id, account=account)) + try: + service_account = iam.projects().serviceAccounts().get( + name=full_name).execute() + except errors.HttpError as e: + if e.resp.status != 404: + raise + service_account = None + + return service_account + + +def _create_service_account(account_id, account_config, config): + project_id = config["provider"]["project_id"] + + service_account = iam.projects().serviceAccounts().create( + name="projects/{project_id}".format(project_id=project_id), + body={ + "accountId": account_id, + "serviceAccount": account_config, + }).execute() + + return service_account + + +def _add_iam_policy_binding(service_account, roles): + """Add new IAM roles for the service account.""" + project_id = service_account["projectId"] + email = service_account["email"] + member_id = "serviceAccount:" + email + + policy = crm.projects().getIamPolicy(resource=project_id).execute() + + for role in roles: + role_exists = False + for binding in policy["bindings"]: + if binding["role"] == role: + if member_id not in binding["members"]: + binding["members"].append(member_id) + role_exists = True + + if not role_exists: + policy["bindings"].append({ + "members": [member_id], + "role": role, + }) + + result = crm.projects().setIamPolicy( + resource=project_id, body={ + "policy": policy, + }).execute() + + return result + + +def _create_project_ssh_key_pair(project, public_key, ssh_user): + """Inserts an ssh-key into project commonInstanceMetadata""" + + key_parts = public_key.split(" ") + + # Sanity checks to make sure that the generated key matches expectation + assert len(key_parts) == 2, key_parts + assert key_parts[0] == "ssh-rsa", key_parts + + new_ssh_meta = "{ssh_user}:ssh-rsa {key_value} {ssh_user}".format( + ssh_user=ssh_user, key_value=key_parts[1]) + + common_instance_metadata = project["commonInstanceMetadata"] + items = common_instance_metadata.get("items", []) + + ssh_keys_i = next( + (i for i, item in enumerate(items) if item["key"] == "ssh-keys"), None) + + if ssh_keys_i is None: + items.append({"key": "ssh-keys", "value": new_ssh_meta}) + else: + ssh_keys = items[ssh_keys_i] + ssh_keys["value"] += "\n" + new_ssh_meta + items[ssh_keys_i] = ssh_keys + + common_instance_metadata["items"] = items + + operation = compute.projects().setCommonInstanceMetadata( + project=project["name"], body=common_instance_metadata).execute() + + response = wait_for_compute_global_operation(project["name"], operation) + + return response diff --git a/python/ray/autoscaler/gcp/example-full.yaml b/python/ray/autoscaler/gcp/example-full.yaml new file mode 100644 index 000000000..57f5dd282 --- /dev/null +++ b/python/ray/autoscaler/gcp/example-full.yaml @@ -0,0 +1,161 @@ +# An unique identifier for the head node and workers of this cluster. +cluster_name: default + +# The minimum number of workers nodes to launch in addition to the head +# node. This number should be >= 0. +min_workers: 0 + +# The maximum number of workers nodes to launch in addition to the head +# node. This takes precedence over min_workers. +max_workers: 2 + +# This executes all commands on all nodes in the docker container, +# and opens all the necessary ports to support the Ray cluster. +# Empty string means disabled. +docker: + image: "" # e.g., tensorflow/tensorflow:1.5.0-py3 + container_name: "" # e.g. ray_docker + + +# The autoscaler will scale up the cluster to this target fraction of resource +# usage. For example, if a cluster of 10 nodes is 100% busy and +# target_utilization is 0.8, it would resize the cluster to 13. This fraction +# can be decreased to increase the aggressiveness of upscaling. +# This value must be less than 1.0 for scaling to happen. +target_utilization_fraction: 0.8 + +# If a node is idle for this many minutes, it will be removed. +idle_timeout_minutes: 5 + +# Cloud-provider specific configuration. +provider: + type: gcp + region: us-west1 + availability_zone: us-west1-a + project_id: null # Globally unique project id + +# How Ray will authenticate with newly launched nodes. +auth: + ssh_user: ubuntu +# By default Ray creates a new private keypair, but you can also use your own. +# If you do so, make sure to also set "KeyName" in the head and worker node +# configurations below. This requires that you have added the key into the +# project wide meta-data. +# ssh_private_key: /path/to/your/key.pem + +# Provider-specific config for the head node, e.g. instance type. By default +# Ray will auto-configure unspecified fields such as subnets and ssh-keys. +# For more documentation on available fields, see: +# https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert +head_node: + machineType: n1-standard-2 + disks: + - boot: true + autoDelete: true + type: PERSISTENT + initializeParams: + diskSizeGb: 50 + # See https://cloud.google.com/compute/docs/images for more images + sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-1604-lts # Ubuntu + + # Additional options can be found in in the compute docs at + # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert + +worker_nodes: + machineType: n1-standard-2 + disks: + - boot: true + autoDelete: true + type: PERSISTENT + initializeParams: + diskSizeGb: 50 + # See https://cloud.google.com/compute/docs/images for more images + sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-1604-lts # Ubuntu + # Run workers on preemtible instance by default. + # Comment this out to use on-demand. + scheduling: + - preemptible: true + + # Additional options can be found in in the compute docs at + # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert + +# Files or directories to copy to the head and worker nodes. The format is a +# dictionary from REMOTE_PATH: LOCAL_PATH, e.g. +file_mounts: { +# "/path1/on/remote/machine": "/path1/on/local/machine", +# "/path2/on/remote/machine": "/path2/on/local/machine", +} + +# List of shell commands to run to set up nodes. +setup_commands: + # Consider uncommenting these if you also want to run apt-get commands during setup + # - sudo pkill -9 apt-get || true + # - sudo pkill -9 dpkg || true + # - sudo dpkg --configure -a + + # Install basics. + - sudo apt-get update + - >- + sudo apt-get install -y + cmake + pkg-config + build-essential + autoconf + curl + libtool + unzip + flex + bison + python + # Install Anaconda. + - >- + wget https://repo.continuum.io/archive/Anaconda3-5.0.1-Linux-x86_64.sh -O ~/anaconda3.sh + || true + - bash ~/anaconda3.sh -b -p ~/anaconda3 || true + - rm ~/anaconda3.sh + - echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.bashrc + + # Build Ray. + # Note: if you're developing Ray, you probably want to create a boot-disk + # that has your Ray repo pre-cloned. Then, you can replace the pip installs + # below with a git checkout (and possibly a recompile). + - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc + - >- + pip install + google-api-python-client==1.6.7 + cython==0.27.3 + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.4.0-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.4.0-cp35-cp35m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.4.0-cp36-cp36m-manylinux1_x86_64.whl + - >- + cd ~ + && git clone https://github.com/ray-project/ray || true + - >- + cd ~/ray/python + && pip install -e . --verbose + +# Custom commands that will be run on the head node after common setup. +head_setup_commands: [] + +# Custom commands that will be run on worker nodes after common setup. +worker_setup_commands: [] + +# Command to start ray on the head node. You don't need to change this. +head_start_ray_commands: + - ray stop + - >- + ulimit -n 65536; + ray start + --head + --redis-port=6379 + --object-manager-port=8076 + --autoscaling-config=~/ray_bootstrap_config.yaml + +# Command to start ray on worker nodes. You don't need to change this. +worker_start_ray_commands: + - ray stop + - >- + ulimit -n 65536; + ray start + --redis-address=$RAY_HEAD_IP:6379 + --object-manager-port=8076 diff --git a/python/ray/autoscaler/gcp/example-minimal.yaml b/python/ray/autoscaler/gcp/example-minimal.yaml new file mode 100644 index 000000000..b4e8dfc6a --- /dev/null +++ b/python/ray/autoscaler/gcp/example-minimal.yaml @@ -0,0 +1,17 @@ +# An unique identifier for the head node and workers of this cluster. +cluster_name: minimal + +# The maximum number of workers nodes to launch in addition to the head +# node. This takes precedence over min_workers. min_workers default to 0. +max_workers: 1 + +# Cloud-provider specific configuration. +provider: + type: gcp + region: us-west1 + availability_zone: us-west1-a + project_id: null # Globally unique project id + +# How Ray will authenticate with newly launched nodes. +auth: + ssh_user: ubuntu diff --git a/python/ray/autoscaler/gcp/node_provider.py b/python/ray/autoscaler/gcp/node_provider.py new file mode 100644 index 000000000..e3f1b1c5d --- /dev/null +++ b/python/ray/autoscaler/gcp/node_provider.py @@ -0,0 +1,213 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from uuid import uuid4 +import time + +from googleapiclient import discovery + +from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME +from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL + +INSTANCE_NAME_MAX_LEN = 64 +INSTANCE_NAME_UUID_LEN = 8 + + +def wait_for_compute_zone_operation(compute, project_name, operation, zone): + """Poll for compute zone operation until finished.""" + print("Waiting for operation {} to finish...".format(operation["name"])) + + for _ in range(MAX_POLLS): + result = compute.zoneOperations().get( + project=project_name, operation=operation["name"], + zone=zone).execute() + if "error" in result: + raise Exception(result["error"]) + + if result["status"] == "DONE": + print("Done.") + break + + time.sleep(POLL_INTERVAL) + + return result + + +class GCPNodeProvider(NodeProvider): + def __init__(self, provider_config, cluster_name): + NodeProvider.__init__(self, provider_config, cluster_name) + + self.compute = discovery.build("compute", "v1") + + # Cache of node objects from the last nodes() call. This avoids + # excessive DescribeInstances requests. + self.cached_nodes = {} + + # Cache of ip lookups. We assume IPs never change once assigned. + self.internal_ip_cache = {} + self.external_ip_cache = {} + + def nodes(self, tag_filters): + if tag_filters: + label_filter_expr = "(" + " AND ".join([ + "(labels.{key} = {value})".format(key=key, value=value) + for key, value in tag_filters.items() + ]) + ")" + else: + label_filter_expr = "" + + instance_state_filter_expr = "(" + " OR ".join([ + "(status = {status})".format(status=status) + for status in {"PROVISIONING", "STAGING", "RUNNING"} + ]) + ")" + + cluster_name_filter_expr = ("(labels.{key} = {value})" + "".format( + key=TAG_RAY_CLUSTER_NAME, + value=self.cluster_name)) + + not_empty_filters = [ + f for f in [ + label_filter_expr, + instance_state_filter_expr, + cluster_name_filter_expr, + ] if f + ] + + filter_expr = " AND ".join(not_empty_filters) + + response = self.compute.instances().list( + project=self.provider_config["project_id"], + zone=self.provider_config["availability_zone"], + filter=filter_expr, + ).execute() + + instances = response.get("items", []) + # Note: All the operations use "name" as the unique instance identifier + self.cached_nodes = {i["name"]: i for i in instances} + + return [i["name"] for i in instances] + + def is_running(self, node_id): + node = self._node(node_id) + return node["status"] == "RUNNING" + + def is_terminated(self, node_id): + node = self._node(node_id) + return node["status"] not in {"PROVISIONING", "STAGING", "RUNNING"} + + def node_tags(self, node_id): + node = self._node(node_id) + labels = node.get("labels", {}) + return labels + + def set_node_tags(self, node_id, tags): + labels = tags + project_id = self.provider_config["project_id"] + availability_zone = self.provider_config["availability_zone"] + + node = self._node(node_id) + operation = self.compute.instances().setLabels( + project=project_id, + zone=availability_zone, + instance=node_id, + body={ + "labels": dict(node["labels"], **labels), + "labelFingerprint": node["labelFingerprint"] + }).execute() + + result = wait_for_compute_zone_operation(self.compute, project_id, + operation, availability_zone) + + return result + + def external_ip(self, node_id): + if node_id in self.external_ip_cache: + return self.external_ip_cache[node_id] + node = self._node(node_id) + # TODO: Is there a better and more reliable way to do this? + ip = (node.get("networkInterfaces", [{}])[0].get( + "accessConfigs", [{}])[0].get("natIP", None)) + if ip: + self.external_ip_cache[node_id] = ip + return ip + + def internal_ip(self, node_id): + if node_id in self.internal_ip_cache: + return self.internal_ip_cache[node_id] + node = self._node(node_id) + ip = node.get("networkInterfaces", [{}])[0].get("networkIP") + if ip: + self.internal_ip_cache[node_id] = ip + return ip + + def create_node(self, base_config, tags, count): + labels = tags # gcp uses "labels" instead of aws "tags" + project_id = self.provider_config["project_id"] + availability_zone = self.provider_config["availability_zone"] + + config = base_config.copy() + + name_label = labels[TAG_RAY_NODE_NAME] + assert (len(name_label) <= + (INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1)), ( + name_label, len(name_label)) + + config.update({ + "machineType": ("zones/{zone}/machineTypes/{machine_type}" + "".format( + zone=availability_zone, + machine_type=base_config["machineType"])), + "labels": dict( + config.get("labels", {}), **labels, + **{TAG_RAY_CLUSTER_NAME: self.cluster_name}), + }) + + operations = [ + self.compute.instances().insert( + project=project_id, + zone=availability_zone, + body=dict( + config, **{ + "name": ("{name_label}-{uuid}".format( + name_label=name_label, + uuid=uuid4().hex[:INSTANCE_NAME_UUID_LEN])) + })).execute() for i in range(count) + ] + + results = [ + wait_for_compute_zone_operation(self.compute, project_id, + operation, availability_zone) + for operation in operations + ] + + return results + + def terminate_node(self, node_id): + project_id = self.provider_config["project_id"] + availability_zone = self.provider_config["availability_zone"] + + operation = self.compute.instances().delete( + project=project_id, + zone=availability_zone, + instance=node_id, + ).execute() + + result = wait_for_compute_zone_operation(self.compute, project_id, + operation, availability_zone) + + return result + + def _node(self, node_id): + if node_id in self.cached_nodes: + return self.cached_nodes[node_id] + + instance = self.compute.instances().get( + project=self.provider_config["project_id"], + zone=self.provider_config["availability_zone"], + instance=node_id, + ).execute() + + return instance diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index 9aa4bb994..463346551 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -13,11 +13,22 @@ def import_aws(): return bootstrap_aws, AWSNodeProvider -def load_aws_config(): +def import_gcp(): + from ray.autoscaler.gcp.config import bootstrap_gcp + from ray.autoscaler.gcp.node_provider import GCPNodeProvider + return bootstrap_gcp, GCPNodeProvider + + +def load_aws_example_config(): import ray.autoscaler.aws as ray_aws return os.path.join(os.path.dirname(ray_aws.__file__), "example-full.yaml") +def load_gcp_example_config(): + import ray.autoscaler.gcp as ray_gcp + return os.path.join(os.path.dirname(ray_gcp.__file__), "example-full.yaml") + + def import_external(): """Mock a normal provider importer.""" @@ -29,8 +40,8 @@ def import_external(): NODE_PROVIDERS = { "aws": import_aws, - "gce": None, # TODO: support more node providers - "azure": None, + "gcp": import_gcp, + "azure": None, # TODO: support more node providers "kubernetes": None, "docker": None, "local_cluster": None, @@ -38,9 +49,9 @@ NODE_PROVIDERS = { } DEFAULT_CONFIGS = { - "aws": load_aws_config, - "gce": None, # TODO: support more node providers - "azure": None, + "aws": load_aws_example_config, + "gcp": load_gcp_example_config, + "azure": None, # TODO: support more node providers "kubernetes": None, "docker": None, "local_cluster": None, @@ -115,7 +126,7 @@ class NodeProvider(object): nodes() must be called again to refresh results. Examples: - >>> provider.nodes({TAG_RAY_NODE_TYPE: "Worker"}) + >>> provider.nodes({TAG_RAY_NODE_TYPE: "worker"}) ["node-1", "node-2"] """ raise NotImplementedError diff --git a/python/ray/autoscaler/tags.py b/python/ray/autoscaler/tags.py index 646f59d44..1912d675b 100644 --- a/python/ray/autoscaler/tags.py +++ b/python/ray/autoscaler/tags.py @@ -1,22 +1,23 @@ +"""The Ray autoscaler uses tags/labels to associate metadata with instances.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function -"""The Ray autoscaler uses tags to associate metadata with instances.""" # Tag for the name of the node -TAG_NAME = "Name" - -# Tag uniquely identifying all nodes of a cluster -TAG_RAY_CLUSTER_NAME = "ray:ClusterName" +TAG_RAY_NODE_NAME = "ray-node-name" # Tag for the type of node (e.g. Head, Worker) -TAG_RAY_NODE_TYPE = "ray:NodeType" +TAG_RAY_NODE_TYPE = "ray-node-type" # Tag that reports the current state of the node (e.g. Updating, Up-to-date) -TAG_RAY_NODE_STATUS = "ray:NodeStatus" +TAG_RAY_NODE_STATUS = "ray-node-status" + +# Tag uniquely identifying all nodes of a cluster +TAG_RAY_CLUSTER_NAME = "ray-cluster-name" # Hash of the node launch config, used to identify out-of-date nodes -TAG_RAY_LAUNCH_CONFIG = "ray:LaunchConfig" +TAG_RAY_LAUNCH_CONFIG = "ray-launch-config" # Hash of the node runtime config, used to determine if updates are needed -TAG_RAY_RUNTIME_CONFIG = "ray:RuntimeConfig" +TAG_RAY_RUNTIME_CONFIG = "ray-runtime-config" diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index 8d518f8ab..84a7e8494 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -2,7 +2,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import pipes +try: # py3 + from shlex import quote +except ImportError: # py2 + from pipes import quote import os import subprocess import sys @@ -17,6 +20,7 @@ from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG # How long to wait for a node to start, in seconds NODE_START_WAIT_S = 300 +SSH_CHECK_INTERVAL = 5 def pretty_cmd(cmd_str): @@ -43,7 +47,10 @@ class NodeUpdater(object): self.ssh_user = auth_config["ssh_user"] self.ssh_ip = self.provider.external_ip(node_id) self.node_id = node_id - self.file_mounts = file_mounts + self.file_mounts = { + remote: os.path.expanduser(local) + for remote, local in file_mounts.items() + } self.setup_cmds = setup_cmds self.runtime_hash = runtime_hash if redirect_output: @@ -73,7 +80,7 @@ class NodeUpdater(object): "See {} for remote logs.".format(error_str, self.output_name), file=self.stdout) self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "UpdateFailed"}) + {TAG_RAY_NODE_STATUS: "update-failed"}) if self.logfile is not None: print("----- BEGIN REMOTE LOGS -----\n" + open(self.logfile.name).read() + @@ -81,7 +88,7 @@ class NodeUpdater(object): raise e self.provider.set_node_tags( self.node_id, { - TAG_RAY_NODE_STATUS: "Up-to-date", + TAG_RAY_NODE_STATUS: "up-to-date", TAG_RAY_RUNTIME_CONFIG: self.runtime_hash }) print( @@ -91,7 +98,7 @@ class NodeUpdater(object): def do_update(self): self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "WaitingForSSH"}) + {TAG_RAY_NODE_STATUS: "waiting-for-ssh"}) deadline = time.time() + NODE_START_WAIT_S # Wait for external IP @@ -130,20 +137,20 @@ class NodeUpdater(object): print( "NodeUpdater: SSH not up, retrying: {}".format(retry_str), file=self.stdout) - time.sleep(5) + time.sleep(SSH_CHECK_INTERVAL) else: break assert ssh_ok, "Unable to SSH to node" # Rsync file mounts self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "SyncingFiles"}) + {TAG_RAY_NODE_STATUS: "syncing-files"}) for remote_path, local_path in self.file_mounts.items(): print( "NodeUpdater: Syncing {} to {}...".format( local_path, remote_path), file=self.stdout) - assert os.path.exists(local_path) + assert os.path.exists(local_path), local_path if os.path.isdir(local_path): if not local_path.endswith("/"): local_path += "/" @@ -162,7 +169,7 @@ class NodeUpdater(object): # Run init commands self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "SettingUp"}) + {TAG_RAY_NODE_STATUS: "setting-up"}) for cmd in self.setup_cmds: self.ssh_cmd(cmd, verbose=True) @@ -172,14 +179,13 @@ class NodeUpdater(object): "NodeUpdater: running {} on {}...".format( pretty_cmd(cmd), self.ssh_ip), file=self.stdout) - force_interactive = "set -i && source ~/.bashrc && " + force_interactive = "set -i || true && source ~/.bashrc && " self.process_runner.check_call( [ "ssh", "-o", "ConnectTimeout={}s".format(connect_timeout), - "-o", "StrictHostKeyChecking=no", - "-i", self.ssh_private_key, "{}@{}".format( - self.ssh_user, self.ssh_ip), "bash --login -c {}".format( - pipes.quote(force_interactive + cmd)) + "-o", "StrictHostKeyChecking=no", "-i", self.ssh_private_key, + "{}@{}".format(self.ssh_user, self.ssh_ip), + "bash --login -c {}".format(quote(force_interactive + cmd)) ], stdout=redirect or self.stdout, stderr=redirect or self.stderr) diff --git a/python/ray/tune/log_sync.py b/python/ray/tune/log_sync.py index d47d0e4e2..6b71b4303 100644 --- a/python/ray/tune/log_sync.py +++ b/python/ray/tune/log_sync.py @@ -4,10 +4,14 @@ from __future__ import print_function import distutils.spawn import os -import pipes import subprocess import time +try: # py3 + from shlex import quote +except ImportError: # py2 + from pipes import quote + import ray from ray.tune.cluster_info import get_ssh_key, get_ssh_user from ray.tune.error import TuneError @@ -16,14 +20,29 @@ from ray.tune.result import DEFAULT_RESULTS_DIR # Map from (logdir, remote_dir) -> syncer _syncers = {} +S3_PREFIX = "s3://" +GCS_PREFIX = "gs://" +ALLOWED_REMOTE_PREFIXES = (S3_PREFIX, GCS_PREFIX) + def get_syncer(local_dir, remote_dir=None): if remote_dir: - if not remote_dir.startswith("s3://"): - raise TuneError("Upload uri must start with s3://") + if not any( + remote_dir.startswith(prefix) + for prefix in ALLOWED_REMOTE_PREFIXES): + raise TuneError("Upload uri must start with one of: {}" + "".format(ALLOWED_REMOTE_PREFIXES)) - if not distutils.spawn.find_executable("aws"): - raise TuneError("Upload uri requires awscli tool to be installed") + if (remote_dir.startswith(S3_PREFIX) + and not distutils.spawn.find_executable("aws")): + raise TuneError( + "Upload uri starting with '{}' requires awscli tool" + " to be installed".format(S3_PREFIX)) + elif (remote_dir.startswith(GCS_PREFIX) + and not distutils.spawn.find_executable("gsutil")): + raise TuneError( + "Upload uri starting with '{}' requires gsutil tool" + " to be installed".format(GCS_PREFIX)) if local_dir.startswith(DEFAULT_RESULTS_DIR + "/"): rel_path = os.path.relpath(local_dir, DEFAULT_RESULTS_DIR) @@ -85,14 +104,18 @@ class _LogSyncer(object): print("Error: log sync requires rsync to be installed.") return worker_to_local_sync_cmd = (( - """rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """ + """rsync -avz -e "ssh -i {} -o ConnectTimeout=120s """ """-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format( - ssh_key, ssh_user, self.worker_ip, - pipes.quote(self.local_dir), pipes.quote(self.local_dir))) + quote(ssh_key), ssh_user, self.worker_ip, + quote(self.local_dir), quote(self.local_dir))) if self.remote_dir: - local_to_remote_sync_cmd = ("aws s3 sync '{}' '{}'".format( - pipes.quote(self.local_dir), pipes.quote(self.remote_dir))) + if self.remote_dir.startswith(S3_PREFIX): + local_to_remote_sync_cmd = ("aws s3 sync {} {}".format( + quote(self.local_dir), quote(self.remote_dir))) + elif self.remote_dir.startswith(GCS_PREFIX): + local_to_remote_sync_cmd = ("gsutil rsync -r {} {}".format( + quote(self.local_dir), quote(self.remote_dir))) else: local_to_remote_sync_cmd = None diff --git a/python/setup.py b/python/setup.py index 565880682..94a62666e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -40,7 +40,10 @@ ray_ui_files = [ "ray/core/src/catapult_files/trace_viewer_full.html" ] -ray_autoscaler_files = ["ray/autoscaler/aws/example-full.yaml"] +ray_autoscaler_files = [ + "ray/autoscaler/aws/example-full.yaml", + "ray/autoscaler/gcp/example-full.yaml", +] if "RAY_USE_NEW_GCS" in os.environ and os.environ["RAY_USE_NEW_GCS"] == "on": ray_files += [ diff --git a/test/autoscaler_test.py b/test/autoscaler_test.py index c56bd3ce5..ea317aac8 100644 --- a/test/autoscaler_test.py +++ b/test/autoscaler_test.py @@ -251,7 +251,7 @@ class AutoscalingTest(unittest.TestCase): config["max_workers"] = 5 config_path = self.write_config(config) self.provider = MockProvider() - self.provider.create_node({}, {TAG_RAY_NODE_TYPE: "Worker"}, 10) + self.provider.create_node({}, {TAG_RAY_NODE_TYPE: "worker"}, 10) autoscaler = StandardAutoscaler( config_path, LoadMetrics(), max_failures=0, update_interval_s=0) self.assertEqual(len(self.provider.nodes({})), 10) @@ -398,12 +398,12 @@ class AutoscalingTest(unittest.TestCase): node.state = "running" assert len( self.provider.nodes({ - TAG_RAY_NODE_STATUS: "Uninitialized" + TAG_RAY_NODE_STATUS: "uninitialized" })) == 2 autoscaler.update() self.waitFor( lambda: len(self.provider.nodes( - {TAG_RAY_NODE_STATUS: "Up-to-date"})) == 2) + {TAG_RAY_NODE_STATUS: "up-to-date"})) == 2) def testReportsConfigFailures(self): config_path = self.write_config(SMALL_CLUSTER) @@ -424,12 +424,12 @@ class AutoscalingTest(unittest.TestCase): node.state = "running" assert len( self.provider.nodes({ - TAG_RAY_NODE_STATUS: "Uninitialized" + TAG_RAY_NODE_STATUS: "uninitialized" })) == 2 autoscaler.update() self.waitFor( lambda: len(self.provider.nodes( - {TAG_RAY_NODE_STATUS: "UpdateFailed"})) == 2) + {TAG_RAY_NODE_STATUS: "update-failed"})) == 2) def testConfiguresOutdatedNodes(self): config_path = self.write_config(SMALL_CLUSTER) @@ -451,7 +451,7 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() self.waitFor( lambda: len(self.provider.nodes( - {TAG_RAY_NODE_STATUS: "Up-to-date"})) == 2) + {TAG_RAY_NODE_STATUS: "up-to-date"})) == 2) runner.calls = [] new_config = SMALL_CLUSTER.copy() new_config["worker_setup_commands"] = ["cmdX", "cmdY"] @@ -520,7 +520,7 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() self.waitFor( lambda: len(self.provider.nodes( - {TAG_RAY_NODE_STATUS: "Up-to-date"})) == 2) + {TAG_RAY_NODE_STATUS: "up-to-date"})) == 2) # Mark a node as unhealthy lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0