From de200ff9120fdd313d94cdc5e600c4747e0eff38 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Wed, 3 Aug 2016 16:47:13 -0700 Subject: [PATCH] better error messages when composing remote functions (#339) Better error messages when composing remote functions --- lib/python/ray/worker.py | 232 +++++++++++++++++++++++++++++---------- 1 file changed, 175 insertions(+), 57 deletions(-) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 593f91581..c750327bd 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -28,61 +28,168 @@ WORKER_MODE = 1 PYTHON_MODE = 2 SILENT_MODE = 3 # This is only used during testing. -class RayFailedObject(object): +class RayTaskError(Exception): """An object used internally to represent a task that threw an exception. - If a task throws an exception during execution, a RayFailedObject is stored in - the object store for each of the tasks outputs. When an object is retrieved - from the object store, the Python method that retrieved it should check to see - if the object is a RayFailedObject and if it is then an exception should be - thrown containing the error message. + If a task throws an exception during execution, a RayTaskError is stored in + the object store for each of the task's outputs. When an object is retrieved + from the object store, the Python method that retrieved it checks to see if + the object is a RayTaskError and if it is then an exceptionis thrown + propagating the error message. - Attributes - error_message (str): The error message raised by the task that failed. + Currently, we either use the exception attribute or the traceback attribute + but not both. + + Attributes: + function_name (str): The name of the function that failed and produced the + RayTaskError. + exception (Exception): The exception object thrown by the failed task. + traceback_str (str): The traceback from the exception. """ - def __init__(self, error_message): - """Initialize a RayFailedObject. - - Args: - error_message (str): The error message raised by the task for which a - RayFailedObject is being created. - """ - self.error_message = error_message + def __init__(self, function_name, exception, traceback_str): + """Initialize a RayTaskError.""" + self.function_name = function_name + if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError) or isinstance(exception, RayGetArgumentTypeError): + self.exception = exception + else: + self.exception = None + self.traceback_str = traceback_str @staticmethod def deserialize(primitives): - """Create a RayFailedObject from a primitive object. - - This initializes a RayFailedObject from a primitive object created by the - serialize method. This method is required in order for Ray to serialize - custom Python classes. - - Note: - This method should not be called by users. - - Args: - primitives (str): The object's error message. - """ - return RayFailedObject(primitives) + """Create a RayTaskError from a primitive object.""" + function_name, exception, traceback_str = primitives + if exception[0] == "RayGetError": + exception = RayGetError.deserialize(exception[1]) + elif exception[0] == "RayGetArgumentError": + exception = RayGetArgumentError.deserialize(exception[1]) + elif exception[0] == "RayGetArgumentTypeError": + exception = RayGetArgumentTypeError.deserialize(exception[1]) + elif exception[0] == "None": + exception = None + else: + assert False, "This code should be unreachable." + return RayTaskError(function_name, exception, traceback_str) def serialize(self): - """Turn a RayFailedObject into a primitive object. + """Turn a RayTaskError into a primitive object.""" + if isinstance(self.exception, RayGetError): + serialized_exception = ("RayGetError", self.exception.serialize()) + elif isinstance(self.exception, RayGetArgumentError): + serialized_exception = ("RayGetArgumentError", self.exception.serialize()) + elif isinstance(self.exception, RayGetArgumentTypeError): + serialized_exception = ("RayGetArgumentTypeError", self.exception.serialize()) + elif self.exception is None: + serialized_exception = ("None",) + else: + assert False, "This code should be unreachable." + return (self.function_name, serialized_exception, self.traceback_str) - This method is required in order for Ray to serialize - custom Python classes. + def __str__(self): + """Format a RayTaskError as a string.""" + if self.traceback_str is None: + # This path is taken if getting the task arguments failed. + return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.exception) + else: + # This path is taken if the task execution failed. + return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.traceback_str) - Note: - The output of this method should only be used by the deserialize method. - This method should not be called by users. +class RayGetError(Exception): + """An exception used when get is called on an output of a failed task. - Args: - primitives (str): The object's error message. + Attributes: + objectid (lib.ObjectID): The ObjectID that get was called on. + task_error (RayTaskError): The RayTaskError object created by the failed + task. + """ - Returns: - A primitive representation of a RayFailedObject. - """ - return self.error_message + def __init__(self, objectid, task_error): + """Initialize a RayGetError object.""" + self.objectid = objectid + self.task_error = task_error + + @staticmethod + def deserialize(primitives): + """Create a RayGetError from a primitive object.""" + objectid, task_error = primitives + return RayGetError(objectid, RayTaskError.deserialize(task_error)) + + def serialize(self): + """Turn a RayGetError into a primitive object.""" + return (self.objectid, self.task_error.serialize()) + + def __str__(self): + """Format a RayGetError as a string.""" + return "Could not get objectid {}. It was created by remote function {}{}{} which failed with:\n\n{}".format(self.objectid, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error) + +class RayGetArgumentError(Exception): + """An exception used when a task's argument was produced by a failed task. + + Attributes: + argument_index (int): The index (zero indexed) of the failed argument in + present task's remote function call. + function_name (str): The name of the function for the current task. + objectid (lib.ObjectID): The ObjectID that was passed in as the argument. + task_error (RayTaskError): The RayTaskError object created by the failed + task. + """ + + def __init__(self, function_name, argument_index, objectid, task_error): + """Initialize a RayGetArgumentError object.""" + self.argument_index = argument_index + self.function_name = function_name + self.objectid = objectid + self.task_error = task_error + + @staticmethod + def deserialize(primitives): + """Create a RayGetArgumentError from a primitive object.""" + function_name, argument_index, objectid, task_error = primitives + return RayGetArgumentError(function_name, argument_index, objectid, RayTaskError.deserialize(task_error)) + + def serialize(self): + """Turn a RayGetArgumentError into a primitive object.""" + return (self.function_name, self.argument_index, self.objectid, self.task_error.serialize()) + + def __str__(self): + """Format a RayGetArgumentError as a string.""" + return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error) + +class RayGetArgumentTypeError(Exception): + """An exception used when a task's argument doesn't type check. + + Attributes: + function_name (str): The name of the function for the current task. + argument_index (int): The index (zero indexed) of the argument in the + present task's remote function call. + received_type: The type of the argument that was passed in. + expected_type: The type that was expected. This is determined by the remote + decorator. + """ + + def __init__(self, function_name, argument_index, received_type, expected_type): + """Initialize a RayGetArgumentTypeError object.""" + self.function_name = function_name + self.argument_index = argument_index + # TODO(rkn): when we support the serialization of types, then we should + # remove the string conversions below. + self.received_type = str(received_type) + self.expected_type = str(expected_type) + + @staticmethod + def deserialize(primitives): + """Create a RayGetArgumentTypeError from a primitive object.""" + function_name, argument_index, received_type, expected_type = primitives + return RayGetArgumentTypeError(function_name, argument_index, received_type, expected_type) + + def serialize(self): + """Turn a RayGetArgumentTypeError into a primitive object.""" + return (self.function_name, self.argument_index, self.received_type, self.expected_type) + + def __str__(self): + """Format a RayGetArgumentTypeError as a string.""" + return "Argument {} for remote function {}{}{} has type {} but an argument of type {} was expected.".format(self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.received_type, self.expected_type) class RayDealloc(object): """An object used internally to properly implement reference counting. @@ -689,8 +796,10 @@ def get(objectid, worker=global_worker): if worker.mode == SCRIPT_MODE: worker.print_new_failures() value = worker.get_object(objectid) - if isinstance(value, RayFailedObject): - raise Exception("The task that created this object ID failed with error message:\n{}".format(value.error_message)) + if isinstance(value, RayTaskError): + # If the result is a RayTaskError, then the task that created this object + # failed, and we should propagate the error message here. + raise RayGetError(objectid, value) return value def put(value, worker=global_worker): @@ -749,7 +858,7 @@ def restart_workers_local(num_workers, worker_path, worker=global_worker): def format_error_message(exception_message): """Improve the formatting of an exception thrown by a remote function. - This method takes an backtrace from an exception and makes it nicer by + This method takes a traceback from an exception and makes it nicer by removing a few uninformative lines and adding some space to indent the remaining lines nicely. @@ -763,7 +872,6 @@ def format_error_message(exception_message): # Remove lines 1, 2, 3, and 4, which are always the same, they just contain # information about the main loop. lines = lines[0:1] + lines[5:] - lines = [10 * " " + line for line in lines] return "\n".join(lines) def main_loop(worker=global_worker): @@ -778,7 +886,7 @@ def main_loop(worker=global_worker): If the process of getting the arguments for execution (which does some type checking) or the process of executing the task fail, then the main loop will - catch the exception and store RayFailedObject objects containing the relevant + catch the exception and store RayTaskError objects containing the relevant error messages in the object store in place of the actual outputs. These objects are used to propagate the error messages. """ @@ -792,14 +900,16 @@ def main_loop(worker=global_worker): outputs = worker.functions[func_name].executor(arguments) # execute the function if len(return_objectids) == 1: outputs = (outputs,) - except Exception: - exception_message = format_error_message(traceback.format_exc()) - # Here we are storing RayFailedObjects in the object store to indicate - # failure (this is only interpreted by the worker). - failure_objects = [RayFailedObject(exception_message) for _ in range(len(return_objectids))] + except Exception as e: + # If the task threw an exception, then record the traceback. We determine + # whether the exception was thrown in the task execution by whether the + # variable "arguments" is defined. + traceback_str = format_error_message(traceback.format_exc()) if "arguments" in locals() else None + failure_object = RayTaskError(func_name, e, traceback_str) + failure_objects = [failure_object for _ in range(len(return_objectids))] store_outputs_in_objstore(return_objectids, failure_objects, worker) - raylib.notify_task_completed(worker.handle, False, exception_message) # notify the scheduler that the task threw an exception - _logger().info("Worker threw exception with message: \n\n{}\n, while running function {}.".format(exception_message, func_name)) + raylib.notify_task_completed(worker.handle, False, str(failure_object)) + _logger().info("Worker threw exception with message: \n\n{}\n, while running function {}.".format(str(failure_object), func_name)) else: store_outputs_in_objstore(return_objectids, outputs, worker) # store output in local object store raylib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully @@ -1013,7 +1123,8 @@ def typecheck_arg(arg, expected_type, i, name): name (str): The name of the function. Raises: - Exception: An exception is raised if arg does not have the expected type. + RayGetArgumentTypeError: An exception is raised if arg does not have the + expected type. """ if issubclass(type(arg), expected_type): # Passed the type-checck @@ -1023,7 +1134,7 @@ def typecheck_arg(arg, expected_type, i, name): # TODO(mehrdadn): Should long really be convertible to int? pass else: - raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, name, type(arg), expected_type)) + raise RayGetArgumentTypeError(name, i, type(arg), expected_type) def check_arguments(arg_types, has_vararg_param, name, args): """Check that the arguments to the remote function have the right types. @@ -1080,7 +1191,10 @@ def get_arguments_for_execution(function, args, worker=global_worker): value. Raises: - Exception: An exception is raised the args do not all have the right types. + RayGetArgumentError: This exception is raised if a task that created one of + the arguments failed. + RayGetArgumentTypeError: This exception is raised (via typecheck_arg) if one + of the arguments does not have the expected type. """ # TODO(rkn): Eventually, all of the type checking can be put in `check_arguments` above so that the error will happen immediately when calling a remote function. arguments = [] @@ -1102,12 +1216,16 @@ def get_arguments_for_execution(function, args, worker=global_worker): # get the object from the local object store _logger().info("Getting argument {} for function {}.".format(i, function.__name__)) argument = worker.get_object(arg) + if isinstance(argument, RayTaskError): + # If the result is a RayTaskError, then the task that created this + # object failed, and we should propagate the error message here. + raise RayGetArgumentError(function.__name__, i, arg, argument) _logger().info("Successfully retrieved argument {} for function {}.".format(i, function.__name__)) else: # pass the argument by value argument = arg - typecheck_arg(argument, expected_type, i, function) + typecheck_arg(argument, expected_type, i, function.__name__) arguments.append(argument) return arguments