Preserve the original exception type when converting to RayTaskError (#5799)

This commit is contained in:
Eric Liang
2019-09-28 17:03:15 -07:00
committed by GitHub
parent 493364d3bd
commit 81ee887f91
5 changed files with 57 additions and 10 deletions
+34 -3
View File
@@ -28,18 +28,49 @@ class RayTaskError(RayError):
traceback_str (str): The traceback from the exception.
"""
def __init__(self, function_name, traceback_str):
def __init__(self,
function_name,
traceback_str,
cause_cls,
pid=None,
host=None):
"""Initialize a RayTaskError."""
if setproctitle:
self.proctitle = setproctitle.getproctitle()
else:
self.proctitle = "ray_worker"
self.pid = os.getpid()
self.host = os.uname()[1]
self.pid = pid or os.getpid()
self.host = host or os.uname()[1]
self.function_name = function_name
self.traceback_str = traceback_str
self.cause_cls = cause_cls
assert traceback_str is not None
def as_instanceof_cause(self):
"""Returns copy that is an instance of the cause's Python class.
The returned exception will inherit from both RayTaskError and the
cause class.
"""
if issubclass(RayTaskError, self.cause_cls):
return self # already satisfied
class cls(self.cause_cls, RayTaskError):
def __init__(self, function_name, traceback_str, cause_cls, pid,
host):
RayTaskError.__init__(self, function_name, traceback_str,
cause_cls, pid, host)
name = "RayTaskError({})".format(self.cause_cls.__name__)
cls.__name__ = name
cls.__qualname__ = name
return cls(self.function_name, self.traceback_str, self.cause_cls,
self.pid, self.host)
cls.original = self
return cls
def __str__(self):
"""Format a RayTaskError as a string."""
lines = self.traceback_str.split("\n")
+3 -2
View File
@@ -21,9 +21,10 @@ def wrap_to_ray_error(callable_obj, *args):
"""Utility method that catch and seal exceptions in execution"""
try:
return callable_obj(*args)
except Exception:
except Exception as e:
traceback_str = ray.utils.format_error_message(traceback.format_exc())
return ray.exceptions.RayTaskError(str(callable_obj), traceback_str)
return ray.exceptions.RayTaskError(
str(callable_obj), traceback_str, e.__class__)
class RayServeMixin:
+2 -2
View File
@@ -58,10 +58,10 @@ class LocalModeManager(object):
else:
for object_id, result in zip(object_ids, results):
object_id.value = result
except Exception:
except Exception as e:
function_name = function_descriptor.function_name
backtrace = format_error_message(traceback.format_exc())
task_error = RayTaskError(function_name, backtrace)
task_error = RayTaskError(function_name, backtrace, e.__class__)
for object_id in object_ids:
object_id.value = task_error
+7 -1
View File
@@ -63,14 +63,20 @@ def test_failed_task(ray_start_regular):
# ray.get should throw an exception.
assert False
class CustomException(ValueError):
pass
@ray.remote
def f():
raise Exception("This function failed.")
raise CustomException("This function failed.")
try:
ray.get(f.remote())
except Exception as e:
assert "This function failed." in str(e)
assert isinstance(e, CustomException)
assert isinstance(e, ray.exceptions.RayTaskError)
assert "RayTaskError(CustomException)" in repr(e)
else:
# ray.get should throw an exception.
assert False
+11 -2
View File
@@ -1002,7 +1002,13 @@ class Worker(object):
def _handle_process_task_failure(self, function_descriptor,
return_object_ids, error, backtrace):
function_name = function_descriptor.function_name
failure_object = RayTaskError(function_name, backtrace)
if isinstance(error, RayTaskError):
# avoid recursively nesting of RayTaskError
failure_object = RayTaskError(function_name, backtrace,
error.cause_cls)
else:
failure_object = RayTaskError(function_name, backtrace,
error.__class__)
failure_objects = [
failure_object for _ in range(len(return_object_ids))
]
@@ -2290,7 +2296,10 @@ def get(object_ids):
last_task_error_raise_time = time.time()
if isinstance(value, ray.exceptions.UnreconstructableError):
worker.dump_object_store_memory_usage()
raise value
if isinstance(value, RayTaskError):
raise value.as_instanceof_cause()
else:
raise value
# Run post processors.
for post_processor in worker._post_get_hooks: