diff --git a/python/ray/node.py b/python/ray/node.py index cf0733a93..28e01f586 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -61,9 +61,12 @@ class Node(object): connect_only (bool): If true, connect to the node without starting new processes. """ - if shutdown_at_exit and connect_only: - raise ValueError("'shutdown_at_exit' and 'connect_only' cannot " - "be both true.") + if shutdown_at_exit: + if connect_only: + raise ValueError("'shutdown_at_exit' and 'connect_only' " + "cannot both be true.") + self._register_shutdown_hooks() + self.head = head self.all_processes = {} @@ -152,9 +155,50 @@ class Node(object): if not connect_only: self.start_ray_processes() - if shutdown_at_exit: - atexit.register(lambda: self.kill_all_processes( - check_alive=False, allow_graceful=True)) + def _register_shutdown_hooks(self): + # Make ourselves a process group session leader to ensure we can clean + # up child processes later without killing a process that started us. + try: + os.setpgrp() + except OSError as e: + logger.warning("setpgrp failed, processes may not be " + "cleaned up properly: {}.".format(e)) + + # Clean up child process by first going through the normal + # kill_all_processes procedure (which should clean them all up + # under normal circumstances), then sending a SIGTERM to our + # process group to take care of any children that may have been + # spawned but not yet added to the list. + def clean_up_children(sigterm_handler): + self.kill_all_processes(check_alive=False, allow_graceful=True) + signal.signal(signal.SIGTERM, sigterm_handler) + try: + # SIGTERM our process group as a last resort in case there + # were processes that we spawned but didn't add to the list + # (could happen if interrupted just after spawning them). + # We could send SIGKILL here to be sure, but we're also + # sending it to ourselves. + os.killpg(0, signal.SIGTERM) + except OSError as e: + print("killpg failed, processes may not have " + "been cleaned up properly: {}.".format(e)) + + # Register the a handler to be called during the normal python + # shutdown process. We pass an empty lambda to clean_up_children + # because after cleaning up the child processes, it should do + # nothing and return so that the shutdown process can continue. + def atexit_handler(): + return clean_up_children(lambda *args, **kwargs: None) + + atexit.register(atexit_handler) + + # Register the a handler to be called if we get a SIGTERM. + # In this case, we want to exit with an error code (1) after + # cleaning up child processes. + def sigterm_handler(): + return clean_up_children(lambda *args, **kwargs: sys.exit(1)) + + signal.signal(signal.SIGTERM, sigterm_handler) def _init_temp(self, redis_client): # Create an dictionary to store temp file index. diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index da297a7f8..f66fd8019 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -9,6 +9,7 @@ import logging import os import subprocess import sys +import time import ray.services as services from ray.autoscaler.commands import ( @@ -314,7 +315,7 @@ def start(node_ip_address, redis_address, address, redis_port, include_java=False, ) - node = ray.node.Node(ray_params, head=True, shutdown_at_exit=False) + node = ray.node.Node(ray_params, head=True, shutdown_at_exit=block) redis_address = node.redis_address logger.info( @@ -384,13 +385,12 @@ def start(node_ip_address, redis_address, address, redis_port, check_no_existing_redis_clients(ray_params.node_ip_address, redis_client) ray_params.update(redis_address=redis_address) - node = ray.node.Node(ray_params, head=False, shutdown_at_exit=False) + node = ray.node.Node(ray_params, head=False, shutdown_at_exit=block) logger.info("\nStarted Ray on this node. If you wish to terminate the " "processes that have been started, run\n\n" " ray stop") if block: - import time while True: time.sleep(1) deceased = node.dead_processes() @@ -399,8 +399,8 @@ def start(node_ip_address, redis_address, address, redis_port, for process_type, process in deceased: logger.error("\t{} died with exit code {}".format( process_type, process.returncode)) + # shutdown_at_exit will handle cleanup. logger.error("Killing remaining processes and exiting...") - node.kill_all_processes(check_alive=False, allow_graceful=True) sys.exit(1) diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index 411c7669c..b75e5163a 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -9,8 +9,13 @@ import time import ray from ray.utils import _random_string -from ray.tests.utils import (run_string_as_driver, - run_string_as_driver_nonblocking) +from ray.tests.utils import ( + run_string_as_driver, + run_string_as_driver_nonblocking, + wait_for_children_of_pid, + wait_for_children_of_pid_to_exit, + kill_process_by_name, +) def test_error_isolation(call_ray_start): @@ -267,7 +272,7 @@ print("success") def test_calling_start_ray_head(): - # Test that we can call start-ray.sh with various command line + # Test that we can call ray start with various command line # parameters. TODO(rkn): This test only tests the --head code path. We # should also test the non-head node code path. @@ -327,62 +332,30 @@ def test_calling_start_ray_head(): ["ray", "start", "--head", "--redis-address", "127.0.0.1:6379"]) subprocess.check_output(["ray", "stop"]) - # Test --block. Killing any child process should cause the command to exit. + # Test --block. Killing a child process should cause the command to exit. blocked = subprocess.Popen(["ray", "start", "--head", "--block"]) - blocked.poll() - # Wait for up to 10s for the ray command to spawn a child process. - for _ in range(10): - try: - subprocess.check_output(["pgrep", "-P", str(blocked.pid)]) - break - except subprocess.CalledProcessError: - time.sleep(1) - else: - assert False, "ray start didn't spawn children within 10s of starting" + wait_for_children_of_pid(blocked.pid, num_children=7, timeout=30) blocked.poll() assert blocked.returncode is None - # Kill all child processes of the ray command and check that it exits. - subprocess.check_output(["pkill", "-P", str(blocked.pid)]) - for _ in range(10): - time.sleep(1) - blocked.poll() - if blocked.returncode is not None: - break - else: - assert False, "ray start didn't exit within 10s of child process dying" - - assert blocked.returncode != 0 + kill_process_by_name("raylet") + wait_for_children_of_pid_to_exit(blocked.pid, timeout=120) + blocked.wait() + assert blocked.returncode != 0, "ray start shouldn't return 0 on bad exit" # Test --block. Killing the command should clean up all child processes. blocked = subprocess.Popen(["ray", "start", "--head", "--block"]) blocked.poll() assert blocked.returncode is None - # Wait for up to 10s for the ray command to spawn a child process. - for _ in range(10): - try: - subprocess.check_output(["pgrep", "-P", str(blocked.pid)]) - break - except subprocess.CalledProcessError: - time.sleep(1) - else: - assert False, "ray start didn't spawn children within 10s of starting" + wait_for_children_of_pid(blocked.pid, num_children=7, timeout=30) blocked.terminate() - - # Check that the child processes are cleaned up within 10s. - for _ in range(10): - try: - subprocess.check_output( - ["pgrep", "-P", str(blocked.pid), "raylet"]) - except subprocess.CalledProcessError: - # pgrep didn't find anything, so the child processes are dead. - break - else: - assert False, "ray start didn't kill children within 10s of exiting." + wait_for_children_of_pid_to_exit(blocked.pid, timeout=120) + blocked.wait() + assert blocked.returncode != 0, "ray start shouldn't return 0 on bad exit" @pytest.mark.parametrize( diff --git a/python/ray/tests/utils.py b/python/ray/tests/utils.py index 5203a7d8e..ea14adfb8 100644 --- a/python/ray/tests/utils.py +++ b/python/ray/tests/utils.py @@ -4,6 +4,7 @@ from __future__ import print_function import fnmatch import os +import psutil import subprocess import sys import tempfile @@ -37,6 +38,39 @@ def wait_for_pid_to_exit(pid, timeout=20): raise Exception("Timed out while waiting for process to exit.") +def wait_for_children_of_pid(pid, num_children=1, timeout=20): + p = psutil.Process(pid) + start_time = time.time() + while time.time() - start_time < timeout: + num_alive = len(p.children(recursive=False)) + if num_alive >= num_children: + return + time.sleep(0.1) + raise Exception("Timed out while waiting for process children to start " + "({}/{} started).".format(num_alive, num_children)) + + +def wait_for_children_of_pid_to_exit(pid, timeout=20): + children = psutil.Process(pid).children() + if len(children) == 0: + return + + _, alive = psutil.wait_procs(children, timeout=timeout) + if len(alive) > 0: + raise Exception("Timed out while waiting for process children to exit." + " Children still alive: {}.".format( + [p.name() for p in alive])) + + +def kill_process_by_name(name, SIGKILL=False): + for p in psutil.process_iter(attrs=["name"]): + if p.info["name"] == name: + if SIGKILL: + p.kill() + else: + p.terminate() + + def run_string_as_driver(driver_script): """Run a driver as a separate process.