From 08adbb371f5cd6fd1fc839abb573fbed130b2a91 Mon Sep 17 00:00:00 2001 From: fyrestone Date: Wed, 26 Aug 2020 10:46:05 +0800 Subject: [PATCH] Cross language exception (#10023) --- doc/source/conf.py | 44 +++++-- java/BUILD.bazel | 1 + .../io/ray/api/exception/RayException.java | 15 --- .../ray/api/exception/RayTaskException.java | 15 --- .../io/ray/runtime/AbstractRayRuntime.java | 3 +- .../java/io/ray/runtime/RayRuntimeProxy.java | 2 +- .../exception/CrossLanguageException.java | 18 +++ .../runtime}/exception/RayActorException.java | 2 +- .../ray/runtime/exception/RayException.java | 40 +++++++ .../runtime/exception/RayTaskException.java | 13 +++ .../exception/RayWorkerException.java | 2 +- .../exception/UnreconstructableException.java | 2 +- .../ray/runtime/object/ObjectSerializer.java | 34 ++++-- .../io/ray/runtime/object/ObjectStore.java | 7 +- .../io/ray/runtime/task/TaskExecutor.java | 2 +- .../java/io/ray/test/ActorRestartTest.java | 2 +- .../src/main/java/io/ray/test/ActorTest.java | 2 +- .../ray/test/CrossLanguageInvocationTest.java | 109 +++++++++++++++++- .../main/java/io/ray/test/FailureTest.java | 9 +- .../main/java/io/ray/test/KillActorTest.java | 2 +- .../java/io/ray/test/MultiThreadingTest.java | 3 +- .../test_cross_language_invocation.py | 29 +++++ python/ray/_raylet.pyx | 3 + python/ray/exceptions.py | 32 ++++- python/ray/gcs_utils.py | 57 ++++++--- python/ray/serialization.py | 33 +++--- python/ray/tests/test_failure.py | 17 +++ src/ray/core_worker/lib/java/jni_init.cc | 2 +- src/ray/protobuf/common.proto | 44 +++++++ src/ray/protobuf/gcs.proto | 34 ------ 30 files changed, 441 insertions(+), 137 deletions(-) delete mode 100644 java/api/src/main/java/io/ray/api/exception/RayException.java delete mode 100644 java/api/src/main/java/io/ray/api/exception/RayTaskException.java create mode 100644 java/runtime/src/main/java/io/ray/runtime/exception/CrossLanguageException.java rename java/{api/src/main/java/io/ray/api => runtime/src/main/java/io/ray/runtime}/exception/RayActorException.java (91%) create mode 100644 java/runtime/src/main/java/io/ray/runtime/exception/RayException.java create mode 100644 java/runtime/src/main/java/io/ray/runtime/exception/RayTaskException.java rename java/{api/src/main/java/io/ray/api => runtime/src/main/java/io/ray/runtime}/exception/RayWorkerException.java (87%) rename java/{api/src/main/java/io/ray/api => runtime/src/main/java/io/ray/runtime}/exception/UnreconstructableException.java (95%) diff --git a/doc/source/conf.py b/doc/source/conf.py index bf1e7eab6..cdcb293f8 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -23,17 +23,39 @@ from custom_directives import CustomGalleryItemDirective # These lines added to enable Sphinx to work without installing Ray. import mock MOCK_MODULES = [ - "blist", "gym", "gym.spaces", "psutil", "ray._raylet", - "ray.core.generated", "ray.core.generated.gcs_pb2", - "ray.core.generated.ray.protocol.Task", "scipy.signal", "scipy.stats", - "setproctitle", "tensorflow_probability", "tensorflow", - "tensorflow.contrib", "tensorflow.contrib.all_reduce", "tree", - "tensorflow.contrib.all_reduce.python", "tensorflow.contrib.layers", - "tensorflow.contrib.rnn", "tensorflow.contrib.slim", "tensorflow.core", - "tensorflow.core.util", "tensorflow.python", "tensorflow.python.client", - "tensorflow.python.util", "torch", "torch.distributed", "torch.nn", - "torch.nn.parallel", "torch.utils.data", "torch.utils.data.distributed", - "zoopt" + "blist", + "gym", + "gym.spaces", + "psutil", + "ray._raylet", + "ray.core.generated", + "ray.core.generated.common_pb2", + "ray.core.generated.gcs_pb2", + "ray.core.generated.ray.protocol.Task", + "scipy.signal", + "scipy.stats", + "setproctitle", + "tensorflow_probability", + "tensorflow", + "tensorflow.contrib", + "tensorflow.contrib.all_reduce", + "tree", + "tensorflow.contrib.all_reduce.python", + "tensorflow.contrib.layers", + "tensorflow.contrib.rnn", + "tensorflow.contrib.slim", + "tensorflow.core", + "tensorflow.core.util", + "tensorflow.python", + "tensorflow.python.client", + "tensorflow.python.util", + "torch", + "torch.distributed", + "torch.nn", + "torch.nn.parallel", + "torch.utils.data", + "torch.utils.data.distributed", + "zoopt", ] import scipy.stats import scipy.linalg diff --git a/java/BUILD.bazel b/java/BUILD.bazel index b8691c1d6..258215738 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -105,6 +105,7 @@ define_java_module( ":io_ray_ray_runtime", "@maven//:com_google_code_gson_gson", "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_sun_xml_bind_jaxb_core", "@maven//:com_sun_xml_bind_jaxb_impl", "@maven//:commons_io_commons_io", diff --git a/java/api/src/main/java/io/ray/api/exception/RayException.java b/java/api/src/main/java/io/ray/api/exception/RayException.java deleted file mode 100644 index 50baab238..000000000 --- a/java/api/src/main/java/io/ray/api/exception/RayException.java +++ /dev/null @@ -1,15 +0,0 @@ -package io.ray.api.exception; - -/** - * Base class of all ray exceptions. - */ -public class RayException extends RuntimeException { - - public RayException(String message) { - super(message); - } - - public RayException(String message, Throwable cause) { - super(message, cause); - } -} diff --git a/java/api/src/main/java/io/ray/api/exception/RayTaskException.java b/java/api/src/main/java/io/ray/api/exception/RayTaskException.java deleted file mode 100644 index f0a5ae2ea..000000000 --- a/java/api/src/main/java/io/ray/api/exception/RayTaskException.java +++ /dev/null @@ -1,15 +0,0 @@ -package io.ray.api.exception; - -/** - * Indicates that a task threw an exception during execution. - * - * If a task throws an exception during execution, a RayTaskException is stored in the object store - * as the task's output. Then when the object is retrieved from the object store, this exception - * will be thrown and propagate the error message. - */ -public class RayTaskException extends RayException { - - public RayTaskException(String message, Throwable cause) { - super(message, cause); - } -} diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index a8ed4fa3e..fffc0efd2 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -8,7 +8,6 @@ import io.ray.api.BaseActorHandle; import io.ray.api.ObjectRef; import io.ray.api.PyActorHandle; import io.ray.api.WaitResult; -import io.ray.api.exception.RayException; import io.ray.api.function.PyActorClass; import io.ray.api.function.PyActorMethod; import io.ray.api.function.PyFunction; @@ -81,7 +80,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { } @Override - public T get(ObjectRef objectRef) throws RayException { + public T get(ObjectRef objectRef) throws RuntimeException { List ret = get(ImmutableList.of(objectRef)); return ret.get(0); } diff --git a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeProxy.java b/java/runtime/src/main/java/io/ray/runtime/RayRuntimeProxy.java index 1f515f91b..6220e1a8f 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeProxy.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayRuntimeProxy.java @@ -1,8 +1,8 @@ package io.ray.runtime; -import io.ray.api.exception.RayException; import io.ray.api.runtime.RayRuntime; import io.ray.runtime.config.RunMode; +import io.ray.runtime.exception.RayException; import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; diff --git a/java/runtime/src/main/java/io/ray/runtime/exception/CrossLanguageException.java b/java/runtime/src/main/java/io/ray/runtime/exception/CrossLanguageException.java new file mode 100644 index 000000000..74b0e2076 --- /dev/null +++ b/java/runtime/src/main/java/io/ray/runtime/exception/CrossLanguageException.java @@ -0,0 +1,18 @@ +package io.ray.runtime.exception; + +import io.ray.runtime.generated.Common.Language; + +public class CrossLanguageException extends RayException { + + private Language language; + + public CrossLanguageException(io.ray.runtime.generated.Common.RayException exception) { + super(String.format("An exception raised from %s:\n%s", exception.getLanguage().name(), + exception.getFormattedExceptionString())); + this.language = exception.getLanguage(); + } + + public Language getLanguage() { + return this.language; + } +} diff --git a/java/api/src/main/java/io/ray/api/exception/RayActorException.java b/java/runtime/src/main/java/io/ray/runtime/exception/RayActorException.java similarity index 91% rename from java/api/src/main/java/io/ray/api/exception/RayActorException.java rename to java/runtime/src/main/java/io/ray/runtime/exception/RayActorException.java index 81d9983e0..e73c41845 100644 --- a/java/api/src/main/java/io/ray/api/exception/RayActorException.java +++ b/java/runtime/src/main/java/io/ray/runtime/exception/RayActorException.java @@ -1,4 +1,4 @@ -package io.ray.api.exception; +package io.ray.runtime.exception; /** * Indicates that the actor died unexpectedly before finishing a task. diff --git a/java/runtime/src/main/java/io/ray/runtime/exception/RayException.java b/java/runtime/src/main/java/io/ray/runtime/exception/RayException.java new file mode 100644 index 000000000..10c00aa86 --- /dev/null +++ b/java/runtime/src/main/java/io/ray/runtime/exception/RayException.java @@ -0,0 +1,40 @@ +package io.ray.runtime.exception; + +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import io.ray.runtime.generated.Common.Language; +import io.ray.runtime.serializer.Serializer; + +public class RayException extends RuntimeException { + + public RayException(String message) { + super(message); + } + + public RayException(String message, Throwable cause) { + super(message, cause); + } + + public byte[] toBytes() { + String formattedException = org.apache.commons.lang3.exception.ExceptionUtils + .getStackTrace(this); + io.ray.runtime.generated.Common.RayException.Builder builder = + io.ray.runtime.generated.Common.RayException.newBuilder(); + builder.setLanguage(Language.JAVA); + builder.setFormattedExceptionString(formattedException); + builder.setSerializedException(ByteString.copyFrom(Serializer.encode(this).getLeft())); + return builder.build().toByteArray(); + } + + public static RayException fromBytes(byte[] serialized) + throws InvalidProtocolBufferException { + io.ray.runtime.generated.Common.RayException exception = + io.ray.runtime.generated.Common.RayException.parseFrom(serialized); + if (exception.getLanguage() == Language.JAVA) { + return Serializer + .decode(exception.getSerializedException().toByteArray(), RayException.class); + } else { + return new CrossLanguageException(exception); + } + } +} diff --git a/java/runtime/src/main/java/io/ray/runtime/exception/RayTaskException.java b/java/runtime/src/main/java/io/ray/runtime/exception/RayTaskException.java new file mode 100644 index 000000000..3b0d5327d --- /dev/null +++ b/java/runtime/src/main/java/io/ray/runtime/exception/RayTaskException.java @@ -0,0 +1,13 @@ +package io.ray.runtime.exception; + +import io.ray.runtime.util.NetworkUtil; +import io.ray.runtime.util.SystemUtil; + +public class RayTaskException extends RayException { + + public RayTaskException(String message, Throwable cause) { + super(String.format("(pid=%d, ip=%s) %s", + SystemUtil.pid(), NetworkUtil.getIpAddress(null), message), cause); + } + +} diff --git a/java/api/src/main/java/io/ray/api/exception/RayWorkerException.java b/java/runtime/src/main/java/io/ray/runtime/exception/RayWorkerException.java similarity index 87% rename from java/api/src/main/java/io/ray/api/exception/RayWorkerException.java rename to java/runtime/src/main/java/io/ray/runtime/exception/RayWorkerException.java index 97826423c..e24b55a61 100644 --- a/java/api/src/main/java/io/ray/api/exception/RayWorkerException.java +++ b/java/runtime/src/main/java/io/ray/runtime/exception/RayWorkerException.java @@ -1,4 +1,4 @@ -package io.ray.api.exception; +package io.ray.runtime.exception; /** * Indicates that the worker died unexpectedly while executing a task. diff --git a/java/api/src/main/java/io/ray/api/exception/UnreconstructableException.java b/java/runtime/src/main/java/io/ray/runtime/exception/UnreconstructableException.java similarity index 95% rename from java/api/src/main/java/io/ray/api/exception/UnreconstructableException.java rename to java/runtime/src/main/java/io/ray/runtime/exception/UnreconstructableException.java index 526ec6f0d..a99e9c075 100644 --- a/java/api/src/main/java/io/ray/api/exception/UnreconstructableException.java +++ b/java/runtime/src/main/java/io/ray/runtime/exception/UnreconstructableException.java @@ -1,4 +1,4 @@ -package io.ray.api.exception; +package io.ray.runtime.exception; import io.ray.api.id.ObjectId; diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java index 8b2c239aa..b09a70789 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java @@ -1,11 +1,12 @@ package io.ray.runtime.object; -import io.ray.api.exception.RayActorException; -import io.ray.api.exception.RayTaskException; -import io.ray.api.exception.RayWorkerException; -import io.ray.api.exception.UnreconstructableException; +import com.google.protobuf.InvalidProtocolBufferException; import io.ray.api.id.ObjectId; -import io.ray.runtime.generated.Gcs.ErrorType; +import io.ray.runtime.exception.RayActorException; +import io.ray.runtime.exception.RayTaskException; +import io.ray.runtime.exception.RayWorkerException; +import io.ray.runtime.exception.UnreconstructableException; +import io.ray.runtime.generated.Common.ErrorType; import io.ray.runtime.serializer.Serializer; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -70,7 +71,21 @@ public class ObjectSerializer { } else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) { return new UnreconstructableException(objectId); } else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) { - return Serializer.decode(data, objectType); + // Serialization logic of task execution exception: an instance of + // `io.ray.runtime.exception.RayTaskException` + // -> a `RayException` protobuf message + // -> protobuf-serialized bytes + // -> MessagePack-serialized bytes. + // So here the `data` variable is MessagePack-serialized bytes, and the `serialized` + // variable is protobuf-serialized bytes. They are not the same. + byte[] serialized = Serializer.decode(data, byte[].class); + try { + return RayTaskException.fromBytes(serialized); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException( + "Can't deserialize RayTaskException object: " + objectId + .toString()); + } } else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_PYTHON)) { throw new IllegalArgumentException("Can't deserialize Python object: " + objectId .toString()); @@ -107,7 +122,12 @@ public class ObjectSerializer { } return new NativeRayObject(bytes, OBJECT_METADATA_TYPE_RAW); } else if (object instanceof RayTaskException) { - byte[] serializedBytes = Serializer.encode(object).getLeft(); + RayTaskException taskException = (RayTaskException) object; + byte[] serializedBytes = Serializer.encode(taskException.toBytes()).getLeft(); + // serializedBytes is MessagePack serialized bytes + // taskException.toBytes() is protobuf serialized bytes + // Only OBJECT_METADATA_TYPE_RAW is raw bytes, + // any other type should be the MessagePack serialized bytes. return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META); } else { try { diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java index a30e5d8a4..2d227d604 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java @@ -3,10 +3,10 @@ package io.ray.runtime.object; import com.google.common.base.Preconditions; import io.ray.api.ObjectRef; import io.ray.api.WaitResult; -import io.ray.api.exception.RayException; import io.ray.api.id.ObjectId; import io.ray.api.id.UniqueId; import io.ray.runtime.context.WorkerContext; +import io.ray.runtime.exception.RayException; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -160,8 +160,7 @@ public abstract class ObjectStore { * Delete a list of objects from the object store. * * @param objectIds IDs of the objects to delete. - * @param localOnly Whether only delete the objects in local node, or all nodes in the - * cluster. + * @param localOnly Whether only delete the objects in local node, or all nodes in the cluster. * @param deleteCreatingTasks Whether also delete the tasks that created these objects. */ public abstract void delete(List objectIds, boolean localOnly, @@ -169,6 +168,7 @@ public abstract class ObjectStore { /** * Increase the local reference count for this object ID. + * * @param workerId The ID of the worker to increase on. * @param objectId The object ID to increase the reference count for. */ @@ -176,6 +176,7 @@ public abstract class ObjectStore { /** * Decrease the reference count for this object ID. + * * @param workerId The ID of the worker to decrease on. * @param objectId The object ID to decrease the reference count for. */ diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java index 115e422f7..2b5a97ea2 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java @@ -1,12 +1,12 @@ package io.ray.runtime.task; import com.google.common.base.Preconditions; -import io.ray.api.exception.RayTaskException; import io.ray.api.id.ActorId; import io.ray.api.id.JobId; import io.ray.api.id.TaskId; import io.ray.api.id.UniqueId; import io.ray.runtime.RayRuntimeInternal; +import io.ray.runtime.exception.RayTaskException; import io.ray.runtime.functionmanager.JavaFunctionDescriptor; import io.ray.runtime.functionmanager.RayFunction; import io.ray.runtime.generated.Common.TaskType; diff --git a/java/test/src/main/java/io/ray/test/ActorRestartTest.java b/java/test/src/main/java/io/ray/test/ActorRestartTest.java index 132c8954c..099122394 100644 --- a/java/test/src/main/java/io/ray/test/ActorRestartTest.java +++ b/java/test/src/main/java/io/ray/test/ActorRestartTest.java @@ -3,7 +3,7 @@ package io.ray.test; import io.ray.api.ActorHandle; import io.ray.api.Checkpointable; import io.ray.api.Ray; -import io.ray.api.exception.RayActorException; +import io.ray.runtime.exception.RayActorException; import io.ray.api.id.ActorId; import io.ray.api.id.UniqueId; import io.ray.runtime.util.SystemUtil; diff --git a/java/test/src/main/java/io/ray/test/ActorTest.java b/java/test/src/main/java/io/ray/test/ActorTest.java index 1e7e40b04..dd2bda08a 100644 --- a/java/test/src/main/java/io/ray/test/ActorTest.java +++ b/java/test/src/main/java/io/ray/test/ActorTest.java @@ -5,7 +5,7 @@ import io.ray.api.ActorHandle; import io.ray.api.ObjectRef; import io.ray.api.PyActorHandle; import io.ray.api.Ray; -import io.ray.api.exception.UnreconstructableException; +import io.ray.runtime.exception.UnreconstructableException; import io.ray.api.id.ActorId; import io.ray.api.id.UniqueId; import java.util.Collections; diff --git a/java/test/src/main/java/io/ray/test/CrossLanguageInvocationTest.java b/java/test/src/main/java/io/ray/test/CrossLanguageInvocationTest.java index ba9868890..e630494ee 100644 --- a/java/test/src/main/java/io/ray/test/CrossLanguageInvocationTest.java +++ b/java/test/src/main/java/io/ray/test/CrossLanguageInvocationTest.java @@ -11,6 +11,9 @@ import io.ray.api.function.PyActorMethod; import io.ray.api.function.PyFunction; import io.ray.runtime.actor.NativeActorHandle; import io.ray.runtime.actor.NativePyActorHandle; +import io.ray.runtime.exception.CrossLanguageException; +import io.ray.runtime.exception.RayException; +import io.ray.runtime.generated.Common.Language; import java.io.File; import java.io.IOException; import java.io.InputStream; @@ -19,14 +22,11 @@ import java.util.Arrays; import java.util.List; import java.util.Map; import org.apache.commons.io.FileUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.testng.Assert; import org.testng.annotations.Test; public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { - private static final Logger LOGGER = LoggerFactory.getLogger(CrossLanguageInvocationTest.class); private static final String PYTHON_MODULE = "test_cross_language_invocation"; @Override @@ -151,7 +151,6 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { PyFunction.of(PYTHON_MODULE, "py_func_call_java_actor", byte[].class), "1".getBytes()).remote(); Assert.assertEquals(res.get(), "Counter1".getBytes()); - } @Test @@ -188,6 +187,91 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { Assert.assertEquals(res.get(), "3".getBytes()); } + @Test + public void testExceptionSerialization() throws IOException { + try { + throw new RayException("Test Exception"); + } catch (RayException e) { + String formattedException = org.apache.commons.lang3.exception.ExceptionUtils + .getStackTrace(e); + io.ray.runtime.generated.Common.RayException exception = io.ray.runtime.generated.Common.RayException + .parseFrom(e.toBytes()); + Assert.assertEquals(exception.getFormattedExceptionString(), formattedException); + } + } + + @Test + public void testRaiseExceptionFromPython() { + ObjectRef res = Ray.task(PyFunction.of( + PYTHON_MODULE, "py_func_python_raise_exception", Object.class)).remote(); + try { + res.get(); + } catch (RuntimeException ex) { + // ex is a Python exception(py_func_python_raise_exception) with no cause. + Assert.assertTrue(ex instanceof CrossLanguageException); + CrossLanguageException e = (CrossLanguageException) ex; + Assert.assertEquals(e.getLanguage(), Language.PYTHON); + // ex.cause is null. + Assert.assertNull(ex.getCause()); + Assert.assertTrue(ex.getMessage().contains("ZeroDivisionError: division by zero"), + ex.getMessage()); + return; + } + Assert.fail(); + } + + @Test + public void testThrowExceptionFromJava() { + ObjectRef res = Ray.task(PyFunction.of( + PYTHON_MODULE, "py_func_java_throw_exception", Object.class)).remote(); + try { + res.get(); + } catch (RuntimeException ex) { + final String message = ex.getMessage(); + Assert.assertTrue(message.contains("py_func_java_throw_exception"), message); + Assert.assertTrue(message.contains("io.ray.test.CrossLanguageInvocationTest.throwException"), + message); + Assert.assertTrue(message.contains("java.lang.ArithmeticException: / by zero"), message); + return; + } + Assert.fail(); + } + + @Test + public void testRaiseExceptionFromNestPython() { + ObjectRef res = Ray.task( + PyFunction.of(PYTHON_MODULE, "py_func_nest_python_raise_exception", Object.class)).remote(); + try { + res.get(); + } catch (RuntimeException ex) { + final String message = ex.getMessage(); + Assert.assertTrue(message.contains("py_func_nest_python_raise_exception"), message); + Assert.assertTrue(message.contains("io.ray.runtime.task.TaskExecutor.execute"), message); + Assert.assertTrue(message.contains("py_func_python_raise_exception"), message); + Assert.assertTrue(message.contains("ZeroDivisionError: division by zero"), message); + return; + } + Assert.fail(); + } + + @Test + public void testThrowExceptionFromNestJava() { + ObjectRef res = Ray.task( + PyFunction.of(PYTHON_MODULE, "py_func_nest_java_throw_exception", Object.class)).remote(); + try { + res.get(); + } catch (RuntimeException ex) { + final String message = ex.getMessage(); + Assert.assertTrue(message.contains("py_func_nest_java_throw_exception"), message); + Assert.assertEquals(org.apache.commons.lang3.StringUtils + .countMatches(message, "io.ray.runtime.exception.RayTaskException"), 2); + Assert.assertTrue(message.contains("py_func_java_throw_exception"), message); + Assert.assertTrue(message.contains("java.lang.ArithmeticException: / by zero"), message); + return; + } + Assert.fail(); + } + public static Object[] pack(int i, String s, double f, Object[] o) { // This function will be called from test_cross_language_invocation.py return new Object[]{i, s, f, o}; @@ -227,6 +311,23 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { return (byte[]) res.get(); } + @SuppressWarnings("ConstantOverflow") + public static Object throwException() { + return 1 / 0; + } + + public static Object throwJavaException() { + ObjectRef res = Ray.task( + PyFunction.of(PYTHON_MODULE, "py_func_java_throw_exception", Object.class)).remote(); + return res.get(); + } + + public static Object raisePythonException() { + ObjectRef res = Ray.task( + PyFunction.of(PYTHON_MODULE, "py_func_python_raise_exception", Object.class)).remote(); + return res.get(); + } + public static class TestActor { public TestActor(byte[] v) { diff --git a/java/test/src/main/java/io/ray/test/FailureTest.java b/java/test/src/main/java/io/ray/test/FailureTest.java index 803040080..0d25b16f0 100644 --- a/java/test/src/main/java/io/ray/test/FailureTest.java +++ b/java/test/src/main/java/io/ray/test/FailureTest.java @@ -3,11 +3,10 @@ package io.ray.test; import io.ray.api.ActorHandle; import io.ray.api.ObjectRef; import io.ray.api.Ray; -import io.ray.api.exception.RayActorException; -import io.ray.api.exception.RayException; -import io.ray.api.exception.RayTaskException; -import io.ray.api.exception.RayWorkerException; +import io.ray.runtime.exception.RayActorException; +import io.ray.runtime.exception.RayWorkerException; import io.ray.api.function.RayFunc0; +import io.ray.runtime.exception.RayTaskException; import java.time.Duration; import java.time.Instant; import java.util.Arrays; @@ -138,7 +137,7 @@ public class FailureTest extends BaseTest { try { Ray.get(Arrays.asList(obj1, obj2)); Assert.fail("Should throw RayException."); - } catch (RayException e) { + } catch (RuntimeException e) { Instant end = Instant.now(); long duration = Duration.between(start, end).toMillis(); Assert.assertTrue(duration < 5000, "Should fail quickly. " + diff --git a/java/test/src/main/java/io/ray/test/KillActorTest.java b/java/test/src/main/java/io/ray/test/KillActorTest.java index c7ac3d673..6e3a64391 100644 --- a/java/test/src/main/java/io/ray/test/KillActorTest.java +++ b/java/test/src/main/java/io/ray/test/KillActorTest.java @@ -4,7 +4,7 @@ import com.google.common.collect.ImmutableList; import io.ray.api.ActorHandle; import io.ray.api.ObjectRef; import io.ray.api.Ray; -import io.ray.api.exception.RayActorException; +import io.ray.runtime.exception.RayActorException; import java.util.function.BiConsumer; import org.testng.Assert; import org.testng.annotations.AfterClass; diff --git a/java/test/src/main/java/io/ray/test/MultiThreadingTest.java b/java/test/src/main/java/io/ray/test/MultiThreadingTest.java index d7d9f1425..bbd360637 100644 --- a/java/test/src/main/java/io/ray/test/MultiThreadingTest.java +++ b/java/test/src/main/java/io/ray/test/MultiThreadingTest.java @@ -5,7 +5,6 @@ import io.ray.api.ActorHandle; import io.ray.api.ObjectRef; import io.ray.api.Ray; import io.ray.api.WaitResult; -import io.ray.api.exception.RayException; import io.ray.api.id.ActorId; import java.util.ArrayList; import java.util.List; @@ -189,7 +188,7 @@ public class MultiThreadingTest extends BaseTest { try { // It wouldn't be OK to run them in another thread if not wrapped the runnable. for (Runnable runnable : runnables) { - Assert.expectThrows(RayException.class, runnable::run); + Assert.expectThrows(RuntimeException.class, runnable::run); } } catch (Throwable ex) { throwable[0] = ex; diff --git a/java/test/src/main/resources/test_cross_language_invocation.py b/java/test/src/main/resources/test_cross_language_invocation.py index f14bf1da2..372c6f6e8 100644 --- a/java/test/src/main/resources/test_cross_language_invocation.py +++ b/java/test/src/main/resources/test_cross_language_invocation.py @@ -83,6 +83,35 @@ def py_func_pass_python_actor_handle(): return ray.get(r) +@ray.remote +def py_func_python_raise_exception(): + 1 / 0 + + +@ray.remote +def py_func_java_throw_exception(): + f = ray.java_function("io.ray.test.CrossLanguageInvocationTest", + "throwException") + r = f.remote() + return ray.get(r) + + +@ray.remote +def py_func_nest_python_raise_exception(): + f = ray.java_function("io.ray.test.CrossLanguageInvocationTest", + "raisePythonException") + r = f.remote() + return ray.get(r) + + +@ray.remote +def py_func_nest_java_throw_exception(): + f = ray.java_function("io.ray.test.CrossLanguageInvocationTest", + "throwJavaException") + r = f.remote() + return ray.get(r) + + @ray.remote class Counter(object): def __init__(self, value): diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 1b435d81a..d5f4c1bd1 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -222,6 +222,9 @@ cdef class Language: cdef from_native(const CLanguage& lang): return Language(lang) + def value(self): + return self.lang + def __eq__(self, other): return (isinstance(other, Language) and (self.lang) == ((other).lang)) diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 98abb9ceb..5b246afda 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -1,14 +1,44 @@ import os +from traceback import format_exception import colorama import ray +import ray.cloudpickle as pickle +from ray.core.generated.common_pb2 import RayException, Language import setproctitle class RayError(Exception): """Super class of all ray exception types.""" - pass + + def to_bytes(self): + # Extract exc_info from exception object. + exc_info = (type(self), self, self.__traceback__) + formatted_exception_string = "\n".join(format_exception(*exc_info)) + return RayException( + language=ray.Language.PYTHON.value(), + serialized_exception=pickle.dumps(self), + formatted_exception_string=formatted_exception_string + ).SerializeToString() + + @staticmethod + def from_bytes(b): + ray_exception = RayException() + ray_exception.ParseFromString(b) + if ray_exception.language == ray.Language.PYTHON.value(): + return pickle.loads(ray_exception.serialized_exception) + else: + return CrossLanguageError(ray_exception) + + +class CrossLanguageError(RayError): + """Raised from another language.""" + + def __init__(self, ray_exception): + super().__init__("An exception raised from {}:\n{}".format( + Language.Name(ray_exception.language), + ray_exception.formatted_exception_string)) class RayConnectionError(RayError): diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 1662489c0..cf9cc2e6d 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -1,19 +1,50 @@ +from ray.core.generated.common_pb2 import ErrorType from ray.core.generated.gcs_pb2 import ( - ActorCheckpointIdData, ActorTableData, GcsNodeInfo, JobTableData, - JobConfig, ErrorTableData, ErrorType, GcsEntry, HeartbeatBatchTableData, - HeartbeatTableData, ObjectTableData, ProfileTableData, TablePrefix, - TablePubsub, TaskTableData, ResourceMap, ResourceTableData, - ObjectLocationInfo, PubSubMessage, WorkerTableData, - PlacementGroupTableData) + ActorCheckpointIdData, + ActorTableData, + GcsNodeInfo, + JobTableData, + JobConfig, + ErrorTableData, + GcsEntry, + HeartbeatBatchTableData, + HeartbeatTableData, + ObjectTableData, + ProfileTableData, + TablePrefix, + TablePubsub, + TaskTableData, + ResourceMap, + ResourceTableData, + ObjectLocationInfo, + PubSubMessage, + WorkerTableData, + PlacementGroupTableData, +) __all__ = [ - "ActorCheckpointIdData", "ActorTableData", "GcsNodeInfo", "JobTableData", - "JobConfig", "ErrorTableData", "ErrorType", "GcsEntry", - "HeartbeatBatchTableData", "HeartbeatTableData", "ObjectTableData", - "ProfileTableData", "TablePrefix", "TablePubsub", "TaskTableData", - "ResourceMap", "ResourceTableData", "construct_error_message", - "ObjectLocationInfo", "PubSubMessage", "WorkerTableData", - "PlacementGroupTableData" + "ActorCheckpointIdData", + "ActorTableData", + "GcsNodeInfo", + "JobTableData", + "JobConfig", + "ErrorTableData", + "ErrorType", + "GcsEntry", + "HeartbeatBatchTableData", + "HeartbeatTableData", + "ObjectTableData", + "ProfileTableData", + "TablePrefix", + "TablePubsub", + "TaskTableData", + "ResourceMap", + "ResourceTableData", + "construct_error_message", + "ObjectLocationInfo", + "PubSubMessage", + "WorkerTableData", + "PlacementGroupTableData", ] FUNCTION_PREFIX = "RemoteFunction:" diff --git a/python/ray/serialization.py b/python/ray/serialization.py index cce2fcad3..faee81971 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -9,6 +9,7 @@ import ray.utils from ray.utils import _random_string from ray.gcs_utils import ErrorType from ray.exceptions import ( + RayError, PlasmaObjectNotAvailable, RayTaskError, RayActorError, @@ -221,10 +222,10 @@ class SerializationContext: def _deserialize_msgpack_data(self, data, metadata): msgpack_data, pickle5_data = split_buffer(data) - if metadata == ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE: - python_objects = [] - else: + if metadata == ray_constants.OBJECT_METADATA_TYPE_PYTHON: python_objects = self._deserialize_pickle5_data(pickle5_data) + else: + python_objects = [] try: @@ -262,8 +263,7 @@ class SerializationContext: # independent. if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"): obj = self._deserialize_msgpack_data(data, metadata) - assert isinstance(obj, RayTaskError) - return obj + return RayError.from_bytes(obj) elif error_type == ErrorType.Value("WORKER_DIED"): return RayWorkerError() elif error_type == ErrorType.Value("ACTOR_DIED"): @@ -347,7 +347,16 @@ class SerializationContext: metadata, inband, writer, self.get_and_clear_contained_object_refs()) - def _serialize_to_msgpack(self, metadata, value): + def _serialize_to_msgpack(self, value): + # Only RayTaskError is possible to be serialized here. We don't + # need to deal with other exception types here. + if isinstance(value, RayTaskError): + metadata = str( + ErrorType.Value("TASK_EXECUTION_EXCEPTION")).encode("ascii") + value = value.to_bytes() + else: + metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE + python_objects = [] def _python_serializer(o): @@ -358,10 +367,10 @@ class SerializationContext: msgpack_data = MessagePackSerializer.dumps(value, _python_serializer) if python_objects: + metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON pickle5_serialized_object = \ self._serialize_to_pickle5(metadata, python_objects) else: - metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE pickle5_serialized_object = None return MessagePackSerializedObject(metadata, msgpack_data, @@ -379,15 +388,7 @@ class SerializationContext: # that this object can also be read by Java. return RawSerializedObject(value) else: - # Only RayTaskError is possible to be serialized here. We don't - # need to deal with other exception types here. - if isinstance(value, RayTaskError): - metadata = str(ErrorType.Value( - "TASK_EXECUTION_EXCEPTION")).encode("ascii") - else: - metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON - - return self._serialize_to_msgpack(metadata, value) + return self._serialize_to_msgpack(value) def register_custom_serializer(self, cls, diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 6772a5937..742f9167c 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -12,6 +12,7 @@ import redis import ray import ray.ray_constants as ray_constants +from ray.exceptions import RayTaskError from ray.cluster_utils import Cluster from ray.test_utils import ( wait_for_condition, @@ -452,6 +453,22 @@ def test_actor_scope_or_intentionally_killed_message(ray_start_regular, errors) +def test_exception_chain(ray_start_regular): + @ray.remote + def bar(): + return 1 / 0 + + @ray.remote + def foo(): + return ray.get(bar.remote()) + + r = foo.remote() + try: + ray.get(r) + except ZeroDivisionError as ex: + assert isinstance(ex, RayTaskError) + + @pytest.mark.skip("This test does not work yet.") @pytest.mark.parametrize( "ray_start_object_store_memory", [10**6], indirect=True) diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 30bd1b949..1213de4f6 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -167,7 +167,7 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_system_class = LoadClass(env, "java/lang/System"); java_system_gc = env->GetStaticMethodID(java_system_class, "gc", "()V"); - java_ray_exception_class = LoadClass(env, "io/ray/api/exception/RayException"); + java_ray_exception_class = LoadClass(env, "io/ray/runtime/exception/RayException"); java_jni_exception_util_class = LoadClass(env, "io/ray/runtime/util/JniExceptionUtil"); java_jni_exception_util_get_stack_trace = env->GetStaticMethodID( diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 6e9952093..1c622b602 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -100,6 +100,50 @@ message FunctionDescriptor { } } +// This enum type is used as object's metadata to indicate the object's +// creating task has failed because of a certain error. +// TODO(hchen): We may want to make these errors more specific. E.g., we may +// want to distinguish between intentional and expected actor failures, and +// between worker process failure and node failure. +enum ErrorType { + // Indicates that a task failed because the worker died unexpectedly while + // executing it. + WORKER_DIED = 0; + // Indicates that a task failed because the actor died unexpectedly before + // finishing it. + ACTOR_DIED = 1; + // Indicates that an object is lost and cannot be restarted. + // Note, this currently only happens to actor objects. When the actor's + // state is already after the object's creating task, the actor cannot + // re-run the task. + // TODO(hchen): we may want to reuse this error type for more cases. E.g., + // 1) A object that was put by the driver. + // 2) The object's creating task is already cleaned up from GCS (this + // currently crashes raylet). + OBJECT_UNRECONSTRUCTABLE = 2; + // Indicates that a task failed due to user code failure. + TASK_EXECUTION_EXCEPTION = 3; + // Indicates that the object has been placed in plasma. This error shouldn't + // ever be exposed to user code; it is only used internally to indicate the + // result of a direct call has been placed in plasma. + OBJECT_IN_PLASMA = 4; + // Indicates that an object has been cancelled. + TASK_CANCELLED = 5; + // Inidicates that creating the GCS service failed to create the actor. + ACTOR_CREATION_FAILED = 6; +} + +/// The task exception encapsulates all information about task +/// execution execeptions. +message RayException { + // Language of this exception. + Language language = 1; + // The serialized exception. + bytes serialized_exception = 2; + // The formatted exception string. + string formatted_exception_string = 3; +} + /// The task specification encapsulates all immutable information about the /// task. These fields are determined at submission time, converse to the /// `TaskExecutionSpec` may change at execution time. diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 8a4fe4a9d..e8a4e0a0b 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -422,37 +422,3 @@ message PubSubMessage { bytes id = 1; bytes data = 2; } - -// This enum type is used as object's metadata to indicate the object's -// creating task has failed because of a certain error. -// TODO(hchen): We may want to make these errors more specific. E.g., we may -// want to distinguish between intentional and expected actor failures, and -// between worker process failure and node failure. -enum ErrorType { - // Indicates that a task failed because the worker died unexpectedly while - // executing it. - WORKER_DIED = 0; - // Indicates that a task failed because the actor died unexpectedly before - // finishing it. - ACTOR_DIED = 1; - // Indicates that an object is lost and cannot be restarted. - // Note, this currently only happens to actor objects. When the actor's - // state is already after the object's creating task, the actor cannot - // re-run the task. - // TODO(hchen): we may want to reuse this error type for more cases. E.g., - // 1) A object that was put by the driver. - // 2) The object's creating task is already cleaned up from GCS (this - // currently - // crashes raylet). - OBJECT_UNRECONSTRUCTABLE = 2; - // Indicates that a task failed due to user code failure. - TASK_EXECUTION_EXCEPTION = 3; - // Indicates that the object has been placed in plasma. This error shouldn't - // ever be exposed to user code; it is only used internally to indicate the - // result of a direct call has been placed in plasma. - OBJECT_IN_PLASMA = 4; - // Indicates that an object has been cancelled. - TASK_CANCELLED = 5; - // Inidicates that creating the GCS service failed to create the actor. - ACTOR_CREATION_FAILED = 6; -}