from __future__ import absolute_import from __future__ import division from __future__ import print_function import atexit import collections import datetime import errno import json import os import logging import signal import sys import tempfile import threading import time import ray import ray.ray_constants as ray_constants import ray.services from ray.utils import try_to_create_directory # Logger for this module. It should be configured at the entry point # into the program using Ray. Ray configures it by default automatically # using logging.basicConfig in its entry/init points. logger = logging.getLogger(__name__) PY3 = sys.version_info.major >= 3 class Node(object): """An encapsulation of the Ray processes on a single node. This class is responsible for starting Ray processes and killing them. Attributes: all_processes (dict): A mapping from process type (str) to a list of ProcessInfo objects. All lists have length one except for the Redis server list, which has multiple. """ def __init__(self, ray_params, head=False, shutdown_at_exit=True, connect_only=False): """Start a node. Args: ray_params (ray.params.RayParams): The parameters to use to configure the node. head (bool): True if this is the head node, which means it will start additional processes like the Redis servers, monitor processes, and web UI. shutdown_at_exit (bool): If true, a handler will be registered to shutdown the processes started here when the Python interpreter exits. connect_only (bool): If true, connect to the node without starting new processes. """ if shutdown_at_exit and connect_only: raise ValueError("'shutdown_at_exit' and 'connect_only' cannot " "be both true.") self.all_processes = {} ray_params.update_if_absent( node_ip_address=ray.services.get_node_ip_address(), include_log_monitor=True, resources={}, include_webui=False, worker_path=os.path.join( os.path.dirname(os.path.abspath(__file__)), "workers/default_worker.py")) self._ray_params = ray_params self._node_ip_address = ray_params.node_ip_address self._redis_address = ray_params.redis_address self._config = (json.loads(ray_params._internal_config) if ray_params._internal_config else None) if head: ray_params.update_if_absent(num_redis_shards=1, include_webui=True) self._plasma_store_socket_name = None self._raylet_socket_name = None self._webui_url = None self._dashboard_url = None else: self._plasma_store_socket_name = ( ray_params.plasma_store_socket_name) self._raylet_socket_name = ray_params.raylet_socket_name redis_client = self.create_redis_client() # TODO(suquark): Replace _webui_url_helper in worker.py in # another PR. _webui_url = redis_client.hmget("webui", "url")[0] self._webui_url = (ray.utils.decode(_webui_url) if _webui_url is not None else None) ray_params.include_java = ( ray.services.include_java_from_redis(redis_client)) self._init_temp() if not connect_only: self.start_ray_processes() if shutdown_at_exit: atexit.register(lambda: self.kill_all_processes( check_alive=False, allow_graceful=True)) def _init_temp(self): # Create an dictionary to store temp file index. self._incremental_dict = collections.defaultdict(lambda: 0) self._temp_dir = self._ray_params.temp_dir if self._temp_dir is None: date_str = datetime.datetime.today().strftime("%Y-%m-%d_%H-%M-%S") self._temp_dir = self._make_inc_temp( prefix="session_{date_str}_{pid}".format( pid=os.getpid(), date_str=date_str), directory_name="/tmp/ray") try_to_create_directory(self._temp_dir) # Create a directory to be used for socket files. self._sockets_dir = os.path.join(self._temp_dir, "sockets") try_to_create_directory(self._sockets_dir) # Create a directory to be used for process log files. self._logs_dir = os.path.join(self._temp_dir, "logs") try_to_create_directory(self._logs_dir) @property def node_ip_address(self): """Get the cluster Redis address.""" return self._node_ip_address @property def redis_address(self): """Get the cluster Redis address.""" return self._redis_address @property def plasma_store_socket_name(self): """Get the node's plasma store socket name.""" return self._plasma_store_socket_name @property def webui_url(self): """Get the cluster's web UI url.""" return self._webui_url @property def raylet_socket_name(self): """Get the node's raylet socket name.""" return self._raylet_socket_name def create_redis_client(self): """Create a redis client.""" return ray.services.create_redis_client( self._redis_address, self._ray_params.redis_password) def get_temp_dir_path(self): """Get the path of the temporary directory.""" return self._temp_dir def get_logs_dir_path(self): """Get the path of the log files directory.""" return self._logs_dir def get_sockets_dir_path(self): """Get the path of the sockets directory.""" return self._sockets_dir def _make_inc_temp(self, suffix="", prefix="", directory_name="/tmp/ray"): """Return a incremental temporary file name. The file is not created. Args: suffix (str): The suffix of the temp file. prefix (str): The prefix of the temp file. directory_name (str) : The base directory of the temp file. Returns: A string of file name. If there existing a file having the same name, the returned name will look like "{directory_name}/{prefix}.{unique_index}{suffix}" """ directory_name = os.path.expanduser(directory_name) index = self._incremental_dict[suffix, prefix, directory_name] # `tempfile.TMP_MAX` could be extremely large, # so using `range` in Python2.x should be avoided. while index < tempfile.TMP_MAX: if index == 0: filename = os.path.join(directory_name, prefix + suffix) else: filename = os.path.join(directory_name, prefix + "." + str(index) + suffix) index += 1 if not os.path.exists(filename): # Save the index. self._incremental_dict[suffix, prefix, directory_name] = index return filename raise FileExistsError(errno.EEXIST, "No usable temporary filename found") def new_log_files(self, name, redirect_output=True): """Generate partially randomized filenames for log files. Args: name (str): descriptive string for this log file. redirect_output (bool): True if files should be generated for logging stdout and stderr and false if stdout and stderr should not be redirected. If it is None, it will use the "redirect_output" Ray parameter. Returns: If redirect_output is true, this will return a tuple of two file handles. The first is for redirecting stdout and the second is for redirecting stderr. If redirect_output is false, this will return a tuple of two None objects. """ if redirect_output is None: redirect_output = self._ray_params.redirect_output if not redirect_output: return None, None log_stdout = self._make_inc_temp( suffix=".out", prefix=name, directory_name=self._logs_dir) log_stderr = self._make_inc_temp( suffix=".err", prefix=name, directory_name=self._logs_dir) # Line-buffer the output (mode 1). log_stdout_file = open(log_stdout, "a", buffering=1) log_stderr_file = open(log_stderr, "a", buffering=1) return log_stdout_file, log_stderr_file def _prepare_socket_file(self, socket_path, default_prefix): """Prepare the socket file for raylet and plasma. This method helps to prepare a socket file. 1. Make the directory if the directory does not exist. 2. If the socket file exists, raise exception. Args: socket_path (string): the socket file to prepare. """ if socket_path is not None: if os.path.exists(socket_path): raise Exception("Socket file {} exists!".format(socket_path)) socket_dir = os.path.dirname(socket_path) try_to_create_directory(socket_dir) return socket_path return self._make_inc_temp( prefix=default_prefix, directory_name=self._sockets_dir) def start_redis(self): """Start the Redis servers.""" assert self._redis_address is None redis_log_files = [self.new_log_files("redis")] for i in range(self._ray_params.num_redis_shards): redis_log_files.append(self.new_log_files("redis-shard_" + str(i))) (self._redis_address, redis_shards, process_infos) = ray.services.start_redis( self._node_ip_address, redis_log_files, port=self._ray_params.redis_port, redis_shard_ports=self._ray_params.redis_shard_ports, num_redis_shards=self._ray_params.num_redis_shards, redis_max_clients=self._ray_params.redis_max_clients, redirect_worker_output=True, password=self._ray_params.redis_password, include_java=self._ray_params.include_java, redis_max_memory=self._ray_params.redis_max_memory) assert ( ray_constants.PROCESS_TYPE_REDIS_SERVER not in self.all_processes) self.all_processes[ray_constants.PROCESS_TYPE_REDIS_SERVER] = ( process_infos) def start_log_monitor(self): """Start the log monitor.""" stdout_file, stderr_file = self.new_log_files("log_monitor") process_info = ray.services.start_log_monitor( self.redis_address, self._logs_dir, stdout_file=stdout_file, stderr_file=stderr_file, redis_password=self._ray_params.redis_password) assert ray_constants.PROCESS_TYPE_LOG_MONITOR not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_LOG_MONITOR] = [ process_info ] def start_reporter(self): """Start the reporter.""" stdout_file, stderr_file = self.new_log_files("reporter", True) process_info = ray.services.start_reporter( self.redis_address, stdout_file=stdout_file, stderr_file=stderr_file, redis_password=self._ray_params.redis_password) assert ray_constants.PROCESS_TYPE_REPORTER not in self.all_processes if process_info is not None: self.all_processes[ray_constants.PROCESS_TYPE_REPORTER] = [ process_info ] def start_dashboard(self): """Start the dashboard.""" stdout_file, stderr_file = self.new_log_files("dashboard", True) self._dashboard_url, process_info = ray.services.start_dashboard( self.redis_address, self._temp_dir, stdout_file=stdout_file, stderr_file=stderr_file, redis_password=self._ray_params.redis_password) assert ray_constants.PROCESS_TYPE_DASHBOARD not in self.all_processes if process_info is not None: self.all_processes[ray_constants.PROCESS_TYPE_DASHBOARD] = [ process_info ] def start_ui(self): """Start the web UI.""" stdout_file, stderr_file = self.new_log_files("webui") notebook_name = self._make_inc_temp( suffix=".ipynb", prefix="ray_ui", directory_name=self._temp_dir) self._webui_url, process_info = ray.services.start_ui( self._redis_address, notebook_name, stdout_file=stdout_file, stderr_file=stderr_file) assert ray_constants.PROCESS_TYPE_WEB_UI not in self.all_processes if process_info is not None: self.all_processes[ray_constants.PROCESS_TYPE_WEB_UI] = [ process_info ] def start_plasma_store(self): """Start the plasma store.""" assert self._plasma_store_socket_name is None # If the user specified a socket name, use it. self._plasma_store_socket_name = self._prepare_socket_file( self._ray_params.plasma_store_socket_name, default_prefix="plasma_store") stdout_file, stderr_file = self.new_log_files("plasma_store") process_info = ray.services.start_plasma_store( stdout_file=stdout_file, stderr_file=stderr_file, object_store_memory=self._ray_params.object_store_memory, plasma_directory=self._ray_params.plasma_directory, huge_pages=self._ray_params.huge_pages, plasma_store_socket_name=self._plasma_store_socket_name) assert ( ray_constants.PROCESS_TYPE_PLASMA_STORE not in self.all_processes) self.all_processes[ray_constants.PROCESS_TYPE_PLASMA_STORE] = [ process_info ] def start_raylet(self, use_valgrind=False, use_profiler=False): """Start the raylet. Args: use_valgrind (bool): True if we should start the process in valgrind. use_profiler (bool): True if we should start the process in the valgrind profiler. """ assert self._raylet_socket_name is None # If the user specified a socket name, use it. self._raylet_socket_name = self._prepare_socket_file( self._ray_params.raylet_socket_name, default_prefix="raylet") stdout_file, stderr_file = self.new_log_files("raylet") process_info = ray.services.start_raylet( self._redis_address, self._node_ip_address, self._raylet_socket_name, self._plasma_store_socket_name, self._ray_params.worker_path, self._temp_dir, self._ray_params.num_cpus, self._ray_params.num_gpus, self._ray_params.resources, self._ray_params.object_manager_port, self._ray_params.node_manager_port, self._ray_params.redis_password, use_valgrind=use_valgrind, use_profiler=use_profiler, stdout_file=stdout_file, stderr_file=stderr_file, config=self._config, include_java=self._ray_params.include_java, java_worker_options=self._ray_params.java_worker_options, load_code_from_local=self._ray_params.load_code_from_local, ) assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info] def new_worker_redirected_log_file(self, worker_id): """Create new logging files for workers to redirect its output.""" worker_stdout_file, worker_stderr_file = (self.new_log_files( "worker-" + ray.utils.binary_to_hex(worker_id), True)) return worker_stdout_file, worker_stderr_file def start_worker(self): """Start a worker process.""" raise NotImplementedError def start_monitor(self): """Start the monitor.""" stdout_file, stderr_file = self.new_log_files("monitor") process_info = ray.services.start_monitor( self._redis_address, stdout_file=stdout_file, stderr_file=stderr_file, autoscaling_config=self._ray_params.autoscaling_config, redis_password=self._ray_params.redis_password) assert ray_constants.PROCESS_TYPE_MONITOR not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_MONITOR] = [process_info] def start_raylet_monitor(self): """Start the raylet monitor.""" stdout_file, stderr_file = self.new_log_files("raylet_monitor") process_info = ray.services.start_raylet_monitor( self._redis_address, stdout_file=stdout_file, stderr_file=stderr_file, redis_password=self._ray_params.redis_password, config=self._config) assert (ray_constants.PROCESS_TYPE_RAYLET_MONITOR not in self.all_processes) self.all_processes[ray_constants.PROCESS_TYPE_RAYLET_MONITOR] = [ process_info ] def start_ray_processes(self): """Start all of the processes on the node.""" logger.info( "Process STDOUT and STDERR is being redirected to {}.".format( self._logs_dir)) # If this is the head node, start the relevant head node processes. if self._redis_address is None: self.start_redis() self.start_monitor() self.start_raylet_monitor() if PY3 and self._ray_params.include_webui: self.start_dashboard() self.start_plasma_store() self.start_raylet() if PY3 and self._ray_params.include_webui: self.start_reporter() if self._ray_params.include_log_monitor: self.start_log_monitor() def _kill_process_type(self, process_type, allow_graceful=False, check_alive=True, wait=False): """Kill a process of a given type. If the process type is PROCESS_TYPE_REDIS_SERVER, then we will kill all of the Redis servers. If the process was started in valgrind, then we will raise an exception if the process has a non-zero exit code. Args: process_type: The type of the process to kill. allow_graceful (bool): Send a SIGTERM first and give the process time to exit gracefully. If that doesn't work, then use SIGKILL. We usually want to do this outside of tests. check_alive (bool): If true, then we expect the process to be alive and will raise an exception if the process is already dead. wait (bool): If true, then this method will not return until the process in question has exited. Raises: This process raises an exception in the following cases: 1. The process had already died and check_alive is true. 2. The process had been started in valgrind and had a non-zero exit code. """ process_infos = self.all_processes[process_type] if process_type != ray_constants.PROCESS_TYPE_REDIS_SERVER: assert len(process_infos) == 1 for process_info in process_infos: process = process_info.process # Handle the case where the process has already exited. if process.poll() is not None: if check_alive: raise Exception("Attempting to kill a process of type " "'{}', but this process is already dead." .format(process_type)) else: continue if process_info.use_valgrind: process.terminate() process.wait() if process.returncode != 0: message = ("Valgrind detected some errors in process of " "type {}. Error code {}.".format( process_type, process.returncode)) if process_info.stdout_file is not None: with open(process_info.stdout_file, "r") as f: message += "\nPROCESS STDOUT:\n" + f.read() if process_info.stderr_file is not None: with open(process_info.stderr_file, "r") as f: message += "\nPROCESS STDERR:\n" + f.read() raise Exception(message) continue if process_info.use_valgrind_profiler: # Give process signal to write profiler data. os.kill(process.pid, signal.SIGINT) # Wait for profiling data to be written. time.sleep(0.1) if allow_graceful: # Allow the process one second to exit gracefully. process.terminate() timer = threading.Timer(1, lambda process: process.kill(), [process]) try: timer.start() process.wait() finally: timer.cancel() if process.poll() is not None: continue # If the process did not exit within one second, force kill it. process.kill() # The reason we usually don't call process.wait() here is that # there's some chance we'd end up waiting a really long time. if wait: process.wait() del self.all_processes[process_type] def kill_redis(self, check_alive=True): """Kill the Redis servers. Args: check_alive (bool): Raise an exception if any of the processes were already dead. """ self._kill_process_type( ray_constants.PROCESS_TYPE_REDIS_SERVER, check_alive=check_alive) def kill_plasma_store(self, check_alive=True): """Kill the plasma store. Args: check_alive (bool): Raise an exception if the process was already dead. """ self._kill_process_type( ray_constants.PROCESS_TYPE_PLASMA_STORE, check_alive=check_alive) def kill_raylet(self, check_alive=True): """Kill the raylet. Args: check_alive (bool): Raise an exception if the process was already dead. """ self._kill_process_type( ray_constants.PROCESS_TYPE_RAYLET, check_alive=check_alive) def kill_log_monitor(self, check_alive=True): """Kill the log monitor. Args: check_alive (bool): Raise an exception if the process was already dead. """ self._kill_process_type( ray_constants.PROCESS_TYPE_LOG_MONITOR, check_alive=check_alive) def kill_reporter(self, check_alive=True): """Kill the reporter. Args: check_alive (bool): Raise an exception if the process was already dead. """ self._kill_process_type( ray_constants.PROCESS_TYPE_REPORTER, check_alive=check_alive) def kill_dashboard(self, check_alive=True): """Kill the dashboard. Args: check_alive (bool): Raise an exception if the process was already dead. """ self._kill_process_type( ray_constants.PROCESS_TYPE_DASHBOARD, check_alive=check_alive) def kill_monitor(self, check_alive=True): """Kill the monitor. Args: check_alive (bool): Raise an exception if the process was already dead. """ self._kill_process_type( ray_constants.PROCESS_TYPE_MONITOR, check_alive=check_alive) def kill_raylet_monitor(self, check_alive=True): """Kill the raylet monitor. Args: check_alive (bool): Raise an exception if the process was already dead. """ self._kill_process_type( ray_constants.PROCESS_TYPE_RAYLET_MONITOR, check_alive=check_alive) def kill_all_processes(self, check_alive=True, allow_graceful=False): """Kill all of the processes. Note that This is slower than necessary because it calls kill, wait, kill, wait, ... instead of kill, kill, ..., wait, wait, ... Args: check_alive (bool): Raise an exception if any of the processes were already dead. """ # Kill the raylet first. This is important for suppressing errors at # shutdown because we give the raylet a chance to exit gracefully and # clean up its child worker processes. If we were to kill the plasma # store (or Redis) first, that could cause the raylet to exit # ungracefully, leading to more verbose output from the workers. if ray_constants.PROCESS_TYPE_RAYLET in self.all_processes: self._kill_process_type( ray_constants.PROCESS_TYPE_RAYLET, check_alive=check_alive, allow_graceful=allow_graceful) # We call "list" to copy the keys because we are modifying the # dictionary while iterating over it. for process_type in list(self.all_processes.keys()): self._kill_process_type( process_type, check_alive=check_alive, allow_graceful=allow_graceful) def live_processes(self): """Return a list of the live processes. Returns: A list of the live processes. """ result = [] for process_type, process_infos in self.all_processes.items(): for process_info in process_infos: if process_info.process.poll() is None: result.append((process_type, process_info.process)) return result def dead_processes(self): """Return a list of the dead processes. Note that this ignores processes that have been explicitly killed, e.g., via a command like node.kill_raylet(). Returns: A list of the dead processes ignoring the ones that have been explicitly killed. """ result = [] for process_type, process_infos in self.all_processes.items(): for process_info in process_infos: if process_info.process.poll() is not None: result.append((process_type, process_info.process)) return result def any_processes_alive(self): """Return true if any processes are still alive. Returns: True if any process is still alive. """ return any(self.live_processes()) def remaining_processes_alive(self): """Return true if all remaining processes are still alive. Note that this ignores processes that have been explicitly killed, e.g., via a command like node.kill_raylet(). Returns: True if any process that wasn't explicitly killed is still alive. """ return not any(self.dead_processes())