mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 22:17:21 +08:00
[autoscaler] Speedups (#3720)
- NodeUpdater gets its' IP in parallel now (no longer in __init__) - We use persistent connections in SSH (temp folder created only for ray; ControlMaster) - hash_runtime_conf was performing a pointless hexlify step, wasting time on large files - We use NodeUpdaterThreads and share the NodeProvider; NodeUpdaterProcess is removed - AWSNodeProvider caches nodes more aggressively - NodeProvider now has a shim batch terminate_nodes() call; AWSNodeProvider parallelises it; the autoscaler uses it - AWSNodeProvider batches EC2 update_tags calls - Logging changes throughout to provide standardised timing information for profiling - Pulled out a few unnecessary is_running calls (NodeUpdater will loop waiting for SSH anyway) ## Related issue number Issue #3599
This commit is contained in:
committed by
Richard Liaw
parent
ff3c6af1d6
commit
315edab085
@@ -2,21 +2,19 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import binascii
|
||||
import copy
|
||||
import json
|
||||
import hashlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from six import string_types
|
||||
from six.moves import queue
|
||||
import subprocess
|
||||
import threading
|
||||
import logging
|
||||
import time
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
@@ -26,11 +24,12 @@ from ray.ray_constants import AUTOSCALER_MAX_NUM_FAILURES, \
|
||||
AUTOSCALER_UPDATE_INTERVAL_S, AUTOSCALER_HEARTBEAT_TIMEOUT_S
|
||||
from ray.autoscaler.node_provider import get_node_provider, \
|
||||
get_default_config
|
||||
from ray.autoscaler.updater import NodeUpdaterProcess
|
||||
from ray.autoscaler.updater import NodeUpdaterThread
|
||||
from ray.autoscaler.docker import dockerize_if_needed
|
||||
from ray.autoscaler.tags import (TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG,
|
||||
TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE,
|
||||
TAG_RAY_NODE_NAME)
|
||||
|
||||
import ray.services as services
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -158,11 +157,13 @@ class LoadMetrics(object):
|
||||
def prune(mapping):
|
||||
unwanted = set(mapping) - active_ips
|
||||
for unwanted_key in unwanted:
|
||||
logger.info("Removed mapping: {} - {}".format(
|
||||
unwanted_key, mapping[unwanted_key]))
|
||||
logger.info("LoadMetrics: "
|
||||
"Removed mapping: {} - {}".format(
|
||||
unwanted_key, mapping[unwanted_key]))
|
||||
del mapping[unwanted_key]
|
||||
if unwanted:
|
||||
logger.info(
|
||||
"LoadMetrics: "
|
||||
"Removed {} stale ip mappings: {} not in {}".format(
|
||||
len(unwanted), unwanted, active_ips))
|
||||
|
||||
@@ -174,8 +175,8 @@ class LoadMetrics(object):
|
||||
return self._info()["NumNodesUsed"]
|
||||
|
||||
def info_string(self):
|
||||
return " - {}".format("\n - ".join(
|
||||
["{}: {}".format(k, v) for k, v in sorted(self._info().items())]))
|
||||
return ", ".join(
|
||||
["{}={}".format(k, v) for k, v in sorted(self._info().items())])
|
||||
|
||||
def _info(self):
|
||||
nodes_used = 0.0
|
||||
@@ -223,17 +224,13 @@ class LoadMetrics(object):
|
||||
|
||||
|
||||
class NodeLauncher(threading.Thread):
|
||||
def __init__(self, queue, pending, *args, **kwargs):
|
||||
def __init__(self, provider, queue, pending, *args, **kwargs):
|
||||
self.queue = queue
|
||||
self.pending = pending
|
||||
self.provider = None
|
||||
self.provider = provider
|
||||
super(NodeLauncher, self).__init__(*args, **kwargs)
|
||||
|
||||
def _launch_node(self, config, count):
|
||||
if self.provider is None:
|
||||
self.provider = get_node_provider(config["provider"],
|
||||
config["cluster_name"])
|
||||
|
||||
tag_filters = {TAG_RAY_NODE_TYPE: "worker"}
|
||||
before = self.provider.nodes(tag_filters=tag_filters)
|
||||
launch_hash = hash_launch_conf(config["worker_nodes"], config["auth"])
|
||||
@@ -247,7 +244,8 @@ class NodeLauncher(threading.Thread):
|
||||
}, count)
|
||||
after = self.provider.nodes(tag_filters=tag_filters)
|
||||
if set(after).issubset(before):
|
||||
logger.error("No new nodes reported after node creation")
|
||||
logger.error("NodeLauncher: "
|
||||
"No new nodes reported after node creation")
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
@@ -305,8 +303,6 @@ class StandardAutoscaler(object):
|
||||
max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES,
|
||||
max_failures=AUTOSCALER_MAX_NUM_FAILURES,
|
||||
process_runner=subprocess,
|
||||
verbose_updates=True,
|
||||
node_updater_cls=NodeUpdaterProcess,
|
||||
update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S):
|
||||
self.config_path = config_path
|
||||
self.reload_config(errors_fatal=True)
|
||||
@@ -317,9 +313,7 @@ class StandardAutoscaler(object):
|
||||
self.max_failures = max_failures
|
||||
self.max_launch_batch = max_launch_batch
|
||||
self.max_concurrent_launches = max_concurrent_launches
|
||||
self.verbose_updates = verbose_updates
|
||||
self.process_runner = process_runner
|
||||
self.node_updater_cls = node_updater_cls
|
||||
|
||||
# Map from node_id to NodeUpdater processes
|
||||
self.updaters = {}
|
||||
@@ -337,7 +331,9 @@ class StandardAutoscaler(object):
|
||||
max_concurrent_launches / float(max_launch_batch))
|
||||
for i in range(int(max_batches)):
|
||||
node_launcher = NodeLauncher(
|
||||
queue=self.launch_queue, pending=self.num_launches_pending)
|
||||
provider=self.provider,
|
||||
queue=self.launch_queue,
|
||||
pending=self.num_launches_pending)
|
||||
node_launcher.daemon = True
|
||||
node_launcher.start()
|
||||
|
||||
@@ -359,11 +355,12 @@ class StandardAutoscaler(object):
|
||||
self.reload_config(errors_fatal=False)
|
||||
self._update()
|
||||
except Exception as e:
|
||||
logger.exception("Error during autoscaling.")
|
||||
logger.exception("StandardAutoscaler: "
|
||||
"Error during autoscaling.")
|
||||
self.num_failures += 1
|
||||
if self.num_failures > self.max_failures:
|
||||
logger.critical(
|
||||
"*** StandardAutoscaler: Too many errors, abort. ***")
|
||||
logger.critical("StandardAutoscaler: "
|
||||
"Too many errors, abort.")
|
||||
raise e
|
||||
|
||||
def _update(self):
|
||||
@@ -377,7 +374,7 @@ class StandardAutoscaler(object):
|
||||
self.last_update_time = now
|
||||
num_pending = self.num_launches_pending.value
|
||||
nodes = self.workers()
|
||||
logger.info(self.info_string(nodes))
|
||||
self.log_info_string(nodes)
|
||||
self.load_metrics.prune_active_ips(
|
||||
[self.provider.internal_ip(node_id) for node_id in nodes])
|
||||
target_workers = self.target_num_workers()
|
||||
@@ -385,35 +382,37 @@ class StandardAutoscaler(object):
|
||||
# Terminate any idle or out of date nodes
|
||||
last_used = self.load_metrics.last_used_time_by_ip
|
||||
horizon = now - (60 * self.config["idle_timeout_minutes"])
|
||||
num_terminated = 0
|
||||
|
||||
nodes_to_terminate = []
|
||||
for node_id in nodes:
|
||||
node_ip = self.provider.internal_ip(node_id)
|
||||
if node_ip in last_used and last_used[node_ip] < horizon and \
|
||||
len(nodes) - num_terminated > target_workers:
|
||||
num_terminated += 1
|
||||
logger.info("StandardAutoscaler: Terminating idle node: "
|
||||
"{}".format(node_id))
|
||||
self.provider.terminate_node(node_id)
|
||||
len(nodes) - len(nodes_to_terminate) > target_workers:
|
||||
logger.info("StandardAutoscaler: "
|
||||
"{}: Terminating idle node".format(node_id))
|
||||
nodes_to_terminate.append(node_id)
|
||||
elif not self.launch_config_ok(node_id):
|
||||
num_terminated += 1
|
||||
logger.info("StandardAutoscaler: Terminating outdated node: "
|
||||
"{}".format(node_id))
|
||||
self.provider.terminate_node(node_id)
|
||||
if num_terminated > 0:
|
||||
logger.info("StandardAutoscaler: "
|
||||
"{}: Terminating outdated node".format(node_id))
|
||||
nodes_to_terminate.append(node_id)
|
||||
|
||||
if nodes_to_terminate:
|
||||
self.provider.terminate_nodes(nodes_to_terminate)
|
||||
nodes = self.workers()
|
||||
logger.info(self.info_string(nodes))
|
||||
self.log_info_string(nodes)
|
||||
|
||||
# Terminate nodes if there are too many
|
||||
num_terminated = 0
|
||||
nodes_to_terminate = []
|
||||
while len(nodes) > self.config["max_workers"]:
|
||||
num_terminated += 1
|
||||
logger.info("StandardAutoscaler: Terminating unneeded node: "
|
||||
"{}".format(nodes[-1]))
|
||||
self.provider.terminate_node(nodes[-1])
|
||||
logger.info("StandardAutoscaler: "
|
||||
"{}: Terminating unneeded node".format(nodes[-1]))
|
||||
nodes_to_terminate.append(nodes[-1])
|
||||
nodes = nodes[:-1]
|
||||
if num_terminated > 0:
|
||||
|
||||
if nodes_to_terminate:
|
||||
self.provider.terminate_nodes(nodes_to_terminate)
|
||||
nodes = self.workers()
|
||||
logger.info(self.info_string(nodes))
|
||||
self.log_info_string(nodes)
|
||||
|
||||
# Launch new nodes if needed
|
||||
num_workers = len(nodes) + num_pending
|
||||
@@ -422,7 +421,8 @@ class StandardAutoscaler(object):
|
||||
self.max_concurrent_launches - num_pending)
|
||||
num_launches = min(max_allowed, target_workers - num_workers)
|
||||
self.launch_new_node(num_launches)
|
||||
logger.info(self.info_string())
|
||||
nodes = self.workers()
|
||||
self.log_info_string(nodes)
|
||||
else:
|
||||
self.bringup = False
|
||||
|
||||
@@ -442,11 +442,21 @@ class StandardAutoscaler(object):
|
||||
# immediately trying to restart Ray on the new node.
|
||||
self.load_metrics.mark_active(self.provider.internal_ip(node_id))
|
||||
nodes = self.workers()
|
||||
logger.info(self.info_string(nodes))
|
||||
self.log_info_string(nodes)
|
||||
|
||||
# Update nodes with out-of-date files
|
||||
for node_id in nodes:
|
||||
self.update_if_needed(node_id)
|
||||
T = [
|
||||
threading.Thread(
|
||||
target=self.spawn_updater,
|
||||
args=(node_id, commands),
|
||||
) for node_id, commands in (self.should_update(node_id)
|
||||
for node_id in nodes)
|
||||
if node_id is not None
|
||||
]
|
||||
for t in T:
|
||||
t.start()
|
||||
for t in T:
|
||||
t.join()
|
||||
|
||||
# Attempt to recover unhealthy nodes
|
||||
for node_id in nodes:
|
||||
@@ -471,7 +481,8 @@ class StandardAutoscaler(object):
|
||||
if errors_fatal:
|
||||
raise e
|
||||
else:
|
||||
logger.exception("StandardAutoscaler: Error parsing config.")
|
||||
logger.exception("StandardAutoscaler: "
|
||||
"Error parsing config.")
|
||||
|
||||
def target_num_workers(self):
|
||||
initial_workers = self.config["initial_workers"]
|
||||
@@ -496,9 +507,9 @@ class StandardAutoscaler(object):
|
||||
def files_up_to_date(self, node_id):
|
||||
applied = self.provider.node_tags(node_id).get(TAG_RAY_RUNTIME_CONFIG)
|
||||
if applied != self.runtime_hash:
|
||||
logger.info(
|
||||
"StandardAutoscaler: {} has runtime state {}, want {}".format(
|
||||
node_id, applied, self.runtime_hash))
|
||||
logger.info("StandardAutoscaler: "
|
||||
"{}: Runtime state is {}, want {}".format(
|
||||
node_id, applied, self.runtime_hash))
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -512,27 +523,29 @@ class StandardAutoscaler(object):
|
||||
delta = now - last_heartbeat_time
|
||||
if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S:
|
||||
return
|
||||
logger.warning("StandardAutoscaler: No heartbeat from node "
|
||||
"{} in {} seconds, restarting Ray to recover...".format(
|
||||
node_id, delta))
|
||||
updater = self.node_updater_cls(
|
||||
logger.warning("StandardAutoscaler: "
|
||||
"{}: No heartbeat in {}s, "
|
||||
"restarting Ray to recover...".format(node_id, delta))
|
||||
updater = NodeUpdaterThread(
|
||||
node_id,
|
||||
self.config["provider"],
|
||||
self.provider,
|
||||
self.config["auth"],
|
||||
self.config["cluster_name"], {},
|
||||
with_head_node_ip(self.config["worker_start_ray_commands"]),
|
||||
self.runtime_hash,
|
||||
redirect_output=not self.verbose_updates,
|
||||
process_runner=self.process_runner,
|
||||
use_internal_ip=True)
|
||||
updater.start()
|
||||
self.updaters[node_id] = updater
|
||||
|
||||
def update_if_needed(self, node_id):
|
||||
def should_update(self, node_id):
|
||||
if not self.can_update(node_id):
|
||||
return
|
||||
return (None, None)
|
||||
|
||||
if self.files_up_to_date(node_id):
|
||||
return
|
||||
return (None, None)
|
||||
|
||||
successful_updated = self.num_successful_updates.get(node_id, 0) > 0
|
||||
if successful_updated and self.config.get("restart_only", False):
|
||||
init_commands = self.config["worker_start_ray_commands"]
|
||||
@@ -544,33 +557,35 @@ class StandardAutoscaler(object):
|
||||
self.config["worker_setup_commands"] +
|
||||
self.config["worker_start_ray_commands"])
|
||||
|
||||
updater = self.node_updater_cls(
|
||||
return (node_id, init_commands)
|
||||
|
||||
def spawn_updater(self, node_id, init_commands):
|
||||
updater = NodeUpdaterThread(
|
||||
node_id,
|
||||
self.config["provider"],
|
||||
self.provider,
|
||||
self.config["auth"],
|
||||
self.config["cluster_name"],
|
||||
self.config["file_mounts"],
|
||||
with_head_node_ip(init_commands),
|
||||
self.runtime_hash,
|
||||
redirect_output=not self.verbose_updates,
|
||||
process_runner=self.process_runner,
|
||||
use_internal_ip=True)
|
||||
updater.start()
|
||||
self.updaters[node_id] = updater
|
||||
|
||||
def can_update(self, node_id):
|
||||
if not self.provider.is_running(node_id):
|
||||
if node_id in self.updaters:
|
||||
return False
|
||||
if not self.launch_config_ok(node_id):
|
||||
return False
|
||||
if node_id in self.updaters:
|
||||
return False
|
||||
if self.num_failed_updates.get(node_id, 0) > 0: # TODO(ekl) retry?
|
||||
return False
|
||||
return True
|
||||
|
||||
def launch_new_node(self, count):
|
||||
logger.info("StandardAutoscaler: Launching {} new nodes".format(count))
|
||||
logger.info("StandardAutoscaler: "
|
||||
"Launching {} new nodes".format(count))
|
||||
self.num_launches_pending.inc(count)
|
||||
config = copy.deepcopy(self.config)
|
||||
self.launch_queue.put((config, count))
|
||||
@@ -578,9 +593,11 @@ class StandardAutoscaler(object):
|
||||
def workers(self):
|
||||
return self.provider.nodes(tag_filters={TAG_RAY_NODE_TYPE: "worker"})
|
||||
|
||||
def info_string(self, nodes=None):
|
||||
if nodes is None:
|
||||
nodes = self.workers()
|
||||
def log_info_string(self, nodes):
|
||||
logger.info("StandardAutoscaler: {}".format(self.info_string(nodes)))
|
||||
logger.info("LoadMetrics: {}".format(self.load_metrics.info_string()))
|
||||
|
||||
def info_string(self, nodes):
|
||||
suffix = ""
|
||||
if self.num_launches_pending:
|
||||
suffix += " ({} pending)".format(self.num_launches_pending.value)
|
||||
@@ -591,9 +608,9 @@ class StandardAutoscaler(object):
|
||||
len(self.num_failed_updates))
|
||||
if self.bringup:
|
||||
suffix += " (bringup=True)"
|
||||
return "StandardAutoscaler [{}]: {}/{} target nodes{}\n{}".format(
|
||||
datetime.now(), len(nodes), self.target_num_workers(), suffix,
|
||||
self.load_metrics.info_string())
|
||||
|
||||
return "{}/{} target nodes{}".format(
|
||||
len(nodes), self.target_num_workers(), suffix)
|
||||
|
||||
|
||||
def typename(v):
|
||||
@@ -685,6 +702,11 @@ def hash_runtime_conf(file_mounts, extra_objs):
|
||||
hasher = hashlib.sha1()
|
||||
|
||||
def add_content_hashes(path):
|
||||
def add_hash_of_file(fpath):
|
||||
with open(fpath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(2**20), b''):
|
||||
hasher.update(chunk)
|
||||
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isdir(path):
|
||||
dirs = []
|
||||
@@ -694,18 +716,12 @@ def hash_runtime_conf(file_mounts, extra_objs):
|
||||
hasher.update(dirpath.encode("utf-8"))
|
||||
for name in filenames:
|
||||
hasher.update(name.encode("utf-8"))
|
||||
with open(os.path.join(dirpath, name), "rb") as f:
|
||||
if os.path.getsize(os.path.join(dirpath,
|
||||
name)) < 1000000000:
|
||||
hasher.update(binascii.hexlify(f.read()))
|
||||
else:
|
||||
for chunk in iter(lambda: f.read(8192), b''):
|
||||
hasher.update(binascii.hexlify(chunk))
|
||||
fpath = os.path.join(dirpath, name)
|
||||
add_hash_of_file(fpath)
|
||||
else:
|
||||
with open(path, "rb") as f:
|
||||
hasher.update(binascii.hexlify(f.read()))
|
||||
add_hash_of_file(path)
|
||||
|
||||
conf_str = (json.dumps(sorted(file_mounts.items())).encode("utf-8") +
|
||||
conf_str = (json.dumps(file_mounts, sort_keys=True).encode("utf-8") +
|
||||
json.dumps(extra_objs, sort_keys=True).encode("utf-8"))
|
||||
|
||||
# Important: only hash the files once. Otherwise, we can end up restarting
|
||||
|
||||
@@ -4,9 +4,9 @@ from __future__ import print_function
|
||||
|
||||
from distutils.version import StrictVersion
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
@@ -14,6 +14,8 @@ import botocore
|
||||
|
||||
from ray.ray_constants import BOTO_MAX_RETRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RAY = "ray-autoscaler"
|
||||
DEFAULT_RAY_INSTANCE_PROFILE = RAY + "-v1"
|
||||
DEFAULT_RAY_IAM_ROLE = RAY + "-v1"
|
||||
@@ -34,7 +36,6 @@ def key_pair(i, region):
|
||||
|
||||
# Suppress excessive connection dropped logs from boto
|
||||
logging.getLogger("botocore").setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def bootstrap_aws(config):
|
||||
@@ -62,8 +63,9 @@ def _configure_iam_role(config):
|
||||
profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)
|
||||
|
||||
if profile is None:
|
||||
logger.info("Creating new instance profile {}".format(
|
||||
DEFAULT_RAY_INSTANCE_PROFILE))
|
||||
logger.info("_configure_iam_role: "
|
||||
"Creating new instance profile {}".format(
|
||||
DEFAULT_RAY_INSTANCE_PROFILE))
|
||||
client = _client("iam", config)
|
||||
client.create_instance_profile(
|
||||
InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE)
|
||||
@@ -75,7 +77,8 @@ def _configure_iam_role(config):
|
||||
if not profile.roles:
|
||||
role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
|
||||
if role is None:
|
||||
logger.info("Creating new role {}".format(DEFAULT_RAY_IAM_ROLE))
|
||||
logger.info("_configure_iam_role: "
|
||||
"Creating new role {}".format(DEFAULT_RAY_IAM_ROLE))
|
||||
iam = _resource("iam", config)
|
||||
iam.create_role(
|
||||
RoleName=DEFAULT_RAY_IAM_ROLE,
|
||||
@@ -99,8 +102,9 @@ def _configure_iam_role(config):
|
||||
profile.add_role(RoleName=role.name)
|
||||
time.sleep(15) # wait for propagation
|
||||
|
||||
logger.info("Role not specified for head node, using {}".format(
|
||||
profile.arn))
|
||||
logger.info("_configure_iam_role: "
|
||||
"Role not specified for head node, using {}".format(
|
||||
profile.arn))
|
||||
config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
|
||||
|
||||
return config
|
||||
@@ -126,7 +130,8 @@ def _configure_key_pair(config):
|
||||
|
||||
# We can safely create a new key.
|
||||
if not key and not os.path.exists(key_path):
|
||||
logger.info("Creating new key pair {}".format(key_name))
|
||||
logger.info("_configure_key_pair: "
|
||||
"Creating new key pair {}".format(key_name))
|
||||
key = ec2.create_key_pair(KeyName=key_name)
|
||||
with open(key_path, "w") as f:
|
||||
f.write(key.key_material)
|
||||
@@ -142,7 +147,8 @@ def _configure_key_pair(config):
|
||||
assert os.path.exists(key_path), \
|
||||
"Private key file {} not found for {}".format(key_path, key_name)
|
||||
|
||||
logger.info("KeyName not specified for nodes, using {}".format(key_name))
|
||||
logger.info("_configure_key_pair: "
|
||||
"KeyName not specified for nodes, using {}".format(key_name))
|
||||
|
||||
config["auth"]["ssh_private_key"] = key_path
|
||||
config["head_node"]["KeyName"] = key_name
|
||||
@@ -174,19 +180,21 @@ def _configure_subnet(config):
|
||||
"No usable subnets matching availability zone {} "
|
||||
"found. Choose a different availability zone or try "
|
||||
"manually creating an instance in your specified region "
|
||||
"to populate the list of subnets and trying this again."
|
||||
.format(config["provider"]["availability_zone"]))
|
||||
"to populate the list of subnets and trying this again.".
|
||||
format(config["provider"]["availability_zone"]))
|
||||
|
||||
subnet_ids = [s.subnet_id for s in subnets]
|
||||
subnet_descr = [(s.subnet_id, s.availability_zone) for s in subnets]
|
||||
if "SubnetIds" not in config["head_node"]:
|
||||
config["head_node"]["SubnetIds"] = subnet_ids
|
||||
logger.info("SubnetIds not specified for head node,"
|
||||
" using {}".format(subnet_descr))
|
||||
logger.info("_configure_subnet: "
|
||||
"SubnetIds not specified for head node, using {}".format(
|
||||
subnet_descr))
|
||||
|
||||
if "SubnetIds" not in config["worker_nodes"]:
|
||||
config["worker_nodes"]["SubnetIds"] = subnet_ids
|
||||
logger.info("SubnetId not specified for workers,"
|
||||
logger.info("_configure_subnet: "
|
||||
"SubnetId not specified for workers,"
|
||||
" using {}".format(subnet_descr))
|
||||
|
||||
return config
|
||||
@@ -202,7 +210,8 @@ def _configure_security_group(config):
|
||||
security_group = _get_security_group(config, vpc_id, group_name)
|
||||
|
||||
if security_group is None:
|
||||
logger.info("Creating new security group {}".format(group_name))
|
||||
logger.info("_configure_security_group: "
|
||||
"Creating new security group {}".format(group_name))
|
||||
client = _client("ec2", config)
|
||||
client.create_security_group(
|
||||
Description="Auto-created security group for Ray workers",
|
||||
@@ -230,12 +239,14 @@ def _configure_security_group(config):
|
||||
|
||||
if "SecurityGroupIds" not in config["head_node"]:
|
||||
logger.info(
|
||||
"_configure_security_group: "
|
||||
"SecurityGroupIds not specified for head node, using {}".format(
|
||||
security_group.group_name))
|
||||
config["head_node"]["SecurityGroupIds"] = [security_group.id]
|
||||
|
||||
if "SecurityGroupIds" not in config["worker_nodes"]:
|
||||
logger.info(
|
||||
"_configure_security_group: "
|
||||
"SecurityGroupIds not specified for workers, using {}".format(
|
||||
security_group.group_name))
|
||||
config["worker_nodes"]["SecurityGroupIds"] = [security_group.id]
|
||||
|
||||
@@ -3,6 +3,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
@@ -10,6 +12,7 @@ from botocore.config import Config
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
|
||||
from ray.ray_constants import BOTO_MAX_RETRIES
|
||||
from ray.autoscaler.log_timer import LogTimer
|
||||
|
||||
|
||||
def to_aws_format(tags):
|
||||
@@ -40,15 +43,59 @@ class AWSNodeProvider(NodeProvider):
|
||||
# Try availability zones round-robin, starting from random offset
|
||||
self.subnet_idx = random.randint(0, 100)
|
||||
|
||||
self.tag_cache = {} # Tags that we believe to actually be on EC2.
|
||||
self.tag_cache_pending = {} # Tags that we will soon upload.
|
||||
self.tag_cache_lock = threading.Lock()
|
||||
self.tag_cache_update_event = threading.Event()
|
||||
self.tag_cache_kill_event = threading.Event()
|
||||
self.tag_update_thread = threading.Thread(
|
||||
target=self._node_tag_update_loop)
|
||||
self.tag_update_thread.start()
|
||||
|
||||
# Cache of node objects from the last nodes() call. This avoids
|
||||
# excessive DescribeInstances requests.
|
||||
self.cached_nodes = {}
|
||||
|
||||
# Cache of ip lookups. We assume IPs never change once assigned.
|
||||
self.internal_ip_cache = {}
|
||||
self.external_ip_cache = {}
|
||||
def _node_tag_update_loop(self):
|
||||
""" Update the AWS tags for a cluster periodically.
|
||||
|
||||
The purpose of this loop is to avoid excessive EC2 calls when a large
|
||||
number of nodes are being launched simultaneously.
|
||||
"""
|
||||
while True:
|
||||
self.tag_cache_update_event.wait()
|
||||
self.tag_cache_update_event.clear()
|
||||
|
||||
batch_updates = defaultdict(list)
|
||||
|
||||
with self.tag_cache_lock:
|
||||
for node_id, tags in self.tag_cache_pending.items():
|
||||
for x in tags.items():
|
||||
batch_updates[x].append(node_id)
|
||||
self.tag_cache[node_id].update(tags)
|
||||
|
||||
self.tag_cache_pending = {}
|
||||
|
||||
for (k, v), node_ids in batch_updates.items():
|
||||
m = "Set tag {}={} on {}".format(k, v, node_ids)
|
||||
with LogTimer("AWSNodeProvider: {}".format(m)):
|
||||
if k == TAG_RAY_NODE_NAME:
|
||||
k = "Name"
|
||||
self.ec2.meta.client.create_tags(
|
||||
Resources=node_ids,
|
||||
Tags=[{
|
||||
"Key": k,
|
||||
"Value": v
|
||||
}],
|
||||
)
|
||||
|
||||
self.tag_cache_kill_event.wait(timeout=5)
|
||||
if self.tag_cache_kill_event.is_set():
|
||||
return
|
||||
|
||||
def nodes(self, tag_filters):
|
||||
# Note that these filters are acceptable because they are set on
|
||||
# node initialization, and so can never be sitting in the cache.
|
||||
tag_filters = to_aws_format(tag_filters)
|
||||
filters = [
|
||||
{
|
||||
@@ -65,9 +112,19 @@ class AWSNodeProvider(NodeProvider):
|
||||
"Name": "tag:{}".format(k),
|
||||
"Values": [v],
|
||||
})
|
||||
instances = list(self.ec2.instances.filter(Filters=filters))
|
||||
self.cached_nodes = {i.id: i for i in instances}
|
||||
return [i.id for i in instances]
|
||||
|
||||
nodes = list(self.ec2.instances.filter(Filters=filters))
|
||||
# Populate the tag cache with initial information if necessary
|
||||
for node in nodes:
|
||||
if node.id in self.tag_cache:
|
||||
continue
|
||||
|
||||
self.tag_cache[node.id] = from_aws_format(
|
||||
{x["Key"]: x["Value"]
|
||||
for x in node.tags})
|
||||
|
||||
self.cached_nodes = {node.id: node for node in nodes}
|
||||
return [node.id for node in nodes]
|
||||
|
||||
def is_running(self, node_id):
|
||||
node = self._node(node_id)
|
||||
@@ -79,40 +136,25 @@ class AWSNodeProvider(NodeProvider):
|
||||
return state not in ["running", "pending"]
|
||||
|
||||
def node_tags(self, node_id):
|
||||
node = self._node(node_id)
|
||||
tags = {}
|
||||
for tag in node.tags:
|
||||
tags[tag["Key"]] = tag["Value"]
|
||||
return from_aws_format(tags)
|
||||
with self.tag_cache_lock:
|
||||
d1 = self.tag_cache[node_id]
|
||||
d2 = self.tag_cache_pending.get(node_id, {})
|
||||
return dict(d1, **d2)
|
||||
|
||||
def external_ip(self, node_id):
|
||||
if node_id in self.external_ip_cache:
|
||||
return self.external_ip_cache[node_id]
|
||||
node = self._node(node_id)
|
||||
ip = node.public_ip_address
|
||||
if ip:
|
||||
self.external_ip_cache[node_id] = ip
|
||||
return ip
|
||||
return self._node(node_id).public_ip_address
|
||||
|
||||
def internal_ip(self, node_id):
|
||||
if node_id in self.internal_ip_cache:
|
||||
return self.internal_ip_cache[node_id]
|
||||
node = self._node(node_id)
|
||||
ip = node.private_ip_address
|
||||
if ip:
|
||||
self.internal_ip_cache[node_id] = ip
|
||||
return ip
|
||||
return self._node(node_id).private_ip_address
|
||||
|
||||
def set_node_tags(self, node_id, tags):
|
||||
tags = to_aws_format(tags)
|
||||
node = self._node(node_id)
|
||||
tag_pairs = []
|
||||
for k, v in tags.items():
|
||||
tag_pairs.append({
|
||||
"Key": k,
|
||||
"Value": v,
|
||||
})
|
||||
node.create_tags(Tags=tag_pairs)
|
||||
with self.tag_cache_lock:
|
||||
try:
|
||||
self.tag_cache_pending[node_id].update(tags)
|
||||
except KeyError:
|
||||
self.tag_cache_pending[node_id] = tags
|
||||
|
||||
self.tag_cache_update_event.set()
|
||||
|
||||
def create_node(self, node_config, tags, count):
|
||||
tags = to_aws_format(tags)
|
||||
@@ -166,9 +208,24 @@ class AWSNodeProvider(NodeProvider):
|
||||
node = self._node(node_id)
|
||||
node.terminate()
|
||||
|
||||
self.tag_cache.pop(node_id, None)
|
||||
self.tag_cache_pending.pop(node_id, None)
|
||||
|
||||
def terminate_nodes(self, node_ids):
|
||||
self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)
|
||||
|
||||
for node_id in node_ids:
|
||||
self.tag_cache.pop(node_id, None)
|
||||
self.tag_cache_pending.pop(node_id, None)
|
||||
|
||||
def _node(self, node_id):
|
||||
if node_id in self.cached_nodes:
|
||||
return self.cached_nodes[node_id]
|
||||
matches = list(self.ec2.instances.filter(InstanceIds=[node_id]))
|
||||
assert len(matches) == 1, "Invalid instance id {}".format(node_id)
|
||||
return matches[0]
|
||||
if node_id not in self.cached_nodes:
|
||||
self.nodes({}) # Side effect: should cache it.
|
||||
|
||||
assert node_id in self.cached_nodes, "Invalid instance id {}".format(
|
||||
node_id)
|
||||
return self.cached_nodes[node_id]
|
||||
|
||||
def cleanup(self):
|
||||
self.tag_cache_update_event.set()
|
||||
self.tag_cache_kill_event.set()
|
||||
|
||||
@@ -8,9 +8,9 @@ import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import logging
|
||||
import sys
|
||||
import click
|
||||
import logging
|
||||
import random
|
||||
|
||||
import yaml
|
||||
@@ -24,7 +24,8 @@ from ray.autoscaler.autoscaler import validate_config, hash_runtime_conf, \
|
||||
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
|
||||
TAG_RAY_NODE_NAME
|
||||
from ray.autoscaler.updater import NodeUpdaterProcess
|
||||
from ray.autoscaler.updater import NodeUpdaterThread
|
||||
from ray.autoscaler.log_timer import LogTimer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -81,18 +82,35 @@ def teardown_cluster(config_file, yes, workers_only, override_cluster_name):
|
||||
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
|
||||
if not workers_only:
|
||||
for node in provider.nodes({TAG_RAY_NODE_TYPE: "head"}):
|
||||
logger.info("Terminating head node {}".format(node))
|
||||
provider.terminate_node(node)
|
||||
def remaining_nodes():
|
||||
if workers_only:
|
||||
A = []
|
||||
else:
|
||||
A = [
|
||||
node_id for node_id in provider.nodes({
|
||||
TAG_RAY_NODE_TYPE: "head"
|
||||
})
|
||||
]
|
||||
|
||||
nodes = provider.nodes({TAG_RAY_NODE_TYPE: "worker"})
|
||||
while nodes:
|
||||
for node in nodes:
|
||||
logger.info("Terminating worker {}".format(node))
|
||||
provider.terminate_node(node)
|
||||
time.sleep(5)
|
||||
nodes = provider.nodes({TAG_RAY_NODE_TYPE: "worker"})
|
||||
A += [
|
||||
node_id for node_id in provider.nodes({
|
||||
TAG_RAY_NODE_TYPE: "worker"
|
||||
})
|
||||
]
|
||||
return A
|
||||
|
||||
# Loop here to check that both the head and worker nodes are actually
|
||||
# really gone
|
||||
A = remaining_nodes()
|
||||
with LogTimer("teardown_cluster: Termination done."):
|
||||
while A:
|
||||
logger.info("teardown_cluster: "
|
||||
"Terminating {} nodes...".format(len(A)))
|
||||
provider.terminate_nodes(A)
|
||||
time.sleep(1)
|
||||
A = remaining_nodes()
|
||||
|
||||
provider.cleanup()
|
||||
|
||||
|
||||
def kill_node(config_file, yes, override_cluster_name):
|
||||
@@ -108,26 +126,26 @@ def kill_node(config_file, yes, override_cluster_name):
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
nodes = provider.nodes({TAG_RAY_NODE_TYPE: "worker"})
|
||||
node = random.choice(nodes)
|
||||
logger.info("Terminating worker {}".format(node))
|
||||
updater = NodeUpdaterProcess(
|
||||
node,
|
||||
config["provider"],
|
||||
config["auth"],
|
||||
config["cluster_name"],
|
||||
config["file_mounts"], [],
|
||||
"",
|
||||
redirect_output=False)
|
||||
logger.info("kill_node: Terminating worker {}".format(node))
|
||||
|
||||
updater = NodeUpdaterThread(node, config["provider"], provider,
|
||||
config["auth"], config["cluster_name"],
|
||||
config["file_mounts"], [], "")
|
||||
|
||||
_exec(updater, "ray stop", False, False)
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
iip = provider.internal_ip(node)
|
||||
if iip:
|
||||
return iip
|
||||
|
||||
return provider.external_ip(node)
|
||||
|
||||
|
||||
def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
override_cluster_name):
|
||||
"""Create the cluster head node, which in turn creates the workers."""
|
||||
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
head_node_tags = {
|
||||
TAG_RAY_NODE_TYPE: "head",
|
||||
@@ -148,9 +166,10 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
TAG_RAY_LAUNCH_CONFIG) != launch_hash:
|
||||
if head_node is not None:
|
||||
confirm("Head node config out-of-date. It will be terminated", yes)
|
||||
logger.info("Terminating outdated head node {}".format(head_node))
|
||||
logger.info("get_or_create_head_node: "
|
||||
"Terminating outdated head node {}".format(head_node))
|
||||
provider.terminate_node(head_node)
|
||||
logger.info("Launching new head node...")
|
||||
logger.info("get_or_create_head_node: Launching new head node...")
|
||||
head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
|
||||
head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
|
||||
config["cluster_name"])
|
||||
@@ -163,7 +182,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
# TODO(ekl) right now we always update the head node even if the hash
|
||||
# matches. We could prompt the user for what they want to do in this case.
|
||||
runtime_hash = hash_runtime_conf(config["file_mounts"], config)
|
||||
logger.info("Updating files on head node...")
|
||||
logger.info("get_or_create_head_node: Updating files on head node...")
|
||||
|
||||
# Rewrite the auth config so that the head node can update the workers
|
||||
remote_key_path = "~/ray_bootstrap_key.pem"
|
||||
@@ -197,15 +216,16 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
config["setup_commands"] + config["head_setup_commands"] +
|
||||
config["head_start_ray_commands"])
|
||||
|
||||
updater = NodeUpdaterProcess(
|
||||
updater = NodeUpdaterThread(
|
||||
head_node,
|
||||
config["provider"],
|
||||
provider,
|
||||
config["auth"],
|
||||
config["cluster_name"],
|
||||
config["file_mounts"],
|
||||
init_commands,
|
||||
runtime_hash,
|
||||
redirect_output=False)
|
||||
)
|
||||
updater.start()
|
||||
updater.join()
|
||||
|
||||
@@ -213,11 +233,13 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
provider.nodes(head_node_tags)
|
||||
|
||||
if updater.exitcode != 0:
|
||||
logger.error("Updating {} failed".format(
|
||||
provider.external_ip(head_node)))
|
||||
logger.error("get_or_create_head_node: "
|
||||
"Updating {} failed".format(
|
||||
provider.external_ip(head_node)))
|
||||
sys.exit(1)
|
||||
logger.info("Head node up-to-date, IP address is: {}".format(
|
||||
provider.external_ip(head_node)))
|
||||
logger.info("get_or_create_head_node: "
|
||||
"Head node up-to-date, IP address is: {}".format(
|
||||
provider.external_ip(head_node)))
|
||||
|
||||
monitor_str = "tail -n 100 -f /tmp/ray/session_*/logs/monitor*"
|
||||
for s in init_commands:
|
||||
@@ -239,6 +261,8 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
config["auth"]["ssh_user"],
|
||||
provider.external_ip(head_node)))
|
||||
|
||||
provider.cleanup()
|
||||
|
||||
|
||||
def attach_cluster(config_file, start, use_tmux, override_cluster_name, new):
|
||||
"""Attaches to a screen for the specified cluster.
|
||||
@@ -288,14 +312,18 @@ def exec_cluster(config_file, cmd, screen, tmux, stop, start,
|
||||
config = _bootstrap_config(config)
|
||||
head_node = _get_head_node(
|
||||
config, config_file, override_cluster_name, create_if_needed=start)
|
||||
updater = NodeUpdaterProcess(
|
||||
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
updater = NodeUpdaterThread(
|
||||
head_node,
|
||||
config["provider"],
|
||||
provider,
|
||||
config["auth"],
|
||||
config["cluster_name"],
|
||||
config["file_mounts"], [],
|
||||
config["file_mounts"],
|
||||
[],
|
||||
"",
|
||||
redirect_output=False)
|
||||
)
|
||||
if stop:
|
||||
cmd += ("; ray stop; ray teardown ~/ray_bootstrap_config.yaml --yes "
|
||||
"--workers-only; sudo shutdown -h now")
|
||||
@@ -322,6 +350,8 @@ def exec_cluster(config_file, cmd, screen, tmux, stop, start,
|
||||
attach_command)
|
||||
logger.info(attach_info)
|
||||
|
||||
provider.cleanup()
|
||||
|
||||
|
||||
def _exec(updater, cmd, screen, tmux, expect_error=False, port_forward=None):
|
||||
if cmd:
|
||||
@@ -363,20 +393,26 @@ def rsync(config_file, source, target, override_cluster_name, down):
|
||||
config = _bootstrap_config(config)
|
||||
head_node = _get_head_node(
|
||||
config, config_file, override_cluster_name, create_if_needed=False)
|
||||
updater = NodeUpdaterProcess(
|
||||
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
updater = NodeUpdaterThread(
|
||||
head_node,
|
||||
config["provider"],
|
||||
provider,
|
||||
config["auth"],
|
||||
config["cluster_name"],
|
||||
config["file_mounts"], [],
|
||||
config["file_mounts"],
|
||||
[],
|
||||
"",
|
||||
redirect_output=False)
|
||||
)
|
||||
if down:
|
||||
rsync = updater.rsync_down
|
||||
else:
|
||||
rsync = updater.rsync_up
|
||||
rsync(source, target, check_error=False)
|
||||
|
||||
provider.cleanup()
|
||||
|
||||
|
||||
def get_head_node_ip(config_file, override_cluster_name):
|
||||
"""Returns head node IP for given configuration file if exists."""
|
||||
@@ -384,9 +420,13 @@ def get_head_node_ip(config_file, override_cluster_name):
|
||||
config = yaml.load(open(config_file).read())
|
||||
if override_cluster_name is not None:
|
||||
config["cluster_name"] = override_cluster_name
|
||||
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
head_node = _get_head_node(config, config_file, override_cluster_name)
|
||||
return provider.external_ip(head_node)
|
||||
ip = provider.external_ip(head_node)
|
||||
provider.cleanup()
|
||||
|
||||
return ip
|
||||
|
||||
|
||||
def get_worker_node_ips(config_file, override_cluster_name):
|
||||
@@ -409,6 +449,8 @@ def _get_head_node(config,
|
||||
TAG_RAY_NODE_TYPE: "head",
|
||||
}
|
||||
nodes = provider.nodes(head_node_tags)
|
||||
provider.cleanup()
|
||||
|
||||
if len(nodes) > 0:
|
||||
head_node = nodes[0]
|
||||
return head_node
|
||||
|
||||
@@ -2,8 +2,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import os
|
||||
import logging
|
||||
try: # py3
|
||||
from shlex import quote
|
||||
except ImportError: # py2
|
||||
@@ -20,6 +20,7 @@ def dockerize_if_needed(config):
|
||||
if not docker_image:
|
||||
if cname:
|
||||
logger.warning(
|
||||
"dockerize_if_needed: "
|
||||
"Container name given but no Docker image - continuing...")
|
||||
return config
|
||||
else:
|
||||
|
||||
@@ -11,6 +11,8 @@ from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from googleapiclient import discovery, errors
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
crm = discovery.build("cloudresourcemanager", "v1")
|
||||
iam = discovery.build("iam", "v1")
|
||||
compute = discovery.build("compute", "v1")
|
||||
@@ -30,12 +32,11 @@ DEFAULT_SERVICE_ACCOUNT_ROLES = ("roles/storage.objectAdmin",
|
||||
MAX_POLLS = 12
|
||||
POLL_INTERVAL = 5
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wait_for_crm_operation(operation):
|
||||
"""Poll for cloud resource manager operation until finished."""
|
||||
logger.info("Waiting for operation {} to finish...".format(operation))
|
||||
logger.info("wait_for_crm_operation: "
|
||||
"Waiting for operation {} to finish...".format(operation))
|
||||
|
||||
for _ in range(MAX_POLLS):
|
||||
result = crm.operations().get(name=operation["name"]).execute()
|
||||
@@ -43,7 +44,7 @@ def wait_for_crm_operation(operation):
|
||||
raise Exception(result["error"])
|
||||
|
||||
if "done" in result and result["done"]:
|
||||
logger.info("Done.")
|
||||
logger.info("wait_for_crm_operation: Operation done.")
|
||||
break
|
||||
|
||||
time.sleep(POLL_INTERVAL)
|
||||
@@ -53,8 +54,9 @@ def wait_for_crm_operation(operation):
|
||||
|
||||
def wait_for_compute_global_operation(project_name, operation):
|
||||
"""Poll for global compute operation until finished."""
|
||||
logger.info("Waiting for operation {} to finish...".format(
|
||||
operation["name"]))
|
||||
logger.info("wait_for_compute_global_operation: "
|
||||
"Waiting for operation {} to finish...".format(
|
||||
operation["name"]))
|
||||
|
||||
for _ in range(MAX_POLLS):
|
||||
result = compute.globalOperations().get(
|
||||
@@ -65,7 +67,8 @@ def wait_for_compute_global_operation(project_name, operation):
|
||||
raise Exception(result["error"])
|
||||
|
||||
if result["status"] == "DONE":
|
||||
logger.info("Done.")
|
||||
logger.info("wait_for_compute_global_operation: "
|
||||
"Operation done.")
|
||||
break
|
||||
|
||||
time.sleep(POLL_INTERVAL)
|
||||
@@ -158,8 +161,9 @@ def _configure_iam_role(config):
|
||||
service_account = _get_service_account(email, config)
|
||||
|
||||
if service_account is None:
|
||||
logger.info("Creating new service account {}".format(
|
||||
DEFAULT_SERVICE_ACCOUNT_ID))
|
||||
logger.info("_configure_iam_role: "
|
||||
"Creating new service account {}".format(
|
||||
DEFAULT_SERVICE_ACCOUNT_ID))
|
||||
|
||||
service_account = _create_service_account(
|
||||
DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config)
|
||||
@@ -231,7 +235,8 @@ def _configure_key_pair(config):
|
||||
|
||||
# Create a key since it doesn't exist locally or in GCP
|
||||
if not key_found and not os.path.exists(private_key_path):
|
||||
logger.info("Creating new key pair {}".format(key_name))
|
||||
logger.info("_configure_key_pair: "
|
||||
"Creating new key pair {}".format(key_name))
|
||||
public_key, private_key = generate_rsa_key_pair()
|
||||
|
||||
_create_project_ssh_key_pair(project, public_key, ssh_user)
|
||||
@@ -256,8 +261,9 @@ def _configure_key_pair(config):
|
||||
"Private key file {} not found for user {}"
|
||||
"".format(private_key_path, ssh_user))
|
||||
|
||||
logger.info("Private key not specified in config, using {}"
|
||||
"".format(private_key_path))
|
||||
logger.info("_configure_key_pair: "
|
||||
"Private key not specified in config, using"
|
||||
"{}".format(private_key_path))
|
||||
|
||||
config["auth"]["ssh_private_key"] = private_key_path
|
||||
|
||||
|
||||
@@ -4,24 +4,25 @@ from __future__ import print_function
|
||||
|
||||
from uuid import uuid4
|
||||
import time
|
||||
|
||||
import logging
|
||||
|
||||
from googleapiclient import discovery
|
||||
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
|
||||
from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
INSTANCE_NAME_MAX_LEN = 64
|
||||
INSTANCE_NAME_UUID_LEN = 8
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wait_for_compute_zone_operation(compute, project_name, operation, zone):
|
||||
"""Poll for compute zone operation until finished."""
|
||||
logger.info("Waiting for operation {} to finish...".format(
|
||||
operation["name"]))
|
||||
logger.info("wait_for_compute_zone_operation: "
|
||||
"Waiting for operation {} to finish...".format(
|
||||
operation["name"]))
|
||||
|
||||
for _ in range(MAX_POLLS):
|
||||
result = compute.zoneOperations().get(
|
||||
@@ -31,7 +32,8 @@ def wait_for_compute_zone_operation(compute, project_name, operation, zone):
|
||||
raise Exception(result["error"])
|
||||
|
||||
if result["status"] == "DONE":
|
||||
logger.info("Done.")
|
||||
logger.info("wait_for_compute_zone_operation: "
|
||||
"Operation {} finished.".format(operation["name"]))
|
||||
break
|
||||
|
||||
time.sleep(POLL_INTERVAL)
|
||||
|
||||
@@ -12,6 +12,7 @@ from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
filelock_logger = logging.getLogger("filelock")
|
||||
filelock_logger.setLevel(logging.WARNING)
|
||||
|
||||
@@ -26,7 +27,8 @@ class ClusterState(object):
|
||||
workers = json.loads(open(self.save_path).read())
|
||||
else:
|
||||
workers = {}
|
||||
logger.info("Loaded cluster state: {}".format(workers))
|
||||
logger.info("ClusterState: "
|
||||
"Loaded cluster state: {}".format(workers))
|
||||
for worker_ip in provider_config["worker_ips"]:
|
||||
if worker_ip not in workers:
|
||||
workers[worker_ip] = {
|
||||
@@ -50,7 +52,8 @@ class ClusterState(object):
|
||||
TAG_RAY_NODE_TYPE] == "head"
|
||||
assert len(workers) == len(provider_config["worker_ips"]) + 1
|
||||
with open(self.save_path, "w") as f:
|
||||
logger.info("Writing cluster state: {}".format(workers))
|
||||
logger.info("ClusterState: "
|
||||
"Writing cluster state: {}".format(workers))
|
||||
f.write(json.dumps(workers))
|
||||
|
||||
def get(self):
|
||||
@@ -65,7 +68,8 @@ class ClusterState(object):
|
||||
workers = self.get()
|
||||
workers[worker_id] = info
|
||||
with open(self.save_path, "w") as f:
|
||||
logger.info("Writing cluster state: {}".format(workers))
|
||||
logger.info("ClusterState: "
|
||||
"Writing cluster state: {}".format(workers))
|
||||
f.write(json.dumps(workers))
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogTimer():
|
||||
def __init__(self, message):
|
||||
self._message = message
|
||||
|
||||
def __enter__(self):
|
||||
self._start_time = datetime.datetime.utcnow()
|
||||
|
||||
def __exit__(self, *_):
|
||||
td = datetime.datetime.utcnow() - self._start_time
|
||||
logger.info(self._message +
|
||||
" [LogTimer={:.0f}ms]".format(td.total_seconds() * 1000))
|
||||
@@ -3,9 +3,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_aws():
|
||||
from ray.autoscaler.aws.config import bootstrap_aws
|
||||
@@ -174,3 +177,14 @@ class NodeProvider(object):
|
||||
def terminate_node(self, node_id):
|
||||
"""Terminates the specified node."""
|
||||
raise NotImplementedError
|
||||
|
||||
def terminate_nodes(self, node_ids):
|
||||
"""Terminates a set of nodes. May be overridden with a batch method."""
|
||||
for node_id in node_ids:
|
||||
logger.info("NodeProvider: "
|
||||
"{}: Terminating node".format(node_id))
|
||||
self.terminate_node(node_id)
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean-up when a Provider is no longer required."""
|
||||
pass
|
||||
|
||||
+153
-116
@@ -10,24 +10,33 @@ import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from multiprocessing import Process
|
||||
from threading import Thread
|
||||
|
||||
from ray.autoscaler.node_provider import get_node_provider
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG
|
||||
from ray.autoscaler.log_timer import LogTimer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# How long to wait for a node to start, in seconds
|
||||
NODE_START_WAIT_S = 300
|
||||
SSH_CHECK_INTERVAL = 5
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
SSH_CONTROL_PATH = "/tmp/ray_ssh_sockets"
|
||||
|
||||
|
||||
def pretty_cmd(cmd_str):
|
||||
return "\n\n\t{}\n\n".format(cmd_str)
|
||||
def get_default_ssh_options(private_key, connect_timeout):
|
||||
OPTS = [
|
||||
("ConnectTimeout", "{}s".format(connect_timeout)),
|
||||
("StrictHostKeyChecking", "no"),
|
||||
("ControlMaster", "auto"),
|
||||
("ControlPath", "{}/%C".format(SSH_CONTROL_PATH)),
|
||||
("ControlPersist", "yes"),
|
||||
]
|
||||
|
||||
return ["-i", private_key] + [
|
||||
x for y in (["-o", "{}={}".format(k, v)] for k, v in OPTS) for x in y
|
||||
]
|
||||
|
||||
|
||||
class NodeUpdater(object):
|
||||
@@ -36,12 +45,12 @@ class NodeUpdater(object):
|
||||
def __init__(self,
|
||||
node_id,
|
||||
provider_config,
|
||||
provider,
|
||||
auth_config,
|
||||
cluster_name,
|
||||
file_mounts,
|
||||
setup_cmds,
|
||||
runtime_hash,
|
||||
redirect_output=True,
|
||||
process_runner=subprocess,
|
||||
use_internal_ip=False):
|
||||
self.daemon = True
|
||||
@@ -49,31 +58,22 @@ class NodeUpdater(object):
|
||||
self.node_id = node_id
|
||||
self.use_internal_ip = (use_internal_ip or provider_config.get(
|
||||
"use_internal_ips", False))
|
||||
self.provider = get_node_provider(provider_config, cluster_name)
|
||||
self.provider = provider
|
||||
self.ssh_private_key = auth_config["ssh_private_key"]
|
||||
self.ssh_user = auth_config["ssh_user"]
|
||||
self.ssh_ip = self.get_node_ip()
|
||||
self.ssh_ip = None
|
||||
self.file_mounts = {
|
||||
remote: os.path.expanduser(local)
|
||||
for remote, local in file_mounts.items()
|
||||
}
|
||||
self.setup_cmds = setup_cmds
|
||||
self.runtime_hash = runtime_hash
|
||||
self.logger = logger.getChild(str(node_id))
|
||||
if redirect_output:
|
||||
self.logfile = tempfile.NamedTemporaryFile(
|
||||
mode="w", prefix="node-updater-", delete=False)
|
||||
handler = logging.StreamHandler(stream=self.logfile)
|
||||
handler.setLevel(logging.INFO)
|
||||
self.logger.addHandler(handler)
|
||||
self.output_name = self.logfile.name
|
||||
self.stdout = self.logfile
|
||||
self.stderr = self.logfile
|
||||
|
||||
def get_caller(self, check_error):
|
||||
if check_error:
|
||||
return self.process_runner.call
|
||||
else:
|
||||
self.logfile = None
|
||||
self.output_name = "(console)"
|
||||
self.stdout = sys.stdout
|
||||
self.stderr = sys.stderr
|
||||
return self.process_runner.check_call
|
||||
|
||||
def get_node_ip(self):
|
||||
if self.use_internal_ip:
|
||||
@@ -81,128 +81,173 @@ class NodeUpdater(object):
|
||||
else:
|
||||
return self.provider.external_ip(self.node_id)
|
||||
|
||||
def wait_for_ip(self, deadline):
|
||||
while time.time() < deadline and \
|
||||
not self.provider.is_terminated(self.node_id):
|
||||
logger.info("NodeUpdater: "
|
||||
"Waiting for IP of {}...".format(self.node_id))
|
||||
ip = self.get_node_ip()
|
||||
if ip is not None:
|
||||
return ip
|
||||
time.sleep(10)
|
||||
|
||||
return None
|
||||
|
||||
def set_ssh_ip_if_required(self):
|
||||
if self.ssh_ip is not None:
|
||||
return
|
||||
|
||||
# We assume that this never changes.
|
||||
# I think that's reasonable.
|
||||
deadline = time.time() + NODE_START_WAIT_S
|
||||
with LogTimer("NodeUpdater: {}: Got IP".format(self.node_id)):
|
||||
ip = self.wait_for_ip(deadline)
|
||||
assert ip is not None, "Unable to find IP of node"
|
||||
|
||||
self.ssh_ip = ip
|
||||
|
||||
# This should run before any SSH commands and therefore ensure that
|
||||
# the ControlPath directory exists, allowing SSH to maintain
|
||||
# persistent sessions later on.
|
||||
with open("/dev/null", "w") as redirect:
|
||||
self.get_caller(False)(
|
||||
["mkdir", "-p", SSH_CONTROL_PATH],
|
||||
stdout=redirect,
|
||||
stderr=redirect)
|
||||
|
||||
self.get_caller(False)(
|
||||
["chmod", "0700", SSH_CONTROL_PATH],
|
||||
stdout=redirect,
|
||||
stderr=redirect)
|
||||
|
||||
def run(self):
|
||||
self.logger.info(
|
||||
"NodeUpdater: Updating {} to {}, logging to {}".format(
|
||||
self.node_id, self.runtime_hash, self.output_name))
|
||||
logger.info("NodeUpdater: "
|
||||
"{}: Updating to {}".format(self.node_id,
|
||||
self.runtime_hash))
|
||||
try:
|
||||
self.do_update()
|
||||
m = "{}: Applied config {}".format(self.node_id, self.runtime_hash)
|
||||
with LogTimer("NodeUpdater: {}".format(m)):
|
||||
self.do_update()
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
if hasattr(e, "cmd"):
|
||||
error_str = "(Exit Status {}) {}".format(
|
||||
e.returncode, pretty_cmd(" ".join(e.cmd)))
|
||||
self.logger.error("NodeUpdater: Error updating {}"
|
||||
"See {} for remote logs.".format(
|
||||
error_str, self.output_name))
|
||||
e.returncode, " ".join(e.cmd))
|
||||
logger.error("NodeUpdater: "
|
||||
"{}: Error updating {}".format(
|
||||
self.node_id, error_str))
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "update-failed"})
|
||||
if self.logfile is not None:
|
||||
self.logger.info("----- BEGIN REMOTE LOGS -----\n" +
|
||||
open(self.logfile.name).read() +
|
||||
"\n----- END REMOTE LOGS -----")
|
||||
raise e
|
||||
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {
|
||||
TAG_RAY_NODE_STATUS: "up-to-date",
|
||||
TAG_RAY_RUNTIME_CONFIG: self.runtime_hash
|
||||
})
|
||||
self.logger.info("NodeUpdater: Applied config {} to node {}".format(
|
||||
self.runtime_hash, self.node_id))
|
||||
|
||||
def do_update(self):
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "waiting-for-ssh"})
|
||||
deadline = time.time() + NODE_START_WAIT_S
|
||||
self.exitcode = 0
|
||||
|
||||
# Wait for external IP
|
||||
while time.time() < deadline and \
|
||||
not self.provider.is_terminated(self.node_id):
|
||||
self.logger.info("NodeUpdater: Waiting for IP of {}...".format(
|
||||
self.node_id))
|
||||
self.ssh_ip = self.get_node_ip()
|
||||
if self.ssh_ip is not None:
|
||||
break
|
||||
time.sleep(10)
|
||||
assert self.ssh_ip is not None, "Unable to find IP of node"
|
||||
def wait_for_ssh(self, deadline):
|
||||
logger.info("NodeUpdater: "
|
||||
"{}: Waiting for SSH...".format(self.node_id))
|
||||
|
||||
# Wait for SSH access
|
||||
ssh_ok = False
|
||||
while time.time() < deadline and \
|
||||
not self.provider.is_terminated(self.node_id):
|
||||
try:
|
||||
self.logger.info(
|
||||
"NodeUpdater: Waiting for SSH to {}...".format(
|
||||
self.node_id))
|
||||
if not self.provider.is_running(self.node_id):
|
||||
raise Exception("Node not running yet...")
|
||||
logger.debug("NodeUpdater: "
|
||||
"{}: Waiting for SSH...".format(self.node_id))
|
||||
self.ssh_cmd(
|
||||
"uptime",
|
||||
connect_timeout=5,
|
||||
redirect=open("/dev/null", "w"))
|
||||
ssh_ok = True
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
retry_str = str(e)
|
||||
if hasattr(e, "cmd"):
|
||||
retry_str = "(Exit Status {}): {}".format(
|
||||
e.returncode, pretty_cmd(" ".join(e.cmd)))
|
||||
self.logger.debug(
|
||||
"NodeUpdater: SSH not up, retrying: {}".format(retry_str),
|
||||
)
|
||||
e.returncode, " ".join(e.cmd))
|
||||
logger.debug("NodeUpdater: "
|
||||
"{}: SSH not up, retrying: {}".format(
|
||||
self.node_id, retry_str))
|
||||
time.sleep(SSH_CHECK_INTERVAL)
|
||||
else:
|
||||
break
|
||||
assert ssh_ok, "Unable to SSH to node"
|
||||
|
||||
return False
|
||||
|
||||
def do_update(self):
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "waiting-for-ssh"})
|
||||
|
||||
deadline = time.time() + NODE_START_WAIT_S
|
||||
self.set_ssh_ip_if_required()
|
||||
|
||||
# Wait for SSH access
|
||||
with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)):
|
||||
ssh_ok = self.wait_for_ssh(deadline)
|
||||
assert ssh_ok, "Unable to SSH to node"
|
||||
|
||||
# Rsync file mounts
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "syncing-files"})
|
||||
for remote_path, local_path in self.file_mounts.items():
|
||||
self.logger.info("NodeUpdater: Syncing {} to {}...".format(
|
||||
local_path, remote_path))
|
||||
logger.info("NodeUpdater: "
|
||||
"{}: Syncing {} to {}...".format(
|
||||
self.node_id, local_path, remote_path))
|
||||
assert os.path.exists(local_path), local_path
|
||||
if os.path.isdir(local_path):
|
||||
if not local_path.endswith("/"):
|
||||
local_path += "/"
|
||||
if not remote_path.endswith("/"):
|
||||
remote_path += "/"
|
||||
self.ssh_cmd("mkdir -p {}".format(os.path.dirname(remote_path)))
|
||||
self.rsync_up(local_path, remote_path)
|
||||
|
||||
m = "{}: Synced {} to {}".format(self.node_id, local_path,
|
||||
remote_path)
|
||||
with LogTimer("NodeUpdater {}".format(m)):
|
||||
self.ssh_cmd(
|
||||
"mkdir -p {}".format(os.path.dirname(remote_path)),
|
||||
redirect=open("/dev/null", "w"),
|
||||
)
|
||||
self.rsync_up(
|
||||
local_path, remote_path, redirect=open("/dev/null", "w"))
|
||||
|
||||
# Run init commands
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "setting-up"})
|
||||
for cmd in self.setup_cmds:
|
||||
self.ssh_cmd(cmd, verbose=True)
|
||||
|
||||
def rsync_up(self, source, target, check_error=True):
|
||||
if check_error:
|
||||
call = self.process_runner.call
|
||||
else:
|
||||
call = self.process_runner.check_call
|
||||
call(
|
||||
m = "{}: Setup commands completed".format(self.node_id)
|
||||
with LogTimer("NodeUpdater: {}".format(m)):
|
||||
for cmd in self.setup_cmds:
|
||||
self.ssh_cmd(
|
||||
cmd,
|
||||
# verbose=True,
|
||||
redirect=open("/dev/null", "w"))
|
||||
|
||||
def rsync_up(self, source, target, redirect=None, check_error=True):
|
||||
self.set_ssh_ip_if_required()
|
||||
self.get_caller(check_error)(
|
||||
[
|
||||
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
|
||||
"-o ConnectTimeout=120s -o StrictHostKeyChecking=no",
|
||||
"rsync", "-e",
|
||||
" ".join(["ssh"] +
|
||||
get_default_ssh_options(self.ssh_private_key, 120)),
|
||||
"--delete", "-avz", source, "{}@{}:{}".format(
|
||||
self.ssh_user, self.ssh_ip, target)
|
||||
],
|
||||
stdout=self.stdout,
|
||||
stderr=self.stderr)
|
||||
stdout=redirect or sys.stdout,
|
||||
stderr=redirect or sys.stderr)
|
||||
|
||||
def rsync_down(self, source, target, check_error=True):
|
||||
if check_error:
|
||||
call = self.process_runner.call
|
||||
else:
|
||||
call = self.process_runner.check_call
|
||||
call(
|
||||
def rsync_down(self, source, target, redirect=None, check_error=True):
|
||||
self.set_ssh_ip_if_required()
|
||||
self.get_caller(check_error)(
|
||||
[
|
||||
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
|
||||
"-o ConnectTimeout=120s -o StrictHostKeyChecking=no", "-avz",
|
||||
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, source), target
|
||||
"rsync", "-e",
|
||||
" ".join(["ssh"] +
|
||||
get_default_ssh_options(self.ssh_private_key, 120)),
|
||||
"-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
|
||||
source), target
|
||||
],
|
||||
stdout=self.stdout,
|
||||
stderr=self.stderr)
|
||||
stdout=redirect or sys.stdout,
|
||||
stderr=redirect or sys.stderr)
|
||||
|
||||
def ssh_cmd(self,
|
||||
cmd,
|
||||
@@ -213,9 +258,12 @@ class NodeUpdater(object):
|
||||
emulate_interactive=True,
|
||||
expect_error=False,
|
||||
port_forward=None):
|
||||
|
||||
self.set_ssh_ip_if_required()
|
||||
|
||||
if verbose:
|
||||
self.logger.info("NodeUpdater: running {} on {}...".format(
|
||||
pretty_cmd(cmd), self.ssh_ip))
|
||||
logger.info("NodeUpdater: "
|
||||
"Running {} on {}...".format(cmd, self.ssh_ip))
|
||||
ssh = ["ssh"]
|
||||
if allocate_tty:
|
||||
ssh.append("-tt")
|
||||
@@ -224,35 +272,24 @@ class NodeUpdater(object):
|
||||
"set -i || true && source ~/.bashrc && "
|
||||
"export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && ")
|
||||
cmd = "bash --login -c {}".format(quote(force_interactive + cmd))
|
||||
if expect_error:
|
||||
call = self.process_runner.call
|
||||
else:
|
||||
call = self.process_runner.check_call
|
||||
|
||||
if port_forward is None:
|
||||
ssh_opt = []
|
||||
else:
|
||||
ssh_opt = [
|
||||
"-L", "{}:localhost:{}".format(port_forward, port_forward)
|
||||
]
|
||||
call(
|
||||
ssh + ssh_opt + [
|
||||
"-o", "ConnectTimeout={}s".format(connect_timeout), "-o",
|
||||
"StrictHostKeyChecking=no", "-i", self.ssh_private_key,
|
||||
"{}@{}".format(self.ssh_user, self.ssh_ip), cmd
|
||||
],
|
||||
stdout=redirect or self.stdout,
|
||||
stderr=redirect or self.stderr)
|
||||
|
||||
self.get_caller(expect_error)(
|
||||
ssh + ssh_opt + get_default_ssh_options(self.ssh_private_key,
|
||||
connect_timeout) +
|
||||
["{}@{}".format(self.ssh_user, self.ssh_ip), cmd],
|
||||
stdout=redirect or sys.stdout,
|
||||
stderr=redirect or sys.stderr)
|
||||
|
||||
|
||||
class NodeUpdaterProcess(NodeUpdater, Process):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Process.__init__(self)
|
||||
NodeUpdater.__init__(self, *args, **kwargs)
|
||||
|
||||
|
||||
# Single-threaded version for unit tests
|
||||
class NodeUpdaterThread(NodeUpdater, Thread):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Thread.__init__(self)
|
||||
NodeUpdater.__init__(self, *args, **kwargs)
|
||||
self.exitcode = 0
|
||||
self.exitcode = -1
|
||||
|
||||
+19
-12
@@ -20,7 +20,6 @@ from ray.services import get_ip_address, get_port
|
||||
from ray.utils import (binary_to_hex, binary_to_object_id, hex_to_binary,
|
||||
setup_logger)
|
||||
|
||||
# Set up logging.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -69,8 +68,10 @@ class Monitor(object):
|
||||
# that redis server.
|
||||
addr_port = self.redis.lrange("RedisShards", 0, -1)
|
||||
if len(addr_port) > 1:
|
||||
logger.warning("TODO: if launching > 1 redis shard, flushing "
|
||||
"needs to touch shards in parallel.")
|
||||
logger.warning(
|
||||
"Monitor: "
|
||||
"TODO: if launching > 1 redis shard, flushing needs to "
|
||||
"touch shards in parallel.")
|
||||
self.issue_gcs_flushes = False
|
||||
else:
|
||||
addr_port = addr_port[0].split(b":")
|
||||
@@ -82,6 +83,7 @@ class Monitor(object):
|
||||
self.redis_shard.execute_command("HEAD.FLUSH 0")
|
||||
except redis.exceptions.ResponseError as e:
|
||||
logger.info(
|
||||
"Monitor: "
|
||||
"Turning off flushing due to exception: {}".format(
|
||||
str(e)))
|
||||
self.issue_gcs_flushes = False
|
||||
@@ -128,8 +130,9 @@ class Monitor(object):
|
||||
self.load_metrics.update(ip, static_resources,
|
||||
dynamic_resources)
|
||||
else:
|
||||
print("Warning: could not find ip for client {} in {}.".format(
|
||||
client_id, self.local_scheduler_id_to_ip_map))
|
||||
logger.warning(
|
||||
"Monitor: "
|
||||
"could not find ip for client {}".format(client_id))
|
||||
|
||||
def _xray_clean_up_entries_for_driver(self, driver_id):
|
||||
"""Remove this driver's object/task entries from redis.
|
||||
@@ -185,11 +188,14 @@ class Monitor(object):
|
||||
continue
|
||||
redis = self.state.redis_clients[shard_index]
|
||||
num_deleted = redis.delete(*keys)
|
||||
logger.info("Removed {} dead redis entries of the driver from"
|
||||
" redis shard {}.".format(num_deleted, shard_index))
|
||||
logger.info("Monitor: "
|
||||
"Removed {} dead redis entries of the "
|
||||
"driver from redis shard {}.".format(
|
||||
num_deleted, shard_index))
|
||||
if num_deleted != len(keys):
|
||||
logger.warning("Failed to remove {} relevant redis entries"
|
||||
" from redis shard {}.".format(
|
||||
logger.warning("Monitor: "
|
||||
"Failed to remove {} relevant redis "
|
||||
"entries from redis shard {}.".format(
|
||||
len(keys) - num_deleted, shard_index))
|
||||
|
||||
def xray_driver_removed_handler(self, unused_channel, data):
|
||||
@@ -205,8 +211,9 @@ class Monitor(object):
|
||||
message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
|
||||
driver_data, 0)
|
||||
driver_id = message.DriverId()
|
||||
logger.info("XRay Driver {} has been removed.".format(
|
||||
binary_to_hex(driver_id)))
|
||||
logger.info("Monitor: "
|
||||
"XRay Driver {} has been removed.".format(
|
||||
binary_to_hex(driver_id)))
|
||||
self._xray_clean_up_entries_for_driver(driver_id)
|
||||
|
||||
def process_messages(self, max_messages=10000):
|
||||
@@ -281,7 +288,7 @@ class Monitor(object):
|
||||
max_entries_to_flush = self.gcs_flush_policy.num_entries_to_flush()
|
||||
num_flushed = self.redis_shard.execute_command(
|
||||
"HEAD.FLUSH {}".format(max_entries_to_flush))
|
||||
logger.info("num_flushed {}".format(num_flushed))
|
||||
logger.info("Monitor: num_flushed {}".format(num_flushed))
|
||||
|
||||
# This flushes event log and log files.
|
||||
ray.experimental.flush_redis_unsafe(self.redis)
|
||||
|
||||
Reference in New Issue
Block a user