From e4f9b3b7d951cdec2df11815ecb4ddf92a0a5afa Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Tue, 26 Nov 2019 22:00:08 -0600 Subject: [PATCH] Use process reaper for cleanup (#6253) --- python/ray/cluster_utils.py | 8 ++- python/ray/node.py | 84 +++++++++++++++------------ python/ray/ray_constants.py | 1 + python/ray/ray_process_reaper.py | 49 ++++++++++++++++ python/ray/scripts/scripts.py | 14 ++--- python/ray/services.py | 48 ++++++++++++++- python/ray/tests/test_multi_node_2.py | 1 + python/ray/worker.py | 14 ++++- python/ray/workers/default_worker.py | 6 +- 9 files changed, 172 insertions(+), 53 deletions(-) create mode 100644 python/ray/ray_process_reaper.py diff --git a/python/ray/cluster_utils.py b/python/ray/cluster_utils.py index f42d63114..27f7f8ba1 100644 --- a/python/ray/cluster_utils.py +++ b/python/ray/cluster_utils.py @@ -85,7 +85,10 @@ class Cluster(object): ray_params.update_if_absent(**default_kwargs) if self.head_node is None: node = ray.node.Node( - ray_params, head=True, shutdown_at_exit=self._shutdown_at_exit) + ray_params, + head=True, + shutdown_at_exit=self._shutdown_at_exit, + spawn_reaper=self._shutdown_at_exit) self.head_node = node self.redis_address = self.head_node.redis_address self.redis_password = node_args.get("redis_password") @@ -99,7 +102,8 @@ class Cluster(object): node = ray.node.Node( ray_params, head=False, - shutdown_at_exit=self._shutdown_at_exit) + shutdown_at_exit=self._shutdown_at_exit, + spawn_reaper=self._shutdown_at_exit) self.worker_nodes.add(node) # Wait for the node to appear in the client table. We do this so that diff --git a/python/ray/node.py b/python/ray/node.py index 08eb307dd..53984c53f 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -47,6 +47,7 @@ class Node(object): ray_params, head=False, shutdown_at_exit=True, + spawn_reaper=True, connect_only=False): """Start a node. @@ -56,9 +57,10 @@ class Node(object): head (bool): True if this is the head node, which means it will start additional processes like the Redis servers, monitor processes, and web UI. - shutdown_at_exit (bool): If true, a handler will be registered to - shutdown the processes started here when the Python interpreter - exits. + shutdown_at_exit (bool): If true, spawned processes will be cleaned + up if this process exits normally. + spawn_reaper (bool): If true, spawns a process that will clean up + other spawned processes if this process dies unexpectedly. connect_only (bool): If true, connect to the node without starting new processes. """ @@ -158,6 +160,9 @@ class Node(object): # raylet starts. self._ray_params.node_manager_port = self._get_unused_port() + if not connect_only and spawn_reaper: + self.start_reaper_process() + # Start processes. if head: self.start_head_processes() @@ -170,39 +175,10 @@ class Node(object): self.start_ray_processes() def _register_shutdown_hooks(self): - # Make ourselves a process group session leader to ensure we can clean - # up child processes later without killing a process that started us. - try: - os.setpgrp() - except OSError as e: - logger.warning("setpgrp failed, processes may not be " - "cleaned up properly: {}.".format(e)) - - # Clean up child process by first going through the normal - # kill_all_processes procedure (which should clean them all up - # under normal circumstances), then sending a SIGTERM to our - # process group to take care of any children that may have been - # spawned but not yet added to the list. - def clean_up_children(sigterm_handler): + # Register the atexit handler. In this case, we shouldn't call sys.exit + # as we're already in the exit procedure. + def atexit_handler(*args): self.kill_all_processes(check_alive=False, allow_graceful=True) - signal.signal(signal.SIGTERM, sigterm_handler) - try: - # SIGTERM our process group as a last resort in case there - # were processes that we spawned but didn't add to the list - # (could happen if interrupted just after spawning them). - # We could send SIGKILL here to be sure, but we're also - # sending it to ourselves. - os.killpg(0, signal.SIGTERM) - except OSError as e: - print("killpg failed, processes may not have " - "been cleaned up properly: {}.".format(e)) - - # Register the a handler to be called during the normal python - # shutdown process. We pass an empty lambda to clean_up_children - # because after cleaning up the child processes, it should do - # nothing and return so that the shutdown process can continue. - def atexit_handler(): - return clean_up_children(lambda *args, **kwargs: None) atexit.register(atexit_handler) @@ -210,7 +186,8 @@ class Node(object): # In this case, we want to exit with an error code (1) after # cleaning up child processes. def sigterm_handler(signum, frame): - return clean_up_children(lambda *args, **kwargs: sys.exit(1)) + self.kill_all_processes(check_alive=False, allow_graceful=True) + sys.exit(1) signal.signal(signal.SIGTERM, sigterm_handler) @@ -435,6 +412,20 @@ class Node(object): return self._make_inc_temp( prefix=default_prefix, directory_name=self._sockets_dir) + def start_reaper_process(self): + """ + Start the reaper process. + + This must be the first process spawned and should only be called when + ray processes should be cleaned up if this process dies. + """ + process_info = ray.services.start_reaper() + assert ray_constants.PROCESS_TYPE_REAPER not in self.all_processes + if process_info is not None: + self.all_processes[ray_constants.PROCESS_TYPE_REAPER] = [ + process_info + ] + def start_redis(self): """Start the Redis servers.""" assert self._redis_address is None @@ -790,6 +781,16 @@ class Node(object): self._kill_process_type( ray_constants.PROCESS_TYPE_RAYLET_MONITOR, check_alive=check_alive) + def kill_reaper(self, check_alive=True): + """Kill the reaper process. + + Args: + check_alive (bool): Raise an exception if the process was already + dead. + """ + self._kill_process_type( + ray_constants.PROCESS_TYPE_REAPER, check_alive=check_alive) + def kill_all_processes(self, check_alive=True, allow_graceful=False): """Kill all of the processes. @@ -814,8 +815,17 @@ class Node(object): # We call "list" to copy the keys because we are modifying the # dictionary while iterating over it. for process_type in list(self.all_processes.keys()): + # Need to kill the reaper process last in case we die unexpectedly + # while cleaning up. + if process_type != ray_constants.PROCESS_TYPE_REAPER: + self._kill_process_type( + process_type, + check_alive=check_alive, + allow_graceful=allow_graceful) + + if ray_constants.PROCESS_TYPE_REAPER in self.all_processes: self._kill_process_type( - process_type, + ray_constants.PROCESS_TYPE_REAPER, check_alive=check_alive, allow_graceful=allow_graceful) diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index 87fb53801..6fa335d6d 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -168,6 +168,7 @@ NO_RECONSTRUCTION = 0 INFINITE_RECONSTRUCTION = 2**30 # Constants used to define the different process types. +PROCESS_TYPE_REAPER = "reaper" PROCESS_TYPE_MONITOR = "monitor" PROCESS_TYPE_RAYLET_MONITOR = "raylet_monitor" PROCESS_TYPE_LOG_MONITOR = "log_monitor" diff --git a/python/ray/ray_process_reaper.py b/python/ray/ray_process_reaper.py new file mode 100644 index 000000000..f804242c4 --- /dev/null +++ b/python/ray/ray_process_reaper.py @@ -0,0 +1,49 @@ +import os +import signal +import sys +import time +""" +This is a lightweight "reaper" process used to ensure that ray processes are +cleaned up properly when the main ray process dies unexpectedly (e.g., +segfaults or gets SIGKILLed). Note that processes may not be cleaned up +properly if this process is SIGTERMed or SIGKILLed. + +It detects that its parent has died by reading from stdin, which must be +inherited from the parent process so that the OS will deliver an EOF if the +parent dies. When this happens, the reaper process kills the rest of its +process group (first attempting graceful shutdown with SIGTERM, then escalating +to SIGKILL). +""" + +SIGTERM_GRACE_PERIOD_SECONDS = 1 + + +def reap_process_group(*args): + def sigterm_handler(*args): + # Give a one-second grace period for other processes to clean up. + time.sleep(SIGTERM_GRACE_PERIOD_SECONDS) + # SIGKILL the pgroup (including ourselves) as a last-resort. + os.killpg(0, signal.SIGKILL) + + # Set a SIGTERM handler to handle SIGTERMing ourselves with the group. + signal.signal(signal.SIGTERM, sigterm_handler) + + # Our parent must have died, SIGTERM the group (including ourselves). + # TODO(mehrdadn): killpg isn't supported on Windows. + os.killpg(0, signal.SIGTERM) + + +def main(): + # Read from stdout forever. Because stdout is a file descriptor + # inherited from our parent process, we will get an EOF if the parent + # dies, which is signaled by an empty return from read(). + # We intentionally don't set any signal handlers here, so a SIGTERM from + # the parent can be used to kill this process gracefully without it killing + # the rest of the process group. + while len(sys.stdin.read()) != 0: + pass + reap_process_group() + + +if __name__ == "__main__": + main() diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index ff3f7a925..733e868ec 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -327,7 +327,8 @@ def start(node_ip_address, redis_address, address, redis_port, include_java=False, ) - node = ray.node.Node(ray_params, head=True, shutdown_at_exit=block) + node = ray.node.Node( + ray_params, head=True, shutdown_at_exit=block, spawn_reaper=block) redis_address = node.redis_address logger.info( @@ -395,7 +396,8 @@ def start(node_ip_address, redis_address, address, redis_port, check_no_existing_redis_clients(ray_params.node_ip_address, redis_client) ray_params.update(redis_address=redis_address) - node = ray.node.Node(ray_params, head=False, shutdown_at_exit=block) + node = ray.node.Node( + ray_params, head=False, shutdown_at_exit=block, spawn_reaper=block) logger.info("\nStarted Ray on this node. If you wish to terminate the " "processes that have been started, run\n\n" " ray stop") @@ -436,7 +438,6 @@ def stop(force, verbose): # See STANDARD FORMAT SPECIFIERS section of # http://man7.org/linux/man-pages/man1/ps.1.html # about comm and args. This can help avoid killing non-ray processes. - # Format: # Keyword to filter, filter by command (True)/filter by args (False) ["raylet", True], @@ -450,12 +451,9 @@ def stop(force, verbose): ["log_monitor.py", False], ["reporter.py", False], ["dashboard.py", False], + ["ray_process_reaper.py", False], ] - signal_name = "TERM" - if force: - signal_name = "KILL" - for process in processes_to_kill: keyword, filter_by_cmd = process if filter_by_cmd: @@ -475,7 +473,7 @@ def stop(force, verbose): "kill -s {} $(ps ax -o {} | grep {} | grep -v grep {} | grep ray |" "awk '{{ print $1 }}') 2> /dev/null".format( # ^^ This is how you escape braces in python format string. - signal_name, + "KILL" if force else "TERM", ps_format, keyword, debug_operator)) diff --git a/python/ray/services.py b/python/ray/services.py index 2ce8f0bb4..d3852186f 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -330,7 +330,8 @@ def start_ray_process(command, use_perftools_profiler=False, use_tmux=False, stdout_file=None, - stderr_file=None): + stderr_file=None, + pipe_stdin=False): """Start one of the Ray processes. TODO(rkn): We need to figure out how these commands interact. For example, @@ -357,6 +358,8 @@ def start_ray_process(command, no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. + pipe_stdin: If true, subprocess.PIPE will be passed to the process as + stdin. Returns: Information about the process that was started including a handle to @@ -438,13 +441,23 @@ def start_ray_process(command, # version, and tmux 2.1) command = ["tmux", "new-session", "-d", "{}".format(" ".join(command))] + # Block sigint for spawned processes so they aren't killed by the SIGINT + # propagated from the shell on Ctrl-C so we can handle KeyboardInterrupts + # in interactive sessions. This is only supported in Python 3.3 and above. + def block_sigint(): + import signal + import sys + if sys.version_info >= (3, 3): + signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT}) + process = subprocess.Popen( command, env=modified_env, cwd=cwd, stdout=stdout_file, stderr=stderr_file, - preexec_fn=os.setsid) + stdin=subprocess.PIPE if pipe_stdin else None, + preexec_fn=block_sigint) return ProcessInfo( process=process, @@ -563,6 +576,37 @@ def check_version_info(redis_client): logger.warning(error_message) +def start_reaper(): + """Start the reaper process. + + This is a lightweight process that simply + waits for its parent process to die and then terminates its own + process group. This allows us to ensure that ray processes are always + terminated properly so long as that process itself isn't SIGKILLed. + + Returns: + ProcessInfo for the process that was started. + """ + # Make ourselves a process group leader so that the reaper can clean + # up other ray processes without killing the process group of the + # process that started us. + try: + os.setpgrp() + except OSError as e: + logger.warning("setpgrp failed, processes may not be " + "cleaned up properly: {}.".format(e)) + # Don't start the reaper in this case as it could result in killing + # other user processes. + return None + + reaper_filepath = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "ray_process_reaper.py") + command = [sys.executable, "-u", reaper_filepath] + process_info = start_ray_process( + command, ray_constants.PROCESS_TYPE_REAPER, pipe_stdin=True) + return process_info + + def start_redis(node_ip_address, redirect_files, resource_spec, diff --git a/python/ray/tests/test_multi_node_2.py b/python/ray/tests/test_multi_node_2.py index 4812b4383..b83a54060 100644 --- a/python/ray/tests/test_multi_node_2.py +++ b/python/ray/tests/test_multi_node_2.py @@ -217,6 +217,7 @@ def test_worker_plasma_store_failure(ray_start_cluster_head): cluster.wait_for_nodes() worker.kill_reporter() worker.kill_plasma_store() + worker.kill_reaper() worker.all_processes[ray_constants.PROCESS_TYPE_RAYLET][0].process.wait() assert not worker.any_processes_alive(), worker.live_processes() diff --git a/python/ray/worker.py b/python/ray/worker.py index 3a836135d..bb11cd327 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -724,9 +724,13 @@ def init(address=None, ) # Start the Ray processes. We set shutdown_at_exit=False because we # shutdown the node in the ray.shutdown call that happens in the atexit - # handler. + # handler. We still spawn a reaper process in case the atexit handler + # isn't called. _global_node = ray.node.Node( - head=True, shutdown_at_exit=False, ray_params=ray_params) + head=True, + shutdown_at_exit=False, + spawn_reaper=True, + ray_params=ray_params) else: # In this case, we are connecting to an existing cluster. if num_cpus is not None or num_gpus is not None: @@ -779,7 +783,11 @@ def init(address=None, load_code_from_local=load_code_from_local, use_pickle=use_pickle) _global_node = ray.node.Node( - ray_params, head=False, shutdown_at_exit=False, connect_only=True) + ray_params, + head=False, + shutdown_at_exit=False, + spawn_reaper=False, + connect_only=True) connect( _global_node, diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 83c9489b5..2f178a3ad 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -89,7 +89,11 @@ if __name__ == "__main__": use_pickle=args.use_pickle) node = ray.node.Node( - ray_params, head=False, shutdown_at_exit=False, connect_only=True) + ray_params, + head=False, + shutdown_at_exit=False, + spawn_reaper=False, + connect_only=True) ray.worker._global_node = node ray.worker.connect(node, mode=ray.WORKER_MODE) ray.worker.global_worker.main_loop()