mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:18:59 +08:00
EC2 cluster setup scripts and initial version of auto-scaler (#1311)
This commit is contained in:
committed by
Robert Nishihara
parent
76b6b4a2d3
commit
f5ea44338e
@@ -0,0 +1,330 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import hashlib
|
||||
import os
|
||||
import subprocess
|
||||
import traceback
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import yaml
|
||||
|
||||
from ray.autoscaler.node_provider import get_node_provider
|
||||
from ray.autoscaler.updater import NodeUpdaterProcess
|
||||
from ray.autoscaler.tags import TAG_RAY_LAUNCH_CONFIG, \
|
||||
TAG_RAY_RUNTIME_CONFIG, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE, TAG_NAME
|
||||
import ray.services as services
|
||||
|
||||
|
||||
CLUSTER_CONFIG_SCHEMA = {
|
||||
# An unique identifier for the head node and workers of this cluster.
|
||||
"cluster_name": str,
|
||||
|
||||
# The minimum number of workers nodes to launch in addition to the head
|
||||
# node. This number should be >= 0.
|
||||
"min_workers": int,
|
||||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers.
|
||||
"max_workers": int,
|
||||
|
||||
# Cloud-provider specific configuration.
|
||||
"provider": {
|
||||
"type": str, # e.g. aws
|
||||
"region": str, # e.g. us-east-1
|
||||
},
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
"auth": dict,
|
||||
|
||||
# Provider-specific config for the head node, e.g. instance type.
|
||||
"head_node": dict,
|
||||
|
||||
# Provider-specific config for worker nodes. e.g. instance type.
|
||||
"worker_nodes": dict,
|
||||
|
||||
# Map of remote paths to local paths, e.g. {"/tmp/data": "/my/local/data"}
|
||||
"file_mounts": dict,
|
||||
|
||||
# List of shell commands to run to initialize the head node.
|
||||
"head_init_commands": list,
|
||||
|
||||
# List of shell commands to run to initialize workers.
|
||||
"worker_init_commands": list,
|
||||
}
|
||||
|
||||
|
||||
# Abort autoscaling if more than this number of errors are encountered. This
|
||||
# is a safety feature to prevent e.g. runaway node launches.
|
||||
MAX_NUM_FAILURES = 5
|
||||
|
||||
# Max number of nodes to launch at a time.
|
||||
MAX_CONCURRENT_LAUNCHES = 10
|
||||
|
||||
|
||||
class StandardAutoscaler(object):
|
||||
"""The autoscaling control loop for a Ray cluster.
|
||||
|
||||
There are two ways to start an autoscaling cluster: manually by running
|
||||
`ray start --head --autoscaling-config=/path/to/config.yaml` on a
|
||||
instance that has permission to launch other instances, or you can also use
|
||||
`ray create_or_update /path/to/config.yaml` from your laptop, which will
|
||||
configure the right AWS/Cloud roles automatically.
|
||||
|
||||
StandardAutoscaler's `update` method is periodically called by `monitor.py`
|
||||
to add and remove nodes as necessary. Currently, load-based autoscaling is
|
||||
not implemented, so all this class does is try to maintain a constant
|
||||
cluster size.
|
||||
|
||||
StandardAutoscaler is also used to bootstrap clusters (by adding workers
|
||||
until the target cluster size is met).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config_path,
|
||||
max_concurrent_launches=MAX_CONCURRENT_LAUNCHES,
|
||||
max_failures=MAX_NUM_FAILURES, process_runner=subprocess,
|
||||
verbose_updates=False, node_updater_cls=NodeUpdaterProcess):
|
||||
self.config_path = config_path
|
||||
self.reload_config(errors_fatal=True)
|
||||
self.provider = get_node_provider(
|
||||
self.config["provider"], self.config["cluster_name"])
|
||||
|
||||
self.max_failures = max_failures
|
||||
self.max_concurrent_launches = max_concurrent_launches
|
||||
self.verbose_updates = verbose_updates
|
||||
self.process_runner = process_runner
|
||||
self.node_updater_cls = node_updater_cls
|
||||
|
||||
# Map from node_id to NodeUpdater processes
|
||||
self.updaters = {}
|
||||
self.num_failed_updates = defaultdict(int)
|
||||
self.num_failures = 0
|
||||
|
||||
for local_path in self.config["file_mounts"].values():
|
||||
assert os.path.exists(local_path)
|
||||
|
||||
print("StandardAutoscaler: {}".format(self.config))
|
||||
|
||||
def update(self):
|
||||
try:
|
||||
self.reload_config(errors_fatal=False)
|
||||
self._update()
|
||||
except Exception as e:
|
||||
print(
|
||||
"StandardAutoscaler: Error during autoscaling: {}",
|
||||
traceback.format_exc())
|
||||
self.num_failures += 1
|
||||
if self.num_failures > self.max_failures:
|
||||
print("*** StandardAutoscaler: Too many errors, abort. ***")
|
||||
raise e
|
||||
|
||||
def _update(self):
|
||||
nodes = self.workers()
|
||||
target_num_workers = self.config["max_workers"]
|
||||
|
||||
# Terminate nodes while there are too many
|
||||
while len(nodes) > target_num_workers:
|
||||
print(
|
||||
"StandardAutoscaler: Terminating unneeded node: "
|
||||
"{}".format(nodes[-1]))
|
||||
self.provider.terminate_node(nodes[-1])
|
||||
nodes = self.workers()
|
||||
print(self.debug_string())
|
||||
|
||||
if target_num_workers == 0:
|
||||
return
|
||||
|
||||
# Update nodes with out-of-date files
|
||||
for node_id in nodes:
|
||||
self.update_if_needed(node_id)
|
||||
|
||||
# Launch a new node if needed
|
||||
if len(nodes) < target_num_workers:
|
||||
self.launch_new_node(
|
||||
min(
|
||||
self.max_concurrent_launches,
|
||||
target_num_workers - len(nodes)))
|
||||
print(self.debug_string())
|
||||
return
|
||||
else:
|
||||
# If enough nodes, terminate an out-of-date node.
|
||||
for node_id in nodes:
|
||||
if not self.launch_config_ok(node_id):
|
||||
print(
|
||||
"StandardAutoscaler: Terminating outdated node: "
|
||||
"{}".format(node_id))
|
||||
self.provider.terminate_node(node_id)
|
||||
print(self.debug_string())
|
||||
return
|
||||
|
||||
# Process any completed updates
|
||||
completed = []
|
||||
for node_id, updater in self.updaters.items():
|
||||
if not updater.is_alive():
|
||||
completed.append(node_id)
|
||||
if completed:
|
||||
for node_id in completed:
|
||||
if self.updaters[node_id].exitcode != 0:
|
||||
self.num_failed_updates[node_id] += 1
|
||||
del self.updaters[node_id]
|
||||
print(self.debug_string())
|
||||
|
||||
def reload_config(self, errors_fatal=False):
|
||||
try:
|
||||
with open(self.config_path) as f:
|
||||
new_config = yaml.load(f.read())
|
||||
validate_config(new_config)
|
||||
new_launch_hash = hash_launch_conf(
|
||||
new_config["worker_nodes"], new_config["auth"])
|
||||
new_runtime_hash = hash_runtime_conf(
|
||||
new_config["file_mounts"], new_config["worker_init_commands"])
|
||||
self.config = new_config
|
||||
self.launch_hash = new_launch_hash
|
||||
self.runtime_hash = new_runtime_hash
|
||||
except Exception as e:
|
||||
if errors_fatal:
|
||||
raise e
|
||||
else:
|
||||
print(
|
||||
"StandardAutoscaler: Error parsing config: {}",
|
||||
traceback.format_exc())
|
||||
|
||||
def launch_config_ok(self, node_id):
|
||||
launch_conf = self.provider.node_tags(node_id).get(
|
||||
TAG_RAY_LAUNCH_CONFIG)
|
||||
if self.launch_hash != launch_conf:
|
||||
return False
|
||||
return True
|
||||
|
||||
def files_up_to_date(self, node_id):
|
||||
applied = self.provider.node_tags(node_id).get(TAG_RAY_RUNTIME_CONFIG)
|
||||
if applied != self.runtime_hash:
|
||||
print(
|
||||
"StandardAutoscaler: {} has runtime state {}, want {}".format(
|
||||
node_id, applied, self.runtime_hash))
|
||||
return False
|
||||
return True
|
||||
|
||||
def update_if_needed(self, node_id):
|
||||
if not self.provider.is_running(node_id):
|
||||
return
|
||||
if not self.launch_config_ok(node_id):
|
||||
return
|
||||
if node_id in self.updaters:
|
||||
return
|
||||
if self.num_failed_updates.get(node_id, 0) > 0: # TODO(ekl) retry?
|
||||
return
|
||||
if self.files_up_to_date(node_id):
|
||||
return
|
||||
updater = self.node_updater_cls(
|
||||
node_id,
|
||||
self.config["provider"],
|
||||
self.config["auth"],
|
||||
self.config["cluster_name"],
|
||||
self.config["file_mounts"],
|
||||
with_head_node_ip(self.config["worker_init_commands"]),
|
||||
self.runtime_hash,
|
||||
redirect_output=not self.verbose_updates,
|
||||
process_runner=self.process_runner)
|
||||
updater.start()
|
||||
self.updaters[node_id] = updater
|
||||
|
||||
def launch_new_node(self, count):
|
||||
print("StandardAutoscaler: Launching {} new nodes".format(count))
|
||||
num_before = len(self.workers())
|
||||
self.provider.create_node(
|
||||
self.config["worker_nodes"],
|
||||
{
|
||||
TAG_NAME: "ray-{}-worker".format(self.config["cluster_name"]),
|
||||
TAG_RAY_NODE_TYPE: "Worker",
|
||||
TAG_RAY_NODE_STATUS: "Uninitialized",
|
||||
TAG_RAY_LAUNCH_CONFIG: self.launch_hash,
|
||||
},
|
||||
count)
|
||||
# TODO(ekl) be less conservative in this check
|
||||
assert len(self.workers()) > num_before, \
|
||||
"Num nodes failed to increase after creating a new node"
|
||||
|
||||
def workers(self):
|
||||
return self.provider.nodes(tag_filters={
|
||||
TAG_RAY_NODE_TYPE: "Worker",
|
||||
})
|
||||
|
||||
def debug_string(self, nodes=None):
|
||||
if nodes is None:
|
||||
nodes = self.workers()
|
||||
target_num_workers = self.config["max_workers"]
|
||||
suffix = ""
|
||||
if self.updaters:
|
||||
suffix += " ({} updating)".format(len(self.updaters))
|
||||
if self.num_failed_updates:
|
||||
suffix += " ({} failed to update)".format(
|
||||
len(self.num_failed_updates))
|
||||
return "StandardAutoscaler: Have {} / {} target nodes{}".format(
|
||||
len(nodes), target_num_workers, suffix)
|
||||
|
||||
|
||||
def validate_config(config, schema=CLUSTER_CONFIG_SCHEMA):
|
||||
if type(config) is not dict:
|
||||
raise ValueError("Config is not a dictionary")
|
||||
for k, v in schema.items():
|
||||
if k not in config:
|
||||
raise ValueError(
|
||||
"Missing required config key `{}` of type {}".format(
|
||||
k, v.__name__))
|
||||
if isinstance(v, type):
|
||||
if not isinstance(config[k], v):
|
||||
raise ValueError(
|
||||
"Config key `{}` has wrong type {}, expected {}".format(
|
||||
k, type(config[k]).__name__, v.__name__))
|
||||
else:
|
||||
validate_config(config[k], schema[k])
|
||||
for k in config.keys():
|
||||
if k not in schema:
|
||||
raise ValueError(
|
||||
"Unexpected config key `{}` not in {}".format(
|
||||
k, schema.keys()))
|
||||
|
||||
|
||||
def with_head_node_ip(cmds):
|
||||
head_ip = services.get_node_ip_address()
|
||||
out = []
|
||||
for cmd in cmds:
|
||||
out.append("export RAY_HEAD_IP={}; {}".format(head_ip, cmd))
|
||||
return out
|
||||
|
||||
|
||||
def hash_launch_conf(node_conf, auth):
|
||||
hasher = hashlib.sha1()
|
||||
hasher.update(json.dumps([node_conf, auth]).encode("utf-8"))
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def hash_runtime_conf(file_mounts, extra_objs):
|
||||
hasher = hashlib.sha1()
|
||||
|
||||
def add_content_hashes(path):
|
||||
if os.path.isdir(path):
|
||||
dirs = []
|
||||
for dirpath, _, filenames in os.walk(path):
|
||||
dirs.append((dirpath, sorted(filenames)))
|
||||
for dirpath, filenames in sorted(dirs):
|
||||
hasher.update(dirpath.encode("utf-8"))
|
||||
for name in filenames:
|
||||
hasher.update(name.encode("utf-8"))
|
||||
with open(os.path.join(dirpath, name), "rb") as f:
|
||||
hasher.update(f.read())
|
||||
else:
|
||||
with open(path, 'r') as f:
|
||||
hasher.update(f.read().encode("utf-8"))
|
||||
|
||||
hasher.update(json.dumps(sorted(file_mounts.items())).encode("utf-8"))
|
||||
hasher.update(json.dumps(extra_objs).encode("utf-8"))
|
||||
for local_path in sorted(file_mounts.values()):
|
||||
add_content_hashes(local_path)
|
||||
|
||||
return hasher.hexdigest()
|
||||
@@ -0,0 +1,249 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import boto3
|
||||
|
||||
RAY = "ray-autoscaler"
|
||||
DEFAULT_RAY_INSTANCE_PROFILE = RAY
|
||||
DEFAULT_RAY_IAM_ROLE = RAY
|
||||
SECURITY_GROUP_TEMPLATE = RAY + "-{}"
|
||||
|
||||
|
||||
def key_pair(i):
|
||||
"""Returns the ith default (aws_key_pair_name, key_pair_path)."""
|
||||
if i == 0:
|
||||
return RAY, os.path.expanduser("~/.ssh/{}.pem".format(RAY))
|
||||
return (
|
||||
"{}_{}".format(RAY, i),
|
||||
os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, i)))
|
||||
|
||||
|
||||
# Suppress excessive connection dropped logs from boto
|
||||
logging.getLogger("botocore").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def bootstrap_aws(config):
|
||||
# The head node needs to have an IAM role that allows it to create further
|
||||
# EC2 instances.
|
||||
config = _configure_iam_role(config)
|
||||
|
||||
# Configure SSH access, using an existing key pair if possible.
|
||||
config = _configure_key_pair(config)
|
||||
|
||||
# Pick a reasonable subnet if not specified by the user.
|
||||
config = _configure_subnet(config)
|
||||
|
||||
# Cluster workers should be in a security group that permits traffic within
|
||||
# the group, and also SSH access from outside.
|
||||
config = _configure_security_group(config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _configure_iam_role(config):
|
||||
if "IamInstanceProfile" in config["head_node"]:
|
||||
return config
|
||||
|
||||
profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)
|
||||
|
||||
if profile is None:
|
||||
print("Creating new instance profile {}".format(
|
||||
DEFAULT_RAY_INSTANCE_PROFILE))
|
||||
client = _client("iam", config)
|
||||
client.create_instance_profile(
|
||||
InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE)
|
||||
profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)
|
||||
time.sleep(15) # wait for propagation
|
||||
|
||||
assert profile is not None, "Failed to create instance profile"
|
||||
|
||||
if not profile.roles:
|
||||
role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
|
||||
if role is None:
|
||||
print("Creating new role {}".format(DEFAULT_RAY_IAM_ROLE))
|
||||
iam = _resource("iam", config)
|
||||
iam.create_role(
|
||||
RoleName=DEFAULT_RAY_IAM_ROLE,
|
||||
AssumeRolePolicyDocument=json.dumps({
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"Service": "ec2.amazonaws.com"},
|
||||
"Action": "sts:AssumeRole",
|
||||
},
|
||||
],
|
||||
}))
|
||||
role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
|
||||
assert role is not None, "Failed to create role"
|
||||
role.attach_policy(
|
||||
PolicyArn="arn:aws:iam::aws:policy/AmazonEC2FullAccess")
|
||||
profile.add_role(RoleName=role.name)
|
||||
time.sleep(15) # wait for propagation
|
||||
|
||||
print("Role not specified for head node, using {}".format(
|
||||
profile.arn))
|
||||
config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _configure_key_pair(config):
|
||||
if "ssh_private_key" in config["auth"]:
|
||||
assert "KeyName" in config["head_node"]
|
||||
assert "KeyName" in config["worker_nodes"]
|
||||
return config
|
||||
|
||||
ec2 = _resource("ec2", config)
|
||||
|
||||
# Try a few times to get or create a good key pair.
|
||||
for i in range(10):
|
||||
key_name, key_path = key_pair(i)
|
||||
key = _get_key(key_name, config)
|
||||
|
||||
# Found a good key.
|
||||
if key and os.path.exists(key_path):
|
||||
break
|
||||
|
||||
# We can safely create a new key.
|
||||
if not key and not os.path.exists(key_path):
|
||||
print("Creating new key pair {}".format(key_name))
|
||||
key = ec2.create_key_pair(KeyName=key_name)
|
||||
with open(key_path, "w") as f:
|
||||
f.write(key.key_material)
|
||||
os.chmod(key_path, 0o600)
|
||||
break
|
||||
|
||||
assert key, "AWS keypair {} not found for {}".format(key_name, key_path)
|
||||
assert os.path.exists(key_path), \
|
||||
"Private key file {} not found for {}".format(key_path, key_name)
|
||||
|
||||
print("KeyName not specified for nodes, using {}".format(key_name))
|
||||
|
||||
config["auth"]["ssh_private_key"] = key_path
|
||||
config["head_node"]["KeyName"] = key_name
|
||||
config["worker_nodes"]["KeyName"] = key_name
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _configure_subnet(config):
|
||||
ec2 = _resource("ec2", config)
|
||||
subnets = sorted(
|
||||
[s for s in ec2.subnets.all() if s.state == "available"],
|
||||
reverse=True, # sort from Z-A
|
||||
key=lambda subnet: subnet.availability_zone)
|
||||
default_subnet = subnets[0]
|
||||
|
||||
if "SubnetId" not in config["head_node"]:
|
||||
config["head_node"]["SubnetId"] = default_subnet.id
|
||||
print("SubnetId not specified for head node, using {} in {}".format(
|
||||
default_subnet.id, default_subnet.availability_zone))
|
||||
|
||||
if "SubnetId" not in config["worker_nodes"]:
|
||||
config["worker_nodes"]["SubnetId"] = default_subnet.id
|
||||
print("SubnetId not specified for workers, using {} in {}".format(
|
||||
default_subnet.id, default_subnet.availability_zone))
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _configure_security_group(config):
|
||||
if "SecurityGroupIds" in config["head_node"] and \
|
||||
"SecurityGroupIds" in config["worker_nodes"]:
|
||||
return config # have user-defined groups
|
||||
|
||||
group_name = SECURITY_GROUP_TEMPLATE.format(config["cluster_name"])
|
||||
subnet = _get_subnet_or_die(config, config["worker_nodes"]["SubnetId"])
|
||||
security_group = _get_security_group(config, subnet.vpc_id, group_name)
|
||||
|
||||
if security_group is None:
|
||||
print("Creating new security group {}".format(group_name))
|
||||
client = _client("ec2", config)
|
||||
client.create_security_group(
|
||||
Description="Auto-created security group for Ray workers",
|
||||
GroupName=group_name,
|
||||
VpcId=subnet.vpc_id)
|
||||
security_group = _get_security_group(config, subnet.vpc_id, group_name)
|
||||
assert security_group, "Failed to create security group"
|
||||
|
||||
if not security_group.ip_permissions:
|
||||
security_group.authorize_ingress(
|
||||
IpPermissions=[
|
||||
{"FromPort": -1, "ToPort": -1, "IpProtocol": "-1",
|
||||
"UserIdGroupPairs": [{"GroupId": security_group.id}]},
|
||||
{"FromPort": 22, "ToPort": 22, "IpProtocol": "TCP",
|
||||
"IpRanges": [{"CidrIp": "0.0.0.0/0"}]}])
|
||||
|
||||
if "SecurityGroupIds" not in config["head_node"]:
|
||||
print("SecurityGroupIds not specified for head node, using {}".format(
|
||||
security_group.group_name))
|
||||
config["head_node"]["SecurityGroupIds"] = [security_group.id]
|
||||
|
||||
if "SecurityGroupIds" not in config["worker_nodes"]:
|
||||
print("SecurityGroupIds not specified for workers, using {}".format(
|
||||
security_group.group_name))
|
||||
config["worker_nodes"]["SecurityGroupIds"] = [security_group.id]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _get_subnet_or_die(config, subnet_id):
|
||||
ec2 = _resource("ec2", config)
|
||||
subnet = list(
|
||||
ec2.subnets.filter(Filters=[
|
||||
{"Name": "subnet-id", "Values": [subnet_id]}]))
|
||||
assert len(subnet) == 1, "Subnet not found"
|
||||
subnet = subnet[0]
|
||||
return subnet
|
||||
|
||||
|
||||
def _get_security_group(config, vpc_id, group_name):
|
||||
ec2 = _resource("ec2", config)
|
||||
existing_groups = list(
|
||||
ec2.security_groups.filter(Filters=[
|
||||
{"Name": "vpc-id", "Values": [vpc_id]}]))
|
||||
for sg in existing_groups:
|
||||
if sg.group_name == group_name:
|
||||
return sg
|
||||
|
||||
|
||||
def _get_role(role_name, config):
|
||||
iam = _resource("iam", config)
|
||||
role = iam.Role(role_name)
|
||||
try:
|
||||
role.load()
|
||||
return role
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_instance_profile(profile_name, config):
|
||||
iam = _resource("iam", config)
|
||||
profile = iam.InstanceProfile(profile_name)
|
||||
try:
|
||||
profile.load()
|
||||
return profile
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_key(key_name, config):
|
||||
ec2 = _resource("ec2", config)
|
||||
for key in ec2.key_pairs.filter(
|
||||
Filters=[{"Name": "key-name", "Values": [key_name]}]):
|
||||
if key.name == key_name:
|
||||
return key
|
||||
|
||||
|
||||
def _client(name, config):
|
||||
return boto3.client(name, config["provider"]["region"])
|
||||
|
||||
|
||||
def _resource(name, config):
|
||||
return boto3.resource(name, config["provider"]["region"])
|
||||
@@ -0,0 +1,72 @@
|
||||
# An unique identifier for the head node and workers of this cluster.
|
||||
cluster_name: default
|
||||
|
||||
# The minimum number of workers nodes to launch in addition to the head
|
||||
# node. This number should be >= 0.
|
||||
min_workers: 0
|
||||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers.
|
||||
max_workers: 2
|
||||
|
||||
# Cloud-provider specific configuration.
|
||||
provider:
|
||||
type: aws
|
||||
region: us-east-1
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
auth:
|
||||
ssh_user: ubuntu
|
||||
# By default Ray creates a new private keypair, but you can also use your own.
|
||||
# If you do so, make sure to also set "KeyName" in the head and worker node
|
||||
# configurations below.
|
||||
# ssh_private_key: /path/to/your/key.pem
|
||||
|
||||
# Provider-specific config for the head node, e.g. instance type. By default
|
||||
# Ray will auto-configure unspecified fields such as SubnetId and KeyName.
|
||||
# For more documentation on available fields, see:
|
||||
# http://boto3.readthedocs.io/en/latest/reference/services/ec2.html#EC2.ServiceResource.create_instances
|
||||
head_node:
|
||||
InstanceType: m5.large
|
||||
ImageId: ami-212d465b
|
||||
|
||||
# Additional options in the boto docs.
|
||||
|
||||
# Provider-specific config for worker nodes, e.g. instance type. By default
|
||||
# Ray will auto-configure unspecified fields such as SubnetId and KeyName.
|
||||
# For more documentation on available fields, see:
|
||||
# http://boto3.readthedocs.io/en/latest/reference/services/ec2.html#EC2.ServiceResource.create_instances
|
||||
worker_nodes:
|
||||
InstanceType: m5.large
|
||||
ImageId: ami-212d465b
|
||||
|
||||
# Run workers on spot by default. Comment this out to use on-demand.
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
# Additional options can be found in the boto docs, e.g.
|
||||
# SpotOptions:
|
||||
# MaxPrice: MAX_HOURLY_PRICE
|
||||
|
||||
# Additional options in the boto docs.
|
||||
|
||||
# Files or directories to copy to the head and worker nodes. The format is a
|
||||
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
|
||||
file_mounts: {
|
||||
# "/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
}
|
||||
|
||||
# List of shell commands to run to initialize the head node.
|
||||
head_init_commands:
|
||||
- cd ~/ray; git remote add eric https://github.com/ericl/ray.git || true
|
||||
- cd ~/ray; git fetch eric && git reset --hard e1e97b3
|
||||
- yes | ~/anaconda3/bin/conda install boto3=1.4.8 # 1.4.8 adds InstanceMarketOptions
|
||||
- ~/.local/bin/ray stop
|
||||
- ~/.local/bin/ray start --head --redis-port=6379 --autoscaling-config=~/ray_bootstrap_config.yaml
|
||||
|
||||
# List of shell commands to run to initialize workers.
|
||||
worker_init_commands:
|
||||
- cd ~/ray; git remote add eric https://github.com/ericl/ray.git || true
|
||||
- cd ~/ray; git fetch eric && git reset --hard e1e97b3
|
||||
- ~/.local/bin/ray stop
|
||||
- ~/.local/bin/ray start --redis-address=$RAY_HEAD_IP:6379
|
||||
@@ -0,0 +1,95 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import boto3
|
||||
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME
|
||||
|
||||
|
||||
class AWSNodeProvider(NodeProvider):
|
||||
def __init__(self, provider_config, cluster_name):
|
||||
NodeProvider.__init__(self, provider_config, cluster_name)
|
||||
self.ec2 = boto3.resource("ec2", region_name=provider_config["region"])
|
||||
|
||||
def nodes(self, tag_filters):
|
||||
filters = [
|
||||
{
|
||||
"Name": "instance-state-name",
|
||||
"Values": ["pending", "running"],
|
||||
},
|
||||
{
|
||||
"Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
|
||||
"Values": [self.cluster_name],
|
||||
},
|
||||
]
|
||||
for k, v in tag_filters.items():
|
||||
filters.append({
|
||||
"Name": "tag:{}".format(k),
|
||||
"Values": [v],
|
||||
})
|
||||
instances = list(self.ec2.instances.filter(Filters=filters))
|
||||
return [i.id for i in instances]
|
||||
|
||||
def is_running(self, node_id):
|
||||
node = self._node(node_id)
|
||||
return node.state["Name"] == "running"
|
||||
|
||||
def is_terminated(self, node_id):
|
||||
node = self._node(node_id)
|
||||
state = node.state["Name"]
|
||||
return state not in ["running", "pending"]
|
||||
|
||||
def node_tags(self, node_id):
|
||||
node = self._node(node_id)
|
||||
tags = {}
|
||||
for tag in node.tags:
|
||||
tags[tag["Key"]] = tag["Value"]
|
||||
return tags
|
||||
|
||||
def external_ip(self, node_id):
|
||||
node = self._node(node_id)
|
||||
return node.public_ip_address
|
||||
|
||||
def set_node_tags(self, node_id, tags):
|
||||
node = self._node(node_id)
|
||||
tag_pairs = []
|
||||
for k, v in tags.items():
|
||||
tag_pairs.append({
|
||||
"Key": k, "Value": v,
|
||||
})
|
||||
node.create_tags(Tags=tag_pairs)
|
||||
|
||||
def create_node(self, node_config, tags, count):
|
||||
conf = node_config.copy()
|
||||
tag_pairs = [{
|
||||
"Key": TAG_RAY_CLUSTER_NAME,
|
||||
"Value": self.cluster_name,
|
||||
}]
|
||||
for k, v in tags.items():
|
||||
tag_pairs.append(
|
||||
{
|
||||
"Key": k,
|
||||
"Value": v,
|
||||
})
|
||||
conf.update({
|
||||
"MinCount": 1,
|
||||
"MaxCount": count,
|
||||
"TagSpecifications": conf.get("TagSpecifications", []) + [
|
||||
{
|
||||
"ResourceType": "instance",
|
||||
"Tags": tag_pairs,
|
||||
}
|
||||
]
|
||||
})
|
||||
self.ec2.create_instances(**conf)
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
node = self._node(node_id)
|
||||
node.terminate()
|
||||
|
||||
def _node(self, node_id):
|
||||
matches = list(self.ec2.instances.filter(InstanceIds=[node_id]))
|
||||
assert len(matches) == 1, "Invalid instance id {}".format(node_id)
|
||||
return matches[0]
|
||||
@@ -0,0 +1,146 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import json
|
||||
import tempfile
|
||||
import time
|
||||
import sys
|
||||
|
||||
import yaml
|
||||
|
||||
from ray.autoscaler.autoscaler import validate_config, hash_runtime_conf, \
|
||||
hash_launch_conf
|
||||
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
|
||||
TAG_RAY_RUNTIME_CONFIG, TAG_NAME
|
||||
from ray.autoscaler.updater import NodeUpdaterProcess
|
||||
|
||||
|
||||
def create_or_update_cluster(
|
||||
config_file, override_min_workers, override_max_workers, sync_only):
|
||||
"""Create or updates an autoscaling Ray cluster from a config json."""
|
||||
|
||||
config = yaml.load(open(config_file).read())
|
||||
|
||||
validate_config(config)
|
||||
if override_min_workers is not None:
|
||||
config["min_workers"] = override_min_workers
|
||||
if override_max_workers is not None:
|
||||
config["max_workers"] = override_max_workers
|
||||
if sync_only:
|
||||
config["worker_init_commands"] = []
|
||||
config["head_init_commands"] = []
|
||||
|
||||
importer = NODE_PROVIDERS.get(config["provider"]["type"])
|
||||
if not importer:
|
||||
raise NotImplementedError(
|
||||
"Unsupported provider {}".format(config["provider"]))
|
||||
|
||||
bootstrap_config, _ = importer()
|
||||
config = bootstrap_config(config)
|
||||
get_or_create_head_node(config)
|
||||
|
||||
|
||||
def teardown_cluster(config_file):
|
||||
"""Destroys all nodes of a Ray cluster described by a config json."""
|
||||
|
||||
config = yaml.load(open(config_file).read())
|
||||
|
||||
validate_config(config)
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
head_node_tags = {
|
||||
TAG_RAY_NODE_TYPE: "Head",
|
||||
}
|
||||
for node in provider.nodes(head_node_tags):
|
||||
print("Terminating head node {}".format(node))
|
||||
provider.terminate_node(node)
|
||||
nodes = provider.nodes({})
|
||||
while nodes:
|
||||
for node in nodes:
|
||||
print("Terminating worker {}".format(node))
|
||||
provider.terminate_node(node)
|
||||
time.sleep(5)
|
||||
nodes = provider.nodes({})
|
||||
|
||||
|
||||
def get_or_create_head_node(config):
|
||||
"""Create the cluster head node, which in turn creates the workers."""
|
||||
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
head_node_tags = {
|
||||
TAG_RAY_NODE_TYPE: "Head",
|
||||
}
|
||||
nodes = provider.nodes(head_node_tags)
|
||||
if len(nodes) > 0:
|
||||
head_node = nodes[0]
|
||||
else:
|
||||
head_node = None
|
||||
|
||||
launch_hash = hash_launch_conf(config["head_node"], config["auth"])
|
||||
if head_node is None or provider.node_tags(head_node).get(
|
||||
TAG_RAY_LAUNCH_CONFIG) != launch_hash:
|
||||
if head_node is not None:
|
||||
print("Terminating outdated head node {}".format(head_node))
|
||||
provider.terminate_node(head_node)
|
||||
print("Launching new head node...")
|
||||
head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
|
||||
head_node_tags[TAG_NAME] = "ray-{}-head".format(config["cluster_name"])
|
||||
provider.create_node(config["head_node"], head_node_tags, 1)
|
||||
|
||||
nodes = provider.nodes(head_node_tags)
|
||||
assert len(nodes) == 1, "Failed to create head node."
|
||||
head_node = nodes[0]
|
||||
|
||||
runtime_hash = hash_runtime_conf(config["file_mounts"], config)
|
||||
|
||||
if provider.node_tags(head_node).get(
|
||||
TAG_RAY_RUNTIME_CONFIG) != runtime_hash:
|
||||
print("Updating files on head node...")
|
||||
|
||||
# Rewrite the auth config so that the head node can update the workers
|
||||
remote_key_path = "~/ray_bootstrap_key.pem"
|
||||
remote_config = copy.deepcopy(config)
|
||||
remote_config["auth"]["ssh_private_key"] = remote_key_path
|
||||
|
||||
# Adjust for new file locations
|
||||
new_mounts = {}
|
||||
for remote_path in config["file_mounts"]:
|
||||
new_mounts[remote_path] = remote_path
|
||||
remote_config["file_mounts"] = new_mounts
|
||||
|
||||
# Now inject the rewritten config and SSH key into the head node
|
||||
remote_config_file = tempfile.NamedTemporaryFile(
|
||||
"w", prefix="ray-bootstrap-")
|
||||
remote_config_file.write(json.dumps(remote_config))
|
||||
remote_config_file.flush()
|
||||
config["file_mounts"].update({
|
||||
remote_key_path: config["auth"]["ssh_private_key"],
|
||||
"~/ray_bootstrap_config.yaml": remote_config_file.name
|
||||
})
|
||||
|
||||
updater = NodeUpdaterProcess(
|
||||
head_node,
|
||||
config["provider"],
|
||||
config["auth"],
|
||||
config["cluster_name"],
|
||||
config["file_mounts"],
|
||||
config["head_init_commands"],
|
||||
runtime_hash,
|
||||
redirect_output=False)
|
||||
updater.start()
|
||||
updater.join()
|
||||
if updater.exitcode != 0:
|
||||
print("Error: updating {} failed".format(
|
||||
provider.external_ip(head_node)))
|
||||
sys.exit(1)
|
||||
print(
|
||||
"Head node up-to-date, IP address is: {}".format(
|
||||
provider.external_ip(head_node)))
|
||||
print(
|
||||
"To monitor auto-scaling activity, you can run:\n\n"
|
||||
" ssh -i {} {}@{} 'tail -f /tmp/raylogs/monitor-*'\n".format(
|
||||
config["auth"]["ssh_private_key"],
|
||||
config["auth"]["ssh_user"],
|
||||
provider.external_ip(head_node)))
|
||||
@@ -0,0 +1,83 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
def import_aws():
|
||||
from ray.autoscaler.aws.config import bootstrap_aws
|
||||
from ray.autoscaler.aws.node_provider import AWSNodeProvider
|
||||
return bootstrap_aws, AWSNodeProvider
|
||||
|
||||
|
||||
NODE_PROVIDERS = {
|
||||
"aws": import_aws,
|
||||
"gce": None, # TODO: support more node providers
|
||||
"azure": None,
|
||||
"kubernetes": None,
|
||||
"docker": None,
|
||||
"local_cluster": None,
|
||||
}
|
||||
|
||||
|
||||
def get_node_provider(provider_config, cluster_name):
|
||||
importer = NODE_PROVIDERS.get(provider_config["type"])
|
||||
if importer is None:
|
||||
raise NotImplementedError(
|
||||
"Unsupported node provider: {}".format(provider_config["type"]))
|
||||
_, provider_cls = importer()
|
||||
return provider_cls(provider_config, cluster_name)
|
||||
|
||||
|
||||
class NodeProvider(object):
|
||||
"""Interface for getting and returning nodes from a Cloud.
|
||||
|
||||
NodeProviders are namespaced by the `cluster_name` parameter; they only
|
||||
operate on nodes within that namespace.
|
||||
|
||||
Nodes may be in one of three states: {pending, running, terminated}. Nodes
|
||||
appear immediately once started by `create_node`, and transition
|
||||
immediately to terminated when `terminate_node` is called.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_config, cluster_name):
|
||||
self.provider_config = provider_config
|
||||
self.cluster_name = cluster_name
|
||||
|
||||
def nodes(self, tag_filters):
|
||||
"""Return a list of node ids filtered by the specified tags dict.
|
||||
|
||||
This list must not include terminated nodes.
|
||||
|
||||
Examples:
|
||||
>>> provider.nodes({TAG_RAY_NODE_TYPE: "Worker"})
|
||||
["node-1", "node-2"]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_running(self, node_id):
|
||||
"""Return whether the specified node is running."""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_terminated(self, node_id):
|
||||
"""Return whether the specified node is terminated."""
|
||||
raise NotImplementedError
|
||||
|
||||
def node_tags(self, node_id):
|
||||
"""Returns the tags of the given node (string dict)."""
|
||||
raise NotImplementedError
|
||||
|
||||
def external_ip(self, node_id):
|
||||
"""Returns the external ip of the given node."""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_node(self, node_config, tags, count):
|
||||
"""Creates a number of nodes within the namespace."""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_node_tags(self, node_id, tags):
|
||||
"""Sets the tag values (string dict) for the specified node."""
|
||||
raise NotImplementedError
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
"""Terminates the specified node."""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""The Ray autoscaler uses tags to associate metadata with instances."""
|
||||
|
||||
# Tag for the name of the node
|
||||
TAG_NAME = "Name"
|
||||
|
||||
# Tag uniquely identifying all nodes of a cluster
|
||||
TAG_RAY_CLUSTER_NAME = "ray:ClusterName"
|
||||
|
||||
# Tag for the type of node (e.g. Head, Worker)
|
||||
TAG_RAY_NODE_TYPE = "ray:NodeType"
|
||||
|
||||
# Tag that reports the current state of the node (e.g. Updating, Up-to-date)
|
||||
TAG_RAY_NODE_STATUS = "ray:NodeStatus"
|
||||
|
||||
# Hash of the node launch config, used to identify out-of-date nodes
|
||||
TAG_RAY_LAUNCH_CONFIG = "ray:LaunchConfig"
|
||||
|
||||
# Hash of the node runtime config, used to determine if updates are needed
|
||||
TAG_RAY_RUNTIME_CONFIG = "ray:RuntimeConfig"
|
||||
@@ -0,0 +1,172 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from multiprocessing import Process
|
||||
from threading import Thread
|
||||
|
||||
from ray.autoscaler.node_provider import get_node_provider
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG
|
||||
|
||||
# How long to wait for a node to start, in seconds
|
||||
NODE_START_WAIT_S = 300
|
||||
|
||||
|
||||
class NodeUpdater(object):
|
||||
"""A process for syncing files and running init commands on a node."""
|
||||
|
||||
def __init__(
|
||||
self, node_id, provider_config, auth_config, cluster_name,
|
||||
file_mounts, init_cmds, runtime_hash, redirect_output=True,
|
||||
process_runner=subprocess):
|
||||
self.daemon = True
|
||||
self.process_runner = process_runner
|
||||
self.provider = get_node_provider(provider_config, cluster_name)
|
||||
self.ssh_private_key = auth_config["ssh_private_key"]
|
||||
self.ssh_user = auth_config["ssh_user"]
|
||||
self.ssh_ip = self.provider.external_ip(node_id)
|
||||
self.node_id = node_id
|
||||
self.file_mounts = file_mounts
|
||||
self.init_cmds = init_cmds
|
||||
self.runtime_hash = runtime_hash
|
||||
if redirect_output:
|
||||
self.logfile = tempfile.NamedTemporaryFile(
|
||||
mode="w", prefix="node-updater-", delete=False)
|
||||
self.output_name = self.logfile.name
|
||||
self.stdout = self.logfile
|
||||
self.stderr = self.logfile
|
||||
else:
|
||||
self.logfile = None
|
||||
self.output_name = "(console)"
|
||||
self.stdout = sys.stdout
|
||||
self.stderr = sys.stderr
|
||||
|
||||
def run(self):
|
||||
print("NodeUpdater: Updating {} to {}, logging to {}".format(
|
||||
self.node_id, self.runtime_hash, self.output_name))
|
||||
try:
|
||||
self.do_update()
|
||||
except Exception as e:
|
||||
print(
|
||||
"NodeUpdater: Error updating {}, "
|
||||
"see {} for remote logs".format(e, self.output_name),
|
||||
file=self.stdout)
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: "UpdateFailed"})
|
||||
if self.logfile is not None:
|
||||
print(
|
||||
"----- BEGIN REMOTE LOGS -----\n" +
|
||||
open(self.logfile.name).read() +
|
||||
"\n----- END REMOTE LOGS -----")
|
||||
raise e
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {
|
||||
TAG_RAY_NODE_STATUS: "Up-to-date",
|
||||
TAG_RAY_RUNTIME_CONFIG: self.runtime_hash
|
||||
})
|
||||
print(
|
||||
"NodeUpdater: Applied config {} to node {}".format(
|
||||
self.runtime_hash, self.node_id),
|
||||
file=self.stdout)
|
||||
|
||||
def do_update(self):
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: "WaitingForSSH"})
|
||||
deadline = time.time() + NODE_START_WAIT_S
|
||||
|
||||
# Wait for external IP
|
||||
while time.time() < deadline and \
|
||||
not self.provider.is_terminated(self.node_id):
|
||||
print(
|
||||
"NodeUpdater: Waiting for IP of {}...".format(self.node_id),
|
||||
file=self.stdout)
|
||||
self.ssh_ip = self.provider.external_ip(self.node_id)
|
||||
if self.ssh_ip is not None:
|
||||
break
|
||||
time.sleep(5)
|
||||
assert self.ssh_ip is not None, "Unable to find IP of node"
|
||||
|
||||
# Wait for SSH access
|
||||
while time.time() < deadline and \
|
||||
not self.provider.is_terminated(self.node_id):
|
||||
try:
|
||||
print(
|
||||
"NodeUpdater: Waiting for SSH to {}...".format(
|
||||
self.node_id),
|
||||
file=self.stdout)
|
||||
if not self.provider.is_running(self.node_id):
|
||||
raise Exception()
|
||||
self.ssh_cmd(
|
||||
"uptime",
|
||||
connect_timeout=5, redirect=open("/dev/null", "w"))
|
||||
except Exception as e:
|
||||
print(
|
||||
"NodeUpdater: SSH not up, retrying: {}".format(e),
|
||||
file=self.stdout)
|
||||
time.sleep(5)
|
||||
else:
|
||||
break
|
||||
assert not self.provider.is_terminated(self.node_id)
|
||||
|
||||
# Rsync file mounts
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: "SyncingFiles"})
|
||||
for remote_path, local_path in self.file_mounts.items():
|
||||
print(
|
||||
"NodeUpdater: Syncing {} to {}...".format(
|
||||
local_path, remote_path),
|
||||
file=self.stdout)
|
||||
assert os.path.exists(local_path)
|
||||
if os.path.isdir(local_path):
|
||||
if not local_path.endswith("/"):
|
||||
local_path += "/"
|
||||
if not remote_path.endswith("/"):
|
||||
remote_path += "/"
|
||||
self.ssh_cmd(
|
||||
"mkdir -p {}".format(os.path.dirname(remote_path)))
|
||||
self.process_runner.check_call([
|
||||
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
|
||||
"-o ConnectTimeout=60s -o StrictHostKeyChecking=no",
|
||||
"--delete", "-avz", "{}".format(local_path),
|
||||
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, remote_path)
|
||||
], stdout=self.stdout, stderr=self.stderr)
|
||||
|
||||
# Run init commands
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: "RunningInitCmds"})
|
||||
for cmd in self.init_cmds:
|
||||
self.ssh_cmd(cmd, verbose=True)
|
||||
|
||||
def ssh_cmd(self, cmd, connect_timeout=60, redirect=None, verbose=False):
|
||||
if verbose:
|
||||
print(
|
||||
"NodeUpdater: running {} on {}...".format(
|
||||
cmd, self.ssh_ip),
|
||||
file=self.stdout)
|
||||
self.process_runner.check_call([
|
||||
"ssh", "-o", "ConnectTimeout={}s".format(connect_timeout),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-i", self.ssh_private_key,
|
||||
"{}@{}".format(self.ssh_user, self.ssh_ip),
|
||||
cmd,
|
||||
], stdout=redirect or self.stdout, stderr=redirect or self.stderr)
|
||||
|
||||
|
||||
class NodeUpdaterProcess(NodeUpdater, Process):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Process.__init__(self)
|
||||
NodeUpdater.__init__(self, *args, **kwargs)
|
||||
|
||||
|
||||
# Single-threaded version for unit tests
|
||||
class NodeUpdaterThread(NodeUpdater, Thread):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Thread.__init__(self)
|
||||
NodeUpdater.__init__(self, *args, **kwargs)
|
||||
self.exitcode = 0
|
||||
+21
-2
@@ -5,6 +5,7 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
@@ -15,6 +16,7 @@ import redis
|
||||
from ray.core.generated.DriverTableMessage import DriverTableMessage
|
||||
from ray.core.generated.SubscribeToDBClientTableReply import \
|
||||
SubscribeToDBClientTableReply
|
||||
from ray.autoscaler.autoscaler import StandardAutoscaler
|
||||
from ray.core.generated.TaskInfo import TaskInfo
|
||||
from ray.services import get_ip_address, get_port
|
||||
from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary
|
||||
@@ -75,7 +77,7 @@ class Monitor(object):
|
||||
managers that were up at one point and have died since then.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_address, redis_port):
|
||||
def __init__(self, redis_address, redis_port, autoscaling_config):
|
||||
# Initialize the Redis clients.
|
||||
self.state = ray.experimental.state.GlobalState()
|
||||
self.state._initialize_global_state(redis_address, redis_port)
|
||||
@@ -90,6 +92,10 @@ class Monitor(object):
|
||||
self.dead_local_schedulers = set()
|
||||
self.live_plasma_managers = Counter()
|
||||
self.dead_plasma_managers = set()
|
||||
if autoscaling_config:
|
||||
self.autoscaler = StandardAutoscaler(autoscaling_config)
|
||||
else:
|
||||
self.autoscaler = None
|
||||
|
||||
def subscribe(self, channel):
|
||||
"""Subscribe to the given channel.
|
||||
@@ -556,6 +562,9 @@ class Monitor(object):
|
||||
|
||||
# Handle messages from the subscription channels.
|
||||
while True:
|
||||
# Process autoscaling actions
|
||||
if self.autoscaler:
|
||||
self.autoscaler.update()
|
||||
# Record how many dead local schedulers and plasma managers we had
|
||||
# at the beginning of this round.
|
||||
num_dead_local_schedulers = len(self.dead_local_schedulers)
|
||||
@@ -604,6 +613,11 @@ if __name__ == "__main__":
|
||||
required=True,
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument(
|
||||
"--autoscaling-config",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the path to the autoscaling config file")
|
||||
args = parser.parse_args()
|
||||
|
||||
redis_ip_address = get_ip_address(args.redis_address)
|
||||
@@ -612,5 +626,10 @@ if __name__ == "__main__":
|
||||
# Initialize the global state.
|
||||
ray.global_state._initialize_global_state(redis_ip_address, redis_port)
|
||||
|
||||
monitor = Monitor(redis_ip_address, redis_port)
|
||||
if args.autoscaling_config:
|
||||
autoscaling_config = os.path.expanduser(args.autoscaling_config)
|
||||
else:
|
||||
autoscaling_config = None
|
||||
|
||||
monitor = Monitor(redis_ip_address, redis_port, autoscaling_config)
|
||||
monitor.run()
|
||||
|
||||
@@ -7,6 +7,7 @@ import json
|
||||
import subprocess
|
||||
|
||||
import ray.services as services
|
||||
from ray.autoscaler.commands import create_or_update_cluster, teardown_cluster
|
||||
|
||||
|
||||
def check_no_existing_redis_clients(node_ip_address, redis_client):
|
||||
@@ -76,10 +77,12 @@ def cli():
|
||||
help="object store directory for memory mapped files")
|
||||
@click.option("--huge-pages", is_flag=True, default=False,
|
||||
help="enable support for huge pages in the object store")
|
||||
@click.option("--autoscaling-config", required=False, type=str,
|
||||
help="the file that contains the autoscaling config")
|
||||
def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
||||
redis_max_clients, object_manager_port, num_workers, num_cpus,
|
||||
num_gpus, resources, head, no_ui, block, plasma_directory,
|
||||
huge_pages):
|
||||
huge_pages, autoscaling_config):
|
||||
# Note that we redirect stdout and stderr to /dev/null because otherwise
|
||||
# attempts to print may cause exceptions if a process is started inside of
|
||||
# an SSH connection and the SSH connection dies. TODO(rkn): This is a
|
||||
@@ -138,7 +141,8 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
||||
redis_max_clients=redis_max_clients,
|
||||
include_webui=(not no_ui),
|
||||
plasma_directory=plasma_directory,
|
||||
huge_pages=huge_pages)
|
||||
huge_pages=huge_pages,
|
||||
autoscaling_config=autoscaling_config)
|
||||
print(address_info)
|
||||
print("\nStarted Ray on this node. You can add additional nodes to "
|
||||
"the cluster by calling\n\n"
|
||||
@@ -238,8 +242,35 @@ def stop():
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("cluster_config_file", required=True, type=str)
|
||||
@click.option(
|
||||
"--sync-only", is_flag=True, default=False, help=(
|
||||
"Whether to only perform the file sync stage when updating nodes. "
|
||||
"This avoids interrupting running jobs. You can use this when "
|
||||
"resizing the cluster with the min/max_workers flag."))
|
||||
@click.option(
|
||||
"--min-workers", required=False, type=int, help=(
|
||||
"Override the configured min worker node count for the cluster."))
|
||||
@click.option(
|
||||
"--max-workers", required=False, type=int, help=(
|
||||
"Override the configured max worker node count for the cluster."))
|
||||
def create_or_update(
|
||||
cluster_config_file, min_workers, max_workers, sync_only):
|
||||
create_or_update_cluster(
|
||||
cluster_config_file, min_workers, max_workers, sync_only)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("cluster_config_file", required=True, type=str)
|
||||
def teardown(cluster_config_file):
|
||||
teardown_cluster(cluster_config_file)
|
||||
|
||||
|
||||
cli.add_command(start)
|
||||
cli.add_command(stop)
|
||||
cli.add_command(create_or_update)
|
||||
cli.add_command(teardown)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
+19
-8
@@ -4,6 +4,7 @@ from __future__ import print_function
|
||||
|
||||
import binascii
|
||||
from collections import namedtuple, OrderedDict
|
||||
from datetime import datetime
|
||||
import cloudpickle
|
||||
import json
|
||||
import os
|
||||
@@ -863,7 +864,7 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
|
||||
|
||||
|
||||
def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
stderr_file=None, cleanup=True):
|
||||
stderr_file=None, cleanup=True, autoscaling_config=None):
|
||||
"""Run a process to monitor the other processes.
|
||||
|
||||
Args:
|
||||
@@ -878,12 +879,15 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
then this process will be killed by services.cleanup() when the
|
||||
Python process that imported services exits. This is True by
|
||||
default.
|
||||
autoscaling_config: path to autoscaling config file.
|
||||
"""
|
||||
monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
"monitor.py")
|
||||
command = [sys.executable,
|
||||
monitor_path,
|
||||
"--redis-address=" + str(redis_address)]
|
||||
if autoscaling_config:
|
||||
command.append("--autoscaling-config=" + str(autoscaling_config))
|
||||
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
if cleanup:
|
||||
all_processes[PROCESS_TYPE_WORKER].append(p)
|
||||
@@ -908,7 +912,8 @@ def start_ray_processes(address_info=None,
|
||||
start_workers_from_local_scheduler=True,
|
||||
resources=None,
|
||||
plasma_directory=None,
|
||||
huge_pages=False):
|
||||
huge_pages=False,
|
||||
autoscaling_config=None):
|
||||
"""Helper method to start Ray processes.
|
||||
|
||||
Args:
|
||||
@@ -956,6 +961,7 @@ def start_ray_processes(address_info=None,
|
||||
be created.
|
||||
huge_pages: Boolean flag indicating whether to start the Object
|
||||
Store with hugetlbfs support. Requires plasma_directory.
|
||||
autoscaling_config: path to autoscaling config file.
|
||||
|
||||
Returns:
|
||||
A dictionary of the address information for the processes that were
|
||||
@@ -1006,7 +1012,8 @@ def start_ray_processes(address_info=None,
|
||||
node_ip_address,
|
||||
stdout_file=monitor_stdout_file,
|
||||
stderr_file=monitor_stderr_file,
|
||||
cleanup=cleanup)
|
||||
cleanup=cleanup,
|
||||
autoscaling_config=autoscaling_config)
|
||||
|
||||
if redis_shards == []:
|
||||
# Get redis shards from primary redis instance.
|
||||
@@ -1221,7 +1228,8 @@ def start_ray_head(address_info=None,
|
||||
redis_max_clients=None,
|
||||
include_webui=True,
|
||||
plasma_directory=None,
|
||||
huge_pages=False):
|
||||
huge_pages=False,
|
||||
autoscaling_config=None):
|
||||
"""Start Ray in local mode.
|
||||
|
||||
Args:
|
||||
@@ -1263,6 +1271,7 @@ def start_ray_head(address_info=None,
|
||||
be created.
|
||||
huge_pages: Boolean flag indicating whether to start the Object
|
||||
Store with hugetlbfs support. Requires plasma_directory.
|
||||
autoscaling_config: path to autoscaling config file.
|
||||
|
||||
Returns:
|
||||
A dictionary of the address information for the processes that were
|
||||
@@ -1287,7 +1296,8 @@ def start_ray_head(address_info=None,
|
||||
num_redis_shards=num_redis_shards,
|
||||
redis_max_clients=redis_max_clients,
|
||||
plasma_directory=plasma_directory,
|
||||
huge_pages=huge_pages)
|
||||
huge_pages=huge_pages,
|
||||
autoscaling_config=autoscaling_config)
|
||||
|
||||
|
||||
def try_to_create_directory(directory_path):
|
||||
@@ -1333,9 +1343,10 @@ def new_log_files(name, redirect_output):
|
||||
# Create another directory that will be used by some of the RL algorithms.
|
||||
try_to_create_directory("/tmp/ray")
|
||||
|
||||
log_id = random.randint(0, 1000000000)
|
||||
log_stdout = "{}/{}-{:010d}.out".format(logs_dir, name, log_id)
|
||||
log_stderr = "{}/{}-{:010d}.err".format(logs_dir, name, log_id)
|
||||
log_id = random.randint(0, 10000)
|
||||
date_str = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
log_stdout = "{}/{}-{}-{:05d}.out".format(logs_dir, name, date_str, log_id)
|
||||
log_stderr = "{}/{}-{}-{:05d}.err".format(logs_dir, name, date_str, log_id)
|
||||
log_stdout_file = open(log_stdout, "a")
|
||||
log_stderr_file = open(log_stderr, "a")
|
||||
return log_stdout_file, log_stderr_file
|
||||
|
||||
@@ -112,6 +112,7 @@ setup(name="ray",
|
||||
"colorama",
|
||||
"psutil",
|
||||
"pytest",
|
||||
"pyyaml",
|
||||
"redis",
|
||||
"cloudpickle == 0.5.2",
|
||||
# The six module is required by pyarrow.
|
||||
|
||||
Reference in New Issue
Block a user