diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 66c6c2e96..8fb91d3a3 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -38,6 +38,8 @@ def parse_client_table(redis_client): node_info = {} gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0) + ordered_client_ids = [] + # Since GCS entries are append-only, we override so that # only the latest entries are kept. for i in range(gcs_entry.EntriesLength()): @@ -58,6 +60,8 @@ def parse_client_table(redis_client): assert client_id in node_info, "Client removed not found!" assert node_info[client_id]["IsInsertion"], ( "Unexpected duplicate removal of client.") + else: + ordered_client_ids.append(client_id) node_info[client_id] = { "ClientID": client_id, @@ -72,7 +76,13 @@ def parse_client_table(redis_client): client.RayletSocketName(), allow_none=True), "Resources": resources } - return list(node_info.values()) + # NOTE: We return the list comprehension below instead of simply doing + # 'list(node_info.values())' in order to have the nodes appear in the order + # that they joined the cluster. Python dictionaries do not preserve + # insertion order. We could use an OrderedDict, but then we'd have to be + # sure to only insert a given node a single time (clients that die appear + # twice in the GCS log). + return [node_info[client_id] for client_id in ordered_client_ids] class GlobalState(object): diff --git a/python/ray/node.py b/python/ray/node.py new file mode 100644 index 000000000..85b29c30c --- /dev/null +++ b/python/ray/node.py @@ -0,0 +1,503 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import atexit +import collections +import json +import os +import logging +import signal +import threading +import time + +import ray +from ray.tempfile_services import ( + get_logs_dir_path, get_object_store_socket_name, get_raylet_socket_name, + new_log_monitor_log_file, new_monitor_log_file, + new_raylet_monitor_log_file, new_plasma_store_log_file, + new_raylet_log_file, new_webui_log_file, set_temp_root) + +ProcessInfo = collections.namedtuple( + "ProcessInfo", ["process", "use_valgrind", "use_profiler"]) + +PROCESS_TYPE_MONITOR = "monitor" +PROCESS_TYPE_RAYLET_MONITOR = "raylet_monitor" +PROCESS_TYPE_LOG_MONITOR = "log_monitor" +PROCESS_TYPE_WORKER = "worker" +PROCESS_TYPE_RAYLET = "raylet" +PROCESS_TYPE_PLASMA_STORE = "plasma_store" +PROCESS_TYPE_REDIS_SERVER = "redis_server" +PROCESS_TYPE_WEB_UI = "web_ui" + +# Logger for this module. It should be configured at the entry point +# into the program using Ray. Ray configures it by default automatically +# using logging.basicConfig in its entry/init points. +logger = logging.getLogger(__name__) + + +class Node(object): + """An encapsulation of the Ray processes on a single node. + + This class is responsible for starting Ray processes and killing them. + + Attributes: + all_processes (dict): A mapping from process type (str) to a list of + ProcessInfo objects. All lists have length one except for the Redis + server list, which has multiple. + """ + + def __init__(self, ray_params, head=False, shutdown_at_exit=True): + """Start a node. + + Args: + ray_params (ray.params.RayParams): The parameters to use to + configure the node. + head (bool): True if this is the head node, which means it will + start additional processes like the Redis servers, monitor + processes, and web UI. + shutdown_at_exit (bool): If true, a handler will be registered to + shutdown the processes started here when the Python interpreter + exits. + """ + self.all_processes = {} + + ray_params.update_if_absent( + node_ip_address=ray.services.get_node_ip_address(), + include_log_monitor=True, + resources={}, + include_webui=False, + worker_path=os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "workers/default_worker.py")) + + if head: + ray_params.update_if_absent(num_redis_shards=1, include_webui=True) + + self._ray_params = ray_params + self._config = (json.loads(ray_params._internal_config) + if ray_params._internal_config else None) + self._node_ip_address = ray_params.node_ip_address + self._redis_address = ray_params.redis_address + self._plasma_store_socket_name = None + self._raylet_socket_name = None + self._webui_url = None + + self.start_ray_processes() + + if shutdown_at_exit: + atexit.register(lambda: self.kill_all_processes( + check_alive=False, allow_graceful=True)) + + @property + def node_ip_address(self): + """Get the cluster Redis address.""" + return self._node_ip_address + + @property + def redis_address(self): + """Get the cluster Redis address.""" + return self._redis_address + + @property + def plasma_store_socket_name(self): + """Get the node's plasma store socket name.""" + return self._plasma_store_socket_name + + @property + def webui_url(self): + """Get the cluster's web UI url.""" + return self._webui_url + + @property + def raylet_socket_name(self): + """Get the node's raylet socket name.""" + return self._raylet_socket_name + + def start_redis(self): + """Start the Redis servers.""" + assert self._redis_address is None + (self._redis_address, redis_shards, + processes) = ray.services.start_redis( + self._node_ip_address, + port=self._ray_params.redis_port, + redis_shard_ports=self._ray_params.redis_shard_ports, + num_redis_shards=self._ray_params.num_redis_shards, + redis_max_clients=self._ray_params.redis_max_clients, + redirect_output=self._ray_params.redirect_output, + redirect_worker_output=self._ray_params.redirect_worker_output, + password=self._ray_params.redis_password, + redis_max_memory=self._ray_params.redis_max_memory) + assert PROCESS_TYPE_REDIS_SERVER not in self.all_processes + self.all_processes[PROCESS_TYPE_REDIS_SERVER] = [] + for process in processes: + process_info = ProcessInfo( + process=process, use_valgrind=False, use_profiler=False) + self.all_processes[PROCESS_TYPE_REDIS_SERVER].append(process_info) + + def start_log_monitor(self): + """Start the log monitor.""" + stdout_file, stderr_file = new_log_monitor_log_file() + process = ray.services.start_log_monitor( + self.redis_address, + self._node_ip_address, + stdout_file=stdout_file, + stderr_file=stderr_file, + redis_password=self._ray_params.redis_password) + assert PROCESS_TYPE_LOG_MONITOR not in self.all_processes + self.all_processes[PROCESS_TYPE_LOG_MONITOR] = [ + ProcessInfo( + process=process, use_valgrind=False, use_profiler=False) + ] + + def start_ui(self): + """Start the web UI.""" + stdout_file, stderr_file = new_webui_log_file() + self._webui_url, process = ray.services.start_ui( + self._redis_address, + stdout_file=stdout_file, + stderr_file=stderr_file) + assert PROCESS_TYPE_WEB_UI not in self.all_processes + if process is not None: + self.all_processes[PROCESS_TYPE_WEB_UI] = [ + ProcessInfo( + process=process, use_valgrind=False, use_profiler=False) + ] + + def start_plasma_store(self): + """Start the plasma store.""" + assert self._plasma_store_socket_name is None + # If the user specified a socket name, use it. + self._plasma_store_socket_name = ( + self._ray_params.plasma_store_socket_name + or get_object_store_socket_name()) + stdout_file, stderr_file = (new_plasma_store_log_file( + self._ray_params.redirect_output)) + process = ray.services.start_plasma_store( + self._node_ip_address, + self._redis_address, + stdout_file=stdout_file, + stderr_file=stderr_file, + object_store_memory=self._ray_params.object_store_memory, + plasma_directory=self._ray_params.plasma_directory, + huge_pages=self._ray_params.huge_pages, + plasma_store_socket_name=self._plasma_store_socket_name, + redis_password=self._ray_params.redis_password) + assert PROCESS_TYPE_PLASMA_STORE not in self.all_processes + self.all_processes[PROCESS_TYPE_PLASMA_STORE] = [ + ProcessInfo( + process=process, use_valgrind=False, use_profiler=False) + ] + + def start_raylet(self, use_valgrind=False, use_profiler=False): + """Start the raylet. + + Args: + use_valgrind (bool): True if we should start the process in + valgrind. + use_profiler (bool): True if we should start the process in the + valgrind profiler. + """ + assert self._raylet_socket_name is None + # If the user specified a socket name, use it. + self._raylet_socket_name = (self._ray_params.raylet_socket_name + or get_raylet_socket_name()) + stdout_file, stderr_file = new_raylet_log_file( + redirect_output=self._ray_params.redirect_worker_output) + process = ray.services.start_raylet( + self._redis_address, + self._node_ip_address, + self._raylet_socket_name, + self._plasma_store_socket_name, + self._ray_params.worker_path, + self._ray_params.num_cpus, + self._ray_params.num_gpus, + self._ray_params.resources, + self._ray_params.object_manager_port, + self._ray_params.node_manager_port, + self._ray_params.redis_password, + use_valgrind=use_valgrind, + use_profiler=use_profiler, + stdout_file=stdout_file, + stderr_file=stderr_file, + config=self._config) + assert PROCESS_TYPE_RAYLET not in self.all_processes + self.all_processes[PROCESS_TYPE_RAYLET] = [ + ProcessInfo( + process=process, + use_valgrind=use_valgrind, + use_profiler=use_profiler) + ] + + def start_worker(self): + """Start a worker process.""" + raise NotImplementedError + + def start_monitor(self): + """Start the monitor.""" + stdout_file, stderr_file = new_monitor_log_file( + self._ray_params.redirect_output) + process = ray.services.start_monitor( + self._redis_address, + self._node_ip_address, + stdout_file=stdout_file, + stderr_file=stderr_file, + autoscaling_config=self._ray_params.autoscaling_config, + redis_password=self._ray_params.redis_password) + assert PROCESS_TYPE_MONITOR not in self.all_processes + self.all_processes[PROCESS_TYPE_MONITOR] = [ + ProcessInfo( + process=process, use_valgrind=False, use_profiler=False) + ] + + def start_raylet_monitor(self): + """Start the raylet monitor.""" + stdout_file, stderr_file = new_raylet_monitor_log_file( + self._ray_params.redirect_output) + process = ray.services.start_raylet_monitor( + self._redis_address, + stdout_file=stdout_file, + stderr_file=stderr_file, + redis_password=self._ray_params.redis_password, + config=self._config) + assert PROCESS_TYPE_RAYLET_MONITOR not in self.all_processes + self.all_processes[PROCESS_TYPE_RAYLET_MONITOR] = [ + ProcessInfo( + process=process, use_valgrind=False, use_profiler=False) + ] + + def start_ray_processes(self): + """Start all of the processes on the node.""" + set_temp_root(self._ray_params.temp_dir) + logger.info( + "Process STDOUT and STDERR is being redirected to {}.".format( + get_logs_dir_path())) + + # If this is the head node, start the relevant head node processes. + if self._redis_address is None: + self.start_redis() + self.start_monitor() + self.start_raylet_monitor() + + self.start_plasma_store() + self.start_raylet() + + if self._ray_params.include_log_monitor: + self.start_log_monitor() + if self._ray_params.include_webui: + self.start_ui() + + def _kill_process_type(self, + process_type, + allow_graceful=False, + check_alive=True, + wait=False): + """Kill a process of a given type. + + If the process type is PROCESS_TYPE_REDIS_SERVER, then we will kill all + of the Redis servers. + + If the process was started in valgrind, then we will raise an exception + if the process has a non-zero exit code. + + Args: + process_type: The type of the process to kill. + allow_graceful (bool): Send a SIGTERM first and give the process + time to exit gracefully. If that doesn't work, then use + SIGKILL. We usually want to do this outside of tests. + check_alive (bool): If true, then we expect the process to be alive + and will raise an exception if the process is already dead. + wait (bool): If true, then this method will not return until the + process in question has exited. + + Raises: + This process raises an exception in the following cases: + 1. The process had already died and check_alive is true. + 2. The process had been started in valgrind and had a non-zero + exit code. + """ + process_infos = self.all_processes[process_type] + if process_type != PROCESS_TYPE_REDIS_SERVER: + assert len(process_infos) == 1 + for process_info in process_infos: + process = process_info.process + # Handle the case where the process has already exited. + if process.poll() is not None: + if check_alive: + raise Exception("Attempting to kill a process of type " + "'{}', but this process is already dead." + .format(process_type)) + else: + continue + + if process_info.use_valgrind: + process.terminate() + process.wait() + if process.returncode != 0: + raise Exception("Valgrind detected some errors.") + continue + + if process_info.use_profiler: + # Give process signal to write profiler data. + os.kill(process.pid, signal.SIGINT) + # Wait for profiling data to be written. + time.sleep(0.1) + + if allow_graceful: + # Allow the process one second to exit gracefully. + process.terminate() + timer = threading.Timer(1, lambda process: process.kill(), + [process]) + try: + timer.start() + process.wait() + finally: + timer.cancel() + + if process.poll() is not None: + continue + + # If the process did not exit within one second, force kill it. + process.kill() + # The reason we usually don't call process.wait() here is that + # there's some chance we'd end up waiting a really long time. + if wait: + process.wait() + + del self.all_processes[process_type] + + def kill_redis(self, check_alive=True): + """Kill the Redis servers. + + Args: + check_alive (bool): Raise an exception if any of the processes + were already dead. + """ + self._kill_process_type( + PROCESS_TYPE_REDIS_SERVER, check_alive=check_alive) + + def kill_plasma_store(self, check_alive=True): + """Kill the plasma store. + + Args: + check_alive (bool): Raise an exception if the process was already + dead. + """ + self._kill_process_type( + PROCESS_TYPE_PLASMA_STORE, check_alive=check_alive) + + def kill_raylet(self, check_alive=True): + """Kill the raylet. + + Args: + check_alive (bool): Raise an exception if the process was already + dead. + """ + self._kill_process_type(PROCESS_TYPE_RAYLET, check_alive=check_alive) + + def kill_log_monitor(self, check_alive=True): + """Kill the log monitor. + + Args: + check_alive (bool): Raise an exception if the process was already + dead. + """ + self._kill_process_type( + PROCESS_TYPE_LOG_MONITOR, check_alive=check_alive) + + def kill_monitor(self, check_alive=True): + """Kill the monitor. + + Args: + check_alive (bool): Raise an exception if the process was already + dead. + """ + self._kill_process_type(PROCESS_TYPE_MONITOR, check_alive=check_alive) + + def kill_raylet_monitor(self, check_alive=True): + """Kill the raylet monitor. + + Args: + check_alive (bool): Raise an exception if the process was already + dead. + """ + self._kill_process_type( + PROCESS_TYPE_RAYLET_MONITOR, check_alive=check_alive) + + def kill_all_processes(self, check_alive=True, allow_graceful=False): + """Kill all of the processes. + + Note that This is slower than necessary because it calls kill, wait, + kill, wait, ... instead of kill, kill, ..., wait, wait, ... + + Args: + check_alive (bool): Raise an exception if any of the processes were + already dead. + """ + # Kill the raylet first. This is important for suppressing errors at + # shutdown because we give the raylet a chance to exit gracefully and + # clean up its child worker processes. If we were to kill the plasma + # store (or Redis) first, that could cause the raylet to exit + # ungracefully, leading to more verbose output from the workers. + if PROCESS_TYPE_RAYLET in self.all_processes: + self._kill_process_type( + PROCESS_TYPE_RAYLET, + check_alive=check_alive, + allow_graceful=allow_graceful) + + # We call "list" to copy the keys because we are modifying the + # dictionary while iterating over it. + for process_type in list(self.all_processes.keys()): + self._kill_process_type( + process_type, + check_alive=check_alive, + allow_graceful=allow_graceful) + + def live_processes(self): + """Return a list of the live processes. + + Returns: + A list of the live processes. + """ + result = [] + for process_type, process_infos in self.all_processes.items(): + for process_info in process_infos: + if process_info.process.poll() is None: + result.append((process_type, process_info.process)) + return result + + def dead_processes(self): + """Return a list of the dead processes. + + Note that this ignores processes that have been explicitly killed, + e.g., via a command like node.kill_raylet(). + + Returns: + A list of the dead processes ignoring the ones that have been + explicitly killed. + """ + result = [] + for process_type, process_infos in self.all_processes.items(): + for process_info in process_infos: + if process_info.process.poll() is not None: + result.append((process_type, process_info.process)) + return result + + def any_processes_alive(self): + """Return true if any processes are still alive. + + Returns: + True if any process is still alive. + """ + return any(self.live_processes()) + + def remaining_processes_alive(self): + """Return true if all remaining processes are still alive. + + Note that this ignores processes that have been explicitly killed, + e.g., via a command like node.kill_raylet(). + + Returns: + True if any process that wasn't explicitly killed is still alive. + """ + return not any(self.dead_processes()) diff --git a/python/ray/parameter.py b/python/ray/parameter.py index d8b9b77f8..11086604f 100644 --- a/python/ray/parameter.py +++ b/python/ray/parameter.py @@ -11,25 +11,13 @@ class RayParams(object): """A class used to store the parameters used by Ray. Attributes: - address_info (dict): A dictionary with address information for - processes in a partially-started Ray cluster. If - start_ray_local=True, any processes not in this dictionary will be - started. If provided, an updated address_info dictionary will be - returned to include processes that are newly started. - start_ray_local (bool): If True then this will start any processes not - already in address_info, including Redis, a global scheduler, local - scheduler(s), object store(s), and worker(s). It will also kill - these processes when Python exits. If False, this will attach to an - existing Ray cluster. redis_address (str): The address of the Redis server to connect to. If this address is not provided, then this command will start Redis, a global scheduler, a local scheduler, a plasma store, a plasma manager, and some workers. It will also kill these processes when Python exits. redis_port (int): The port that the primary Redis shard should listen - to. If None, then a random port will be chosen. If the key - "redis_address" is in address_info, then this argument will be - ignored. + to. If None, then a random port will be chosen. redis_shard_ports: A list of the ports to use for the non-primary Redis shards. num_cpus (int): Number of CPUs to configure the raylet with. @@ -84,8 +72,6 @@ class RayParams(object): """ def __init__(self, - address_info=None, - start_ray_local=False, redis_address=None, num_cpus=None, num_gpus=None, @@ -118,8 +104,6 @@ class RayParams(object): include_log_monitor=None, autoscaling_config=None, _internal_config=None): - self.address_info = address_info - self.start_ray_local = start_ray_local self.object_id_seed = object_id_seed self.redis_address = redis_address self.num_cpus = num_cpus @@ -191,3 +175,8 @@ class RayParams(object): assert "GPU" not in self.resources, ( "'GPU' should not be included in the resource dictionary. Use " "num_gpus instead.") + + if self.num_workers is not None: + raise Exception( + "The 'num_workers' argument is deprecated. Please use " + "'num_cpus' instead.") diff --git a/python/ray/plasma/plasma.py b/python/ray/plasma/plasma.py index 53b243426..27af92cd0 100644 --- a/python/ray/plasma/plasma.py +++ b/python/ray/plasma/plasma.py @@ -7,8 +7,6 @@ import subprocess import sys import time -from ray.tempfile_services import get_object_store_socket_name - __all__ = ["start_plasma_store", "DEFAULT_PLASMA_STORE_MEMORY"] PLASMA_WAIT_TIMEOUT = 2**30 @@ -64,7 +62,7 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, plasma_store_executable = os.path.join( os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_store_server") - plasma_store_name = socket_name or get_object_store_socket_name() + plasma_store_name = socket_name command = [ plasma_store_executable, "-s", plasma_store_name, "-m", str(plasma_store_memory) diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 207659fd1..4bf1249c0 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -15,8 +15,6 @@ from ray.autoscaler.commands import ( import ray.ray_constants as ray_constants import ray.utils -from ray.parameter import RayParams - logger = logging.getLogger(__name__) @@ -231,7 +229,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, " --resources='{\"CustomResource1\": 3, " "\"CustomReseource2\": 2}'") - ray_params = RayParams( + ray_params = ray.parameter.RayParams( node_ip_address=node_ip_address, object_manager_port=object_manager_port, node_manager_port=node_manager_port, @@ -285,8 +283,9 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, include_webui=(not no_ui), autoscaling_config=autoscaling_config) - address_info = services.start_ray_head(ray_params, cleanup=False) - logger.info(address_info) + node = ray.node.Node(ray_params, head=True, shutdown_at_exit=False) + redis_address = node.redis_address + logger.info( "\nStarted Ray on this node. You can add additional nodes to " "the cluster by calling\n\n" @@ -299,9 +298,9 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, "that your firewall is configured properly. If you wish to " "terminate the processes that have been started, run\n\n" " ray stop".format( - address_info["redis_address"], " --redis-password " + redis_address, " --redis-password " if redis_password else "", redis_password if redis_password - else "", address_info["redis_address"], "\", redis_password=\"" + else "", redis_address, "\", redis_password=\"" if redis_password else "", redis_password if redis_password else "")) else: @@ -349,9 +348,9 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, # if the Redis server already has clients on this node. check_no_existing_redis_clients(ray_params.node_ip_address, redis_client) - ray_params.redis_address = redis_address - address_info = services.start_ray_node(ray_params, cleanup=False) - logger.info(address_info) + ray_params.update(redis_address=redis_address) + + node = ray.node.Node(ray_params, head=False, shutdown_at_exit=False) logger.info("\nStarted Ray on this node. If you wish to terminate the " "processes that have been started, run\n\n" " ray stop") diff --git a/python/ray/services.py b/python/ray/services.py index 3740a7368..476f2bc6b 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -8,13 +8,10 @@ import multiprocessing import os import random import resource -import signal import socket import subprocess import sys -import threading import time -from collections import OrderedDict import redis import pyarrow @@ -22,30 +19,8 @@ import pyarrow import ray.ray_constants as ray_constants import ray.plasma -from ray.tempfile_services import ( - get_ipython_notebook_path, get_logs_dir_path, get_raylet_socket_name, - get_temp_root, new_log_monitor_log_file, new_monitor_log_file, - new_plasma_store_log_file, new_raylet_log_file, new_redis_log_file, - new_webui_log_file, set_temp_root) - -PROCESS_TYPE_MONITOR = "monitor" -PROCESS_TYPE_LOG_MONITOR = "log_monitor" -PROCESS_TYPE_WORKER = "worker" -PROCESS_TYPE_RAYLET = "raylet" -PROCESS_TYPE_PLASMA_STORE = "plasma_store" -PROCESS_TYPE_REDIS_SERVER = "redis_server" -PROCESS_TYPE_WEB_UI = "web_ui" - -# This is a dictionary tracking all of the processes of different types that -# have been started by this services module. Note that the order of the keys is -# important because it determines the order in which these processes will be -# terminated when Ray exits, and certain orders will cause errors to be logged -# to the screen. -all_processes = OrderedDict( - [(PROCESS_TYPE_MONITOR, []), (PROCESS_TYPE_LOG_MONITOR, []), - (PROCESS_TYPE_WORKER, []), (PROCESS_TYPE_RAYLET, []), - (PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_REDIS_SERVER, []), - (PROCESS_TYPE_WEB_UI, [])], ) +from ray.tempfile_services import (get_ipython_notebook_path, get_temp_root, + new_redis_log_file) # True if processes are run in the valgrind profiler. RUN_RAYLET_PROFILER = False @@ -106,85 +81,24 @@ def new_port(): return random.randint(10000, 65535) -def kill_process(p): - """Kill a process. +def remaining_processes_alive(exclude=None): + """See if the remaining processes are alive or not. - Args: - p: The process to kill. + Note that this ignores processes that have been explicitly killed, + e.g., via a command like node.kill_raylet(). Returns: - True if the process was killed successfully and false otherwise. + True if the remaining processes started by ray.init() are alive and + False otherwise. + + Raises: + Exception: An exception is raised if the processes were not started by + ray.init(). """ - if p.poll() is not None: - # The process has already terminated. - return True - if any([RUN_RAYLET_PROFILER, RUN_PLASMA_STORE_PROFILER]): - # Give process signal to write profiler data. - os.kill(p.pid, signal.SIGINT) - # Wait for profiling data to be written. - time.sleep(0.1) - - # Allow the process one second to exit gracefully. - p.terminate() - timer = threading.Timer(1, lambda p: p.kill(), [p]) - try: - timer.start() - p.wait() - finally: - timer.cancel() - - if p.poll() is not None: - return True - - # If the process did not exit within one second, force kill it. - p.kill() - if p.poll() is not None: - return True - - # The process was not killed for some reason. - return False - - -def cleanup(): - """When running in local mode, shutdown the Ray processes. - - This method is used to shutdown processes that were started with - services.start_ray_head(). It kills all scheduler, object store, and worker - processes that were started by this services module. Driver processes are - started and disconnected by worker.py. - """ - successfully_shut_down = True - # Terminate the processes in reverse order. - for process_type in all_processes.keys(): - # Kill all of the processes of a certain type. - for p in all_processes[process_type]: - success = kill_process(p) - successfully_shut_down = successfully_shut_down and success - # Reset the list of processes of this type. - all_processes[process_type] = [] - if not successfully_shut_down: - logger.warning("Ray did not shut down properly.") - - -def all_processes_alive(exclude=None): - """Check if all of the processes are still alive. - - Args: - exclude: Don't check the processes whose types are in this list. - """ - - if exclude is None: - exclude = [] - for process_type, processes in all_processes.items(): - # Note that p.poll() returns the exit code that the process exited - # with, so an exit code of None indicates that the process is still - # alive. - processes_alive = [p.poll() is None for p in processes] - if not all(processes_alive) and process_type not in exclude: - logger.warning( - "A process of type {} has died.".format(process_type)) - return False - return True + if ray.worker._global_node is None: + raise Exception("This process is not in a position to determine " + "whether all processes are alive or not.") + return ray.worker._global_node.remaining_processes_alive() def address_to_ip(address): @@ -411,7 +325,6 @@ def start_redis(node_ip_address, redis_max_clients=None, redirect_output=False, redirect_worker_output=False, - cleanup=True, password=None, use_credis=None, redis_max_memory=None): @@ -434,10 +347,6 @@ def start_redis(node_ip_address, redirect_worker_output (bool): True if worker output should be redirected to a file and false otherwise. Workers will have access to this value when they start up. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then all Redis processes started by this method will be killed by - services.cleanup() when the Python process that imported services - exits. password (str): Prevents external clients without the password from connecting to Redis if provided. use_credis: If True, additionally load the chain-replicated libraries @@ -450,8 +359,9 @@ def start_redis(node_ip_address, capped at 10GB but can be set higher. Returns: - A tuple of the address for the primary Redis shard and a list of - addresses for the remaining shards. + A tuple of the address for the primary Redis shard, a list of + addresses for the remaining shards, and the processes that were + started. """ redis_stdout_file, redis_stderr_file = new_redis_log_file(redirect_output) @@ -461,6 +371,8 @@ def start_redis(node_ip_address, raise Exception("The number of Redis shard ports does not match the " "number of Redis shards.") + processes = [] + if use_credis is None: use_credis = ("RAY_USE_NEW_GCS" in os.environ) if use_credis and password is not None: @@ -471,25 +383,24 @@ def start_redis(node_ip_address, "password-protected Redis ports, ensure that " "the environment variable `RAY_USE_NEW_GCS=off`.") if not use_credis: - assigned_port, _ = _start_redis_instance( + assigned_port, p = _start_redis_instance( node_ip_address=node_ip_address, port=port, redis_max_clients=redis_max_clients, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, - cleanup=cleanup, password=password, # Below we use None to indicate no limit on the memory of the # primary Redis shard. redis_max_memory=None) + processes.append(p) else: - assigned_port, _ = _start_redis_instance( + assigned_port, p = _start_redis_instance( node_ip_address=node_ip_address, port=port, redis_max_clients=redis_max_clients, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, - cleanup=cleanup, executable=CREDIS_EXECUTABLE, # It is important to load the credis module BEFORE the ray module, # as the latter contains an extern declaration that the former @@ -499,6 +410,7 @@ def start_redis(node_ip_address, # Below we use None to indicate no limit on the memory of the # primary Redis shard. redis_max_memory=None) + processes.append(p) if port is not None: assert assigned_port == port port = assigned_port @@ -534,26 +446,25 @@ def start_redis(node_ip_address, redis_stdout_file, redis_stderr_file = new_redis_log_file( redirect_output, shard_number=i) if not use_credis: - redis_shard_port, _ = _start_redis_instance( + redis_shard_port, p = _start_redis_instance( node_ip_address=node_ip_address, port=redis_shard_ports[i], redis_max_clients=redis_max_clients, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, - cleanup=cleanup, password=password, redis_max_memory=redis_max_memory) + processes.append(p) else: assert num_redis_shards == 1, \ "For now, RAY_USE_NEW_GCS supports 1 shard, and credis "\ "supports 1-node chain for that shard only." - redis_shard_port, _ = _start_redis_instance( + redis_shard_port, p = _start_redis_instance( node_ip_address=node_ip_address, port=redis_shard_ports[i], redis_max_clients=redis_max_clients, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, - cleanup=cleanup, password=password, executable=CREDIS_EXECUTABLE, # It is important to load the credis module BEFORE the ray @@ -561,6 +472,7 @@ def start_redis(node_ip_address, # former supplies. modules=[CREDIS_MEMBER_MODULE, REDIS_MODULE], redis_max_memory=redis_max_memory) + processes.append(p) if redis_shard_ports[i] is not None: assert redis_shard_port == redis_shard_ports[i] @@ -578,7 +490,7 @@ def start_redis(node_ip_address, shard_client.execute_command("MEMBER.CONNECT_TO_MASTER", node_ip_address, port) - return redis_address, redis_shards + return redis_address, redis_shards, processes def _start_redis_instance(node_ip_address="127.0.0.1", @@ -587,7 +499,6 @@ def _start_redis_instance(node_ip_address="127.0.0.1", num_retries=20, stdout_file=None, stderr_file=None, - cleanup=True, password=None, executable=REDIS_EXECUTABLE, modules=None, @@ -606,9 +517,6 @@ def _start_redis_instance(node_ip_address="127.0.0.1", no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then this process will be killed by serices.cleanup() when the - Python process that imported services exits. password (str): Prevents external clients without the password from connecting to Redis if provided. executable (str): Full path tho the redis-server executable. @@ -659,8 +567,6 @@ def _start_redis_instance(node_ip_address="127.0.0.1", # Check if Redis successfully started (or at least if it the executable # did not exit within 0.1 seconds). if p.poll() is None: - if cleanup: - all_processes[PROCESS_TYPE_REDIS_SERVER].append(p) break port = new_port() counter += 1 @@ -734,7 +640,6 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None, stderr_file=None, - cleanup=cleanup, redis_password=None): """Start a log monitor process. @@ -746,10 +651,10 @@ def start_log_monitor(redis_address, no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then this process will be killed by services.cleanup() when the - Python process that imported services exits. redis_password (str): The password of the redis server. + + Returns: + The process that was started. """ log_monitor_filepath = os.path.join( os.path.dirname(os.path.abspath(__file__)), "log_monitor.py") @@ -760,15 +665,14 @@ def start_log_monitor(redis_address, if redis_password: command += ["--redis-password", redis_password] p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - if cleanup: - all_processes[PROCESS_TYPE_LOG_MONITOR].append(p) record_log_files_in_redis( redis_address, node_ip_address, [stdout_file, stderr_file], password=redis_password) + return p -def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): +def start_ui(redis_address, stdout_file=None, stderr_file=None): """Start a UI process. Args: @@ -777,9 +681,9 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then this process will be killed by services.cleanup() when the - Python process that imported services exits. + + Returns: + A tuple of the web UI url and the process that was started. """ port = 8888 @@ -820,12 +724,11 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): logger.warning("Failed to start the UI, you may need to run " "'pip install jupyter'.") else: - if cleanup: - all_processes[PROCESS_TYPE_WEB_UI].append(ui_process) logger.info("\n" + "=" * 70) logger.info("View the web UI at {}".format(webui_url)) logger.info("=" * 70 + "\n") - return webui_url + return webui_url, ui_process + return None, None def check_and_update_resources(num_cpus, num_gpus, resources): @@ -887,28 +790,40 @@ def check_and_update_resources(num_cpus, num_gpus, resources): return resources -def start_raylet(ray_params, +def start_raylet(redis_address, + node_ip_address, raylet_name, plasma_store_name, - num_initial_workers=0, + worker_path, + num_cpus=None, + num_gpus=None, + resources=None, + object_manager_port=None, + node_manager_port=None, + redis_password=None, use_valgrind=False, use_profiler=False, stdout_file=None, stderr_file=None, - cleanup=True, config=None): """Start a raylet, which is a combined local scheduler and object manager. Args: - ray_params (ray.params.RayParams): The RayParams instance. The - following parameters could be checked: redis_address, - node_ip_address, worker_path, resources, num_cpus, num_gpus, - object_manager_port, node_manager_port, redis_password. - resources, object_manager_port, node_manager_port. + redis_address (str): The address of the primary Redis server. + node_ip_address (str): The IP address of this node. raylet_name (str): The name of the raylet socket to create. plasma_store_name (str): The name of the plasma store socket to connect to. - num_initial_workers (int): The number of workers to start initially. + worker_path (str): The path of the Python file that new worker + processes will execute. + num_cpus: The CPUs allocated for this raylet. + num_gpus: The GPUs allocated for this raylet. + resources: The custom resources allocated for this raylet. + object_manager_port: The port to use for the object manager. If this is + None, then the object manager will choose its own port. + node_manager_port: The port to use for the node manager. If this is + None, then the node manager will choose its own port. + redis_password: The password to use when connecting to Redis. use_valgrind (bool): True if the raylet should be started inside of valgrind. If this is True, use_profiler must be False. use_profiler (bool): True if the raylet should be started inside @@ -917,14 +832,11 @@ def start_raylet(ray_params, no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then this process will be killed by serices.cleanup() when the - Python process that imported services exits. config (dict|None): Optional Raylet configuration that will override defaults in RayConfig. Returns: - The raylet socket name. + The process that was started. """ config = config or {} config_str = ",".join(["{},{}".format(*kv) for kv in config.items()]) @@ -932,8 +844,11 @@ def start_raylet(ray_params, if use_valgrind and use_profiler: raise Exception("Cannot use valgrind and profiler at the same time.") - static_resources = check_and_update_resources( - ray_params.num_cpus, ray_params.num_gpus, ray_params.resources) + num_initial_workers = (num_cpus if num_cpus is not None else + multiprocessing.cpu_count()) + + static_resources = check_and_update_resources(num_cpus, num_gpus, + resources) # Limit the number of workers that can be started in parallel by the # raylet. However, make sure it is at least 1. @@ -944,7 +859,7 @@ def start_raylet(ray_params, resource_argument = ",".join( ["{},{}".format(*kv) for kv in static_resources.items()]) - gcs_ip_address, gcs_port = ray_params.redis_address.split(":") + gcs_ip_address, gcs_port = redis_address.split(":") # Create the command that the Raylet will use to start workers. start_worker_command = ("{} {} " @@ -953,30 +868,28 @@ def start_raylet(ray_params, "--raylet-name={} " "--redis-address={} " "--temp-dir={}".format( - sys.executable, ray_params.worker_path, - ray_params.node_ip_address, plasma_store_name, - raylet_name, ray_params.redis_address, + sys.executable, worker_path, node_ip_address, + plasma_store_name, raylet_name, redis_address, get_temp_root())) - if ray_params.redis_password: - start_worker_command += " --redis-password {}".format( - ray_params.redis_password) + if redis_password: + start_worker_command += " --redis-password {}".format(redis_password) # If the object manager port is None, then use 0 to cause the object # manager to choose its own port. - if ray_params.object_manager_port is None: - ray_params.object_manager_port = 0 + if object_manager_port is None: + object_manager_port = 0 # If the node manager port is None, then use 0 to cause the node manager # to choose its own port. - if ray_params.node_manager_port is None: - ray_params.node_manager_port = 0 + if node_manager_port is None: + node_manager_port = 0 command = [ RAYLET_EXECUTABLE, raylet_name, plasma_store_name, - str(ray_params.object_manager_port), - str(ray_params.node_manager_port), - ray_params.node_ip_address, + str(object_manager_port), + str(node_manager_port), + node_ip_address, gcs_ip_address, gcs_port, str(num_initial_workers), @@ -985,12 +898,12 @@ def start_raylet(ray_params, config_str, start_worker_command, "", # Worker command for Java, not needed for Python. - ray_params.redis_password or "", + redis_password or "", get_temp_root(), ] if use_valgrind: - pid = subprocess.Popen( + p = subprocess.Popen( [ "valgrind", "--track-origins=yes", "--leak-check=full", "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", @@ -999,7 +912,7 @@ def start_raylet(ray_params, stdout=stdout_file, stderr=stderr_file) elif use_profiler: - pid = subprocess.Popen( + p = subprocess.Popen( ["valgrind", "--tool=callgrind"] + command, stdout=stdout_file, stderr=stderr_file) @@ -1007,19 +920,17 @@ def start_raylet(ray_params, modified_env = os.environ.copy() modified_env["LD_PRELOAD"] = os.environ["RAYLET_PERFTOOLS_PATH"] modified_env["CPUPROFILE"] = os.environ["RAYLET_PERFTOOLS_LOGFILE"] - pid = subprocess.Popen( + p = subprocess.Popen( command, stdout=stdout_file, stderr=stderr_file, env=modified_env) else: - pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) + p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - if cleanup: - all_processes[PROCESS_TYPE_RAYLET].append(pid) record_log_files_in_redis( - ray_params.redis_address, - ray_params.node_ip_address, [stdout_file, stderr_file], - password=ray_params.redis_password) + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) - return raylet_name + return p def determine_plasma_store_config(object_store_memory=None, @@ -1104,11 +1015,9 @@ def determine_plasma_store_config(object_store_memory=None, def start_plasma_store(node_ip_address, redis_address, - object_manager_port=None, - store_stdout_file=None, - store_stderr_file=None, + stdout_file=None, + stderr_file=None, object_store_memory=None, - cleanup=True, plasma_directory=None, huge_pages=False, plasma_store_socket_name=None, @@ -1119,25 +1028,20 @@ def start_plasma_store(node_ip_address, node_ip_address (str): The IP address of the node running the object store. redis_address (str): The address of the Redis instance to connect to. - object_manager_port (int): The port to use for the object manager. If - this is not provided, one will be generated randomly. - store_stdout_file: A file handle opened for writing to redirect stdout + stdout_file: A file handle opened for writing to redirect stdout to. If no redirection should happen, then this should be None. - store_stderr_file: A file handle opened for writing to redirect stderr + stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. object_store_memory: The amount of memory (in bytes) to start the object store with. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then this process will be killed by serices.cleanup() when the - Python process that imported services exits. plasma_directory: A directory where the Plasma memory mapped files will be created. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. redis_password (str): The password of the redis server. - Return: - The Plasma store socket name. + Returns: + The process that was started. """ object_store_memory, plasma_directory = determine_plasma_store_config( object_store_memory, plasma_directory, huge_pages) @@ -1153,23 +1057,21 @@ def start_plasma_store(node_ip_address, logger.info("Starting the Plasma object store with {} GB memory " "using {}.".format(object_store_memory_str, plasma_directory)) # Start the Plasma store. - plasma_store_name, p1 = ray.plasma.start_plasma_store( + plasma_store_name, p = ray.plasma.start_plasma_store( plasma_store_memory=object_store_memory, use_profiler=RUN_PLASMA_STORE_PROFILER, - stdout_file=store_stdout_file, - stderr_file=store_stderr_file, + stdout_file=stdout_file, + stderr_file=stderr_file, plasma_directory=plasma_directory, huge_pages=huge_pages, socket_name=plasma_store_socket_name) - if cleanup: - all_processes[PROCESS_TYPE_PLASMA_STORE].append(p1) record_log_files_in_redis( redis_address, - node_ip_address, [store_stdout_file, store_stderr_file], + node_ip_address, [stdout_file, stderr_file], password=redis_password) - return plasma_store_name + return p def start_worker(node_ip_address, @@ -1178,8 +1080,7 @@ def start_worker(node_ip_address, redis_address, worker_path, stdout_file=None, - stderr_file=None, - cleanup=True): + stderr_file=None): """This method starts a worker process. Args: @@ -1194,10 +1095,9 @@ def start_worker(node_ip_address, no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then this process will be killed by services.cleanup() when the - Python process that imported services exits. This is True by - default. + + Returns: + The process that was started. """ command = [ sys.executable, "-u", worker_path, @@ -1207,17 +1107,15 @@ def start_worker(node_ip_address, "--temp-dir=" + get_temp_root() ] p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - if cleanup: - all_processes[PROCESS_TYPE_WORKER].append(p) record_log_files_in_redis(redis_address, node_ip_address, [stdout_file, stderr_file]) + return p def start_monitor(redis_address, node_ip_address, stdout_file=None, stderr_file=None, - cleanup=True, autoscaling_config=None, redis_password=None): """Run a process to monitor the other processes. @@ -1230,12 +1128,11 @@ def start_monitor(redis_address, no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then this process will be killed by services.cleanup() when the - Python process that imported services exits. This is True by - default. autoscaling_config: path to autoscaling config file. redis_password (str): The password of the redis server. + + Returns: + The process that was started. """ monitor_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "monitor.py") @@ -1248,18 +1145,16 @@ def start_monitor(redis_address, if redis_password: command.append("--redis-password=" + redis_password) p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - if cleanup: - all_processes[PROCESS_TYPE_MONITOR].append(p) record_log_files_in_redis( redis_address, node_ip_address, [stdout_file, stderr_file], password=redis_password) + return p def start_raylet_monitor(redis_address, stdout_file=None, stderr_file=None, - cleanup=True, redis_password=None, config=None): """Run a process to monitor the other processes. @@ -1270,13 +1165,12 @@ def start_raylet_monitor(redis_address, no redirection should happen, then this should be None. stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then this process will be killed by services.cleanup() when the - Python process that imported services exits. This is True by - default. redis_password (str): The password of the redis server. config (dict|None): Optional configuration that will override defaults in RayConfig. + + Returns: + The process that was started. """ gcs_ip_address, gcs_port = redis_address.split(":") redis_password = redis_password or "" @@ -1286,222 +1180,4 @@ def start_raylet_monitor(redis_address, if redis_password: command += [redis_password] p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - if cleanup: - all_processes[PROCESS_TYPE_MONITOR].append(p) - - -def start_ray_processes(ray_params, cleanup=True): - """Helper method to start Ray processes. - - Args: - ray_params (ray.params.RayParams): The RayParams instance. The - following parameters will be set to default values if it's None: - node_ip_address("127.0.0.1"), include_webui(False), - worker_path(path of default_worker.py), include_log_monitor(False) - cleanup (bool): If cleanup is true, then the processes started here - will be killed by services.cleanup() when the Python process that - called this method exits. - - Returns: - A dictionary of the address information for the processes that were - started. - """ - - set_temp_root(ray_params.temp_dir) - - logger.info("Process STDOUT and STDERR is being redirected to {}.".format( - get_logs_dir_path())) - - config = json.loads( - ray_params._internal_config) if ray_params._internal_config else None - - ray_params.update_if_absent( - include_log_monitor=False, - resources={}, - include_webui=False, - node_ip_address="127.0.0.1") - - if ray_params.num_workers is not None: - raise Exception("The 'num_workers' argument is deprecated. Please use " - "'num_cpus' instead.") - else: - num_initial_workers = (ray_params.num_cpus - if ray_params.num_cpus is not None else - multiprocessing.cpu_count()) - - ray_params.update_if_absent( - address_info={}, - worker_path=os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "workers/default_worker.py")) - ray_params.address_info["node_ip_address"] = ray_params.node_ip_address - - # Start Redis if there isn't already an instance running. TODO(rkn): We are - # suppressing the output of Redis because on Linux it prints a bunch of - # warning messages when it starts up. Instead of suppressing the output, we - # should address the warnings. - ray_params.redis_address = ray_params.address_info.get("redis_address") - ray_params.redis_shards = ray_params.address_info.get("redis_shards", []) - if ray_params.redis_address is None: - ray_params.redis_address, ray_params.redis_shards = start_redis( - ray_params.node_ip_address, - port=ray_params.redis_port, - redis_shard_ports=ray_params.redis_shard_ports, - num_redis_shards=ray_params.num_redis_shards, - redis_max_clients=ray_params.redis_max_clients, - redirect_output=True, - redirect_worker_output=ray_params.redirect_worker_output, - cleanup=cleanup, - password=ray_params.redis_password, - redis_max_memory=ray_params.redis_max_memory) - ray_params.address_info["redis_address"] = ray_params.redis_address - time.sleep(0.1) - - # Start monitoring the processes. - monitor_stdout_file, monitor_stderr_file = new_monitor_log_file( - ray_params.redirect_output) - start_monitor( - ray_params.redis_address, - ray_params.node_ip_address, - stdout_file=monitor_stdout_file, - stderr_file=monitor_stderr_file, - cleanup=cleanup, - autoscaling_config=ray_params.autoscaling_config, - redis_password=ray_params.redis_password) - start_raylet_monitor( - ray_params.redis_address, - stdout_file=monitor_stdout_file, - stderr_file=monitor_stderr_file, - cleanup=cleanup, - redis_password=ray_params.redis_password, - config=config) - if ray_params.redis_shards == []: - # Get redis shards from primary redis instance. - redis_ip_address, redis_port = ray_params.redis_address.split(":") - redis_client = redis.StrictRedis( - host=redis_ip_address, - port=redis_port, - password=ray_params.redis_password) - redis_shards = redis_client.lrange("RedisShards", start=0, end=-1) - ray_params.redis_shards = [ - ray.utils.decode(shard) for shard in redis_shards - ] - ray_params.address_info["redis_shards"] = ray_params.redis_shards - - # Start the log monitor, if necessary. - if ray_params.include_log_monitor: - log_monitor_stdout_file, log_monitor_stderr_file = ( - new_log_monitor_log_file()) - start_log_monitor( - ray_params.redis_address, - ray_params.node_ip_address, - stdout_file=log_monitor_stdout_file, - stderr_file=log_monitor_stderr_file, - cleanup=cleanup, - redis_password=ray_params.redis_password) - - # Initialize with existing services. - object_store_address = ray_params.address_info.get("object_store_address") - raylet_socket_name = ray_params.address_info.get("raylet_socket_name") - - # Start the object store. - assert object_store_address is None - # Start Plasma. - plasma_store_stdout_file, plasma_store_stderr_file = ( - new_plasma_store_log_file(ray_params.redirect_output)) - - ray_params.address_info["object_store_address"] = start_plasma_store( - ray_params.node_ip_address, - ray_params.redis_address, - store_stdout_file=plasma_store_stdout_file, - store_stderr_file=plasma_store_stderr_file, - object_store_memory=ray_params.object_store_memory, - cleanup=cleanup, - plasma_directory=ray_params.plasma_directory, - huge_pages=ray_params.huge_pages, - plasma_store_socket_name=ray_params.plasma_store_socket_name, - redis_password=ray_params.redis_password) - time.sleep(0.1) - - # Start the raylet. - assert raylet_socket_name is None - raylet_stdout_file, raylet_stderr_file = new_raylet_log_file( - redirect_output=ray_params.redirect_worker_output) - ray_params.address_info["raylet_socket_name"] = start_raylet( - ray_params, - ray_params.raylet_socket_name or get_raylet_socket_name(), - ray_params.address_info["object_store_address"], - num_initial_workers=num_initial_workers, - stdout_file=raylet_stdout_file, - stderr_file=raylet_stderr_file, - cleanup=cleanup, - config=config) - - # Try to start the web UI. - if ray_params.include_webui: - ui_stdout_file, ui_stderr_file = new_webui_log_file() - ray_params.address_info["webui_url"] = start_ui( - ray_params.redis_address, - stdout_file=ui_stdout_file, - stderr_file=ui_stderr_file, - cleanup=cleanup) - else: - ray_params.address_info["webui_url"] = "" - # Return the addresses of the relevant processes. - return ray_params.address_info - - -def start_ray_node(ray_params, cleanup=True): - """Start the Ray processes for a single node. - - This assumes that the Ray processes on some master node have already been - started. - - Args: - ray_params (ray.params.RayParams): The RayParams instance. The - following parameters could be checked: node_ip_address, - redis_address, object_manager_port, node_manager_port, - num_workers, object_store_memory, redis_password, worker_path, - cleanup, redirect_worker_output, redirect_output, resources, - plasma_directory, huge_pages, plasma_store_socket_name, - raylet_socket_name, temp_dir, _internal_config. - cleanup (bool): If cleanup is true, then the processes started here - will be killed by services.cleanup() when the Python process that - called this method exits. - - Returns: - A dictionary of the address information for the processes that were - started. - """ - ray_params.address_info = { - "redis_address": ray_params.redis_address, - } - ray_params.update(include_log_monitor=True) - return start_ray_processes(ray_params, cleanup=cleanup) - - -def start_ray_head(ray_params, cleanup=True): - """Start Ray in local mode. - - Args: - ray_params (ray.params.RayParams): The RayParams instance. The - following parameters could be checked: address_info, - object_manager_port, node_manager_port, node_ip_address, - redis_port, redis_shard_ports, num_workers, object_store_memory, - redis_max_memory, worker_path, cleanup, redirect_worker_output, - redirect_output, start_workers_from_local_scheduler, resources, - num_redis_shards, redis_max_clients, redis_password, include_webui, - huge_pages, plasma_directory, autoscaling_config, - plasma_store_socket_name, raylet_socket_name, temp_dir, - _internal_config. - cleanup (bool): If cleanup is true, then the processes started here - will be killed by services.cleanup() when the Python process that - called this method exits. - - Returns: - A dictionary of the address information for the processes that were - started. - """ - ray_params.update_if_absent(num_redis_shards=1, include_webui=True) - ray_params.update(include_log_monitor=True) - return start_ray_processes(ray_params, cleanup=cleanup) + return p diff --git a/python/ray/tempfile_services.py b/python/ray/tempfile_services.py index 0ffced3fa..6e47c3991 100644 --- a/python/ray/tempfile_services.py +++ b/python/ray/tempfile_services.py @@ -230,3 +230,10 @@ def new_monitor_log_file(redirect_output): monitor_stdout_file, monitor_stderr_file = new_log_files( "monitor", redirect_output) return monitor_stdout_file, monitor_stderr_file + + +def new_raylet_monitor_log_file(redirect_output): + """Create new logging files for the raylet monitor.""" + raylet_monitor_stdout_file, raylet_monitor_stderr_file = new_log_files( + "raylet_monitor", redirect_output) + return raylet_monitor_stdout_file, raylet_monitor_stderr_file diff --git a/python/ray/test/cluster_utils.py b/python/ray/test/cluster_utils.py index f346184b0..06c4b9f38 100644 --- a/python/ray/test/cluster_utils.py +++ b/python/ray/test/cluster_utils.py @@ -2,15 +2,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import atexit import logging import time import redis import ray -from ray.parameter import RayParams -import ray.services as services logger = logging.getLogger(__name__) @@ -35,10 +32,10 @@ class Cluster(object): for shutting down all started processes. """ self.head_node = None - self.worker_nodes = {} + self.worker_nodes = set() self.redis_address = None - self.redis_password = None self.connected = False + self._shutdown_at_exit = shutdown_at_exit if not initialize_head and connect: raise RuntimeError("Cannot connect to uninitialized cluster.") @@ -46,14 +43,12 @@ class Cluster(object): head_node_args = head_node_args or {} self.add_node(**head_node_args) if connect: - self.connect(head_node_args) - if shutdown_at_exit: - atexit.register(self.shutdown) + self.connect() - def connect(self, head_node_args): + def connect(self): + """Connect the driver to the cluster.""" assert self.redis_address is not None assert not self.connected - self.redis_password = head_node_args.get("redis_password") output_info = ray.init( ignore_reinit_error=True, redis_address=self.redis_address, @@ -61,7 +56,7 @@ class Cluster(object): logger.info(output_info) self.connected = True - def add_node(self, **override_kwargs): + def add_node(self, **node_args): """Adds a node to the local Ray Cluster. All nodes are by default started with the following settings: @@ -70,41 +65,39 @@ class Cluster(object): object_store_memory=100 * (2**20) # 100 MB Args: - override_kwargs: Keyword arguments used in `start_ray_head` - and `start_ray_node`. Overrides defaults. + node_args: Keyword arguments used in `start_ray_head` and + `start_ray_node`. Overrides defaults. Returns: Node object of the added Ray node. """ - node_kwargs = { + default_kwargs = { "num_cpus": 1, - "object_store_memory": 100 * (2**20) # 100 MB + "object_store_memory": 100 * (2**20), # 100 MB } - node_kwargs.update(override_kwargs) - ray_params = RayParams( - node_ip_address=services.get_node_ip_address(), **node_kwargs) - + ray_params = ray.parameter.RayParams(**node_args) + ray_params.update_if_absent(**default_kwargs) if self.head_node is None: - ray_params.update(include_webui=False) - address_info = services.start_ray_head(ray_params, cleanup=True) - self.redis_address = address_info["redis_address"] - # TODO(rliaw): Find a more stable way than modifying global state. - process_dict_copy = services.all_processes.copy() - for key in services.all_processes: - services.all_processes[key] = [] - node = Node(address_info, process_dict_copy) + node = ray.node.Node( + ray_params, head=True, shutdown_at_exit=self._shutdown_at_exit) self.head_node = node + self.redis_address = self.head_node.redis_address + self.redis_password = node_args.get("redis_password") + self.webui_url = self.head_node.webui_url else: - ray_params.update(redis_address=self.redis_address) - address_info = services.start_ray_node(ray_params, cleanup=True) - # TODO(rliaw): Find a more stable way than modifying global state. - process_dict_copy = services.all_processes.copy() - for key in services.all_processes: - services.all_processes[key] = [] - node = Node(address_info, process_dict_copy) - self.worker_nodes[node] = address_info - logger.info("Starting Node with raylet socket {}".format( - address_info["raylet_socket_name"])) + ray_params.update_if_absent(redis_address=self.redis_address) + node = ray.node.Node( + ray_params, + head=False, + shutdown_at_exit=self._shutdown_at_exit) + self.worker_nodes.add(node) + + # Wait for the node to appear in the client table. We do this so that + # the nodes appears in the client table in the order that the + # corresponding calls to add_node were made. We do this because in the + # tests we assume that the driver is connected to the first node that + # is added. + self._wait_for_node(node) return node @@ -116,16 +109,44 @@ class Cluster(object): will be removed. """ if self.head_node == node: - self.head_node.kill_all_processes() + self.head_node.kill_all_processes(check_alive=False) self.head_node = None # TODO(rliaw): Do we need to kill all worker processes? else: - node.kill_all_processes() - self.worker_nodes.pop(node) + node.kill_all_processes(check_alive=False) + self.worker_nodes.remove(node) assert not node.any_processes_alive(), ( "There are zombie processes left over after killing.") + def _wait_for_node(self, node, timeout=30): + """Wait until this node has appeared in the client table. + + Args: + node (ray.node.Node): The node to wait for. + timeout: The amount of time in seconds to wait before raising an + exception. + + Raises: + Exception: An exception is raised if the timeout expires before the + node appears in the client table. + """ + ip_address, port = self.redis_address.split(":") + redis_client = redis.StrictRedis( + host=ip_address, port=int(port), password=self.redis_password) + + start_time = time.time() + while time.time() - start_time < timeout: + clients = ray.experimental.state.parse_client_table(redis_client) + object_store_socket_names = [ + client["ObjectStoreSocketName"] for client in clients + ] + if node.plasma_store_socket_name in object_store_socket_names: + return + else: + time.sleep(0.1) + raise Exception("Timed out while waiting for nodes to join.") + def wait_for_nodes(self, timeout=30): """Waits for correct number of nodes to be registered. @@ -179,6 +200,18 @@ class Cluster(object): nodes = [self.head_node] + nodes return nodes + def remaining_processes_alive(self): + """Returns a bool indicating whether all processes are alive or not. + + Note that this ignores processes that have been explicitly killed, + e.g., via a command like node.kill_raylet(). + + Returns: + True if all processes are alive and false otherwise. + """ + return all( + node.remaining_processes_alive() for node in self.list_all_nodes()) + def shutdown(self): """Removes all nodes.""" @@ -188,63 +221,5 @@ class Cluster(object): for node in all_nodes: self.remove_node(node) - if self.head_node: + if self.head_node is not None: self.remove_node(self.head_node) - else: - logger.warning("No headnode exists!") - - -class Node(object): - """Abstraction for a Ray node.""" - - def __init__(self, address_info, process_dict): - # TODO(rliaw): Is there a unique identifier for a node? - self.address_info = address_info - self.process_dict = process_dict - - def kill_plasma_store(self): - self.process_dict[services.PROCESS_TYPE_PLASMA_STORE][0].kill() - self.process_dict[services.PROCESS_TYPE_PLASMA_STORE][0].wait() - - def kill_raylet(self): - self.process_dict[services.PROCESS_TYPE_RAYLET][0].kill() - self.process_dict[services.PROCESS_TYPE_RAYLET][0].wait() - - def kill_log_monitor(self): - self.process_dict["log_monitor"][0].kill() - self.process_dict["log_monitor"][0].wait() - - def kill_all_processes(self): - for process_name, process_list in self.process_dict.items(): - logger.info("Killing all {}(s)".format(process_name)) - for process in process_list: - # Kill the process if it is still alive. - if process.poll() is None: - process.kill() - - for process_name, process_list in self.process_dict.items(): - logger.info("Waiting all {}(s)".format(process_name)) - for process in process_list: - process.wait() - - def live_processes(self): - return [(p_name, proc) for p_name, p_list in self.process_dict.items() - for proc in p_list if proc.poll() is None] - - def dead_processes(self): - return [(p_name, proc) for p_name, p_list in self.process_dict.items() - for proc in p_list if proc.poll() is not None] - - def any_processes_alive(self): - return any(self.live_processes()) - - def all_processes_alive(self): - return not any(self.dead_processes()) - - def get_plasma_store_name(self): - """Return the plasma store name. - - Assuming one plasma store per raylet, this may be used as a unique - identifier for a node. - """ - return self.address_info['object_store_address'] diff --git a/python/ray/worker.py b/python/ray/worker.py index 1576265eb..4c8dde100 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -26,6 +26,7 @@ import ray.cloudpickle as pickle import ray.experimental.state as state import ray.gcs_utils import ray.memory_monitor as memory_monitor +import ray.node import ray.remote_function import ray.serialization as serialization import ray.services as services @@ -37,7 +38,7 @@ import ray.ray_constants as ray_constants from ray import import_thread from ray import profiling from ray.function_manager import (FunctionActorManager, FunctionDescriptor) -from ray.parameter import RayParams +import ray.parameter from ray.utils import ( check_oversized_pickle, is_cython, @@ -1071,6 +1072,9 @@ per worker process. global_state = state.GlobalState() +_global_node = None +"""ray.node.Node: The global node object that is created by ray.init().""" + class RayConnectionError(Exception): pass @@ -1250,148 +1254,6 @@ def get_address_info_from_redis(redis_address, counter += 1 -def _init(ray_params, driver_id=None): - """Helper method to connect to an existing Ray cluster or start a new one. - - This method handles two cases. Either a Ray cluster already exists and we - just attach this driver to it, or we start all of the processes associated - with a Ray cluster and attach to the newly started cluster. - - Args: - ray_params (ray.params.RayParams): The RayParams instance. The - following parameters could be checked: address_info, - start_ray_local, object_id_seed, num_workers, - object_store_memory, redis_max_memory, local_mode, - redirect_worker_output, driver_mode, redirect_output, - start_workers_from_local_scheduler, num_cpus, num_gpus, resources, - num_redis_shards, redis_max_clients, redis_password, - plasma_directory, huge_pages, include_webui, driver_id, - plasma_store_socket_name, temp_dir, raylet_socket_name, - _internal_config - driver_id: The ID of driver. - - Returns: - Address information about the started processes. - - Raises: - Exception: An exception is raised if an inappropriate combination of - arguments is passed in. - """ - if ray_params.driver_mode is not None: - raise Exception("The 'driver_mode' argument has been deprecated. " - "To run Ray in local mode, pass in local_mode=True.") - if ray_params.local_mode: - ray_params.driver_mode = LOCAL_MODE - else: - ray_params.driver_mode = SCRIPT_MODE - - # Get addresses of existing services. - if ray_params.address_info is None: - ray_params.address_info = {} - else: - assert isinstance(ray_params.address_info, dict) - ray_params.node_ip_address = ray_params.address_info.get("node_ip_address") - ray_params.redis_address = ray_params.address_info.get("redis_address") - - # Start any services that do not yet exist. - if ray_params.driver_mode == LOCAL_MODE: - # If starting Ray in LOCAL_MODE, don't start any other processes. - pass - elif ray_params.start_ray_local: - # In this case, we launch a scheduler, a new object store, and some - # workers, and we connect to them. We do not launch any processes that - # are already registered in address_info. - ray_params.update_if_absent( - node_ip_address=ray.services.get_node_ip_address()) - # Use 1 additional redis shard if num_redis_shards is not provided. - ray_params.update_if_absent(num_redis_shards=1) - - # Start the scheduler, object store, and some workers. These will be - # killed by the call to shutdown(), which happens when the Python - # script exits. - ray_params.address_info = services.start_ray_head(ray_params) - else: - if ray_params.redis_address is None: - raise Exception("When connecting to an existing cluster, " - "redis_address must be provided.") - if ray_params.num_workers is not None: - raise Exception("When connecting to an existing cluster, " - "num_workers must not be provided.") - if ray_params.num_cpus is not None or ray_params.num_gpus is not None: - raise Exception("When connecting to an existing cluster, num_cpus " - "and num_gpus must not be provided.") - if ray_params.resources is not None: - raise Exception("When connecting to an existing cluster, " - "resources must not be provided.") - if ray_params.num_redis_shards is not None: - raise Exception("When connecting to an existing cluster, " - "num_redis_shards must not be provided.") - if ray_params.redis_max_clients is not None: - raise Exception("When connecting to an existing cluster, " - "redis_max_clients must not be provided.") - if ray_params.object_store_memory is not None: - raise Exception("When connecting to an existing cluster, " - "object_store_memory must not be provided.") - if ray_params.redis_max_memory is not None: - raise Exception("When connecting to an existing cluster, " - "redis_max_memory must not be provided.") - if ray_params.plasma_directory is not None: - raise Exception("When connecting to an existing cluster, " - "plasma_directory must not be provided.") - if ray_params.huge_pages: - raise Exception("When connecting to an existing cluster, " - "huge_pages must not be provided.") - if ray_params.temp_dir is not None: - raise Exception("When connecting to an existing cluster, " - "temp_dir must not be provided.") - if ray_params.plasma_store_socket_name is not None: - raise Exception("When connecting to an existing cluster, " - "plasma_store_socket_name must not be provided.") - if ray_params.raylet_socket_name is not None: - raise Exception("When connecting to an existing cluster, " - "raylet_socket_name must not be provided.") - if ray_params._internal_config is not None: - raise Exception("When connecting to an existing cluster, " - "_internal_config must not be provided.") - - # Get the node IP address if one is not provided. - ray_params.update_if_absent( - node_ip_address=services.get_node_ip_address( - ray_params.redis_address)) - # Get the address info of the processes to connect to from Redis. - ray_params.address_info = get_address_info_from_redis( - ray_params.redis_address, - ray_params.node_ip_address, - redis_password=ray_params.redis_password) - - # Connect this driver to Redis, the object store, and the local scheduler. - # Choose the first object store and local scheduler if there are multiple. - # The corresponding call to disconnect will happen in the call to - # shutdown() when the Python script exits. - if ray_params.driver_mode == LOCAL_MODE: - driver_address_info = {} - else: - driver_address_info = { - "node_ip_address": ray_params.node_ip_address, - "redis_address": ray_params.address_info["redis_address"], - "store_socket_name": ray_params.address_info[ - "object_store_address"], - "webui_url": ray_params.address_info["webui_url"], - } - driver_address_info["raylet_socket_name"] = ( - ray_params.address_info["raylet_socket_name"]) - - # We only pass `temp_dir` to a worker (WORKER_MODE). - # It can't be a worker here. - connect( - ray_params, - driver_address_info, - mode=ray_params.driver_mode, - worker=global_worker, - driver_id=driver_id) - return ray_params.address_info - - def init(redis_address=None, num_cpus=None, num_gpus=None, @@ -1520,6 +1382,14 @@ def init(redis_address=None, raise DeprecationWarning("The use_raylet argument is deprecated. " "Please remove it.") + if driver_mode is not None: + raise Exception("The 'driver_mode' argument has been deprecated. " + "To run Ray in local mode, pass in local_mode=True.") + if local_mode: + driver_mode = LOCAL_MODE + else: + driver_mode = SCRIPT_MODE + if setproctitle is None: logger.warning( "WARNING: Not updating worker name since `setproctitle` is not " @@ -1532,7 +1402,10 @@ def init(redis_address=None, "called.") return else: - raise Exception("Perhaps you called ray.init twice by accident?") + raise Exception("Perhaps you called ray.init twice by accident? " + "This error can be suppressed by passing in " + "'ignore_reinit_error=True' or by calling " + "'ray.shutdown()' prior to 'ray.init()'.") # Convert hostnames to numerical IP address. if node_ip_address is not None: @@ -1540,36 +1413,132 @@ def init(redis_address=None, if redis_address is not None: redis_address = services.address_to_ip(redis_address) - info = {"node_ip_address": node_ip_address, "redis_address": redis_address} - ray_params = RayParams( - address_info=info, - start_ray_local=(redis_address is None), - num_workers=num_workers, - object_id_seed=object_id_seed, - local_mode=local_mode, - driver_mode=driver_mode, - redirect_worker_output=redirect_worker_output, - redirect_output=redirect_output, - num_cpus=num_cpus, - num_gpus=num_gpus, - resources=resources, - num_redis_shards=num_redis_shards, - redis_max_clients=redis_max_clients, + address_info = { + "node_ip_address": node_ip_address, + "redis_address": redis_address + } + + if driver_mode == LOCAL_MODE: + # If starting Ray in LOCAL_MODE, don't start any other processes. + pass + elif redis_address is None: + if node_ip_address is None: + node_ip_address = ray.services.get_node_ip_address() + if num_redis_shards is None: + num_redis_shards = 1 + # In this case, we need to start a new cluster. + ray_params = ray.parameter.RayParams( + redis_address=redis_address, + node_ip_address=node_ip_address, + num_workers=num_workers, + object_id_seed=object_id_seed, + local_mode=local_mode, + driver_mode=driver_mode, + redirect_worker_output=redirect_worker_output, + redirect_output=redirect_output, + num_cpus=num_cpus, + num_gpus=num_gpus, + resources=resources, + num_redis_shards=num_redis_shards, + redis_max_clients=redis_max_clients, + redis_password=redis_password, + plasma_directory=plasma_directory, + huge_pages=huge_pages, + include_webui=include_webui, + object_store_memory=object_store_memory, + redis_max_memory=redis_max_memory, + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir, + _internal_config=_internal_config, + ) + # Start the Ray processes. We set shutdown_at_exit=False because we + # shutdown the node in the ray.shutdown call that happens in the atexit + # handler. + global _global_node + _global_node = ray.node.Node( + head=True, shutdown_at_exit=False, ray_params=ray_params) + address_info["redis_address"] = _global_node.redis_address + address_info[ + "object_store_address"] = _global_node.plasma_store_socket_name + address_info["webui_url"] = _global_node.webui_url + address_info["raylet_socket_name"] = _global_node.raylet_socket_name + else: + # In this case, we are connecting to an existing cluster. + if num_workers is not None: + raise Exception("When connecting to an existing cluster, " + "num_workers must not be provided.") + if num_cpus is not None or num_gpus is not None: + raise Exception("When connecting to an existing cluster, num_cpus " + "and num_gpus must not be provided.") + if resources is not None: + raise Exception("When connecting to an existing cluster, " + "resources must not be provided.") + if num_redis_shards is not None: + raise Exception("When connecting to an existing cluster, " + "num_redis_shards must not be provided.") + if redis_max_clients is not None: + raise Exception("When connecting to an existing cluster, " + "redis_max_clients must not be provided.") + if object_store_memory is not None: + raise Exception("When connecting to an existing cluster, " + "object_store_memory must not be provided.") + if redis_max_memory is not None: + raise Exception("When connecting to an existing cluster, " + "redis_max_memory must not be provided.") + if plasma_directory is not None: + raise Exception("When connecting to an existing cluster, " + "plasma_directory must not be provided.") + if huge_pages: + raise Exception("When connecting to an existing cluster, " + "huge_pages must not be provided.") + if temp_dir is not None: + raise Exception("When connecting to an existing cluster, " + "temp_dir must not be provided.") + if plasma_store_socket_name is not None: + raise Exception("When connecting to an existing cluster, " + "plasma_store_socket_name must not be provided.") + if raylet_socket_name is not None: + raise Exception("When connecting to an existing cluster, " + "raylet_socket_name must not be provided.") + if _internal_config is not None: + raise Exception("When connecting to an existing cluster, " + "_internal_config must not be provided.") + + # Get the node IP address if one is not provided. + + if node_ip_address is None: + node_ip_address = services.get_node_ip_address(redis_address) + # Get the address info of the processes to connect to from Redis. + address_info = get_address_info_from_redis( + redis_address, node_ip_address, redis_password=redis_password) + + if driver_mode == LOCAL_MODE: + driver_address_info = {} + else: + driver_address_info = { + "node_ip_address": node_ip_address, + "redis_address": address_info["redis_address"], + "store_socket_name": address_info["object_store_address"], + "webui_url": address_info["webui_url"], + } + driver_address_info["raylet_socket_name"] = ( + address_info["raylet_socket_name"]) + + # We only pass `temp_dir` to a worker (WORKER_MODE). + # It can't be a worker here. + connect( + driver_address_info, redis_password=redis_password, - plasma_directory=plasma_directory, - huge_pages=huge_pages, - include_webui=include_webui, - object_store_memory=object_store_memory, - redis_max_memory=redis_max_memory, - plasma_store_socket_name=plasma_store_socket_name, - raylet_socket_name=raylet_socket_name, - temp_dir=temp_dir, - _internal_config=_internal_config, - ) - ret = _init(ray_params, driver_id=driver_id) + object_id_seed=object_id_seed, + mode=driver_mode, + worker=global_worker, + driver_id=driver_id) + for hook in _post_init_hooks: hook() - return ret + + return address_info # Functions to run as callback after a successful ray init @@ -1601,16 +1570,11 @@ def shutdown(worker=global_worker): if hasattr(worker, "plasma_client"): worker.plasma_client.disconnect() - if worker.mode == SCRIPT_MODE: - services.cleanup() - else: - # If this is not a driver, make sure there are no orphan processes, - # besides possibly the worker itself. - for process_type, processes in services.all_processes.items(): - if process_type == services.PROCESS_TYPE_WORKER: - assert len(processes) <= 1 - else: - assert len(processes) == 0 + # Shut down the Ray processes. + global _global_node + if _global_node is not None: + _global_node.kill_all_processes(check_alive=False, allow_graceful=True) + _global_node = None worker.set_mode(None) @@ -1767,19 +1731,23 @@ def print_error_messages(worker): pass -def connect(ray_params, - info, +def connect(info, + redis_password=None, + object_id_seed=None, mode=WORKER_MODE, worker=global_worker, driver_id=None): """Connect this worker to the local scheduler, to Plasma, and to Redis. Args: - ray_params (ray.params.RayParams): The RayParams instance. The - following parameters could be checked: object_id_seed, - redis_password info (dict): A dictionary with address of the Redis server and the sockets of the plasma store and raylet. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. + object_id_seed (int): Used to seed the deterministic generation of + object IDs. The same value can be used across multiple runs of the + same job in order to generate the object IDs in a consistent + manner. However, the same ID should not be used for different jobs. mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE. worker: The ray.Worker instance. @@ -1839,7 +1807,7 @@ def connect(ray_params, redis.StrictRedis( host=redis_ip_address, port=int(redis_port), - password=ray_params.redis_password)) + password=redis_password)) # For driver's check that the version information matches the version # information that the Ray cluster was started with. @@ -1877,13 +1845,11 @@ def connect(ray_params, services.record_log_files_in_redis( info["redis_address"], info["node_ip_address"], [log_stdout_file, log_stderr_file], - password=ray_params.redis_password) + password=redis_password) # Create an object for interfacing with the global state. global_state._initialize_global_state( - redis_ip_address, - int(redis_port), - redis_password=ray_params.redis_password) + redis_ip_address, int(redis_port), redis_password=redis_password) # Register the worker with Redis. if mode == SCRIPT_MODE: @@ -1932,8 +1898,8 @@ def connect(ray_params, # the user's random number generator). Otherwise, set the current task # ID randomly to avoid object ID collisions. numpy_state = np.random.get_state() - if ray_params.object_id_seed is not None: - np.random.seed(ray_params.object_id_seed) + if object_id_seed is not None: + np.random.seed(object_id_seed) else: # Try to use true randomness. np.random.seed(None) diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 335b228ec..47c2532dc 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -8,7 +8,6 @@ import traceback import ray import ray.actor -from ray.parameter import RayParams import ray.ray_constants as ray_constants import ray.tempfile_services as tempfile_services @@ -76,15 +75,8 @@ if __name__ == "__main__": # Override the temporary directory. tempfile_services.set_temp_root(args.temp_dir) - ray_params = RayParams( - node_ip_address=args.node_ip_address, - redis_address=args.redis_address, - redis_password=args.redis_password, - plasma_store_socket_name=args.object_store_name, - raylet_socket_name=args.raylet_name, - temp_dir=args.temp_dir) - - ray.worker.connect(ray_params, info, mode=ray.WORKER_MODE) + ray.worker.connect( + info, redis_password=args.redis_password, mode=ray.WORKER_MODE) error_explanation = """ This error is unexpected and should not have happened. Somehow a worker diff --git a/test/actor_test.py b/test/actor_test.py index c9bc9c3af..d92506f90 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -1312,7 +1312,7 @@ def test_exception_raised_when_actor_node_dies(head_node_cluster): # Create an actor that is not on the local scheduler. actor = Counter.remote() while (ray.get(actor.local_plasma.remote()) != - remote_node.get_plasma_store_name()): + remote_node.plasma_store_socket_name): actor = Counter.remote() # Kill the second node. @@ -1456,15 +1456,13 @@ def setup_counter_actor(test_checkpoint=False, os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Hanging with new GCS API.") def test_checkpointing(two_node_cluster): + cluster = two_node_cluster actor, ids = setup_counter_actor(test_checkpoint=True) # Wait for the last task to finish running. ray.get(ids[-1]) # Kill the corresponding plasma store to get rid of the cached objects. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] - process.kill() - process.wait() + cluster.list_all_nodes()[1].kill_plasma_store(wait=True) # Check that the actor restored from a checkpoint. assert ray.get(actor.test_restore.remote()) @@ -1484,16 +1482,14 @@ def test_checkpointing(two_node_cluster): os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Hanging with new GCS API.") def test_remote_checkpoint(two_node_cluster): + cluster = two_node_cluster actor, ids = setup_counter_actor(test_checkpoint=True) # Do a remote checkpoint call and wait for it to finish. ray.get(actor.__ray_checkpoint__.remote()) # Kill the corresponding plasma store to get rid of the cached objects. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] - process.kill() - process.wait() + cluster.list_all_nodes()[1].kill_plasma_store(wait=True) # Check that the actor restored from a checkpoint. assert ray.get(actor.test_restore.remote()) @@ -1513,15 +1509,13 @@ def test_remote_checkpoint(two_node_cluster): os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Hanging with new GCS API.") def test_lost_checkpoint(two_node_cluster): + cluster = two_node_cluster actor, ids = setup_counter_actor(test_checkpoint=True) # Wait for the first fraction of tasks to finish running. ray.get(ids[len(ids) // 10]) # Kill the corresponding plasma store to get rid of the cached objects. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] - process.kill() - process.wait() + cluster.list_all_nodes()[1].kill_plasma_store(wait=True) # Check that the actor restored from a checkpoint. assert ray.get(actor.test_restore.remote()) @@ -1542,15 +1536,13 @@ def test_lost_checkpoint(two_node_cluster): os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Hanging with new GCS API.") def test_checkpoint_exception(two_node_cluster): + cluster = two_node_cluster actor, ids = setup_counter_actor(test_checkpoint=True, save_exception=True) # Wait for the last task to finish running. ray.get(ids[-1]) # Kill the corresponding plasma store to get rid of the cached objects. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] - process.kill() - process.wait() + cluster.list_all_nodes()[1].kill_plasma_store(wait=True) # Check that we can submit another call on the actor and get the # correct counter result. @@ -1573,16 +1565,14 @@ def test_checkpoint_exception(two_node_cluster): os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Hanging with new GCS API.") def test_checkpoint_resume_exception(two_node_cluster): + cluster = two_node_cluster actor, ids = setup_counter_actor( test_checkpoint=True, resume_exception=True) # Wait for the last task to finish running. ray.get(ids[-1]) # Kill the corresponding plasma store to get rid of the cached objects. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] - process.kill() - process.wait() + cluster.list_all_nodes()[1].kill_plasma_store(wait=True) # Check that we can submit another call on the actor and get the # correct counter result. @@ -1603,6 +1593,7 @@ def test_checkpoint_resume_exception(two_node_cluster): @pytest.mark.skip("Fork/join consistency not yet implemented.") def test_distributed_handle(two_node_cluster): + cluster = two_node_cluster counter, ids = setup_counter_actor(test_checkpoint=False) @ray.remote @@ -1625,10 +1616,7 @@ def test_distributed_handle(two_node_cluster): # Kill the second plasma store to get rid of the cached objects and # trigger the corresponding local scheduler to exit. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] - process.kill() - process.wait() + cluster.list_all_nodes()[1].kill_plasma_store(wait=True) # Check that the actor did not restore from a checkpoint. assert not ray.get(counter.test_restore.remote()) @@ -1643,6 +1631,7 @@ def test_distributed_handle(two_node_cluster): os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Hanging with new GCS API.") def test_remote_checkpoint_distributed_handle(two_node_cluster): + cluster = two_node_cluster counter, ids = setup_counter_actor(test_checkpoint=True) @ray.remote @@ -1666,10 +1655,7 @@ def test_remote_checkpoint_distributed_handle(two_node_cluster): # Kill the second plasma store to get rid of the cached objects and # trigger the corresponding local scheduler to exit. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] - process.kill() - process.wait() + cluster.list_all_nodes()[1].kill_plasma_store(wait=True) # Check that the actor restored from a checkpoint. assert ray.get(counter.test_restore.remote()) @@ -1686,6 +1672,7 @@ def test_remote_checkpoint_distributed_handle(two_node_cluster): @pytest.mark.skip("Fork/join consistency not yet implemented.") def test_checkpoint_distributed_handle(two_node_cluster): + cluster = two_node_cluster counter, ids = setup_counter_actor(test_checkpoint=True) @ray.remote @@ -1708,10 +1695,7 @@ def test_checkpoint_distributed_handle(two_node_cluster): # Kill the second plasma store to get rid of the cached objects and # trigger the corresponding local scheduler to exit. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] - process.kill() - process.wait() + cluster.list_all_nodes()[1].kill_plasma_store(wait=True) # Check that the actor restored from a checkpoint. assert ray.get(counter.test_restore.remote()) @@ -1721,8 +1705,8 @@ def test_checkpoint_distributed_handle(two_node_cluster): assert x == count + 1 -def _test_nondeterministic_reconstruction(num_forks, num_items_per_fork, - num_forks_to_wait): +def _test_nondeterministic_reconstruction( + cluster, num_forks, num_items_per_fork, num_forks_to_wait): # Make a shared queue. @ray.remote class Queue(object): @@ -1774,10 +1758,7 @@ def _test_nondeterministic_reconstruction(num_forks, num_items_per_fork, # Kill the second plasma store to get rid of the cached objects and # trigger the corresponding local scheduler to exit. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] - process.kill() - process.wait() + cluster.list_all_nodes()[1].kill_plasma_store(wait=True) # Read the queue again and check for deterministic reconstruction. ray.get(enqueue_tasks) @@ -1794,14 +1775,16 @@ def _test_nondeterministic_reconstruction(num_forks, num_items_per_fork, os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Currently doesn't work with the new GCS.") def test_nondeterministic_reconstruction(two_node_cluster): - _test_nondeterministic_reconstruction(10, 100, 10) + cluster = two_node_cluster + _test_nondeterministic_reconstruction(cluster, 10, 100, 10) @pytest.mark.skip("Nondeterministic reconstruction currently not supported " "when there are concurrent forks that didn't finish " "initial execution.") def test_nondeterministic_reconstruction_concurrent_forks(two_node_cluster): - _test_nondeterministic_reconstruction(10, 100, 1) + cluster = two_node_cluster + _test_nondeterministic_reconstruction(cluster, 10, 100, 1) @pytest.fixture @@ -2278,7 +2261,7 @@ def test_actor_reconstruction_on_node_failure(head_node_cluster): def kill_node(object_store_socket): node_to_remove = None for node in cluster.worker_nodes: - if object_store_socket == node.get_plasma_store_name(): + if object_store_socket == node.plasma_store_socket_name: node_to_remove = node cluster.remove_node(node_to_remove) diff --git a/test/component_failures_test.py b/test/component_failures_test.py index 36ee0a498..eb488f6c8 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -91,7 +91,7 @@ def test_dying_worker_get(shutdown_only): time.sleep(0.1) # Make sure that nothing has died. - assert ray.services.all_processes_alive() + assert ray.services.remaining_processes_alive() # This test checks that when a driver dies in the middle of a get, the plasma @@ -134,7 +134,7 @@ ray.get(ray.ObjectID(ray.utils.hex_to_binary("{}"))) time.sleep(0.1) # Make sure that nothing has died. - assert ray.services.all_processes_alive() + assert ray.services.remaining_processes_alive() # This test checks that when a worker dies in the middle of a wait, the plasma @@ -176,7 +176,7 @@ def test_dying_worker_wait(shutdown_only): time.sleep(0.1) # Make sure that nothing has died. - assert ray.services.all_processes_alive() + assert ray.services.remaining_processes_alive() # This test checks that when a driver dies in the middle of a wait, the plasma @@ -219,7 +219,7 @@ ray.wait([ray.ObjectID(ray.utils.hex_to_binary("{}"))]) time.sleep(0.1) # Make sure that nothing has died. - assert ray.services.all_processes_alive() + assert ray.services.remaining_processes_alive() @pytest.fixture(params=[(1, 4), (4, 4)]) @@ -241,6 +241,23 @@ def ray_start_workers_separate_multinode(request): def test_worker_failed(ray_start_workers_separate_multinode): num_nodes, num_initial_workers = (ray_start_workers_separate_multinode) + @ray.remote + def get_pids(): + time.sleep(0.25) + return os.getpid() + + start_time = time.time() + pids = set() + while len(pids) < num_nodes * num_initial_workers: + new_pids = ray.get([ + get_pids.remote() + for _ in range(2 * num_nodes * num_initial_workers) + ]) + for pid in new_pids: + pids.add(pid) + if time.time() - start_time > 60: + raise Exception("Timed out while waiting to get worker PIDs.") + @ray.remote def f(x): time.sleep(0.5) @@ -253,13 +270,16 @@ def test_worker_failed(ray_start_workers_separate_multinode): # Allow the tasks some time to begin executing. time.sleep(0.1) # Kill the workers as the tasks execute. - for worker in ( - ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER]): - worker.terminate() + for pid in pids: + os.kill(pid, signal.SIGKILL) time.sleep(0.1) - # Make sure that we can still get the objects after the executing tasks - # died. - ray.get(object_ids) + # Make sure that we either get the object or we get an appropriate + # exception. + for object_id in object_ids: + try: + ray.get(object_id) + except ray.worker.RayTaskError: + pass @pytest.fixture @@ -278,13 +298,13 @@ def ray_initialize_cluster(): })) ray.init(redis_address=cluster.redis_address) - yield None + yield cluster ray.shutdown() cluster.shutdown() -def _test_component_failed(component_type): +def _test_component_failed(cluster, component_type): """Kill a component on all worker nodes and check workload succeeds.""" # Submit many tasks with many dependencies. @ray.remote @@ -299,8 +319,10 @@ def _test_component_failed(component_type): # execute. Do this in a loop while submitting tasks between each # component failure. time.sleep(0.1) - components = ray.services.all_processes[component_type] - for process in components[1:]: + worker_nodes = cluster.list_all_nodes()[1:] + assert len(worker_nodes) > 0 + for node in worker_nodes: + process = node.all_processes[component_type][0].process # Submit a round of tasks with many dependencies. x = 1 for _ in range(1000): @@ -324,40 +346,43 @@ def _test_component_failed(component_type): ray.get(xs) -def check_components_alive(component_type, check_component_alive): - """Check that a given component type is alive on all worker nodes. - """ - components = ray.services.all_processes[component_type][1:] - for component in components: +def check_components_alive(cluster, component_type, check_component_alive): + """Check that a given component type is alive on all worker nodes.""" + worker_nodes = cluster.list_all_nodes()[1:] + assert len(worker_nodes) > 0 + for node in worker_nodes: + process = node.all_processes[component_type][0].process if check_component_alive: - assert component.poll() is None + assert process.poll() is None else: print("waiting for " + component_type + " with PID " + - str(component.pid) + "to terminate") - component.wait() + str(process.pid) + "to terminate") + process.wait() print("done waiting for " + component_type + " with PID " + - str(component.pid) + "to terminate") - assert not component.poll() is None + str(process.pid) + "to terminate") + assert not process.poll() is None def test_raylet_failed(ray_initialize_cluster): + cluster = ray_initialize_cluster # Kill all local schedulers on worker nodes. - _test_component_failed(ray.services.PROCESS_TYPE_RAYLET) + _test_component_failed(cluster, ray.node.PROCESS_TYPE_RAYLET) # The plasma stores should still be alive on the worker nodes. - check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, True) + check_components_alive(cluster, ray.node.PROCESS_TYPE_PLASMA_STORE, True) @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Hanging with new GCS API.") def test_plasma_store_failed(ray_initialize_cluster): + cluster = ray_initialize_cluster # Kill all plasma stores on worker nodes. - _test_component_failed(ray.services.PROCESS_TYPE_PLASMA_STORE) + _test_component_failed(cluster, ray.node.PROCESS_TYPE_PLASMA_STORE) # No processes should be left alive on the worker nodes. - check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, False) - check_components_alive(ray.services.PROCESS_TYPE_RAYLET, False) + check_components_alive(cluster, ray.node.PROCESS_TYPE_PLASMA_STORE, False) + check_components_alive(cluster, ray.node.PROCESS_TYPE_RAYLET, False) def test_actor_creation_node_failure(ray_start_cluster): @@ -406,42 +431,39 @@ def test_actor_creation_node_failure(ray_start_cluster): @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Hanging with new GCS API.") -def test_driver_lives_sequential(): - ray.worker.init() - all_processes = ray.services.all_processes - processes = (all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE] + - all_processes[ray.services.PROCESS_TYPE_RAYLET]) +def test_driver_lives_sequential(shutdown_only): + ray.init(num_cpus=1) + ray.worker._global_node.kill_raylet() + ray.worker._global_node.kill_plasma_store() + ray.worker._global_node.kill_log_monitor() + ray.worker._global_node.kill_monitor() + ray.worker._global_node.kill_raylet_monitor() - # Kill all the components sequentially. - for process in processes: - process.terminate() - time.sleep(0.1) - process.kill() - process.wait() - - ray.shutdown() # If the driver can reach the tearDown method, then it is still alive. @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Hanging with new GCS API.") -def test_driver_lives_parallel(): - ray.worker.init() - all_processes = ray.services.all_processes - processes = (all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE] + - all_processes[ray.services.PROCESS_TYPE_RAYLET]) +def test_driver_lives_parallel(shutdown_only): + ray.init(num_cpus=1) + all_processes = ray.worker._global_node.all_processes + process_infos = (all_processes[ray.node.PROCESS_TYPE_PLASMA_STORE] + + all_processes[ray.node.PROCESS_TYPE_RAYLET] + + all_processes[ray.node.PROCESS_TYPE_LOG_MONITOR] + + all_processes[ray.node.PROCESS_TYPE_MONITOR] + + all_processes[ray.node.PROCESS_TYPE_RAYLET_MONITOR]) + assert len(process_infos) == 5 # Kill all the components in parallel. - for process in processes: - process.terminate() + for process_info in process_infos: + process_info.process.terminate() time.sleep(0.1) - for process in processes: - process.kill() + for process_info in process_infos: + process_info.process.kill() - for process in processes: - process.wait() + for process_info in process_infos: + process_info.process.wait() # If the driver can reach the tearDown method, then it is still alive. - ray.shutdown() diff --git a/test/failure_test.py b/test/failure_test.py index a38055786..93e6808bd 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -670,7 +670,7 @@ def test_raylet_crash_when_get(ray_start_regular): def sleep_to_kill_raylet(): # Don't kill raylet before default workers get connected. time.sleep(2) - ray.services.all_processes[ray.services.PROCESS_TYPE_RAYLET][0].kill() + ray.worker._global_node.kill_raylet() thread = threading.Thread(target=sleep_to_kill_raylet) thread.start() diff --git a/test/multi_node_test.py b/test/multi_node_test.py index e323751b2..3986aa02b 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -181,7 +181,6 @@ print("success") out = run_string_as_driver(driver_script2) # Make sure the first driver ran to completion. assert "success" in out - assert ray.services.all_processes_alive() @pytest.fixture diff --git a/test/multi_node_test_2.py b/test/multi_node_test_2.py index 85de1fe31..f0334f66c 100644 --- a/test/multi_node_test_2.py +++ b/test/multi_node_test_2.py @@ -8,7 +8,6 @@ import pytest import time import ray -import ray.services as services from ray.test.cluster_utils import Cluster logger = logging.getLogger(__name__) @@ -55,8 +54,8 @@ def test_cluster(): g = Cluster(initialize_head=False) node = g.add_node() node2 = g.add_node() - assert node.all_processes_alive() - assert node2.all_processes_alive() + assert node.remaining_processes_alive() + assert node2.remaining_processes_alive() g.remove_node(node2) g.remove_node(node) assert not any(n.any_processes_alive() for n in [node, node2]) @@ -117,5 +116,5 @@ def test_worker_plasma_store_failure(start_connected_cluster): # Log monitor doesn't die for some reason worker.kill_log_monitor() worker.kill_plasma_store() - worker.process_dict[services.PROCESS_TYPE_RAYLET][0].wait() + worker.all_processes[ray.node.PROCESS_TYPE_RAYLET][0].process.wait() assert not worker.any_processes_alive(), worker.live_processes() diff --git a/test/runtest.py b/test/runtest.py index 080fef4b1..16a68f98e 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1242,7 +1242,7 @@ def test_multithreading(shutdown_only): ready, _ = ray.wait( objects, num_returns=len(objects), - timeout=1000, + timeout=1000.0, ) assert len(ready) == num_wait_objects assert ray.get(ready) == list(range(num_wait_objects)) @@ -1273,7 +1273,7 @@ def test_multithreading(shutdown_only): ready, _ = ray.wait( wait_objects, num_returns=len(wait_objects), - timeout=1000, + timeout=1000.0, ) assert len(ready) == len(wait_objects) for _ in range(50): diff --git a/test/stress_tests.py b/test/stress_tests.py index fe40074dd..8c4f20d98 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -48,13 +48,15 @@ def ray_start_combination(request): cluster.add_node(num_cpus=10) ray.init(redis_address=cluster.redis_address) - yield num_nodes, num_workers_per_scheduler + yield num_nodes, num_workers_per_scheduler, cluster # The code after the yield will run as teardown code. ray.shutdown() cluster.shutdown() def test_submitting_tasks(ray_start_combination): + _, _, cluster = ray_start_combination + @ray.remote def f(x): return x @@ -71,10 +73,12 @@ def test_submitting_tasks(ray_start_combination): for _ in range(1000): ray.get([f.remote(1) for _ in range(1)]) - assert ray.services.all_processes_alive() + assert cluster.remaining_processes_alive() def test_dependencies(ray_start_combination): + _, _, cluster = ray_start_combination + @ray.remote def f(x): return x @@ -94,7 +98,7 @@ def test_dependencies(ray_start_combination): xs.append(g.remote(1)) ray.get(xs) - assert ray.services.all_processes_alive() + assert cluster.remaining_processes_alive() def test_submitting_many_tasks(ray_start_sharded): @@ -109,7 +113,7 @@ def test_submitting_many_tasks(ray_start_sharded): return x ray.get([g(1000) for _ in range(100)]) - assert ray.services.all_processes_alive() + assert ray.services.remaining_processes_alive() def test_submitting_many_actors_to_one(ray_start_sharded): @@ -147,7 +151,7 @@ def test_getting_and_putting(ray_start_sharded): for _ in range(1000): ray.get(x_id) - assert ray.services.all_processes_alive() + assert ray.services.remaining_processes_alive() def test_getting_many_objects(ray_start_sharded): @@ -159,11 +163,11 @@ def test_getting_many_objects(ray_start_sharded): lst = ray.get([f.remote() for _ in range(n)]) assert lst == n * [1] - assert ray.services.all_processes_alive() + assert ray.services.remaining_processes_alive() def test_wait(ray_start_combination): - num_nodes, num_workers_per_scheduler = ray_start_combination + num_nodes, num_workers_per_scheduler, cluster = ray_start_combination num_workers = num_nodes * num_workers_per_scheduler @ray.remote @@ -186,7 +190,7 @@ def test_wait(ray_start_combination): ] ray.wait(x_ids, num_returns=len(x_ids)) - assert ray.services.all_processes_alive() + assert cluster.remaining_processes_alive() @pytest.fixture(params=[1, 4]) @@ -263,8 +267,7 @@ def test_simple(ray_start_reconstruction): values = ray.get(args[i * chunk:(i + 1) * chunk]) del values - for node in cluster.list_all_nodes(): - assert node.all_processes_alive() + assert cluster.remaining_processes_alive() def sorted_random_indexes(total, output_num): @@ -328,8 +331,7 @@ def test_recursive(ray_start_reconstruction): values = ray.get(args[i * chunk:(i + 1) * chunk]) del values - for node in cluster.list_all_nodes(): - assert node.all_processes_alive() + assert cluster.remaining_processes_alive() @pytest.mark.skip(reason="This test often hangs or fails in CI.") @@ -386,8 +388,7 @@ def test_multiple_recursive(ray_start_reconstruction): value = ray.get(args[i]) assert value[0] == i - for node in cluster.list_all_nodes(): - assert node.all_processes_alive() + assert cluster.remaining_processes_alive() def wait_for_errors(error_check): @@ -472,8 +473,7 @@ def test_nondeterministic_task(ray_start_reconstruction): assert all(error["type"] == ray_constants.HASH_MISMATCH_PUSH_ERROR for error in errors) - for node in cluster.list_all_nodes(): - assert node.all_processes_alive() + assert cluster.remaining_processes_alive() @pytest.fixture diff --git a/test/tempfile_test.py b/test/tempfile_test.py index d931a9fe7..17d710403 100644 --- a/test/tempfile_test.py +++ b/test/tempfile_test.py @@ -72,8 +72,8 @@ def test_raylet_tempfiles(): assert log_files == { "log_monitor.out", "log_monitor.err", "plasma_store.out", "plasma_store.err", "webui.out", "webui.err", "monitor.out", - "monitor.err", "redis-shard_0.out", "redis-shard_0.err", "redis.out", - "redis.err" + "monitor.err", "raylet_monitor.out", "raylet_monitor.err", + "redis-shard_0.out", "redis-shard_0.err", "redis.out", "redis.err" } # without raylet logs socket_files = set(os.listdir(tempfile_services.get_sockets_dir_path())) assert socket_files == {"plasma_store", "raylet"} @@ -86,8 +86,9 @@ def test_raylet_tempfiles(): assert log_files == { "log_monitor.out", "log_monitor.err", "plasma_store.out", "plasma_store.err", "webui.out", "webui.err", "monitor.out", - "monitor.err", "redis-shard_0.out", "redis-shard_0.err", "redis.out", - "redis.err", "raylet.out", "raylet.err" + "monitor.err", "raylet_monitor.out", "raylet_monitor.err", + "redis-shard_0.out", "redis-shard_0.err", "redis.out", "redis.err", + "raylet.out", "raylet.err" } # with raylet logs socket_files = set(os.listdir(tempfile_services.get_sockets_dir_path())) assert socket_files == {"plasma_store", "raylet"} @@ -101,8 +102,9 @@ def test_raylet_tempfiles(): assert log_files.issuperset({ "log_monitor.out", "log_monitor.err", "plasma_store.out", "plasma_store.err", "webui.out", "webui.err", "monitor.out", - "monitor.err", "redis-shard_0.out", "redis-shard_0.err", "redis.out", - "redis.err", "raylet.out", "raylet.err" + "monitor.err", "raylet_monitor.out", "raylet_monitor.err", + "redis-shard_0.out", "redis-shard_0.err", "redis.out", "redis.err", + "raylet.out", "raylet.err" }) # with raylet logs # Check numbers of worker log file.