From 4d42664b2a3f1875ff9e7f3f3cebe435ca312fb1 Mon Sep 17 00:00:00 2001 From: mehrdadn Date: Tue, 3 Mar 2020 09:45:42 -0800 Subject: [PATCH] Use prctl(PR_SET_PDEATHSIG) on Linux instead of reaper (#7150) --- python/ray/node.py | 35 +++++-- python/ray/services.py | 108 +++++++++++++------ python/ray/tests/test_multi_node_2.py | 3 +- python/ray/utils.py | 144 ++++++++++++++++++++++++++ 4 files changed, 244 insertions(+), 46 deletions(-) diff --git a/python/ray/node.py b/python/ray/node.py index 9a7160719..9758b5894 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -66,6 +66,8 @@ class Node: self._register_shutdown_hooks() self.head = head + self.kernel_fate_share = (spawn_reaper + and ray.utils.detect_fate_sharing_support()) self.all_processes = {} # Try to get node IP address with the parameters. @@ -154,7 +156,7 @@ class Node: # raylet starts. self._ray_params.node_manager_port = self._get_unused_port() - if not connect_only and spawn_reaper: + if not connect_only and spawn_reaper and not self.kernel_fate_share: self.start_reaper_process() # Start processes. @@ -413,7 +415,9 @@ class Node: This must be the first process spawned and should only be called when ray processes should be cleaned up if this process dies. """ - process_info = ray.services.start_reaper() + assert not self.kernel_fate_share, ( + "a reaper should not be used with kernel fate-sharing") + process_info = ray.services.start_reaper(fate_share=False) assert ray_constants.PROCESS_TYPE_REAPER not in self.all_processes if process_info is not None: self.all_processes[ray_constants.PROCESS_TYPE_REAPER] = [ @@ -438,7 +442,8 @@ class Node: redis_max_clients=self._ray_params.redis_max_clients, redirect_worker_output=True, password=self._ray_params.redis_password, - include_java=self._ray_params.include_java) + include_java=self._ray_params.include_java, + fate_share=self.kernel_fate_share) assert ( ray_constants.PROCESS_TYPE_REDIS_SERVER not in self.all_processes) self.all_processes[ray_constants.PROCESS_TYPE_REDIS_SERVER] = ( @@ -452,7 +457,8 @@ class Node: self._logs_dir, stdout_file=stdout_file, stderr_file=stderr_file, - redis_password=self._ray_params.redis_password) + redis_password=self._ray_params.redis_password, + fate_share=self.kernel_fate_share) assert ray_constants.PROCESS_TYPE_LOG_MONITOR not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_LOG_MONITOR] = [ process_info @@ -465,7 +471,8 @@ class Node: self.redis_address, stdout_file=stdout_file, stderr_file=stderr_file, - redis_password=self._ray_params.redis_password) + redis_password=self._ray_params.redis_password, + fate_share=self.kernel_fate_share) assert ray_constants.PROCESS_TYPE_REPORTER not in self.all_processes if process_info is not None: self.all_processes[ray_constants.PROCESS_TYPE_REPORTER] = [ @@ -488,7 +495,8 @@ class Node: self._temp_dir, stdout_file=stdout_file, stderr_file=stderr_file, - redis_password=self._ray_params.redis_password) + redis_password=self._ray_params.redis_password, + fate_share=self.kernel_fate_share) assert ray_constants.PROCESS_TYPE_DASHBOARD not in self.all_processes if process_info is not None: self.all_processes[ray_constants.PROCESS_TYPE_DASHBOARD] = [ @@ -506,7 +514,8 @@ class Node: stderr_file=stderr_file, plasma_directory=self._ray_params.plasma_directory, huge_pages=self._ray_params.huge_pages, - plasma_store_socket_name=self._plasma_store_socket_name) + plasma_store_socket_name=self._plasma_store_socket_name, + fate_share=self.kernel_fate_share) assert ( ray_constants.PROCESS_TYPE_PLASMA_STORE not in self.all_processes) self.all_processes[ray_constants.PROCESS_TYPE_PLASMA_STORE] = [ @@ -522,7 +531,8 @@ class Node: stdout_file=stdout_file, stderr_file=stderr_file, redis_password=self._ray_params.redis_password, - config=self._config) + config=self._config, + fate_share=self.kernel_fate_share) assert ( ray_constants.PROCESS_TYPE_GCS_SERVER not in self.all_processes) self.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER] = [ @@ -559,7 +569,8 @@ class Node: include_java=self._ray_params.include_java, java_worker_options=self._ray_params.java_worker_options, load_code_from_local=self._ray_params.load_code_from_local, - use_pickle=self._ray_params.use_pickle) + use_pickle=self._ray_params.use_pickle, + fate_share=self.kernel_fate_share) assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info] @@ -581,7 +592,8 @@ class Node: stdout_file=stdout_file, stderr_file=stderr_file, autoscaling_config=self._ray_params.autoscaling_config, - redis_password=self._ray_params.redis_password) + redis_password=self._ray_params.redis_password, + fate_share=self.kernel_fate_share) assert ray_constants.PROCESS_TYPE_MONITOR not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_MONITOR] = [process_info] @@ -593,7 +605,8 @@ class Node: stdout_file=stdout_file, stderr_file=stderr_file, redis_password=self._ray_params.redis_password, - config=self._config) + config=self._config, + fate_share=self.kernel_fate_share) assert (ray_constants.PROCESS_TYPE_RAYLET_MONITOR not in self.all_processes) self.all_processes[ray_constants.PROCESS_TYPE_RAYLET_MONITOR] = [ diff --git a/python/ray/services.py b/python/ray/services.py index ea0234149..4a89bfce0 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -329,7 +329,8 @@ def start_ray_process(command, use_tmux=False, stdout_file=None, stderr_file=None, - pipe_stdin=False): + pipe_stdin=False, + fate_share=None): """Start one of the Ray processes. TODO(rkn): We need to figure out how these commands interact. For example, @@ -358,6 +359,8 @@ def start_ray_process(command, no redirection should happen, then this should be None. pipe_stdin: If true, subprocess.PIPE will be passed to the process as stdin. + fate_share: If true, the child will be killed if its parent (us) dies. + Note that this functionality must be supported, or it is an error. Returns: Information about the process that was started including a handle to @@ -439,12 +442,18 @@ def start_ray_process(command, # version, and tmux 2.1) command = ["tmux", "new-session", "-d", "{}".format(" ".join(command))] - # Block sigint for spawned processes so they aren't killed by the SIGINT - # propagated from the shell on Ctrl-C so we can handle KeyboardInterrupts - # in interactive sessions. This is only supported in Python 3.3 and above. - def block_sigint(): + if fate_share is None: + logger.warning("fate_share= should be passed to start_ray_process()") + if fate_share: + assert ray.utils.detect_fate_sharing_support(), ( + "kernel-level fate-sharing must only be specified if " + "detect_fate_sharing_support() has returned True") + + def preexec_fn(): import signal signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT}) + if fate_share and sys.platform.startswith("linux"): + ray.utils.set_kill_on_parent_death_linux() process = subprocess.Popen( command, @@ -453,7 +462,10 @@ def start_ray_process(command, stdout=stdout_file, stderr=stderr_file, stdin=subprocess.PIPE if pipe_stdin else None, - preexec_fn=block_sigint) + preexec_fn=preexec_fn if sys.platform != "win32" else None) + + if fate_share and sys.platform == "win32": + ray.utils.set_kill_child_on_death_win32(process) return ProcessInfo( process=process, @@ -569,7 +581,7 @@ def check_version_info(redis_client): logger.warning(error_message) -def start_reaper(): +def start_reaper(fate_share=None): """Start the reaper process. This is a lightweight process that simply @@ -585,8 +597,9 @@ def start_reaper(): # process that started us. try: os.setpgrp() - except OSError as e: - if e.errno == errno.EPERM and os.getpgrp() == os.getpid(): + except (AttributeError, OSError) as e: + errcode = e.errno if isinstance(e, OSError) else None + if errcode == errno.EPERM and os.getpgrp() == os.getpid(): # Nothing to do; we're already a session leader. pass else: @@ -600,7 +613,10 @@ def start_reaper(): os.path.dirname(os.path.abspath(__file__)), "ray_process_reaper.py") command = [sys.executable, "-u", reaper_filepath] process_info = start_ray_process( - command, ray_constants.PROCESS_TYPE_REAPER, pipe_stdin=True) + command, + ray_constants.PROCESS_TYPE_REAPER, + pipe_stdin=True, + fate_share=fate_share) return process_info @@ -614,7 +630,8 @@ def start_redis(node_ip_address, redirect_worker_output=False, password=None, use_credis=None, - include_java=False): + include_java=False, + fate_share=None): """Start the Redis global state store. Args: @@ -698,7 +715,8 @@ def start_redis(node_ip_address, # primary Redis shard. redis_max_memory=None, stdout_file=redis_stdout_file, - stderr_file=redis_stderr_file) + stderr_file=redis_stderr_file, + fate_share=fate_share) processes.append(p) redis_address = address(node_ip_address, port) @@ -803,7 +821,8 @@ def _start_redis_instance(executable, stdout_file=None, stderr_file=None, password=None, - redis_max_memory=None): + redis_max_memory=None, + fate_share=None): """Start a single Redis server. Notes: @@ -869,7 +888,8 @@ def _start_redis_instance(executable, command, ray_constants.PROCESS_TYPE_REDIS_SERVER, stdout_file=stdout_file, - stderr_file=stderr_file) + stderr_file=stderr_file, + fate_share=fate_share) time.sleep(0.1) # Check if Redis successfully started (or at least if it the executable # did not exit within 0.1 seconds). @@ -942,7 +962,8 @@ def start_log_monitor(redis_address, logs_dir, stdout_file=None, stderr_file=None, - redis_password=None): + redis_password=None, + fate_share=None): """Start a log monitor process. Args: @@ -970,14 +991,16 @@ def start_log_monitor(redis_address, command, ray_constants.PROCESS_TYPE_LOG_MONITOR, stdout_file=stdout_file, - stderr_file=stderr_file) + stderr_file=stderr_file, + fate_share=fate_share) return process_info def start_reporter(redis_address, stdout_file=None, stderr_file=None, - redis_password=None): + redis_password=None, + fate_share=None): """Start a reporter process. Args: @@ -1004,7 +1027,8 @@ def start_reporter(redis_address, command, ray_constants.PROCESS_TYPE_REPORTER, stdout_file=stdout_file, - stderr_file=stderr_file) + stderr_file=stderr_file, + fate_share=fate_share) return process_info @@ -1014,7 +1038,8 @@ def start_dashboard(require_webui, temp_dir, stdout_file=None, stderr_file=None, - redis_password=None): + redis_password=None, + fate_share=None): """Start a dashboard process. Args: @@ -1077,7 +1102,8 @@ def start_dashboard(require_webui, command, ray_constants.PROCESS_TYPE_DASHBOARD, stdout_file=stdout_file, - stderr_file=stderr_file) + stderr_file=stderr_file, + fate_share=fate_share) dashboard_url = "{}:{}".format( host if host != "0.0.0.0" else get_node_ip_address(), port) @@ -1093,7 +1119,8 @@ def start_gcs_server(redis_address, stdout_file=None, stderr_file=None, redis_password=None, - config=None): + config=None, + fate_share=None): """Start a gcs server. Args: redis_address (str): The address that the Redis server is listening on. @@ -1123,7 +1150,8 @@ def start_gcs_server(redis_address, command, ray_constants.PROCESS_TYPE_GCS_SERVER, stdout_file=stdout_file, - stderr_file=stderr_file) + stderr_file=stderr_file, + fate_share=fate_share) return process_info @@ -1146,7 +1174,8 @@ def start_raylet(redis_address, include_java=False, java_worker_options=None, load_code_from_local=False, - use_pickle=False): + use_pickle=False, + fate_share=None): """Start a raylet, which is a combined local scheduler and object manager. Args: @@ -1275,7 +1304,8 @@ def start_raylet(redis_address, use_valgrind_profiler=use_profiler, use_perftools_profiler=("RAYLET_PERFTOOLS_PATH" in os.environ), stdout_file=stdout_file, - stderr_file=stderr_file) + stderr_file=stderr_file, + fate_share=fate_share) return process_info @@ -1437,7 +1467,8 @@ def _start_plasma_store(plasma_store_memory, stderr_file=None, plasma_directory=None, huge_pages=False, - socket_name=None): + socket_name=None, + fate_share=None): """Start a plasma store process. Args: @@ -1491,7 +1522,8 @@ def _start_plasma_store(plasma_store_memory, use_valgrind=use_valgrind, use_valgrind_profiler=use_profiler, stdout_file=stdout_file, - stderr_file=stderr_file) + stderr_file=stderr_file, + fate_share=fate_share) return process_info @@ -1500,7 +1532,8 @@ def start_plasma_store(resource_spec, stderr_file=None, plasma_directory=None, huge_pages=False, - plasma_store_socket_name=None): + plasma_store_socket_name=None, + fate_share=None): """This method starts an object store process. Args: @@ -1541,7 +1574,8 @@ def start_plasma_store(resource_spec, stderr_file=stderr_file, plasma_directory=plasma_directory, huge_pages=huge_pages, - socket_name=plasma_store_socket_name) + socket_name=plasma_store_socket_name, + fate_share=fate_share) return process_info @@ -1553,7 +1587,8 @@ def start_worker(node_ip_address, worker_path, temp_dir, stdout_file=None, - stderr_file=None): + stderr_file=None, + fate_share=None): """This method starts a worker process. Args: @@ -1584,7 +1619,8 @@ def start_worker(node_ip_address, command, ray_constants.PROCESS_TYPE_WORKER, stdout_file=stdout_file, - stderr_file=stderr_file) + stderr_file=stderr_file, + fate_share=fate_share) return process_info @@ -1592,7 +1628,8 @@ def start_monitor(redis_address, stdout_file=None, stderr_file=None, autoscaling_config=None, - redis_password=None): + redis_password=None, + fate_share=None): """Run a process to monitor the other processes. Args: @@ -1621,7 +1658,8 @@ def start_monitor(redis_address, command, ray_constants.PROCESS_TYPE_MONITOR, stdout_file=stdout_file, - stderr_file=stderr_file) + stderr_file=stderr_file, + fate_share=fate_share) return process_info @@ -1629,7 +1667,8 @@ def start_raylet_monitor(redis_address, stdout_file=None, stderr_file=None, redis_password=None, - config=None): + config=None, + fate_share=None): """Run a process to monitor the other processes. Args: @@ -1661,5 +1700,6 @@ def start_raylet_monitor(redis_address, command, ray_constants.PROCESS_TYPE_RAYLET_MONITOR, stdout_file=stdout_file, - stderr_file=stderr_file) + stderr_file=stderr_file, + fate_share=fate_share) return process_info diff --git a/python/ray/tests/test_multi_node_2.py b/python/ray/tests/test_multi_node_2.py index 495709d93..a65786b32 100644 --- a/python/ray/tests/test_multi_node_2.py +++ b/python/ray/tests/test_multi_node_2.py @@ -176,7 +176,8 @@ def test_worker_plasma_store_failure(ray_start_cluster_head): cluster.wait_for_nodes() worker.kill_reporter() worker.kill_plasma_store() - worker.kill_reaper() + if ray_constants.PROCESS_TYPE_REAPER in worker.all_processes: + worker.kill_reaper() worker.all_processes[ray_constants.PROCESS_TYPE_RAYLET][0].process.wait() assert not worker.any_processes_alive(), worker.live_processes() diff --git a/python/ray/utils.py b/python/ray/utils.py index 174430e46..b03516abd 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -6,6 +6,7 @@ import logging import numpy as np import os import six +import subprocess import sys import threading import time @@ -18,6 +19,15 @@ import psutil logger = logging.getLogger(__name__) +# Linux can bind child processes' lifetimes to that of their parents via prctl. +# prctl support is detected dynamically once, and assumed thereafter. +linux_prctl = None + +# Windows can bind processes' lifetimes to that of kernel-level "job objects". +# We keep a global job object to tie its lifetime to that of our own process. +win32_job = None +win32_AssignProcessToJobObject = None + def _random_string(): id_hash = hashlib.sha1() @@ -496,6 +506,140 @@ def is_main_thread(): return threading.current_thread().getName() == "MainThread" +def detect_fate_sharing_support_win32(): + global win32_job, win32_AssignProcessToJobObject + if win32_job is None and sys.platform == "win32": + import ctypes + try: + from ctypes.wintypes import BOOL, DWORD, HANDLE, LPVOID, LPCWSTR + kernel32 = ctypes.WinDLL("kernel32") + kernel32.CreateJobObjectW.argtypes = (LPVOID, LPCWSTR) + kernel32.CreateJobObjectW.restype = HANDLE + sijo_argtypes = (HANDLE, ctypes.c_int, LPVOID, DWORD) + kernel32.SetInformationJobObject.argtypes = sijo_argtypes + kernel32.SetInformationJobObject.restype = BOOL + kernel32.AssignProcessToJobObject.argtypes = (HANDLE, HANDLE) + kernel32.AssignProcessToJobObject.restype = BOOL + except (AttributeError, TypeError, ImportError): + kernel32 = None + job = kernel32.CreateJobObjectW(None, None) if kernel32 else None + job = subprocess.Handle(job) if job else job + if job: + from ctypes.wintypes import DWORD, LARGE_INTEGER, ULARGE_INTEGER + + class JOBOBJECT_BASIC_LIMIT_INFORMATION(ctypes.Structure): + _fields_ = [ + ("PerProcessUserTimeLimit", LARGE_INTEGER), + ("PerJobUserTimeLimit", LARGE_INTEGER), + ("LimitFlags", DWORD), + ("MinimumWorkingSetSize", ctypes.c_size_t), + ("MaximumWorkingSetSize", ctypes.c_size_t), + ("ActiveProcessLimit", DWORD), + ("Affinity", ctypes.c_size_t), + ("PriorityClass", DWORD), + ("SchedulingClass", DWORD), + ] + + class IO_COUNTERS(ctypes.Structure): + _fields_ = [ + ("ReadOperationCount", ULARGE_INTEGER), + ("WriteOperationCount", ULARGE_INTEGER), + ("OtherOperationCount", ULARGE_INTEGER), + ("ReadTransferCount", ULARGE_INTEGER), + ("WriteTransferCount", ULARGE_INTEGER), + ("OtherTransferCount", ULARGE_INTEGER), + ] + + class JOBOBJECT_EXTENDED_LIMIT_INFORMATION(ctypes.Structure): + _fields_ = [ + ("BasicLimitInformation", + JOBOBJECT_BASIC_LIMIT_INFORMATION), + ("IoInfo", IO_COUNTERS), + ("ProcessMemoryLimit", ctypes.c_size_t), + ("JobMemoryLimit", ctypes.c_size_t), + ("PeakProcessMemoryUsed", ctypes.c_size_t), + ("PeakJobMemoryUsed", ctypes.c_size_t), + ] + + # Defined in ; also available here: + # https://docs.microsoft.com/en-us/windows/win32/api/jobapi2/nf-jobapi2-setinformationjobobject + JobObjectExtendedLimitInformation = 9 + JOB_OBJECT_LIMIT_BREAKAWAY_OK = 0x00000800 + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE = 0x00002000 + buf = JOBOBJECT_EXTENDED_LIMIT_INFORMATION() + buf.BasicLimitInformation.LimitFlags = ( + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE + | JOB_OBJECT_LIMIT_BREAKAWAY_OK) + infoclass = JobObjectExtendedLimitInformation + if not kernel32.SetInformationJobObject( + job, infoclass, ctypes.byref(buf), ctypes.sizeof(buf)): + job = None + win32_AssignProcessToJobObject = (kernel32.AssignProcessToJobObject + if kernel32 is not None else False) + win32_job = job if job else False + return bool(win32_job) + + +def detect_fate_sharing_support_linux(): + global linux_prctl + if linux_prctl is None and sys.platform.startswith("linux"): + try: + from ctypes import c_int, c_ulong, CDLL + prctl = CDLL(None).prctl + prctl.restype = c_int + prctl.argtypes = [c_int, c_ulong, c_ulong, c_ulong, c_ulong] + except (AttributeError, TypeError): + prctl = None + linux_prctl = prctl if prctl else False + return bool(linux_prctl) + + +def detect_fate_sharing_support(): + result = None + if sys.platform == "win32": + result = detect_fate_sharing_support_win32() + elif sys.platform.startswith("linux"): + result = detect_fate_sharing_support_linux() + return result + + +def set_kill_on_parent_death_linux(): + """Ensures this process dies if its parent dies (fate-sharing). + + Linux-only. Must be called in preexec_fn (i.e. by the child). + """ + if detect_fate_sharing_support_linux(): + import signal + PR_SET_PDEATHSIG = 1 + if linux_prctl(PR_SET_PDEATHSIG, signal.SIGKILL, 0, 0, 0) != 0: + import ctypes + raise OSError(ctypes.get_errno(), "prctl(PR_SET_PDEATHSIG) failed") + else: + assert False, "PR_SET_PDEATHSIG used despite being unavailable" + + +def set_kill_child_on_death_win32(child_proc): + """Ensures the child process dies if this process dies (fate-sharing). + + Windows-only. Must be called by the parent, after spawning the child. + + Args: + child_proc: The subprocess.Popen or subprocess.Handle object. + """ + + if isinstance(child_proc, subprocess.Popen): + child_proc = child_proc._handle + assert isinstance(child_proc, subprocess.Handle) + + if detect_fate_sharing_support_win32(): + if not win32_AssignProcessToJobObject(win32_job, int(child_proc)): + import ctypes + raise OSError(ctypes.get_last_error(), + "AssignProcessToJobObject() failed") + else: + assert False, "AssignProcessToJobObject used despite being unavailable" + + def try_make_directory_shared(directory_path): try: os.chmod(directory_path, 0o0777)