[autoscaler] Add unit tests for stopped node caching, fix flaky tests (#5793)

This commit is contained in:
Eric Liang
2019-09-27 22:36:09 -07:00
committed by GitHub
parent 86610a30c9
commit 493364d3bd
6 changed files with 268 additions and 77 deletions
+14 -20
View File
@@ -22,7 +22,8 @@ from ray.autoscaler.node_provider import get_node_provider, \
get_default_config
from ray.autoscaler.tags import (TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG,
TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE,
TAG_RAY_NODE_NAME)
TAG_RAY_NODE_NAME, STATUS_UP_TO_DATE,
STATUS_UNINITIALIZED, NODE_TYPE_WORKER)
from ray.autoscaler.updater import NodeUpdaterThread
from ray.ray_constants import AUTOSCALER_MAX_NUM_FAILURES, \
AUTOSCALER_MAX_LAUNCH_BATCH, AUTOSCALER_MAX_CONCURRENT_LAUNCHES, \
@@ -315,7 +316,7 @@ class NodeLauncher(threading.Thread):
super(NodeLauncher, self).__init__(*args, **kwargs)
def _launch_node(self, config, count):
worker_filter = {TAG_RAY_NODE_TYPE: "worker"}
worker_filter = {TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER}
before = self.provider.non_terminated_nodes(tag_filters=worker_filter)
launch_hash = hash_launch_conf(config["worker_nodes"], config["auth"])
self.log("Launching {} nodes.".format(count))
@@ -323,8 +324,8 @@ class NodeLauncher(threading.Thread):
config["worker_nodes"], {
TAG_RAY_NODE_NAME: "ray-{}-worker".format(
config["cluster_name"]),
TAG_RAY_NODE_TYPE: "worker",
TAG_RAY_NODE_STATUS: "uninitialized",
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER,
TAG_RAY_NODE_STATUS: STATUS_UNINITIALIZED,
TAG_RAY_LAUNCH_CONFIG: launch_hash,
}, count)
after = self.provider.non_terminated_nodes(tag_filters=worker_filter)
@@ -547,18 +548,10 @@ class StandardAutoscaler(object):
self.log_info_string(nodes, target_workers)
# Update nodes with out-of-date files
T = [
threading.Thread(
target=self.spawn_updater,
args=(node_id, commands, ray_start),
) for node_id, commands, ray_start in (self.should_update(node_id)
for node_id in nodes)
if node_id is not None
]
for t in T:
t.start()
for t in T:
t.join()
for node_id, commands, ray_start in (self.should_update(node_id)
for node_id in nodes):
if node_id is not None:
self.spawn_updater(node_id, commands, ray_start)
# Attempt to recover unhealthy nodes
for node_id in nodes:
@@ -664,10 +657,11 @@ class StandardAutoscaler(object):
def should_update(self, node_id):
if not self.can_update(node_id):
return (None, None, None)
return None, None, None # no update
if self.files_up_to_date(node_id):
return (None, None, None)
status = self.provider.node_tags(node_id).get(TAG_RAY_NODE_STATUS)
if status == STATUS_UP_TO_DATE and self.files_up_to_date(node_id):
return None, None, None # no update
successful_updated = self.num_successful_updates.get(node_id, 0) > 0
if successful_updated and self.config.get("restart_only", False):
@@ -718,7 +712,7 @@ class StandardAutoscaler(object):
def workers(self):
return self.provider.non_terminated_nodes(
tag_filters={TAG_RAY_NODE_TYPE: "worker"})
tag_filters={TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER})
def log_info_string(self, nodes, target):
logger.info("StandardAutoscaler: {}".format(
+11 -7
View File
@@ -23,7 +23,7 @@ from ray.autoscaler.autoscaler import validate_config, hash_runtime_conf, \
hash_launch_conf, fillout_defaults
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
TAG_RAY_NODE_NAME
TAG_RAY_NODE_NAME, NODE_TYPE_WORKER, NODE_TYPE_HEAD
from ray.autoscaler.updater import NodeUpdaterThread
from ray.autoscaler.log_timer import LogTimer
from ray.autoscaler.docker import with_docker_exec
@@ -91,13 +91,13 @@ def teardown_cluster(config_file, yes, workers_only, override_cluster_name):
else:
A = [
node_id for node_id in provider.non_terminated_nodes({
TAG_RAY_NODE_TYPE: "head"
TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD
})
]
A += [
node_id for node_id in provider.non_terminated_nodes({
TAG_RAY_NODE_TYPE: "worker"
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
})
]
return A
@@ -128,7 +128,9 @@ def kill_node(config_file, yes, hard, override_cluster_name):
provider = get_node_provider(config["provider"], config["cluster_name"])
try:
nodes = provider.non_terminated_nodes({TAG_RAY_NODE_TYPE: "worker"})
nodes = provider.non_terminated_nodes({
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
})
node = random.choice(nodes)
logger.info("kill_node: Shutdown worker {}".format(node))
if hard:
@@ -174,7 +176,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
config_file = os.path.abspath(config_file)
try:
head_node_tags = {
TAG_RAY_NODE_TYPE: "head",
TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD,
}
nodes = provider.non_terminated_nodes(head_node_tags)
if len(nodes) > 0:
@@ -506,7 +508,9 @@ def get_worker_node_ips(config_file, override_cluster_name):
provider = get_node_provider(config["provider"], config["cluster_name"])
try:
nodes = provider.non_terminated_nodes({TAG_RAY_NODE_TYPE: "worker"})
nodes = provider.non_terminated_nodes({
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
})
if config.get("provider", {}).get("use_internal_ips", False) is True:
return [provider.internal_ip(node) for node in nodes]
@@ -523,7 +527,7 @@ def _get_head_node(config,
provider = get_node_provider(config["provider"], config["cluster_name"])
try:
head_node_tags = {
TAG_RAY_NODE_TYPE: "head",
TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD,
}
nodes = provider.non_terminated_nodes(head_node_tags)
finally:
+9 -7
View File
@@ -10,7 +10,8 @@ import socket
import logging
from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, NODE_TYPE_WORKER, \
NODE_TYPE_HEAD
logger = logging.getLogger(__name__)
@@ -29,8 +30,9 @@ class ClusterState(object):
if os.path.exists(self.save_path):
workers = json.loads(open(self.save_path).read())
head_config = workers.get(provider_config["head_ip"])
if not head_config or head_config.get(
"tags", {}).get(TAG_RAY_NODE_TYPE) != "head":
if (not head_config or
head_config.get("tags", {}).get(TAG_RAY_NODE_TYPE)
!= NODE_TYPE_HEAD):
workers = {}
logger.info("Head IP changed - recreating cluster.")
else:
@@ -41,23 +43,23 @@ class ClusterState(object):
if worker_ip not in workers:
workers[worker_ip] = {
"tags": {
TAG_RAY_NODE_TYPE: "worker"
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
},
"state": "terminated",
}
else:
assert workers[worker_ip]["tags"][
TAG_RAY_NODE_TYPE] == "worker"
TAG_RAY_NODE_TYPE] == NODE_TYPE_WORKER
if provider_config["head_ip"] not in workers:
workers[provider_config["head_ip"]] = {
"tags": {
TAG_RAY_NODE_TYPE: "head"
TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD
},
"state": "terminated",
}
else:
assert workers[provider_config["head_ip"]]["tags"][
TAG_RAY_NODE_TYPE] == "head"
TAG_RAY_NODE_TYPE] == NODE_TYPE_HEAD
assert len(workers) == len(provider_config["worker_ips"]) + 1
with open(self.save_path, "w") as f:
logger.debug("ClusterState: "
+8
View File
@@ -9,9 +9,17 @@ TAG_RAY_NODE_NAME = "ray-node-name"
# Tag for the type of node (e.g. Head, Worker)
TAG_RAY_NODE_TYPE = "ray-node-type"
NODE_TYPE_HEAD = "head"
NODE_TYPE_WORKER = "worker"
# Tag that reports the current state of the node (e.g. Updating, Up-to-date)
TAG_RAY_NODE_STATUS = "ray-node-status"
STATUS_UNINITIALIZED = "uninitialized"
STATUS_WAITING_FOR_SSH = "waiting-for-ssh"
STATUS_SYNCING_FILES = "syncing-files"
STATUS_SETTING_UP = "setting-up"
STATUS_UPDATE_FAILED = "update-failed"
STATUS_UP_TO_DATE = "up-to-date"
# Tag uniquely identifying all nodes of a cluster
TAG_RAY_CLUSTER_NAME = "ray-cluster-name"
+17 -15
View File
@@ -16,7 +16,9 @@ import time
from threading import Thread
from getpass import getuser
from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG
from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG, \
STATUS_UP_TO_DATE, STATUS_UPDATE_FAILED, STATUS_WAITING_FOR_SSH, \
STATUS_SETTING_UP, STATUS_SYNCING_FILES
from ray.autoscaler.log_timer import LogTimer
logger = logging.getLogger(__name__)
@@ -56,7 +58,6 @@ class NodeUpdater(object):
ray_start_commands,
runtime_hash,
process_runner=subprocess,
exit_on_update_fail=False,
use_internal_ip=False):
ssh_control_hash = hashlib.md5(cluster_name.encode()).hexdigest()
@@ -82,7 +83,6 @@ class NodeUpdater(object):
self.initialization_commands = initialization_commands
self.setup_commands = setup_commands
self.ray_start_commands = ray_start_commands
self.exit_on_update_fail = exit_on_update_fail
self.runtime_hash = runtime_hash
def get_node_ip(self):
@@ -152,13 +152,13 @@ class NodeUpdater(object):
logger.error("NodeUpdater: "
"{}: Error updating {}".format(
self.node_id, error_str))
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "update-failed"})
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
raise e
self.provider.set_node_tags(
self.node_id, {
TAG_RAY_NODE_STATUS: "up-to-date",
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
TAG_RAY_RUNTIME_CONFIG: self.runtime_hash
})
@@ -213,8 +213,8 @@ class NodeUpdater(object):
sync_cmd(local_path, remote_path, redirect=None)
def do_update(self):
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "waiting-for-ssh"})
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_WAITING_FOR_SSH})
deadline = time.time() + NODE_START_WAIT_S
self.set_ssh_ip_if_required()
@@ -230,27 +230,27 @@ class NodeUpdater(object):
"NodeUpdater: {} already up-to-date, skip to ray start".format(
self.node_id))
else:
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "syncing-files"})
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SYNCING_FILES})
self.sync_file_mounts(self.rsync_up)
# Run init commands
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "setting-up"})
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SETTING_UP})
m = "{}: Initialization commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.initialization_commands:
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
self.ssh_cmd(cmd)
m = "{}: Setup commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.setup_commands:
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
self.ssh_cmd(cmd)
m = "{}: Ray start commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.ray_start_commands:
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
self.ssh_cmd(cmd)
def rsync_up(self, source, target, redirect=None):
logger.info("NodeUpdater: "
@@ -321,6 +321,8 @@ class NodeUpdater(object):
stderr=redirect or sys.stderr)
except subprocess.CalledProcessError:
if exit_on_fail:
# Only reason we need this exit flag here is because here we
# know the final command and can print it nicely before exit()
logger.error("Command failed: \n\n {}\n".format(
" ".join(final_cmd)))
sys.exit(1)
+209 -28
View File
@@ -2,7 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from flaky import flaky
import shutil
import tempfile
import threading
@@ -15,7 +14,8 @@ import ray
import ray.services as services
from ray.autoscaler.autoscaler import StandardAutoscaler, LoadMetrics, \
fillout_defaults, validate_config
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_NODE_STATUS
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_NODE_STATUS, \
STATUS_UP_TO_DATE, STATUS_UPDATE_FAILED
from ray.autoscaler.node_provider import NODE_PROVIDERS, NodeProvider
from ray.tests.utils import RayTestTimeoutException
import pytest
@@ -47,22 +47,53 @@ class MockProcessRunner(object):
raise Exception("Failing command on purpose")
self.calls.append(cmd)
def assert_has_call(self, ip, pattern):
out = ""
for cmd in self.calls:
msg = " ".join(cmd)
if ip in msg:
out += msg
out += "\n"
if pattern in out:
return True
else:
raise Exception("Did not find [{}] in [{}] for {}".format(
pattern, out, ip))
def assert_not_has_call(self, ip, pattern):
out = ""
for cmd in self.calls:
msg = " ".join(cmd)
if ip in msg:
out += msg
out += "\n"
if pattern in out:
raise Exception("Found [{}] in [{}] for {}".format(
pattern, out, ip))
else:
return True
def clear_history(self):
self.calls = []
class MockProvider(NodeProvider):
def __init__(self):
def __init__(self, cache_stopped=False):
self.mock_nodes = {}
self.next_id = 0
self.throw = False
self.fail_creates = False
self.ready_to_create = threading.Event()
self.ready_to_create.set()
self.cache_stopped = cache_stopped
def non_terminated_nodes(self, tag_filters):
if self.throw:
raise Exception("oops")
return [
n.node_id for n in self.mock_nodes.values()
if n.matches(tag_filters) and n.state != "terminated"
if n.matches(tag_filters)
and n.state not in ["stopped", "terminated"]
]
def non_terminated_node_ips(self, tag_filters):
@@ -70,14 +101,15 @@ class MockProvider(NodeProvider):
raise Exception("oops")
return [
n.internal_ip for n in self.mock_nodes.values()
if n.matches(tag_filters) and n.state != "terminated"
if n.matches(tag_filters)
and n.state not in ["stopped", "terminated"]
]
def is_running(self, node_id):
return self.mock_nodes[node_id].state == "running"
def is_terminated(self, node_id):
return self.mock_nodes[node_id].state == "terminated"
return self.mock_nodes[node_id].state in ["stopped", "terminated"]
def node_tags(self, node_id):
return self.mock_nodes[node_id].tags
@@ -92,15 +124,29 @@ class MockProvider(NodeProvider):
self.ready_to_create.wait()
if self.fail_creates:
return
if self.cache_stopped:
for node in self.mock_nodes.values():
if node.state == "stopped" and count > 0:
count -= 1
node.state = "pending"
node.tags.update(tags)
for _ in range(count):
self.mock_nodes[self.next_id] = MockNode(self.next_id, tags)
self.mock_nodes[self.next_id] = MockNode(self.next_id, tags.copy())
self.next_id += 1
def set_node_tags(self, node_id, tags):
self.mock_nodes[node_id].tags.update(tags)
def terminate_node(self, node_id):
self.mock_nodes[node_id].state = "terminated"
if self.cache_stopped:
self.mock_nodes[node_id].state = "stopped"
else:
self.mock_nodes[node_id].state = "terminated"
def finish_starting_nodes(self):
for node in self.mock_nodes.values():
if node.state == "pending":
node.state = "running"
SMALL_CLUSTER = {
@@ -131,10 +177,10 @@ SMALL_CLUSTER = {
"TestProp": 2,
},
"file_mounts": {},
"initialization_commands": ["cmd0"],
"setup_commands": ["cmd1"],
"head_setup_commands": ["cmd2"],
"worker_setup_commands": ["cmd3"],
"initialization_commands": ["init_cmd"],
"setup_commands": ["setup_cmd"],
"head_setup_commands": ["head_setup_cmd"],
"worker_setup_commands": ["worker_setup_cmd"],
"head_start_ray_commands": ["start_ray_head"],
"worker_start_ray_commands": ["start_ray_worker"],
}
@@ -344,11 +390,13 @@ class AutoscalingTest(unittest.TestCase):
autoscaler.update()
self.waitForNodes(0)
autoscaler.request_resources({"CPU": cores_per_node * 10})
for _ in range(3): # Maximum launch batch is 5
for _ in range(5): # Maximum launch batch is 5
time.sleep(0.01)
autoscaler.update()
self.waitForNodes(10)
autoscaler.request_resources({"CPU": cores_per_node * 30})
for _ in range(4): # Maximum launch batch is 5
time.sleep(0.01)
autoscaler.update()
self.waitForNodes(30)
@@ -700,10 +748,10 @@ class AutoscalingTest(unittest.TestCase):
autoscaler.update()
autoscaler.update()
self.waitForNodes(2)
for node in self.provider.mock_nodes.values():
node.state = "running"
self.provider.finish_starting_nodes()
autoscaler.update()
self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})
self.waitForNodes(
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
def testReportsConfigFailures(self):
config = copy.deepcopy(SMALL_CLUSTER)
@@ -712,7 +760,7 @@ class AutoscalingTest(unittest.TestCase):
config["provider"]["type"] = "mock"
config_path = self.write_config(config)
self.provider = MockProvider()
runner = MockProcessRunner(fail_cmds=["cmd1"])
runner = MockProcessRunner(fail_cmds=["setup_cmd"])
autoscaler = StandardAutoscaler(
config_path,
LoadMetrics(),
@@ -722,11 +770,10 @@ class AutoscalingTest(unittest.TestCase):
autoscaler.update()
autoscaler.update()
self.waitForNodes(2)
for node in self.provider.mock_nodes.values():
node.state = "running"
self.provider.finish_starting_nodes()
autoscaler.update()
self.waitForNodes(
2, tag_filters={TAG_RAY_NODE_STATUS: "update-failed"})
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
def testConfiguresOutdatedNodes(self):
config_path = self.write_config(SMALL_CLUSTER)
@@ -741,10 +788,10 @@ class AutoscalingTest(unittest.TestCase):
autoscaler.update()
autoscaler.update()
self.waitForNodes(2)
for node in self.provider.mock_nodes.values():
node.state = "running"
self.provider.finish_starting_nodes()
autoscaler.update()
self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})
self.waitForNodes(
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
runner.calls = []
new_config = SMALL_CLUSTER.copy()
new_config["worker_setup_commands"] = ["cmdX", "cmdY"]
@@ -843,7 +890,6 @@ class AutoscalingTest(unittest.TestCase):
autoscaler.update()
assert len(self.provider.non_terminated_nodes({})) == 0
@flaky(max_runs=4)
def testRecoverUnhealthyWorkers(self):
config_path = self.write_config(SMALL_CLUSTER)
self.provider = MockProvider()
@@ -857,14 +903,19 @@ class AutoscalingTest(unittest.TestCase):
update_interval_s=0)
autoscaler.update()
self.waitForNodes(2)
for node in self.provider.mock_nodes.values():
node.state = "running"
self.provider.finish_starting_nodes()
autoscaler.update()
self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})
self.waitForNodes(
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
# Mark a node as unhealthy
lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0
for _ in range(5):
if autoscaler.updaters:
time.sleep(0.05)
autoscaler.update()
assert not autoscaler.updaters
num_calls = len(runner.calls)
lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0
autoscaler.update()
self.waitFor(lambda: len(runner.calls) > num_calls, num_retries=150)
@@ -901,6 +952,136 @@ class AutoscalingTest(unittest.TestCase):
StandardAutoscaler(
invalid_provider, LoadMetrics(), update_interval_s=0)
def testSetupCommandsWithNoNodeCaching(self):
config = SMALL_CLUSTER.copy()
config["min_workers"] = 1
config["max_workers"] = 1
config_path = self.write_config(config)
self.provider = MockProvider(cache_stopped=False)
runner = MockProcessRunner()
lm = LoadMetrics()
autoscaler = StandardAutoscaler(
config_path,
lm,
max_failures=0,
process_runner=runner,
update_interval_s=0)
autoscaler.update()
self.waitForNodes(1)
self.provider.finish_starting_nodes()
autoscaler.update()
self.waitForNodes(
1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
runner.assert_has_call("172.0.0.0", "init_cmd")
runner.assert_has_call("172.0.0.0", "setup_cmd")
runner.assert_has_call("172.0.0.0", "worker_setup_cmd")
runner.assert_has_call("172.0.0.0", "start_ray_worker")
# Check the node was not reused
self.provider.terminate_node(0)
autoscaler.update()
self.waitForNodes(1)
runner.clear_history()
self.provider.finish_starting_nodes()
autoscaler.update()
self.waitForNodes(
1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
runner.assert_has_call("172.0.0.1", "init_cmd")
runner.assert_has_call("172.0.0.1", "setup_cmd")
runner.assert_has_call("172.0.0.1", "worker_setup_cmd")
runner.assert_has_call("172.0.0.1", "start_ray_worker")
def testSetupCommandsWithStoppedNodeCaching(self):
config = SMALL_CLUSTER.copy()
config["min_workers"] = 1
config["max_workers"] = 1
config_path = self.write_config(config)
self.provider = MockProvider(cache_stopped=True)
runner = MockProcessRunner()
lm = LoadMetrics()
autoscaler = StandardAutoscaler(
config_path,
lm,
max_failures=0,
process_runner=runner,
update_interval_s=0)
autoscaler.update()
self.waitForNodes(1)
self.provider.finish_starting_nodes()
autoscaler.update()
self.waitForNodes(
1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
runner.assert_has_call("172.0.0.0", "init_cmd")
runner.assert_has_call("172.0.0.0", "setup_cmd")
runner.assert_has_call("172.0.0.0", "worker_setup_cmd")
runner.assert_has_call("172.0.0.0", "start_ray_worker")
# Check the node was indeed reused
self.provider.terminate_node(0)
autoscaler.update()
self.waitForNodes(1)
runner.clear_history()
self.provider.finish_starting_nodes()
autoscaler.update()
self.waitForNodes(
1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
runner.assert_not_has_call("172.0.0.0", "init_cmd")
runner.assert_not_has_call("172.0.0.0", "setup_cmd")
runner.assert_not_has_call("172.0.0.0", "worker_setup_cmd")
runner.assert_has_call("172.0.0.0", "start_ray_worker")
runner.clear_history()
autoscaler.update()
runner.assert_not_has_call("172.0.0.0", "setup_cmd")
# We did not start any other nodes
runner.assert_not_has_call("172.0.0.1", " ")
def testMultiNodeReuse(self):
config = SMALL_CLUSTER.copy()
config["min_workers"] = 3
config["max_workers"] = 3
config_path = self.write_config(config)
self.provider = MockProvider(cache_stopped=True)
runner = MockProcessRunner()
lm = LoadMetrics()
autoscaler = StandardAutoscaler(
config_path,
lm,
max_failures=0,
process_runner=runner,
update_interval_s=0)
autoscaler.update()
self.waitForNodes(3)
self.provider.finish_starting_nodes()
autoscaler.update()
self.waitForNodes(
3, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
self.provider.terminate_node(0)
self.provider.terminate_node(1)
self.provider.terminate_node(2)
runner.clear_history()
# Scale up to 10 nodes, check we reuse the first 3 and add 7 more.
config["min_workers"] = 10
config["max_workers"] = 10
self.write_config(config)
autoscaler.update()
autoscaler.update()
self.waitForNodes(10)
self.provider.finish_starting_nodes()
autoscaler.update()
self.waitForNodes(
10, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
autoscaler.update()
for i in [0, 1, 2]:
runner.assert_not_has_call("172.0.0.{}".format(i), "setup_cmd")
runner.assert_has_call("172.0.0.{}".format(i), "start_ray_worker")
for i in [3, 4, 5, 6, 7, 8, 9]:
runner.assert_has_call("172.0.0.{}".format(i), "setup_cmd")
runner.assert_has_call("172.0.0.{}".format(i), "start_ray_worker")
if __name__ == "__main__":
unittest.main(verbosity=2)