From f5412c0417b32fe483499409bfdb734272c0642c Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Thu, 31 Dec 2020 19:47:35 +0800 Subject: [PATCH] [Java] Avoid failure of serializing a user-defined unserializable exception. (#13119) --- .../runtime/exception/RayTaskException.java | 4 ++ .../io/ray/runtime/task/TaskExecutor.java | 24 +++++++-- .../java/io/ray/test/TaskExceptionTest.java | 49 +++++++++++++++++++ 3 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 java/test/src/main/java/io/ray/test/TaskExceptionTest.java diff --git a/java/runtime/src/main/java/io/ray/runtime/exception/RayTaskException.java b/java/runtime/src/main/java/io/ray/runtime/exception/RayTaskException.java index 3d382d98b..e32bf314f 100644 --- a/java/runtime/src/main/java/io/ray/runtime/exception/RayTaskException.java +++ b/java/runtime/src/main/java/io/ray/runtime/exception/RayTaskException.java @@ -5,6 +5,10 @@ import io.ray.runtime.util.SystemUtil; public class RayTaskException extends RayException { + public RayTaskException(String message) { + super(message); + } + public RayTaskException(String message, Throwable cause) { super( String.format( diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java index 99a5cf7c8..b7b707c63 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskExecutor.java @@ -17,6 +17,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.lang3.exception.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -159,9 +160,26 @@ public abstract class TaskExecutor { boolean hasReturn = rayFunction != null && rayFunction.hasReturn(); boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals(""); if (hasReturn || isCrossLanguage) { - returnObjects.add( - ObjectSerializer.serialize( - new RayTaskException("Error executing task " + taskId, e))); + NativeRayObject serializedException; + try { + serializedException = + ObjectSerializer.serialize( + new RayTaskException("Error executing task " + taskId, e)); + } catch (Exception unserializable) { + // We should try-catch `ObjectSerializer.serialize` here. Because otherwise if the + // application-level exception is not serializable. `ObjectSerializer.serialize` + // will throw an exception and crash the worker. + // Refer to the case `TaskExceptionTest.java` for more details. + LOGGER.warn("Failed to serialize the exception to a RayObject.", unserializable); + serializedException = + ObjectSerializer.serialize( + new RayTaskException( + String.format( + "Error executing task %s with the exception: %s", + taskId, ExceptionUtils.getStackTrace(e)))); + } + Preconditions.checkNotNull(serializedException); + returnObjects.add(serializedException); } } else { actorContext.actorCreationException = e; diff --git a/java/test/src/main/java/io/ray/test/TaskExceptionTest.java b/java/test/src/main/java/io/ray/test/TaskExceptionTest.java new file mode 100644 index 000000000..3a3d84754 --- /dev/null +++ b/java/test/src/main/java/io/ray/test/TaskExceptionTest.java @@ -0,0 +1,49 @@ +package io.ray.test; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class TaskExceptionTest extends BaseTest { + + private static class UnserializableClass {} + + private static class UnserializableException extends RuntimeException { + + public UnserializableException() { + super(); + } + + private UnserializableClass unSerializableClass = new UnserializableClass(); + } + + private static class MyActor { + + public String sayHi() { + return "Hi"; + } + + public String throwUnserializableException() { + throw new UnserializableException(); + } + } + + private static String throwUnserializableException() { + throw new UnserializableException(); + } + + @Test + public void testThrowUnserializableExceptionInNormalTask() { + // Test that if a task throws an unserializable exception, the worker won't crash. + Assert.assertThrows( + (() -> Ray.task(TaskExceptionTest::throwUnserializableException).remote().get())); + } + + @Test + public void testThrowUnserializableExceptionInActorTask() { + ActorHandle myActor = Ray.actor(MyActor::new).remote(); + Assert.assertEquals("Hi", myActor.task(MyActor::sayHi).remote().get()); + Assert.assertThrows((() -> myActor.task(MyActor::throwUnserializableException).remote().get())); + } +}