EC2 cluster setup scripts and initial version of auto-scaler (#1311)

This commit is contained in:
Eric Liang
2017-12-15 23:56:39 -08:00
committed by Robert Nishihara
parent 76b6b4a2d3
commit f5ea44338e
20 changed files with 1665 additions and 16 deletions
View File
+330
View File
@@ -0,0 +1,330 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import hashlib
import os
import subprocess
import traceback
from collections import defaultdict
import yaml
from ray.autoscaler.node_provider import get_node_provider
from ray.autoscaler.updater import NodeUpdaterProcess
from ray.autoscaler.tags import TAG_RAY_LAUNCH_CONFIG, \
TAG_RAY_RUNTIME_CONFIG, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE, TAG_NAME
import ray.services as services
CLUSTER_CONFIG_SCHEMA = {
# An unique identifier for the head node and workers of this cluster.
"cluster_name": str,
# The minimum number of workers nodes to launch in addition to the head
# node. This number should be >= 0.
"min_workers": int,
# The maximum number of workers nodes to launch in addition to the head
# node. This takes precedence over min_workers.
"max_workers": int,
# Cloud-provider specific configuration.
"provider": {
"type": str, # e.g. aws
"region": str, # e.g. us-east-1
},
# How Ray will authenticate with newly launched nodes.
"auth": dict,
# Provider-specific config for the head node, e.g. instance type.
"head_node": dict,
# Provider-specific config for worker nodes. e.g. instance type.
"worker_nodes": dict,
# Map of remote paths to local paths, e.g. {"/tmp/data": "/my/local/data"}
"file_mounts": dict,
# List of shell commands to run to initialize the head node.
"head_init_commands": list,
# List of shell commands to run to initialize workers.
"worker_init_commands": list,
}
# Abort autoscaling if more than this number of errors are encountered. This
# is a safety feature to prevent e.g. runaway node launches.
MAX_NUM_FAILURES = 5
# Max number of nodes to launch at a time.
MAX_CONCURRENT_LAUNCHES = 10
class StandardAutoscaler(object):
"""The autoscaling control loop for a Ray cluster.
There are two ways to start an autoscaling cluster: manually by running
`ray start --head --autoscaling-config=/path/to/config.yaml` on a
instance that has permission to launch other instances, or you can also use
`ray create_or_update /path/to/config.yaml` from your laptop, which will
configure the right AWS/Cloud roles automatically.
StandardAutoscaler's `update` method is periodically called by `monitor.py`
to add and remove nodes as necessary. Currently, load-based autoscaling is
not implemented, so all this class does is try to maintain a constant
cluster size.
StandardAutoscaler is also used to bootstrap clusters (by adding workers
until the target cluster size is met).
"""
def __init__(
self, config_path,
max_concurrent_launches=MAX_CONCURRENT_LAUNCHES,
max_failures=MAX_NUM_FAILURES, process_runner=subprocess,
verbose_updates=False, node_updater_cls=NodeUpdaterProcess):
self.config_path = config_path
self.reload_config(errors_fatal=True)
self.provider = get_node_provider(
self.config["provider"], self.config["cluster_name"])
self.max_failures = max_failures
self.max_concurrent_launches = max_concurrent_launches
self.verbose_updates = verbose_updates
self.process_runner = process_runner
self.node_updater_cls = node_updater_cls
# Map from node_id to NodeUpdater processes
self.updaters = {}
self.num_failed_updates = defaultdict(int)
self.num_failures = 0
for local_path in self.config["file_mounts"].values():
assert os.path.exists(local_path)
print("StandardAutoscaler: {}".format(self.config))
def update(self):
try:
self.reload_config(errors_fatal=False)
self._update()
except Exception as e:
print(
"StandardAutoscaler: Error during autoscaling: {}",
traceback.format_exc())
self.num_failures += 1
if self.num_failures > self.max_failures:
print("*** StandardAutoscaler: Too many errors, abort. ***")
raise e
def _update(self):
nodes = self.workers()
target_num_workers = self.config["max_workers"]
# Terminate nodes while there are too many
while len(nodes) > target_num_workers:
print(
"StandardAutoscaler: Terminating unneeded node: "
"{}".format(nodes[-1]))
self.provider.terminate_node(nodes[-1])
nodes = self.workers()
print(self.debug_string())
if target_num_workers == 0:
return
# Update nodes with out-of-date files
for node_id in nodes:
self.update_if_needed(node_id)
# Launch a new node if needed
if len(nodes) < target_num_workers:
self.launch_new_node(
min(
self.max_concurrent_launches,
target_num_workers - len(nodes)))
print(self.debug_string())
return
else:
# If enough nodes, terminate an out-of-date node.
for node_id in nodes:
if not self.launch_config_ok(node_id):
print(
"StandardAutoscaler: Terminating outdated node: "
"{}".format(node_id))
self.provider.terminate_node(node_id)
print(self.debug_string())
return
# Process any completed updates
completed = []
for node_id, updater in self.updaters.items():
if not updater.is_alive():
completed.append(node_id)
if completed:
for node_id in completed:
if self.updaters[node_id].exitcode != 0:
self.num_failed_updates[node_id] += 1
del self.updaters[node_id]
print(self.debug_string())
def reload_config(self, errors_fatal=False):
try:
with open(self.config_path) as f:
new_config = yaml.load(f.read())
validate_config(new_config)
new_launch_hash = hash_launch_conf(
new_config["worker_nodes"], new_config["auth"])
new_runtime_hash = hash_runtime_conf(
new_config["file_mounts"], new_config["worker_init_commands"])
self.config = new_config
self.launch_hash = new_launch_hash
self.runtime_hash = new_runtime_hash
except Exception as e:
if errors_fatal:
raise e
else:
print(
"StandardAutoscaler: Error parsing config: {}",
traceback.format_exc())
def launch_config_ok(self, node_id):
launch_conf = self.provider.node_tags(node_id).get(
TAG_RAY_LAUNCH_CONFIG)
if self.launch_hash != launch_conf:
return False
return True
def files_up_to_date(self, node_id):
applied = self.provider.node_tags(node_id).get(TAG_RAY_RUNTIME_CONFIG)
if applied != self.runtime_hash:
print(
"StandardAutoscaler: {} has runtime state {}, want {}".format(
node_id, applied, self.runtime_hash))
return False
return True
def update_if_needed(self, node_id):
if not self.provider.is_running(node_id):
return
if not self.launch_config_ok(node_id):
return
if node_id in self.updaters:
return
if self.num_failed_updates.get(node_id, 0) > 0: # TODO(ekl) retry?
return
if self.files_up_to_date(node_id):
return
updater = self.node_updater_cls(
node_id,
self.config["provider"],
self.config["auth"],
self.config["cluster_name"],
self.config["file_mounts"],
with_head_node_ip(self.config["worker_init_commands"]),
self.runtime_hash,
redirect_output=not self.verbose_updates,
process_runner=self.process_runner)
updater.start()
self.updaters[node_id] = updater
def launch_new_node(self, count):
print("StandardAutoscaler: Launching {} new nodes".format(count))
num_before = len(self.workers())
self.provider.create_node(
self.config["worker_nodes"],
{
TAG_NAME: "ray-{}-worker".format(self.config["cluster_name"]),
TAG_RAY_NODE_TYPE: "Worker",
TAG_RAY_NODE_STATUS: "Uninitialized",
TAG_RAY_LAUNCH_CONFIG: self.launch_hash,
},
count)
# TODO(ekl) be less conservative in this check
assert len(self.workers()) > num_before, \
"Num nodes failed to increase after creating a new node"
def workers(self):
return self.provider.nodes(tag_filters={
TAG_RAY_NODE_TYPE: "Worker",
})
def debug_string(self, nodes=None):
if nodes is None:
nodes = self.workers()
target_num_workers = self.config["max_workers"]
suffix = ""
if self.updaters:
suffix += " ({} updating)".format(len(self.updaters))
if self.num_failed_updates:
suffix += " ({} failed to update)".format(
len(self.num_failed_updates))
return "StandardAutoscaler: Have {} / {} target nodes{}".format(
len(nodes), target_num_workers, suffix)
def validate_config(config, schema=CLUSTER_CONFIG_SCHEMA):
if type(config) is not dict:
raise ValueError("Config is not a dictionary")
for k, v in schema.items():
if k not in config:
raise ValueError(
"Missing required config key `{}` of type {}".format(
k, v.__name__))
if isinstance(v, type):
if not isinstance(config[k], v):
raise ValueError(
"Config key `{}` has wrong type {}, expected {}".format(
k, type(config[k]).__name__, v.__name__))
else:
validate_config(config[k], schema[k])
for k in config.keys():
if k not in schema:
raise ValueError(
"Unexpected config key `{}` not in {}".format(
k, schema.keys()))
def with_head_node_ip(cmds):
head_ip = services.get_node_ip_address()
out = []
for cmd in cmds:
out.append("export RAY_HEAD_IP={}; {}".format(head_ip, cmd))
return out
def hash_launch_conf(node_conf, auth):
hasher = hashlib.sha1()
hasher.update(json.dumps([node_conf, auth]).encode("utf-8"))
return hasher.hexdigest()
def hash_runtime_conf(file_mounts, extra_objs):
hasher = hashlib.sha1()
def add_content_hashes(path):
if os.path.isdir(path):
dirs = []
for dirpath, _, filenames in os.walk(path):
dirs.append((dirpath, sorted(filenames)))
for dirpath, filenames in sorted(dirs):
hasher.update(dirpath.encode("utf-8"))
for name in filenames:
hasher.update(name.encode("utf-8"))
with open(os.path.join(dirpath, name), "rb") as f:
hasher.update(f.read())
else:
with open(path, 'r') as f:
hasher.update(f.read().encode("utf-8"))
hasher.update(json.dumps(sorted(file_mounts.items())).encode("utf-8"))
hasher.update(json.dumps(extra_objs).encode("utf-8"))
for local_path in sorted(file_mounts.values()):
add_content_hashes(local_path)
return hasher.hexdigest()
+249
View File
@@ -0,0 +1,249 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import logging
import os
import time
import boto3
RAY = "ray-autoscaler"
DEFAULT_RAY_INSTANCE_PROFILE = RAY
DEFAULT_RAY_IAM_ROLE = RAY
SECURITY_GROUP_TEMPLATE = RAY + "-{}"
def key_pair(i):
"""Returns the ith default (aws_key_pair_name, key_pair_path)."""
if i == 0:
return RAY, os.path.expanduser("~/.ssh/{}.pem".format(RAY))
return (
"{}_{}".format(RAY, i),
os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, i)))
# Suppress excessive connection dropped logs from boto
logging.getLogger("botocore").setLevel(logging.WARNING)
def bootstrap_aws(config):
# The head node needs to have an IAM role that allows it to create further
# EC2 instances.
config = _configure_iam_role(config)
# Configure SSH access, using an existing key pair if possible.
config = _configure_key_pair(config)
# Pick a reasonable subnet if not specified by the user.
config = _configure_subnet(config)
# Cluster workers should be in a security group that permits traffic within
# the group, and also SSH access from outside.
config = _configure_security_group(config)
return config
def _configure_iam_role(config):
if "IamInstanceProfile" in config["head_node"]:
return config
profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)
if profile is None:
print("Creating new instance profile {}".format(
DEFAULT_RAY_INSTANCE_PROFILE))
client = _client("iam", config)
client.create_instance_profile(
InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE)
profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)
time.sleep(15) # wait for propagation
assert profile is not None, "Failed to create instance profile"
if not profile.roles:
role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
if role is None:
print("Creating new role {}".format(DEFAULT_RAY_IAM_ROLE))
iam = _resource("iam", config)
iam.create_role(
RoleName=DEFAULT_RAY_IAM_ROLE,
AssumeRolePolicyDocument=json.dumps({
"Statement": [
{
"Effect": "Allow",
"Principal": {"Service": "ec2.amazonaws.com"},
"Action": "sts:AssumeRole",
},
],
}))
role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
assert role is not None, "Failed to create role"
role.attach_policy(
PolicyArn="arn:aws:iam::aws:policy/AmazonEC2FullAccess")
profile.add_role(RoleName=role.name)
time.sleep(15) # wait for propagation
print("Role not specified for head node, using {}".format(
profile.arn))
config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
return config
def _configure_key_pair(config):
if "ssh_private_key" in config["auth"]:
assert "KeyName" in config["head_node"]
assert "KeyName" in config["worker_nodes"]
return config
ec2 = _resource("ec2", config)
# Try a few times to get or create a good key pair.
for i in range(10):
key_name, key_path = key_pair(i)
key = _get_key(key_name, config)
# Found a good key.
if key and os.path.exists(key_path):
break
# We can safely create a new key.
if not key and not os.path.exists(key_path):
print("Creating new key pair {}".format(key_name))
key = ec2.create_key_pair(KeyName=key_name)
with open(key_path, "w") as f:
f.write(key.key_material)
os.chmod(key_path, 0o600)
break
assert key, "AWS keypair {} not found for {}".format(key_name, key_path)
assert os.path.exists(key_path), \
"Private key file {} not found for {}".format(key_path, key_name)
print("KeyName not specified for nodes, using {}".format(key_name))
config["auth"]["ssh_private_key"] = key_path
config["head_node"]["KeyName"] = key_name
config["worker_nodes"]["KeyName"] = key_name
return config
def _configure_subnet(config):
ec2 = _resource("ec2", config)
subnets = sorted(
[s for s in ec2.subnets.all() if s.state == "available"],
reverse=True, # sort from Z-A
key=lambda subnet: subnet.availability_zone)
default_subnet = subnets[0]
if "SubnetId" not in config["head_node"]:
config["head_node"]["SubnetId"] = default_subnet.id
print("SubnetId not specified for head node, using {} in {}".format(
default_subnet.id, default_subnet.availability_zone))
if "SubnetId" not in config["worker_nodes"]:
config["worker_nodes"]["SubnetId"] = default_subnet.id
print("SubnetId not specified for workers, using {} in {}".format(
default_subnet.id, default_subnet.availability_zone))
return config
def _configure_security_group(config):
if "SecurityGroupIds" in config["head_node"] and \
"SecurityGroupIds" in config["worker_nodes"]:
return config # have user-defined groups
group_name = SECURITY_GROUP_TEMPLATE.format(config["cluster_name"])
subnet = _get_subnet_or_die(config, config["worker_nodes"]["SubnetId"])
security_group = _get_security_group(config, subnet.vpc_id, group_name)
if security_group is None:
print("Creating new security group {}".format(group_name))
client = _client("ec2", config)
client.create_security_group(
Description="Auto-created security group for Ray workers",
GroupName=group_name,
VpcId=subnet.vpc_id)
security_group = _get_security_group(config, subnet.vpc_id, group_name)
assert security_group, "Failed to create security group"
if not security_group.ip_permissions:
security_group.authorize_ingress(
IpPermissions=[
{"FromPort": -1, "ToPort": -1, "IpProtocol": "-1",
"UserIdGroupPairs": [{"GroupId": security_group.id}]},
{"FromPort": 22, "ToPort": 22, "IpProtocol": "TCP",
"IpRanges": [{"CidrIp": "0.0.0.0/0"}]}])
if "SecurityGroupIds" not in config["head_node"]:
print("SecurityGroupIds not specified for head node, using {}".format(
security_group.group_name))
config["head_node"]["SecurityGroupIds"] = [security_group.id]
if "SecurityGroupIds" not in config["worker_nodes"]:
print("SecurityGroupIds not specified for workers, using {}".format(
security_group.group_name))
config["worker_nodes"]["SecurityGroupIds"] = [security_group.id]
return config
def _get_subnet_or_die(config, subnet_id):
ec2 = _resource("ec2", config)
subnet = list(
ec2.subnets.filter(Filters=[
{"Name": "subnet-id", "Values": [subnet_id]}]))
assert len(subnet) == 1, "Subnet not found"
subnet = subnet[0]
return subnet
def _get_security_group(config, vpc_id, group_name):
ec2 = _resource("ec2", config)
existing_groups = list(
ec2.security_groups.filter(Filters=[
{"Name": "vpc-id", "Values": [vpc_id]}]))
for sg in existing_groups:
if sg.group_name == group_name:
return sg
def _get_role(role_name, config):
iam = _resource("iam", config)
role = iam.Role(role_name)
try:
role.load()
return role
except Exception:
return None
def _get_instance_profile(profile_name, config):
iam = _resource("iam", config)
profile = iam.InstanceProfile(profile_name)
try:
profile.load()
return profile
except Exception:
return None
def _get_key(key_name, config):
ec2 = _resource("ec2", config)
for key in ec2.key_pairs.filter(
Filters=[{"Name": "key-name", "Values": [key_name]}]):
if key.name == key_name:
return key
def _client(name, config):
return boto3.client(name, config["provider"]["region"])
def _resource(name, config):
return boto3.resource(name, config["provider"]["region"])
+72
View File
@@ -0,0 +1,72 @@
# An unique identifier for the head node and workers of this cluster.
cluster_name: default
# The minimum number of workers nodes to launch in addition to the head
# node. This number should be >= 0.
min_workers: 0
# The maximum number of workers nodes to launch in addition to the head
# node. This takes precedence over min_workers.
max_workers: 2
# Cloud-provider specific configuration.
provider:
type: aws
region: us-east-1
# How Ray will authenticate with newly launched nodes.
auth:
ssh_user: ubuntu
# By default Ray creates a new private keypair, but you can also use your own.
# If you do so, make sure to also set "KeyName" in the head and worker node
# configurations below.
# ssh_private_key: /path/to/your/key.pem
# Provider-specific config for the head node, e.g. instance type. By default
# Ray will auto-configure unspecified fields such as SubnetId and KeyName.
# For more documentation on available fields, see:
# http://boto3.readthedocs.io/en/latest/reference/services/ec2.html#EC2.ServiceResource.create_instances
head_node:
InstanceType: m5.large
ImageId: ami-212d465b
# Additional options in the boto docs.
# Provider-specific config for worker nodes, e.g. instance type. By default
# Ray will auto-configure unspecified fields such as SubnetId and KeyName.
# For more documentation on available fields, see:
# http://boto3.readthedocs.io/en/latest/reference/services/ec2.html#EC2.ServiceResource.create_instances
worker_nodes:
InstanceType: m5.large
ImageId: ami-212d465b
# Run workers on spot by default. Comment this out to use on-demand.
InstanceMarketOptions:
MarketType: spot
# Additional options can be found in the boto docs, e.g.
# SpotOptions:
# MaxPrice: MAX_HOURLY_PRICE
# Additional options in the boto docs.
# Files or directories to copy to the head and worker nodes. The format is a
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
file_mounts: {
# "/path1/on/remote/machine": "/path1/on/local/machine",
# "/path2/on/remote/machine": "/path2/on/local/machine",
}
# List of shell commands to run to initialize the head node.
head_init_commands:
- cd ~/ray; git remote add eric https://github.com/ericl/ray.git || true
- cd ~/ray; git fetch eric && git reset --hard e1e97b3
- yes | ~/anaconda3/bin/conda install boto3=1.4.8 # 1.4.8 adds InstanceMarketOptions
- ~/.local/bin/ray stop
- ~/.local/bin/ray start --head --redis-port=6379 --autoscaling-config=~/ray_bootstrap_config.yaml
# List of shell commands to run to initialize workers.
worker_init_commands:
- cd ~/ray; git remote add eric https://github.com/ericl/ray.git || true
- cd ~/ray; git fetch eric && git reset --hard e1e97b3
- ~/.local/bin/ray stop
- ~/.local/bin/ray start --redis-address=$RAY_HEAD_IP:6379
@@ -0,0 +1,95 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import boto3
from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME
class AWSNodeProvider(NodeProvider):
def __init__(self, provider_config, cluster_name):
NodeProvider.__init__(self, provider_config, cluster_name)
self.ec2 = boto3.resource("ec2", region_name=provider_config["region"])
def nodes(self, tag_filters):
filters = [
{
"Name": "instance-state-name",
"Values": ["pending", "running"],
},
{
"Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
"Values": [self.cluster_name],
},
]
for k, v in tag_filters.items():
filters.append({
"Name": "tag:{}".format(k),
"Values": [v],
})
instances = list(self.ec2.instances.filter(Filters=filters))
return [i.id for i in instances]
def is_running(self, node_id):
node = self._node(node_id)
return node.state["Name"] == "running"
def is_terminated(self, node_id):
node = self._node(node_id)
state = node.state["Name"]
return state not in ["running", "pending"]
def node_tags(self, node_id):
node = self._node(node_id)
tags = {}
for tag in node.tags:
tags[tag["Key"]] = tag["Value"]
return tags
def external_ip(self, node_id):
node = self._node(node_id)
return node.public_ip_address
def set_node_tags(self, node_id, tags):
node = self._node(node_id)
tag_pairs = []
for k, v in tags.items():
tag_pairs.append({
"Key": k, "Value": v,
})
node.create_tags(Tags=tag_pairs)
def create_node(self, node_config, tags, count):
conf = node_config.copy()
tag_pairs = [{
"Key": TAG_RAY_CLUSTER_NAME,
"Value": self.cluster_name,
}]
for k, v in tags.items():
tag_pairs.append(
{
"Key": k,
"Value": v,
})
conf.update({
"MinCount": 1,
"MaxCount": count,
"TagSpecifications": conf.get("TagSpecifications", []) + [
{
"ResourceType": "instance",
"Tags": tag_pairs,
}
]
})
self.ec2.create_instances(**conf)
def terminate_node(self, node_id):
node = self._node(node_id)
node.terminate()
def _node(self, node_id):
matches = list(self.ec2.instances.filter(InstanceIds=[node_id]))
assert len(matches) == 1, "Invalid instance id {}".format(node_id)
return matches[0]
+146
View File
@@ -0,0 +1,146 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import json
import tempfile
import time
import sys
import yaml
from ray.autoscaler.autoscaler import validate_config, hash_runtime_conf, \
hash_launch_conf
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
TAG_RAY_RUNTIME_CONFIG, TAG_NAME
from ray.autoscaler.updater import NodeUpdaterProcess
def create_or_update_cluster(
config_file, override_min_workers, override_max_workers, sync_only):
"""Create or updates an autoscaling Ray cluster from a config json."""
config = yaml.load(open(config_file).read())
validate_config(config)
if override_min_workers is not None:
config["min_workers"] = override_min_workers
if override_max_workers is not None:
config["max_workers"] = override_max_workers
if sync_only:
config["worker_init_commands"] = []
config["head_init_commands"] = []
importer = NODE_PROVIDERS.get(config["provider"]["type"])
if not importer:
raise NotImplementedError(
"Unsupported provider {}".format(config["provider"]))
bootstrap_config, _ = importer()
config = bootstrap_config(config)
get_or_create_head_node(config)
def teardown_cluster(config_file):
"""Destroys all nodes of a Ray cluster described by a config json."""
config = yaml.load(open(config_file).read())
validate_config(config)
provider = get_node_provider(config["provider"], config["cluster_name"])
head_node_tags = {
TAG_RAY_NODE_TYPE: "Head",
}
for node in provider.nodes(head_node_tags):
print("Terminating head node {}".format(node))
provider.terminate_node(node)
nodes = provider.nodes({})
while nodes:
for node in nodes:
print("Terminating worker {}".format(node))
provider.terminate_node(node)
time.sleep(5)
nodes = provider.nodes({})
def get_or_create_head_node(config):
"""Create the cluster head node, which in turn creates the workers."""
provider = get_node_provider(config["provider"], config["cluster_name"])
head_node_tags = {
TAG_RAY_NODE_TYPE: "Head",
}
nodes = provider.nodes(head_node_tags)
if len(nodes) > 0:
head_node = nodes[0]
else:
head_node = None
launch_hash = hash_launch_conf(config["head_node"], config["auth"])
if head_node is None or provider.node_tags(head_node).get(
TAG_RAY_LAUNCH_CONFIG) != launch_hash:
if head_node is not None:
print("Terminating outdated head node {}".format(head_node))
provider.terminate_node(head_node)
print("Launching new head node...")
head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
head_node_tags[TAG_NAME] = "ray-{}-head".format(config["cluster_name"])
provider.create_node(config["head_node"], head_node_tags, 1)
nodes = provider.nodes(head_node_tags)
assert len(nodes) == 1, "Failed to create head node."
head_node = nodes[0]
runtime_hash = hash_runtime_conf(config["file_mounts"], config)
if provider.node_tags(head_node).get(
TAG_RAY_RUNTIME_CONFIG) != runtime_hash:
print("Updating files on head node...")
# Rewrite the auth config so that the head node can update the workers
remote_key_path = "~/ray_bootstrap_key.pem"
remote_config = copy.deepcopy(config)
remote_config["auth"]["ssh_private_key"] = remote_key_path
# Adjust for new file locations
new_mounts = {}
for remote_path in config["file_mounts"]:
new_mounts[remote_path] = remote_path
remote_config["file_mounts"] = new_mounts
# Now inject the rewritten config and SSH key into the head node
remote_config_file = tempfile.NamedTemporaryFile(
"w", prefix="ray-bootstrap-")
remote_config_file.write(json.dumps(remote_config))
remote_config_file.flush()
config["file_mounts"].update({
remote_key_path: config["auth"]["ssh_private_key"],
"~/ray_bootstrap_config.yaml": remote_config_file.name
})
updater = NodeUpdaterProcess(
head_node,
config["provider"],
config["auth"],
config["cluster_name"],
config["file_mounts"],
config["head_init_commands"],
runtime_hash,
redirect_output=False)
updater.start()
updater.join()
if updater.exitcode != 0:
print("Error: updating {} failed".format(
provider.external_ip(head_node)))
sys.exit(1)
print(
"Head node up-to-date, IP address is: {}".format(
provider.external_ip(head_node)))
print(
"To monitor auto-scaling activity, you can run:\n\n"
" ssh -i {} {}@{} 'tail -f /tmp/raylogs/monitor-*'\n".format(
config["auth"]["ssh_private_key"],
config["auth"]["ssh_user"],
provider.external_ip(head_node)))
+83
View File
@@ -0,0 +1,83 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def import_aws():
from ray.autoscaler.aws.config import bootstrap_aws
from ray.autoscaler.aws.node_provider import AWSNodeProvider
return bootstrap_aws, AWSNodeProvider
NODE_PROVIDERS = {
"aws": import_aws,
"gce": None, # TODO: support more node providers
"azure": None,
"kubernetes": None,
"docker": None,
"local_cluster": None,
}
def get_node_provider(provider_config, cluster_name):
importer = NODE_PROVIDERS.get(provider_config["type"])
if importer is None:
raise NotImplementedError(
"Unsupported node provider: {}".format(provider_config["type"]))
_, provider_cls = importer()
return provider_cls(provider_config, cluster_name)
class NodeProvider(object):
"""Interface for getting and returning nodes from a Cloud.
NodeProviders are namespaced by the `cluster_name` parameter; they only
operate on nodes within that namespace.
Nodes may be in one of three states: {pending, running, terminated}. Nodes
appear immediately once started by `create_node`, and transition
immediately to terminated when `terminate_node` is called.
"""
def __init__(self, provider_config, cluster_name):
self.provider_config = provider_config
self.cluster_name = cluster_name
def nodes(self, tag_filters):
"""Return a list of node ids filtered by the specified tags dict.
This list must not include terminated nodes.
Examples:
>>> provider.nodes({TAG_RAY_NODE_TYPE: "Worker"})
["node-1", "node-2"]
"""
raise NotImplementedError
def is_running(self, node_id):
"""Return whether the specified node is running."""
raise NotImplementedError
def is_terminated(self, node_id):
"""Return whether the specified node is terminated."""
raise NotImplementedError
def node_tags(self, node_id):
"""Returns the tags of the given node (string dict)."""
raise NotImplementedError
def external_ip(self, node_id):
"""Returns the external ip of the given node."""
raise NotImplementedError
def create_node(self, node_config, tags, count):
"""Creates a number of nodes within the namespace."""
raise NotImplementedError
def set_node_tags(self, node_id, tags):
"""Sets the tag values (string dict) for the specified node."""
raise NotImplementedError
def terminate_node(self, node_id):
"""Terminates the specified node."""
raise NotImplementedError
+23
View File
@@ -0,0 +1,23 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""The Ray autoscaler uses tags to associate metadata with instances."""
# Tag for the name of the node
TAG_NAME = "Name"
# Tag uniquely identifying all nodes of a cluster
TAG_RAY_CLUSTER_NAME = "ray:ClusterName"
# Tag for the type of node (e.g. Head, Worker)
TAG_RAY_NODE_TYPE = "ray:NodeType"
# Tag that reports the current state of the node (e.g. Updating, Up-to-date)
TAG_RAY_NODE_STATUS = "ray:NodeStatus"
# Hash of the node launch config, used to identify out-of-date nodes
TAG_RAY_LAUNCH_CONFIG = "ray:LaunchConfig"
# Hash of the node runtime config, used to determine if updates are needed
TAG_RAY_RUNTIME_CONFIG = "ray:RuntimeConfig"
+172
View File
@@ -0,0 +1,172 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import subprocess
import sys
import tempfile
import time
from multiprocessing import Process
from threading import Thread
from ray.autoscaler.node_provider import get_node_provider
from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG
# How long to wait for a node to start, in seconds
NODE_START_WAIT_S = 300
class NodeUpdater(object):
"""A process for syncing files and running init commands on a node."""
def __init__(
self, node_id, provider_config, auth_config, cluster_name,
file_mounts, init_cmds, runtime_hash, redirect_output=True,
process_runner=subprocess):
self.daemon = True
self.process_runner = process_runner
self.provider = get_node_provider(provider_config, cluster_name)
self.ssh_private_key = auth_config["ssh_private_key"]
self.ssh_user = auth_config["ssh_user"]
self.ssh_ip = self.provider.external_ip(node_id)
self.node_id = node_id
self.file_mounts = file_mounts
self.init_cmds = init_cmds
self.runtime_hash = runtime_hash
if redirect_output:
self.logfile = tempfile.NamedTemporaryFile(
mode="w", prefix="node-updater-", delete=False)
self.output_name = self.logfile.name
self.stdout = self.logfile
self.stderr = self.logfile
else:
self.logfile = None
self.output_name = "(console)"
self.stdout = sys.stdout
self.stderr = sys.stderr
def run(self):
print("NodeUpdater: Updating {} to {}, logging to {}".format(
self.node_id, self.runtime_hash, self.output_name))
try:
self.do_update()
except Exception as e:
print(
"NodeUpdater: Error updating {}, "
"see {} for remote logs".format(e, self.output_name),
file=self.stdout)
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "UpdateFailed"})
if self.logfile is not None:
print(
"----- BEGIN REMOTE LOGS -----\n" +
open(self.logfile.name).read() +
"\n----- END REMOTE LOGS -----")
raise e
self.provider.set_node_tags(
self.node_id, {
TAG_RAY_NODE_STATUS: "Up-to-date",
TAG_RAY_RUNTIME_CONFIG: self.runtime_hash
})
print(
"NodeUpdater: Applied config {} to node {}".format(
self.runtime_hash, self.node_id),
file=self.stdout)
def do_update(self):
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "WaitingForSSH"})
deadline = time.time() + NODE_START_WAIT_S
# Wait for external IP
while time.time() < deadline and \
not self.provider.is_terminated(self.node_id):
print(
"NodeUpdater: Waiting for IP of {}...".format(self.node_id),
file=self.stdout)
self.ssh_ip = self.provider.external_ip(self.node_id)
if self.ssh_ip is not None:
break
time.sleep(5)
assert self.ssh_ip is not None, "Unable to find IP of node"
# Wait for SSH access
while time.time() < deadline and \
not self.provider.is_terminated(self.node_id):
try:
print(
"NodeUpdater: Waiting for SSH to {}...".format(
self.node_id),
file=self.stdout)
if not self.provider.is_running(self.node_id):
raise Exception()
self.ssh_cmd(
"uptime",
connect_timeout=5, redirect=open("/dev/null", "w"))
except Exception as e:
print(
"NodeUpdater: SSH not up, retrying: {}".format(e),
file=self.stdout)
time.sleep(5)
else:
break
assert not self.provider.is_terminated(self.node_id)
# Rsync file mounts
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "SyncingFiles"})
for remote_path, local_path in self.file_mounts.items():
print(
"NodeUpdater: Syncing {} to {}...".format(
local_path, remote_path),
file=self.stdout)
assert os.path.exists(local_path)
if os.path.isdir(local_path):
if not local_path.endswith("/"):
local_path += "/"
if not remote_path.endswith("/"):
remote_path += "/"
self.ssh_cmd(
"mkdir -p {}".format(os.path.dirname(remote_path)))
self.process_runner.check_call([
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
"-o ConnectTimeout=60s -o StrictHostKeyChecking=no",
"--delete", "-avz", "{}".format(local_path),
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, remote_path)
], stdout=self.stdout, stderr=self.stderr)
# Run init commands
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "RunningInitCmds"})
for cmd in self.init_cmds:
self.ssh_cmd(cmd, verbose=True)
def ssh_cmd(self, cmd, connect_timeout=60, redirect=None, verbose=False):
if verbose:
print(
"NodeUpdater: running {} on {}...".format(
cmd, self.ssh_ip),
file=self.stdout)
self.process_runner.check_call([
"ssh", "-o", "ConnectTimeout={}s".format(connect_timeout),
"-o", "StrictHostKeyChecking=no",
"-i", self.ssh_private_key,
"{}@{}".format(self.ssh_user, self.ssh_ip),
cmd,
], stdout=redirect or self.stdout, stderr=redirect or self.stderr)
class NodeUpdaterProcess(NodeUpdater, Process):
def __init__(self, *args, **kwargs):
Process.__init__(self)
NodeUpdater.__init__(self, *args, **kwargs)
# Single-threaded version for unit tests
class NodeUpdaterThread(NodeUpdater, Thread):
def __init__(self, *args, **kwargs):
Thread.__init__(self)
NodeUpdater.__init__(self, *args, **kwargs)
self.exitcode = 0
+21 -2
View File
@@ -5,6 +5,7 @@ from __future__ import print_function
import argparse
import json
import logging
import os
import time
from collections import Counter, defaultdict
@@ -15,6 +16,7 @@ import redis
from ray.core.generated.DriverTableMessage import DriverTableMessage
from ray.core.generated.SubscribeToDBClientTableReply import \
SubscribeToDBClientTableReply
from ray.autoscaler.autoscaler import StandardAutoscaler
from ray.core.generated.TaskInfo import TaskInfo
from ray.services import get_ip_address, get_port
from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary
@@ -75,7 +77,7 @@ class Monitor(object):
managers that were up at one point and have died since then.
"""
def __init__(self, redis_address, redis_port):
def __init__(self, redis_address, redis_port, autoscaling_config):
# Initialize the Redis clients.
self.state = ray.experimental.state.GlobalState()
self.state._initialize_global_state(redis_address, redis_port)
@@ -90,6 +92,10 @@ class Monitor(object):
self.dead_local_schedulers = set()
self.live_plasma_managers = Counter()
self.dead_plasma_managers = set()
if autoscaling_config:
self.autoscaler = StandardAutoscaler(autoscaling_config)
else:
self.autoscaler = None
def subscribe(self, channel):
"""Subscribe to the given channel.
@@ -556,6 +562,9 @@ class Monitor(object):
# Handle messages from the subscription channels.
while True:
# Process autoscaling actions
if self.autoscaler:
self.autoscaler.update()
# Record how many dead local schedulers and plasma managers we had
# at the beginning of this round.
num_dead_local_schedulers = len(self.dead_local_schedulers)
@@ -604,6 +613,11 @@ if __name__ == "__main__":
required=True,
type=str,
help="the address to use for Redis")
parser.add_argument(
"--autoscaling-config",
required=False,
type=str,
help="the path to the autoscaling config file")
args = parser.parse_args()
redis_ip_address = get_ip_address(args.redis_address)
@@ -612,5 +626,10 @@ if __name__ == "__main__":
# Initialize the global state.
ray.global_state._initialize_global_state(redis_ip_address, redis_port)
monitor = Monitor(redis_ip_address, redis_port)
if args.autoscaling_config:
autoscaling_config = os.path.expanduser(args.autoscaling_config)
else:
autoscaling_config = None
monitor = Monitor(redis_ip_address, redis_port, autoscaling_config)
monitor.run()
+33 -2
View File
@@ -7,6 +7,7 @@ import json
import subprocess
import ray.services as services
from ray.autoscaler.commands import create_or_update_cluster, teardown_cluster
def check_no_existing_redis_clients(node_ip_address, redis_client):
@@ -76,10 +77,12 @@ def cli():
help="object store directory for memory mapped files")
@click.option("--huge-pages", is_flag=True, default=False,
help="enable support for huge pages in the object store")
@click.option("--autoscaling-config", required=False, type=str,
help="the file that contains the autoscaling config")
def start(node_ip_address, redis_address, redis_port, num_redis_shards,
redis_max_clients, object_manager_port, num_workers, num_cpus,
num_gpus, resources, head, no_ui, block, plasma_directory,
huge_pages):
huge_pages, autoscaling_config):
# Note that we redirect stdout and stderr to /dev/null because otherwise
# attempts to print may cause exceptions if a process is started inside of
# an SSH connection and the SSH connection dies. TODO(rkn): This is a
@@ -138,7 +141,8 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
redis_max_clients=redis_max_clients,
include_webui=(not no_ui),
plasma_directory=plasma_directory,
huge_pages=huge_pages)
huge_pages=huge_pages,
autoscaling_config=autoscaling_config)
print(address_info)
print("\nStarted Ray on this node. You can add additional nodes to "
"the cluster by calling\n\n"
@@ -238,8 +242,35 @@ def stop():
"awk '{ print $2 }') 2> /dev/null"], shell=True)
@click.command()
@click.argument("cluster_config_file", required=True, type=str)
@click.option(
"--sync-only", is_flag=True, default=False, help=(
"Whether to only perform the file sync stage when updating nodes. "
"This avoids interrupting running jobs. You can use this when "
"resizing the cluster with the min/max_workers flag."))
@click.option(
"--min-workers", required=False, type=int, help=(
"Override the configured min worker node count for the cluster."))
@click.option(
"--max-workers", required=False, type=int, help=(
"Override the configured max worker node count for the cluster."))
def create_or_update(
cluster_config_file, min_workers, max_workers, sync_only):
create_or_update_cluster(
cluster_config_file, min_workers, max_workers, sync_only)
@click.command()
@click.argument("cluster_config_file", required=True, type=str)
def teardown(cluster_config_file):
teardown_cluster(cluster_config_file)
cli.add_command(start)
cli.add_command(stop)
cli.add_command(create_or_update)
cli.add_command(teardown)
def main():
+19 -8
View File
@@ -4,6 +4,7 @@ from __future__ import print_function
import binascii
from collections import namedtuple, OrderedDict
from datetime import datetime
import cloudpickle
import json
import os
@@ -863,7 +864,7 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
def start_monitor(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=True):
stderr_file=None, cleanup=True, autoscaling_config=None):
"""Run a process to monitor the other processes.
Args:
@@ -878,12 +879,15 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
then this process will be killed by services.cleanup() when the
Python process that imported services exits. This is True by
default.
autoscaling_config: path to autoscaling config file.
"""
monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"monitor.py")
command = [sys.executable,
monitor_path,
"--redis-address=" + str(redis_address)]
if autoscaling_config:
command.append("--autoscaling-config=" + str(autoscaling_config))
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_WORKER].append(p)
@@ -908,7 +912,8 @@ def start_ray_processes(address_info=None,
start_workers_from_local_scheduler=True,
resources=None,
plasma_directory=None,
huge_pages=False):
huge_pages=False,
autoscaling_config=None):
"""Helper method to start Ray processes.
Args:
@@ -956,6 +961,7 @@ def start_ray_processes(address_info=None,
be created.
huge_pages: Boolean flag indicating whether to start the Object
Store with hugetlbfs support. Requires plasma_directory.
autoscaling_config: path to autoscaling config file.
Returns:
A dictionary of the address information for the processes that were
@@ -1006,7 +1012,8 @@ def start_ray_processes(address_info=None,
node_ip_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup)
cleanup=cleanup,
autoscaling_config=autoscaling_config)
if redis_shards == []:
# Get redis shards from primary redis instance.
@@ -1221,7 +1228,8 @@ def start_ray_head(address_info=None,
redis_max_clients=None,
include_webui=True,
plasma_directory=None,
huge_pages=False):
huge_pages=False,
autoscaling_config=None):
"""Start Ray in local mode.
Args:
@@ -1263,6 +1271,7 @@ def start_ray_head(address_info=None,
be created.
huge_pages: Boolean flag indicating whether to start the Object
Store with hugetlbfs support. Requires plasma_directory.
autoscaling_config: path to autoscaling config file.
Returns:
A dictionary of the address information for the processes that were
@@ -1287,7 +1296,8 @@ def start_ray_head(address_info=None,
num_redis_shards=num_redis_shards,
redis_max_clients=redis_max_clients,
plasma_directory=plasma_directory,
huge_pages=huge_pages)
huge_pages=huge_pages,
autoscaling_config=autoscaling_config)
def try_to_create_directory(directory_path):
@@ -1333,9 +1343,10 @@ def new_log_files(name, redirect_output):
# Create another directory that will be used by some of the RL algorithms.
try_to_create_directory("/tmp/ray")
log_id = random.randint(0, 1000000000)
log_stdout = "{}/{}-{:010d}.out".format(logs_dir, name, log_id)
log_stderr = "{}/{}-{:010d}.err".format(logs_dir, name, log_id)
log_id = random.randint(0, 10000)
date_str = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
log_stdout = "{}/{}-{}-{:05d}.out".format(logs_dir, name, date_str, log_id)
log_stderr = "{}/{}-{}-{:05d}.err".format(logs_dir, name, date_str, log_id)
log_stdout_file = open(log_stdout, "a")
log_stderr_file = open(log_stderr, "a")
return log_stdout_file, log_stderr_file
+1
View File
@@ -112,6 +112,7 @@ setup(name="ray",
"colorama",
"psutil",
"pytest",
"pyyaml",
"redis",
"cloudpickle == 0.5.2",
# The six module is required by pyarrow.