diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 806e7b5a3..0953b02a2 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -486,10 +486,10 @@ cdef execute_task( if isinstance(error, RayTaskError): # Avoid recursive nesting of RayTaskError. failure_object = RayTaskError(function_name, backtrace, - error.cause, proctitle=title) + error.cause_cls, proctitle=title) else: failure_object = RayTaskError(function_name, backtrace, - error, proctitle=title) + error.__class__, proctitle=title) errors = [] for _ in range(c_return_ids.size()): errors.append(failure_object) diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 75d382fd7..8adc46044 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -41,12 +41,17 @@ class RayTaskError(RayError): retrieved from the object store, the Python method that retrieved it checks to see if the object is a RayTaskError and if it is then an exception is thrown propagating the error message. + + Attributes: + function_name (str): The name of the function that failed and produced + the RayTaskError. + traceback_str (str): The traceback from the exception. """ def __init__(self, function_name, traceback_str, - cause, + cause_cls, proctitle=None, pid=None, ip=None): @@ -59,43 +64,34 @@ class RayTaskError(RayError): self.ip = ip or ray.services.get_node_ip_address() self.function_name = function_name self.traceback_str = traceback_str - # TODO(edoakes): should we handle non-serializable exception objects? - self.cause = cause + self.cause_cls = cause_cls assert traceback_str is not None def as_instanceof_cause(self): - """Returns an exception that is an instance of the cause's class. + """Returns copy that is an instance of the cause's Python class. The returned exception will inherit from both RayTaskError and the - cause class and will contain all of the attributes of the cause - exception. + cause class. """ - cause_cls = self.cause.__class__ - if issubclass(RayTaskError, cause_cls): + if issubclass(RayTaskError, self.cause_cls): return self # already satisfied - if issubclass(cause_cls, RayError): + if issubclass(self.cause_cls, RayError): return self # don't try to wrap ray internal errors - cause = self.cause - error_msg = str(self) + class cls(RayTaskError, self.cause_cls): + def __init__(self, function_name, traceback_str, cause_cls, + proctitle, pid, ip): + RayTaskError.__init__(self, function_name, traceback_str, + cause_cls, proctitle, pid, ip) - class cls(RayTaskError, cause_cls): - def __init__(self): - pass - - def __getattr__(self, name): - return getattr(cause, name) - - def __str__(self): - return error_msg - - name = "RayTaskError({})".format(cause_cls.__name__) + name = "RayTaskError({})".format(self.cause_cls.__name__) cls.__name__ = name cls.__qualname__ = name - return cls() + return cls(self.function_name, self.traceback_str, self.cause_cls, + self.proctitle, self.pid, self.ip) def __str__(self): """Format a RayTaskError as a string.""" diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 7ce16a5f0..7813caa55 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -63,12 +63,7 @@ def test_failed_task(ray_start_regular): assert False class CustomException(ValueError): - def __init__(self, msg): - super().__init__(msg) - self.field = 1 - - def f(self): - return 2 + pass @ray.remote def f(): @@ -78,12 +73,9 @@ def test_failed_task(ray_start_regular): ray.get(f.remote()) except Exception as e: assert "This function failed." in str(e) - assert isinstance(e, ValueError) assert isinstance(e, CustomException) assert isinstance(e, ray.exceptions.RayTaskError) assert "RayTaskError(CustomException)" in repr(e) - assert e.field == 1 - assert e.f() == 2 else: # ray.get should throw an exception. assert False