mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 04:19:23 +08:00
[autoscaler] Add unit tests for stopped node caching, fix flaky tests (#5793)
This commit is contained in:
@@ -22,7 +22,8 @@ from ray.autoscaler.node_provider import get_node_provider, \
|
||||
get_default_config
|
||||
from ray.autoscaler.tags import (TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG,
|
||||
TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE,
|
||||
TAG_RAY_NODE_NAME)
|
||||
TAG_RAY_NODE_NAME, STATUS_UP_TO_DATE,
|
||||
STATUS_UNINITIALIZED, NODE_TYPE_WORKER)
|
||||
from ray.autoscaler.updater import NodeUpdaterThread
|
||||
from ray.ray_constants import AUTOSCALER_MAX_NUM_FAILURES, \
|
||||
AUTOSCALER_MAX_LAUNCH_BATCH, AUTOSCALER_MAX_CONCURRENT_LAUNCHES, \
|
||||
@@ -315,7 +316,7 @@ class NodeLauncher(threading.Thread):
|
||||
super(NodeLauncher, self).__init__(*args, **kwargs)
|
||||
|
||||
def _launch_node(self, config, count):
|
||||
worker_filter = {TAG_RAY_NODE_TYPE: "worker"}
|
||||
worker_filter = {TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER}
|
||||
before = self.provider.non_terminated_nodes(tag_filters=worker_filter)
|
||||
launch_hash = hash_launch_conf(config["worker_nodes"], config["auth"])
|
||||
self.log("Launching {} nodes.".format(count))
|
||||
@@ -323,8 +324,8 @@ class NodeLauncher(threading.Thread):
|
||||
config["worker_nodes"], {
|
||||
TAG_RAY_NODE_NAME: "ray-{}-worker".format(
|
||||
config["cluster_name"]),
|
||||
TAG_RAY_NODE_TYPE: "worker",
|
||||
TAG_RAY_NODE_STATUS: "uninitialized",
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER,
|
||||
TAG_RAY_NODE_STATUS: STATUS_UNINITIALIZED,
|
||||
TAG_RAY_LAUNCH_CONFIG: launch_hash,
|
||||
}, count)
|
||||
after = self.provider.non_terminated_nodes(tag_filters=worker_filter)
|
||||
@@ -547,18 +548,10 @@ class StandardAutoscaler(object):
|
||||
self.log_info_string(nodes, target_workers)
|
||||
|
||||
# Update nodes with out-of-date files
|
||||
T = [
|
||||
threading.Thread(
|
||||
target=self.spawn_updater,
|
||||
args=(node_id, commands, ray_start),
|
||||
) for node_id, commands, ray_start in (self.should_update(node_id)
|
||||
for node_id in nodes)
|
||||
if node_id is not None
|
||||
]
|
||||
for t in T:
|
||||
t.start()
|
||||
for t in T:
|
||||
t.join()
|
||||
for node_id, commands, ray_start in (self.should_update(node_id)
|
||||
for node_id in nodes):
|
||||
if node_id is not None:
|
||||
self.spawn_updater(node_id, commands, ray_start)
|
||||
|
||||
# Attempt to recover unhealthy nodes
|
||||
for node_id in nodes:
|
||||
@@ -664,10 +657,11 @@ class StandardAutoscaler(object):
|
||||
|
||||
def should_update(self, node_id):
|
||||
if not self.can_update(node_id):
|
||||
return (None, None, None)
|
||||
return None, None, None # no update
|
||||
|
||||
if self.files_up_to_date(node_id):
|
||||
return (None, None, None)
|
||||
status = self.provider.node_tags(node_id).get(TAG_RAY_NODE_STATUS)
|
||||
if status == STATUS_UP_TO_DATE and self.files_up_to_date(node_id):
|
||||
return None, None, None # no update
|
||||
|
||||
successful_updated = self.num_successful_updates.get(node_id, 0) > 0
|
||||
if successful_updated and self.config.get("restart_only", False):
|
||||
@@ -718,7 +712,7 @@ class StandardAutoscaler(object):
|
||||
|
||||
def workers(self):
|
||||
return self.provider.non_terminated_nodes(
|
||||
tag_filters={TAG_RAY_NODE_TYPE: "worker"})
|
||||
tag_filters={TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER})
|
||||
|
||||
def log_info_string(self, nodes, target):
|
||||
logger.info("StandardAutoscaler: {}".format(
|
||||
|
||||
@@ -23,7 +23,7 @@ from ray.autoscaler.autoscaler import validate_config, hash_runtime_conf, \
|
||||
hash_launch_conf, fillout_defaults
|
||||
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
|
||||
TAG_RAY_NODE_NAME
|
||||
TAG_RAY_NODE_NAME, NODE_TYPE_WORKER, NODE_TYPE_HEAD
|
||||
from ray.autoscaler.updater import NodeUpdaterThread
|
||||
from ray.autoscaler.log_timer import LogTimer
|
||||
from ray.autoscaler.docker import with_docker_exec
|
||||
@@ -91,13 +91,13 @@ def teardown_cluster(config_file, yes, workers_only, override_cluster_name):
|
||||
else:
|
||||
A = [
|
||||
node_id for node_id in provider.non_terminated_nodes({
|
||||
TAG_RAY_NODE_TYPE: "head"
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD
|
||||
})
|
||||
]
|
||||
|
||||
A += [
|
||||
node_id for node_id in provider.non_terminated_nodes({
|
||||
TAG_RAY_NODE_TYPE: "worker"
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
|
||||
})
|
||||
]
|
||||
return A
|
||||
@@ -128,7 +128,9 @@ def kill_node(config_file, yes, hard, override_cluster_name):
|
||||
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
try:
|
||||
nodes = provider.non_terminated_nodes({TAG_RAY_NODE_TYPE: "worker"})
|
||||
nodes = provider.non_terminated_nodes({
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
|
||||
})
|
||||
node = random.choice(nodes)
|
||||
logger.info("kill_node: Shutdown worker {}".format(node))
|
||||
if hard:
|
||||
@@ -174,7 +176,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
config_file = os.path.abspath(config_file)
|
||||
try:
|
||||
head_node_tags = {
|
||||
TAG_RAY_NODE_TYPE: "head",
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD,
|
||||
}
|
||||
nodes = provider.non_terminated_nodes(head_node_tags)
|
||||
if len(nodes) > 0:
|
||||
@@ -506,7 +508,9 @@ def get_worker_node_ips(config_file, override_cluster_name):
|
||||
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
try:
|
||||
nodes = provider.non_terminated_nodes({TAG_RAY_NODE_TYPE: "worker"})
|
||||
nodes = provider.non_terminated_nodes({
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
|
||||
})
|
||||
|
||||
if config.get("provider", {}).get("use_internal_ips", False) is True:
|
||||
return [provider.internal_ip(node) for node in nodes]
|
||||
@@ -523,7 +527,7 @@ def _get_head_node(config,
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
try:
|
||||
head_node_tags = {
|
||||
TAG_RAY_NODE_TYPE: "head",
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD,
|
||||
}
|
||||
nodes = provider.non_terminated_nodes(head_node_tags)
|
||||
finally:
|
||||
|
||||
@@ -10,7 +10,8 @@ import socket
|
||||
import logging
|
||||
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, NODE_TYPE_WORKER, \
|
||||
NODE_TYPE_HEAD
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,8 +30,9 @@ class ClusterState(object):
|
||||
if os.path.exists(self.save_path):
|
||||
workers = json.loads(open(self.save_path).read())
|
||||
head_config = workers.get(provider_config["head_ip"])
|
||||
if not head_config or head_config.get(
|
||||
"tags", {}).get(TAG_RAY_NODE_TYPE) != "head":
|
||||
if (not head_config or
|
||||
head_config.get("tags", {}).get(TAG_RAY_NODE_TYPE)
|
||||
!= NODE_TYPE_HEAD):
|
||||
workers = {}
|
||||
logger.info("Head IP changed - recreating cluster.")
|
||||
else:
|
||||
@@ -41,23 +43,23 @@ class ClusterState(object):
|
||||
if worker_ip not in workers:
|
||||
workers[worker_ip] = {
|
||||
"tags": {
|
||||
TAG_RAY_NODE_TYPE: "worker"
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
|
||||
},
|
||||
"state": "terminated",
|
||||
}
|
||||
else:
|
||||
assert workers[worker_ip]["tags"][
|
||||
TAG_RAY_NODE_TYPE] == "worker"
|
||||
TAG_RAY_NODE_TYPE] == NODE_TYPE_WORKER
|
||||
if provider_config["head_ip"] not in workers:
|
||||
workers[provider_config["head_ip"]] = {
|
||||
"tags": {
|
||||
TAG_RAY_NODE_TYPE: "head"
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD
|
||||
},
|
||||
"state": "terminated",
|
||||
}
|
||||
else:
|
||||
assert workers[provider_config["head_ip"]]["tags"][
|
||||
TAG_RAY_NODE_TYPE] == "head"
|
||||
TAG_RAY_NODE_TYPE] == NODE_TYPE_HEAD
|
||||
assert len(workers) == len(provider_config["worker_ips"]) + 1
|
||||
with open(self.save_path, "w") as f:
|
||||
logger.debug("ClusterState: "
|
||||
|
||||
@@ -9,9 +9,17 @@ TAG_RAY_NODE_NAME = "ray-node-name"
|
||||
|
||||
# Tag for the type of node (e.g. Head, Worker)
|
||||
TAG_RAY_NODE_TYPE = "ray-node-type"
|
||||
NODE_TYPE_HEAD = "head"
|
||||
NODE_TYPE_WORKER = "worker"
|
||||
|
||||
# Tag that reports the current state of the node (e.g. Updating, Up-to-date)
|
||||
TAG_RAY_NODE_STATUS = "ray-node-status"
|
||||
STATUS_UNINITIALIZED = "uninitialized"
|
||||
STATUS_WAITING_FOR_SSH = "waiting-for-ssh"
|
||||
STATUS_SYNCING_FILES = "syncing-files"
|
||||
STATUS_SETTING_UP = "setting-up"
|
||||
STATUS_UPDATE_FAILED = "update-failed"
|
||||
STATUS_UP_TO_DATE = "up-to-date"
|
||||
|
||||
# Tag uniquely identifying all nodes of a cluster
|
||||
TAG_RAY_CLUSTER_NAME = "ray-cluster-name"
|
||||
|
||||
@@ -16,7 +16,9 @@ import time
|
||||
from threading import Thread
|
||||
from getpass import getuser
|
||||
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG, \
|
||||
STATUS_UP_TO_DATE, STATUS_UPDATE_FAILED, STATUS_WAITING_FOR_SSH, \
|
||||
STATUS_SETTING_UP, STATUS_SYNCING_FILES
|
||||
from ray.autoscaler.log_timer import LogTimer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -56,7 +58,6 @@ class NodeUpdater(object):
|
||||
ray_start_commands,
|
||||
runtime_hash,
|
||||
process_runner=subprocess,
|
||||
exit_on_update_fail=False,
|
||||
use_internal_ip=False):
|
||||
|
||||
ssh_control_hash = hashlib.md5(cluster_name.encode()).hexdigest()
|
||||
@@ -82,7 +83,6 @@ class NodeUpdater(object):
|
||||
self.initialization_commands = initialization_commands
|
||||
self.setup_commands = setup_commands
|
||||
self.ray_start_commands = ray_start_commands
|
||||
self.exit_on_update_fail = exit_on_update_fail
|
||||
self.runtime_hash = runtime_hash
|
||||
|
||||
def get_node_ip(self):
|
||||
@@ -152,13 +152,13 @@ class NodeUpdater(object):
|
||||
logger.error("NodeUpdater: "
|
||||
"{}: Error updating {}".format(
|
||||
self.node_id, error_str))
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "update-failed"})
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
|
||||
raise e
|
||||
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {
|
||||
TAG_RAY_NODE_STATUS: "up-to-date",
|
||||
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
|
||||
TAG_RAY_RUNTIME_CONFIG: self.runtime_hash
|
||||
})
|
||||
|
||||
@@ -213,8 +213,8 @@ class NodeUpdater(object):
|
||||
sync_cmd(local_path, remote_path, redirect=None)
|
||||
|
||||
def do_update(self):
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "waiting-for-ssh"})
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_WAITING_FOR_SSH})
|
||||
|
||||
deadline = time.time() + NODE_START_WAIT_S
|
||||
self.set_ssh_ip_if_required()
|
||||
@@ -230,27 +230,27 @@ class NodeUpdater(object):
|
||||
"NodeUpdater: {} already up-to-date, skip to ray start".format(
|
||||
self.node_id))
|
||||
else:
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "syncing-files"})
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SYNCING_FILES})
|
||||
self.sync_file_mounts(self.rsync_up)
|
||||
|
||||
# Run init commands
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "setting-up"})
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SETTING_UP})
|
||||
m = "{}: Initialization commands completed".format(self.node_id)
|
||||
with LogTimer("NodeUpdater: {}".format(m)):
|
||||
for cmd in self.initialization_commands:
|
||||
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
|
||||
self.ssh_cmd(cmd)
|
||||
|
||||
m = "{}: Setup commands completed".format(self.node_id)
|
||||
with LogTimer("NodeUpdater: {}".format(m)):
|
||||
for cmd in self.setup_commands:
|
||||
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
|
||||
self.ssh_cmd(cmd)
|
||||
|
||||
m = "{}: Ray start commands completed".format(self.node_id)
|
||||
with LogTimer("NodeUpdater: {}".format(m)):
|
||||
for cmd in self.ray_start_commands:
|
||||
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
|
||||
self.ssh_cmd(cmd)
|
||||
|
||||
def rsync_up(self, source, target, redirect=None):
|
||||
logger.info("NodeUpdater: "
|
||||
@@ -321,6 +321,8 @@ class NodeUpdater(object):
|
||||
stderr=redirect or sys.stderr)
|
||||
except subprocess.CalledProcessError:
|
||||
if exit_on_fail:
|
||||
# Only reason we need this exit flag here is because here we
|
||||
# know the final command and can print it nicely before exit()
|
||||
logger.error("Command failed: \n\n {}\n".format(
|
||||
" ".join(final_cmd)))
|
||||
sys.exit(1)
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from flaky import flaky
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
@@ -15,7 +14,8 @@ import ray
|
||||
import ray.services as services
|
||||
from ray.autoscaler.autoscaler import StandardAutoscaler, LoadMetrics, \
|
||||
fillout_defaults, validate_config
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_NODE_STATUS
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_NODE_STATUS, \
|
||||
STATUS_UP_TO_DATE, STATUS_UPDATE_FAILED
|
||||
from ray.autoscaler.node_provider import NODE_PROVIDERS, NodeProvider
|
||||
from ray.tests.utils import RayTestTimeoutException
|
||||
import pytest
|
||||
@@ -47,22 +47,53 @@ class MockProcessRunner(object):
|
||||
raise Exception("Failing command on purpose")
|
||||
self.calls.append(cmd)
|
||||
|
||||
def assert_has_call(self, ip, pattern):
|
||||
out = ""
|
||||
for cmd in self.calls:
|
||||
msg = " ".join(cmd)
|
||||
if ip in msg:
|
||||
out += msg
|
||||
out += "\n"
|
||||
if pattern in out:
|
||||
return True
|
||||
else:
|
||||
raise Exception("Did not find [{}] in [{}] for {}".format(
|
||||
pattern, out, ip))
|
||||
|
||||
def assert_not_has_call(self, ip, pattern):
|
||||
out = ""
|
||||
for cmd in self.calls:
|
||||
msg = " ".join(cmd)
|
||||
if ip in msg:
|
||||
out += msg
|
||||
out += "\n"
|
||||
if pattern in out:
|
||||
raise Exception("Found [{}] in [{}] for {}".format(
|
||||
pattern, out, ip))
|
||||
else:
|
||||
return True
|
||||
|
||||
def clear_history(self):
|
||||
self.calls = []
|
||||
|
||||
|
||||
class MockProvider(NodeProvider):
|
||||
def __init__(self):
|
||||
def __init__(self, cache_stopped=False):
|
||||
self.mock_nodes = {}
|
||||
self.next_id = 0
|
||||
self.throw = False
|
||||
self.fail_creates = False
|
||||
self.ready_to_create = threading.Event()
|
||||
self.ready_to_create.set()
|
||||
self.cache_stopped = cache_stopped
|
||||
|
||||
def non_terminated_nodes(self, tag_filters):
|
||||
if self.throw:
|
||||
raise Exception("oops")
|
||||
return [
|
||||
n.node_id for n in self.mock_nodes.values()
|
||||
if n.matches(tag_filters) and n.state != "terminated"
|
||||
if n.matches(tag_filters)
|
||||
and n.state not in ["stopped", "terminated"]
|
||||
]
|
||||
|
||||
def non_terminated_node_ips(self, tag_filters):
|
||||
@@ -70,14 +101,15 @@ class MockProvider(NodeProvider):
|
||||
raise Exception("oops")
|
||||
return [
|
||||
n.internal_ip for n in self.mock_nodes.values()
|
||||
if n.matches(tag_filters) and n.state != "terminated"
|
||||
if n.matches(tag_filters)
|
||||
and n.state not in ["stopped", "terminated"]
|
||||
]
|
||||
|
||||
def is_running(self, node_id):
|
||||
return self.mock_nodes[node_id].state == "running"
|
||||
|
||||
def is_terminated(self, node_id):
|
||||
return self.mock_nodes[node_id].state == "terminated"
|
||||
return self.mock_nodes[node_id].state in ["stopped", "terminated"]
|
||||
|
||||
def node_tags(self, node_id):
|
||||
return self.mock_nodes[node_id].tags
|
||||
@@ -92,15 +124,29 @@ class MockProvider(NodeProvider):
|
||||
self.ready_to_create.wait()
|
||||
if self.fail_creates:
|
||||
return
|
||||
if self.cache_stopped:
|
||||
for node in self.mock_nodes.values():
|
||||
if node.state == "stopped" and count > 0:
|
||||
count -= 1
|
||||
node.state = "pending"
|
||||
node.tags.update(tags)
|
||||
for _ in range(count):
|
||||
self.mock_nodes[self.next_id] = MockNode(self.next_id, tags)
|
||||
self.mock_nodes[self.next_id] = MockNode(self.next_id, tags.copy())
|
||||
self.next_id += 1
|
||||
|
||||
def set_node_tags(self, node_id, tags):
|
||||
self.mock_nodes[node_id].tags.update(tags)
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
self.mock_nodes[node_id].state = "terminated"
|
||||
if self.cache_stopped:
|
||||
self.mock_nodes[node_id].state = "stopped"
|
||||
else:
|
||||
self.mock_nodes[node_id].state = "terminated"
|
||||
|
||||
def finish_starting_nodes(self):
|
||||
for node in self.mock_nodes.values():
|
||||
if node.state == "pending":
|
||||
node.state = "running"
|
||||
|
||||
|
||||
SMALL_CLUSTER = {
|
||||
@@ -131,10 +177,10 @@ SMALL_CLUSTER = {
|
||||
"TestProp": 2,
|
||||
},
|
||||
"file_mounts": {},
|
||||
"initialization_commands": ["cmd0"],
|
||||
"setup_commands": ["cmd1"],
|
||||
"head_setup_commands": ["cmd2"],
|
||||
"worker_setup_commands": ["cmd3"],
|
||||
"initialization_commands": ["init_cmd"],
|
||||
"setup_commands": ["setup_cmd"],
|
||||
"head_setup_commands": ["head_setup_cmd"],
|
||||
"worker_setup_commands": ["worker_setup_cmd"],
|
||||
"head_start_ray_commands": ["start_ray_head"],
|
||||
"worker_start_ray_commands": ["start_ray_worker"],
|
||||
}
|
||||
@@ -344,11 +390,13 @@ class AutoscalingTest(unittest.TestCase):
|
||||
autoscaler.update()
|
||||
self.waitForNodes(0)
|
||||
autoscaler.request_resources({"CPU": cores_per_node * 10})
|
||||
for _ in range(3): # Maximum launch batch is 5
|
||||
for _ in range(5): # Maximum launch batch is 5
|
||||
time.sleep(0.01)
|
||||
autoscaler.update()
|
||||
self.waitForNodes(10)
|
||||
autoscaler.request_resources({"CPU": cores_per_node * 30})
|
||||
for _ in range(4): # Maximum launch batch is 5
|
||||
time.sleep(0.01)
|
||||
autoscaler.update()
|
||||
self.waitForNodes(30)
|
||||
|
||||
@@ -700,10 +748,10 @@ class AutoscalingTest(unittest.TestCase):
|
||||
autoscaler.update()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2)
|
||||
for node in self.provider.mock_nodes.values():
|
||||
node.state = "running"
|
||||
self.provider.finish_starting_nodes()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})
|
||||
self.waitForNodes(
|
||||
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
|
||||
|
||||
def testReportsConfigFailures(self):
|
||||
config = copy.deepcopy(SMALL_CLUSTER)
|
||||
@@ -712,7 +760,7 @@ class AutoscalingTest(unittest.TestCase):
|
||||
config["provider"]["type"] = "mock"
|
||||
config_path = self.write_config(config)
|
||||
self.provider = MockProvider()
|
||||
runner = MockProcessRunner(fail_cmds=["cmd1"])
|
||||
runner = MockProcessRunner(fail_cmds=["setup_cmd"])
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path,
|
||||
LoadMetrics(),
|
||||
@@ -722,11 +770,10 @@ class AutoscalingTest(unittest.TestCase):
|
||||
autoscaler.update()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2)
|
||||
for node in self.provider.mock_nodes.values():
|
||||
node.state = "running"
|
||||
self.provider.finish_starting_nodes()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(
|
||||
2, tag_filters={TAG_RAY_NODE_STATUS: "update-failed"})
|
||||
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
|
||||
|
||||
def testConfiguresOutdatedNodes(self):
|
||||
config_path = self.write_config(SMALL_CLUSTER)
|
||||
@@ -741,10 +788,10 @@ class AutoscalingTest(unittest.TestCase):
|
||||
autoscaler.update()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2)
|
||||
for node in self.provider.mock_nodes.values():
|
||||
node.state = "running"
|
||||
self.provider.finish_starting_nodes()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})
|
||||
self.waitForNodes(
|
||||
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
|
||||
runner.calls = []
|
||||
new_config = SMALL_CLUSTER.copy()
|
||||
new_config["worker_setup_commands"] = ["cmdX", "cmdY"]
|
||||
@@ -843,7 +890,6 @@ class AutoscalingTest(unittest.TestCase):
|
||||
autoscaler.update()
|
||||
assert len(self.provider.non_terminated_nodes({})) == 0
|
||||
|
||||
@flaky(max_runs=4)
|
||||
def testRecoverUnhealthyWorkers(self):
|
||||
config_path = self.write_config(SMALL_CLUSTER)
|
||||
self.provider = MockProvider()
|
||||
@@ -857,14 +903,19 @@ class AutoscalingTest(unittest.TestCase):
|
||||
update_interval_s=0)
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2)
|
||||
for node in self.provider.mock_nodes.values():
|
||||
node.state = "running"
|
||||
self.provider.finish_starting_nodes()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})
|
||||
self.waitForNodes(
|
||||
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
|
||||
|
||||
# Mark a node as unhealthy
|
||||
lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0
|
||||
for _ in range(5):
|
||||
if autoscaler.updaters:
|
||||
time.sleep(0.05)
|
||||
autoscaler.update()
|
||||
assert not autoscaler.updaters
|
||||
num_calls = len(runner.calls)
|
||||
lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0
|
||||
autoscaler.update()
|
||||
self.waitFor(lambda: len(runner.calls) > num_calls, num_retries=150)
|
||||
|
||||
@@ -901,6 +952,136 @@ class AutoscalingTest(unittest.TestCase):
|
||||
StandardAutoscaler(
|
||||
invalid_provider, LoadMetrics(), update_interval_s=0)
|
||||
|
||||
def testSetupCommandsWithNoNodeCaching(self):
|
||||
config = SMALL_CLUSTER.copy()
|
||||
config["min_workers"] = 1
|
||||
config["max_workers"] = 1
|
||||
config_path = self.write_config(config)
|
||||
self.provider = MockProvider(cache_stopped=False)
|
||||
runner = MockProcessRunner()
|
||||
lm = LoadMetrics()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path,
|
||||
lm,
|
||||
max_failures=0,
|
||||
process_runner=runner,
|
||||
update_interval_s=0)
|
||||
autoscaler.update()
|
||||
self.waitForNodes(1)
|
||||
self.provider.finish_starting_nodes()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(
|
||||
1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
|
||||
runner.assert_has_call("172.0.0.0", "init_cmd")
|
||||
runner.assert_has_call("172.0.0.0", "setup_cmd")
|
||||
runner.assert_has_call("172.0.0.0", "worker_setup_cmd")
|
||||
runner.assert_has_call("172.0.0.0", "start_ray_worker")
|
||||
|
||||
# Check the node was not reused
|
||||
self.provider.terminate_node(0)
|
||||
autoscaler.update()
|
||||
self.waitForNodes(1)
|
||||
runner.clear_history()
|
||||
self.provider.finish_starting_nodes()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(
|
||||
1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
|
||||
runner.assert_has_call("172.0.0.1", "init_cmd")
|
||||
runner.assert_has_call("172.0.0.1", "setup_cmd")
|
||||
runner.assert_has_call("172.0.0.1", "worker_setup_cmd")
|
||||
runner.assert_has_call("172.0.0.1", "start_ray_worker")
|
||||
|
||||
def testSetupCommandsWithStoppedNodeCaching(self):
|
||||
config = SMALL_CLUSTER.copy()
|
||||
config["min_workers"] = 1
|
||||
config["max_workers"] = 1
|
||||
config_path = self.write_config(config)
|
||||
self.provider = MockProvider(cache_stopped=True)
|
||||
runner = MockProcessRunner()
|
||||
lm = LoadMetrics()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path,
|
||||
lm,
|
||||
max_failures=0,
|
||||
process_runner=runner,
|
||||
update_interval_s=0)
|
||||
autoscaler.update()
|
||||
self.waitForNodes(1)
|
||||
self.provider.finish_starting_nodes()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(
|
||||
1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
|
||||
runner.assert_has_call("172.0.0.0", "init_cmd")
|
||||
runner.assert_has_call("172.0.0.0", "setup_cmd")
|
||||
runner.assert_has_call("172.0.0.0", "worker_setup_cmd")
|
||||
runner.assert_has_call("172.0.0.0", "start_ray_worker")
|
||||
|
||||
# Check the node was indeed reused
|
||||
self.provider.terminate_node(0)
|
||||
autoscaler.update()
|
||||
self.waitForNodes(1)
|
||||
runner.clear_history()
|
||||
self.provider.finish_starting_nodes()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(
|
||||
1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
|
||||
runner.assert_not_has_call("172.0.0.0", "init_cmd")
|
||||
runner.assert_not_has_call("172.0.0.0", "setup_cmd")
|
||||
runner.assert_not_has_call("172.0.0.0", "worker_setup_cmd")
|
||||
runner.assert_has_call("172.0.0.0", "start_ray_worker")
|
||||
|
||||
runner.clear_history()
|
||||
autoscaler.update()
|
||||
runner.assert_not_has_call("172.0.0.0", "setup_cmd")
|
||||
|
||||
# We did not start any other nodes
|
||||
runner.assert_not_has_call("172.0.0.1", " ")
|
||||
|
||||
def testMultiNodeReuse(self):
|
||||
config = SMALL_CLUSTER.copy()
|
||||
config["min_workers"] = 3
|
||||
config["max_workers"] = 3
|
||||
config_path = self.write_config(config)
|
||||
self.provider = MockProvider(cache_stopped=True)
|
||||
runner = MockProcessRunner()
|
||||
lm = LoadMetrics()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path,
|
||||
lm,
|
||||
max_failures=0,
|
||||
process_runner=runner,
|
||||
update_interval_s=0)
|
||||
autoscaler.update()
|
||||
self.waitForNodes(3)
|
||||
self.provider.finish_starting_nodes()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(
|
||||
3, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
|
||||
|
||||
self.provider.terminate_node(0)
|
||||
self.provider.terminate_node(1)
|
||||
self.provider.terminate_node(2)
|
||||
runner.clear_history()
|
||||
|
||||
# Scale up to 10 nodes, check we reuse the first 3 and add 7 more.
|
||||
config["min_workers"] = 10
|
||||
config["max_workers"] = 10
|
||||
self.write_config(config)
|
||||
autoscaler.update()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(10)
|
||||
self.provider.finish_starting_nodes()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(
|
||||
10, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
|
||||
autoscaler.update()
|
||||
for i in [0, 1, 2]:
|
||||
runner.assert_not_has_call("172.0.0.{}".format(i), "setup_cmd")
|
||||
runner.assert_has_call("172.0.0.{}".format(i), "start_ray_worker")
|
||||
for i in [3, 4, 5, 6, 7, 8, 9]:
|
||||
runner.assert_has_call("172.0.0.{}".format(i), "setup_cmd")
|
||||
runner.assert_has_call("172.0.0.{}".format(i), "start_ray_worker")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user