[autoscaler] cache stopped nodes, no screen on attach (#5741)

This commit is contained in:
Eric Liang
2019-09-22 17:30:35 -07:00
committed by GitHub
parent 5f5873b182
commit 56ab9a00bb
5 changed files with 159 additions and 39 deletions
+20 -11
View File
@@ -77,6 +77,11 @@ CLUSTER_CONFIG_SCHEMA = {
"worker_ips": (list, OPTIONAL), # local cluster worker nodes
"use_internal_ips": (bool, OPTIONAL), # don't require public ips
"extra_config": (dict, OPTIONAL), # provider-specific config
# Whether to try to reuse previously stopped nodes instead of
# launching nodes. This will also cause the autoscaler to stop
# nodes instead of terminating them. Only implemented for AWS.
"cache_stopped_nodes": (bool, OPTIONAL),
},
REQUIRED),
@@ -545,9 +550,9 @@ class StandardAutoscaler(object):
T = [
threading.Thread(
target=self.spawn_updater,
args=(node_id, commands),
) for node_id, commands in (self.should_update(node_id)
for node_id in nodes)
args=(node_id, commands, ray_start),
) for node_id, commands, ray_start in (self.should_update(node_id)
for node_id in nodes)
if node_id is not None
]
for t in T:
@@ -648,7 +653,8 @@ class StandardAutoscaler(object):
cluster_name=self.config["cluster_name"],
file_mounts={},
initialization_commands=[],
setup_commands=with_head_node_ip(
setup_commands=[],
ray_start_commands=with_head_node_ip(
self.config["worker_start_ray_commands"]),
runtime_hash=self.runtime_hash,
process_runner=self.process_runner,
@@ -658,23 +664,25 @@ class StandardAutoscaler(object):
def should_update(self, node_id):
if not self.can_update(node_id):
return (None, None)
return (None, None, None)
if self.files_up_to_date(node_id):
return (None, None)
return (None, None, None)
successful_updated = self.num_successful_updates.get(node_id, 0) > 0
if successful_updated and self.config.get("restart_only", False):
init_commands = self.config["worker_start_ray_commands"]
init_commands = []
ray_commands = self.config["worker_start_ray_commands"]
elif successful_updated and self.config.get("no_restart", False):
init_commands = self.config["worker_setup_commands"]
ray_commands = []
else:
init_commands = (self.config["worker_setup_commands"] +
self.config["worker_start_ray_commands"])
init_commands = self.config["worker_setup_commands"]
ray_commands = self.config["worker_start_ray_commands"]
return (node_id, init_commands)
return (node_id, init_commands, ray_commands)
def spawn_updater(self, node_id, init_commands):
def spawn_updater(self, node_id, init_commands, ray_start_commands):
updater = NodeUpdaterThread(
node_id=node_id,
provider_config=self.config["provider"],
@@ -685,6 +693,7 @@ class StandardAutoscaler(object):
initialization_commands=with_head_node_ip(
self.config["initialization_commands"]),
setup_commands=with_head_node_ip(init_commands),
ray_start_commands=with_head_node_ip(ray_start_commands),
runtime_hash=self.runtime_hash,
process_runner=self.process_runner,
use_internal_ip=True)
+86 -4
View File
@@ -12,7 +12,8 @@ import botocore
from botocore.config import Config
from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME, \
TAG_RAY_LAUNCH_CONFIG, TAG_RAY_NODE_TYPE
from ray.ray_constants import BOTO_MAX_RETRIES
from ray.autoscaler.log_timer import LogTimer
@@ -41,6 +42,8 @@ class AWSNodeProvider(NodeProvider):
def __init__(self, provider_config, cluster_name):
NodeProvider.__init__(self, provider_config, cluster_name)
config = Config(retries={"max_attempts": BOTO_MAX_RETRIES})
self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes",
True)
self.ec2 = boto3.resource(
"ec2", region_name=provider_config["region"], config=config)
@@ -61,7 +64,7 @@ class AWSNodeProvider(NodeProvider):
self.cached_nodes = {}
def _node_tag_update_loop(self):
""" Update the AWS tags for a cluster periodically.
"""Update the AWS tags for a cluster periodically.
The purpose of this loop is to avoid excessive EC2 calls when a large
number of nodes are being launched simultaneously.
@@ -171,6 +174,55 @@ class AWSNodeProvider(NodeProvider):
self.tag_cache_update_event.set()
def create_node(self, node_config, tags, count):
# Try to reuse previously stopped nodes with compatible configs
if self.cache_stopped_nodes:
filters = [
{
"Name": "instance-state-name",
"Values": ["stopped", "stopping"],
},
{
"Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
"Values": [self.cluster_name],
},
{
"Name": "tag:{}".format(TAG_RAY_NODE_TYPE),
"Values": [tags[TAG_RAY_NODE_TYPE]],
},
{
"Name": "tag:{}".format(TAG_RAY_LAUNCH_CONFIG),
"Values": [tags[TAG_RAY_LAUNCH_CONFIG]],
},
]
reuse_nodes = list(
self.ec2.instances.filter(Filters=filters))[:count]
reuse_node_ids = [n.id for n in reuse_nodes]
if reuse_nodes:
logger.info("AWSNodeProvider: reusing instances {}. "
"To disable reuse, set "
"'cache_stopped_nodes: False' in the provider "
"config.".format(reuse_node_ids))
for node in reuse_nodes:
self.tag_cache[node.id] = from_aws_format(
{x["Key"]: x["Value"]
for x in node.tags})
if node.state["Name"] == "stopping":
logger.info("AWSNodeProvider: waiting for instance "
"{} to fully stop...".format(node.id))
node.wait_until_stopped()
self.ec2.meta.client.start_instances(
InstanceIds=reuse_node_ids)
for node_id in reuse_node_ids:
self.set_node_tags(node_id, tags)
count -= len(reuse_node_ids)
if count:
self._create_node(node_config, tags, count)
def _create_node(self, node_config, tags, count):
tags = to_aws_format(tags)
conf = node_config.copy()
@@ -248,13 +300,43 @@ class AWSNodeProvider(NodeProvider):
def terminate_node(self, node_id):
node = self._get_cached_node(node_id)
node.terminate()
if self.cache_stopped_nodes:
if node.spot_instance_request_id:
logger.info(
"AWSNodeProvider: terminating node {} (spot nodes cannot "
"be stopped, only terminated)".format(node_id))
node.terminate()
else:
logger.info(
"AWSNodeProvider: stopping node {}. To terminate nodes "
"on stop, set 'cache_stopped_nodes: False' in the "
"provider config.".format(node_id))
node.stop()
else:
node.terminate()
self.tag_cache.pop(node_id, None)
self.tag_cache_pending.pop(node_id, None)
def terminate_nodes(self, node_ids):
self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)
if not node_ids:
return
node0 = self._get_cached_node(node_ids[0])
if self.cache_stopped_nodes:
if node0.spot_instance_request_id:
logger.info(
"AWSNodeProvider: terminating nodes {} (spot nodes cannot "
"be stopped, only terminated)".format(node_ids))
self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)
else:
logger.info(
"AWSNodeProvider: stopping nodes {}. To terminate nodes "
"on stop, set 'cache_stopped_nodes: False' in the "
"provider config.".format(node_ids))
self.ec2.meta.client.stop_instances(InstanceIds=node_ids)
else:
self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)
for node_id in node_ids:
self.tag_cache.pop(node_id, None)
+24 -10
View File
@@ -105,10 +105,10 @@ def teardown_cluster(config_file, yes, workers_only, override_cluster_name):
# Loop here to check that both the head and worker nodes are actually
# really gone
A = remaining_nodes()
with LogTimer("teardown_cluster: Termination done."):
with LogTimer("teardown_cluster: done."):
while A:
logger.info("teardown_cluster: "
"Terminating {} nodes...".format(len(A)))
"Shutting down {} nodes...".format(len(A)))
provider.terminate_nodes(A)
time.sleep(1)
A = remaining_nodes()
@@ -130,7 +130,7 @@ def kill_node(config_file, yes, hard, override_cluster_name):
try:
nodes = provider.non_terminated_nodes({TAG_RAY_NODE_TYPE: "worker"})
node = random.choice(nodes)
logger.info("kill_node: Terminating worker {}".format(node))
logger.info("kill_node: Shutdown worker {}".format(node))
if hard:
provider.terminate_node(node)
else:
@@ -143,6 +143,7 @@ def kill_node(config_file, yes, hard, override_cluster_name):
file_mounts=config["file_mounts"],
initialization_commands=[],
setup_commands=[],
ray_start_commands=[],
runtime_hash="")
_exec(updater, "ray stop", False, False)
@@ -170,6 +171,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
override_cluster_name):
"""Create the cluster head node, which in turn creates the workers."""
provider = get_node_provider(config["provider"], config["cluster_name"])
config_file = os.path.abspath(config_file)
try:
head_node_tags = {
TAG_RAY_NODE_TYPE: "head",
@@ -193,7 +195,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
yes)
logger.info(
"get_or_create_head_node: "
"Terminating outdated head node {}".format(head_node))
"Shutting down outdated head node {}".format(head_node))
provider.terminate_node(head_node)
logger.info("get_or_create_head_node: Launching new head node...")
head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
@@ -233,12 +235,14 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
})
if restart_only:
init_commands = config["head_start_ray_commands"]
init_commands = []
ray_start_commands = config["head_start_ray_commands"]
elif no_restart:
init_commands = config["head_setup_commands"]
ray_start_commands = []
else:
init_commands = (config["head_setup_commands"] +
config["head_start_ray_commands"])
init_commands = config["head_setup_commands"]
ray_start_commands = config["head_start_ray_commands"]
updater = NodeUpdaterThread(
node_id=head_node,
@@ -249,6 +253,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
file_mounts=config["file_mounts"],
initialization_commands=config["initialization_commands"],
setup_commands=init_commands,
ray_start_commands=ray_start_commands,
runtime_hash=runtime_hash,
)
updater.start()
@@ -279,7 +284,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
modifiers = ""
print("To monitor auto-scaling activity, you can run:\n\n"
" ray exec {} {}{}{}\n".format(
config_file, "--docker " if use_docker else " ",
config_file, "--docker " if use_docker else "",
quote(monitor_str), modifiers))
print("To open a console on the cluster:\n\n"
" ray attach {}{}\n".format(config_file, modifiers))
@@ -292,12 +297,14 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
provider.cleanup()
def attach_cluster(config_file, start, use_tmux, override_cluster_name, new):
def attach_cluster(config_file, start, use_screen, use_tmux,
override_cluster_name, new):
"""Attaches to a screen for the specified cluster.
Arguments:
config_file: path to the cluster yaml
start: whether to start the cluster if it isn't up
use_screen: whether to use screen as multiplexer
use_tmux: whether to use tmux as multiplexer
override_cluster_name: set the name of the cluster
new: whether to force a new screen
@@ -308,11 +315,16 @@ def attach_cluster(config_file, start, use_tmux, override_cluster_name, new):
cmd = "tmux new"
else:
cmd = "tmux attach || tmux new"
else:
elif use_screen:
if new:
cmd = "screen -L"
else:
cmd = "screen -L -xRR"
else:
if new:
raise ValueError(
"--new only makes sense if passing --screen or --tmux")
cmd = "$SHELL"
exec_cluster(config_file, cmd, False, False, False, False, start,
override_cluster_name, None)
@@ -354,6 +366,7 @@ def exec_cluster(config_file, cmd, docker, screen, tmux, stop, start,
file_mounts=config["file_mounts"],
initialization_commands=[],
setup_commands=[],
ray_start_commands=[],
runtime_hash="",
)
@@ -447,6 +460,7 @@ def rsync(config_file, source, target, override_cluster_name, down):
file_mounts=config["file_mounts"],
initialization_commands=[],
setup_commands=[],
ray_start_commands=[],
runtime_hash="",
)
if down:
+25 -12
View File
@@ -53,6 +53,7 @@ class NodeUpdater(object):
file_mounts,
initialization_commands,
setup_commands,
ray_start_commands,
runtime_hash,
process_runner=subprocess,
exit_on_update_fail=False,
@@ -80,6 +81,7 @@ class NodeUpdater(object):
}
self.initialization_commands = initialization_commands
self.setup_commands = setup_commands
self.ray_start_commands = ray_start_commands
self.exit_on_update_fail = exit_on_update_fail
self.runtime_hash = runtime_hash
@@ -222,21 +224,32 @@ class NodeUpdater(object):
ssh_ok = self.wait_for_ssh(deadline)
assert ssh_ok, "Unable to SSH to node"
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "syncing-files"})
self.sync_file_mounts(self.rsync_up)
node_tags = self.provider.node_tags(self.node_id)
if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash:
logger.info(
"NodeUpdater: {} already up-to-date, skip to ray start".format(
self.node_id))
else:
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "syncing-files"})
self.sync_file_mounts(self.rsync_up)
# Run init commands
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "setting-up"})
m = "{}: Initialization commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.initialization_commands:
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
# Run init commands
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "setting-up"})
m = "{}: Initialization commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.initialization_commands:
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
m = "{}: Setup commands completed".format(self.node_id)
m = "{}: Setup commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.setup_commands:
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
m = "{}: Ray start commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.setup_commands:
for cmd in self.ray_start_commands:
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
def rsync_up(self, source, target, redirect=None):
+4 -2
View File
@@ -568,6 +568,8 @@ def monitor(cluster_config_file, lines, cluster_name):
is_flag=True,
default=False,
help="Start the cluster if needed.")
@click.option(
"--screen", is_flag=True, default=False, help="Run the command in screen.")
@click.option(
"--tmux", is_flag=True, default=False, help="Run the command in tmux.")
@click.option(
@@ -578,8 +580,8 @@ def monitor(cluster_config_file, lines, cluster_name):
help="Override the configured cluster name.")
@click.option(
"--new", "-N", is_flag=True, help="Force creation of a new screen.")
def attach(cluster_config_file, start, tmux, cluster_name, new):
attach_cluster(cluster_config_file, start, tmux, cluster_name, new)
def attach(cluster_config_file, start, screen, tmux, cluster_name, new):
attach_cluster(cluster_config_file, start, screen, tmux, cluster_name, new)
@cli.command()