diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 117153423..049e4a9c0 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -48,11 +48,6 @@ public abstract class AbstractRayRuntime implements RayRuntime { protected ObjectStoreProxy objectStoreProxy; protected FunctionManager functionManager; - /** - * Actor ID -> local actor instance. - */ - Map localActors = new HashMap<>(); - public AbstractRayRuntime(RayConfig rayConfig) { this.rayConfig = rayConfig; functionManager = new FunctionManager(rayConfig.driverResourcePath); diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index 929f343be..d97f4d2a6 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -1,5 +1,6 @@ package org.ray.runtime; +import com.google.common.base.Preconditions; import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.RayFunction; @@ -18,6 +19,21 @@ public class Worker { private final AbstractRayRuntime runtime; + /** + * The current actor object, if this worker is an actor, otherwise null. + */ + private Object currentActor = null; + + /** + * Id of the current actor object, if the worker is an actor, otherwise NIL. + */ + private UniqueId currentActorId = UniqueId.NIL; + + /** + * The exception that failed the actor creation task, if any. + */ + private Exception actorCreationException = null; + public Worker(AbstractRayRuntime runtime) { this.runtime = runtime; } @@ -46,7 +62,14 @@ public class Worker { runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader); Thread.currentThread().setContextClassLoader(rayFunction.classLoader); // Get local actor object and arguments. - Object actor = spec.isActorTask() ? runtime.localActors.get(spec.actorId) : null; + Object actor = null; + if (spec.isActorTask()) { + Preconditions.checkState(spec.actorId.equals(currentActorId)); + if (actorCreationException != null) { + throw actorCreationException; + } + actor = currentActor; + } Object[] args = ArgumentsBuilder.unwrap(spec, rayFunction.classLoader); // Execute the task. Object result; @@ -59,12 +82,18 @@ public class Worker { if (!spec.isActorCreationTask()) { runtime.put(returnId, result); } else { - runtime.localActors.put(returnId, result); + currentActor = result; + currentActorId = returnId; } LOGGER.info("Finished executing task {}", spec.taskId); } catch (Exception e) { LOGGER.error("Error executing task " + spec, e); - runtime.put(returnId, new RayException("Error executing task " + spec, e)); + if (!spec.isActorCreationTask()) { + runtime.put(returnId, new RayException("Error executing task " + spec, e)); + } else { + actorCreationException = e; + currentActorId = returnId; + } } finally { runtime.getWorkerContext().setCurrentTask(null, null); Thread.currentThread().setContextClassLoader(oldLoader); diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java new file mode 100644 index 000000000..0e9d6da91 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -0,0 +1,63 @@ +package org.ray.api.test; + +import org.junit.Assert; +import org.junit.Test; +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.RayObject; +import org.ray.api.exception.RayException; + +public class FailureTest extends BaseTest { + + private static final String EXCEPTION_MESSAGE = "Oops"; + + public static int badFunc() { + throw new RuntimeException(EXCEPTION_MESSAGE); + } + + public static class BadActor { + + public BadActor(boolean failOnCreation) { + if (failOnCreation) { + throw new RuntimeException(EXCEPTION_MESSAGE); + } + } + + public int func() { + throw new RuntimeException(EXCEPTION_MESSAGE); + } + } + + private static void assertTaskFail(RayObject rayObject) { + try { + rayObject.get(); + Assert.fail("Task didn't fail."); + } catch (RayException e) { + e.printStackTrace(); + Throwable rootCause = e.getCause(); + while (rootCause.getCause() != null) { + rootCause = rootCause.getCause(); + } + Assert.assertTrue(rootCause instanceof RuntimeException); + Assert.assertEquals(rootCause.getMessage(), EXCEPTION_MESSAGE); + } + } + + @Test + public void testNormalTaskFailure() { + assertTaskFail(Ray.call(FailureTest::badFunc)); + } + + @Test + public void testActorCreationFailure() { + RayActor actor = Ray.createActor(BadActor::new, true); + assertTaskFail(Ray.call(BadActor::func, actor)); + } + + @Test + public void testActorTaskFailure() { + RayActor actor = Ray.createActor(BadActor::new, false); + assertTaskFail(Ray.call(BadActor::func, actor)); + } +} +