diff --git a/python/ray/autoscaler/autoscaler.py b/python/ray/autoscaler/autoscaler.py index 75f0b19b0..fc25046b0 100644 --- a/python/ray/autoscaler/autoscaler.py +++ b/python/ray/autoscaler/autoscaler.py @@ -22,7 +22,8 @@ from ray.autoscaler.node_provider import get_node_provider, \ get_default_config from ray.autoscaler.tags import (TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE, - TAG_RAY_NODE_NAME) + TAG_RAY_NODE_NAME, STATUS_UP_TO_DATE, + STATUS_UNINITIALIZED, NODE_TYPE_WORKER) from ray.autoscaler.updater import NodeUpdaterThread from ray.ray_constants import AUTOSCALER_MAX_NUM_FAILURES, \ AUTOSCALER_MAX_LAUNCH_BATCH, AUTOSCALER_MAX_CONCURRENT_LAUNCHES, \ @@ -315,7 +316,7 @@ class NodeLauncher(threading.Thread): super(NodeLauncher, self).__init__(*args, **kwargs) def _launch_node(self, config, count): - worker_filter = {TAG_RAY_NODE_TYPE: "worker"} + worker_filter = {TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER} before = self.provider.non_terminated_nodes(tag_filters=worker_filter) launch_hash = hash_launch_conf(config["worker_nodes"], config["auth"]) self.log("Launching {} nodes.".format(count)) @@ -323,8 +324,8 @@ class NodeLauncher(threading.Thread): config["worker_nodes"], { TAG_RAY_NODE_NAME: "ray-{}-worker".format( config["cluster_name"]), - TAG_RAY_NODE_TYPE: "worker", - TAG_RAY_NODE_STATUS: "uninitialized", + TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER, + TAG_RAY_NODE_STATUS: STATUS_UNINITIALIZED, TAG_RAY_LAUNCH_CONFIG: launch_hash, }, count) after = self.provider.non_terminated_nodes(tag_filters=worker_filter) @@ -547,18 +548,10 @@ class StandardAutoscaler(object): self.log_info_string(nodes, target_workers) # Update nodes with out-of-date files - T = [ - threading.Thread( - target=self.spawn_updater, - args=(node_id, commands, ray_start), - ) for node_id, commands, ray_start in (self.should_update(node_id) - for node_id in nodes) - if node_id is not None - ] - for t in T: - t.start() - for t in T: - t.join() + for node_id, commands, ray_start in (self.should_update(node_id) + for node_id in nodes): + if node_id is not None: + self.spawn_updater(node_id, commands, ray_start) # Attempt to recover unhealthy nodes for node_id in nodes: @@ -664,10 +657,11 @@ class StandardAutoscaler(object): def should_update(self, node_id): if not self.can_update(node_id): - return (None, None, None) + return None, None, None # no update - if self.files_up_to_date(node_id): - return (None, None, None) + status = self.provider.node_tags(node_id).get(TAG_RAY_NODE_STATUS) + if status == STATUS_UP_TO_DATE and self.files_up_to_date(node_id): + return None, None, None # no update successful_updated = self.num_successful_updates.get(node_id, 0) > 0 if successful_updated and self.config.get("restart_only", False): @@ -718,7 +712,7 @@ class StandardAutoscaler(object): def workers(self): return self.provider.non_terminated_nodes( - tag_filters={TAG_RAY_NODE_TYPE: "worker"}) + tag_filters={TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER}) def log_info_string(self, nodes, target): logger.info("StandardAutoscaler: {}".format( diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index 9f3544e5d..24bc1a2b5 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -23,7 +23,7 @@ from ray.autoscaler.autoscaler import validate_config, hash_runtime_conf, \ hash_launch_conf, fillout_defaults from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \ - TAG_RAY_NODE_NAME + TAG_RAY_NODE_NAME, NODE_TYPE_WORKER, NODE_TYPE_HEAD from ray.autoscaler.updater import NodeUpdaterThread from ray.autoscaler.log_timer import LogTimer from ray.autoscaler.docker import with_docker_exec @@ -91,13 +91,13 @@ def teardown_cluster(config_file, yes, workers_only, override_cluster_name): else: A = [ node_id for node_id in provider.non_terminated_nodes({ - TAG_RAY_NODE_TYPE: "head" + TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD }) ] A += [ node_id for node_id in provider.non_terminated_nodes({ - TAG_RAY_NODE_TYPE: "worker" + TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER }) ] return A @@ -128,7 +128,9 @@ def kill_node(config_file, yes, hard, override_cluster_name): provider = get_node_provider(config["provider"], config["cluster_name"]) try: - nodes = provider.non_terminated_nodes({TAG_RAY_NODE_TYPE: "worker"}) + nodes = provider.non_terminated_nodes({ + TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER + }) node = random.choice(nodes) logger.info("kill_node: Shutdown worker {}".format(node)) if hard: @@ -174,7 +176,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, config_file = os.path.abspath(config_file) try: head_node_tags = { - TAG_RAY_NODE_TYPE: "head", + TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD, } nodes = provider.non_terminated_nodes(head_node_tags) if len(nodes) > 0: @@ -506,7 +508,9 @@ def get_worker_node_ips(config_file, override_cluster_name): provider = get_node_provider(config["provider"], config["cluster_name"]) try: - nodes = provider.non_terminated_nodes({TAG_RAY_NODE_TYPE: "worker"}) + nodes = provider.non_terminated_nodes({ + TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER + }) if config.get("provider", {}).get("use_internal_ips", False) is True: return [provider.internal_ip(node) for node in nodes] @@ -523,7 +527,7 @@ def _get_head_node(config, provider = get_node_provider(config["provider"], config["cluster_name"]) try: head_node_tags = { - TAG_RAY_NODE_TYPE: "head", + TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD, } nodes = provider.non_terminated_nodes(head_node_tags) finally: diff --git a/python/ray/autoscaler/local/node_provider.py b/python/ray/autoscaler/local/node_provider.py index a298bf26c..5996c0ca9 100644 --- a/python/ray/autoscaler/local/node_provider.py +++ b/python/ray/autoscaler/local/node_provider.py @@ -10,7 +10,8 @@ import socket import logging from ray.autoscaler.node_provider import NodeProvider -from ray.autoscaler.tags import TAG_RAY_NODE_TYPE +from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, NODE_TYPE_WORKER, \ + NODE_TYPE_HEAD logger = logging.getLogger(__name__) @@ -29,8 +30,9 @@ class ClusterState(object): if os.path.exists(self.save_path): workers = json.loads(open(self.save_path).read()) head_config = workers.get(provider_config["head_ip"]) - if not head_config or head_config.get( - "tags", {}).get(TAG_RAY_NODE_TYPE) != "head": + if (not head_config or + head_config.get("tags", {}).get(TAG_RAY_NODE_TYPE) + != NODE_TYPE_HEAD): workers = {} logger.info("Head IP changed - recreating cluster.") else: @@ -41,23 +43,23 @@ class ClusterState(object): if worker_ip not in workers: workers[worker_ip] = { "tags": { - TAG_RAY_NODE_TYPE: "worker" + TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER }, "state": "terminated", } else: assert workers[worker_ip]["tags"][ - TAG_RAY_NODE_TYPE] == "worker" + TAG_RAY_NODE_TYPE] == NODE_TYPE_WORKER if provider_config["head_ip"] not in workers: workers[provider_config["head_ip"]] = { "tags": { - TAG_RAY_NODE_TYPE: "head" + TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD }, "state": "terminated", } else: assert workers[provider_config["head_ip"]]["tags"][ - TAG_RAY_NODE_TYPE] == "head" + TAG_RAY_NODE_TYPE] == NODE_TYPE_HEAD assert len(workers) == len(provider_config["worker_ips"]) + 1 with open(self.save_path, "w") as f: logger.debug("ClusterState: " diff --git a/python/ray/autoscaler/tags.py b/python/ray/autoscaler/tags.py index 1912d675b..d96f542ee 100644 --- a/python/ray/autoscaler/tags.py +++ b/python/ray/autoscaler/tags.py @@ -9,9 +9,17 @@ TAG_RAY_NODE_NAME = "ray-node-name" # Tag for the type of node (e.g. Head, Worker) TAG_RAY_NODE_TYPE = "ray-node-type" +NODE_TYPE_HEAD = "head" +NODE_TYPE_WORKER = "worker" # Tag that reports the current state of the node (e.g. Updating, Up-to-date) TAG_RAY_NODE_STATUS = "ray-node-status" +STATUS_UNINITIALIZED = "uninitialized" +STATUS_WAITING_FOR_SSH = "waiting-for-ssh" +STATUS_SYNCING_FILES = "syncing-files" +STATUS_SETTING_UP = "setting-up" +STATUS_UPDATE_FAILED = "update-failed" +STATUS_UP_TO_DATE = "up-to-date" # Tag uniquely identifying all nodes of a cluster TAG_RAY_CLUSTER_NAME = "ray-cluster-name" diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index 6402b9ecd..7e9934898 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -16,7 +16,9 @@ import time from threading import Thread from getpass import getuser -from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG +from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG, \ + STATUS_UP_TO_DATE, STATUS_UPDATE_FAILED, STATUS_WAITING_FOR_SSH, \ + STATUS_SETTING_UP, STATUS_SYNCING_FILES from ray.autoscaler.log_timer import LogTimer logger = logging.getLogger(__name__) @@ -56,7 +58,6 @@ class NodeUpdater(object): ray_start_commands, runtime_hash, process_runner=subprocess, - exit_on_update_fail=False, use_internal_ip=False): ssh_control_hash = hashlib.md5(cluster_name.encode()).hexdigest() @@ -82,7 +83,6 @@ class NodeUpdater(object): self.initialization_commands = initialization_commands self.setup_commands = setup_commands self.ray_start_commands = ray_start_commands - self.exit_on_update_fail = exit_on_update_fail self.runtime_hash = runtime_hash def get_node_ip(self): @@ -152,13 +152,13 @@ class NodeUpdater(object): logger.error("NodeUpdater: " "{}: Error updating {}".format( self.node_id, error_str)) - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "update-failed"}) + self.provider.set_node_tags( + self.node_id, {TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED}) raise e self.provider.set_node_tags( self.node_id, { - TAG_RAY_NODE_STATUS: "up-to-date", + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_RUNTIME_CONFIG: self.runtime_hash }) @@ -213,8 +213,8 @@ class NodeUpdater(object): sync_cmd(local_path, remote_path, redirect=None) def do_update(self): - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "waiting-for-ssh"}) + self.provider.set_node_tags( + self.node_id, {TAG_RAY_NODE_STATUS: STATUS_WAITING_FOR_SSH}) deadline = time.time() + NODE_START_WAIT_S self.set_ssh_ip_if_required() @@ -230,27 +230,27 @@ class NodeUpdater(object): "NodeUpdater: {} already up-to-date, skip to ray start".format( self.node_id)) else: - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "syncing-files"}) + self.provider.set_node_tags( + self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SYNCING_FILES}) self.sync_file_mounts(self.rsync_up) # Run init commands - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "setting-up"}) + self.provider.set_node_tags( + self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SETTING_UP}) m = "{}: Initialization commands completed".format(self.node_id) with LogTimer("NodeUpdater: {}".format(m)): for cmd in self.initialization_commands: - self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail) + self.ssh_cmd(cmd) m = "{}: Setup commands completed".format(self.node_id) with LogTimer("NodeUpdater: {}".format(m)): for cmd in self.setup_commands: - self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail) + self.ssh_cmd(cmd) m = "{}: Ray start commands completed".format(self.node_id) with LogTimer("NodeUpdater: {}".format(m)): for cmd in self.ray_start_commands: - self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail) + self.ssh_cmd(cmd) def rsync_up(self, source, target, redirect=None): logger.info("NodeUpdater: " @@ -321,6 +321,8 @@ class NodeUpdater(object): stderr=redirect or sys.stderr) except subprocess.CalledProcessError: if exit_on_fail: + # Only reason we need this exit flag here is because here we + # know the final command and can print it nicely before exit() logger.error("Command failed: \n\n {}\n".format( " ".join(final_cmd))) sys.exit(1) diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index a6ebcaad3..bb1524c68 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -2,7 +2,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from flaky import flaky import shutil import tempfile import threading @@ -15,7 +14,8 @@ import ray import ray.services as services from ray.autoscaler.autoscaler import StandardAutoscaler, LoadMetrics, \ fillout_defaults, validate_config -from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_NODE_STATUS +from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_NODE_STATUS, \ + STATUS_UP_TO_DATE, STATUS_UPDATE_FAILED from ray.autoscaler.node_provider import NODE_PROVIDERS, NodeProvider from ray.tests.utils import RayTestTimeoutException import pytest @@ -47,22 +47,53 @@ class MockProcessRunner(object): raise Exception("Failing command on purpose") self.calls.append(cmd) + def assert_has_call(self, ip, pattern): + out = "" + for cmd in self.calls: + msg = " ".join(cmd) + if ip in msg: + out += msg + out += "\n" + if pattern in out: + return True + else: + raise Exception("Did not find [{}] in [{}] for {}".format( + pattern, out, ip)) + + def assert_not_has_call(self, ip, pattern): + out = "" + for cmd in self.calls: + msg = " ".join(cmd) + if ip in msg: + out += msg + out += "\n" + if pattern in out: + raise Exception("Found [{}] in [{}] for {}".format( + pattern, out, ip)) + else: + return True + + def clear_history(self): + self.calls = [] + class MockProvider(NodeProvider): - def __init__(self): + def __init__(self, cache_stopped=False): self.mock_nodes = {} self.next_id = 0 self.throw = False self.fail_creates = False self.ready_to_create = threading.Event() self.ready_to_create.set() + self.cache_stopped = cache_stopped def non_terminated_nodes(self, tag_filters): if self.throw: raise Exception("oops") return [ n.node_id for n in self.mock_nodes.values() - if n.matches(tag_filters) and n.state != "terminated" + if n.matches(tag_filters) + and n.state not in ["stopped", "terminated"] ] def non_terminated_node_ips(self, tag_filters): @@ -70,14 +101,15 @@ class MockProvider(NodeProvider): raise Exception("oops") return [ n.internal_ip for n in self.mock_nodes.values() - if n.matches(tag_filters) and n.state != "terminated" + if n.matches(tag_filters) + and n.state not in ["stopped", "terminated"] ] def is_running(self, node_id): return self.mock_nodes[node_id].state == "running" def is_terminated(self, node_id): - return self.mock_nodes[node_id].state == "terminated" + return self.mock_nodes[node_id].state in ["stopped", "terminated"] def node_tags(self, node_id): return self.mock_nodes[node_id].tags @@ -92,15 +124,29 @@ class MockProvider(NodeProvider): self.ready_to_create.wait() if self.fail_creates: return + if self.cache_stopped: + for node in self.mock_nodes.values(): + if node.state == "stopped" and count > 0: + count -= 1 + node.state = "pending" + node.tags.update(tags) for _ in range(count): - self.mock_nodes[self.next_id] = MockNode(self.next_id, tags) + self.mock_nodes[self.next_id] = MockNode(self.next_id, tags.copy()) self.next_id += 1 def set_node_tags(self, node_id, tags): self.mock_nodes[node_id].tags.update(tags) def terminate_node(self, node_id): - self.mock_nodes[node_id].state = "terminated" + if self.cache_stopped: + self.mock_nodes[node_id].state = "stopped" + else: + self.mock_nodes[node_id].state = "terminated" + + def finish_starting_nodes(self): + for node in self.mock_nodes.values(): + if node.state == "pending": + node.state = "running" SMALL_CLUSTER = { @@ -131,10 +177,10 @@ SMALL_CLUSTER = { "TestProp": 2, }, "file_mounts": {}, - "initialization_commands": ["cmd0"], - "setup_commands": ["cmd1"], - "head_setup_commands": ["cmd2"], - "worker_setup_commands": ["cmd3"], + "initialization_commands": ["init_cmd"], + "setup_commands": ["setup_cmd"], + "head_setup_commands": ["head_setup_cmd"], + "worker_setup_commands": ["worker_setup_cmd"], "head_start_ray_commands": ["start_ray_head"], "worker_start_ray_commands": ["start_ray_worker"], } @@ -344,11 +390,13 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() self.waitForNodes(0) autoscaler.request_resources({"CPU": cores_per_node * 10}) - for _ in range(3): # Maximum launch batch is 5 + for _ in range(5): # Maximum launch batch is 5 + time.sleep(0.01) autoscaler.update() self.waitForNodes(10) autoscaler.request_resources({"CPU": cores_per_node * 30}) for _ in range(4): # Maximum launch batch is 5 + time.sleep(0.01) autoscaler.update() self.waitForNodes(30) @@ -700,10 +748,10 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() autoscaler.update() self.waitForNodes(2) - for node in self.provider.mock_nodes.values(): - node.state = "running" + self.provider.finish_starting_nodes() autoscaler.update() - self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"}) + self.waitForNodes( + 2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE}) def testReportsConfigFailures(self): config = copy.deepcopy(SMALL_CLUSTER) @@ -712,7 +760,7 @@ class AutoscalingTest(unittest.TestCase): config["provider"]["type"] = "mock" config_path = self.write_config(config) self.provider = MockProvider() - runner = MockProcessRunner(fail_cmds=["cmd1"]) + runner = MockProcessRunner(fail_cmds=["setup_cmd"]) autoscaler = StandardAutoscaler( config_path, LoadMetrics(), @@ -722,11 +770,10 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() autoscaler.update() self.waitForNodes(2) - for node in self.provider.mock_nodes.values(): - node.state = "running" + self.provider.finish_starting_nodes() autoscaler.update() self.waitForNodes( - 2, tag_filters={TAG_RAY_NODE_STATUS: "update-failed"}) + 2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED}) def testConfiguresOutdatedNodes(self): config_path = self.write_config(SMALL_CLUSTER) @@ -741,10 +788,10 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() autoscaler.update() self.waitForNodes(2) - for node in self.provider.mock_nodes.values(): - node.state = "running" + self.provider.finish_starting_nodes() autoscaler.update() - self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"}) + self.waitForNodes( + 2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE}) runner.calls = [] new_config = SMALL_CLUSTER.copy() new_config["worker_setup_commands"] = ["cmdX", "cmdY"] @@ -843,7 +890,6 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() assert len(self.provider.non_terminated_nodes({})) == 0 - @flaky(max_runs=4) def testRecoverUnhealthyWorkers(self): config_path = self.write_config(SMALL_CLUSTER) self.provider = MockProvider() @@ -857,14 +903,19 @@ class AutoscalingTest(unittest.TestCase): update_interval_s=0) autoscaler.update() self.waitForNodes(2) - for node in self.provider.mock_nodes.values(): - node.state = "running" + self.provider.finish_starting_nodes() autoscaler.update() - self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"}) + self.waitForNodes( + 2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE}) # Mark a node as unhealthy - lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0 + for _ in range(5): + if autoscaler.updaters: + time.sleep(0.05) + autoscaler.update() + assert not autoscaler.updaters num_calls = len(runner.calls) + lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0 autoscaler.update() self.waitFor(lambda: len(runner.calls) > num_calls, num_retries=150) @@ -901,6 +952,136 @@ class AutoscalingTest(unittest.TestCase): StandardAutoscaler( invalid_provider, LoadMetrics(), update_interval_s=0) + def testSetupCommandsWithNoNodeCaching(self): + config = SMALL_CLUSTER.copy() + config["min_workers"] = 1 + config["max_workers"] = 1 + config_path = self.write_config(config) + self.provider = MockProvider(cache_stopped=False) + runner = MockProcessRunner() + lm = LoadMetrics() + autoscaler = StandardAutoscaler( + config_path, + lm, + max_failures=0, + process_runner=runner, + update_interval_s=0) + autoscaler.update() + self.waitForNodes(1) + self.provider.finish_starting_nodes() + autoscaler.update() + self.waitForNodes( + 1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE}) + runner.assert_has_call("172.0.0.0", "init_cmd") + runner.assert_has_call("172.0.0.0", "setup_cmd") + runner.assert_has_call("172.0.0.0", "worker_setup_cmd") + runner.assert_has_call("172.0.0.0", "start_ray_worker") + + # Check the node was not reused + self.provider.terminate_node(0) + autoscaler.update() + self.waitForNodes(1) + runner.clear_history() + self.provider.finish_starting_nodes() + autoscaler.update() + self.waitForNodes( + 1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE}) + runner.assert_has_call("172.0.0.1", "init_cmd") + runner.assert_has_call("172.0.0.1", "setup_cmd") + runner.assert_has_call("172.0.0.1", "worker_setup_cmd") + runner.assert_has_call("172.0.0.1", "start_ray_worker") + + def testSetupCommandsWithStoppedNodeCaching(self): + config = SMALL_CLUSTER.copy() + config["min_workers"] = 1 + config["max_workers"] = 1 + config_path = self.write_config(config) + self.provider = MockProvider(cache_stopped=True) + runner = MockProcessRunner() + lm = LoadMetrics() + autoscaler = StandardAutoscaler( + config_path, + lm, + max_failures=0, + process_runner=runner, + update_interval_s=0) + autoscaler.update() + self.waitForNodes(1) + self.provider.finish_starting_nodes() + autoscaler.update() + self.waitForNodes( + 1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE}) + runner.assert_has_call("172.0.0.0", "init_cmd") + runner.assert_has_call("172.0.0.0", "setup_cmd") + runner.assert_has_call("172.0.0.0", "worker_setup_cmd") + runner.assert_has_call("172.0.0.0", "start_ray_worker") + + # Check the node was indeed reused + self.provider.terminate_node(0) + autoscaler.update() + self.waitForNodes(1) + runner.clear_history() + self.provider.finish_starting_nodes() + autoscaler.update() + self.waitForNodes( + 1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE}) + runner.assert_not_has_call("172.0.0.0", "init_cmd") + runner.assert_not_has_call("172.0.0.0", "setup_cmd") + runner.assert_not_has_call("172.0.0.0", "worker_setup_cmd") + runner.assert_has_call("172.0.0.0", "start_ray_worker") + + runner.clear_history() + autoscaler.update() + runner.assert_not_has_call("172.0.0.0", "setup_cmd") + + # We did not start any other nodes + runner.assert_not_has_call("172.0.0.1", " ") + + def testMultiNodeReuse(self): + config = SMALL_CLUSTER.copy() + config["min_workers"] = 3 + config["max_workers"] = 3 + config_path = self.write_config(config) + self.provider = MockProvider(cache_stopped=True) + runner = MockProcessRunner() + lm = LoadMetrics() + autoscaler = StandardAutoscaler( + config_path, + lm, + max_failures=0, + process_runner=runner, + update_interval_s=0) + autoscaler.update() + self.waitForNodes(3) + self.provider.finish_starting_nodes() + autoscaler.update() + self.waitForNodes( + 3, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE}) + + self.provider.terminate_node(0) + self.provider.terminate_node(1) + self.provider.terminate_node(2) + runner.clear_history() + + # Scale up to 10 nodes, check we reuse the first 3 and add 7 more. + config["min_workers"] = 10 + config["max_workers"] = 10 + self.write_config(config) + autoscaler.update() + autoscaler.update() + self.waitForNodes(10) + self.provider.finish_starting_nodes() + autoscaler.update() + self.waitForNodes( + 10, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE}) + autoscaler.update() + for i in [0, 1, 2]: + runner.assert_not_has_call("172.0.0.{}".format(i), "setup_cmd") + runner.assert_has_call("172.0.0.{}".format(i), "start_ray_worker") + for i in [3, 4, 5, 6, 7, 8, 9]: + runner.assert_has_call("172.0.0.{}".format(i), "setup_cmd") + runner.assert_has_call("172.0.0.{}".format(i), "start_ray_worker") + if __name__ == "__main__": unittest.main(verbosity=2)