diff --git a/python/ray/autoscaler/autoscaler.py b/python/ray/autoscaler/autoscaler.py index af6719019..75f0b19b0 100644 --- a/python/ray/autoscaler/autoscaler.py +++ b/python/ray/autoscaler/autoscaler.py @@ -77,6 +77,11 @@ CLUSTER_CONFIG_SCHEMA = { "worker_ips": (list, OPTIONAL), # local cluster worker nodes "use_internal_ips": (bool, OPTIONAL), # don't require public ips "extra_config": (dict, OPTIONAL), # provider-specific config + + # Whether to try to reuse previously stopped nodes instead of + # launching nodes. This will also cause the autoscaler to stop + # nodes instead of terminating them. Only implemented for AWS. + "cache_stopped_nodes": (bool, OPTIONAL), }, REQUIRED), @@ -545,9 +550,9 @@ class StandardAutoscaler(object): T = [ threading.Thread( target=self.spawn_updater, - args=(node_id, commands), - ) for node_id, commands in (self.should_update(node_id) - for node_id in nodes) + args=(node_id, commands, ray_start), + ) for node_id, commands, ray_start in (self.should_update(node_id) + for node_id in nodes) if node_id is not None ] for t in T: @@ -648,7 +653,8 @@ class StandardAutoscaler(object): cluster_name=self.config["cluster_name"], file_mounts={}, initialization_commands=[], - setup_commands=with_head_node_ip( + setup_commands=[], + ray_start_commands=with_head_node_ip( self.config["worker_start_ray_commands"]), runtime_hash=self.runtime_hash, process_runner=self.process_runner, @@ -658,23 +664,25 @@ class StandardAutoscaler(object): def should_update(self, node_id): if not self.can_update(node_id): - return (None, None) + return (None, None, None) if self.files_up_to_date(node_id): - return (None, None) + return (None, None, None) successful_updated = self.num_successful_updates.get(node_id, 0) > 0 if successful_updated and self.config.get("restart_only", False): - init_commands = self.config["worker_start_ray_commands"] + init_commands = [] + ray_commands = self.config["worker_start_ray_commands"] elif successful_updated and self.config.get("no_restart", False): init_commands = self.config["worker_setup_commands"] + ray_commands = [] else: - init_commands = (self.config["worker_setup_commands"] + - self.config["worker_start_ray_commands"]) + init_commands = self.config["worker_setup_commands"] + ray_commands = self.config["worker_start_ray_commands"] - return (node_id, init_commands) + return (node_id, init_commands, ray_commands) - def spawn_updater(self, node_id, init_commands): + def spawn_updater(self, node_id, init_commands, ray_start_commands): updater = NodeUpdaterThread( node_id=node_id, provider_config=self.config["provider"], @@ -685,6 +693,7 @@ class StandardAutoscaler(object): initialization_commands=with_head_node_ip( self.config["initialization_commands"]), setup_commands=with_head_node_ip(init_commands), + ray_start_commands=with_head_node_ip(ray_start_commands), runtime_hash=self.runtime_hash, process_runner=self.process_runner, use_internal_ip=True) diff --git a/python/ray/autoscaler/aws/node_provider.py b/python/ray/autoscaler/aws/node_provider.py index 387c250c9..c5f5438f1 100644 --- a/python/ray/autoscaler/aws/node_provider.py +++ b/python/ray/autoscaler/aws/node_provider.py @@ -12,7 +12,8 @@ import botocore from botocore.config import Config from ray.autoscaler.node_provider import NodeProvider -from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME +from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME, \ + TAG_RAY_LAUNCH_CONFIG, TAG_RAY_NODE_TYPE from ray.ray_constants import BOTO_MAX_RETRIES from ray.autoscaler.log_timer import LogTimer @@ -41,6 +42,8 @@ class AWSNodeProvider(NodeProvider): def __init__(self, provider_config, cluster_name): NodeProvider.__init__(self, provider_config, cluster_name) config = Config(retries={"max_attempts": BOTO_MAX_RETRIES}) + self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", + True) self.ec2 = boto3.resource( "ec2", region_name=provider_config["region"], config=config) @@ -61,7 +64,7 @@ class AWSNodeProvider(NodeProvider): self.cached_nodes = {} def _node_tag_update_loop(self): - """ Update the AWS tags for a cluster periodically. + """Update the AWS tags for a cluster periodically. The purpose of this loop is to avoid excessive EC2 calls when a large number of nodes are being launched simultaneously. @@ -171,6 +174,55 @@ class AWSNodeProvider(NodeProvider): self.tag_cache_update_event.set() def create_node(self, node_config, tags, count): + # Try to reuse previously stopped nodes with compatible configs + if self.cache_stopped_nodes: + filters = [ + { + "Name": "instance-state-name", + "Values": ["stopped", "stopping"], + }, + { + "Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME), + "Values": [self.cluster_name], + }, + { + "Name": "tag:{}".format(TAG_RAY_NODE_TYPE), + "Values": [tags[TAG_RAY_NODE_TYPE]], + }, + { + "Name": "tag:{}".format(TAG_RAY_LAUNCH_CONFIG), + "Values": [tags[TAG_RAY_LAUNCH_CONFIG]], + }, + ] + + reuse_nodes = list( + self.ec2.instances.filter(Filters=filters))[:count] + reuse_node_ids = [n.id for n in reuse_nodes] + if reuse_nodes: + logger.info("AWSNodeProvider: reusing instances {}. " + "To disable reuse, set " + "'cache_stopped_nodes: False' in the provider " + "config.".format(reuse_node_ids)) + + for node in reuse_nodes: + self.tag_cache[node.id] = from_aws_format( + {x["Key"]: x["Value"] + for x in node.tags}) + if node.state["Name"] == "stopping": + logger.info("AWSNodeProvider: waiting for instance " + "{} to fully stop...".format(node.id)) + node.wait_until_stopped() + + self.ec2.meta.client.start_instances( + InstanceIds=reuse_node_ids) + for node_id in reuse_node_ids: + self.set_node_tags(node_id, tags) + count -= len(reuse_node_ids) + + if count: + self._create_node(node_config, tags, count) + + def _create_node(self, node_config, tags, count): tags = to_aws_format(tags) conf = node_config.copy() @@ -248,13 +300,43 @@ class AWSNodeProvider(NodeProvider): def terminate_node(self, node_id): node = self._get_cached_node(node_id) - node.terminate() + if self.cache_stopped_nodes: + if node.spot_instance_request_id: + logger.info( + "AWSNodeProvider: terminating node {} (spot nodes cannot " + "be stopped, only terminated)".format(node_id)) + node.terminate() + else: + logger.info( + "AWSNodeProvider: stopping node {}. To terminate nodes " + "on stop, set 'cache_stopped_nodes: False' in the " + "provider config.".format(node_id)) + node.stop() + else: + node.terminate() self.tag_cache.pop(node_id, None) self.tag_cache_pending.pop(node_id, None) def terminate_nodes(self, node_ids): - self.ec2.meta.client.terminate_instances(InstanceIds=node_ids) + if not node_ids: + return + + node0 = self._get_cached_node(node_ids[0]) + if self.cache_stopped_nodes: + if node0.spot_instance_request_id: + logger.info( + "AWSNodeProvider: terminating nodes {} (spot nodes cannot " + "be stopped, only terminated)".format(node_ids)) + self.ec2.meta.client.terminate_instances(InstanceIds=node_ids) + else: + logger.info( + "AWSNodeProvider: stopping nodes {}. To terminate nodes " + "on stop, set 'cache_stopped_nodes: False' in the " + "provider config.".format(node_ids)) + self.ec2.meta.client.stop_instances(InstanceIds=node_ids) + else: + self.ec2.meta.client.terminate_instances(InstanceIds=node_ids) for node_id in node_ids: self.tag_cache.pop(node_id, None) diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index 91718d8fc..9f3544e5d 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -105,10 +105,10 @@ def teardown_cluster(config_file, yes, workers_only, override_cluster_name): # Loop here to check that both the head and worker nodes are actually # really gone A = remaining_nodes() - with LogTimer("teardown_cluster: Termination done."): + with LogTimer("teardown_cluster: done."): while A: logger.info("teardown_cluster: " - "Terminating {} nodes...".format(len(A))) + "Shutting down {} nodes...".format(len(A))) provider.terminate_nodes(A) time.sleep(1) A = remaining_nodes() @@ -130,7 +130,7 @@ def kill_node(config_file, yes, hard, override_cluster_name): try: nodes = provider.non_terminated_nodes({TAG_RAY_NODE_TYPE: "worker"}) node = random.choice(nodes) - logger.info("kill_node: Terminating worker {}".format(node)) + logger.info("kill_node: Shutdown worker {}".format(node)) if hard: provider.terminate_node(node) else: @@ -143,6 +143,7 @@ def kill_node(config_file, yes, hard, override_cluster_name): file_mounts=config["file_mounts"], initialization_commands=[], setup_commands=[], + ray_start_commands=[], runtime_hash="") _exec(updater, "ray stop", False, False) @@ -170,6 +171,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, override_cluster_name): """Create the cluster head node, which in turn creates the workers.""" provider = get_node_provider(config["provider"], config["cluster_name"]) + config_file = os.path.abspath(config_file) try: head_node_tags = { TAG_RAY_NODE_TYPE: "head", @@ -193,7 +195,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, yes) logger.info( "get_or_create_head_node: " - "Terminating outdated head node {}".format(head_node)) + "Shutting down outdated head node {}".format(head_node)) provider.terminate_node(head_node) logger.info("get_or_create_head_node: Launching new head node...") head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash @@ -233,12 +235,14 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, }) if restart_only: - init_commands = config["head_start_ray_commands"] + init_commands = [] + ray_start_commands = config["head_start_ray_commands"] elif no_restart: init_commands = config["head_setup_commands"] + ray_start_commands = [] else: - init_commands = (config["head_setup_commands"] + - config["head_start_ray_commands"]) + init_commands = config["head_setup_commands"] + ray_start_commands = config["head_start_ray_commands"] updater = NodeUpdaterThread( node_id=head_node, @@ -249,6 +253,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, file_mounts=config["file_mounts"], initialization_commands=config["initialization_commands"], setup_commands=init_commands, + ray_start_commands=ray_start_commands, runtime_hash=runtime_hash, ) updater.start() @@ -279,7 +284,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, modifiers = "" print("To monitor auto-scaling activity, you can run:\n\n" " ray exec {} {}{}{}\n".format( - config_file, "--docker " if use_docker else " ", + config_file, "--docker " if use_docker else "", quote(monitor_str), modifiers)) print("To open a console on the cluster:\n\n" " ray attach {}{}\n".format(config_file, modifiers)) @@ -292,12 +297,14 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, provider.cleanup() -def attach_cluster(config_file, start, use_tmux, override_cluster_name, new): +def attach_cluster(config_file, start, use_screen, use_tmux, + override_cluster_name, new): """Attaches to a screen for the specified cluster. Arguments: config_file: path to the cluster yaml start: whether to start the cluster if it isn't up + use_screen: whether to use screen as multiplexer use_tmux: whether to use tmux as multiplexer override_cluster_name: set the name of the cluster new: whether to force a new screen @@ -308,11 +315,16 @@ def attach_cluster(config_file, start, use_tmux, override_cluster_name, new): cmd = "tmux new" else: cmd = "tmux attach || tmux new" - else: + elif use_screen: if new: cmd = "screen -L" else: cmd = "screen -L -xRR" + else: + if new: + raise ValueError( + "--new only makes sense if passing --screen or --tmux") + cmd = "$SHELL" exec_cluster(config_file, cmd, False, False, False, False, start, override_cluster_name, None) @@ -354,6 +366,7 @@ def exec_cluster(config_file, cmd, docker, screen, tmux, stop, start, file_mounts=config["file_mounts"], initialization_commands=[], setup_commands=[], + ray_start_commands=[], runtime_hash="", ) @@ -447,6 +460,7 @@ def rsync(config_file, source, target, override_cluster_name, down): file_mounts=config["file_mounts"], initialization_commands=[], setup_commands=[], + ray_start_commands=[], runtime_hash="", ) if down: diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index 146fe9dfb..6402b9ecd 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -53,6 +53,7 @@ class NodeUpdater(object): file_mounts, initialization_commands, setup_commands, + ray_start_commands, runtime_hash, process_runner=subprocess, exit_on_update_fail=False, @@ -80,6 +81,7 @@ class NodeUpdater(object): } self.initialization_commands = initialization_commands self.setup_commands = setup_commands + self.ray_start_commands = ray_start_commands self.exit_on_update_fail = exit_on_update_fail self.runtime_hash = runtime_hash @@ -222,21 +224,32 @@ class NodeUpdater(object): ssh_ok = self.wait_for_ssh(deadline) assert ssh_ok, "Unable to SSH to node" - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "syncing-files"}) - self.sync_file_mounts(self.rsync_up) + node_tags = self.provider.node_tags(self.node_id) + if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash: + logger.info( + "NodeUpdater: {} already up-to-date, skip to ray start".format( + self.node_id)) + else: + self.provider.set_node_tags(self.node_id, + {TAG_RAY_NODE_STATUS: "syncing-files"}) + self.sync_file_mounts(self.rsync_up) - # Run init commands - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "setting-up"}) - m = "{}: Initialization commands completed".format(self.node_id) - with LogTimer("NodeUpdater: {}".format(m)): - for cmd in self.initialization_commands: - self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail) + # Run init commands + self.provider.set_node_tags(self.node_id, + {TAG_RAY_NODE_STATUS: "setting-up"}) + m = "{}: Initialization commands completed".format(self.node_id) + with LogTimer("NodeUpdater: {}".format(m)): + for cmd in self.initialization_commands: + self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail) - m = "{}: Setup commands completed".format(self.node_id) + m = "{}: Setup commands completed".format(self.node_id) + with LogTimer("NodeUpdater: {}".format(m)): + for cmd in self.setup_commands: + self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail) + + m = "{}: Ray start commands completed".format(self.node_id) with LogTimer("NodeUpdater: {}".format(m)): - for cmd in self.setup_commands: + for cmd in self.ray_start_commands: self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail) def rsync_up(self, source, target, redirect=None): diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index f66fd8019..5f25b29e0 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -568,6 +568,8 @@ def monitor(cluster_config_file, lines, cluster_name): is_flag=True, default=False, help="Start the cluster if needed.") +@click.option( + "--screen", is_flag=True, default=False, help="Run the command in screen.") @click.option( "--tmux", is_flag=True, default=False, help="Run the command in tmux.") @click.option( @@ -578,8 +580,8 @@ def monitor(cluster_config_file, lines, cluster_name): help="Override the configured cluster name.") @click.option( "--new", "-N", is_flag=True, help="Force creation of a new screen.") -def attach(cluster_config_file, start, tmux, cluster_name, new): - attach_cluster(cluster_config_file, start, tmux, cluster_name, new) +def attach(cluster_config_file, start, screen, tmux, cluster_name, new): + attach_cluster(cluster_config_file, start, screen, tmux, cluster_name, new) @cli.command()