Auto-scale ray clusters based on GCS load metrics (#1348)

This adds (experimental) auto-scaling support for Ray clusters based on GCS load metrics. The auto-scaling algorithm is as follows:

Based on current (instantaneous) load information, we compute the approximate number of "used workers". This is based on the bottleneck resource, e.g. if 8/8 GPUs are used in a 8-node cluster but all the CPUs are idle, the number of used nodes is still counted as 8. This number can also be fractional.
We scale that number by 1 / target_utilization_fraction and round up to determine the target cluster size (subject to the max_workers constraint). The autoscaler control loop takes care of launching new nodes until the target cluster size is met.
When a node is idle for more than idle_timeout_minutes, we remove it from the cluster if that would not drop the cluster size below min_workers.
Note that we'll need to update the wheel in the example yaml file after this PR is merged.
This commit is contained in:
Eric Liang
2017-12-31 14:39:57 -08:00
committed by GitHub
parent e970e24ea5
commit b6c42f96be
12 changed files with 657 additions and 176 deletions
+248 -55
View File
@@ -6,12 +6,18 @@ import json
import hashlib
import os
import subprocess
import time
import traceback
from collections import defaultdict
from datetime import datetime
import numpy as np
import yaml
from ray.ray_constants import AUTOSCALER_MAX_NUM_FAILURES, \
AUTOSCALER_MAX_CONCURRENT_LAUNCHES, AUTOSCALER_UPDATE_INTERVAL_S, \
AUTOSCALER_HEARTBEAT_TIMEOUT_S
from ray.autoscaler.node_provider import get_node_provider
from ray.autoscaler.updater import NodeUpdaterProcess
from ray.autoscaler.tags import TAG_RAY_LAUNCH_CONFIG, \
@@ -31,6 +37,14 @@ CLUSTER_CONFIG_SCHEMA = {
# node. This takes precedence over min_workers.
"max_workers": int,
# The autoscaler will scale up the cluster to this target fraction of
# resources usage. For example, if a cluster of 8 nodes is 100% busy
# and target_utilization was 0.8, it would resize the cluster to 10.
"target_utilization_fraction": float,
# If a node is idle for this many minutes, it will be removed.
"idle_timeout_minutes": int,
# Cloud-provider specific configuration.
"provider": {
"type": str, # e.g. aws
@@ -49,20 +63,114 @@ CLUSTER_CONFIG_SCHEMA = {
# Map of remote paths to local paths, e.g. {"/tmp/data": "/my/local/data"}
"file_mounts": dict,
# List of shell commands to run to initialize the head node.
"head_init_commands": list,
# List of common shell commands to run to initialize nodes.
"setup_commands": list,
# List of shell commands to run to initialize workers.
"worker_init_commands": list,
# Commands that will be run on the head node after common setup.
"head_setup_commands": list,
# Commands that will be run on worker nodes after common setup.
"worker_setup_commands": list,
# Command to start ray on the head node. You shouldn't need to modify this.
"head_start_ray_commands": list,
# Command to start ray on worker nodes. You shouldn't need to modify this.
"worker_start_ray_commands": list,
# Whether to avoid restarting the cluster during updates. This field is
# controlled by the ray --no-restart flag and cannot be set by the user.
"no_restart": None,
}
# Abort autoscaling if more than this number of errors are encountered. This
# is a safety feature to prevent e.g. runaway node launches.
MAX_NUM_FAILURES = 5
class LoadMetrics(object):
"""Container for cluster load metrics.
# Max number of nodes to launch at a time.
MAX_CONCURRENT_LAUNCHES = 10
Metrics here are updated from local scheduler heartbeats. The autoscaler
queries these metrics to determine when to scale up, and which nodes
can be removed.
"""
def __init__(self):
self.last_used_time_by_ip = {}
self.last_heartbeat_time_by_ip = {}
self.static_resources_by_ip = {}
self.dynamic_resources_by_ip = {}
self.local_ip = services.get_node_ip_address()
def update(self, ip, static_resources, dynamic_resources):
self.static_resources_by_ip[ip] = static_resources
self.dynamic_resources_by_ip[ip] = dynamic_resources
now = time.time()
if ip not in self.last_used_time_by_ip or \
static_resources != dynamic_resources:
self.last_used_time_by_ip[ip] = now
self.last_heartbeat_time_by_ip[ip] = now
def mark_active(self, ip):
self.last_heartbeat_time_by_ip[ip] = time.time()
def prune_active_ips(self, active_ips):
active_ips = set(active_ips)
active_ips.add(self.local_ip)
def prune(mapping):
unwanted = set(mapping) - active_ips
for unwanted_key in unwanted:
del mapping[unwanted_key]
if unwanted:
print(
"Removed {} stale ip mappings: {} not in {}".format(
len(unwanted), unwanted, active_ips))
prune(self.last_used_time_by_ip)
prune(self.static_resources_by_ip)
prune(self.dynamic_resources_by_ip)
def approx_workers_used(self):
return self._info()["NumNodesUsed"]
def debug_string(self):
return " - {}".format(
"\n - ".join(
["{}: {}".format(k, v)
for k, v in sorted(self._info().items())]))
def _info(self):
nodes_used = 0.0
resources_used = {}
resources_total = {}
now = time.time()
for ip, max_resources in self.static_resources_by_ip.items():
avail_resources = self.dynamic_resources_by_ip[ip]
max_frac = 0.0
for resource_id, amount in max_resources.items():
used = amount - avail_resources[resource_id]
if resource_id not in resources_used:
resources_used[resource_id] = 0.0
resources_total[resource_id] = 0.0
resources_used[resource_id] += used
resources_total[resource_id] += amount
assert used >= 0
if amount > 0:
frac = used / float(amount)
if frac > max_frac:
max_frac = frac
nodes_used += max_frac
idle_times = [now - t for t in self.last_used_time_by_ip.values()]
return {
"ResourceUsage": ", ".join([
"{}/{} {}".format(
round(resources_used[rid], 2),
round(resources_total[rid], 2), rid)
for rid in sorted(resources_used)]),
"NumNodesConnected": len(self.static_resources_by_ip),
"NumNodesUsed": round(nodes_used, 2),
"NodeIdleSeconds": "Min={} Mean={} Max={}".format(
int(np.min(idle_times)) if idle_times else -1,
int(np.mean(idle_times)) if idle_times else -1,
int(np.max(idle_times)) if idle_times else -1),
}
class StandardAutoscaler(object):
@@ -84,12 +192,15 @@ class StandardAutoscaler(object):
"""
def __init__(
self, config_path,
max_concurrent_launches=MAX_CONCURRENT_LAUNCHES,
max_failures=MAX_NUM_FAILURES, process_runner=subprocess,
verbose_updates=False, node_updater_cls=NodeUpdaterProcess):
self, config_path, load_metrics,
max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES,
max_failures=AUTOSCALER_MAX_NUM_FAILURES,
process_runner=subprocess, verbose_updates=False,
node_updater_cls=NodeUpdaterProcess,
update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S):
self.config_path = config_path
self.reload_config(errors_fatal=True)
self.load_metrics = load_metrics
self.provider = get_node_provider(
self.config["provider"], self.config["cluster_name"])
@@ -102,7 +213,10 @@ class StandardAutoscaler(object):
# Map from node_id to NodeUpdater processes
self.updaters = {}
self.num_failed_updates = defaultdict(int)
self.num_successful_updates = defaultdict(int)
self.num_failures = 0
self.last_update_time = 0.0
self.update_interval_s = update_interval_s
for local_path in self.config["file_mounts"].values():
assert os.path.exists(local_path)
@@ -123,43 +237,59 @@ class StandardAutoscaler(object):
raise e
def _update(self):
nodes = self.workers()
target_num_workers = self.config["max_workers"]
# Throttle autoscaling updates to this interval to avoid exceeding
# rate limits on API calls.
if time.time() - self.last_update_time < self.update_interval_s:
return
# Terminate nodes while there are too many
while len(nodes) > target_num_workers:
self.last_update_time = time.time()
nodes = self.workers()
print(self.debug_string(nodes))
self.load_metrics.prune_active_ips(
[self.provider.internal_ip(node_id) for node_id in nodes])
# Terminate any idle or out of date nodes
last_used = self.load_metrics.last_used_time_by_ip
horizon = time.time() - (60 * self.config["idle_timeout_minutes"])
num_terminated = 0
for node_id in nodes:
node_ip = self.provider.internal_ip(node_id)
if node_ip in last_used and last_used[node_ip] < horizon and \
len(nodes) - num_terminated > self.config["min_workers"]:
num_terminated += 1
print(
"StandardAutoscaler: Terminating idle node: "
"{}".format(node_id))
self.provider.terminate_node(node_id)
elif not self.launch_config_ok(node_id):
num_terminated += 1
print(
"StandardAutoscaler: Terminating outdated node: "
"{}".format(node_id))
self.provider.terminate_node(node_id)
if num_terminated > 0:
nodes = self.workers()
print(self.debug_string(nodes))
# Terminate nodes if there are too many
num_terminated = 0
while len(nodes) > self.config["max_workers"]:
num_terminated += 1
print(
"StandardAutoscaler: Terminating unneeded node: "
"{}".format(nodes[-1]))
self.provider.terminate_node(nodes[-1])
nodes = nodes[:-1]
if num_terminated > 0:
nodes = self.workers()
print(self.debug_string())
print(self.debug_string(nodes))
if target_num_workers == 0:
return
# Update nodes with out-of-date files
for node_id in nodes:
self.update_if_needed(node_id)
# Launch a new node if needed
if len(nodes) < target_num_workers:
# Launch new nodes if needed
target_num = self.target_num_workers()
if len(nodes) < target_num:
self.launch_new_node(
min(
self.max_concurrent_launches,
target_num_workers - len(nodes)))
min(self.max_concurrent_launches, target_num - len(nodes)))
print(self.debug_string())
return
else:
# If enough nodes, terminate an out-of-date node.
for node_id in nodes:
if not self.launch_config_ok(node_id):
print(
"StandardAutoscaler: Terminating outdated node: "
"{}".format(node_id))
self.provider.terminate_node(node_id)
print(self.debug_string())
return
# Process any completed updates
completed = []
@@ -168,10 +298,24 @@ class StandardAutoscaler(object):
completed.append(node_id)
if completed:
for node_id in completed:
if self.updaters[node_id].exitcode != 0:
if self.updaters[node_id].exitcode == 0:
self.num_successful_updates[node_id] += 1
else:
self.num_failed_updates[node_id] += 1
del self.updaters[node_id]
print(self.debug_string())
# Mark the node as active to prevent the node recovery logic
# immediately trying to restart Ray on the new node.
self.load_metrics.mark_active(self.provider.internal_ip(node_id))
nodes = self.workers()
print(self.debug_string(nodes))
# Update nodes with out-of-date files
for node_id in nodes:
self.update_if_needed(node_id)
# Attempt to recover unhealthy nodes
for node_id in nodes:
self.recover_if_needed(node_id)
def reload_config(self, errors_fatal=False):
try:
@@ -181,7 +325,10 @@ class StandardAutoscaler(object):
new_launch_hash = hash_launch_conf(
new_config["worker_nodes"], new_config["auth"])
new_runtime_hash = hash_runtime_conf(
new_config["file_mounts"], new_config["worker_init_commands"])
new_config["file_mounts"],
[new_config["setup_commands"],
new_config["worker_setup_commands"],
new_config["worker_start_ray_commands"]])
self.config = new_config
self.launch_hash = new_launch_hash
self.runtime_hash = new_runtime_hash
@@ -193,6 +340,14 @@ class StandardAutoscaler(object):
"StandardAutoscaler: Error parsing config: {}",
traceback.format_exc())
def target_num_workers(self):
target_frac = self.config["target_utilization_fraction"]
cur_used = self.load_metrics.approx_workers_used()
ideal_num_workers = int(np.ceil(cur_used / float(target_frac)))
return min(
self.config["max_workers"],
max(self.config["min_workers"], ideal_num_workers))
def launch_config_ok(self, node_id):
launch_conf = self.provider.node_tags(node_id).get(
TAG_RAY_LAUNCH_CONFIG)
@@ -209,30 +364,66 @@ class StandardAutoscaler(object):
return False
return True
def recover_if_needed(self, node_id):
if not self.can_update(node_id):
return
last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip.get(
self.provider.internal_ip(node_id), 0)
if time.time() - last_heartbeat_time < AUTOSCALER_HEARTBEAT_TIMEOUT_S:
return
print("StandardAutoscaler: Restarting Ray on {}".format(node_id))
updater = self.node_updater_cls(
node_id,
self.config["provider"],
self.config["auth"],
self.config["cluster_name"],
{},
with_head_node_ip(self.config["worker_start_ray_commands"]),
self.runtime_hash,
redirect_output=not self.verbose_updates,
process_runner=self.process_runner)
updater.start()
self.updaters[node_id] = updater
def update_if_needed(self, node_id):
if not self.provider.is_running(node_id):
return
if not self.launch_config_ok(node_id):
return
if node_id in self.updaters:
return
if self.num_failed_updates.get(node_id, 0) > 0: # TODO(ekl) retry?
if not self.can_update(node_id):
return
if self.files_up_to_date(node_id):
return
if self.config.get("no_restart", False) and \
self.num_successful_updates.get(node_id, 0) > 0:
init_commands = (
self.config["setup_commands"] +
self.config["worker_setup_commands"])
else:
init_commands = (
self.config["setup_commands"] +
self.config["worker_setup_commands"] +
self.config["worker_start_ray_commands"])
updater = self.node_updater_cls(
node_id,
self.config["provider"],
self.config["auth"],
self.config["cluster_name"],
self.config["file_mounts"],
with_head_node_ip(self.config["worker_init_commands"]),
with_head_node_ip(init_commands),
self.runtime_hash,
redirect_output=not self.verbose_updates,
process_runner=self.process_runner)
updater.start()
self.updaters[node_id] = updater
def can_update(self, node_id):
if not self.provider.is_running(node_id):
return False
if not self.launch_config_ok(node_id):
return False
if node_id in self.updaters:
return False
if self.num_failed_updates.get(node_id, 0) > 0: # TODO(ekl) retry?
return False
return True
def launch_new_node(self, count):
print("StandardAutoscaler: Launching {} new nodes".format(count))
num_before = len(self.workers())
@@ -257,21 +448,23 @@ class StandardAutoscaler(object):
def debug_string(self, nodes=None):
if nodes is None:
nodes = self.workers()
target_num_workers = self.config["max_workers"]
suffix = ""
if self.updaters:
suffix += " ({} updating)".format(len(self.updaters))
if self.num_failed_updates:
suffix += " ({} failed to update)".format(
len(self.num_failed_updates))
return "StandardAutoscaler: Have {} / {} target nodes{}".format(
len(nodes), target_num_workers, suffix)
return "StandardAutoscaler [{}]: {}/{} target nodes{}\n{}".format(
datetime.now(), len(nodes), self.target_num_workers(),
suffix, self.load_metrics.debug_string())
def validate_config(config, schema=CLUSTER_CONFIG_SCHEMA):
if type(config) is not dict:
raise ValueError("Config is not a dictionary")
for k, v in schema.items():
if v is None:
continue # None means we don't validate the field
if k not in config:
raise ValueError(
"Missing required config key `{}` of type {}".format(
@@ -9,6 +9,15 @@ min_workers: 2
# node. This takes precedence over min_workers.
max_workers: 2
# The autoscaler will scale up the cluster to this target fraction of resource
# usage. For example, if a cluster of 10 nodes is 100% busy and
# target_utilization is 0.8, it would resize the cluster to 13. This fraction
# can be decreased to increase the aggressiveness of upscaling.
target_utilization_fraction: 0.8
# If a node is idle for this many minutes, it will be removed.
idle_timeout_minutes: 5
# Cloud-provider specific configuration.
provider:
type: aws
@@ -56,37 +65,32 @@ file_mounts: {
# "/path2/on/remote/machine": "/path2/on/local/machine",
}
# List of shell commands to run to initialize the head node.
head_init_commands:
# List of shell commands to run to set up nodes.
setup_commands:
# Install basics.
- sudo apt-get update
- sudo apt-get install -y cmake pkg-config build-essential autoconf curl libtool unzip python
# Install Anaconda.
- wget https://repo.continuum.io/archive/Anaconda3-5.0.1-Linux-x86_64.sh
- bash Anaconda3-5.0.1-Linux-x86_64.sh -b -p $HOME/anaconda3
- wget https://repo.continuum.io/archive/Anaconda3-5.0.1-Linux-x86_64.sh || true
- bash Anaconda3-5.0.1-Linux-x86_64.sh -b -p $HOME/anaconda3 || true
- echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.bashrc
# Build Ray.
- git clone https://github.com/ray-project/ray
- PATH=/home/ubuntu/anaconda3/bin:$PATH pip install -U cloudpickle boto3==1.4.8
- cd ray/python; PATH=/home/ubuntu/anaconda3/bin:$PATH python setup.py develop
# Start Ray.
- PATH=/home/ubuntu/anaconda3/bin:$PATH ray stop
- PATH=/home/ubuntu/anaconda3/bin:$PATH ray start --head --redis-port=6379 --autoscaling-config=~/ray_bootstrap_config.yaml
- git clone https://github.com/ray-project/ray || true
- pip install -U cloudpickle boto3==1.4.8
- cd ray/python; python setup.py develop
# List of shell commands to run to initialize workers.
worker_init_commands:
# Install basics.
- sudo apt-get update
- sudo apt-get install -y cmake pkg-config build-essential autoconf curl libtool unzip python
# Install Anaconda.
- sudo apt-get update
- wget https://repo.continuum.io/archive/Anaconda3-5.0.1-Linux-x86_64.sh
- bash Anaconda3-5.0.1-Linux-x86_64.sh -b -p $HOME/anaconda3
- echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.bashrc
# Build Ray.
- git clone https://github.com/ray-project/ray
- PATH=/home/ubuntu/anaconda3/bin:$PATH pip install -U cloudpickle boto3==1.4.8
- cd ray/python; PATH=/home/ubuntu/anaconda3/bin:$PATH python setup.py develop
# Start Ray.
- PATH=/home/ubuntu/anaconda3/bin:$PATH ray stop
- PATH=/home/ubuntu/anaconda3/bin:$PATH ray start --head --redis-address=$RAY_HEAD_IP:6379
# Custom commands that will be run on the head node after common setup.
head_setup_commands: []
# Custom commands that will be run on worker nodes after common setup.
worker_setup_commands: []
# Command to start ray on the head node. You don't need to change this.
head_start_ray_commands:
- ray stop
- ray start --head --redis-port=6379 --autoscaling-config=~/ray_bootstrap_config.yaml
# Command to start ray on worker nodes. You don't need to change this.
worker_start_ray_commands:
- ray stop
- ray start --redis-address=$RAY_HEAD_IP:6379
+30 -12
View File
@@ -7,7 +7,16 @@ min_workers: 0
# The maximum number of workers nodes to launch in addition to the head
# node. This takes precedence over min_workers.
max_workers: 2
max_workers: 4
# The autoscaler will scale up the cluster to this target fraction of resource
# usage. For example, if a cluster of 10 nodes is 100% busy and
# target_utilization is 0.8, it would resize the cluster to 13. This fraction
# can be decreased to increase the aggressiveness of upscaling.
target_utilization_fraction: 0.8
# If a node is idle for this many minutes, it will be removed.
idle_timeout_minutes: 5
# Cloud-provider specific configuration.
provider:
@@ -56,18 +65,27 @@ file_mounts: {
# "/path2/on/remote/machine": "/path2/on/local/machine",
}
# List of shell commands to run to initialize the head node.
head_init_commands:
# List of shell commands to run to set up nodes.
setup_commands:
# Note: if you're developing Ray, you probably want to create an AMI that
# has your Ray repo pre-cloned. Then, you can replace the pip installs
# below with a git checkout <your_sha> (and possibly a recompile).
- PATH=/home/ubuntu/anaconda3/bin:$PATH pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/f5ea44338eca392df3a868035df3901829cc2ca1/ray-0.3.0-cp36-cp36m-manylinux1_x86_64.whl
- PATH=/home/ubuntu/anaconda3/bin:$PATH pip install boto3==1.4.8 # 1.4.8 adds InstanceMarketOptions
- PATH=/home/ubuntu/anaconda3/bin:$PATH ray stop
- PATH=/home/ubuntu/anaconda3/bin:$PATH ray start --head --redis-port=6379 --autoscaling-config=~/ray_bootstrap_config.yaml
# TODO(ekl) update this to a wheel from master
- pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/f5ea44338eca392df3a868035df3901829cc2ca1/ray-0.3.0-cp36-cp36m-manylinux1_x86_64.whl
# List of shell commands to run to initialize workers.
worker_init_commands:
- PATH=/home/ubuntu/anaconda3/bin:$PATH pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/f5ea44338eca392df3a868035df3901829cc2ca1/ray-0.3.0-cp36-cp36m-manylinux1_x86_64.whl
- PATH=/home/ubuntu/anaconda3/bin:$PATH ray stop
- PATH=/home/ubuntu/anaconda3/bin:$PATH ray start --redis-address=$RAY_HEAD_IP:6379
# Custom commands that will be run on the head node after common setup.
head_setup_commands:
- pip install boto3==1.4.8 # 1.4.8 adds InstanceMarketOptions
# Custom commands that will be run on worker nodes after common setup.
worker_setup_commands: []
# Command to start ray on the head node. You don't need to change this.
head_start_ray_commands:
- ray stop
- ray start --head --redis-port=6379 --autoscaling-config=~/ray_bootstrap_config.yaml
# Command to start ray on worker nodes. You don't need to change this.
worker_start_ray_commands:
- ray stop
- ray start --redis-address=$RAY_HEAD_IP:6379
+26 -1
View File
@@ -13,6 +13,14 @@ class AWSNodeProvider(NodeProvider):
NodeProvider.__init__(self, provider_config, cluster_name)
self.ec2 = boto3.resource("ec2", region_name=provider_config["region"])
# Cache of node objects from the last nodes() call. This avoids
# excessive DescribeInstances requests.
self.cached_nodes = {}
# Cache of ip lookups. We assume IPs never change once assigned.
self.internal_ip_cache = {}
self.external_ip_cache = {}
def nodes(self, tag_filters):
filters = [
{
@@ -30,6 +38,7 @@ class AWSNodeProvider(NodeProvider):
"Values": [v],
})
instances = list(self.ec2.instances.filter(Filters=filters))
self.cached_nodes = {i.id: i for i in instances}
return [i.id for i in instances]
def is_running(self, node_id):
@@ -49,8 +58,22 @@ class AWSNodeProvider(NodeProvider):
return tags
def external_ip(self, node_id):
if node_id in self.external_ip_cache:
return self.external_ip_cache[node_id]
node = self._node(node_id)
return node.public_ip_address
ip = node.public_ip_address
if ip:
self.external_ip_cache[node_id] = ip
return ip
def internal_ip(self, node_id):
if node_id in self.internal_ip_cache:
return self.internal_ip_cache[node_id]
node = self._node(node_id)
ip = node.private_ip_address
if ip:
self.internal_ip_cache[node_id] = ip
return ip
def set_node_tags(self, node_id, tags):
node = self._node(node_id)
@@ -90,6 +113,8 @@ class AWSNodeProvider(NodeProvider):
node.terminate()
def _node(self, node_id):
if node_id in self.cached_nodes:
return self.cached_nodes[node_id]
matches = list(self.ec2.instances.filter(InstanceIds=[node_id]))
assert len(matches) == 1, "Invalid instance id {}".format(node_id)
return matches[0]
+67 -43
View File
@@ -14,12 +14,12 @@ from ray.autoscaler.autoscaler import validate_config, hash_runtime_conf, \
hash_launch_conf
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
TAG_RAY_RUNTIME_CONFIG, TAG_NAME
TAG_NAME
from ray.autoscaler.updater import NodeUpdaterProcess
def create_or_update_cluster(
config_file, override_min_workers, override_max_workers, sync_only):
config_file, override_min_workers, override_max_workers, no_restart):
"""Create or updates an autoscaling Ray cluster from a config json."""
config = yaml.load(open(config_file).read())
@@ -29,9 +29,6 @@ def create_or_update_cluster(
config["min_workers"] = override_min_workers
if override_max_workers is not None:
config["max_workers"] = override_max_workers
if sync_only:
config["worker_init_commands"] = []
config["head_init_commands"] = []
importer = NODE_PROVIDERS.get(config["provider"]["type"])
if not importer:
@@ -40,7 +37,7 @@ def create_or_update_cluster(
bootstrap_config, _ = importer()
config = bootstrap_config(config)
get_or_create_head_node(config)
get_or_create_head_node(config, no_restart)
def teardown_cluster(config_file):
@@ -48,6 +45,8 @@ def teardown_cluster(config_file):
config = yaml.load(open(config_file).read())
confirm("This will destroy your cluster")
validate_config(config)
provider = get_node_provider(config["provider"], config["cluster_name"])
head_node_tags = {
@@ -65,7 +64,7 @@ def teardown_cluster(config_file):
nodes = provider.nodes({})
def get_or_create_head_node(config):
def get_or_create_head_node(config, no_restart):
"""Create the cluster head node, which in turn creates the workers."""
provider = get_node_provider(config["provider"], config["cluster_name"])
@@ -78,6 +77,11 @@ def get_or_create_head_node(config):
else:
head_node = None
if not head_node:
confirm("This will create a new cluster")
elif not no_restart:
confirm("This will restart your cluster")
launch_hash = hash_launch_conf(config["head_node"], config["auth"])
if head_node is None or provider.node_tags(head_node).get(
TAG_RAY_LAUNCH_CONFIG) != launch_hash:
@@ -93,48 +97,60 @@ def get_or_create_head_node(config):
assert len(nodes) == 1, "Failed to create head node."
head_node = nodes[0]
# TODO(ekl) right now we always update the head node even if the hash
# matches. We could prompt the user for what they want to do in this case.
runtime_hash = hash_runtime_conf(config["file_mounts"], config)
print("Updating files on head node...")
if provider.node_tags(head_node).get(
TAG_RAY_RUNTIME_CONFIG) != runtime_hash:
print("Updating files on head node...")
# Rewrite the auth config so that the head node can update the workers
remote_key_path = "~/ray_bootstrap_key.pem"
remote_config = copy.deepcopy(config)
remote_config["auth"]["ssh_private_key"] = remote_key_path
# Rewrite the auth config so that the head node can update the workers
remote_key_path = "~/ray_bootstrap_key.pem"
remote_config = copy.deepcopy(config)
remote_config["auth"]["ssh_private_key"] = remote_key_path
# Adjust for new file locations
new_mounts = {}
for remote_path in config["file_mounts"]:
new_mounts[remote_path] = remote_path
remote_config["file_mounts"] = new_mounts
remote_config["no_restart"] = no_restart
# Adjust for new file locations
new_mounts = {}
for remote_path in config["file_mounts"]:
new_mounts[remote_path] = remote_path
remote_config["file_mounts"] = new_mounts
# Now inject the rewritten config and SSH key into the head node
remote_config_file = tempfile.NamedTemporaryFile(
"w", prefix="ray-bootstrap-")
remote_config_file.write(json.dumps(remote_config))
remote_config_file.flush()
config["file_mounts"].update({
remote_key_path: config["auth"]["ssh_private_key"],
"~/ray_bootstrap_config.yaml": remote_config_file.name
})
# Now inject the rewritten config and SSH key into the head node
remote_config_file = tempfile.NamedTemporaryFile(
"w", prefix="ray-bootstrap-")
remote_config_file.write(json.dumps(remote_config))
remote_config_file.flush()
config["file_mounts"].update({
remote_key_path: config["auth"]["ssh_private_key"],
"~/ray_bootstrap_config.yaml": remote_config_file.name
})
if no_restart:
init_commands = (
config["setup_commands"] + config["head_setup_commands"])
else:
init_commands = (
config["setup_commands"] + config["head_setup_commands"] +
config["head_start_ray_commands"])
updater = NodeUpdaterProcess(
head_node,
config["provider"],
config["auth"],
config["cluster_name"],
config["file_mounts"],
config["head_init_commands"],
runtime_hash,
redirect_output=False)
updater.start()
updater.join()
if updater.exitcode != 0:
print("Error: updating {} failed".format(
provider.external_ip(head_node)))
sys.exit(1)
updater = NodeUpdaterProcess(
head_node,
config["provider"],
config["auth"],
config["cluster_name"],
config["file_mounts"],
init_commands,
runtime_hash,
redirect_output=False)
updater.start()
updater.join()
# Refresh the node cache so we see the external ip if available
provider.nodes(head_node_tags)
if updater.exitcode != 0:
print("Error: updating {} failed".format(
provider.external_ip(head_node)))
sys.exit(1)
print(
"Head node up-to-date, IP address is: {}".format(
provider.external_ip(head_node)))
@@ -150,3 +166,11 @@ def get_or_create_head_node(config):
config["auth"]["ssh_private_key"],
config["auth"]["ssh_user"],
provider.external_ip(head_node)))
def confirm(msg):
print("{}. Do you want to continue [y/N]? ".format(msg), end="")
answer = input()
if answer.strip().lower() != "y":
print("Abort.")
exit(1)
+8 -1
View File
@@ -46,7 +46,10 @@ class NodeProvider(object):
def nodes(self, tag_filters):
"""Return a list of node ids filtered by the specified tags dict.
This list must not include terminated nodes.
This list must not include terminated nodes. For performance reasons,
providers are allowed to cache the result of a call to nodes() to
serve single-node queries (e.g. is_running(node_id)). This means that
nodes() must be called again to refresh results.
Examples:
>>> provider.nodes({TAG_RAY_NODE_TYPE: "Worker"})
@@ -70,6 +73,10 @@ class NodeProvider(object):
"""Returns the external ip of the given node."""
raise NotImplementedError
def internal_ip(self, node_id):
"""Returns the internal ip (Ray ip) of the given node."""
raise NotImplementedError
def create_node(self, node_config, tags, count):
"""Creates a number of nodes within the namespace."""
raise NotImplementedError
+10 -6
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pipes
import os
import subprocess
import sys
@@ -23,7 +24,7 @@ class NodeUpdater(object):
def __init__(
self, node_id, provider_config, auth_config, cluster_name,
file_mounts, init_cmds, runtime_hash, redirect_output=True,
file_mounts, setup_cmds, runtime_hash, redirect_output=True,
process_runner=subprocess):
self.daemon = True
self.process_runner = process_runner
@@ -33,7 +34,7 @@ class NodeUpdater(object):
self.ssh_ip = self.provider.external_ip(node_id)
self.node_id = node_id
self.file_mounts = file_mounts
self.init_cmds = init_cmds
self.setup_cmds = setup_cmds
self.runtime_hash = runtime_hash
if redirect_output:
self.logfile = tempfile.NamedTemporaryFile(
@@ -93,6 +94,7 @@ class NodeUpdater(object):
assert self.ssh_ip is not None, "Unable to find IP of node"
# Wait for SSH access
ssh_ok = False
while time.time() < deadline and \
not self.provider.is_terminated(self.node_id):
try:
@@ -105,6 +107,7 @@ class NodeUpdater(object):
self.ssh_cmd(
"uptime",
connect_timeout=5, redirect=open("/dev/null", "w"))
ssh_ok = True
except Exception as e:
print(
"NodeUpdater: SSH not up, retrying: {}".format(e),
@@ -112,7 +115,7 @@ class NodeUpdater(object):
time.sleep(5)
else:
break
assert not self.provider.is_terminated(self.node_id)
assert ssh_ok, "Unable to SSH to node"
# Rsync file mounts
self.provider.set_node_tags(
@@ -139,8 +142,8 @@ class NodeUpdater(object):
# Run init commands
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "RunningInitCmds"})
for cmd in self.init_cmds:
self.node_id, {TAG_RAY_NODE_STATUS: "SettingUp"})
for cmd in self.setup_cmds:
self.ssh_cmd(cmd, verbose=True)
def ssh_cmd(self, cmd, connect_timeout=60, redirect=None, verbose=False):
@@ -149,12 +152,13 @@ class NodeUpdater(object):
"NodeUpdater: running {} on {}...".format(
cmd, self.ssh_ip),
file=self.stdout)
force_interactive = "set -i && source ~/.bashrc && "
self.process_runner.check_call([
"ssh", "-o", "ConnectTimeout={}s".format(connect_timeout),
"-o", "StrictHostKeyChecking=no",
"-i", self.ssh_private_key,
"{}@{}".format(self.ssh_user, self.ssh_ip),
cmd,
"bash --login -c {}".format(pipes.quote(force_interactive + cmd))
], stdout=redirect or self.stdout, stderr=redirect or self.stderr)
+44 -3
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import argparse
import binascii
import json
import logging
import os
@@ -14,9 +15,11 @@ import ray.utils
import redis
# Import flatbuffer bindings.
from ray.core.generated.DriverTableMessage import DriverTableMessage
from ray.core.generated.LocalSchedulerInfoMessage import \
LocalSchedulerInfoMessage
from ray.core.generated.SubscribeToDBClientTableReply import \
SubscribeToDBClientTableReply
from ray.autoscaler.autoscaler import StandardAutoscaler
SubscribeToDBClientTableReply
from ray.autoscaler.autoscaler import LoadMetrics, StandardAutoscaler
from ray.core.generated.TaskInfo import TaskInfo
from ray.services import get_ip_address, get_port
from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary
@@ -31,6 +34,7 @@ NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
TASK_STATUS_LOST = 32
# common/state/redis.cc
LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers"
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
DRIVER_DEATH_CHANNEL = b"driver_deaths"
@@ -92,8 +96,10 @@ class Monitor(object):
self.dead_local_schedulers = set()
self.live_plasma_managers = Counter()
self.dead_plasma_managers = set()
self.load_metrics = LoadMetrics()
if autoscaling_config:
self.autoscaler = StandardAutoscaler(autoscaling_config)
self.autoscaler = StandardAutoscaler(
autoscaling_config, self.load_metrics)
else:
self.autoscaler = None
@@ -286,6 +292,36 @@ class Monitor(object):
# already dead.
del self.live_plasma_managers[db_client_id]
def local_scheduler_info_handler(self, unused_channel, data):
"""Handle a local scheduler heartbeat from Redis."""
message = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage(
data, 0)
num_resources = message.DynamicResourcesLength()
static_resources = {}
dynamic_resources = {}
for i in range(num_resources):
dyn = message.DynamicResources(i)
static = message.StaticResources(i)
dynamic_resources[dyn.Key().decode("utf-8")] = dyn.Value()
static_resources[static.Key().decode("utf-8")] = static.Value()
client_id = binascii.hexlify(message.DbClientId()).decode("utf-8")
clients = ray.global_state.client_table()
local_schedulers = [
entry for client in clients.values() for entry in client
if (entry["ClientType"] == "local_scheduler" and not
entry["Deleted"])
]
ip = None
for ls in local_schedulers:
if ls["DBClientID"] == client_id:
ip = ls["AuxAddress"].split(":")[0]
if ip:
self.load_metrics.update(ip, static_resources, dynamic_resources)
else:
print("Warning: could not find ip for client {} in {}".format(
client_id, local_schedulers))
def plasma_manager_heartbeat_handler(self, unused_channel, data):
"""Handle a plasma manager heartbeat from Redis.
@@ -513,6 +549,10 @@ class Monitor(object):
assert self.subscribed[channel]
# The message was a heartbeat from a plasma manager.
message_handler = self.plasma_manager_heartbeat_handler
elif channel == LOCAL_SCHEDULER_INFO_CHANNEL:
assert self.subscribed[channel]
# The message was a heartbeat from a local scheduler
message_handler = self.local_scheduler_info_handler
elif channel == DB_CLIENT_TABLE_NAME:
assert self.subscribed[channel]
# The message was a notification from the db_client table.
@@ -537,6 +577,7 @@ class Monitor(object):
"""
# Initialize the subscription channel.
self.subscribe(DB_CLIENT_TABLE_NAME)
self.subscribe(LOCAL_SCHEDULER_INFO_CHANNEL)
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(DRIVER_DEATH_CHANNEL)
+20
View File
@@ -0,0 +1,20 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Ray constants used in the Python code."""
# Abort autoscaling if more than this number of errors are encountered. This
# is a safety feature to prevent e.g. runaway node launches.
AUTOSCALER_MAX_NUM_FAILURES = 5
# Max number of nodes to launch at a time.
AUTOSCALER_MAX_CONCURRENT_LAUNCHES = 10
# Interval at which to perform autoscaling updates.
AUTOSCALER_UPDATE_INTERVAL_S = 5
# The autoscaler will attempt to restart Ray on nodes it hasn't heard from
# in more than this interval.
AUTOSCALER_HEARTBEAT_TIMEOUT_S = 30
+5 -6
View File
@@ -245,10 +245,9 @@ def stop():
@click.command()
@click.argument("cluster_config_file", required=True, type=str)
@click.option(
"--sync-only", is_flag=True, default=False, help=(
"Whether to only perform the file sync stage when updating nodes. "
"This avoids interrupting running jobs. You can use this when "
"resizing the cluster with the min/max_workers flag."))
"--no-restart", is_flag=True, default=False, help=(
"Whether to skip restarting Ray services during the update. "
"This avoids interrupting running jobs."))
@click.option(
"--min-workers", required=False, type=int, help=(
"Override the configured min worker node count for the cluster."))
@@ -256,9 +255,9 @@ def stop():
"--max-workers", required=False, type=int, help=(
"Override the configured max worker node count for the cluster."))
def create_or_update(
cluster_config_file, min_workers, max_workers, sync_only):
cluster_config_file, min_workers, max_workers, no_restart):
create_or_update_cluster(
cluster_config_file, min_workers, max_workers, sync_only)
cluster_config_file, min_workers, max_workers, no_restart)
@click.command()
+6 -3
View File
@@ -554,7 +554,7 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
log_monitor_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"log_monitor.py")
p = subprocess.Popen([sys.executable, log_monitor_filepath,
p = subprocess.Popen([sys.executable, "-u", log_monitor_filepath,
"--redis-address", redis_address,
"--node-ip-address", node_ip_address],
stdout=stdout_file, stderr=stderr_file)
@@ -850,6 +850,7 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
default.
"""
command = [sys.executable,
"-u",
worker_path,
"--node-ip-address=" + node_ip_address,
"--object-store-name=" + object_store_name,
@@ -884,6 +885,7 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"monitor.py")
command = [sys.executable,
"-u",
monitor_path,
"--redis-address=" + str(redis_address)]
if autoscaling_config:
@@ -1347,6 +1349,7 @@ def new_log_files(name, redirect_output):
date_str = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
log_stdout = "{}/{}-{}-{:05d}.out".format(logs_dir, name, date_str, log_id)
log_stderr = "{}/{}-{}-{:05d}.err".format(logs_dir, name, date_str, log_id)
log_stdout_file = open(log_stdout, "a")
log_stderr_file = open(log_stderr, "a")
# Line-buffer the output (mode 1)
log_stdout_file = open(log_stdout, "a", buffering=1)
log_stderr_file = open(log_stderr, "a", buffering=1)
return log_stdout_file, log_stderr_file