From 81ee887f91c70383b2752da7b17d36fcfbafe52f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 28 Sep 2019 17:03:15 -0700 Subject: [PATCH] Preserve the original exception type when converting to RayTaskError (#5799) --- python/ray/exceptions.py | 37 ++++++++++++++++++-- python/ray/experimental/serve/task_runner.py | 5 +-- python/ray/local_mode_manager.py | 4 +-- python/ray/tests/test_failure.py | 8 ++++- python/ray/worker.py | 13 +++++-- 5 files changed, 57 insertions(+), 10 deletions(-) diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index c60adae2d..ba0c2e5d6 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -28,18 +28,49 @@ class RayTaskError(RayError): traceback_str (str): The traceback from the exception. """ - def __init__(self, function_name, traceback_str): + def __init__(self, + function_name, + traceback_str, + cause_cls, + pid=None, + host=None): """Initialize a RayTaskError.""" if setproctitle: self.proctitle = setproctitle.getproctitle() else: self.proctitle = "ray_worker" - self.pid = os.getpid() - self.host = os.uname()[1] + self.pid = pid or os.getpid() + self.host = host or os.uname()[1] self.function_name = function_name self.traceback_str = traceback_str + self.cause_cls = cause_cls assert traceback_str is not None + def as_instanceof_cause(self): + """Returns copy that is an instance of the cause's Python class. + + The returned exception will inherit from both RayTaskError and the + cause class. + """ + + if issubclass(RayTaskError, self.cause_cls): + return self # already satisfied + + class cls(self.cause_cls, RayTaskError): + def __init__(self, function_name, traceback_str, cause_cls, pid, + host): + RayTaskError.__init__(self, function_name, traceback_str, + cause_cls, pid, host) + + name = "RayTaskError({})".format(self.cause_cls.__name__) + cls.__name__ = name + cls.__qualname__ = name + + return cls(self.function_name, self.traceback_str, self.cause_cls, + self.pid, self.host) + cls.original = self + return cls + def __str__(self): """Format a RayTaskError as a string.""" lines = self.traceback_str.split("\n") diff --git a/python/ray/experimental/serve/task_runner.py b/python/ray/experimental/serve/task_runner.py index 0533a6c02..ff42cf67c 100644 --- a/python/ray/experimental/serve/task_runner.py +++ b/python/ray/experimental/serve/task_runner.py @@ -21,9 +21,10 @@ def wrap_to_ray_error(callable_obj, *args): """Utility method that catch and seal exceptions in execution""" try: return callable_obj(*args) - except Exception: + except Exception as e: traceback_str = ray.utils.format_error_message(traceback.format_exc()) - return ray.exceptions.RayTaskError(str(callable_obj), traceback_str) + return ray.exceptions.RayTaskError( + str(callable_obj), traceback_str, e.__class__) class RayServeMixin: diff --git a/python/ray/local_mode_manager.py b/python/ray/local_mode_manager.py index 1deb0533b..a85ae68e5 100644 --- a/python/ray/local_mode_manager.py +++ b/python/ray/local_mode_manager.py @@ -58,10 +58,10 @@ class LocalModeManager(object): else: for object_id, result in zip(object_ids, results): object_id.value = result - except Exception: + except Exception as e: function_name = function_descriptor.function_name backtrace = format_error_message(traceback.format_exc()) - task_error = RayTaskError(function_name, backtrace) + task_error = RayTaskError(function_name, backtrace, e.__class__) for object_id in object_ids: object_id.value = task_error diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 75e178802..7f747cb3e 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -63,14 +63,20 @@ def test_failed_task(ray_start_regular): # ray.get should throw an exception. assert False + class CustomException(ValueError): + pass + @ray.remote def f(): - raise Exception("This function failed.") + raise CustomException("This function failed.") try: ray.get(f.remote()) except Exception as e: assert "This function failed." in str(e) + assert isinstance(e, CustomException) + assert isinstance(e, ray.exceptions.RayTaskError) + assert "RayTaskError(CustomException)" in repr(e) else: # ray.get should throw an exception. assert False diff --git a/python/ray/worker.py b/python/ray/worker.py index 4a6c36ae0..bf74ebd42 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1002,7 +1002,13 @@ class Worker(object): def _handle_process_task_failure(self, function_descriptor, return_object_ids, error, backtrace): function_name = function_descriptor.function_name - failure_object = RayTaskError(function_name, backtrace) + if isinstance(error, RayTaskError): + # avoid recursively nesting of RayTaskError + failure_object = RayTaskError(function_name, backtrace, + error.cause_cls) + else: + failure_object = RayTaskError(function_name, backtrace, + error.__class__) failure_objects = [ failure_object for _ in range(len(return_object_ids)) ] @@ -2290,7 +2296,10 @@ def get(object_ids): last_task_error_raise_time = time.time() if isinstance(value, ray.exceptions.UnreconstructableError): worker.dump_object_store_memory_usage() - raise value + if isinstance(value, RayTaskError): + raise value.as_instanceof_cause() + else: + raise value # Run post processors. for post_processor in worker._post_get_hooks: