diff --git a/python/ray/actor.py b/python/ray/actor.py index 06cdc3f94..1501b5b11 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -15,7 +15,8 @@ import ray.local_scheduler import ray.signature as signature import ray.worker from ray.utils import (binary_to_hex, FunctionProperties, random_string, - release_gpus_in_use, select_local_scheduler, is_cython) + release_gpus_in_use, select_local_scheduler, is_cython, + push_error_to_driver) def random_actor_id(): @@ -252,9 +253,9 @@ def fetch_and_register_actor(actor_class_key, worker): # traceback and notify the scheduler of the failure. traceback_str = ray.worker.format_error_message(traceback.format_exc()) # Log the error message. - worker.push_error_to_driver(driver_id, "register_actor_signatures", - traceback_str, - data={"actor_id": actor_id_str}) + push_error_to_driver(worker.redis_client, "register_actor_signatures", + traceback_str, driver_id, + data={"actor_id": actor_id_str}) # TODO(rkn): In the future, it might make sense to have the worker exit # here. However, currently that would lead to hanging if someone calls # ray.get on a method invoked on the actor. diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 26913ef0c..5dd484174 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -3,16 +3,12 @@ from __future__ import division from __future__ import print_function import click -import redis import subprocess import ray.services as services -def check_no_existing_redis_clients(node_ip_address, redis_address): - redis_ip_address, redis_port = redis_address.split(":") - redis_client = redis.StrictRedis(host=redis_ip_address, - port=int(redis_port)) +def check_no_existing_redis_clients(node_ip_address, redis_client): # The client table prefix must be kept in sync with the file # "src/common/redis_module/ray_redis_module.cc" where it is defined. REDIS_CLIENT_TABLE_PREFIX = "CL:" @@ -158,9 +154,18 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, raise Exception("If --head is not passed in, the --no-ui flag is " "not relevant.") redis_ip_address, redis_port = redis_address.split(":") + # Wait for the Redis server to be started. And throw an exception if we # can't connect to it. services.wait_for_redis_to_start(redis_ip_address, int(redis_port)) + + # Create a Redis client. + redis_client = services.create_redis_client(redis_address) + + # Check that the verion information on this node matches the version + # information that the cluster was started with. + services.check_version_info(redis_client) + # Get the node IP address if one is not provided. if node_ip_address is None: node_ip_address = services.get_node_ip_address(redis_address) @@ -168,7 +173,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, # Check that there aren't already Redis clients with the same IP # address connected with this Redis instance. This raises an exception # if the Redis server already has clients on this node. - check_no_existing_redis_clients(node_ip_address, redis_address) + check_no_existing_redis_clients(node_ip_address, redis_client) address_info = services.start_ray_node( node_ip_address=node_ip_address, redis_address=redis_address, diff --git a/python/ray/services.py b/python/ray/services.py index 38e19d472..8871ab549 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -226,6 +226,21 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files): redis_client.rpush(log_file_list_key, log_file.name) +def create_redis_client(redis_address): + """Create a Redis client. + + Args: + The IP address and port of the Redis server. + + Returns: + A Redis client. + """ + redis_ip_address, redis_port = redis_address.split(":") + # For this command to work, some other client (on the same machine + # as Redis) must have run "CONFIG SET protected-mode no". + return redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) + + def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): """Wait for a Redis server to be available. diff --git a/python/ray/utils.py b/python/ray/utils.py index 446407fa9..7e91b8c88 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -11,6 +11,37 @@ import sys import ray.local_scheduler +ERROR_KEY_PREFIX = b"Error:" +DRIVER_ID_LENGTH = 20 + + +def _random_string(): + return np.random.bytes(20) + + +def push_error_to_driver(redis_client, error_type, message, driver_id=None, + data=None): + """Push an error message to the driver to be printed in the background. + + Args: + redis_client: The redis client to use. + error_type (str): The type of the error. + message (str): The message that will be printed in the background + on the driver. + driver_id: The ID of the driver to push the error message to. If this + is None, then the message will be pushed to all drivers. + data: This should be a dictionary mapping strings to strings. It + will be serialized with json and stored in Redis. + """ + if driver_id is None: + driver_id = DRIVER_ID_LENGTH * b"\x00" + error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string() + data = {} if data is None else data + redis_client.hmset(error_key, {"type": error_type, + "message": message, + "data": data}) + redis_client.rpush("ErrorKeys", error_key) + def is_cython(obj): """Check if an object is a Cython function or method""" diff --git a/python/ray/worker.py b/python/ray/worker.py index 223b351a5..dff99df55 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -407,9 +407,10 @@ class Worker(object): "object store. This may be fine, or it " "may be a bug.") if not warning_sent: - self.push_error_to_driver(self.task_driver_id.id(), - "wait_for_class", - warning_message) + ray.utils.push_error_to_driver( + self.redis_client, "wait_for_class", + warning_message, + driver_id=self.task_driver_id.id()) warning_sent = True def get_object(self, object_ids): @@ -599,24 +600,6 @@ class Worker(object): # operations into a transaction (or by implementing a custom # command that does all three things). - def push_error_to_driver(self, driver_id, error_type, message, data=None): - """Push an error message to the driver to be printed in the background. - - Args: - driver_id: The ID of the driver to push the error message to. - error_type (str): The type of the error. - message (str): The message that will be printed in the background - on the driver. - data: This should be a dictionary mapping strings to strings. It - will be serialized with json and stored in Redis. - """ - error_key = ERROR_KEY_PREFIX + driver_id + b":" + random_string() - data = {} if data is None else data - self.redis_client.hmset(error_key, {"type": error_type, - "message": message, - "data": data}) - self.redis_client.rpush("ErrorKeys", error_key) - def _wait_for_function(self, function_id, driver_id, timeout=10): """Wait until the function to be executed is present on this worker. @@ -651,9 +634,10 @@ class Worker(object): "registered. You may have to restart " "Ray.") if not warning_sent: - self.push_error_to_driver(driver_id, - "wait_for_function", - warning_message) + ray.utils.push_error_to_driver(self.redis_client, + "wait_for_function", + warning_message, + driver_id=driver_id) warning_sent = True time.sleep(0.001) @@ -808,10 +792,12 @@ class Worker(object): range(len(return_object_ids))] self._store_outputs_in_objstore(return_object_ids, failure_objects) # Log the error message. - self.push_error_to_driver(self.task_driver_id.id(), "task", - str(failure_object), - data={"function_id": function_id.id(), - "function_name": function_name}) + ray.utils.push_error_to_driver(self.redis_client, + "task", + str(failure_object), + driver_id=self.task_driver_id.id(), + data={"function_id": function_id.id(), + "function_name": function_name}) def _wait_for_and_process_task(self, task): """Wait for a task to be ready and process the task. @@ -1552,10 +1538,12 @@ def fetch_and_register_remote_function(key, worker=global_worker): # record the traceback and notify the scheduler of the failure. traceback_str = format_error_message(traceback.format_exc()) # Log the error message. - worker.push_error_to_driver(driver_id, "register_remote_function", - traceback_str, - data={"function_id": function_id.id(), - "function_name": function_name}) + ray.utils.push_error_to_driver(worker.redis_client, + "register_remote_function", + traceback_str, + driver_id=driver_id, + data={"function_id": function_id.id(), + "function_name": function_name}) else: # TODO(rkn): Why is the below line necessary? function.__module__ = module @@ -1582,8 +1570,11 @@ def fetch_and_execute_function_to_run(key, worker=global_worker): # Log the error message. name = function.__name__ if ("function" in locals() and hasattr(function, "__name__")) else "" - worker.push_error_to_driver(driver_id, "function_to_run", - traceback_str, data={"name": name}) + ray.utils.push_error_to_driver(worker.redis_client, + "function_to_run", + traceback_str, + driver_id=driver_id, + data={"name": name}) def import_thread(worker, mode): @@ -1714,9 +1705,19 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, worker.redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) - # Check that the version information matches the version information that - # the Ray cluster was started with. - ray.services.check_version_info(worker.redis_client) + # For driver's check that the version information matches the version + # information that the Ray cluster was started with. + try: + ray.services.check_version_info(worker.redis_client) + except Exception as e: + if mode in [SCRIPT_MODE, SILENT_MODE]: + raise e + elif mode == WORKER_MODE: + traceback_str = traceback.format_exc() + ray.utils.push_error_to_driver(worker.redis_client, + "version_mismatch", + traceback_str, + driver_id=None) worker.lock = threading.Lock() diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 6e480ade5..01a45391d 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -4,8 +4,6 @@ from __future__ import print_function import argparse import binascii -import numpy as np -import redis import traceback import ray @@ -30,36 +28,6 @@ parser.add_argument("--reconstruct", action="store_true", "mode")) -def random_string(): - return np.random.bytes(20) - - -def create_redis_client(redis_address): - redis_ip_address, redis_port = redis_address.split(":") - # For this command to work, some other client (on the same machine - # as Redis) must have run "CONFIG SET protected-mode no". - return redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) - - -def push_error_to_all_drivers(redis_client, message, error_type): - """Push an error message to all drivers. - - Args: - redis_client: The redis client to use. - message: The error message to push. - error_type: The type of the error. - """ - DRIVER_ID_LENGTH = 20 - # We use a driver ID of all zeros to push an error message to all - # drivers. - driver_id = DRIVER_ID_LENGTH * b"\x00" - error_key = b"Error:" + driver_id + b":" + random_string() - # Create a Redis client. - redis_client.hmset(error_key, {"type": error_type, - "message": message}) - redis_client.rpush("ErrorKeys", error_key) - - if __name__ == "__main__": args = parser.parse_args() @@ -80,13 +48,6 @@ if __name__ == "__main__": ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id) - try: - ray.services.check_version_info(ray.worker.global_worker.redis_client) - except Exception as e: - traceback_str = traceback.format_exc() - push_error_to_all_drivers(ray.worker.global_worker.redis_client, - traceback_str, "version_mismatch") - error_explanation = """ This error is unexpected and should not have happened. Somehow a worker crashed in an unanticipated way causing the main_loop to throw an exception, @@ -103,8 +64,9 @@ if __name__ == "__main__": except Exception as e: traceback_str = traceback.format_exc() + error_explanation # Create a Redis client. - redis_client = create_redis_client(args.redis_address) - push_error_to_all_drivers(redis_client, traceback_str, "worker_crash") + redis_client = ray.services.create_redis_client(args.redis_address) + ray.utils.push_error_to_driver(redis_client, "worker_crash", + traceback_str, driver_id=None) # TODO(rkn): Note that if the worker was in the middle of executing # a task, then any worker or driver that is blocking in a get call # and waiting for the output of that task will hang. We need to