[flaky test] Fix flaky checkpointing tests (#5791)

* Fix flaky checkpointing tests

* Fix checkpoint test logic

* Fix exception matching

* timeout exception

* Fix import

* fix build
This commit is contained in:
Edward Oakes
2019-09-27 11:03:07 -07:00
committed by GitHub
parent baf85c6665
commit 86610a30c9
9 changed files with 71 additions and 51 deletions
+1 -1
View File
@@ -866,7 +866,7 @@ class FunctionActorManager(object):
# `available_checkpoints` list.
msg = (
"`load_checkpoint` must return a checkpoint id that " +
"exists in the `available_checkpoints` list, or eone.")
"exists in the `available_checkpoints` list, or None.")
assert any(checkpoint_id == checkpoint.checkpoint_id
for checkpoint in checkpoints), msg
# Notify raylet that this actor has been resumed from
+23 -29
View File
@@ -75,16 +75,19 @@ def ray_checkpointable_actor_cls(request):
if not os.path.isfile(filename):
return None
available_checkpoint_ids = [
c.checkpoint_id for c in available_checkpoints
]
with open(filename, "r") as f:
lines = f.readlines()
checkpoint_id, value = lines[-1].split(" ")
self.value = int(value)
self.resumed_from_checkpoint = True
checkpoint_id = ray.ActorCheckpointID(
ray.utils.hex_to_binary(checkpoint_id))
assert any(checkpoint_id == checkpoint.checkpoint_id
for checkpoint in available_checkpoints)
return checkpoint_id
for line in f:
checkpoint_id, value = line.strip().split(" ")
checkpoint_id = ray.ActorCheckpointID(
ray.utils.hex_to_binary(checkpoint_id))
if checkpoint_id in available_checkpoint_ids:
self.value = int(value)
self.resumed_from_checkpoint = True
return checkpoint_id
return None
def checkpoint_expired(self, actor_id, checkpoint_id):
pass
@@ -2405,7 +2408,7 @@ def test_checkpointing(ray_start_regular, ray_checkpointable_actor_cls):
"""Test actor checkpointing and restoring from a checkpoint."""
actor = ray.remote(
max_reconstructions=2)(ray_checkpointable_actor_cls).remote()
# Call increase 3 times.
# Call increase 3 times, triggering a checkpoint.
expected = 0
for _ in range(3):
ray.get(actor.increase.remote())
@@ -2511,10 +2514,10 @@ def test_checkpointing_save_exception(ray_start_regular,
@ray.remote(max_reconstructions=2)
class RemoteCheckpointableActor(ray_checkpointable_actor_cls):
def save_checkpoint(self, actor_id, checkpoint_context):
raise Exception("Error during save")
raise Exception("Intentional error saving checkpoint.")
actor = RemoteCheckpointableActor.remote()
# Call increase 3 times.
# Call increase 3 times, triggering a checkpoint that will fail.
expected = 0
for _ in range(3):
ray.get(actor.increase.remote())
@@ -2539,13 +2542,8 @@ def test_checkpointing_save_exception(ray_start_regular,
assert ray.get(actor.get.remote()) == expected
assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False
# Check that checkpointing errors were pushed to the driver.
errors = ray.errors()
assert len(errors) > 0
for error in errors:
# An error for the actor process dying may also get pushed.
assert (error["type"] == ray_constants.CHECKPOINT_PUSH_ERROR
or error["type"] == ray_constants.WORKER_DIED_PUSH_ERROR)
# Check that the checkpoint error was pushed to the driver.
wait_for_errors(ray_constants.CHECKPOINT_PUSH_ERROR, 1)
def test_checkpointing_load_exception(ray_start_regular,
@@ -2555,15 +2553,16 @@ def test_checkpointing_load_exception(ray_start_regular,
@ray.remote(max_reconstructions=2)
class RemoteCheckpointableActor(ray_checkpointable_actor_cls):
def load_checkpoint(self, actor_id, checkpoints):
raise Exception("Error during load")
raise Exception("Intentional error loading checkpoint.")
actor = RemoteCheckpointableActor.remote()
# Call increase 3 times.
# Call increase 3 times, triggering a checkpoint that will succeed.
expected = 0
for _ in range(3):
ray.get(actor.increase.remote())
expected += 1
# Assert that the actor wasn't resumed from a checkpoint.
# Assert that the actor wasn't resumed from a checkpoint because loading
# it failed.
assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False
# Kill actor process.
kill_actor(actor)
@@ -2583,13 +2582,8 @@ def test_checkpointing_load_exception(ray_start_regular,
assert ray.get(actor.get.remote()) == expected
assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False
# Check that checkpointing errors were pushed to the driver.
errors = ray.errors()
assert len(errors) > 0
for error in errors:
# An error for the actor process dying may also get pushed.
assert (error["type"] == ray_constants.CHECKPOINT_PUSH_ERROR
or error["type"] == ray_constants.WORKER_DIED_PUSH_ERROR)
# Check that the checkpoint error was pushed to the driver.
wait_for_errors(ray_constants.CHECKPOINT_PUSH_ERROR, 1)
@pytest.mark.parametrize(
+3 -1
View File
@@ -17,6 +17,7 @@ from ray.autoscaler.autoscaler import StandardAutoscaler, LoadMetrics, \
fillout_defaults, validate_config
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_NODE_STATUS
from ray.autoscaler.node_provider import NODE_PROVIDERS, NodeProvider
from ray.tests.utils import RayTestTimeoutException
import pytest
@@ -238,7 +239,8 @@ class AutoscalingTest(unittest.TestCase):
if condition():
return
time.sleep(.1)
raise Exception("Timed out waiting for {}".format(condition))
raise RayTestTimeoutException(
"Timed out waiting for {}".format(condition))
def waitForNodes(self, expected, comparison=None, tag_filters={}):
MAX_ITER = 50
+13 -8
View File
@@ -33,6 +33,8 @@ import ray.ray_constants as ray_constants
import ray.tests.cluster_utils
import ray.tests.utils
from ray.tests.utils import RayTestTimeoutException
logger = logging.getLogger(__name__)
@@ -1219,8 +1221,9 @@ def test_profiling_api(ray_start_2_cpus):
start_time = time.time()
while True:
if time.time() - start_time > timeout_seconds:
raise Exception("Timed out while waiting for information in "
"profile table.")
raise RayTestTimeoutException(
"Timed out while waiting for information in "
"profile table.")
profile_data = ray.timeline()
event_types = {event["cat"] for event in profile_data}
expected_types = [
@@ -1921,8 +1924,9 @@ def test_gpu_ids(shutdown_only):
if len(set(ray.get([f.remote() for _ in range(10)]))) == 10:
break
if time.time() > start_time + 10:
raise Exception("Timed out while waiting for workers to start "
"up.")
raise RayTestTimeoutException(
"Timed out while waiting for workers to start "
"up.")
list_of_ids = ray.get([f0.remote() for _ in range(10)])
assert list_of_ids == 10 * [[]]
@@ -2519,7 +2523,7 @@ def wait_for_num_tasks(num_tasks, timeout=10):
if len(ray.tasks()) >= num_tasks:
return
time.sleep(0.1)
raise Exception("Timed out while waiting for global state.")
raise RayTestTimeoutException("Timed out while waiting for global state.")
def wait_for_num_objects(num_objects, timeout=10):
@@ -2528,7 +2532,7 @@ def wait_for_num_objects(num_objects, timeout=10):
if len(ray.objects()) >= num_objects:
return
time.sleep(0.1)
raise Exception("Timed out while waiting for global state.")
raise RayTestTimeoutException("Timed out while waiting for global state.")
@pytest.mark.skipif(
@@ -2621,8 +2625,9 @@ def test_global_state_api(shutdown_only):
if tables_ready:
return
time.sleep(0.1)
raise Exception("Timed out while waiting for object table to "
"update.")
raise RayTestTimeoutException(
"Timed out while waiting for object table to "
"update.")
object_table = ray.objects()
assert len(object_table) == 2
+4 -2
View File
@@ -14,7 +14,8 @@ import pytest
import ray
import ray.ray_constants as ray_constants
from ray.tests.cluster_utils import Cluster
from ray.tests.utils import run_string_as_driver_nonblocking
from ray.tests.utils import (run_string_as_driver_nonblocking,
RayTestTimeoutException)
# This test checks that when a worker dies in the middle of a get, the plasma
@@ -224,7 +225,8 @@ def test_worker_failed(ray_start_workers_separate_multinode):
for pid in new_pids:
pids.add(pid)
if time.time() - start_time > 60:
raise Exception("Timed out while waiting to get worker PIDs.")
raise RayTestTimeoutException(
"Timed out while waiting to get worker PIDs.")
@ray.remote
def f(x):
+2 -1
View File
@@ -19,6 +19,7 @@ from ray.tests.cluster_utils import Cluster
from ray.tests.utils import (
relevant_errors,
wait_for_errors,
RayTestTimeoutException,
)
@@ -733,7 +734,7 @@ def test_connect_with_disconnected_node(shutdown_only):
# This node is killed by SIGTERM, ray_monitor will not mark it again.
removing_node = cluster.add_node(num_cpus=0, _internal_config=config)
cluster.remove_node(removing_node, allow_graceful=True)
with pytest.raises(Exception, match=("Timing out of wait.")):
with pytest.raises(RayTestTimeoutException):
wait_for_errors(ray_constants.REMOVED_NODE_ERROR, 3, timeout=2)
# There is no connection error to a dead node.
info = relevant_errors(ray_constants.RAYLET_CONNECTION_ERROR)
+3 -1
View File
@@ -10,6 +10,7 @@ import time
import ray
from ray.utils import _random_string
from ray.tests.utils import (
RayTestTimeoutException,
run_string_as_driver,
run_string_as_driver_nonblocking,
wait_for_children_of_pid,
@@ -256,7 +257,8 @@ print("success")
print(output_line)
if output_line == "success":
return
raise Exception("Timed out waiting for process to print success.")
raise RayTestTimeoutException(
"Timed out waiting for process to print success.")
# Make sure we can run this driver repeatedly, which means that resources
# are getting released in between.
+16 -8
View File
@@ -14,6 +14,11 @@ import psutil
import ray
class RayTestTimeoutException(Exception):
"""Exception used to identify timeouts from test utilities."""
pass
def _pid_alive(pid):
"""Check if the process with this PID is alive or not.
@@ -36,7 +41,8 @@ def wait_for_pid_to_exit(pid, timeout=20):
if not _pid_alive(pid):
return
time.sleep(0.1)
raise Exception("Timed out while waiting for process to exit.")
raise RayTestTimeoutException(
"Timed out while waiting for process to exit.")
def wait_for_children_of_pid(pid, num_children=1, timeout=20):
@@ -47,8 +53,9 @@ def wait_for_children_of_pid(pid, num_children=1, timeout=20):
if num_alive >= num_children:
return
time.sleep(0.1)
raise Exception("Timed out while waiting for process children to start "
"({}/{} started).".format(num_alive, num_children))
raise RayTestTimeoutException(
"Timed out while waiting for process children to start "
"({}/{} started).".format(num_alive, num_children))
def wait_for_children_of_pid_to_exit(pid, timeout=20):
@@ -58,9 +65,9 @@ def wait_for_children_of_pid_to_exit(pid, timeout=20):
_, alive = psutil.wait_procs(children, timeout=timeout)
if len(alive) > 0:
raise Exception("Timed out while waiting for process children to exit."
" Children still alive: {}.".format(
[p.name() for p in alive]))
raise RayTestTimeoutException(
"Timed out while waiting for process children to exit."
" Children still alive: {}.".format([p.name() for p in alive]))
def kill_process_by_name(name, SIGKILL=False):
@@ -121,13 +128,14 @@ def relevant_errors(error_type):
return [error for error in flat_errors() if error["type"] == error_type]
def wait_for_errors(error_type, num_errors, timeout=10):
def wait_for_errors(error_type, num_errors, timeout=20):
start_time = time.time()
while time.time() - start_time < timeout:
if len(relevant_errors(error_type)) >= num_errors:
return
time.sleep(0.1)
raise Exception("Timing out of wait.")
raise RayTestTimeoutException("Timed out waiting for {} {} errors.".format(
num_errors, error_type))
def wait_for_condition(condition_predictor,