mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:33:16 +08:00
better error messages when composing remote functions (#339)
Better error messages when composing remote functions
This commit is contained in:
committed by
Philipp Moritz
parent
07baf44f26
commit
de200ff912
+175
-57
@@ -28,61 +28,168 @@ WORKER_MODE = 1
|
||||
PYTHON_MODE = 2
|
||||
SILENT_MODE = 3 # This is only used during testing.
|
||||
|
||||
class RayFailedObject(object):
|
||||
class RayTaskError(Exception):
|
||||
"""An object used internally to represent a task that threw an exception.
|
||||
|
||||
If a task throws an exception during execution, a RayFailedObject is stored in
|
||||
the object store for each of the tasks outputs. When an object is retrieved
|
||||
from the object store, the Python method that retrieved it should check to see
|
||||
if the object is a RayFailedObject and if it is then an exception should be
|
||||
thrown containing the error message.
|
||||
If a task throws an exception during execution, a RayTaskError is stored in
|
||||
the object store for each of the task's outputs. When an object is retrieved
|
||||
from the object store, the Python method that retrieved it checks to see if
|
||||
the object is a RayTaskError and if it is then an exceptionis thrown
|
||||
propagating the error message.
|
||||
|
||||
Attributes
|
||||
error_message (str): The error message raised by the task that failed.
|
||||
Currently, we either use the exception attribute or the traceback attribute
|
||||
but not both.
|
||||
|
||||
Attributes:
|
||||
function_name (str): The name of the function that failed and produced the
|
||||
RayTaskError.
|
||||
exception (Exception): The exception object thrown by the failed task.
|
||||
traceback_str (str): The traceback from the exception.
|
||||
"""
|
||||
|
||||
def __init__(self, error_message):
|
||||
"""Initialize a RayFailedObject.
|
||||
|
||||
Args:
|
||||
error_message (str): The error message raised by the task for which a
|
||||
RayFailedObject is being created.
|
||||
"""
|
||||
self.error_message = error_message
|
||||
def __init__(self, function_name, exception, traceback_str):
|
||||
"""Initialize a RayTaskError."""
|
||||
self.function_name = function_name
|
||||
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError) or isinstance(exception, RayGetArgumentTypeError):
|
||||
self.exception = exception
|
||||
else:
|
||||
self.exception = None
|
||||
self.traceback_str = traceback_str
|
||||
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
"""Create a RayFailedObject from a primitive object.
|
||||
|
||||
This initializes a RayFailedObject from a primitive object created by the
|
||||
serialize method. This method is required in order for Ray to serialize
|
||||
custom Python classes.
|
||||
|
||||
Note:
|
||||
This method should not be called by users.
|
||||
|
||||
Args:
|
||||
primitives (str): The object's error message.
|
||||
"""
|
||||
return RayFailedObject(primitives)
|
||||
"""Create a RayTaskError from a primitive object."""
|
||||
function_name, exception, traceback_str = primitives
|
||||
if exception[0] == "RayGetError":
|
||||
exception = RayGetError.deserialize(exception[1])
|
||||
elif exception[0] == "RayGetArgumentError":
|
||||
exception = RayGetArgumentError.deserialize(exception[1])
|
||||
elif exception[0] == "RayGetArgumentTypeError":
|
||||
exception = RayGetArgumentTypeError.deserialize(exception[1])
|
||||
elif exception[0] == "None":
|
||||
exception = None
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
return RayTaskError(function_name, exception, traceback_str)
|
||||
|
||||
def serialize(self):
|
||||
"""Turn a RayFailedObject into a primitive object.
|
||||
"""Turn a RayTaskError into a primitive object."""
|
||||
if isinstance(self.exception, RayGetError):
|
||||
serialized_exception = ("RayGetError", self.exception.serialize())
|
||||
elif isinstance(self.exception, RayGetArgumentError):
|
||||
serialized_exception = ("RayGetArgumentError", self.exception.serialize())
|
||||
elif isinstance(self.exception, RayGetArgumentTypeError):
|
||||
serialized_exception = ("RayGetArgumentTypeError", self.exception.serialize())
|
||||
elif self.exception is None:
|
||||
serialized_exception = ("None",)
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
return (self.function_name, serialized_exception, self.traceback_str)
|
||||
|
||||
This method is required in order for Ray to serialize
|
||||
custom Python classes.
|
||||
def __str__(self):
|
||||
"""Format a RayTaskError as a string."""
|
||||
if self.traceback_str is None:
|
||||
# This path is taken if getting the task arguments failed.
|
||||
return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.exception)
|
||||
else:
|
||||
# This path is taken if the task execution failed.
|
||||
return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.traceback_str)
|
||||
|
||||
Note:
|
||||
The output of this method should only be used by the deserialize method.
|
||||
This method should not be called by users.
|
||||
class RayGetError(Exception):
|
||||
"""An exception used when get is called on an output of a failed task.
|
||||
|
||||
Args:
|
||||
primitives (str): The object's error message.
|
||||
Attributes:
|
||||
objectid (lib.ObjectID): The ObjectID that get was called on.
|
||||
task_error (RayTaskError): The RayTaskError object created by the failed
|
||||
task.
|
||||
"""
|
||||
|
||||
Returns:
|
||||
A primitive representation of a RayFailedObject.
|
||||
"""
|
||||
return self.error_message
|
||||
def __init__(self, objectid, task_error):
|
||||
"""Initialize a RayGetError object."""
|
||||
self.objectid = objectid
|
||||
self.task_error = task_error
|
||||
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
"""Create a RayGetError from a primitive object."""
|
||||
objectid, task_error = primitives
|
||||
return RayGetError(objectid, RayTaskError.deserialize(task_error))
|
||||
|
||||
def serialize(self):
|
||||
"""Turn a RayGetError into a primitive object."""
|
||||
return (self.objectid, self.task_error.serialize())
|
||||
|
||||
def __str__(self):
|
||||
"""Format a RayGetError as a string."""
|
||||
return "Could not get objectid {}. It was created by remote function {}{}{} which failed with:\n\n{}".format(self.objectid, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
||||
|
||||
class RayGetArgumentError(Exception):
|
||||
"""An exception used when a task's argument was produced by a failed task.
|
||||
|
||||
Attributes:
|
||||
argument_index (int): The index (zero indexed) of the failed argument in
|
||||
present task's remote function call.
|
||||
function_name (str): The name of the function for the current task.
|
||||
objectid (lib.ObjectID): The ObjectID that was passed in as the argument.
|
||||
task_error (RayTaskError): The RayTaskError object created by the failed
|
||||
task.
|
||||
"""
|
||||
|
||||
def __init__(self, function_name, argument_index, objectid, task_error):
|
||||
"""Initialize a RayGetArgumentError object."""
|
||||
self.argument_index = argument_index
|
||||
self.function_name = function_name
|
||||
self.objectid = objectid
|
||||
self.task_error = task_error
|
||||
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
"""Create a RayGetArgumentError from a primitive object."""
|
||||
function_name, argument_index, objectid, task_error = primitives
|
||||
return RayGetArgumentError(function_name, argument_index, objectid, RayTaskError.deserialize(task_error))
|
||||
|
||||
def serialize(self):
|
||||
"""Turn a RayGetArgumentError into a primitive object."""
|
||||
return (self.function_name, self.argument_index, self.objectid, self.task_error.serialize())
|
||||
|
||||
def __str__(self):
|
||||
"""Format a RayGetArgumentError as a string."""
|
||||
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
||||
|
||||
class RayGetArgumentTypeError(Exception):
|
||||
"""An exception used when a task's argument doesn't type check.
|
||||
|
||||
Attributes:
|
||||
function_name (str): The name of the function for the current task.
|
||||
argument_index (int): The index (zero indexed) of the argument in the
|
||||
present task's remote function call.
|
||||
received_type: The type of the argument that was passed in.
|
||||
expected_type: The type that was expected. This is determined by the remote
|
||||
decorator.
|
||||
"""
|
||||
|
||||
def __init__(self, function_name, argument_index, received_type, expected_type):
|
||||
"""Initialize a RayGetArgumentTypeError object."""
|
||||
self.function_name = function_name
|
||||
self.argument_index = argument_index
|
||||
# TODO(rkn): when we support the serialization of types, then we should
|
||||
# remove the string conversions below.
|
||||
self.received_type = str(received_type)
|
||||
self.expected_type = str(expected_type)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
"""Create a RayGetArgumentTypeError from a primitive object."""
|
||||
function_name, argument_index, received_type, expected_type = primitives
|
||||
return RayGetArgumentTypeError(function_name, argument_index, received_type, expected_type)
|
||||
|
||||
def serialize(self):
|
||||
"""Turn a RayGetArgumentTypeError into a primitive object."""
|
||||
return (self.function_name, self.argument_index, self.received_type, self.expected_type)
|
||||
|
||||
def __str__(self):
|
||||
"""Format a RayGetArgumentTypeError as a string."""
|
||||
return "Argument {} for remote function {}{}{} has type {} but an argument of type {} was expected.".format(self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.received_type, self.expected_type)
|
||||
|
||||
class RayDealloc(object):
|
||||
"""An object used internally to properly implement reference counting.
|
||||
@@ -689,8 +796,10 @@ def get(objectid, worker=global_worker):
|
||||
if worker.mode == SCRIPT_MODE:
|
||||
worker.print_new_failures()
|
||||
value = worker.get_object(objectid)
|
||||
if isinstance(value, RayFailedObject):
|
||||
raise Exception("The task that created this object ID failed with error message:\n{}".format(value.error_message))
|
||||
if isinstance(value, RayTaskError):
|
||||
# If the result is a RayTaskError, then the task that created this object
|
||||
# failed, and we should propagate the error message here.
|
||||
raise RayGetError(objectid, value)
|
||||
return value
|
||||
|
||||
def put(value, worker=global_worker):
|
||||
@@ -749,7 +858,7 @@ def restart_workers_local(num_workers, worker_path, worker=global_worker):
|
||||
def format_error_message(exception_message):
|
||||
"""Improve the formatting of an exception thrown by a remote function.
|
||||
|
||||
This method takes an backtrace from an exception and makes it nicer by
|
||||
This method takes a traceback from an exception and makes it nicer by
|
||||
removing a few uninformative lines and adding some space to indent the
|
||||
remaining lines nicely.
|
||||
|
||||
@@ -763,7 +872,6 @@ def format_error_message(exception_message):
|
||||
# Remove lines 1, 2, 3, and 4, which are always the same, they just contain
|
||||
# information about the main loop.
|
||||
lines = lines[0:1] + lines[5:]
|
||||
lines = [10 * " " + line for line in lines]
|
||||
return "\n".join(lines)
|
||||
|
||||
def main_loop(worker=global_worker):
|
||||
@@ -778,7 +886,7 @@ def main_loop(worker=global_worker):
|
||||
|
||||
If the process of getting the arguments for execution (which does some type
|
||||
checking) or the process of executing the task fail, then the main loop will
|
||||
catch the exception and store RayFailedObject objects containing the relevant
|
||||
catch the exception and store RayTaskError objects containing the relevant
|
||||
error messages in the object store in place of the actual outputs. These
|
||||
objects are used to propagate the error messages.
|
||||
"""
|
||||
@@ -792,14 +900,16 @@ def main_loop(worker=global_worker):
|
||||
outputs = worker.functions[func_name].executor(arguments) # execute the function
|
||||
if len(return_objectids) == 1:
|
||||
outputs = (outputs,)
|
||||
except Exception:
|
||||
exception_message = format_error_message(traceback.format_exc())
|
||||
# Here we are storing RayFailedObjects in the object store to indicate
|
||||
# failure (this is only interpreted by the worker).
|
||||
failure_objects = [RayFailedObject(exception_message) for _ in range(len(return_objectids))]
|
||||
except Exception as e:
|
||||
# If the task threw an exception, then record the traceback. We determine
|
||||
# whether the exception was thrown in the task execution by whether the
|
||||
# variable "arguments" is defined.
|
||||
traceback_str = format_error_message(traceback.format_exc()) if "arguments" in locals() else None
|
||||
failure_object = RayTaskError(func_name, e, traceback_str)
|
||||
failure_objects = [failure_object for _ in range(len(return_objectids))]
|
||||
store_outputs_in_objstore(return_objectids, failure_objects, worker)
|
||||
raylib.notify_task_completed(worker.handle, False, exception_message) # notify the scheduler that the task threw an exception
|
||||
_logger().info("Worker threw exception with message: \n\n{}\n, while running function {}.".format(exception_message, func_name))
|
||||
raylib.notify_task_completed(worker.handle, False, str(failure_object))
|
||||
_logger().info("Worker threw exception with message: \n\n{}\n, while running function {}.".format(str(failure_object), func_name))
|
||||
else:
|
||||
store_outputs_in_objstore(return_objectids, outputs, worker) # store output in local object store
|
||||
raylib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully
|
||||
@@ -1013,7 +1123,8 @@ def typecheck_arg(arg, expected_type, i, name):
|
||||
name (str): The name of the function.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if arg does not have the expected type.
|
||||
RayGetArgumentTypeError: An exception is raised if arg does not have the
|
||||
expected type.
|
||||
"""
|
||||
if issubclass(type(arg), expected_type):
|
||||
# Passed the type-checck
|
||||
@@ -1023,7 +1134,7 @@ def typecheck_arg(arg, expected_type, i, name):
|
||||
# TODO(mehrdadn): Should long really be convertible to int?
|
||||
pass
|
||||
else:
|
||||
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, name, type(arg), expected_type))
|
||||
raise RayGetArgumentTypeError(name, i, type(arg), expected_type)
|
||||
|
||||
def check_arguments(arg_types, has_vararg_param, name, args):
|
||||
"""Check that the arguments to the remote function have the right types.
|
||||
@@ -1080,7 +1191,10 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
value.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised the args do not all have the right types.
|
||||
RayGetArgumentError: This exception is raised if a task that created one of
|
||||
the arguments failed.
|
||||
RayGetArgumentTypeError: This exception is raised (via typecheck_arg) if one
|
||||
of the arguments does not have the expected type.
|
||||
"""
|
||||
# TODO(rkn): Eventually, all of the type checking can be put in `check_arguments` above so that the error will happen immediately when calling a remote function.
|
||||
arguments = []
|
||||
@@ -1102,12 +1216,16 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
# get the object from the local object store
|
||||
_logger().info("Getting argument {} for function {}.".format(i, function.__name__))
|
||||
argument = worker.get_object(arg)
|
||||
if isinstance(argument, RayTaskError):
|
||||
# If the result is a RayTaskError, then the task that created this
|
||||
# object failed, and we should propagate the error message here.
|
||||
raise RayGetArgumentError(function.__name__, i, arg, argument)
|
||||
_logger().info("Successfully retrieved argument {} for function {}.".format(i, function.__name__))
|
||||
else:
|
||||
# pass the argument by value
|
||||
argument = arg
|
||||
|
||||
typecheck_arg(argument, expected_type, i, function)
|
||||
typecheck_arg(argument, expected_type, i, function.__name__)
|
||||
arguments.append(argument)
|
||||
return arguments
|
||||
|
||||
|
||||
Reference in New Issue
Block a user