diff --git a/python/ray/node.py b/python/ray/node.py index a63a0a8a8..cd2dc2250 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -11,6 +11,7 @@ import socket import subprocess import sys import tempfile +import threading import time from typing import Optional, Dict @@ -91,6 +92,7 @@ class Node: self.kernel_fate_share = bool( spawn_reaper and ray.utils.detect_fate_sharing_support()) self.all_processes = {} + self.removal_lock = threading.Lock() # Try to get node IP address with the parameters. if ray_params.node_ip_address: @@ -923,6 +925,23 @@ class Node: 2. The process had been started in valgrind and had a non-zero exit code. """ + + # Ensure thread safety + with self.removal_lock: + self._kill_process_impl( + process_type, + allow_graceful=allow_graceful, + check_alive=check_alive, + wait=wait) + + def _kill_process_impl(self, + process_type, + allow_graceful=False, + check_alive=True, + wait=False): + """See `_kill_process_type`.""" + if process_type not in self.all_processes: + return process_infos = self.all_processes[process_type] if process_type != ray_constants.PROCESS_TYPE_REDIS_SERVER: assert len(process_infos) == 1 diff --git a/python/ray/tests/test_client_metadata.py b/python/ray/tests/test_client_metadata.py index ffec75a77..1a6c4e2a5 100644 --- a/python/ray/tests/test_client_metadata.py +++ b/python/ray/tests/test_client_metadata.py @@ -38,3 +38,8 @@ def test_get_runtime_context(ray_start_regular_shared): with pytest.raises(Exception): _ = rtc.task_id + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_client_references.py b/python/ray/tests/test_client_references.py index 8a4458e14..54bfa7f42 100644 --- a/python/ray/tests/test_client_references.py +++ b/python/ray/tests/test_client_references.py @@ -1,5 +1,7 @@ +import pytest from ray.util.client.ray_client_helpers import ray_start_client_server -from ray.util.client.ray_client_helpers import ray_start_client_server_pair +from ray.util.client.ray_client_helpers import ( + ray_start_client_server_pair, ray_start_cluster_client_server_pair) from ray.test_utils import wait_for_condition import ray as real_ray from ray.core.generated.gcs_pb2 import ActorTableData @@ -30,8 +32,14 @@ def server_actor_ref_count(server, n): return test_cond -def test_delete_refs_on_disconnect(ray_start_regular): - with ray_start_client_server_pair() as pair: +@pytest.mark.parametrize( + "ray_start_cluster", [{ + "num_nodes": 1, + "do_init": False + }], indirect=True) +def test_delete_refs_on_disconnect(ray_start_cluster): + cluster = ray_start_cluster + with ray_start_cluster_client_server_pair(cluster.address) as pair: ray, server = pair @ray.remote @@ -49,11 +57,15 @@ def test_delete_refs_on_disconnect(ray_start_regular): # And can get the data assert ray.get(thing1) == 8 - # Close the client + # Close the client. ray.close() wait_for_condition(server_object_ref_count(server, 0), timeout=5) + # Connect to the real ray again, since we disconnected + # upon num_clients = 0. + real_ray.init(address=cluster.address) + def test_cond(): return len(real_ray.objects()) == 0 @@ -73,8 +85,14 @@ def test_delete_ref_on_object_deletion(ray_start_regular): wait_for_condition(server_object_ref_count(server, 1), timeout=5) -def test_delete_actor_on_disconnect(ray_start_regular): - with ray_start_client_server_pair() as pair: +@pytest.mark.parametrize( + "ray_start_cluster", [{ + "num_nodes": 1, + "do_init": False + }], indirect=True) +def test_delete_actor_on_disconnect(ray_start_cluster): + cluster = ray_start_cluster + with ray_start_cluster_client_server_pair(cluster.address) as pair: ray, server = pair @ray.remote @@ -106,6 +124,10 @@ def test_delete_actor_on_disconnect(ray_start_regular): ] return len(alive_actors) == 0 + # Connect to the real ray again, since we disconnected + # upon num_clients = 0. + real_ray.init(address=cluster.address) + wait_for_condition(test_cond, timeout=10) @@ -152,3 +174,9 @@ def test_simple_multiple_references(ray_start_regular): del ref1 assert ray.get(ref2) == "hi" del ref2 + + +if __name__ == "__main__": + import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_client_terminate.py b/python/ray/tests/test_client_terminate.py index 9016c627a..6f7af830f 100644 --- a/python/ray/tests/test_client_terminate.py +++ b/python/ray/tests/test_client_terminate.py @@ -83,3 +83,9 @@ def test_cancel_chain(ray_start_regular, use_force): signaler2.send.remote() ray.get(obj1) + + +if __name__ == "__main__": + import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_cross_language.py b/python/ray/tests/test_cross_language.py index 10766b18b..4ffd6db3e 100644 --- a/python/ray/tests/test_cross_language.py +++ b/python/ray/tests/test_cross_language.py @@ -24,3 +24,7 @@ def test_cross_language_raise_exception(shutdown_only): with pytest.raises(Exception, match="transfer"): ray.java_function("a", "b").remote(PythonObject()) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_mini.py b/python/ray/tests/test_mini.py index dae1e11bd..724deb542 100644 --- a/python/ray/tests/test_mini.py +++ b/python/ray/tests/test_mini.py @@ -59,3 +59,9 @@ def test_actor_api(ray_start_regular): x = 1 f = Foo.remote(x) assert (ray.get(f.get.remote()) == x) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_monitor.py b/python/ray/tests/test_monitor.py index ac67ddcf2..e4b14166d 100644 --- a/python/ray/tests/test_monitor.py +++ b/python/ray/tests/test_monitor.py @@ -37,3 +37,9 @@ def test_parse_resource_demands(): # counted as infeasible or waiting, as long as it's accounted for and # doesn't cause an error. assert len(waiting + infeasible) == 10 + + +if __name__ == "__main__": + import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/client/ray_client_helpers.py b/python/ray/util/client/ray_client_helpers.py index 77f09346d..a7f16c246 100644 --- a/python/ray/util/client/ray_client_helpers.py +++ b/python/ray/util/client/ray_client_helpers.py @@ -1,5 +1,6 @@ from contextlib import contextmanager +import ray as real_ray import ray.util.client.server.server as ray_client_server from ray.util.client import ray @@ -22,3 +23,21 @@ def ray_start_client_server_pair(): ray._inside_client_test = False ray.disconnect() server.stop(0) + + +@contextmanager +def ray_start_cluster_client_server_pair(address): + ray._inside_client_test = True + + def ray_connect_handler(): + real_ray.init(address=address) + + server = ray_client_server.serve( + "localhost:50051", ray_connect_handler=ray_connect_handler) + ray.connect("localhost:50051") + try: + yield ray, server + finally: + ray._inside_client_test = False + ray.disconnect() + server.stop(0)