diff --git a/python/ray/node.py b/python/ray/node.py index 443a10b4a..10a129f9f 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -6,9 +6,9 @@ import os import logging import signal import socket +import subprocess import sys import tempfile -import threading import time import ray @@ -185,7 +185,7 @@ class Node: self.kill_all_processes(check_alive=False, allow_graceful=True) sys.exit(1) - signal.signal(signal.SIGTERM, sigterm_handler) + ray.utils.set_sigterm_handler(sigterm_handler) def _init_temp(self, redis_client): # Create an dictionary to store temp file index. @@ -718,25 +718,21 @@ class Node: time.sleep(0.1) if allow_graceful: - # Allow the process one second to exit gracefully. process.terminate() - timer = threading.Timer(1, lambda process: process.kill(), - [process]) + # Allow the process one second to exit gracefully. + timeout_seconds = 1 try: - timer.start() + process.wait(timeout_seconds) + except subprocess.TimeoutExpired: + pass + + # If the process did not exit, force kill it. + if process.poll() is None: + process.kill() + # The reason we usually don't call process.wait() here is that + # there's some chance we'd end up waiting a really long time. + if wait: process.wait() - finally: - timer.cancel() - - if process.poll() is not None: - continue - - # If the process did not exit within one second, force kill it. - process.kill() - # The reason we usually don't call process.wait() here is that - # there's some chance we'd end up waiting a really long time. - if wait: - process.wait() del self.all_processes[process_type] diff --git a/python/ray/services.py b/python/ray/services.py index 6d3632542..eb51bdb0d 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1,10 +1,12 @@ import collections import errno +import io import json import logging import multiprocessing import os import random +import signal import socket import subprocess import sys @@ -76,6 +78,32 @@ ProcessInfo = collections.namedtuple("ProcessInfo", [ ]) +class ConsolePopen(subprocess.Popen): + if sys.platform == "win32": + + def terminate(self): + if isinstance(self.stdin, io.IOBase): + self.stdin.close() + if self._use_signals: + self.send_signal(signal.CTRL_BREAK_EVENT) + else: + super(ConsolePopen, self).terminate() + + def __init__(self, *args, **kwargs): + # CREATE_NEW_PROCESS_GROUP is used to send Ctrl+C on Windows: + # https://docs.python.org/3/library/subprocess.html#subprocess.Popen.send_signal + new_pgroup = subprocess.CREATE_NEW_PROCESS_GROUP + flags = 0 + if ray.utils.detect_fate_sharing_support(): + # If we don't have kernel-mode fate-sharing, then don't do this + # because our children need to be in out process group for + # the process reaper to properly terminate them. + flags = new_pgroup + kwargs.setdefault("creationflags", flags) + self._use_signals = (kwargs["creationflags"] & new_pgroup) + super(ConsolePopen, self).__init__(*args, **kwargs) + + def address(ip_address, port): return ip_address + ":" + str(port) @@ -464,7 +492,7 @@ def start_ray_process(command, if fate_share and sys.platform.startswith("linux"): ray.utils.set_kill_on_parent_death_linux() - process = subprocess.Popen( + process = ConsolePopen( command, env=modified_env, cwd=cwd, diff --git a/python/ray/utils.py b/python/ray/utils.py index 8ae693bbf..af1936276 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -5,6 +5,7 @@ import inspect import logging import numpy as np import os +import signal import subprocess import sys import tempfile @@ -561,10 +562,12 @@ def detect_fate_sharing_support_win32(): # https://docs.microsoft.com/en-us/windows/win32/api/jobapi2/nf-jobapi2-setinformationjobobject JobObjectExtendedLimitInformation = 9 JOB_OBJECT_LIMIT_BREAKAWAY_OK = 0x00000800 + JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION = 0x00000400 JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE = 0x00002000 buf = JOBOBJECT_EXTENDED_LIMIT_INFORMATION() buf.BasicLimitInformation.LimitFlags = ( JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE + | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION | JOB_OBJECT_LIMIT_BREAKAWAY_OK) infoclass = JobObjectExtendedLimitInformation if not kernel32.SetInformationJobObject( @@ -636,6 +639,17 @@ def set_kill_child_on_death_win32(child_proc): assert False, "AssignProcessToJobObject used despite being unavailable" +def set_sigterm_handler(sigterm_handler): + """Registers a handler for SIGTERM in a platform-compatible manner.""" + if sys.platform == "win32": + # Note that these signal handlers only work for console applications. + # TODO(mehrdadn): implement graceful process termination mechanism + # SIGINT is Ctrl+C, SIGBREAK is Ctrl+Break. + signal.signal(signal.SIGBREAK, sigterm_handler) + else: + signal.signal(signal.SIGTERM, sigterm_handler) + + def try_make_directory_shared(directory_path): try: os.chmod(directory_path, 0o0777) diff --git a/python/ray/worker.py b/python/ray/worker.py index 273db9be7..afc1e4bd5 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -422,7 +422,7 @@ class Worker: shutdown(True) sys.exit(1) - signal.signal(signal.SIGTERM, sigterm_handler) + ray.utils.set_sigterm_handler(sigterm_handler) self.core_worker.run_task_loop() sys.exit(0) @@ -882,7 +882,7 @@ def sigterm_handler(signum, frame): try: - signal.signal(signal.SIGTERM, sigterm_handler) + ray.utils.set_sigterm_handler(sigterm_handler) except ValueError: logger.warning("Failed to set SIGTERM handler, processes might" "not be cleaned up properly on exit.") @@ -1222,14 +1222,18 @@ def connect(node, # Redirect stdout/stderr at the file descriptor level. If we simply # set sys.stdout and sys.stderr, then logging from C++ can fail to # be redirected. - os.dup2(log_stdout_file.fileno(), sys.stdout.fileno()) - os.dup2(log_stderr_file.fileno(), sys.stderr.fileno()) + if log_stdout_file is not None: + os.dup2(log_stdout_file.fileno(), sys.stdout.fileno()) + if log_stderr_file is not None: + os.dup2(log_stderr_file.fileno(), sys.stderr.fileno()) # We also manually set sys.stdout and sys.stderr because that seems # to have an affect on the output buffering. Without doing this, # stdout and stderr are heavily buffered resulting in seemingly # lost logging statements. - sys.stdout = log_stdout_file - sys.stderr = log_stderr_file + if log_stdout_file is not None: + sys.stdout = log_stdout_file + if log_stderr_file is not None: + sys.stderr = log_stderr_file # This should always be the first message to appear in the worker's # stdout and stderr log files. The string "Ray worker pid:" is # parsed in the log monitor process. @@ -1238,8 +1242,12 @@ def connect(node, sys.stdout.flush() sys.stderr.flush() - worker_dict["stdout_file"] = os.path.abspath(log_stdout_file.name) - worker_dict["stderr_file"] = os.path.abspath(log_stderr_file.name) + worker_dict["stdout_file"] = os.path.abspath( + (log_stdout_file + if log_stdout_file is not None else sys.stdout).name) + worker_dict["stderr_file"] = os.path.abspath( + (log_stderr_file + if log_stderr_file is not None else sys.stderr).name) worker.redis_client.hmset(b"Workers:" + worker.worker_id, worker_dict) else: raise ValueError("Invalid worker mode. Expected DRIVER or WORKER.") diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index d227c77d8..de13080ea 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -375,14 +375,10 @@ void CoreWorker::SetCurrentTaskId(const TaskID &task_id) { void CoreWorker::CheckForRayletFailure() { // If the raylet fails, we will be reassigned to init (PID=1). -#ifdef _WIN32 -// TODO(mehrdadn): need a different solution for Windows. -#else if (getppid() == 1) { RAY_LOG(ERROR) << "Raylet failed. Shutting down."; Shutdown(); } -#endif // Reset the timer from the previous expiration time to avoid drift. death_check_timer_.expires_at( diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 5369d1042..42db174b2 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -201,7 +201,12 @@ int main(int argc, char *argv[]) { main_service.stop(); remove(raylet_socket_name.c_str()); }; - boost::asio::signal_set signals(main_service, SIGTERM); + boost::asio::signal_set signals(main_service); +#ifdef _WIN32 + signals.add(SIGBREAK); +#else + signals.add(SIGTERM); +#endif signals.async_wait(handler); main_service.run(); diff --git a/src/ray/raylet/monitor_main.cc b/src/ray/raylet/monitor_main.cc index 0f294ccf2..308db3dcd 100644 --- a/src/ray/raylet/monitor_main.cc +++ b/src/ray/raylet/monitor_main.cc @@ -65,7 +65,12 @@ int main(int argc, char *argv[]) { // // instead of returning immediately. // auto handler = [&io_service](const boost::system::error_code &error, // int signal_number) { io_service.stop(); }; - // boost::asio::signal_set signals(io_service, SIGTERM); + // boost::asio::signal_set signals(io_service); + // #ifdef _WIN32 + // signals.add(SIGBREAK); + // #else + // signals.add(SIGTERM); + // #endif // signals.async_wait(handler); // Initialize the monitor. diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 658e04c34..e1b002037 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -68,7 +68,12 @@ WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers, raylet_config_(raylet_config), starting_worker_timeout_callback_(starting_worker_timeout_callback) { RAY_CHECK(maximum_startup_concurrency > 0); -#ifndef _WIN32 +#ifdef _WIN32 + // If worker processes fail to initialize, don't display an error window. + SetErrorMode(GetErrorMode() | SEM_FAILCRITICALERRORS); + // If worker processes crash, don't display an error window. + SetErrorMode(GetErrorMode() | SEM_NOGPFAULTERRORBOX); +#else // Ignore SIGCHLD signals. If we don't do this, then worker processes will // become zombies instead of dying gracefully. signal(SIGCHLD, SIG_IGN); diff --git a/src/shims/windows/unistd.cc b/src/shims/windows/unistd.cc index d69655791..c7cc44e67 100644 --- a/src/shims/windows/unistd.cc +++ b/src/shims/windows/unistd.cc @@ -1,9 +1,66 @@ #include +#include + #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN 1 #endif #include +#include + +#ifndef STATUS_BUFFER_OVERFLOW +#define STATUS_BUFFER_OVERFLOW ((NTSTATUS)0x80000005L) +#endif + +typedef LONG NTSTATUS; +typedef NTSTATUS WINAPI NtQueryInformationProcess_t(HANDLE ProcessHandle, + ULONG ProcessInformationClass, + PVOID ProcessInformation, + ULONG ProcessInformationLength, + ULONG *ReturnLength); + +static std::atomic NtQueryInformationProcess_ = + ATOMIC_VAR_INIT(NULL); + +pid_t getppid() { + NtQueryInformationProcess_t *NtQueryInformationProcess = ::NtQueryInformationProcess_; + if (!NtQueryInformationProcess) { + NtQueryInformationProcess = reinterpret_cast( + GetProcAddress(GetModuleHandle(TEXT("ntdll.dll")), + _CRT_STRINGIZE(NtQueryInformationProcess))); + ::NtQueryInformationProcess_ = NtQueryInformationProcess; + } + DWORD ppid = 0; + PROCESS_BASIC_INFORMATION info; + ULONG cb = sizeof(info); + NTSTATUS status = NtQueryInformationProcess(GetCurrentProcess(), 0, &info, cb, &cb); + if ((status >= 0 || status == STATUS_BUFFER_OVERFLOW) && cb >= sizeof(info)) { + ppid = reinterpret_cast(info.Reserved3); + } + pid_t result = 0; + if (ppid > 0) { + // For now, assume PPID = 1 (simulating the reassignment to "init" on Linux) + result = 1; + if (HANDLE parent = OpenProcess(PROCESS_QUERY_INFORMATION, FALSE, ppid)) { + long long me_created, parent_created; + FILETIME unused; + if (GetProcessTimes(GetCurrentProcess(), reinterpret_cast(&me_created), + &unused, &unused, &unused) && + GetProcessTimes(parent, reinterpret_cast(&parent_created), &unused, + &unused, &unused)) { + if (me_created >= parent_created) { + // We verified the child is younger than the parent, so we know the parent + // is still alive. + // (Note that the parent can still die by the time this function returns, + // but that race condition exists on POSIX too, which we're emulating here.) + result = static_cast(ppid); + } + } + CloseHandle(parent); + } + } + return result; +} int usleep(useconds_t usec) { Sleep((usec + (1000 - 1)) / 1000); diff --git a/src/shims/windows/unistd.h b/src/shims/windows/unistd.h index d0a04284d..cbc30c3d2 100644 --- a/src/shims/windows/unistd.h +++ b/src/shims/windows/unistd.h @@ -42,6 +42,8 @@ typedef unsigned int useconds_t; int usleep(useconds_t usec); unsigned sleep(unsigned seconds); +pid_t getppid(); + __declspec( deprecated("Killing a process by ID has an inherent race condition on Windows" " and is HIGHLY discouraged. "