diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index fd4dd371b..38f14dc78 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -19,6 +19,7 @@ from ray.test_utils import ( SignalActor, init_error_pubsub, get_error_message, + Semaphore, ) @@ -727,24 +728,44 @@ def test_warning_for_too_many_nested_tasks(shutdown_only): ray.init(num_cpus=num_cpus) p = init_error_pubsub() + remote_wait = Semaphore.remote(value=0) + nested_wait = Semaphore.remote(value=0) + + ray.get([ + remote_wait.locked.remote(), + nested_wait.locked.remote(), + ]) + @ray.remote def f(): time.sleep(1000) return 1 @ray.remote - def h(): - time.sleep(1) + def h(nested_waits): + nested_wait.release.remote() + ray.get(nested_waits) ray.get(f.remote()) @ray.remote - def g(): + def g(remote_waits, nested_waits): # Sleep so that the f tasks all get submitted to the scheduler after # the g tasks. - time.sleep(1) - ray.get(h.remote()) + remote_wait.release.remote() + # wait until every lock is released. + ray.get(remote_waits) + ray.get(h.remote(nested_waits)) + + num_root_tasks = num_cpus * 4 + # Lock remote task until everything is scheduled. + remote_waits = [] + nested_waits = [] + for _ in range(num_root_tasks): + remote_waits.append(remote_wait.acquire.remote()) + nested_waits.append(nested_wait.acquire.remote()) + + [g.remote(remote_waits, nested_waits) for _ in range(num_root_tasks)] - [g.remote() for _ in range(num_cpus * 6)] errors = get_error_message(p, 1, ray_constants.WORKER_POOL_LARGE_ERROR) assert len(errors) == 1 assert errors[0].type == ray_constants.WORKER_POOL_LARGE_ERROR