diff --git a/python/ray/actor.py b/python/ray/actor.py index 926f15b29..41c0d8205 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -668,6 +668,16 @@ class ActorHandle(object): # there are ANY handles in scope in the process that created the actor, # not just the first one. worker = ray.worker.get_global_worker() + if (worker.mode == ray.worker.SCRIPT_MODE + and self._ray_actor_driver_id.id() != worker.worker_id): + # If the worker is a driver and driver id has changed because + # Ray was shut down re-initialized, the actor is already cleaned up + # and we don't need to send `__ray_terminate__` again. + logger.warn( + "Actor is garbage collected in the wrong driver." + + " Actor id = %s, class name = %s.", self._ray_actor_id, + self._ray_class_name) + return if worker.connected and self._ray_original_handle: # TODO(rkn): Should we be passing in the actor cursor as a # dependency here? diff --git a/python/ray/worker.py b/python/ray/worker.py index de2513780..0393ccc23 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -13,6 +13,7 @@ import numpy as np import os import redis import signal +from six.moves import queue import sys import threading import time @@ -97,82 +98,34 @@ class RayTaskError(Exception): traceback_str (str): The traceback from the exception. """ - def __init__(self, function_name, exception, traceback_str): + def __init__(self, function_name, traceback_str): """Initialize a RayTaskError.""" - self.function_name = function_name - if (isinstance(exception, RayGetError) - or isinstance(exception, RayGetArgumentError)): - self.exception = exception + if setproctitle: + self.proctitle = setproctitle.getproctitle() else: - self.exception = None + self.proctitle = "ray_worker" + self.pid = os.getpid() + self.host = os.uname()[1] + self.function_name = function_name self.traceback_str = traceback_str def __str__(self): """Format a RayTaskError as a string.""" - if self.traceback_str is None: - # This path is taken if getting the task arguments failed. - return ("Remote function {}{}{} failed with:\n\n{}".format( - colorama.Fore.RED, self.function_name, colorama.Fore.RESET, - self.exception)) - else: - # This path is taken if the task execution failed. - return ("Remote function {}{}{} failed with:\n\n{}".format( - colorama.Fore.RED, self.function_name, colorama.Fore.RESET, - self.traceback_str)) - - -class RayGetError(Exception): - """An exception used when get is called on an output of a failed task. - - Attributes: - objectid (lib.ObjectID): The ObjectID that get was called on. - task_error (RayTaskError): The RayTaskError object created by the - failed task. - """ - - def __init__(self, objectid, task_error): - """Initialize a RayGetError object.""" - self.objectid = objectid - self.task_error = task_error - - def __str__(self): - """Format a RayGetError as a string.""" - return ("Could not get objectid {}. It was created by remote function " - "{}{}{} which failed with:\n\n{}".format( - self.objectid, colorama.Fore.RED, - self.task_error.function_name, colorama.Fore.RESET, - self.task_error)) - - -class RayGetArgumentError(Exception): - """An exception used when a task's argument was produced by a failed task. - - Attributes: - argument_index (int): The index (zero indexed) of the failed argument - in present task's remote function call. - function_name (str): The name of the function for the current task. - objectid (lib.ObjectID): The ObjectID that was passed in as the - argument. - task_error (RayTaskError): The RayTaskError object created by the - failed task. - """ - - def __init__(self, function_name, argument_index, objectid, task_error): - """Initialize a RayGetArgumentError object.""" - self.argument_index = argument_index - self.function_name = function_name - self.objectid = objectid - self.task_error = task_error - - def __str__(self): - """Format a RayGetArgumentError as a string.""" - return ("Failed to get objectid {} as argument {} for remote function " - "{}{}{}. It was created by remote function {}{}{} which " - "failed with:\n{}".format( - self.objectid, self.argument_index, colorama.Fore.RED, - self.function_name, colorama.Fore.RESET, colorama.Fore.RED, - self.task_error.function_name, colorama.Fore.RESET, - self.task_error)) + lines = self.traceback_str.split("\n") + out = [] + in_worker = False + for line in lines: + if line.startswith("Traceback "): + out.append("{}{}{} (pid={}, host={})".format( + colorama.Fore.CYAN, self.proctitle, colorama.Fore.RESET, + self.pid, self.host)) + elif in_worker: + in_worker = False + elif "ray/worker.py" in line or "ray/function_manager.py" in line: + in_worker = True + else: + out.append(line) + return "\n".join(out) class Worker(object): @@ -449,7 +402,7 @@ class Worker(object): # TODO(ekl): the local scheduler could include relevant # metadata in the task kill case for a better error message invalid_error = RayTaskError( - "", None, + "", "Invalid return value: likely worker died or was killed " "while executing the task; check previous logs or dmesg " "for errors.") @@ -757,7 +710,7 @@ class Worker(object): passed by value. Raises: - RayGetArgumentError: This exception is raised if a task that + RayTaskError: This exception is raised if a task that created one of the arguments failed. """ arguments = [] @@ -766,10 +719,7 @@ class Worker(object): # get the object from the local object store argument = self.get_object([arg])[0] if isinstance(argument, RayTaskError): - # If the result is a RayTaskError, then the task that - # created this object failed, and we should propagate the - # error message here. - raise RayGetArgumentError(function_name, i, arg, argument) + raise argument else: # pass the argument by value argument = arg @@ -842,7 +792,7 @@ class Worker(object): with profiling.profile("task:deserialize_arguments", worker=self): arguments = self._get_arguments_for_execution( function_name, args) - except (RayGetError, RayGetArgumentError) as e: + except RayTaskError as e: self._handle_process_task_failure(function_id, function_name, return_object_ids, e, None) return @@ -889,7 +839,7 @@ class Worker(object): def _handle_process_task_failure(self, function_id, function_name, return_object_ids, error, backtrace): - failure_object = RayTaskError(function_name, error, backtrace) + failure_object = RayTaskError(function_name, backtrace) failure_objects = [ failure_object for _ in range(len(return_object_ids)) ] @@ -1196,18 +1146,6 @@ def _initialize_serialization(driver_id, worker=global_worker): local=True, driver_id=driver_id, class_id="ray.RayTaskError") - register_custom_serializer( - RayGetError, - use_dict=True, - local=True, - driver_id=driver_id, - class_id="ray.RayGetError") - register_custom_serializer( - RayGetArgumentError, - use_dict=True, - local=True, - driver_id=driver_id, - class_id="ray.RayGetArgumentError") # Tell Ray to serialize lambdas with pickle. register_custom_serializer( type(lambda: 0), @@ -1833,12 +1771,38 @@ def custom_excepthook(type, value, tb): sys.excepthook = custom_excepthook +# The last time we raised a TaskError in this process. We use this value to +# suppress redundant error messages pushed from the workers. +last_task_error_raise_time = 0 -def print_error_messages_raylet(worker): - """Print error messages in the background on the driver. +# The max amount of seconds to wait before printing out an uncaught error. +UNCAUGHT_ERROR_GRACE_PERIOD = 5 - This runs in a separate thread on the driver and prints error messages in - the background. + +def print_error_messages_raylet(task_error_queue): + """Prints message received in the given output queue. + + This checks periodically if any un-raised errors occured in the background. + """ + + while True: + error, t = task_error_queue.get() + # Delay errors a little bit of time to attempt to suppress redundant + # messages originating from the worker. + while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time(): + time.sleep(1) + if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD: + logger.debug("Suppressing error from worker: {}".format(error)) + else: + logger.error( + "Possible unhandled error from worker: {}".format(error)) + + +def listen_error_messages_raylet(worker, task_error_queue): + """Listen to error messages in the background on the driver. + + This runs in a separate thread on the driver and pushes (error, time) + tuples to the output queue. """ worker.error_message_pubsub_client = worker.redis_client.pubsub( ignore_subscribe_messages=True) @@ -1875,7 +1839,12 @@ def print_error_messages_raylet(worker): continue error_message = ray.utils.decode(error_data.ErrorMessage()) - logger.error(error_message) + if (ray.utils.decode( + error_data.Type()) == ray_constants.TASK_PUSH_ERROR): + # Delay it a bit to see if we can suppress it + task_error_queue.put((error_message, time.time())) + else: + logger.error(error_message) except redis.ConnectionError: # When Redis terminates the listen call will throw a ConnectionError, @@ -2164,14 +2133,19 @@ def connect(info, # temporarily using this implementation which constantly queries the # scheduler for new error messages. if mode == SCRIPT_MODE: - t = threading.Thread( + q = queue.Queue() + listener = threading.Thread( + target=listen_error_messages_raylet, + name="ray_listen_error_messages", + args=(worker, q)) + printer = threading.Thread( target=print_error_messages_raylet, name="ray_print_error_messages", - args=(worker, )) - # Making the thread a daemon causes it to exit when the main thread - # exits. - t.daemon = True - t.start() + args=(q, )) + listener.daemon = True + listener.start() + printer.daemon = True + printer.start() # If we are using the raylet code path and we are not in local mode, start # a background thread to periodically flush profiling data to the GCS. @@ -2399,11 +2373,13 @@ def get(object_ids, worker=global_worker): # In LOCAL_MODE, ray.get is the identity operation (the input will # actually be a value not an objectid). return object_ids + global last_task_error_raise_time if isinstance(object_ids, list): values = worker.get_object(object_ids) for i, value in enumerate(values): if isinstance(value, RayTaskError): - raise RayGetError(object_ids[i], value) + last_task_error_raise_time = time.time() + raise value return values else: value = worker.get_object([object_ids])[0] @@ -2411,7 +2387,8 @@ def get(object_ids, worker=global_worker): # If the result is a RayTaskError, then the task that created # this object failed, and we should propagate the error message # here. - raise RayGetError(object_ids, value) + last_task_error_raise_time = time.time() + raise value return value diff --git a/test/actor_test.py b/test/actor_test.py index a9e343563..bf1b98a13 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -1300,7 +1300,7 @@ def test_exception_raised_when_actor_node_dies(shutdown_only): # Submit some new actor tasks. x_ids = [actor.inc.remote() for _ in range(5)] for x_id in x_ids: - with pytest.raises(ray.worker.RayGetError): + with pytest.raises(ray.worker.RayTaskError): # There is some small chance that ray.get will actually # succeed (if the object is transferred before the raylet # dies). diff --git a/test/component_failures_test.py b/test/component_failures_test.py index fd09a1759..30071b3c1 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -408,7 +408,7 @@ def test_actor_creation_node_failure(ray_start_cluster): for i, out in enumerate(children_out): try: ray.get(out) - except ray.worker.RayGetError: + except ray.worker.RayTaskError: children[i] = Child.remote(death_probability) # Remove a node. Any actor creation tasks that were forwarded to this # node must be reconstructed.