From 315edab08508e7b4ca07ce22467d76dad4031a89 Mon Sep 17 00:00:00 2001 From: Daniel Edgecumbe <45787862+ls-daniel@users.noreply.github.com> Date: Fri, 1 Feb 2019 10:46:32 +0000 Subject: [PATCH] [autoscaler] Speedups (#3720) - NodeUpdater gets its' IP in parallel now (no longer in __init__) - We use persistent connections in SSH (temp folder created only for ray; ControlMaster) - hash_runtime_conf was performing a pointless hexlify step, wasting time on large files - We use NodeUpdaterThreads and share the NodeProvider; NodeUpdaterProcess is removed - AWSNodeProvider caches nodes more aggressively - NodeProvider now has a shim batch terminate_nodes() call; AWSNodeProvider parallelises it; the autoscaler uses it - AWSNodeProvider batches EC2 update_tags calls - Logging changes throughout to provide standardised timing information for profiling - Pulled out a few unnecessary is_running calls (NodeUpdater will loop waiting for SSH anyway) ## Related issue number Issue #3599 --- python/ray/autoscaler/autoscaler.py | 178 ++++++------ python/ray/autoscaler/aws/config.py | 41 +-- python/ray/autoscaler/aws/node_provider.py | 135 +++++++--- python/ray/autoscaler/commands.py | 120 ++++++--- python/ray/autoscaler/docker.py | 3 +- python/ray/autoscaler/gcp/config.py | 30 ++- python/ray/autoscaler/gcp/node_provider.py | 14 +- python/ray/autoscaler/local/node_provider.py | 10 +- python/ray/autoscaler/log_timer.py | 21 ++ python/ray/autoscaler/node_provider.py | 14 + python/ray/autoscaler/updater.py | 269 +++++++++++-------- python/ray/monitor.py | 31 ++- test/autoscaler_test.py | 23 +- 13 files changed, 545 insertions(+), 344 deletions(-) create mode 100644 python/ray/autoscaler/log_timer.py diff --git a/python/ray/autoscaler/autoscaler.py b/python/ray/autoscaler/autoscaler.py index da0f4eab3..9c5c37976 100644 --- a/python/ray/autoscaler/autoscaler.py +++ b/python/ray/autoscaler/autoscaler.py @@ -2,21 +2,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import binascii import copy import json import hashlib -import logging import math import os from six import string_types from six.moves import queue import subprocess import threading +import logging import time from collections import defaultdict -from datetime import datetime import numpy as np import yaml @@ -26,11 +24,12 @@ from ray.ray_constants import AUTOSCALER_MAX_NUM_FAILURES, \ AUTOSCALER_UPDATE_INTERVAL_S, AUTOSCALER_HEARTBEAT_TIMEOUT_S from ray.autoscaler.node_provider import get_node_provider, \ get_default_config -from ray.autoscaler.updater import NodeUpdaterProcess +from ray.autoscaler.updater import NodeUpdaterThread from ray.autoscaler.docker import dockerize_if_needed from ray.autoscaler.tags import (TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE, TAG_RAY_NODE_NAME) + import ray.services as services logger = logging.getLogger(__name__) @@ -158,11 +157,13 @@ class LoadMetrics(object): def prune(mapping): unwanted = set(mapping) - active_ips for unwanted_key in unwanted: - logger.info("Removed mapping: {} - {}".format( - unwanted_key, mapping[unwanted_key])) + logger.info("LoadMetrics: " + "Removed mapping: {} - {}".format( + unwanted_key, mapping[unwanted_key])) del mapping[unwanted_key] if unwanted: logger.info( + "LoadMetrics: " "Removed {} stale ip mappings: {} not in {}".format( len(unwanted), unwanted, active_ips)) @@ -174,8 +175,8 @@ class LoadMetrics(object): return self._info()["NumNodesUsed"] def info_string(self): - return " - {}".format("\n - ".join( - ["{}: {}".format(k, v) for k, v in sorted(self._info().items())])) + return ", ".join( + ["{}={}".format(k, v) for k, v in sorted(self._info().items())]) def _info(self): nodes_used = 0.0 @@ -223,17 +224,13 @@ class LoadMetrics(object): class NodeLauncher(threading.Thread): - def __init__(self, queue, pending, *args, **kwargs): + def __init__(self, provider, queue, pending, *args, **kwargs): self.queue = queue self.pending = pending - self.provider = None + self.provider = provider super(NodeLauncher, self).__init__(*args, **kwargs) def _launch_node(self, config, count): - if self.provider is None: - self.provider = get_node_provider(config["provider"], - config["cluster_name"]) - tag_filters = {TAG_RAY_NODE_TYPE: "worker"} before = self.provider.nodes(tag_filters=tag_filters) launch_hash = hash_launch_conf(config["worker_nodes"], config["auth"]) @@ -247,7 +244,8 @@ class NodeLauncher(threading.Thread): }, count) after = self.provider.nodes(tag_filters=tag_filters) if set(after).issubset(before): - logger.error("No new nodes reported after node creation") + logger.error("NodeLauncher: " + "No new nodes reported after node creation") def run(self): while True: @@ -305,8 +303,6 @@ class StandardAutoscaler(object): max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES, max_failures=AUTOSCALER_MAX_NUM_FAILURES, process_runner=subprocess, - verbose_updates=True, - node_updater_cls=NodeUpdaterProcess, update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S): self.config_path = config_path self.reload_config(errors_fatal=True) @@ -317,9 +313,7 @@ class StandardAutoscaler(object): self.max_failures = max_failures self.max_launch_batch = max_launch_batch self.max_concurrent_launches = max_concurrent_launches - self.verbose_updates = verbose_updates self.process_runner = process_runner - self.node_updater_cls = node_updater_cls # Map from node_id to NodeUpdater processes self.updaters = {} @@ -337,7 +331,9 @@ class StandardAutoscaler(object): max_concurrent_launches / float(max_launch_batch)) for i in range(int(max_batches)): node_launcher = NodeLauncher( - queue=self.launch_queue, pending=self.num_launches_pending) + provider=self.provider, + queue=self.launch_queue, + pending=self.num_launches_pending) node_launcher.daemon = True node_launcher.start() @@ -359,11 +355,12 @@ class StandardAutoscaler(object): self.reload_config(errors_fatal=False) self._update() except Exception as e: - logger.exception("Error during autoscaling.") + logger.exception("StandardAutoscaler: " + "Error during autoscaling.") self.num_failures += 1 if self.num_failures > self.max_failures: - logger.critical( - "*** StandardAutoscaler: Too many errors, abort. ***") + logger.critical("StandardAutoscaler: " + "Too many errors, abort.") raise e def _update(self): @@ -377,7 +374,7 @@ class StandardAutoscaler(object): self.last_update_time = now num_pending = self.num_launches_pending.value nodes = self.workers() - logger.info(self.info_string(nodes)) + self.log_info_string(nodes) self.load_metrics.prune_active_ips( [self.provider.internal_ip(node_id) for node_id in nodes]) target_workers = self.target_num_workers() @@ -385,35 +382,37 @@ class StandardAutoscaler(object): # Terminate any idle or out of date nodes last_used = self.load_metrics.last_used_time_by_ip horizon = now - (60 * self.config["idle_timeout_minutes"]) - num_terminated = 0 + + nodes_to_terminate = [] for node_id in nodes: node_ip = self.provider.internal_ip(node_id) if node_ip in last_used and last_used[node_ip] < horizon and \ - len(nodes) - num_terminated > target_workers: - num_terminated += 1 - logger.info("StandardAutoscaler: Terminating idle node: " - "{}".format(node_id)) - self.provider.terminate_node(node_id) + len(nodes) - len(nodes_to_terminate) > target_workers: + logger.info("StandardAutoscaler: " + "{}: Terminating idle node".format(node_id)) + nodes_to_terminate.append(node_id) elif not self.launch_config_ok(node_id): - num_terminated += 1 - logger.info("StandardAutoscaler: Terminating outdated node: " - "{}".format(node_id)) - self.provider.terminate_node(node_id) - if num_terminated > 0: + logger.info("StandardAutoscaler: " + "{}: Terminating outdated node".format(node_id)) + nodes_to_terminate.append(node_id) + + if nodes_to_terminate: + self.provider.terminate_nodes(nodes_to_terminate) nodes = self.workers() - logger.info(self.info_string(nodes)) + self.log_info_string(nodes) # Terminate nodes if there are too many - num_terminated = 0 + nodes_to_terminate = [] while len(nodes) > self.config["max_workers"]: - num_terminated += 1 - logger.info("StandardAutoscaler: Terminating unneeded node: " - "{}".format(nodes[-1])) - self.provider.terminate_node(nodes[-1]) + logger.info("StandardAutoscaler: " + "{}: Terminating unneeded node".format(nodes[-1])) + nodes_to_terminate.append(nodes[-1]) nodes = nodes[:-1] - if num_terminated > 0: + + if nodes_to_terminate: + self.provider.terminate_nodes(nodes_to_terminate) nodes = self.workers() - logger.info(self.info_string(nodes)) + self.log_info_string(nodes) # Launch new nodes if needed num_workers = len(nodes) + num_pending @@ -422,7 +421,8 @@ class StandardAutoscaler(object): self.max_concurrent_launches - num_pending) num_launches = min(max_allowed, target_workers - num_workers) self.launch_new_node(num_launches) - logger.info(self.info_string()) + nodes = self.workers() + self.log_info_string(nodes) else: self.bringup = False @@ -442,11 +442,21 @@ class StandardAutoscaler(object): # immediately trying to restart Ray on the new node. self.load_metrics.mark_active(self.provider.internal_ip(node_id)) nodes = self.workers() - logger.info(self.info_string(nodes)) + self.log_info_string(nodes) # Update nodes with out-of-date files - for node_id in nodes: - self.update_if_needed(node_id) + T = [ + threading.Thread( + target=self.spawn_updater, + args=(node_id, commands), + ) for node_id, commands in (self.should_update(node_id) + for node_id in nodes) + if node_id is not None + ] + for t in T: + t.start() + for t in T: + t.join() # Attempt to recover unhealthy nodes for node_id in nodes: @@ -471,7 +481,8 @@ class StandardAutoscaler(object): if errors_fatal: raise e else: - logger.exception("StandardAutoscaler: Error parsing config.") + logger.exception("StandardAutoscaler: " + "Error parsing config.") def target_num_workers(self): initial_workers = self.config["initial_workers"] @@ -496,9 +507,9 @@ class StandardAutoscaler(object): def files_up_to_date(self, node_id): applied = self.provider.node_tags(node_id).get(TAG_RAY_RUNTIME_CONFIG) if applied != self.runtime_hash: - logger.info( - "StandardAutoscaler: {} has runtime state {}, want {}".format( - node_id, applied, self.runtime_hash)) + logger.info("StandardAutoscaler: " + "{}: Runtime state is {}, want {}".format( + node_id, applied, self.runtime_hash)) return False return True @@ -512,27 +523,29 @@ class StandardAutoscaler(object): delta = now - last_heartbeat_time if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S: return - logger.warning("StandardAutoscaler: No heartbeat from node " - "{} in {} seconds, restarting Ray to recover...".format( - node_id, delta)) - updater = self.node_updater_cls( + logger.warning("StandardAutoscaler: " + "{}: No heartbeat in {}s, " + "restarting Ray to recover...".format(node_id, delta)) + updater = NodeUpdaterThread( node_id, self.config["provider"], + self.provider, self.config["auth"], self.config["cluster_name"], {}, with_head_node_ip(self.config["worker_start_ray_commands"]), self.runtime_hash, - redirect_output=not self.verbose_updates, process_runner=self.process_runner, use_internal_ip=True) updater.start() self.updaters[node_id] = updater - def update_if_needed(self, node_id): + def should_update(self, node_id): if not self.can_update(node_id): - return + return (None, None) + if self.files_up_to_date(node_id): - return + return (None, None) + successful_updated = self.num_successful_updates.get(node_id, 0) > 0 if successful_updated and self.config.get("restart_only", False): init_commands = self.config["worker_start_ray_commands"] @@ -544,33 +557,35 @@ class StandardAutoscaler(object): self.config["worker_setup_commands"] + self.config["worker_start_ray_commands"]) - updater = self.node_updater_cls( + return (node_id, init_commands) + + def spawn_updater(self, node_id, init_commands): + updater = NodeUpdaterThread( node_id, self.config["provider"], + self.provider, self.config["auth"], self.config["cluster_name"], self.config["file_mounts"], with_head_node_ip(init_commands), self.runtime_hash, - redirect_output=not self.verbose_updates, process_runner=self.process_runner, use_internal_ip=True) updater.start() self.updaters[node_id] = updater def can_update(self, node_id): - if not self.provider.is_running(node_id): + if node_id in self.updaters: return False if not self.launch_config_ok(node_id): return False - if node_id in self.updaters: - return False if self.num_failed_updates.get(node_id, 0) > 0: # TODO(ekl) retry? return False return True def launch_new_node(self, count): - logger.info("StandardAutoscaler: Launching {} new nodes".format(count)) + logger.info("StandardAutoscaler: " + "Launching {} new nodes".format(count)) self.num_launches_pending.inc(count) config = copy.deepcopy(self.config) self.launch_queue.put((config, count)) @@ -578,9 +593,11 @@ class StandardAutoscaler(object): def workers(self): return self.provider.nodes(tag_filters={TAG_RAY_NODE_TYPE: "worker"}) - def info_string(self, nodes=None): - if nodes is None: - nodes = self.workers() + def log_info_string(self, nodes): + logger.info("StandardAutoscaler: {}".format(self.info_string(nodes))) + logger.info("LoadMetrics: {}".format(self.load_metrics.info_string())) + + def info_string(self, nodes): suffix = "" if self.num_launches_pending: suffix += " ({} pending)".format(self.num_launches_pending.value) @@ -591,9 +608,9 @@ class StandardAutoscaler(object): len(self.num_failed_updates)) if self.bringup: suffix += " (bringup=True)" - return "StandardAutoscaler [{}]: {}/{} target nodes{}\n{}".format( - datetime.now(), len(nodes), self.target_num_workers(), suffix, - self.load_metrics.info_string()) + + return "{}/{} target nodes{}".format( + len(nodes), self.target_num_workers(), suffix) def typename(v): @@ -685,6 +702,11 @@ def hash_runtime_conf(file_mounts, extra_objs): hasher = hashlib.sha1() def add_content_hashes(path): + def add_hash_of_file(fpath): + with open(fpath, "rb") as f: + for chunk in iter(lambda: f.read(2**20), b''): + hasher.update(chunk) + path = os.path.expanduser(path) if os.path.isdir(path): dirs = [] @@ -694,18 +716,12 @@ def hash_runtime_conf(file_mounts, extra_objs): hasher.update(dirpath.encode("utf-8")) for name in filenames: hasher.update(name.encode("utf-8")) - with open(os.path.join(dirpath, name), "rb") as f: - if os.path.getsize(os.path.join(dirpath, - name)) < 1000000000: - hasher.update(binascii.hexlify(f.read())) - else: - for chunk in iter(lambda: f.read(8192), b''): - hasher.update(binascii.hexlify(chunk)) + fpath = os.path.join(dirpath, name) + add_hash_of_file(fpath) else: - with open(path, "rb") as f: - hasher.update(binascii.hexlify(f.read())) + add_hash_of_file(path) - conf_str = (json.dumps(sorted(file_mounts.items())).encode("utf-8") + + conf_str = (json.dumps(file_mounts, sort_keys=True).encode("utf-8") + json.dumps(extra_objs, sort_keys=True).encode("utf-8")) # Important: only hash the files once. Otherwise, we can end up restarting diff --git a/python/ray/autoscaler/aws/config.py b/python/ray/autoscaler/aws/config.py index c0493df0d..22268ea27 100644 --- a/python/ray/autoscaler/aws/config.py +++ b/python/ray/autoscaler/aws/config.py @@ -4,9 +4,9 @@ from __future__ import print_function from distutils.version import StrictVersion import json -import logging import os import time +import logging import boto3 from botocore.config import Config @@ -14,6 +14,8 @@ import botocore from ray.ray_constants import BOTO_MAX_RETRIES +logger = logging.getLogger(__name__) + RAY = "ray-autoscaler" DEFAULT_RAY_INSTANCE_PROFILE = RAY + "-v1" DEFAULT_RAY_IAM_ROLE = RAY + "-v1" @@ -34,7 +36,6 @@ def key_pair(i, region): # Suppress excessive connection dropped logs from boto logging.getLogger("botocore").setLevel(logging.WARNING) -logger = logging.getLogger(__name__) def bootstrap_aws(config): @@ -62,8 +63,9 @@ def _configure_iam_role(config): profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config) if profile is None: - logger.info("Creating new instance profile {}".format( - DEFAULT_RAY_INSTANCE_PROFILE)) + logger.info("_configure_iam_role: " + "Creating new instance profile {}".format( + DEFAULT_RAY_INSTANCE_PROFILE)) client = _client("iam", config) client.create_instance_profile( InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE) @@ -75,7 +77,8 @@ def _configure_iam_role(config): if not profile.roles: role = _get_role(DEFAULT_RAY_IAM_ROLE, config) if role is None: - logger.info("Creating new role {}".format(DEFAULT_RAY_IAM_ROLE)) + logger.info("_configure_iam_role: " + "Creating new role {}".format(DEFAULT_RAY_IAM_ROLE)) iam = _resource("iam", config) iam.create_role( RoleName=DEFAULT_RAY_IAM_ROLE, @@ -99,8 +102,9 @@ def _configure_iam_role(config): profile.add_role(RoleName=role.name) time.sleep(15) # wait for propagation - logger.info("Role not specified for head node, using {}".format( - profile.arn)) + logger.info("_configure_iam_role: " + "Role not specified for head node, using {}".format( + profile.arn)) config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn} return config @@ -126,7 +130,8 @@ def _configure_key_pair(config): # We can safely create a new key. if not key and not os.path.exists(key_path): - logger.info("Creating new key pair {}".format(key_name)) + logger.info("_configure_key_pair: " + "Creating new key pair {}".format(key_name)) key = ec2.create_key_pair(KeyName=key_name) with open(key_path, "w") as f: f.write(key.key_material) @@ -142,7 +147,8 @@ def _configure_key_pair(config): assert os.path.exists(key_path), \ "Private key file {} not found for {}".format(key_path, key_name) - logger.info("KeyName not specified for nodes, using {}".format(key_name)) + logger.info("_configure_key_pair: " + "KeyName not specified for nodes, using {}".format(key_name)) config["auth"]["ssh_private_key"] = key_path config["head_node"]["KeyName"] = key_name @@ -174,19 +180,21 @@ def _configure_subnet(config): "No usable subnets matching availability zone {} " "found. Choose a different availability zone or try " "manually creating an instance in your specified region " - "to populate the list of subnets and trying this again." - .format(config["provider"]["availability_zone"])) + "to populate the list of subnets and trying this again.". + format(config["provider"]["availability_zone"])) subnet_ids = [s.subnet_id for s in subnets] subnet_descr = [(s.subnet_id, s.availability_zone) for s in subnets] if "SubnetIds" not in config["head_node"]: config["head_node"]["SubnetIds"] = subnet_ids - logger.info("SubnetIds not specified for head node," - " using {}".format(subnet_descr)) + logger.info("_configure_subnet: " + "SubnetIds not specified for head node, using {}".format( + subnet_descr)) if "SubnetIds" not in config["worker_nodes"]: config["worker_nodes"]["SubnetIds"] = subnet_ids - logger.info("SubnetId not specified for workers," + logger.info("_configure_subnet: " + "SubnetId not specified for workers," " using {}".format(subnet_descr)) return config @@ -202,7 +210,8 @@ def _configure_security_group(config): security_group = _get_security_group(config, vpc_id, group_name) if security_group is None: - logger.info("Creating new security group {}".format(group_name)) + logger.info("_configure_security_group: " + "Creating new security group {}".format(group_name)) client = _client("ec2", config) client.create_security_group( Description="Auto-created security group for Ray workers", @@ -230,12 +239,14 @@ def _configure_security_group(config): if "SecurityGroupIds" not in config["head_node"]: logger.info( + "_configure_security_group: " "SecurityGroupIds not specified for head node, using {}".format( security_group.group_name)) config["head_node"]["SecurityGroupIds"] = [security_group.id] if "SecurityGroupIds" not in config["worker_nodes"]: logger.info( + "_configure_security_group: " "SecurityGroupIds not specified for workers, using {}".format( security_group.group_name)) config["worker_nodes"]["SecurityGroupIds"] = [security_group.id] diff --git a/python/ray/autoscaler/aws/node_provider.py b/python/ray/autoscaler/aws/node_provider.py index c756d9afd..7aedacba1 100644 --- a/python/ray/autoscaler/aws/node_provider.py +++ b/python/ray/autoscaler/aws/node_provider.py @@ -3,6 +3,8 @@ from __future__ import division from __future__ import print_function import random +import threading +from collections import defaultdict import boto3 from botocore.config import Config @@ -10,6 +12,7 @@ from botocore.config import Config from ray.autoscaler.node_provider import NodeProvider from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME from ray.ray_constants import BOTO_MAX_RETRIES +from ray.autoscaler.log_timer import LogTimer def to_aws_format(tags): @@ -40,15 +43,59 @@ class AWSNodeProvider(NodeProvider): # Try availability zones round-robin, starting from random offset self.subnet_idx = random.randint(0, 100) + self.tag_cache = {} # Tags that we believe to actually be on EC2. + self.tag_cache_pending = {} # Tags that we will soon upload. + self.tag_cache_lock = threading.Lock() + self.tag_cache_update_event = threading.Event() + self.tag_cache_kill_event = threading.Event() + self.tag_update_thread = threading.Thread( + target=self._node_tag_update_loop) + self.tag_update_thread.start() + # Cache of node objects from the last nodes() call. This avoids # excessive DescribeInstances requests. self.cached_nodes = {} - # Cache of ip lookups. We assume IPs never change once assigned. - self.internal_ip_cache = {} - self.external_ip_cache = {} + def _node_tag_update_loop(self): + """ Update the AWS tags for a cluster periodically. + + The purpose of this loop is to avoid excessive EC2 calls when a large + number of nodes are being launched simultaneously. + """ + while True: + self.tag_cache_update_event.wait() + self.tag_cache_update_event.clear() + + batch_updates = defaultdict(list) + + with self.tag_cache_lock: + for node_id, tags in self.tag_cache_pending.items(): + for x in tags.items(): + batch_updates[x].append(node_id) + self.tag_cache[node_id].update(tags) + + self.tag_cache_pending = {} + + for (k, v), node_ids in batch_updates.items(): + m = "Set tag {}={} on {}".format(k, v, node_ids) + with LogTimer("AWSNodeProvider: {}".format(m)): + if k == TAG_RAY_NODE_NAME: + k = "Name" + self.ec2.meta.client.create_tags( + Resources=node_ids, + Tags=[{ + "Key": k, + "Value": v + }], + ) + + self.tag_cache_kill_event.wait(timeout=5) + if self.tag_cache_kill_event.is_set(): + return def nodes(self, tag_filters): + # Note that these filters are acceptable because they are set on + # node initialization, and so can never be sitting in the cache. tag_filters = to_aws_format(tag_filters) filters = [ { @@ -65,9 +112,19 @@ class AWSNodeProvider(NodeProvider): "Name": "tag:{}".format(k), "Values": [v], }) - instances = list(self.ec2.instances.filter(Filters=filters)) - self.cached_nodes = {i.id: i for i in instances} - return [i.id for i in instances] + + nodes = list(self.ec2.instances.filter(Filters=filters)) + # Populate the tag cache with initial information if necessary + for node in nodes: + if node.id in self.tag_cache: + continue + + self.tag_cache[node.id] = from_aws_format( + {x["Key"]: x["Value"] + for x in node.tags}) + + self.cached_nodes = {node.id: node for node in nodes} + return [node.id for node in nodes] def is_running(self, node_id): node = self._node(node_id) @@ -79,40 +136,25 @@ class AWSNodeProvider(NodeProvider): return state not in ["running", "pending"] def node_tags(self, node_id): - node = self._node(node_id) - tags = {} - for tag in node.tags: - tags[tag["Key"]] = tag["Value"] - return from_aws_format(tags) + with self.tag_cache_lock: + d1 = self.tag_cache[node_id] + d2 = self.tag_cache_pending.get(node_id, {}) + return dict(d1, **d2) def external_ip(self, node_id): - if node_id in self.external_ip_cache: - return self.external_ip_cache[node_id] - node = self._node(node_id) - ip = node.public_ip_address - if ip: - self.external_ip_cache[node_id] = ip - return ip + return self._node(node_id).public_ip_address def internal_ip(self, node_id): - if node_id in self.internal_ip_cache: - return self.internal_ip_cache[node_id] - node = self._node(node_id) - ip = node.private_ip_address - if ip: - self.internal_ip_cache[node_id] = ip - return ip + return self._node(node_id).private_ip_address def set_node_tags(self, node_id, tags): - tags = to_aws_format(tags) - node = self._node(node_id) - tag_pairs = [] - for k, v in tags.items(): - tag_pairs.append({ - "Key": k, - "Value": v, - }) - node.create_tags(Tags=tag_pairs) + with self.tag_cache_lock: + try: + self.tag_cache_pending[node_id].update(tags) + except KeyError: + self.tag_cache_pending[node_id] = tags + + self.tag_cache_update_event.set() def create_node(self, node_config, tags, count): tags = to_aws_format(tags) @@ -166,9 +208,24 @@ class AWSNodeProvider(NodeProvider): node = self._node(node_id) node.terminate() + self.tag_cache.pop(node_id, None) + self.tag_cache_pending.pop(node_id, None) + + def terminate_nodes(self, node_ids): + self.ec2.meta.client.terminate_instances(InstanceIds=node_ids) + + for node_id in node_ids: + self.tag_cache.pop(node_id, None) + self.tag_cache_pending.pop(node_id, None) + def _node(self, node_id): - if node_id in self.cached_nodes: - return self.cached_nodes[node_id] - matches = list(self.ec2.instances.filter(InstanceIds=[node_id])) - assert len(matches) == 1, "Invalid instance id {}".format(node_id) - return matches[0] + if node_id not in self.cached_nodes: + self.nodes({}) # Side effect: should cache it. + + assert node_id in self.cached_nodes, "Invalid instance id {}".format( + node_id) + return self.cached_nodes[node_id] + + def cleanup(self): + self.tag_cache_update_event.set() + self.tag_cache_kill_event.set() diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index 37f286248..350a6250c 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -8,9 +8,9 @@ import json import os import tempfile import time +import logging import sys import click -import logging import random import yaml @@ -24,7 +24,8 @@ from ray.autoscaler.autoscaler import validate_config, hash_runtime_conf, \ from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \ TAG_RAY_NODE_NAME -from ray.autoscaler.updater import NodeUpdaterProcess +from ray.autoscaler.updater import NodeUpdaterThread +from ray.autoscaler.log_timer import LogTimer logger = logging.getLogger(__name__) @@ -81,18 +82,35 @@ def teardown_cluster(config_file, yes, workers_only, override_cluster_name): provider = get_node_provider(config["provider"], config["cluster_name"]) - if not workers_only: - for node in provider.nodes({TAG_RAY_NODE_TYPE: "head"}): - logger.info("Terminating head node {}".format(node)) - provider.terminate_node(node) + def remaining_nodes(): + if workers_only: + A = [] + else: + A = [ + node_id for node_id in provider.nodes({ + TAG_RAY_NODE_TYPE: "head" + }) + ] - nodes = provider.nodes({TAG_RAY_NODE_TYPE: "worker"}) - while nodes: - for node in nodes: - logger.info("Terminating worker {}".format(node)) - provider.terminate_node(node) - time.sleep(5) - nodes = provider.nodes({TAG_RAY_NODE_TYPE: "worker"}) + A += [ + node_id for node_id in provider.nodes({ + TAG_RAY_NODE_TYPE: "worker" + }) + ] + return A + + # Loop here to check that both the head and worker nodes are actually + # really gone + A = remaining_nodes() + with LogTimer("teardown_cluster: Termination done."): + while A: + logger.info("teardown_cluster: " + "Terminating {} nodes...".format(len(A))) + provider.terminate_nodes(A) + time.sleep(1) + A = remaining_nodes() + + provider.cleanup() def kill_node(config_file, yes, override_cluster_name): @@ -108,26 +126,26 @@ def kill_node(config_file, yes, override_cluster_name): provider = get_node_provider(config["provider"], config["cluster_name"]) nodes = provider.nodes({TAG_RAY_NODE_TYPE: "worker"}) node = random.choice(nodes) - logger.info("Terminating worker {}".format(node)) - updater = NodeUpdaterProcess( - node, - config["provider"], - config["auth"], - config["cluster_name"], - config["file_mounts"], [], - "", - redirect_output=False) + logger.info("kill_node: Terminating worker {}".format(node)) + + updater = NodeUpdaterThread(node, config["provider"], provider, + config["auth"], config["cluster_name"], + config["file_mounts"], [], "") _exec(updater, "ray stop", False, False) time.sleep(5) + + iip = provider.internal_ip(node) + if iip: + return iip + return provider.external_ip(node) def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, override_cluster_name): """Create the cluster head node, which in turn creates the workers.""" - provider = get_node_provider(config["provider"], config["cluster_name"]) head_node_tags = { TAG_RAY_NODE_TYPE: "head", @@ -148,9 +166,10 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, TAG_RAY_LAUNCH_CONFIG) != launch_hash: if head_node is not None: confirm("Head node config out-of-date. It will be terminated", yes) - logger.info("Terminating outdated head node {}".format(head_node)) + logger.info("get_or_create_head_node: " + "Terminating outdated head node {}".format(head_node)) provider.terminate_node(head_node) - logger.info("Launching new head node...") + logger.info("get_or_create_head_node: Launching new head node...") head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format( config["cluster_name"]) @@ -163,7 +182,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, # TODO(ekl) right now we always update the head node even if the hash # matches. We could prompt the user for what they want to do in this case. runtime_hash = hash_runtime_conf(config["file_mounts"], config) - logger.info("Updating files on head node...") + logger.info("get_or_create_head_node: Updating files on head node...") # Rewrite the auth config so that the head node can update the workers remote_key_path = "~/ray_bootstrap_key.pem" @@ -197,15 +216,16 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, config["setup_commands"] + config["head_setup_commands"] + config["head_start_ray_commands"]) - updater = NodeUpdaterProcess( + updater = NodeUpdaterThread( head_node, config["provider"], + provider, config["auth"], config["cluster_name"], config["file_mounts"], init_commands, runtime_hash, - redirect_output=False) + ) updater.start() updater.join() @@ -213,11 +233,13 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, provider.nodes(head_node_tags) if updater.exitcode != 0: - logger.error("Updating {} failed".format( - provider.external_ip(head_node))) + logger.error("get_or_create_head_node: " + "Updating {} failed".format( + provider.external_ip(head_node))) sys.exit(1) - logger.info("Head node up-to-date, IP address is: {}".format( - provider.external_ip(head_node))) + logger.info("get_or_create_head_node: " + "Head node up-to-date, IP address is: {}".format( + provider.external_ip(head_node))) monitor_str = "tail -n 100 -f /tmp/ray/session_*/logs/monitor*" for s in init_commands: @@ -239,6 +261,8 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, config["auth"]["ssh_user"], provider.external_ip(head_node))) + provider.cleanup() + def attach_cluster(config_file, start, use_tmux, override_cluster_name, new): """Attaches to a screen for the specified cluster. @@ -288,14 +312,18 @@ def exec_cluster(config_file, cmd, screen, tmux, stop, start, config = _bootstrap_config(config) head_node = _get_head_node( config, config_file, override_cluster_name, create_if_needed=start) - updater = NodeUpdaterProcess( + + provider = get_node_provider(config["provider"], config["cluster_name"]) + updater = NodeUpdaterThread( head_node, config["provider"], + provider, config["auth"], config["cluster_name"], - config["file_mounts"], [], + config["file_mounts"], + [], "", - redirect_output=False) + ) if stop: cmd += ("; ray stop; ray teardown ~/ray_bootstrap_config.yaml --yes " "--workers-only; sudo shutdown -h now") @@ -322,6 +350,8 @@ def exec_cluster(config_file, cmd, screen, tmux, stop, start, attach_command) logger.info(attach_info) + provider.cleanup() + def _exec(updater, cmd, screen, tmux, expect_error=False, port_forward=None): if cmd: @@ -363,20 +393,26 @@ def rsync(config_file, source, target, override_cluster_name, down): config = _bootstrap_config(config) head_node = _get_head_node( config, config_file, override_cluster_name, create_if_needed=False) - updater = NodeUpdaterProcess( + + provider = get_node_provider(config["provider"], config["cluster_name"]) + updater = NodeUpdaterThread( head_node, config["provider"], + provider, config["auth"], config["cluster_name"], - config["file_mounts"], [], + config["file_mounts"], + [], "", - redirect_output=False) + ) if down: rsync = updater.rsync_down else: rsync = updater.rsync_up rsync(source, target, check_error=False) + provider.cleanup() + def get_head_node_ip(config_file, override_cluster_name): """Returns head node IP for given configuration file if exists.""" @@ -384,9 +420,13 @@ def get_head_node_ip(config_file, override_cluster_name): config = yaml.load(open(config_file).read()) if override_cluster_name is not None: config["cluster_name"] = override_cluster_name + provider = get_node_provider(config["provider"], config["cluster_name"]) head_node = _get_head_node(config, config_file, override_cluster_name) - return provider.external_ip(head_node) + ip = provider.external_ip(head_node) + provider.cleanup() + + return ip def get_worker_node_ips(config_file, override_cluster_name): @@ -409,6 +449,8 @@ def _get_head_node(config, TAG_RAY_NODE_TYPE: "head", } nodes = provider.nodes(head_node_tags) + provider.cleanup() + if len(nodes) > 0: head_node = nodes[0] return head_node diff --git a/python/ray/autoscaler/docker.py b/python/ray/autoscaler/docker.py index 7ec121ae4..6db0fedd1 100644 --- a/python/ray/autoscaler/docker.py +++ b/python/ray/autoscaler/docker.py @@ -2,8 +2,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging import os +import logging try: # py3 from shlex import quote except ImportError: # py2 @@ -20,6 +20,7 @@ def dockerize_if_needed(config): if not docker_image: if cname: logger.warning( + "dockerize_if_needed: " "Container name given but no Docker image - continuing...") return config else: diff --git a/python/ray/autoscaler/gcp/config.py b/python/ray/autoscaler/gcp/config.py index 619c493e7..0a51e6a11 100644 --- a/python/ray/autoscaler/gcp/config.py +++ b/python/ray/autoscaler/gcp/config.py @@ -11,6 +11,8 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.backends import default_backend from googleapiclient import discovery, errors +logger = logging.getLogger(__name__) + crm = discovery.build("cloudresourcemanager", "v1") iam = discovery.build("iam", "v1") compute = discovery.build("compute", "v1") @@ -30,12 +32,11 @@ DEFAULT_SERVICE_ACCOUNT_ROLES = ("roles/storage.objectAdmin", MAX_POLLS = 12 POLL_INTERVAL = 5 -logger = logging.getLogger(__name__) - def wait_for_crm_operation(operation): """Poll for cloud resource manager operation until finished.""" - logger.info("Waiting for operation {} to finish...".format(operation)) + logger.info("wait_for_crm_operation: " + "Waiting for operation {} to finish...".format(operation)) for _ in range(MAX_POLLS): result = crm.operations().get(name=operation["name"]).execute() @@ -43,7 +44,7 @@ def wait_for_crm_operation(operation): raise Exception(result["error"]) if "done" in result and result["done"]: - logger.info("Done.") + logger.info("wait_for_crm_operation: Operation done.") break time.sleep(POLL_INTERVAL) @@ -53,8 +54,9 @@ def wait_for_crm_operation(operation): def wait_for_compute_global_operation(project_name, operation): """Poll for global compute operation until finished.""" - logger.info("Waiting for operation {} to finish...".format( - operation["name"])) + logger.info("wait_for_compute_global_operation: " + "Waiting for operation {} to finish...".format( + operation["name"])) for _ in range(MAX_POLLS): result = compute.globalOperations().get( @@ -65,7 +67,8 @@ def wait_for_compute_global_operation(project_name, operation): raise Exception(result["error"]) if result["status"] == "DONE": - logger.info("Done.") + logger.info("wait_for_compute_global_operation: " + "Operation done.") break time.sleep(POLL_INTERVAL) @@ -158,8 +161,9 @@ def _configure_iam_role(config): service_account = _get_service_account(email, config) if service_account is None: - logger.info("Creating new service account {}".format( - DEFAULT_SERVICE_ACCOUNT_ID)) + logger.info("_configure_iam_role: " + "Creating new service account {}".format( + DEFAULT_SERVICE_ACCOUNT_ID)) service_account = _create_service_account( DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config) @@ -231,7 +235,8 @@ def _configure_key_pair(config): # Create a key since it doesn't exist locally or in GCP if not key_found and not os.path.exists(private_key_path): - logger.info("Creating new key pair {}".format(key_name)) + logger.info("_configure_key_pair: " + "Creating new key pair {}".format(key_name)) public_key, private_key = generate_rsa_key_pair() _create_project_ssh_key_pair(project, public_key, ssh_user) @@ -256,8 +261,9 @@ def _configure_key_pair(config): "Private key file {} not found for user {}" "".format(private_key_path, ssh_user)) - logger.info("Private key not specified in config, using {}" - "".format(private_key_path)) + logger.info("_configure_key_pair: " + "Private key not specified in config, using" + "{}".format(private_key_path)) config["auth"]["ssh_private_key"] = private_key_path diff --git a/python/ray/autoscaler/gcp/node_provider.py b/python/ray/autoscaler/gcp/node_provider.py index a191dea34..b31ab1af6 100644 --- a/python/ray/autoscaler/gcp/node_provider.py +++ b/python/ray/autoscaler/gcp/node_provider.py @@ -4,24 +4,25 @@ from __future__ import print_function from uuid import uuid4 import time - import logging + from googleapiclient import discovery from ray.autoscaler.node_provider import NodeProvider from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL +logger = logging.getLogger(__name__) + INSTANCE_NAME_MAX_LEN = 64 INSTANCE_NAME_UUID_LEN = 8 -logger = logging.getLogger(__name__) - def wait_for_compute_zone_operation(compute, project_name, operation, zone): """Poll for compute zone operation until finished.""" - logger.info("Waiting for operation {} to finish...".format( - operation["name"])) + logger.info("wait_for_compute_zone_operation: " + "Waiting for operation {} to finish...".format( + operation["name"])) for _ in range(MAX_POLLS): result = compute.zoneOperations().get( @@ -31,7 +32,8 @@ def wait_for_compute_zone_operation(compute, project_name, operation, zone): raise Exception(result["error"]) if result["status"] == "DONE": - logger.info("Done.") + logger.info("wait_for_compute_zone_operation: " + "Operation {} finished.".format(operation["name"])) break time.sleep(POLL_INTERVAL) diff --git a/python/ray/autoscaler/local/node_provider.py b/python/ray/autoscaler/local/node_provider.py index f20da8d22..291f2189c 100644 --- a/python/ray/autoscaler/local/node_provider.py +++ b/python/ray/autoscaler/local/node_provider.py @@ -12,6 +12,7 @@ from ray.autoscaler.node_provider import NodeProvider from ray.autoscaler.tags import TAG_RAY_NODE_TYPE logger = logging.getLogger(__name__) + filelock_logger = logging.getLogger("filelock") filelock_logger.setLevel(logging.WARNING) @@ -26,7 +27,8 @@ class ClusterState(object): workers = json.loads(open(self.save_path).read()) else: workers = {} - logger.info("Loaded cluster state: {}".format(workers)) + logger.info("ClusterState: " + "Loaded cluster state: {}".format(workers)) for worker_ip in provider_config["worker_ips"]: if worker_ip not in workers: workers[worker_ip] = { @@ -50,7 +52,8 @@ class ClusterState(object): TAG_RAY_NODE_TYPE] == "head" assert len(workers) == len(provider_config["worker_ips"]) + 1 with open(self.save_path, "w") as f: - logger.info("Writing cluster state: {}".format(workers)) + logger.info("ClusterState: " + "Writing cluster state: {}".format(workers)) f.write(json.dumps(workers)) def get(self): @@ -65,7 +68,8 @@ class ClusterState(object): workers = self.get() workers[worker_id] = info with open(self.save_path, "w") as f: - logger.info("Writing cluster state: {}".format(workers)) + logger.info("ClusterState: " + "Writing cluster state: {}".format(workers)) f.write(json.dumps(workers)) diff --git a/python/ray/autoscaler/log_timer.py b/python/ray/autoscaler/log_timer.py new file mode 100644 index 000000000..935f46b61 --- /dev/null +++ b/python/ray/autoscaler/log_timer.py @@ -0,0 +1,21 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +import logging + +logger = logging.getLogger(__name__) + + +class LogTimer(): + def __init__(self, message): + self._message = message + + def __enter__(self): + self._start_time = datetime.datetime.utcnow() + + def __exit__(self, *_): + td = datetime.datetime.utcnow() - self._start_time + logger.info(self._message + + " [LogTimer={:.0f}ms]".format(td.total_seconds() * 1000)) diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index acbe96772..0b295ae6e 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -3,9 +3,12 @@ from __future__ import division from __future__ import print_function import importlib +import logging import os import yaml +logger = logging.getLogger(__name__) + def import_aws(): from ray.autoscaler.aws.config import bootstrap_aws @@ -174,3 +177,14 @@ class NodeProvider(object): def terminate_node(self, node_id): """Terminates the specified node.""" raise NotImplementedError + + def terminate_nodes(self, node_ids): + """Terminates a set of nodes. May be overridden with a batch method.""" + for node_id in node_ids: + logger.info("NodeProvider: " + "{}: Terminating node".format(node_id)) + self.terminate_node(node_id) + + def cleanup(self): + """Clean-up when a Provider is no longer required.""" + pass diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index 1d6b5e23b..b79f05909 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -10,24 +10,33 @@ import logging import os import subprocess import sys -import tempfile import time -from multiprocessing import Process from threading import Thread -from ray.autoscaler.node_provider import get_node_provider from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG +from ray.autoscaler.log_timer import LogTimer + +logger = logging.getLogger(__name__) # How long to wait for a node to start, in seconds NODE_START_WAIT_S = 300 SSH_CHECK_INTERVAL = 5 - -logger = logging.getLogger(__name__) +SSH_CONTROL_PATH = "/tmp/ray_ssh_sockets" -def pretty_cmd(cmd_str): - return "\n\n\t{}\n\n".format(cmd_str) +def get_default_ssh_options(private_key, connect_timeout): + OPTS = [ + ("ConnectTimeout", "{}s".format(connect_timeout)), + ("StrictHostKeyChecking", "no"), + ("ControlMaster", "auto"), + ("ControlPath", "{}/%C".format(SSH_CONTROL_PATH)), + ("ControlPersist", "yes"), + ] + + return ["-i", private_key] + [ + x for y in (["-o", "{}={}".format(k, v)] for k, v in OPTS) for x in y + ] class NodeUpdater(object): @@ -36,12 +45,12 @@ class NodeUpdater(object): def __init__(self, node_id, provider_config, + provider, auth_config, cluster_name, file_mounts, setup_cmds, runtime_hash, - redirect_output=True, process_runner=subprocess, use_internal_ip=False): self.daemon = True @@ -49,31 +58,22 @@ class NodeUpdater(object): self.node_id = node_id self.use_internal_ip = (use_internal_ip or provider_config.get( "use_internal_ips", False)) - self.provider = get_node_provider(provider_config, cluster_name) + self.provider = provider self.ssh_private_key = auth_config["ssh_private_key"] self.ssh_user = auth_config["ssh_user"] - self.ssh_ip = self.get_node_ip() + self.ssh_ip = None self.file_mounts = { remote: os.path.expanduser(local) for remote, local in file_mounts.items() } self.setup_cmds = setup_cmds self.runtime_hash = runtime_hash - self.logger = logger.getChild(str(node_id)) - if redirect_output: - self.logfile = tempfile.NamedTemporaryFile( - mode="w", prefix="node-updater-", delete=False) - handler = logging.StreamHandler(stream=self.logfile) - handler.setLevel(logging.INFO) - self.logger.addHandler(handler) - self.output_name = self.logfile.name - self.stdout = self.logfile - self.stderr = self.logfile + + def get_caller(self, check_error): + if check_error: + return self.process_runner.call else: - self.logfile = None - self.output_name = "(console)" - self.stdout = sys.stdout - self.stderr = sys.stderr + return self.process_runner.check_call def get_node_ip(self): if self.use_internal_ip: @@ -81,128 +81,173 @@ class NodeUpdater(object): else: return self.provider.external_ip(self.node_id) + def wait_for_ip(self, deadline): + while time.time() < deadline and \ + not self.provider.is_terminated(self.node_id): + logger.info("NodeUpdater: " + "Waiting for IP of {}...".format(self.node_id)) + ip = self.get_node_ip() + if ip is not None: + return ip + time.sleep(10) + + return None + + def set_ssh_ip_if_required(self): + if self.ssh_ip is not None: + return + + # We assume that this never changes. + # I think that's reasonable. + deadline = time.time() + NODE_START_WAIT_S + with LogTimer("NodeUpdater: {}: Got IP".format(self.node_id)): + ip = self.wait_for_ip(deadline) + assert ip is not None, "Unable to find IP of node" + + self.ssh_ip = ip + + # This should run before any SSH commands and therefore ensure that + # the ControlPath directory exists, allowing SSH to maintain + # persistent sessions later on. + with open("/dev/null", "w") as redirect: + self.get_caller(False)( + ["mkdir", "-p", SSH_CONTROL_PATH], + stdout=redirect, + stderr=redirect) + + self.get_caller(False)( + ["chmod", "0700", SSH_CONTROL_PATH], + stdout=redirect, + stderr=redirect) + def run(self): - self.logger.info( - "NodeUpdater: Updating {} to {}, logging to {}".format( - self.node_id, self.runtime_hash, self.output_name)) + logger.info("NodeUpdater: " + "{}: Updating to {}".format(self.node_id, + self.runtime_hash)) try: - self.do_update() + m = "{}: Applied config {}".format(self.node_id, self.runtime_hash) + with LogTimer("NodeUpdater: {}".format(m)): + self.do_update() except Exception as e: error_str = str(e) if hasattr(e, "cmd"): error_str = "(Exit Status {}) {}".format( - e.returncode, pretty_cmd(" ".join(e.cmd))) - self.logger.error("NodeUpdater: Error updating {}" - "See {} for remote logs.".format( - error_str, self.output_name)) + e.returncode, " ".join(e.cmd)) + logger.error("NodeUpdater: " + "{}: Error updating {}".format( + self.node_id, error_str)) self.provider.set_node_tags(self.node_id, {TAG_RAY_NODE_STATUS: "update-failed"}) - if self.logfile is not None: - self.logger.info("----- BEGIN REMOTE LOGS -----\n" + - open(self.logfile.name).read() + - "\n----- END REMOTE LOGS -----") raise e + self.provider.set_node_tags( self.node_id, { TAG_RAY_NODE_STATUS: "up-to-date", TAG_RAY_RUNTIME_CONFIG: self.runtime_hash }) - self.logger.info("NodeUpdater: Applied config {} to node {}".format( - self.runtime_hash, self.node_id)) - def do_update(self): - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "waiting-for-ssh"}) - deadline = time.time() + NODE_START_WAIT_S + self.exitcode = 0 - # Wait for external IP - while time.time() < deadline and \ - not self.provider.is_terminated(self.node_id): - self.logger.info("NodeUpdater: Waiting for IP of {}...".format( - self.node_id)) - self.ssh_ip = self.get_node_ip() - if self.ssh_ip is not None: - break - time.sleep(10) - assert self.ssh_ip is not None, "Unable to find IP of node" + def wait_for_ssh(self, deadline): + logger.info("NodeUpdater: " + "{}: Waiting for SSH...".format(self.node_id)) - # Wait for SSH access - ssh_ok = False while time.time() < deadline and \ not self.provider.is_terminated(self.node_id): try: - self.logger.info( - "NodeUpdater: Waiting for SSH to {}...".format( - self.node_id)) - if not self.provider.is_running(self.node_id): - raise Exception("Node not running yet...") + logger.debug("NodeUpdater: " + "{}: Waiting for SSH...".format(self.node_id)) self.ssh_cmd( "uptime", connect_timeout=5, redirect=open("/dev/null", "w")) - ssh_ok = True + return True + except Exception as e: retry_str = str(e) if hasattr(e, "cmd"): retry_str = "(Exit Status {}): {}".format( - e.returncode, pretty_cmd(" ".join(e.cmd))) - self.logger.debug( - "NodeUpdater: SSH not up, retrying: {}".format(retry_str), - ) + e.returncode, " ".join(e.cmd)) + logger.debug("NodeUpdater: " + "{}: SSH not up, retrying: {}".format( + self.node_id, retry_str)) time.sleep(SSH_CHECK_INTERVAL) - else: - break - assert ssh_ok, "Unable to SSH to node" + + return False + + def do_update(self): + self.provider.set_node_tags(self.node_id, + {TAG_RAY_NODE_STATUS: "waiting-for-ssh"}) + + deadline = time.time() + NODE_START_WAIT_S + self.set_ssh_ip_if_required() + + # Wait for SSH access + with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)): + ssh_ok = self.wait_for_ssh(deadline) + assert ssh_ok, "Unable to SSH to node" # Rsync file mounts self.provider.set_node_tags(self.node_id, {TAG_RAY_NODE_STATUS: "syncing-files"}) for remote_path, local_path in self.file_mounts.items(): - self.logger.info("NodeUpdater: Syncing {} to {}...".format( - local_path, remote_path)) + logger.info("NodeUpdater: " + "{}: Syncing {} to {}...".format( + self.node_id, local_path, remote_path)) assert os.path.exists(local_path), local_path if os.path.isdir(local_path): if not local_path.endswith("/"): local_path += "/" if not remote_path.endswith("/"): remote_path += "/" - self.ssh_cmd("mkdir -p {}".format(os.path.dirname(remote_path))) - self.rsync_up(local_path, remote_path) + + m = "{}: Synced {} to {}".format(self.node_id, local_path, + remote_path) + with LogTimer("NodeUpdater {}".format(m)): + self.ssh_cmd( + "mkdir -p {}".format(os.path.dirname(remote_path)), + redirect=open("/dev/null", "w"), + ) + self.rsync_up( + local_path, remote_path, redirect=open("/dev/null", "w")) # Run init commands self.provider.set_node_tags(self.node_id, {TAG_RAY_NODE_STATUS: "setting-up"}) - for cmd in self.setup_cmds: - self.ssh_cmd(cmd, verbose=True) - def rsync_up(self, source, target, check_error=True): - if check_error: - call = self.process_runner.call - else: - call = self.process_runner.check_call - call( + m = "{}: Setup commands completed".format(self.node_id) + with LogTimer("NodeUpdater: {}".format(m)): + for cmd in self.setup_cmds: + self.ssh_cmd( + cmd, + # verbose=True, + redirect=open("/dev/null", "w")) + + def rsync_up(self, source, target, redirect=None, check_error=True): + self.set_ssh_ip_if_required() + self.get_caller(check_error)( [ - "rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) + - "-o ConnectTimeout=120s -o StrictHostKeyChecking=no", + "rsync", "-e", + " ".join(["ssh"] + + get_default_ssh_options(self.ssh_private_key, 120)), "--delete", "-avz", source, "{}@{}:{}".format( self.ssh_user, self.ssh_ip, target) ], - stdout=self.stdout, - stderr=self.stderr) + stdout=redirect or sys.stdout, + stderr=redirect or sys.stderr) - def rsync_down(self, source, target, check_error=True): - if check_error: - call = self.process_runner.call - else: - call = self.process_runner.check_call - call( + def rsync_down(self, source, target, redirect=None, check_error=True): + self.set_ssh_ip_if_required() + self.get_caller(check_error)( [ - "rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) + - "-o ConnectTimeout=120s -o StrictHostKeyChecking=no", "-avz", - "{}@{}:{}".format(self.ssh_user, self.ssh_ip, source), target + "rsync", "-e", + " ".join(["ssh"] + + get_default_ssh_options(self.ssh_private_key, 120)), + "-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip, + source), target ], - stdout=self.stdout, - stderr=self.stderr) + stdout=redirect or sys.stdout, + stderr=redirect or sys.stderr) def ssh_cmd(self, cmd, @@ -213,9 +258,12 @@ class NodeUpdater(object): emulate_interactive=True, expect_error=False, port_forward=None): + + self.set_ssh_ip_if_required() + if verbose: - self.logger.info("NodeUpdater: running {} on {}...".format( - pretty_cmd(cmd), self.ssh_ip)) + logger.info("NodeUpdater: " + "Running {} on {}...".format(cmd, self.ssh_ip)) ssh = ["ssh"] if allocate_tty: ssh.append("-tt") @@ -224,35 +272,24 @@ class NodeUpdater(object): "set -i || true && source ~/.bashrc && " "export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && ") cmd = "bash --login -c {}".format(quote(force_interactive + cmd)) - if expect_error: - call = self.process_runner.call - else: - call = self.process_runner.check_call + if port_forward is None: ssh_opt = [] else: ssh_opt = [ "-L", "{}:localhost:{}".format(port_forward, port_forward) ] - call( - ssh + ssh_opt + [ - "-o", "ConnectTimeout={}s".format(connect_timeout), "-o", - "StrictHostKeyChecking=no", "-i", self.ssh_private_key, - "{}@{}".format(self.ssh_user, self.ssh_ip), cmd - ], - stdout=redirect or self.stdout, - stderr=redirect or self.stderr) + + self.get_caller(expect_error)( + ssh + ssh_opt + get_default_ssh_options(self.ssh_private_key, + connect_timeout) + + ["{}@{}".format(self.ssh_user, self.ssh_ip), cmd], + stdout=redirect or sys.stdout, + stderr=redirect or sys.stderr) -class NodeUpdaterProcess(NodeUpdater, Process): - def __init__(self, *args, **kwargs): - Process.__init__(self) - NodeUpdater.__init__(self, *args, **kwargs) - - -# Single-threaded version for unit tests class NodeUpdaterThread(NodeUpdater, Thread): def __init__(self, *args, **kwargs): Thread.__init__(self) NodeUpdater.__init__(self, *args, **kwargs) - self.exitcode = 0 + self.exitcode = -1 diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 0b6d05d40..238d0dc96 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -20,7 +20,6 @@ from ray.services import get_ip_address, get_port from ray.utils import (binary_to_hex, binary_to_object_id, hex_to_binary, setup_logger) -# Set up logging. logger = logging.getLogger(__name__) @@ -69,8 +68,10 @@ class Monitor(object): # that redis server. addr_port = self.redis.lrange("RedisShards", 0, -1) if len(addr_port) > 1: - logger.warning("TODO: if launching > 1 redis shard, flushing " - "needs to touch shards in parallel.") + logger.warning( + "Monitor: " + "TODO: if launching > 1 redis shard, flushing needs to " + "touch shards in parallel.") self.issue_gcs_flushes = False else: addr_port = addr_port[0].split(b":") @@ -82,6 +83,7 @@ class Monitor(object): self.redis_shard.execute_command("HEAD.FLUSH 0") except redis.exceptions.ResponseError as e: logger.info( + "Monitor: " "Turning off flushing due to exception: {}".format( str(e))) self.issue_gcs_flushes = False @@ -128,8 +130,9 @@ class Monitor(object): self.load_metrics.update(ip, static_resources, dynamic_resources) else: - print("Warning: could not find ip for client {} in {}.".format( - client_id, self.local_scheduler_id_to_ip_map)) + logger.warning( + "Monitor: " + "could not find ip for client {}".format(client_id)) def _xray_clean_up_entries_for_driver(self, driver_id): """Remove this driver's object/task entries from redis. @@ -185,11 +188,14 @@ class Monitor(object): continue redis = self.state.redis_clients[shard_index] num_deleted = redis.delete(*keys) - logger.info("Removed {} dead redis entries of the driver from" - " redis shard {}.".format(num_deleted, shard_index)) + logger.info("Monitor: " + "Removed {} dead redis entries of the " + "driver from redis shard {}.".format( + num_deleted, shard_index)) if num_deleted != len(keys): - logger.warning("Failed to remove {} relevant redis entries" - " from redis shard {}.".format( + logger.warning("Monitor: " + "Failed to remove {} relevant redis " + "entries from redis shard {}.".format( len(keys) - num_deleted, shard_index)) def xray_driver_removed_handler(self, unused_channel, data): @@ -205,8 +211,9 @@ class Monitor(object): message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData( driver_data, 0) driver_id = message.DriverId() - logger.info("XRay Driver {} has been removed.".format( - binary_to_hex(driver_id))) + logger.info("Monitor: " + "XRay Driver {} has been removed.".format( + binary_to_hex(driver_id))) self._xray_clean_up_entries_for_driver(driver_id) def process_messages(self, max_messages=10000): @@ -281,7 +288,7 @@ class Monitor(object): max_entries_to_flush = self.gcs_flush_policy.num_entries_to_flush() num_flushed = self.redis_shard.execute_command( "HEAD.FLUSH {}".format(max_entries_to_flush)) - logger.info("num_flushed {}".format(num_flushed)) + logger.info("Monitor: num_flushed {}".format(num_flushed)) # This flushes event log and log files. ray.experimental.flush_redis_unsafe(self.redis) diff --git a/test/autoscaler_test.py b/test/autoscaler_test.py index fef5b839d..6793898f2 100644 --- a/test/autoscaler_test.py +++ b/test/autoscaler_test.py @@ -17,7 +17,6 @@ from ray.autoscaler.autoscaler import StandardAutoscaler, LoadMetrics, \ fillout_defaults, validate_config from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_NODE_STATUS from ray.autoscaler.node_provider import NODE_PROVIDERS, NodeProvider -from ray.autoscaler.updater import NodeUpdaterThread import pytest @@ -166,9 +165,9 @@ class LoadMetricsTest(unittest.TestCase): lm.update("1.1.1.1", {"CPU": 2}, {"CPU": 0}) lm.update("2.2.2.2", {"CPU": 2, "GPU": 16}, {"CPU": 2, "GPU": 2}) debug = lm.info_string() - assert "ResourceUsage: 2.0/4.0 CPU, 14.0/16.0 GPU" in debug - assert "NumNodesConnected: 2" in debug - assert "NumNodesUsed: 1.88" in debug + assert "ResourceUsage=2.0/4.0 CPU, 14.0/16.0 GPU" in debug + assert "NumNodesConnected=2" in debug + assert "NumNodesUsed=1.88" in debug class AutoscalingTest(unittest.TestCase): @@ -528,18 +527,12 @@ class AutoscalingTest(unittest.TestCase): LoadMetrics(), max_failures=0, process_runner=runner, - verbose_updates=True, - node_updater_cls=NodeUpdaterThread, update_interval_s=0) autoscaler.update() autoscaler.update() self.waitForNodes(2) for node in self.provider.mock_nodes.values(): node.state = "running" - assert len( - self.provider.nodes({ - TAG_RAY_NODE_STATUS: "uninitialized" - })) == 2 autoscaler.update() self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"}) @@ -552,18 +545,12 @@ class AutoscalingTest(unittest.TestCase): LoadMetrics(), max_failures=0, process_runner=runner, - verbose_updates=True, - node_updater_cls=NodeUpdaterThread, update_interval_s=0) autoscaler.update() autoscaler.update() self.waitForNodes(2) for node in self.provider.mock_nodes.values(): node.state = "running" - assert len( - self.provider.nodes({ - TAG_RAY_NODE_STATUS: "uninitialized" - })) == 2 autoscaler.update() self.waitForNodes( 2, tag_filters={TAG_RAY_NODE_STATUS: "update-failed"}) @@ -577,8 +564,6 @@ class AutoscalingTest(unittest.TestCase): LoadMetrics(), max_failures=0, process_runner=runner, - verbose_updates=True, - node_updater_cls=NodeUpdaterThread, update_interval_s=0) autoscaler.update() autoscaler.update() @@ -686,8 +671,6 @@ class AutoscalingTest(unittest.TestCase): lm, max_failures=0, process_runner=runner, - verbose_updates=True, - node_updater_cls=NodeUpdaterThread, update_interval_s=0) autoscaler.update() self.waitForNodes(2)