Cross language exception (#10023)

This commit is contained in:
fyrestone
2020-08-26 10:46:05 +08:00
committed by GitHub
parent 1e99b814f0
commit 08adbb371f
30 changed files with 441 additions and 137 deletions
+1
View File
@@ -105,6 +105,7 @@ define_java_module(
":io_ray_ray_runtime",
"@maven//:com_google_code_gson_gson",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_sun_xml_bind_jaxb_core",
"@maven//:com_sun_xml_bind_jaxb_impl",
"@maven//:commons_io_commons_io",
@@ -1,15 +0,0 @@
package io.ray.api.exception;
/**
* Base class of all ray exceptions.
*/
public class RayException extends RuntimeException {
public RayException(String message) {
super(message);
}
public RayException(String message, Throwable cause) {
super(message, cause);
}
}
@@ -1,15 +0,0 @@
package io.ray.api.exception;
/**
* Indicates that a task threw an exception during execution.
*
* If a task throws an exception during execution, a RayTaskException is stored in the object store
* as the task's output. Then when the object is retrieved from the object store, this exception
* will be thrown and propagate the error message.
*/
public class RayTaskException extends RayException {
public RayTaskException(String message, Throwable cause) {
super(message, cause);
}
}
@@ -8,7 +8,6 @@ import io.ray.api.BaseActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.WaitResult;
import io.ray.api.exception.RayException;
import io.ray.api.function.PyActorClass;
import io.ray.api.function.PyActorMethod;
import io.ray.api.function.PyFunction;
@@ -81,7 +80,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
}
@Override
public <T> T get(ObjectRef<T> objectRef) throws RayException {
public <T> T get(ObjectRef<T> objectRef) throws RuntimeException {
List<T> ret = get(ImmutableList.of(objectRef));
return ret.get(0);
}
@@ -1,8 +1,8 @@
package io.ray.runtime;
import io.ray.api.exception.RayException;
import io.ray.api.runtime.RayRuntime;
import io.ray.runtime.config.RunMode;
import io.ray.runtime.exception.RayException;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
@@ -0,0 +1,18 @@
package io.ray.runtime.exception;
import io.ray.runtime.generated.Common.Language;
public class CrossLanguageException extends RayException {
private Language language;
public CrossLanguageException(io.ray.runtime.generated.Common.RayException exception) {
super(String.format("An exception raised from %s:\n%s", exception.getLanguage().name(),
exception.getFormattedExceptionString()));
this.language = exception.getLanguage();
}
public Language getLanguage() {
return this.language;
}
}
@@ -1,4 +1,4 @@
package io.ray.api.exception;
package io.ray.runtime.exception;
/**
* Indicates that the actor died unexpectedly before finishing a task.
@@ -0,0 +1,40 @@
package io.ray.runtime.exception;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.serializer.Serializer;
public class RayException extends RuntimeException {
public RayException(String message) {
super(message);
}
public RayException(String message, Throwable cause) {
super(message, cause);
}
public byte[] toBytes() {
String formattedException = org.apache.commons.lang3.exception.ExceptionUtils
.getStackTrace(this);
io.ray.runtime.generated.Common.RayException.Builder builder =
io.ray.runtime.generated.Common.RayException.newBuilder();
builder.setLanguage(Language.JAVA);
builder.setFormattedExceptionString(formattedException);
builder.setSerializedException(ByteString.copyFrom(Serializer.encode(this).getLeft()));
return builder.build().toByteArray();
}
public static RayException fromBytes(byte[] serialized)
throws InvalidProtocolBufferException {
io.ray.runtime.generated.Common.RayException exception =
io.ray.runtime.generated.Common.RayException.parseFrom(serialized);
if (exception.getLanguage() == Language.JAVA) {
return Serializer
.decode(exception.getSerializedException().toByteArray(), RayException.class);
} else {
return new CrossLanguageException(exception);
}
}
}
@@ -0,0 +1,13 @@
package io.ray.runtime.exception;
import io.ray.runtime.util.NetworkUtil;
import io.ray.runtime.util.SystemUtil;
public class RayTaskException extends RayException {
public RayTaskException(String message, Throwable cause) {
super(String.format("(pid=%d, ip=%s) %s",
SystemUtil.pid(), NetworkUtil.getIpAddress(null), message), cause);
}
}
@@ -1,4 +1,4 @@
package io.ray.api.exception;
package io.ray.runtime.exception;
/**
* Indicates that the worker died unexpectedly while executing a task.
@@ -1,4 +1,4 @@
package io.ray.api.exception;
package io.ray.runtime.exception;
import io.ray.api.id.ObjectId;
@@ -1,11 +1,12 @@
package io.ray.runtime.object;
import io.ray.api.exception.RayActorException;
import io.ray.api.exception.RayTaskException;
import io.ray.api.exception.RayWorkerException;
import io.ray.api.exception.UnreconstructableException;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.api.id.ObjectId;
import io.ray.runtime.generated.Gcs.ErrorType;
import io.ray.runtime.exception.RayActorException;
import io.ray.runtime.exception.RayTaskException;
import io.ray.runtime.exception.RayWorkerException;
import io.ray.runtime.exception.UnreconstructableException;
import io.ray.runtime.generated.Common.ErrorType;
import io.ray.runtime.serializer.Serializer;
import java.nio.ByteBuffer;
import java.util.ArrayList;
@@ -70,7 +71,21 @@ public class ObjectSerializer {
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new UnreconstructableException(objectId);
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
return Serializer.decode(data, objectType);
// Serialization logic of task execution exception: an instance of
// `io.ray.runtime.exception.RayTaskException`
// -> a `RayException` protobuf message
// -> protobuf-serialized bytes
// -> MessagePack-serialized bytes.
// So here the `data` variable is MessagePack-serialized bytes, and the `serialized`
// variable is protobuf-serialized bytes. They are not the same.
byte[] serialized = Serializer.decode(data, byte[].class);
try {
return RayTaskException.fromBytes(serialized);
} catch (InvalidProtocolBufferException e) {
throw new IllegalArgumentException(
"Can't deserialize RayTaskException object: " + objectId
.toString());
}
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_PYTHON)) {
throw new IllegalArgumentException("Can't deserialize Python object: " + objectId
.toString());
@@ -107,7 +122,12 @@ public class ObjectSerializer {
}
return new NativeRayObject(bytes, OBJECT_METADATA_TYPE_RAW);
} else if (object instanceof RayTaskException) {
byte[] serializedBytes = Serializer.encode(object).getLeft();
RayTaskException taskException = (RayTaskException) object;
byte[] serializedBytes = Serializer.encode(taskException.toBytes()).getLeft();
// serializedBytes is MessagePack serialized bytes
// taskException.toBytes() is protobuf serialized bytes
// Only OBJECT_METADATA_TYPE_RAW is raw bytes,
// any other type should be the MessagePack serialized bytes.
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);
} else {
try {
@@ -3,10 +3,10 @@ package io.ray.runtime.object;
import com.google.common.base.Preconditions;
import io.ray.api.ObjectRef;
import io.ray.api.WaitResult;
import io.ray.api.exception.RayException;
import io.ray.api.id.ObjectId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.context.WorkerContext;
import io.ray.runtime.exception.RayException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -160,8 +160,7 @@ public abstract class ObjectStore {
* Delete a list of objects from the object store.
*
* @param objectIds IDs of the objects to delete.
* @param localOnly Whether only delete the objects in local node, or all nodes in the
* cluster.
* @param localOnly Whether only delete the objects in local node, or all nodes in the cluster.
* @param deleteCreatingTasks Whether also delete the tasks that created these objects.
*/
public abstract void delete(List<ObjectId> objectIds, boolean localOnly,
@@ -169,6 +168,7 @@ public abstract class ObjectStore {
/**
* Increase the local reference count for this object ID.
*
* @param workerId The ID of the worker to increase on.
* @param objectId The object ID to increase the reference count for.
*/
@@ -176,6 +176,7 @@ public abstract class ObjectStore {
/**
* Decrease the reference count for this object ID.
*
* @param workerId The ID of the worker to decrease on.
* @param objectId The object ID to decrease the reference count for.
*/
@@ -1,12 +1,12 @@
package io.ray.runtime.task;
import com.google.common.base.Preconditions;
import io.ray.api.exception.RayTaskException;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.exception.RayTaskException;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.runtime.functionmanager.RayFunction;
import io.ray.runtime.generated.Common.TaskType;
@@ -3,7 +3,7 @@ package io.ray.test;
import io.ray.api.ActorHandle;
import io.ray.api.Checkpointable;
import io.ray.api.Ray;
import io.ray.api.exception.RayActorException;
import io.ray.runtime.exception.RayActorException;
import io.ray.api.id.ActorId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.util.SystemUtil;
@@ -5,7 +5,7 @@ import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.Ray;
import io.ray.api.exception.UnreconstructableException;
import io.ray.runtime.exception.UnreconstructableException;
import io.ray.api.id.ActorId;
import io.ray.api.id.UniqueId;
import java.util.Collections;
@@ -11,6 +11,9 @@ import io.ray.api.function.PyActorMethod;
import io.ray.api.function.PyFunction;
import io.ray.runtime.actor.NativeActorHandle;
import io.ray.runtime.actor.NativePyActorHandle;
import io.ray.runtime.exception.CrossLanguageException;
import io.ray.runtime.exception.RayException;
import io.ray.runtime.generated.Common.Language;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
@@ -19,14 +22,11 @@ import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.Test;
public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
private static final Logger LOGGER = LoggerFactory.getLogger(CrossLanguageInvocationTest.class);
private static final String PYTHON_MODULE = "test_cross_language_invocation";
@Override
@@ -151,7 +151,6 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
PyFunction.of(PYTHON_MODULE, "py_func_call_java_actor", byte[].class),
"1".getBytes()).remote();
Assert.assertEquals(res.get(), "Counter1".getBytes());
}
@Test
@@ -188,6 +187,91 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
Assert.assertEquals(res.get(), "3".getBytes());
}
@Test
public void testExceptionSerialization() throws IOException {
try {
throw new RayException("Test Exception");
} catch (RayException e) {
String formattedException = org.apache.commons.lang3.exception.ExceptionUtils
.getStackTrace(e);
io.ray.runtime.generated.Common.RayException exception = io.ray.runtime.generated.Common.RayException
.parseFrom(e.toBytes());
Assert.assertEquals(exception.getFormattedExceptionString(), formattedException);
}
}
@Test
public void testRaiseExceptionFromPython() {
ObjectRef<Object> res = Ray.task(PyFunction.of(
PYTHON_MODULE, "py_func_python_raise_exception", Object.class)).remote();
try {
res.get();
} catch (RuntimeException ex) {
// ex is a Python exception(py_func_python_raise_exception) with no cause.
Assert.assertTrue(ex instanceof CrossLanguageException);
CrossLanguageException e = (CrossLanguageException) ex;
Assert.assertEquals(e.getLanguage(), Language.PYTHON);
// ex.cause is null.
Assert.assertNull(ex.getCause());
Assert.assertTrue(ex.getMessage().contains("ZeroDivisionError: division by zero"),
ex.getMessage());
return;
}
Assert.fail();
}
@Test
public void testThrowExceptionFromJava() {
ObjectRef<Object> res = Ray.task(PyFunction.of(
PYTHON_MODULE, "py_func_java_throw_exception", Object.class)).remote();
try {
res.get();
} catch (RuntimeException ex) {
final String message = ex.getMessage();
Assert.assertTrue(message.contains("py_func_java_throw_exception"), message);
Assert.assertTrue(message.contains("io.ray.test.CrossLanguageInvocationTest.throwException"),
message);
Assert.assertTrue(message.contains("java.lang.ArithmeticException: / by zero"), message);
return;
}
Assert.fail();
}
@Test
public void testRaiseExceptionFromNestPython() {
ObjectRef<Object> res = Ray.task(
PyFunction.of(PYTHON_MODULE, "py_func_nest_python_raise_exception", Object.class)).remote();
try {
res.get();
} catch (RuntimeException ex) {
final String message = ex.getMessage();
Assert.assertTrue(message.contains("py_func_nest_python_raise_exception"), message);
Assert.assertTrue(message.contains("io.ray.runtime.task.TaskExecutor.execute"), message);
Assert.assertTrue(message.contains("py_func_python_raise_exception"), message);
Assert.assertTrue(message.contains("ZeroDivisionError: division by zero"), message);
return;
}
Assert.fail();
}
@Test
public void testThrowExceptionFromNestJava() {
ObjectRef<Object> res = Ray.task(
PyFunction.of(PYTHON_MODULE, "py_func_nest_java_throw_exception", Object.class)).remote();
try {
res.get();
} catch (RuntimeException ex) {
final String message = ex.getMessage();
Assert.assertTrue(message.contains("py_func_nest_java_throw_exception"), message);
Assert.assertEquals(org.apache.commons.lang3.StringUtils
.countMatches(message, "io.ray.runtime.exception.RayTaskException"), 2);
Assert.assertTrue(message.contains("py_func_java_throw_exception"), message);
Assert.assertTrue(message.contains("java.lang.ArithmeticException: / by zero"), message);
return;
}
Assert.fail();
}
public static Object[] pack(int i, String s, double f, Object[] o) {
// This function will be called from test_cross_language_invocation.py
return new Object[]{i, s, f, o};
@@ -227,6 +311,23 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
return (byte[]) res.get();
}
@SuppressWarnings("ConstantOverflow")
public static Object throwException() {
return 1 / 0;
}
public static Object throwJavaException() {
ObjectRef<Object> res = Ray.task(
PyFunction.of(PYTHON_MODULE, "py_func_java_throw_exception", Object.class)).remote();
return res.get();
}
public static Object raisePythonException() {
ObjectRef<Object> res = Ray.task(
PyFunction.of(PYTHON_MODULE, "py_func_python_raise_exception", Object.class)).remote();
return res.get();
}
public static class TestActor {
public TestActor(byte[] v) {
@@ -3,11 +3,10 @@ package io.ray.test;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.exception.RayActorException;
import io.ray.api.exception.RayException;
import io.ray.api.exception.RayTaskException;
import io.ray.api.exception.RayWorkerException;
import io.ray.runtime.exception.RayActorException;
import io.ray.runtime.exception.RayWorkerException;
import io.ray.api.function.RayFunc0;
import io.ray.runtime.exception.RayTaskException;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
@@ -138,7 +137,7 @@ public class FailureTest extends BaseTest {
try {
Ray.get(Arrays.asList(obj1, obj2));
Assert.fail("Should throw RayException.");
} catch (RayException e) {
} catch (RuntimeException e) {
Instant end = Instant.now();
long duration = Duration.between(start, end).toMillis();
Assert.assertTrue(duration < 5000, "Should fail quickly. " +
@@ -4,7 +4,7 @@ import com.google.common.collect.ImmutableList;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.exception.RayActorException;
import io.ray.runtime.exception.RayActorException;
import java.util.function.BiConsumer;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
@@ -5,7 +5,6 @@ import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.WaitResult;
import io.ray.api.exception.RayException;
import io.ray.api.id.ActorId;
import java.util.ArrayList;
import java.util.List;
@@ -189,7 +188,7 @@ public class MultiThreadingTest extends BaseTest {
try {
// It wouldn't be OK to run them in another thread if not wrapped the runnable.
for (Runnable runnable : runnables) {
Assert.expectThrows(RayException.class, runnable::run);
Assert.expectThrows(RuntimeException.class, runnable::run);
}
} catch (Throwable ex) {
throwable[0] = ex;
@@ -83,6 +83,35 @@ def py_func_pass_python_actor_handle():
return ray.get(r)
@ray.remote
def py_func_python_raise_exception():
1 / 0
@ray.remote
def py_func_java_throw_exception():
f = ray.java_function("io.ray.test.CrossLanguageInvocationTest",
"throwException")
r = f.remote()
return ray.get(r)
@ray.remote
def py_func_nest_python_raise_exception():
f = ray.java_function("io.ray.test.CrossLanguageInvocationTest",
"raisePythonException")
r = f.remote()
return ray.get(r)
@ray.remote
def py_func_nest_java_throw_exception():
f = ray.java_function("io.ray.test.CrossLanguageInvocationTest",
"throwJavaException")
r = f.remote()
return ray.get(r)
@ray.remote
class Counter(object):
def __init__(self, value):