diff --git a/python/ray/test/cluster_utils.py b/python/ray/test/cluster_utils.py index afaf5dd21..568048360 100644 --- a/python/ray/test/cluster_utils.py +++ b/python/ray/test/cluster_utils.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import atexit import logging import time @@ -15,7 +16,8 @@ class Cluster(object): def __init__(self, initialize_head=False, connect=False, - head_node_args=None): + head_node_args=None, + shutdown_at_exit=True): """Initializes the cluster. Args: @@ -24,8 +26,10 @@ class Cluster(object): connect (bool): If `initialize_head=True` and `connect=True`, ray.init will be called with the redis address of this cluster passed in. - head_node_args (kwargs): Arguments to be passed into + head_node_args (dict): Arguments to be passed into `start_ray_head` via `self.add_node`. + shutdown_at_exit (bool): If True, registers an exit hook + for shutting down all started processes. """ self.head_node = None self.worker_nodes = {} @@ -41,6 +45,8 @@ class Cluster(object): ray.init( redis_address=self.redis_address, redis_password=redis_password) + if shutdown_at_exit: + atexit.register(self.shutdown) def add_node(self, **override_kwargs): """Adds a node to the local Ray Cluster. @@ -158,12 +164,16 @@ class Cluster(object): return nodes def shutdown(self): + """Removes all nodes.""" + # We create a list here as a copy because `remove_node` # modifies `self.worker_nodes`. all_nodes = list(self.worker_nodes) for node in all_nodes: self.remove_node(node) - self.remove_node(self.head_node) + + if self.head_node is not None: + self.remove_node(self.head_node) class Node(object): diff --git a/test/multi_node_test_2.py b/test/multi_node_test_2.py index 339546be1..04c2c2c55 100644 --- a/test/multi_node_test_2.py +++ b/test/multi_node_test_2.py @@ -58,7 +58,15 @@ def test_cluster(): assert node2.all_processes_alive() g.remove_node(node2) g.remove_node(node) - assert not any(node.any_processes_alive() for node in g.list_all_nodes()) + assert not any(n.any_processes_alive() for n in [node, node2]) + + +def test_shutdown(): + g = Cluster(initialize_head=False) + node = g.add_node() + node2 = g.add_node() + g.shutdown() + assert not any(n.any_processes_alive() for n in [node, node2]) def test_internal_config(start_connected_longer_cluster):