diff --git a/python/ray/tests/test_multi_node_2.py b/python/ray/tests/test_multi_node_2.py index a65786b32..988fef53a 100644 --- a/python/ray/tests/test_multi_node_2.py +++ b/python/ray/tests/test_multi_node_2.py @@ -6,7 +6,7 @@ import ray import ray.ray_constants as ray_constants from ray.monitor import Monitor from ray.cluster_utils import Cluster -from ray.test_utils import generate_internal_config_map +from ray.test_utils import generate_internal_config_map, SignalActor logger = logging.getLogger(__name__) @@ -64,7 +64,7 @@ def setup_monitor(address): return monitor -def verify_load_metrics(monitor, expected_resource_usage=None, timeout=10): +def verify_load_metrics(monitor, expected_resource_usage=None, timeout=30): while True: monitor.process_messages() resource_usage = monitor.load_metrics.get_resource_usage() @@ -114,32 +114,45 @@ def test_heartbeats_single(ray_start_cluster_head): Test proper metrics. """ cluster = ray_start_cluster_head - timeout = 5 monitor = setup_monitor(cluster.address) total_cpus = ray.state.cluster_resources()["CPU"] verify_load_metrics(monitor, (0.0, {"CPU": 0.0}, {"CPU": total_cpus})) @ray.remote - def work(timeout): - time.sleep(timeout) - return True + def work(signal): + wait_signal = signal.wait.remote() + while True: + ready, not_ready = ray.wait([wait_signal], timeout=0) + if len(ready) == 1: + break + time.sleep(1) - work_handle = work.remote(timeout * 2) + signal = SignalActor.remote() + + work_handle = work.remote(signal) verify_load_metrics(monitor, (1.0 / total_cpus, { "CPU": 1.0 }, { "CPU": total_cpus })) + + ray.get(signal.send.remote()) ray.get(work_handle) @ray.remote class Actor: - def work(self, timeout): - time.sleep(timeout) - return True + def work(self, signal): + wait_signal = signal.wait.remote() + while True: + ready, not_ready = ray.wait([wait_signal], timeout=0) + if len(ready) == 1: + break + time.sleep(1) + + signal = SignalActor.remote() test_actor = Actor.remote() - work_handle = test_actor.work.remote(timeout * 2) + work_handle = test_actor.work.remote(signal) verify_load_metrics(monitor, (1.0 / total_cpus, { "CPU": 1.0 @@ -147,6 +160,7 @@ def test_heartbeats_single(ray_start_cluster_head): "CPU": total_cpus })) + ray.get(signal.send.remote()) ray.get(work_handle)