mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 16:58:23 +08:00
[autoscaler] cache stopped nodes, no screen on attach (#5741)
This commit is contained in:
@@ -77,6 +77,11 @@ CLUSTER_CONFIG_SCHEMA = {
|
||||
"worker_ips": (list, OPTIONAL), # local cluster worker nodes
|
||||
"use_internal_ips": (bool, OPTIONAL), # don't require public ips
|
||||
"extra_config": (dict, OPTIONAL), # provider-specific config
|
||||
|
||||
# Whether to try to reuse previously stopped nodes instead of
|
||||
# launching nodes. This will also cause the autoscaler to stop
|
||||
# nodes instead of terminating them. Only implemented for AWS.
|
||||
"cache_stopped_nodes": (bool, OPTIONAL),
|
||||
},
|
||||
REQUIRED),
|
||||
|
||||
@@ -545,9 +550,9 @@ class StandardAutoscaler(object):
|
||||
T = [
|
||||
threading.Thread(
|
||||
target=self.spawn_updater,
|
||||
args=(node_id, commands),
|
||||
) for node_id, commands in (self.should_update(node_id)
|
||||
for node_id in nodes)
|
||||
args=(node_id, commands, ray_start),
|
||||
) for node_id, commands, ray_start in (self.should_update(node_id)
|
||||
for node_id in nodes)
|
||||
if node_id is not None
|
||||
]
|
||||
for t in T:
|
||||
@@ -648,7 +653,8 @@ class StandardAutoscaler(object):
|
||||
cluster_name=self.config["cluster_name"],
|
||||
file_mounts={},
|
||||
initialization_commands=[],
|
||||
setup_commands=with_head_node_ip(
|
||||
setup_commands=[],
|
||||
ray_start_commands=with_head_node_ip(
|
||||
self.config["worker_start_ray_commands"]),
|
||||
runtime_hash=self.runtime_hash,
|
||||
process_runner=self.process_runner,
|
||||
@@ -658,23 +664,25 @@ class StandardAutoscaler(object):
|
||||
|
||||
def should_update(self, node_id):
|
||||
if not self.can_update(node_id):
|
||||
return (None, None)
|
||||
return (None, None, None)
|
||||
|
||||
if self.files_up_to_date(node_id):
|
||||
return (None, None)
|
||||
return (None, None, None)
|
||||
|
||||
successful_updated = self.num_successful_updates.get(node_id, 0) > 0
|
||||
if successful_updated and self.config.get("restart_only", False):
|
||||
init_commands = self.config["worker_start_ray_commands"]
|
||||
init_commands = []
|
||||
ray_commands = self.config["worker_start_ray_commands"]
|
||||
elif successful_updated and self.config.get("no_restart", False):
|
||||
init_commands = self.config["worker_setup_commands"]
|
||||
ray_commands = []
|
||||
else:
|
||||
init_commands = (self.config["worker_setup_commands"] +
|
||||
self.config["worker_start_ray_commands"])
|
||||
init_commands = self.config["worker_setup_commands"]
|
||||
ray_commands = self.config["worker_start_ray_commands"]
|
||||
|
||||
return (node_id, init_commands)
|
||||
return (node_id, init_commands, ray_commands)
|
||||
|
||||
def spawn_updater(self, node_id, init_commands):
|
||||
def spawn_updater(self, node_id, init_commands, ray_start_commands):
|
||||
updater = NodeUpdaterThread(
|
||||
node_id=node_id,
|
||||
provider_config=self.config["provider"],
|
||||
@@ -685,6 +693,7 @@ class StandardAutoscaler(object):
|
||||
initialization_commands=with_head_node_ip(
|
||||
self.config["initialization_commands"]),
|
||||
setup_commands=with_head_node_ip(init_commands),
|
||||
ray_start_commands=with_head_node_ip(ray_start_commands),
|
||||
runtime_hash=self.runtime_hash,
|
||||
process_runner=self.process_runner,
|
||||
use_internal_ip=True)
|
||||
|
||||
@@ -12,7 +12,8 @@ import botocore
|
||||
from botocore.config import Config
|
||||
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME, \
|
||||
TAG_RAY_LAUNCH_CONFIG, TAG_RAY_NODE_TYPE
|
||||
from ray.ray_constants import BOTO_MAX_RETRIES
|
||||
from ray.autoscaler.log_timer import LogTimer
|
||||
|
||||
@@ -41,6 +42,8 @@ class AWSNodeProvider(NodeProvider):
|
||||
def __init__(self, provider_config, cluster_name):
|
||||
NodeProvider.__init__(self, provider_config, cluster_name)
|
||||
config = Config(retries={"max_attempts": BOTO_MAX_RETRIES})
|
||||
self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes",
|
||||
True)
|
||||
self.ec2 = boto3.resource(
|
||||
"ec2", region_name=provider_config["region"], config=config)
|
||||
|
||||
@@ -61,7 +64,7 @@ class AWSNodeProvider(NodeProvider):
|
||||
self.cached_nodes = {}
|
||||
|
||||
def _node_tag_update_loop(self):
|
||||
""" Update the AWS tags for a cluster periodically.
|
||||
"""Update the AWS tags for a cluster periodically.
|
||||
|
||||
The purpose of this loop is to avoid excessive EC2 calls when a large
|
||||
number of nodes are being launched simultaneously.
|
||||
@@ -171,6 +174,55 @@ class AWSNodeProvider(NodeProvider):
|
||||
self.tag_cache_update_event.set()
|
||||
|
||||
def create_node(self, node_config, tags, count):
|
||||
# Try to reuse previously stopped nodes with compatible configs
|
||||
if self.cache_stopped_nodes:
|
||||
filters = [
|
||||
{
|
||||
"Name": "instance-state-name",
|
||||
"Values": ["stopped", "stopping"],
|
||||
},
|
||||
{
|
||||
"Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
|
||||
"Values": [self.cluster_name],
|
||||
},
|
||||
{
|
||||
"Name": "tag:{}".format(TAG_RAY_NODE_TYPE),
|
||||
"Values": [tags[TAG_RAY_NODE_TYPE]],
|
||||
},
|
||||
{
|
||||
"Name": "tag:{}".format(TAG_RAY_LAUNCH_CONFIG),
|
||||
"Values": [tags[TAG_RAY_LAUNCH_CONFIG]],
|
||||
},
|
||||
]
|
||||
|
||||
reuse_nodes = list(
|
||||
self.ec2.instances.filter(Filters=filters))[:count]
|
||||
reuse_node_ids = [n.id for n in reuse_nodes]
|
||||
if reuse_nodes:
|
||||
logger.info("AWSNodeProvider: reusing instances {}. "
|
||||
"To disable reuse, set "
|
||||
"'cache_stopped_nodes: False' in the provider "
|
||||
"config.".format(reuse_node_ids))
|
||||
|
||||
for node in reuse_nodes:
|
||||
self.tag_cache[node.id] = from_aws_format(
|
||||
{x["Key"]: x["Value"]
|
||||
for x in node.tags})
|
||||
if node.state["Name"] == "stopping":
|
||||
logger.info("AWSNodeProvider: waiting for instance "
|
||||
"{} to fully stop...".format(node.id))
|
||||
node.wait_until_stopped()
|
||||
|
||||
self.ec2.meta.client.start_instances(
|
||||
InstanceIds=reuse_node_ids)
|
||||
for node_id in reuse_node_ids:
|
||||
self.set_node_tags(node_id, tags)
|
||||
count -= len(reuse_node_ids)
|
||||
|
||||
if count:
|
||||
self._create_node(node_config, tags, count)
|
||||
|
||||
def _create_node(self, node_config, tags, count):
|
||||
tags = to_aws_format(tags)
|
||||
conf = node_config.copy()
|
||||
|
||||
@@ -248,13 +300,43 @@ class AWSNodeProvider(NodeProvider):
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
node = self._get_cached_node(node_id)
|
||||
node.terminate()
|
||||
if self.cache_stopped_nodes:
|
||||
if node.spot_instance_request_id:
|
||||
logger.info(
|
||||
"AWSNodeProvider: terminating node {} (spot nodes cannot "
|
||||
"be stopped, only terminated)".format(node_id))
|
||||
node.terminate()
|
||||
else:
|
||||
logger.info(
|
||||
"AWSNodeProvider: stopping node {}. To terminate nodes "
|
||||
"on stop, set 'cache_stopped_nodes: False' in the "
|
||||
"provider config.".format(node_id))
|
||||
node.stop()
|
||||
else:
|
||||
node.terminate()
|
||||
|
||||
self.tag_cache.pop(node_id, None)
|
||||
self.tag_cache_pending.pop(node_id, None)
|
||||
|
||||
def terminate_nodes(self, node_ids):
|
||||
self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)
|
||||
if not node_ids:
|
||||
return
|
||||
|
||||
node0 = self._get_cached_node(node_ids[0])
|
||||
if self.cache_stopped_nodes:
|
||||
if node0.spot_instance_request_id:
|
||||
logger.info(
|
||||
"AWSNodeProvider: terminating nodes {} (spot nodes cannot "
|
||||
"be stopped, only terminated)".format(node_ids))
|
||||
self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)
|
||||
else:
|
||||
logger.info(
|
||||
"AWSNodeProvider: stopping nodes {}. To terminate nodes "
|
||||
"on stop, set 'cache_stopped_nodes: False' in the "
|
||||
"provider config.".format(node_ids))
|
||||
self.ec2.meta.client.stop_instances(InstanceIds=node_ids)
|
||||
else:
|
||||
self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)
|
||||
|
||||
for node_id in node_ids:
|
||||
self.tag_cache.pop(node_id, None)
|
||||
|
||||
@@ -105,10 +105,10 @@ def teardown_cluster(config_file, yes, workers_only, override_cluster_name):
|
||||
# Loop here to check that both the head and worker nodes are actually
|
||||
# really gone
|
||||
A = remaining_nodes()
|
||||
with LogTimer("teardown_cluster: Termination done."):
|
||||
with LogTimer("teardown_cluster: done."):
|
||||
while A:
|
||||
logger.info("teardown_cluster: "
|
||||
"Terminating {} nodes...".format(len(A)))
|
||||
"Shutting down {} nodes...".format(len(A)))
|
||||
provider.terminate_nodes(A)
|
||||
time.sleep(1)
|
||||
A = remaining_nodes()
|
||||
@@ -130,7 +130,7 @@ def kill_node(config_file, yes, hard, override_cluster_name):
|
||||
try:
|
||||
nodes = provider.non_terminated_nodes({TAG_RAY_NODE_TYPE: "worker"})
|
||||
node = random.choice(nodes)
|
||||
logger.info("kill_node: Terminating worker {}".format(node))
|
||||
logger.info("kill_node: Shutdown worker {}".format(node))
|
||||
if hard:
|
||||
provider.terminate_node(node)
|
||||
else:
|
||||
@@ -143,6 +143,7 @@ def kill_node(config_file, yes, hard, override_cluster_name):
|
||||
file_mounts=config["file_mounts"],
|
||||
initialization_commands=[],
|
||||
setup_commands=[],
|
||||
ray_start_commands=[],
|
||||
runtime_hash="")
|
||||
|
||||
_exec(updater, "ray stop", False, False)
|
||||
@@ -170,6 +171,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
override_cluster_name):
|
||||
"""Create the cluster head node, which in turn creates the workers."""
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
config_file = os.path.abspath(config_file)
|
||||
try:
|
||||
head_node_tags = {
|
||||
TAG_RAY_NODE_TYPE: "head",
|
||||
@@ -193,7 +195,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
yes)
|
||||
logger.info(
|
||||
"get_or_create_head_node: "
|
||||
"Terminating outdated head node {}".format(head_node))
|
||||
"Shutting down outdated head node {}".format(head_node))
|
||||
provider.terminate_node(head_node)
|
||||
logger.info("get_or_create_head_node: Launching new head node...")
|
||||
head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
|
||||
@@ -233,12 +235,14 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
})
|
||||
|
||||
if restart_only:
|
||||
init_commands = config["head_start_ray_commands"]
|
||||
init_commands = []
|
||||
ray_start_commands = config["head_start_ray_commands"]
|
||||
elif no_restart:
|
||||
init_commands = config["head_setup_commands"]
|
||||
ray_start_commands = []
|
||||
else:
|
||||
init_commands = (config["head_setup_commands"] +
|
||||
config["head_start_ray_commands"])
|
||||
init_commands = config["head_setup_commands"]
|
||||
ray_start_commands = config["head_start_ray_commands"]
|
||||
|
||||
updater = NodeUpdaterThread(
|
||||
node_id=head_node,
|
||||
@@ -249,6 +253,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
file_mounts=config["file_mounts"],
|
||||
initialization_commands=config["initialization_commands"],
|
||||
setup_commands=init_commands,
|
||||
ray_start_commands=ray_start_commands,
|
||||
runtime_hash=runtime_hash,
|
||||
)
|
||||
updater.start()
|
||||
@@ -279,7 +284,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
modifiers = ""
|
||||
print("To monitor auto-scaling activity, you can run:\n\n"
|
||||
" ray exec {} {}{}{}\n".format(
|
||||
config_file, "--docker " if use_docker else " ",
|
||||
config_file, "--docker " if use_docker else "",
|
||||
quote(monitor_str), modifiers))
|
||||
print("To open a console on the cluster:\n\n"
|
||||
" ray attach {}{}\n".format(config_file, modifiers))
|
||||
@@ -292,12 +297,14 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
provider.cleanup()
|
||||
|
||||
|
||||
def attach_cluster(config_file, start, use_tmux, override_cluster_name, new):
|
||||
def attach_cluster(config_file, start, use_screen, use_tmux,
|
||||
override_cluster_name, new):
|
||||
"""Attaches to a screen for the specified cluster.
|
||||
|
||||
Arguments:
|
||||
config_file: path to the cluster yaml
|
||||
start: whether to start the cluster if it isn't up
|
||||
use_screen: whether to use screen as multiplexer
|
||||
use_tmux: whether to use tmux as multiplexer
|
||||
override_cluster_name: set the name of the cluster
|
||||
new: whether to force a new screen
|
||||
@@ -308,11 +315,16 @@ def attach_cluster(config_file, start, use_tmux, override_cluster_name, new):
|
||||
cmd = "tmux new"
|
||||
else:
|
||||
cmd = "tmux attach || tmux new"
|
||||
else:
|
||||
elif use_screen:
|
||||
if new:
|
||||
cmd = "screen -L"
|
||||
else:
|
||||
cmd = "screen -L -xRR"
|
||||
else:
|
||||
if new:
|
||||
raise ValueError(
|
||||
"--new only makes sense if passing --screen or --tmux")
|
||||
cmd = "$SHELL"
|
||||
|
||||
exec_cluster(config_file, cmd, False, False, False, False, start,
|
||||
override_cluster_name, None)
|
||||
@@ -354,6 +366,7 @@ def exec_cluster(config_file, cmd, docker, screen, tmux, stop, start,
|
||||
file_mounts=config["file_mounts"],
|
||||
initialization_commands=[],
|
||||
setup_commands=[],
|
||||
ray_start_commands=[],
|
||||
runtime_hash="",
|
||||
)
|
||||
|
||||
@@ -447,6 +460,7 @@ def rsync(config_file, source, target, override_cluster_name, down):
|
||||
file_mounts=config["file_mounts"],
|
||||
initialization_commands=[],
|
||||
setup_commands=[],
|
||||
ray_start_commands=[],
|
||||
runtime_hash="",
|
||||
)
|
||||
if down:
|
||||
|
||||
@@ -53,6 +53,7 @@ class NodeUpdater(object):
|
||||
file_mounts,
|
||||
initialization_commands,
|
||||
setup_commands,
|
||||
ray_start_commands,
|
||||
runtime_hash,
|
||||
process_runner=subprocess,
|
||||
exit_on_update_fail=False,
|
||||
@@ -80,6 +81,7 @@ class NodeUpdater(object):
|
||||
}
|
||||
self.initialization_commands = initialization_commands
|
||||
self.setup_commands = setup_commands
|
||||
self.ray_start_commands = ray_start_commands
|
||||
self.exit_on_update_fail = exit_on_update_fail
|
||||
self.runtime_hash = runtime_hash
|
||||
|
||||
@@ -222,21 +224,32 @@ class NodeUpdater(object):
|
||||
ssh_ok = self.wait_for_ssh(deadline)
|
||||
assert ssh_ok, "Unable to SSH to node"
|
||||
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "syncing-files"})
|
||||
self.sync_file_mounts(self.rsync_up)
|
||||
node_tags = self.provider.node_tags(self.node_id)
|
||||
if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash:
|
||||
logger.info(
|
||||
"NodeUpdater: {} already up-to-date, skip to ray start".format(
|
||||
self.node_id))
|
||||
else:
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "syncing-files"})
|
||||
self.sync_file_mounts(self.rsync_up)
|
||||
|
||||
# Run init commands
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "setting-up"})
|
||||
m = "{}: Initialization commands completed".format(self.node_id)
|
||||
with LogTimer("NodeUpdater: {}".format(m)):
|
||||
for cmd in self.initialization_commands:
|
||||
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
|
||||
# Run init commands
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "setting-up"})
|
||||
m = "{}: Initialization commands completed".format(self.node_id)
|
||||
with LogTimer("NodeUpdater: {}".format(m)):
|
||||
for cmd in self.initialization_commands:
|
||||
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
|
||||
|
||||
m = "{}: Setup commands completed".format(self.node_id)
|
||||
m = "{}: Setup commands completed".format(self.node_id)
|
||||
with LogTimer("NodeUpdater: {}".format(m)):
|
||||
for cmd in self.setup_commands:
|
||||
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
|
||||
|
||||
m = "{}: Ray start commands completed".format(self.node_id)
|
||||
with LogTimer("NodeUpdater: {}".format(m)):
|
||||
for cmd in self.setup_commands:
|
||||
for cmd in self.ray_start_commands:
|
||||
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
|
||||
|
||||
def rsync_up(self, source, target, redirect=None):
|
||||
|
||||
@@ -568,6 +568,8 @@ def monitor(cluster_config_file, lines, cluster_name):
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Start the cluster if needed.")
|
||||
@click.option(
|
||||
"--screen", is_flag=True, default=False, help="Run the command in screen.")
|
||||
@click.option(
|
||||
"--tmux", is_flag=True, default=False, help="Run the command in tmux.")
|
||||
@click.option(
|
||||
@@ -578,8 +580,8 @@ def monitor(cluster_config_file, lines, cluster_name):
|
||||
help="Override the configured cluster name.")
|
||||
@click.option(
|
||||
"--new", "-N", is_flag=True, help="Force creation of a new screen.")
|
||||
def attach(cluster_config_file, start, tmux, cluster_name, new):
|
||||
attach_cluster(cluster_config_file, start, tmux, cluster_name, new)
|
||||
def attach(cluster_config_file, start, screen, tmux, cluster_name, new):
|
||||
attach_cluster(cluster_config_file, start, screen, tmux, cluster_name, new)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
Reference in New Issue
Block a user