mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 10:01:50 +08:00
[flaky test] Fix flaky checkpointing tests (#5791)
* Fix flaky checkpointing tests * Fix checkpoint test logic * Fix exception matching * timeout exception * Fix import * fix build
This commit is contained in:
@@ -866,7 +866,7 @@ class FunctionActorManager(object):
|
||||
# `available_checkpoints` list.
|
||||
msg = (
|
||||
"`load_checkpoint` must return a checkpoint id that " +
|
||||
"exists in the `available_checkpoints` list, or eone.")
|
||||
"exists in the `available_checkpoints` list, or None.")
|
||||
assert any(checkpoint_id == checkpoint.checkpoint_id
|
||||
for checkpoint in checkpoints), msg
|
||||
# Notify raylet that this actor has been resumed from
|
||||
|
||||
@@ -75,16 +75,19 @@ def ray_checkpointable_actor_cls(request):
|
||||
if not os.path.isfile(filename):
|
||||
return None
|
||||
|
||||
available_checkpoint_ids = [
|
||||
c.checkpoint_id for c in available_checkpoints
|
||||
]
|
||||
with open(filename, "r") as f:
|
||||
lines = f.readlines()
|
||||
checkpoint_id, value = lines[-1].split(" ")
|
||||
self.value = int(value)
|
||||
self.resumed_from_checkpoint = True
|
||||
checkpoint_id = ray.ActorCheckpointID(
|
||||
ray.utils.hex_to_binary(checkpoint_id))
|
||||
assert any(checkpoint_id == checkpoint.checkpoint_id
|
||||
for checkpoint in available_checkpoints)
|
||||
return checkpoint_id
|
||||
for line in f:
|
||||
checkpoint_id, value = line.strip().split(" ")
|
||||
checkpoint_id = ray.ActorCheckpointID(
|
||||
ray.utils.hex_to_binary(checkpoint_id))
|
||||
if checkpoint_id in available_checkpoint_ids:
|
||||
self.value = int(value)
|
||||
self.resumed_from_checkpoint = True
|
||||
return checkpoint_id
|
||||
return None
|
||||
|
||||
def checkpoint_expired(self, actor_id, checkpoint_id):
|
||||
pass
|
||||
@@ -2405,7 +2408,7 @@ def test_checkpointing(ray_start_regular, ray_checkpointable_actor_cls):
|
||||
"""Test actor checkpointing and restoring from a checkpoint."""
|
||||
actor = ray.remote(
|
||||
max_reconstructions=2)(ray_checkpointable_actor_cls).remote()
|
||||
# Call increase 3 times.
|
||||
# Call increase 3 times, triggering a checkpoint.
|
||||
expected = 0
|
||||
for _ in range(3):
|
||||
ray.get(actor.increase.remote())
|
||||
@@ -2511,10 +2514,10 @@ def test_checkpointing_save_exception(ray_start_regular,
|
||||
@ray.remote(max_reconstructions=2)
|
||||
class RemoteCheckpointableActor(ray_checkpointable_actor_cls):
|
||||
def save_checkpoint(self, actor_id, checkpoint_context):
|
||||
raise Exception("Error during save")
|
||||
raise Exception("Intentional error saving checkpoint.")
|
||||
|
||||
actor = RemoteCheckpointableActor.remote()
|
||||
# Call increase 3 times.
|
||||
# Call increase 3 times, triggering a checkpoint that will fail.
|
||||
expected = 0
|
||||
for _ in range(3):
|
||||
ray.get(actor.increase.remote())
|
||||
@@ -2539,13 +2542,8 @@ def test_checkpointing_save_exception(ray_start_regular,
|
||||
assert ray.get(actor.get.remote()) == expected
|
||||
assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False
|
||||
|
||||
# Check that checkpointing errors were pushed to the driver.
|
||||
errors = ray.errors()
|
||||
assert len(errors) > 0
|
||||
for error in errors:
|
||||
# An error for the actor process dying may also get pushed.
|
||||
assert (error["type"] == ray_constants.CHECKPOINT_PUSH_ERROR
|
||||
or error["type"] == ray_constants.WORKER_DIED_PUSH_ERROR)
|
||||
# Check that the checkpoint error was pushed to the driver.
|
||||
wait_for_errors(ray_constants.CHECKPOINT_PUSH_ERROR, 1)
|
||||
|
||||
|
||||
def test_checkpointing_load_exception(ray_start_regular,
|
||||
@@ -2555,15 +2553,16 @@ def test_checkpointing_load_exception(ray_start_regular,
|
||||
@ray.remote(max_reconstructions=2)
|
||||
class RemoteCheckpointableActor(ray_checkpointable_actor_cls):
|
||||
def load_checkpoint(self, actor_id, checkpoints):
|
||||
raise Exception("Error during load")
|
||||
raise Exception("Intentional error loading checkpoint.")
|
||||
|
||||
actor = RemoteCheckpointableActor.remote()
|
||||
# Call increase 3 times.
|
||||
# Call increase 3 times, triggering a checkpoint that will succeed.
|
||||
expected = 0
|
||||
for _ in range(3):
|
||||
ray.get(actor.increase.remote())
|
||||
expected += 1
|
||||
# Assert that the actor wasn't resumed from a checkpoint.
|
||||
# Assert that the actor wasn't resumed from a checkpoint because loading
|
||||
# it failed.
|
||||
assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False
|
||||
# Kill actor process.
|
||||
kill_actor(actor)
|
||||
@@ -2583,13 +2582,8 @@ def test_checkpointing_load_exception(ray_start_regular,
|
||||
assert ray.get(actor.get.remote()) == expected
|
||||
assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False
|
||||
|
||||
# Check that checkpointing errors were pushed to the driver.
|
||||
errors = ray.errors()
|
||||
assert len(errors) > 0
|
||||
for error in errors:
|
||||
# An error for the actor process dying may also get pushed.
|
||||
assert (error["type"] == ray_constants.CHECKPOINT_PUSH_ERROR
|
||||
or error["type"] == ray_constants.WORKER_DIED_PUSH_ERROR)
|
||||
# Check that the checkpoint error was pushed to the driver.
|
||||
wait_for_errors(ray_constants.CHECKPOINT_PUSH_ERROR, 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -17,6 +17,7 @@ from ray.autoscaler.autoscaler import StandardAutoscaler, LoadMetrics, \
|
||||
fillout_defaults, validate_config
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_NODE_STATUS
|
||||
from ray.autoscaler.node_provider import NODE_PROVIDERS, NodeProvider
|
||||
from ray.tests.utils import RayTestTimeoutException
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -238,7 +239,8 @@ class AutoscalingTest(unittest.TestCase):
|
||||
if condition():
|
||||
return
|
||||
time.sleep(.1)
|
||||
raise Exception("Timed out waiting for {}".format(condition))
|
||||
raise RayTestTimeoutException(
|
||||
"Timed out waiting for {}".format(condition))
|
||||
|
||||
def waitForNodes(self, expected, comparison=None, tag_filters={}):
|
||||
MAX_ITER = 50
|
||||
|
||||
@@ -33,6 +33,8 @@ import ray.ray_constants as ray_constants
|
||||
import ray.tests.cluster_utils
|
||||
import ray.tests.utils
|
||||
|
||||
from ray.tests.utils import RayTestTimeoutException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1219,8 +1221,9 @@ def test_profiling_api(ray_start_2_cpus):
|
||||
start_time = time.time()
|
||||
while True:
|
||||
if time.time() - start_time > timeout_seconds:
|
||||
raise Exception("Timed out while waiting for information in "
|
||||
"profile table.")
|
||||
raise RayTestTimeoutException(
|
||||
"Timed out while waiting for information in "
|
||||
"profile table.")
|
||||
profile_data = ray.timeline()
|
||||
event_types = {event["cat"] for event in profile_data}
|
||||
expected_types = [
|
||||
@@ -1921,8 +1924,9 @@ def test_gpu_ids(shutdown_only):
|
||||
if len(set(ray.get([f.remote() for _ in range(10)]))) == 10:
|
||||
break
|
||||
if time.time() > start_time + 10:
|
||||
raise Exception("Timed out while waiting for workers to start "
|
||||
"up.")
|
||||
raise RayTestTimeoutException(
|
||||
"Timed out while waiting for workers to start "
|
||||
"up.")
|
||||
|
||||
list_of_ids = ray.get([f0.remote() for _ in range(10)])
|
||||
assert list_of_ids == 10 * [[]]
|
||||
@@ -2519,7 +2523,7 @@ def wait_for_num_tasks(num_tasks, timeout=10):
|
||||
if len(ray.tasks()) >= num_tasks:
|
||||
return
|
||||
time.sleep(0.1)
|
||||
raise Exception("Timed out while waiting for global state.")
|
||||
raise RayTestTimeoutException("Timed out while waiting for global state.")
|
||||
|
||||
|
||||
def wait_for_num_objects(num_objects, timeout=10):
|
||||
@@ -2528,7 +2532,7 @@ def wait_for_num_objects(num_objects, timeout=10):
|
||||
if len(ray.objects()) >= num_objects:
|
||||
return
|
||||
time.sleep(0.1)
|
||||
raise Exception("Timed out while waiting for global state.")
|
||||
raise RayTestTimeoutException("Timed out while waiting for global state.")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -2621,8 +2625,9 @@ def test_global_state_api(shutdown_only):
|
||||
if tables_ready:
|
||||
return
|
||||
time.sleep(0.1)
|
||||
raise Exception("Timed out while waiting for object table to "
|
||||
"update.")
|
||||
raise RayTestTimeoutException(
|
||||
"Timed out while waiting for object table to "
|
||||
"update.")
|
||||
|
||||
object_table = ray.objects()
|
||||
assert len(object_table) == 2
|
||||
|
||||
@@ -14,7 +14,8 @@ import pytest
|
||||
import ray
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray.tests.cluster_utils import Cluster
|
||||
from ray.tests.utils import run_string_as_driver_nonblocking
|
||||
from ray.tests.utils import (run_string_as_driver_nonblocking,
|
||||
RayTestTimeoutException)
|
||||
|
||||
|
||||
# This test checks that when a worker dies in the middle of a get, the plasma
|
||||
@@ -224,7 +225,8 @@ def test_worker_failed(ray_start_workers_separate_multinode):
|
||||
for pid in new_pids:
|
||||
pids.add(pid)
|
||||
if time.time() - start_time > 60:
|
||||
raise Exception("Timed out while waiting to get worker PIDs.")
|
||||
raise RayTestTimeoutException(
|
||||
"Timed out while waiting to get worker PIDs.")
|
||||
|
||||
@ray.remote
|
||||
def f(x):
|
||||
|
||||
@@ -19,6 +19,7 @@ from ray.tests.cluster_utils import Cluster
|
||||
from ray.tests.utils import (
|
||||
relevant_errors,
|
||||
wait_for_errors,
|
||||
RayTestTimeoutException,
|
||||
)
|
||||
|
||||
|
||||
@@ -733,7 +734,7 @@ def test_connect_with_disconnected_node(shutdown_only):
|
||||
# This node is killed by SIGTERM, ray_monitor will not mark it again.
|
||||
removing_node = cluster.add_node(num_cpus=0, _internal_config=config)
|
||||
cluster.remove_node(removing_node, allow_graceful=True)
|
||||
with pytest.raises(Exception, match=("Timing out of wait.")):
|
||||
with pytest.raises(RayTestTimeoutException):
|
||||
wait_for_errors(ray_constants.REMOVED_NODE_ERROR, 3, timeout=2)
|
||||
# There is no connection error to a dead node.
|
||||
info = relevant_errors(ray_constants.RAYLET_CONNECTION_ERROR)
|
||||
|
||||
@@ -10,6 +10,7 @@ import time
|
||||
import ray
|
||||
from ray.utils import _random_string
|
||||
from ray.tests.utils import (
|
||||
RayTestTimeoutException,
|
||||
run_string_as_driver,
|
||||
run_string_as_driver_nonblocking,
|
||||
wait_for_children_of_pid,
|
||||
@@ -256,7 +257,8 @@ print("success")
|
||||
print(output_line)
|
||||
if output_line == "success":
|
||||
return
|
||||
raise Exception("Timed out waiting for process to print success.")
|
||||
raise RayTestTimeoutException(
|
||||
"Timed out waiting for process to print success.")
|
||||
|
||||
# Make sure we can run this driver repeatedly, which means that resources
|
||||
# are getting released in between.
|
||||
|
||||
@@ -14,6 +14,11 @@ import psutil
|
||||
import ray
|
||||
|
||||
|
||||
class RayTestTimeoutException(Exception):
|
||||
"""Exception used to identify timeouts from test utilities."""
|
||||
pass
|
||||
|
||||
|
||||
def _pid_alive(pid):
|
||||
"""Check if the process with this PID is alive or not.
|
||||
|
||||
@@ -36,7 +41,8 @@ def wait_for_pid_to_exit(pid, timeout=20):
|
||||
if not _pid_alive(pid):
|
||||
return
|
||||
time.sleep(0.1)
|
||||
raise Exception("Timed out while waiting for process to exit.")
|
||||
raise RayTestTimeoutException(
|
||||
"Timed out while waiting for process to exit.")
|
||||
|
||||
|
||||
def wait_for_children_of_pid(pid, num_children=1, timeout=20):
|
||||
@@ -47,8 +53,9 @@ def wait_for_children_of_pid(pid, num_children=1, timeout=20):
|
||||
if num_alive >= num_children:
|
||||
return
|
||||
time.sleep(0.1)
|
||||
raise Exception("Timed out while waiting for process children to start "
|
||||
"({}/{} started).".format(num_alive, num_children))
|
||||
raise RayTestTimeoutException(
|
||||
"Timed out while waiting for process children to start "
|
||||
"({}/{} started).".format(num_alive, num_children))
|
||||
|
||||
|
||||
def wait_for_children_of_pid_to_exit(pid, timeout=20):
|
||||
@@ -58,9 +65,9 @@ def wait_for_children_of_pid_to_exit(pid, timeout=20):
|
||||
|
||||
_, alive = psutil.wait_procs(children, timeout=timeout)
|
||||
if len(alive) > 0:
|
||||
raise Exception("Timed out while waiting for process children to exit."
|
||||
" Children still alive: {}.".format(
|
||||
[p.name() for p in alive]))
|
||||
raise RayTestTimeoutException(
|
||||
"Timed out while waiting for process children to exit."
|
||||
" Children still alive: {}.".format([p.name() for p in alive]))
|
||||
|
||||
|
||||
def kill_process_by_name(name, SIGKILL=False):
|
||||
@@ -121,13 +128,14 @@ def relevant_errors(error_type):
|
||||
return [error for error in flat_errors() if error["type"] == error_type]
|
||||
|
||||
|
||||
def wait_for_errors(error_type, num_errors, timeout=10):
|
||||
def wait_for_errors(error_type, num_errors, timeout=20):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if len(relevant_errors(error_type)) >= num_errors:
|
||||
return
|
||||
time.sleep(0.1)
|
||||
raise Exception("Timing out of wait.")
|
||||
raise RayTestTimeoutException("Timed out waiting for {} {} errors.".format(
|
||||
num_errors, error_type))
|
||||
|
||||
|
||||
def wait_for_condition(condition_predictor,
|
||||
|
||||
Reference in New Issue
Block a user