[autoscaler] Speedups (#3720)

- NodeUpdater gets its' IP in parallel now (no longer in __init__)
- We use persistent connections in SSH (temp folder created only for ray; ControlMaster)
- hash_runtime_conf was performing a pointless hexlify step, wasting time on large files
- We use NodeUpdaterThreads and share the NodeProvider; NodeUpdaterProcess is removed
- AWSNodeProvider caches nodes more aggressively
- NodeProvider now has a shim batch terminate_nodes() call; AWSNodeProvider parallelises it; the autoscaler uses it
- AWSNodeProvider batches EC2 update_tags calls
- Logging changes throughout to provide standardised timing information for profiling
- Pulled out a few unnecessary is_running calls (NodeUpdater will loop waiting for SSH anyway)

## Related issue number
Issue #3599
This commit is contained in:
Daniel Edgecumbe
2019-02-01 10:46:32 +00:00
committed by Richard Liaw
parent ff3c6af1d6
commit 315edab085
13 changed files with 545 additions and 344 deletions
+97 -81
View File
@@ -2,21 +2,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import binascii
import copy
import json
import hashlib
import logging
import math
import os
from six import string_types
from six.moves import queue
import subprocess
import threading
import logging
import time
from collections import defaultdict
from datetime import datetime
import numpy as np
import yaml
@@ -26,11 +24,12 @@ from ray.ray_constants import AUTOSCALER_MAX_NUM_FAILURES, \
AUTOSCALER_UPDATE_INTERVAL_S, AUTOSCALER_HEARTBEAT_TIMEOUT_S
from ray.autoscaler.node_provider import get_node_provider, \
get_default_config
from ray.autoscaler.updater import NodeUpdaterProcess
from ray.autoscaler.updater import NodeUpdaterThread
from ray.autoscaler.docker import dockerize_if_needed
from ray.autoscaler.tags import (TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG,
TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE,
TAG_RAY_NODE_NAME)
import ray.services as services
logger = logging.getLogger(__name__)
@@ -158,11 +157,13 @@ class LoadMetrics(object):
def prune(mapping):
unwanted = set(mapping) - active_ips
for unwanted_key in unwanted:
logger.info("Removed mapping: {} - {}".format(
unwanted_key, mapping[unwanted_key]))
logger.info("LoadMetrics: "
"Removed mapping: {} - {}".format(
unwanted_key, mapping[unwanted_key]))
del mapping[unwanted_key]
if unwanted:
logger.info(
"LoadMetrics: "
"Removed {} stale ip mappings: {} not in {}".format(
len(unwanted), unwanted, active_ips))
@@ -174,8 +175,8 @@ class LoadMetrics(object):
return self._info()["NumNodesUsed"]
def info_string(self):
return " - {}".format("\n - ".join(
["{}: {}".format(k, v) for k, v in sorted(self._info().items())]))
return ", ".join(
["{}={}".format(k, v) for k, v in sorted(self._info().items())])
def _info(self):
nodes_used = 0.0
@@ -223,17 +224,13 @@ class LoadMetrics(object):
class NodeLauncher(threading.Thread):
def __init__(self, queue, pending, *args, **kwargs):
def __init__(self, provider, queue, pending, *args, **kwargs):
self.queue = queue
self.pending = pending
self.provider = None
self.provider = provider
super(NodeLauncher, self).__init__(*args, **kwargs)
def _launch_node(self, config, count):
if self.provider is None:
self.provider = get_node_provider(config["provider"],
config["cluster_name"])
tag_filters = {TAG_RAY_NODE_TYPE: "worker"}
before = self.provider.nodes(tag_filters=tag_filters)
launch_hash = hash_launch_conf(config["worker_nodes"], config["auth"])
@@ -247,7 +244,8 @@ class NodeLauncher(threading.Thread):
}, count)
after = self.provider.nodes(tag_filters=tag_filters)
if set(after).issubset(before):
logger.error("No new nodes reported after node creation")
logger.error("NodeLauncher: "
"No new nodes reported after node creation")
def run(self):
while True:
@@ -305,8 +303,6 @@ class StandardAutoscaler(object):
max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES,
max_failures=AUTOSCALER_MAX_NUM_FAILURES,
process_runner=subprocess,
verbose_updates=True,
node_updater_cls=NodeUpdaterProcess,
update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S):
self.config_path = config_path
self.reload_config(errors_fatal=True)
@@ -317,9 +313,7 @@ class StandardAutoscaler(object):
self.max_failures = max_failures
self.max_launch_batch = max_launch_batch
self.max_concurrent_launches = max_concurrent_launches
self.verbose_updates = verbose_updates
self.process_runner = process_runner
self.node_updater_cls = node_updater_cls
# Map from node_id to NodeUpdater processes
self.updaters = {}
@@ -337,7 +331,9 @@ class StandardAutoscaler(object):
max_concurrent_launches / float(max_launch_batch))
for i in range(int(max_batches)):
node_launcher = NodeLauncher(
queue=self.launch_queue, pending=self.num_launches_pending)
provider=self.provider,
queue=self.launch_queue,
pending=self.num_launches_pending)
node_launcher.daemon = True
node_launcher.start()
@@ -359,11 +355,12 @@ class StandardAutoscaler(object):
self.reload_config(errors_fatal=False)
self._update()
except Exception as e:
logger.exception("Error during autoscaling.")
logger.exception("StandardAutoscaler: "
"Error during autoscaling.")
self.num_failures += 1
if self.num_failures > self.max_failures:
logger.critical(
"*** StandardAutoscaler: Too many errors, abort. ***")
logger.critical("StandardAutoscaler: "
"Too many errors, abort.")
raise e
def _update(self):
@@ -377,7 +374,7 @@ class StandardAutoscaler(object):
self.last_update_time = now
num_pending = self.num_launches_pending.value
nodes = self.workers()
logger.info(self.info_string(nodes))
self.log_info_string(nodes)
self.load_metrics.prune_active_ips(
[self.provider.internal_ip(node_id) for node_id in nodes])
target_workers = self.target_num_workers()
@@ -385,35 +382,37 @@ class StandardAutoscaler(object):
# Terminate any idle or out of date nodes
last_used = self.load_metrics.last_used_time_by_ip
horizon = now - (60 * self.config["idle_timeout_minutes"])
num_terminated = 0
nodes_to_terminate = []
for node_id in nodes:
node_ip = self.provider.internal_ip(node_id)
if node_ip in last_used and last_used[node_ip] < horizon and \
len(nodes) - num_terminated > target_workers:
num_terminated += 1
logger.info("StandardAutoscaler: Terminating idle node: "
"{}".format(node_id))
self.provider.terminate_node(node_id)
len(nodes) - len(nodes_to_terminate) > target_workers:
logger.info("StandardAutoscaler: "
"{}: Terminating idle node".format(node_id))
nodes_to_terminate.append(node_id)
elif not self.launch_config_ok(node_id):
num_terminated += 1
logger.info("StandardAutoscaler: Terminating outdated node: "
"{}".format(node_id))
self.provider.terminate_node(node_id)
if num_terminated > 0:
logger.info("StandardAutoscaler: "
"{}: Terminating outdated node".format(node_id))
nodes_to_terminate.append(node_id)
if nodes_to_terminate:
self.provider.terminate_nodes(nodes_to_terminate)
nodes = self.workers()
logger.info(self.info_string(nodes))
self.log_info_string(nodes)
# Terminate nodes if there are too many
num_terminated = 0
nodes_to_terminate = []
while len(nodes) > self.config["max_workers"]:
num_terminated += 1
logger.info("StandardAutoscaler: Terminating unneeded node: "
"{}".format(nodes[-1]))
self.provider.terminate_node(nodes[-1])
logger.info("StandardAutoscaler: "
"{}: Terminating unneeded node".format(nodes[-1]))
nodes_to_terminate.append(nodes[-1])
nodes = nodes[:-1]
if num_terminated > 0:
if nodes_to_terminate:
self.provider.terminate_nodes(nodes_to_terminate)
nodes = self.workers()
logger.info(self.info_string(nodes))
self.log_info_string(nodes)
# Launch new nodes if needed
num_workers = len(nodes) + num_pending
@@ -422,7 +421,8 @@ class StandardAutoscaler(object):
self.max_concurrent_launches - num_pending)
num_launches = min(max_allowed, target_workers - num_workers)
self.launch_new_node(num_launches)
logger.info(self.info_string())
nodes = self.workers()
self.log_info_string(nodes)
else:
self.bringup = False
@@ -442,11 +442,21 @@ class StandardAutoscaler(object):
# immediately trying to restart Ray on the new node.
self.load_metrics.mark_active(self.provider.internal_ip(node_id))
nodes = self.workers()
logger.info(self.info_string(nodes))
self.log_info_string(nodes)
# Update nodes with out-of-date files
for node_id in nodes:
self.update_if_needed(node_id)
T = [
threading.Thread(
target=self.spawn_updater,
args=(node_id, commands),
) for node_id, commands in (self.should_update(node_id)
for node_id in nodes)
if node_id is not None
]
for t in T:
t.start()
for t in T:
t.join()
# Attempt to recover unhealthy nodes
for node_id in nodes:
@@ -471,7 +481,8 @@ class StandardAutoscaler(object):
if errors_fatal:
raise e
else:
logger.exception("StandardAutoscaler: Error parsing config.")
logger.exception("StandardAutoscaler: "
"Error parsing config.")
def target_num_workers(self):
initial_workers = self.config["initial_workers"]
@@ -496,9 +507,9 @@ class StandardAutoscaler(object):
def files_up_to_date(self, node_id):
applied = self.provider.node_tags(node_id).get(TAG_RAY_RUNTIME_CONFIG)
if applied != self.runtime_hash:
logger.info(
"StandardAutoscaler: {} has runtime state {}, want {}".format(
node_id, applied, self.runtime_hash))
logger.info("StandardAutoscaler: "
"{}: Runtime state is {}, want {}".format(
node_id, applied, self.runtime_hash))
return False
return True
@@ -512,27 +523,29 @@ class StandardAutoscaler(object):
delta = now - last_heartbeat_time
if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S:
return
logger.warning("StandardAutoscaler: No heartbeat from node "
"{} in {} seconds, restarting Ray to recover...".format(
node_id, delta))
updater = self.node_updater_cls(
logger.warning("StandardAutoscaler: "
"{}: No heartbeat in {}s, "
"restarting Ray to recover...".format(node_id, delta))
updater = NodeUpdaterThread(
node_id,
self.config["provider"],
self.provider,
self.config["auth"],
self.config["cluster_name"], {},
with_head_node_ip(self.config["worker_start_ray_commands"]),
self.runtime_hash,
redirect_output=not self.verbose_updates,
process_runner=self.process_runner,
use_internal_ip=True)
updater.start()
self.updaters[node_id] = updater
def update_if_needed(self, node_id):
def should_update(self, node_id):
if not self.can_update(node_id):
return
return (None, None)
if self.files_up_to_date(node_id):
return
return (None, None)
successful_updated = self.num_successful_updates.get(node_id, 0) > 0
if successful_updated and self.config.get("restart_only", False):
init_commands = self.config["worker_start_ray_commands"]
@@ -544,33 +557,35 @@ class StandardAutoscaler(object):
self.config["worker_setup_commands"] +
self.config["worker_start_ray_commands"])
updater = self.node_updater_cls(
return (node_id, init_commands)
def spawn_updater(self, node_id, init_commands):
updater = NodeUpdaterThread(
node_id,
self.config["provider"],
self.provider,
self.config["auth"],
self.config["cluster_name"],
self.config["file_mounts"],
with_head_node_ip(init_commands),
self.runtime_hash,
redirect_output=not self.verbose_updates,
process_runner=self.process_runner,
use_internal_ip=True)
updater.start()
self.updaters[node_id] = updater
def can_update(self, node_id):
if not self.provider.is_running(node_id):
if node_id in self.updaters:
return False
if not self.launch_config_ok(node_id):
return False
if node_id in self.updaters:
return False
if self.num_failed_updates.get(node_id, 0) > 0: # TODO(ekl) retry?
return False
return True
def launch_new_node(self, count):
logger.info("StandardAutoscaler: Launching {} new nodes".format(count))
logger.info("StandardAutoscaler: "
"Launching {} new nodes".format(count))
self.num_launches_pending.inc(count)
config = copy.deepcopy(self.config)
self.launch_queue.put((config, count))
@@ -578,9 +593,11 @@ class StandardAutoscaler(object):
def workers(self):
return self.provider.nodes(tag_filters={TAG_RAY_NODE_TYPE: "worker"})
def info_string(self, nodes=None):
if nodes is None:
nodes = self.workers()
def log_info_string(self, nodes):
logger.info("StandardAutoscaler: {}".format(self.info_string(nodes)))
logger.info("LoadMetrics: {}".format(self.load_metrics.info_string()))
def info_string(self, nodes):
suffix = ""
if self.num_launches_pending:
suffix += " ({} pending)".format(self.num_launches_pending.value)
@@ -591,9 +608,9 @@ class StandardAutoscaler(object):
len(self.num_failed_updates))
if self.bringup:
suffix += " (bringup=True)"
return "StandardAutoscaler [{}]: {}/{} target nodes{}\n{}".format(
datetime.now(), len(nodes), self.target_num_workers(), suffix,
self.load_metrics.info_string())
return "{}/{} target nodes{}".format(
len(nodes), self.target_num_workers(), suffix)
def typename(v):
@@ -685,6 +702,11 @@ def hash_runtime_conf(file_mounts, extra_objs):
hasher = hashlib.sha1()
def add_content_hashes(path):
def add_hash_of_file(fpath):
with open(fpath, "rb") as f:
for chunk in iter(lambda: f.read(2**20), b''):
hasher.update(chunk)
path = os.path.expanduser(path)
if os.path.isdir(path):
dirs = []
@@ -694,18 +716,12 @@ def hash_runtime_conf(file_mounts, extra_objs):
hasher.update(dirpath.encode("utf-8"))
for name in filenames:
hasher.update(name.encode("utf-8"))
with open(os.path.join(dirpath, name), "rb") as f:
if os.path.getsize(os.path.join(dirpath,
name)) < 1000000000:
hasher.update(binascii.hexlify(f.read()))
else:
for chunk in iter(lambda: f.read(8192), b''):
hasher.update(binascii.hexlify(chunk))
fpath = os.path.join(dirpath, name)
add_hash_of_file(fpath)
else:
with open(path, "rb") as f:
hasher.update(binascii.hexlify(f.read()))
add_hash_of_file(path)
conf_str = (json.dumps(sorted(file_mounts.items())).encode("utf-8") +
conf_str = (json.dumps(file_mounts, sort_keys=True).encode("utf-8") +
json.dumps(extra_objs, sort_keys=True).encode("utf-8"))
# Important: only hash the files once. Otherwise, we can end up restarting
+26 -15
View File
@@ -4,9 +4,9 @@ from __future__ import print_function
from distutils.version import StrictVersion
import json
import logging
import os
import time
import logging
import boto3
from botocore.config import Config
@@ -14,6 +14,8 @@ import botocore
from ray.ray_constants import BOTO_MAX_RETRIES
logger = logging.getLogger(__name__)
RAY = "ray-autoscaler"
DEFAULT_RAY_INSTANCE_PROFILE = RAY + "-v1"
DEFAULT_RAY_IAM_ROLE = RAY + "-v1"
@@ -34,7 +36,6 @@ def key_pair(i, region):
# Suppress excessive connection dropped logs from boto
logging.getLogger("botocore").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
def bootstrap_aws(config):
@@ -62,8 +63,9 @@ def _configure_iam_role(config):
profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)
if profile is None:
logger.info("Creating new instance profile {}".format(
DEFAULT_RAY_INSTANCE_PROFILE))
logger.info("_configure_iam_role: "
"Creating new instance profile {}".format(
DEFAULT_RAY_INSTANCE_PROFILE))
client = _client("iam", config)
client.create_instance_profile(
InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE)
@@ -75,7 +77,8 @@ def _configure_iam_role(config):
if not profile.roles:
role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
if role is None:
logger.info("Creating new role {}".format(DEFAULT_RAY_IAM_ROLE))
logger.info("_configure_iam_role: "
"Creating new role {}".format(DEFAULT_RAY_IAM_ROLE))
iam = _resource("iam", config)
iam.create_role(
RoleName=DEFAULT_RAY_IAM_ROLE,
@@ -99,8 +102,9 @@ def _configure_iam_role(config):
profile.add_role(RoleName=role.name)
time.sleep(15) # wait for propagation
logger.info("Role not specified for head node, using {}".format(
profile.arn))
logger.info("_configure_iam_role: "
"Role not specified for head node, using {}".format(
profile.arn))
config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
return config
@@ -126,7 +130,8 @@ def _configure_key_pair(config):
# We can safely create a new key.
if not key and not os.path.exists(key_path):
logger.info("Creating new key pair {}".format(key_name))
logger.info("_configure_key_pair: "
"Creating new key pair {}".format(key_name))
key = ec2.create_key_pair(KeyName=key_name)
with open(key_path, "w") as f:
f.write(key.key_material)
@@ -142,7 +147,8 @@ def _configure_key_pair(config):
assert os.path.exists(key_path), \
"Private key file {} not found for {}".format(key_path, key_name)
logger.info("KeyName not specified for nodes, using {}".format(key_name))
logger.info("_configure_key_pair: "
"KeyName not specified for nodes, using {}".format(key_name))
config["auth"]["ssh_private_key"] = key_path
config["head_node"]["KeyName"] = key_name
@@ -174,19 +180,21 @@ def _configure_subnet(config):
"No usable subnets matching availability zone {} "
"found. Choose a different availability zone or try "
"manually creating an instance in your specified region "
"to populate the list of subnets and trying this again."
.format(config["provider"]["availability_zone"]))
"to populate the list of subnets and trying this again.".
format(config["provider"]["availability_zone"]))
subnet_ids = [s.subnet_id for s in subnets]
subnet_descr = [(s.subnet_id, s.availability_zone) for s in subnets]
if "SubnetIds" not in config["head_node"]:
config["head_node"]["SubnetIds"] = subnet_ids
logger.info("SubnetIds not specified for head node,"
" using {}".format(subnet_descr))
logger.info("_configure_subnet: "
"SubnetIds not specified for head node, using {}".format(
subnet_descr))
if "SubnetIds" not in config["worker_nodes"]:
config["worker_nodes"]["SubnetIds"] = subnet_ids
logger.info("SubnetId not specified for workers,"
logger.info("_configure_subnet: "
"SubnetId not specified for workers,"
" using {}".format(subnet_descr))
return config
@@ -202,7 +210,8 @@ def _configure_security_group(config):
security_group = _get_security_group(config, vpc_id, group_name)
if security_group is None:
logger.info("Creating new security group {}".format(group_name))
logger.info("_configure_security_group: "
"Creating new security group {}".format(group_name))
client = _client("ec2", config)
client.create_security_group(
Description="Auto-created security group for Ray workers",
@@ -230,12 +239,14 @@ def _configure_security_group(config):
if "SecurityGroupIds" not in config["head_node"]:
logger.info(
"_configure_security_group: "
"SecurityGroupIds not specified for head node, using {}".format(
security_group.group_name))
config["head_node"]["SecurityGroupIds"] = [security_group.id]
if "SecurityGroupIds" not in config["worker_nodes"]:
logger.info(
"_configure_security_group: "
"SecurityGroupIds not specified for workers, using {}".format(
security_group.group_name))
config["worker_nodes"]["SecurityGroupIds"] = [security_group.id]
+96 -39
View File
@@ -3,6 +3,8 @@ from __future__ import division
from __future__ import print_function
import random
import threading
from collections import defaultdict
import boto3
from botocore.config import Config
@@ -10,6 +12,7 @@ from botocore.config import Config
from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
from ray.ray_constants import BOTO_MAX_RETRIES
from ray.autoscaler.log_timer import LogTimer
def to_aws_format(tags):
@@ -40,15 +43,59 @@ class AWSNodeProvider(NodeProvider):
# Try availability zones round-robin, starting from random offset
self.subnet_idx = random.randint(0, 100)
self.tag_cache = {} # Tags that we believe to actually be on EC2.
self.tag_cache_pending = {} # Tags that we will soon upload.
self.tag_cache_lock = threading.Lock()
self.tag_cache_update_event = threading.Event()
self.tag_cache_kill_event = threading.Event()
self.tag_update_thread = threading.Thread(
target=self._node_tag_update_loop)
self.tag_update_thread.start()
# Cache of node objects from the last nodes() call. This avoids
# excessive DescribeInstances requests.
self.cached_nodes = {}
# Cache of ip lookups. We assume IPs never change once assigned.
self.internal_ip_cache = {}
self.external_ip_cache = {}
def _node_tag_update_loop(self):
""" Update the AWS tags for a cluster periodically.
The purpose of this loop is to avoid excessive EC2 calls when a large
number of nodes are being launched simultaneously.
"""
while True:
self.tag_cache_update_event.wait()
self.tag_cache_update_event.clear()
batch_updates = defaultdict(list)
with self.tag_cache_lock:
for node_id, tags in self.tag_cache_pending.items():
for x in tags.items():
batch_updates[x].append(node_id)
self.tag_cache[node_id].update(tags)
self.tag_cache_pending = {}
for (k, v), node_ids in batch_updates.items():
m = "Set tag {}={} on {}".format(k, v, node_ids)
with LogTimer("AWSNodeProvider: {}".format(m)):
if k == TAG_RAY_NODE_NAME:
k = "Name"
self.ec2.meta.client.create_tags(
Resources=node_ids,
Tags=[{
"Key": k,
"Value": v
}],
)
self.tag_cache_kill_event.wait(timeout=5)
if self.tag_cache_kill_event.is_set():
return
def nodes(self, tag_filters):
# Note that these filters are acceptable because they are set on
# node initialization, and so can never be sitting in the cache.
tag_filters = to_aws_format(tag_filters)
filters = [
{
@@ -65,9 +112,19 @@ class AWSNodeProvider(NodeProvider):
"Name": "tag:{}".format(k),
"Values": [v],
})
instances = list(self.ec2.instances.filter(Filters=filters))
self.cached_nodes = {i.id: i for i in instances}
return [i.id for i in instances]
nodes = list(self.ec2.instances.filter(Filters=filters))
# Populate the tag cache with initial information if necessary
for node in nodes:
if node.id in self.tag_cache:
continue
self.tag_cache[node.id] = from_aws_format(
{x["Key"]: x["Value"]
for x in node.tags})
self.cached_nodes = {node.id: node for node in nodes}
return [node.id for node in nodes]
def is_running(self, node_id):
node = self._node(node_id)
@@ -79,40 +136,25 @@ class AWSNodeProvider(NodeProvider):
return state not in ["running", "pending"]
def node_tags(self, node_id):
node = self._node(node_id)
tags = {}
for tag in node.tags:
tags[tag["Key"]] = tag["Value"]
return from_aws_format(tags)
with self.tag_cache_lock:
d1 = self.tag_cache[node_id]
d2 = self.tag_cache_pending.get(node_id, {})
return dict(d1, **d2)
def external_ip(self, node_id):
if node_id in self.external_ip_cache:
return self.external_ip_cache[node_id]
node = self._node(node_id)
ip = node.public_ip_address
if ip:
self.external_ip_cache[node_id] = ip
return ip
return self._node(node_id).public_ip_address
def internal_ip(self, node_id):
if node_id in self.internal_ip_cache:
return self.internal_ip_cache[node_id]
node = self._node(node_id)
ip = node.private_ip_address
if ip:
self.internal_ip_cache[node_id] = ip
return ip
return self._node(node_id).private_ip_address
def set_node_tags(self, node_id, tags):
tags = to_aws_format(tags)
node = self._node(node_id)
tag_pairs = []
for k, v in tags.items():
tag_pairs.append({
"Key": k,
"Value": v,
})
node.create_tags(Tags=tag_pairs)
with self.tag_cache_lock:
try:
self.tag_cache_pending[node_id].update(tags)
except KeyError:
self.tag_cache_pending[node_id] = tags
self.tag_cache_update_event.set()
def create_node(self, node_config, tags, count):
tags = to_aws_format(tags)
@@ -166,9 +208,24 @@ class AWSNodeProvider(NodeProvider):
node = self._node(node_id)
node.terminate()
self.tag_cache.pop(node_id, None)
self.tag_cache_pending.pop(node_id, None)
def terminate_nodes(self, node_ids):
self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)
for node_id in node_ids:
self.tag_cache.pop(node_id, None)
self.tag_cache_pending.pop(node_id, None)
def _node(self, node_id):
if node_id in self.cached_nodes:
return self.cached_nodes[node_id]
matches = list(self.ec2.instances.filter(InstanceIds=[node_id]))
assert len(matches) == 1, "Invalid instance id {}".format(node_id)
return matches[0]
if node_id not in self.cached_nodes:
self.nodes({}) # Side effect: should cache it.
assert node_id in self.cached_nodes, "Invalid instance id {}".format(
node_id)
return self.cached_nodes[node_id]
def cleanup(self):
self.tag_cache_update_event.set()
self.tag_cache_kill_event.set()
+81 -39
View File
@@ -8,9 +8,9 @@ import json
import os
import tempfile
import time
import logging
import sys
import click
import logging
import random
import yaml
@@ -24,7 +24,8 @@ from ray.autoscaler.autoscaler import validate_config, hash_runtime_conf, \
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
TAG_RAY_NODE_NAME
from ray.autoscaler.updater import NodeUpdaterProcess
from ray.autoscaler.updater import NodeUpdaterThread
from ray.autoscaler.log_timer import LogTimer
logger = logging.getLogger(__name__)
@@ -81,18 +82,35 @@ def teardown_cluster(config_file, yes, workers_only, override_cluster_name):
provider = get_node_provider(config["provider"], config["cluster_name"])
if not workers_only:
for node in provider.nodes({TAG_RAY_NODE_TYPE: "head"}):
logger.info("Terminating head node {}".format(node))
provider.terminate_node(node)
def remaining_nodes():
if workers_only:
A = []
else:
A = [
node_id for node_id in provider.nodes({
TAG_RAY_NODE_TYPE: "head"
})
]
nodes = provider.nodes({TAG_RAY_NODE_TYPE: "worker"})
while nodes:
for node in nodes:
logger.info("Terminating worker {}".format(node))
provider.terminate_node(node)
time.sleep(5)
nodes = provider.nodes({TAG_RAY_NODE_TYPE: "worker"})
A += [
node_id for node_id in provider.nodes({
TAG_RAY_NODE_TYPE: "worker"
})
]
return A
# Loop here to check that both the head and worker nodes are actually
# really gone
A = remaining_nodes()
with LogTimer("teardown_cluster: Termination done."):
while A:
logger.info("teardown_cluster: "
"Terminating {} nodes...".format(len(A)))
provider.terminate_nodes(A)
time.sleep(1)
A = remaining_nodes()
provider.cleanup()
def kill_node(config_file, yes, override_cluster_name):
@@ -108,26 +126,26 @@ def kill_node(config_file, yes, override_cluster_name):
provider = get_node_provider(config["provider"], config["cluster_name"])
nodes = provider.nodes({TAG_RAY_NODE_TYPE: "worker"})
node = random.choice(nodes)
logger.info("Terminating worker {}".format(node))
updater = NodeUpdaterProcess(
node,
config["provider"],
config["auth"],
config["cluster_name"],
config["file_mounts"], [],
"",
redirect_output=False)
logger.info("kill_node: Terminating worker {}".format(node))
updater = NodeUpdaterThread(node, config["provider"], provider,
config["auth"], config["cluster_name"],
config["file_mounts"], [], "")
_exec(updater, "ray stop", False, False)
time.sleep(5)
iip = provider.internal_ip(node)
if iip:
return iip
return provider.external_ip(node)
def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
override_cluster_name):
"""Create the cluster head node, which in turn creates the workers."""
provider = get_node_provider(config["provider"], config["cluster_name"])
head_node_tags = {
TAG_RAY_NODE_TYPE: "head",
@@ -148,9 +166,10 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
TAG_RAY_LAUNCH_CONFIG) != launch_hash:
if head_node is not None:
confirm("Head node config out-of-date. It will be terminated", yes)
logger.info("Terminating outdated head node {}".format(head_node))
logger.info("get_or_create_head_node: "
"Terminating outdated head node {}".format(head_node))
provider.terminate_node(head_node)
logger.info("Launching new head node...")
logger.info("get_or_create_head_node: Launching new head node...")
head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
config["cluster_name"])
@@ -163,7 +182,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
# TODO(ekl) right now we always update the head node even if the hash
# matches. We could prompt the user for what they want to do in this case.
runtime_hash = hash_runtime_conf(config["file_mounts"], config)
logger.info("Updating files on head node...")
logger.info("get_or_create_head_node: Updating files on head node...")
# Rewrite the auth config so that the head node can update the workers
remote_key_path = "~/ray_bootstrap_key.pem"
@@ -197,15 +216,16 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
config["setup_commands"] + config["head_setup_commands"] +
config["head_start_ray_commands"])
updater = NodeUpdaterProcess(
updater = NodeUpdaterThread(
head_node,
config["provider"],
provider,
config["auth"],
config["cluster_name"],
config["file_mounts"],
init_commands,
runtime_hash,
redirect_output=False)
)
updater.start()
updater.join()
@@ -213,11 +233,13 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
provider.nodes(head_node_tags)
if updater.exitcode != 0:
logger.error("Updating {} failed".format(
provider.external_ip(head_node)))
logger.error("get_or_create_head_node: "
"Updating {} failed".format(
provider.external_ip(head_node)))
sys.exit(1)
logger.info("Head node up-to-date, IP address is: {}".format(
provider.external_ip(head_node)))
logger.info("get_or_create_head_node: "
"Head node up-to-date, IP address is: {}".format(
provider.external_ip(head_node)))
monitor_str = "tail -n 100 -f /tmp/ray/session_*/logs/monitor*"
for s in init_commands:
@@ -239,6 +261,8 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
config["auth"]["ssh_user"],
provider.external_ip(head_node)))
provider.cleanup()
def attach_cluster(config_file, start, use_tmux, override_cluster_name, new):
"""Attaches to a screen for the specified cluster.
@@ -288,14 +312,18 @@ def exec_cluster(config_file, cmd, screen, tmux, stop, start,
config = _bootstrap_config(config)
head_node = _get_head_node(
config, config_file, override_cluster_name, create_if_needed=start)
updater = NodeUpdaterProcess(
provider = get_node_provider(config["provider"], config["cluster_name"])
updater = NodeUpdaterThread(
head_node,
config["provider"],
provider,
config["auth"],
config["cluster_name"],
config["file_mounts"], [],
config["file_mounts"],
[],
"",
redirect_output=False)
)
if stop:
cmd += ("; ray stop; ray teardown ~/ray_bootstrap_config.yaml --yes "
"--workers-only; sudo shutdown -h now")
@@ -322,6 +350,8 @@ def exec_cluster(config_file, cmd, screen, tmux, stop, start,
attach_command)
logger.info(attach_info)
provider.cleanup()
def _exec(updater, cmd, screen, tmux, expect_error=False, port_forward=None):
if cmd:
@@ -363,20 +393,26 @@ def rsync(config_file, source, target, override_cluster_name, down):
config = _bootstrap_config(config)
head_node = _get_head_node(
config, config_file, override_cluster_name, create_if_needed=False)
updater = NodeUpdaterProcess(
provider = get_node_provider(config["provider"], config["cluster_name"])
updater = NodeUpdaterThread(
head_node,
config["provider"],
provider,
config["auth"],
config["cluster_name"],
config["file_mounts"], [],
config["file_mounts"],
[],
"",
redirect_output=False)
)
if down:
rsync = updater.rsync_down
else:
rsync = updater.rsync_up
rsync(source, target, check_error=False)
provider.cleanup()
def get_head_node_ip(config_file, override_cluster_name):
"""Returns head node IP for given configuration file if exists."""
@@ -384,9 +420,13 @@ def get_head_node_ip(config_file, override_cluster_name):
config = yaml.load(open(config_file).read())
if override_cluster_name is not None:
config["cluster_name"] = override_cluster_name
provider = get_node_provider(config["provider"], config["cluster_name"])
head_node = _get_head_node(config, config_file, override_cluster_name)
return provider.external_ip(head_node)
ip = provider.external_ip(head_node)
provider.cleanup()
return ip
def get_worker_node_ips(config_file, override_cluster_name):
@@ -409,6 +449,8 @@ def _get_head_node(config,
TAG_RAY_NODE_TYPE: "head",
}
nodes = provider.nodes(head_node_tags)
provider.cleanup()
if len(nodes) > 0:
head_node = nodes[0]
return head_node
+2 -1
View File
@@ -2,8 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
import logging
try: # py3
from shlex import quote
except ImportError: # py2
@@ -20,6 +20,7 @@ def dockerize_if_needed(config):
if not docker_image:
if cname:
logger.warning(
"dockerize_if_needed: "
"Container name given but no Docker image - continuing...")
return config
else:
+18 -12
View File
@@ -11,6 +11,8 @@ from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from googleapiclient import discovery, errors
logger = logging.getLogger(__name__)
crm = discovery.build("cloudresourcemanager", "v1")
iam = discovery.build("iam", "v1")
compute = discovery.build("compute", "v1")
@@ -30,12 +32,11 @@ DEFAULT_SERVICE_ACCOUNT_ROLES = ("roles/storage.objectAdmin",
MAX_POLLS = 12
POLL_INTERVAL = 5
logger = logging.getLogger(__name__)
def wait_for_crm_operation(operation):
"""Poll for cloud resource manager operation until finished."""
logger.info("Waiting for operation {} to finish...".format(operation))
logger.info("wait_for_crm_operation: "
"Waiting for operation {} to finish...".format(operation))
for _ in range(MAX_POLLS):
result = crm.operations().get(name=operation["name"]).execute()
@@ -43,7 +44,7 @@ def wait_for_crm_operation(operation):
raise Exception(result["error"])
if "done" in result and result["done"]:
logger.info("Done.")
logger.info("wait_for_crm_operation: Operation done.")
break
time.sleep(POLL_INTERVAL)
@@ -53,8 +54,9 @@ def wait_for_crm_operation(operation):
def wait_for_compute_global_operation(project_name, operation):
"""Poll for global compute operation until finished."""
logger.info("Waiting for operation {} to finish...".format(
operation["name"]))
logger.info("wait_for_compute_global_operation: "
"Waiting for operation {} to finish...".format(
operation["name"]))
for _ in range(MAX_POLLS):
result = compute.globalOperations().get(
@@ -65,7 +67,8 @@ def wait_for_compute_global_operation(project_name, operation):
raise Exception(result["error"])
if result["status"] == "DONE":
logger.info("Done.")
logger.info("wait_for_compute_global_operation: "
"Operation done.")
break
time.sleep(POLL_INTERVAL)
@@ -158,8 +161,9 @@ def _configure_iam_role(config):
service_account = _get_service_account(email, config)
if service_account is None:
logger.info("Creating new service account {}".format(
DEFAULT_SERVICE_ACCOUNT_ID))
logger.info("_configure_iam_role: "
"Creating new service account {}".format(
DEFAULT_SERVICE_ACCOUNT_ID))
service_account = _create_service_account(
DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config)
@@ -231,7 +235,8 @@ def _configure_key_pair(config):
# Create a key since it doesn't exist locally or in GCP
if not key_found and not os.path.exists(private_key_path):
logger.info("Creating new key pair {}".format(key_name))
logger.info("_configure_key_pair: "
"Creating new key pair {}".format(key_name))
public_key, private_key = generate_rsa_key_pair()
_create_project_ssh_key_pair(project, public_key, ssh_user)
@@ -256,8 +261,9 @@ def _configure_key_pair(config):
"Private key file {} not found for user {}"
"".format(private_key_path, ssh_user))
logger.info("Private key not specified in config, using {}"
"".format(private_key_path))
logger.info("_configure_key_pair: "
"Private key not specified in config, using"
"{}".format(private_key_path))
config["auth"]["ssh_private_key"] = private_key_path
+8 -6
View File
@@ -4,24 +4,25 @@ from __future__ import print_function
from uuid import uuid4
import time
import logging
from googleapiclient import discovery
from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL
logger = logging.getLogger(__name__)
INSTANCE_NAME_MAX_LEN = 64
INSTANCE_NAME_UUID_LEN = 8
logger = logging.getLogger(__name__)
def wait_for_compute_zone_operation(compute, project_name, operation, zone):
"""Poll for compute zone operation until finished."""
logger.info("Waiting for operation {} to finish...".format(
operation["name"]))
logger.info("wait_for_compute_zone_operation: "
"Waiting for operation {} to finish...".format(
operation["name"]))
for _ in range(MAX_POLLS):
result = compute.zoneOperations().get(
@@ -31,7 +32,8 @@ def wait_for_compute_zone_operation(compute, project_name, operation, zone):
raise Exception(result["error"])
if result["status"] == "DONE":
logger.info("Done.")
logger.info("wait_for_compute_zone_operation: "
"Operation {} finished.".format(operation["name"]))
break
time.sleep(POLL_INTERVAL)
+7 -3
View File
@@ -12,6 +12,7 @@ from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE
logger = logging.getLogger(__name__)
filelock_logger = logging.getLogger("filelock")
filelock_logger.setLevel(logging.WARNING)
@@ -26,7 +27,8 @@ class ClusterState(object):
workers = json.loads(open(self.save_path).read())
else:
workers = {}
logger.info("Loaded cluster state: {}".format(workers))
logger.info("ClusterState: "
"Loaded cluster state: {}".format(workers))
for worker_ip in provider_config["worker_ips"]:
if worker_ip not in workers:
workers[worker_ip] = {
@@ -50,7 +52,8 @@ class ClusterState(object):
TAG_RAY_NODE_TYPE] == "head"
assert len(workers) == len(provider_config["worker_ips"]) + 1
with open(self.save_path, "w") as f:
logger.info("Writing cluster state: {}".format(workers))
logger.info("ClusterState: "
"Writing cluster state: {}".format(workers))
f.write(json.dumps(workers))
def get(self):
@@ -65,7 +68,8 @@ class ClusterState(object):
workers = self.get()
workers[worker_id] = info
with open(self.save_path, "w") as f:
logger.info("Writing cluster state: {}".format(workers))
logger.info("ClusterState: "
"Writing cluster state: {}".format(workers))
f.write(json.dumps(workers))
+21
View File
@@ -0,0 +1,21 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import datetime
import logging
logger = logging.getLogger(__name__)
class LogTimer():
def __init__(self, message):
self._message = message
def __enter__(self):
self._start_time = datetime.datetime.utcnow()
def __exit__(self, *_):
td = datetime.datetime.utcnow() - self._start_time
logger.info(self._message +
" [LogTimer={:.0f}ms]".format(td.total_seconds() * 1000))
+14
View File
@@ -3,9 +3,12 @@ from __future__ import division
from __future__ import print_function
import importlib
import logging
import os
import yaml
logger = logging.getLogger(__name__)
def import_aws():
from ray.autoscaler.aws.config import bootstrap_aws
@@ -174,3 +177,14 @@ class NodeProvider(object):
def terminate_node(self, node_id):
"""Terminates the specified node."""
raise NotImplementedError
def terminate_nodes(self, node_ids):
"""Terminates a set of nodes. May be overridden with a batch method."""
for node_id in node_ids:
logger.info("NodeProvider: "
"{}: Terminating node".format(node_id))
self.terminate_node(node_id)
def cleanup(self):
"""Clean-up when a Provider is no longer required."""
pass
+153 -116
View File
@@ -10,24 +10,33 @@ import logging
import os
import subprocess
import sys
import tempfile
import time
from multiprocessing import Process
from threading import Thread
from ray.autoscaler.node_provider import get_node_provider
from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG
from ray.autoscaler.log_timer import LogTimer
logger = logging.getLogger(__name__)
# How long to wait for a node to start, in seconds
NODE_START_WAIT_S = 300
SSH_CHECK_INTERVAL = 5
logger = logging.getLogger(__name__)
SSH_CONTROL_PATH = "/tmp/ray_ssh_sockets"
def pretty_cmd(cmd_str):
return "\n\n\t{}\n\n".format(cmd_str)
def get_default_ssh_options(private_key, connect_timeout):
OPTS = [
("ConnectTimeout", "{}s".format(connect_timeout)),
("StrictHostKeyChecking", "no"),
("ControlMaster", "auto"),
("ControlPath", "{}/%C".format(SSH_CONTROL_PATH)),
("ControlPersist", "yes"),
]
return ["-i", private_key] + [
x for y in (["-o", "{}={}".format(k, v)] for k, v in OPTS) for x in y
]
class NodeUpdater(object):
@@ -36,12 +45,12 @@ class NodeUpdater(object):
def __init__(self,
node_id,
provider_config,
provider,
auth_config,
cluster_name,
file_mounts,
setup_cmds,
runtime_hash,
redirect_output=True,
process_runner=subprocess,
use_internal_ip=False):
self.daemon = True
@@ -49,31 +58,22 @@ class NodeUpdater(object):
self.node_id = node_id
self.use_internal_ip = (use_internal_ip or provider_config.get(
"use_internal_ips", False))
self.provider = get_node_provider(provider_config, cluster_name)
self.provider = provider
self.ssh_private_key = auth_config["ssh_private_key"]
self.ssh_user = auth_config["ssh_user"]
self.ssh_ip = self.get_node_ip()
self.ssh_ip = None
self.file_mounts = {
remote: os.path.expanduser(local)
for remote, local in file_mounts.items()
}
self.setup_cmds = setup_cmds
self.runtime_hash = runtime_hash
self.logger = logger.getChild(str(node_id))
if redirect_output:
self.logfile = tempfile.NamedTemporaryFile(
mode="w", prefix="node-updater-", delete=False)
handler = logging.StreamHandler(stream=self.logfile)
handler.setLevel(logging.INFO)
self.logger.addHandler(handler)
self.output_name = self.logfile.name
self.stdout = self.logfile
self.stderr = self.logfile
def get_caller(self, check_error):
if check_error:
return self.process_runner.call
else:
self.logfile = None
self.output_name = "(console)"
self.stdout = sys.stdout
self.stderr = sys.stderr
return self.process_runner.check_call
def get_node_ip(self):
if self.use_internal_ip:
@@ -81,128 +81,173 @@ class NodeUpdater(object):
else:
return self.provider.external_ip(self.node_id)
def wait_for_ip(self, deadline):
while time.time() < deadline and \
not self.provider.is_terminated(self.node_id):
logger.info("NodeUpdater: "
"Waiting for IP of {}...".format(self.node_id))
ip = self.get_node_ip()
if ip is not None:
return ip
time.sleep(10)
return None
def set_ssh_ip_if_required(self):
if self.ssh_ip is not None:
return
# We assume that this never changes.
# I think that's reasonable.
deadline = time.time() + NODE_START_WAIT_S
with LogTimer("NodeUpdater: {}: Got IP".format(self.node_id)):
ip = self.wait_for_ip(deadline)
assert ip is not None, "Unable to find IP of node"
self.ssh_ip = ip
# This should run before any SSH commands and therefore ensure that
# the ControlPath directory exists, allowing SSH to maintain
# persistent sessions later on.
with open("/dev/null", "w") as redirect:
self.get_caller(False)(
["mkdir", "-p", SSH_CONTROL_PATH],
stdout=redirect,
stderr=redirect)
self.get_caller(False)(
["chmod", "0700", SSH_CONTROL_PATH],
stdout=redirect,
stderr=redirect)
def run(self):
self.logger.info(
"NodeUpdater: Updating {} to {}, logging to {}".format(
self.node_id, self.runtime_hash, self.output_name))
logger.info("NodeUpdater: "
"{}: Updating to {}".format(self.node_id,
self.runtime_hash))
try:
self.do_update()
m = "{}: Applied config {}".format(self.node_id, self.runtime_hash)
with LogTimer("NodeUpdater: {}".format(m)):
self.do_update()
except Exception as e:
error_str = str(e)
if hasattr(e, "cmd"):
error_str = "(Exit Status {}) {}".format(
e.returncode, pretty_cmd(" ".join(e.cmd)))
self.logger.error("NodeUpdater: Error updating {}"
"See {} for remote logs.".format(
error_str, self.output_name))
e.returncode, " ".join(e.cmd))
logger.error("NodeUpdater: "
"{}: Error updating {}".format(
self.node_id, error_str))
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "update-failed"})
if self.logfile is not None:
self.logger.info("----- BEGIN REMOTE LOGS -----\n" +
open(self.logfile.name).read() +
"\n----- END REMOTE LOGS -----")
raise e
self.provider.set_node_tags(
self.node_id, {
TAG_RAY_NODE_STATUS: "up-to-date",
TAG_RAY_RUNTIME_CONFIG: self.runtime_hash
})
self.logger.info("NodeUpdater: Applied config {} to node {}".format(
self.runtime_hash, self.node_id))
def do_update(self):
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "waiting-for-ssh"})
deadline = time.time() + NODE_START_WAIT_S
self.exitcode = 0
# Wait for external IP
while time.time() < deadline and \
not self.provider.is_terminated(self.node_id):
self.logger.info("NodeUpdater: Waiting for IP of {}...".format(
self.node_id))
self.ssh_ip = self.get_node_ip()
if self.ssh_ip is not None:
break
time.sleep(10)
assert self.ssh_ip is not None, "Unable to find IP of node"
def wait_for_ssh(self, deadline):
logger.info("NodeUpdater: "
"{}: Waiting for SSH...".format(self.node_id))
# Wait for SSH access
ssh_ok = False
while time.time() < deadline and \
not self.provider.is_terminated(self.node_id):
try:
self.logger.info(
"NodeUpdater: Waiting for SSH to {}...".format(
self.node_id))
if not self.provider.is_running(self.node_id):
raise Exception("Node not running yet...")
logger.debug("NodeUpdater: "
"{}: Waiting for SSH...".format(self.node_id))
self.ssh_cmd(
"uptime",
connect_timeout=5,
redirect=open("/dev/null", "w"))
ssh_ok = True
return True
except Exception as e:
retry_str = str(e)
if hasattr(e, "cmd"):
retry_str = "(Exit Status {}): {}".format(
e.returncode, pretty_cmd(" ".join(e.cmd)))
self.logger.debug(
"NodeUpdater: SSH not up, retrying: {}".format(retry_str),
)
e.returncode, " ".join(e.cmd))
logger.debug("NodeUpdater: "
"{}: SSH not up, retrying: {}".format(
self.node_id, retry_str))
time.sleep(SSH_CHECK_INTERVAL)
else:
break
assert ssh_ok, "Unable to SSH to node"
return False
def do_update(self):
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "waiting-for-ssh"})
deadline = time.time() + NODE_START_WAIT_S
self.set_ssh_ip_if_required()
# Wait for SSH access
with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)):
ssh_ok = self.wait_for_ssh(deadline)
assert ssh_ok, "Unable to SSH to node"
# Rsync file mounts
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "syncing-files"})
for remote_path, local_path in self.file_mounts.items():
self.logger.info("NodeUpdater: Syncing {} to {}...".format(
local_path, remote_path))
logger.info("NodeUpdater: "
"{}: Syncing {} to {}...".format(
self.node_id, local_path, remote_path))
assert os.path.exists(local_path), local_path
if os.path.isdir(local_path):
if not local_path.endswith("/"):
local_path += "/"
if not remote_path.endswith("/"):
remote_path += "/"
self.ssh_cmd("mkdir -p {}".format(os.path.dirname(remote_path)))
self.rsync_up(local_path, remote_path)
m = "{}: Synced {} to {}".format(self.node_id, local_path,
remote_path)
with LogTimer("NodeUpdater {}".format(m)):
self.ssh_cmd(
"mkdir -p {}".format(os.path.dirname(remote_path)),
redirect=open("/dev/null", "w"),
)
self.rsync_up(
local_path, remote_path, redirect=open("/dev/null", "w"))
# Run init commands
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "setting-up"})
for cmd in self.setup_cmds:
self.ssh_cmd(cmd, verbose=True)
def rsync_up(self, source, target, check_error=True):
if check_error:
call = self.process_runner.call
else:
call = self.process_runner.check_call
call(
m = "{}: Setup commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.setup_cmds:
self.ssh_cmd(
cmd,
# verbose=True,
redirect=open("/dev/null", "w"))
def rsync_up(self, source, target, redirect=None, check_error=True):
self.set_ssh_ip_if_required()
self.get_caller(check_error)(
[
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
"-o ConnectTimeout=120s -o StrictHostKeyChecking=no",
"rsync", "-e",
" ".join(["ssh"] +
get_default_ssh_options(self.ssh_private_key, 120)),
"--delete", "-avz", source, "{}@{}:{}".format(
self.ssh_user, self.ssh_ip, target)
],
stdout=self.stdout,
stderr=self.stderr)
stdout=redirect or sys.stdout,
stderr=redirect or sys.stderr)
def rsync_down(self, source, target, check_error=True):
if check_error:
call = self.process_runner.call
else:
call = self.process_runner.check_call
call(
def rsync_down(self, source, target, redirect=None, check_error=True):
self.set_ssh_ip_if_required()
self.get_caller(check_error)(
[
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
"-o ConnectTimeout=120s -o StrictHostKeyChecking=no", "-avz",
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, source), target
"rsync", "-e",
" ".join(["ssh"] +
get_default_ssh_options(self.ssh_private_key, 120)),
"-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
source), target
],
stdout=self.stdout,
stderr=self.stderr)
stdout=redirect or sys.stdout,
stderr=redirect or sys.stderr)
def ssh_cmd(self,
cmd,
@@ -213,9 +258,12 @@ class NodeUpdater(object):
emulate_interactive=True,
expect_error=False,
port_forward=None):
self.set_ssh_ip_if_required()
if verbose:
self.logger.info("NodeUpdater: running {} on {}...".format(
pretty_cmd(cmd), self.ssh_ip))
logger.info("NodeUpdater: "
"Running {} on {}...".format(cmd, self.ssh_ip))
ssh = ["ssh"]
if allocate_tty:
ssh.append("-tt")
@@ -224,35 +272,24 @@ class NodeUpdater(object):
"set -i || true && source ~/.bashrc && "
"export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && ")
cmd = "bash --login -c {}".format(quote(force_interactive + cmd))
if expect_error:
call = self.process_runner.call
else:
call = self.process_runner.check_call
if port_forward is None:
ssh_opt = []
else:
ssh_opt = [
"-L", "{}:localhost:{}".format(port_forward, port_forward)
]
call(
ssh + ssh_opt + [
"-o", "ConnectTimeout={}s".format(connect_timeout), "-o",
"StrictHostKeyChecking=no", "-i", self.ssh_private_key,
"{}@{}".format(self.ssh_user, self.ssh_ip), cmd
],
stdout=redirect or self.stdout,
stderr=redirect or self.stderr)
self.get_caller(expect_error)(
ssh + ssh_opt + get_default_ssh_options(self.ssh_private_key,
connect_timeout) +
["{}@{}".format(self.ssh_user, self.ssh_ip), cmd],
stdout=redirect or sys.stdout,
stderr=redirect or sys.stderr)
class NodeUpdaterProcess(NodeUpdater, Process):
def __init__(self, *args, **kwargs):
Process.__init__(self)
NodeUpdater.__init__(self, *args, **kwargs)
# Single-threaded version for unit tests
class NodeUpdaterThread(NodeUpdater, Thread):
def __init__(self, *args, **kwargs):
Thread.__init__(self)
NodeUpdater.__init__(self, *args, **kwargs)
self.exitcode = 0
self.exitcode = -1
+19 -12
View File
@@ -20,7 +20,6 @@ from ray.services import get_ip_address, get_port
from ray.utils import (binary_to_hex, binary_to_object_id, hex_to_binary,
setup_logger)
# Set up logging.
logger = logging.getLogger(__name__)
@@ -69,8 +68,10 @@ class Monitor(object):
# that redis server.
addr_port = self.redis.lrange("RedisShards", 0, -1)
if len(addr_port) > 1:
logger.warning("TODO: if launching > 1 redis shard, flushing "
"needs to touch shards in parallel.")
logger.warning(
"Monitor: "
"TODO: if launching > 1 redis shard, flushing needs to "
"touch shards in parallel.")
self.issue_gcs_flushes = False
else:
addr_port = addr_port[0].split(b":")
@@ -82,6 +83,7 @@ class Monitor(object):
self.redis_shard.execute_command("HEAD.FLUSH 0")
except redis.exceptions.ResponseError as e:
logger.info(
"Monitor: "
"Turning off flushing due to exception: {}".format(
str(e)))
self.issue_gcs_flushes = False
@@ -128,8 +130,9 @@ class Monitor(object):
self.load_metrics.update(ip, static_resources,
dynamic_resources)
else:
print("Warning: could not find ip for client {} in {}.".format(
client_id, self.local_scheduler_id_to_ip_map))
logger.warning(
"Monitor: "
"could not find ip for client {}".format(client_id))
def _xray_clean_up_entries_for_driver(self, driver_id):
"""Remove this driver's object/task entries from redis.
@@ -185,11 +188,14 @@ class Monitor(object):
continue
redis = self.state.redis_clients[shard_index]
num_deleted = redis.delete(*keys)
logger.info("Removed {} dead redis entries of the driver from"
" redis shard {}.".format(num_deleted, shard_index))
logger.info("Monitor: "
"Removed {} dead redis entries of the "
"driver from redis shard {}.".format(
num_deleted, shard_index))
if num_deleted != len(keys):
logger.warning("Failed to remove {} relevant redis entries"
" from redis shard {}.".format(
logger.warning("Monitor: "
"Failed to remove {} relevant redis "
"entries from redis shard {}.".format(
len(keys) - num_deleted, shard_index))
def xray_driver_removed_handler(self, unused_channel, data):
@@ -205,8 +211,9 @@ class Monitor(object):
message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
driver_data, 0)
driver_id = message.DriverId()
logger.info("XRay Driver {} has been removed.".format(
binary_to_hex(driver_id)))
logger.info("Monitor: "
"XRay Driver {} has been removed.".format(
binary_to_hex(driver_id)))
self._xray_clean_up_entries_for_driver(driver_id)
def process_messages(self, max_messages=10000):
@@ -281,7 +288,7 @@ class Monitor(object):
max_entries_to_flush = self.gcs_flush_policy.num_entries_to_flush()
num_flushed = self.redis_shard.execute_command(
"HEAD.FLUSH {}".format(max_entries_to_flush))
logger.info("num_flushed {}".format(num_flushed))
logger.info("Monitor: num_flushed {}".format(num_flushed))
# This flushes event log and log files.
ray.experimental.flush_redis_unsafe(self.redis)