diff --git a/java/api/src/main/java/io/ray/api/runtimecontext/RuntimeContext.java b/java/api/src/main/java/io/ray/api/runtimecontext/RuntimeContext.java index 705497ca0..b5fa486aa 100644 --- a/java/api/src/main/java/io/ray/api/runtimecontext/RuntimeContext.java +++ b/java/api/src/main/java/io/ray/api/runtimecontext/RuntimeContext.java @@ -17,11 +17,7 @@ public interface RuntimeContext { */ ActorId getCurrentActorId(); - /** - * Returns true if the current actor was restarted, false if it's created for the first time. - * - *

Note, this method should only be called from an actor creation task. - */ + /** Returns true if the current actor was restarted, otherwise false. */ boolean wasCurrentActorRestarted(); /** diff --git a/java/runtime/src/main/java/io/ray/runtime/context/RuntimeContextImpl.java b/java/runtime/src/main/java/io/ray/runtime/context/RuntimeContextImpl.java index 2fec99d49..913586f77 100644 --- a/java/runtime/src/main/java/io/ray/runtime/context/RuntimeContextImpl.java +++ b/java/runtime/src/main/java/io/ray/runtime/context/RuntimeContextImpl.java @@ -7,7 +7,6 @@ import io.ray.api.runtimecontext.NodeInfo; import io.ray.api.runtimecontext.RuntimeContext; import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.config.RunMode; -import io.ray.runtime.generated.Common.TaskType; import java.util.List; public class RuntimeContextImpl implements RuntimeContext { @@ -33,14 +32,9 @@ public class RuntimeContextImpl implements RuntimeContext { @Override public boolean wasCurrentActorRestarted() { - TaskType currentTaskType = runtime.getWorkerContext().getCurrentTaskType(); - Preconditions.checkState( - currentTaskType == TaskType.ACTOR_CREATION_TASK, - "This method can only be called from an actor creation task."); if (isSingleProcess()) { return false; } - return runtime.getGcsClient().wasCurrentActorRestarted(getCurrentActorId()); } diff --git a/java/test/src/main/java/io/ray/test/ActorRestartTest.java b/java/test/src/main/java/io/ray/test/ActorRestartTest.java index 577487e6e..fe70e0867 100644 --- a/java/test/src/main/java/io/ray/test/ActorRestartTest.java +++ b/java/test/src/main/java/io/ray/test/ActorRestartTest.java @@ -22,7 +22,7 @@ public class ActorRestartTest extends BaseTest { wasCurrentActorRestarted = Ray.getRuntimeContext().wasCurrentActorRestarted(); } - public boolean wasCurrentActorRestarted() { + public boolean checkWasCurrentActorRestartedInActorCreationTask() { return wasCurrentActorRestarted; } @@ -31,6 +31,10 @@ public class ActorRestartTest extends BaseTest { return value; } + public boolean checkWasCurrentActorRestartedInActorTask() { + return Ray.getRuntimeContext().wasCurrentActorRestarted(); + } + public int getPid() { return SystemUtil.pid(); } @@ -43,30 +47,38 @@ public class ActorRestartTest extends BaseTest { actor.task(Counter::increase).remote().get(); } - Assert.assertFalse(actor.task(Counter::wasCurrentActorRestarted).remote().get()); + // Check if actor was restarted. + Assert.assertFalse( + actor.task(Counter::checkWasCurrentActorRestartedInActorCreationTask).remote().get()); + Assert.assertFalse( + actor.task(Counter::checkWasCurrentActorRestartedInActorTask).remote().get()); // Kill the actor process. - int pid = actor.task(Counter::getPid).remote().get(); - Runtime.getRuntime().exec("kill -9 " + pid); - // Wait for the actor to be killed. - TimeUnit.SECONDS.sleep(1); + killActorProcess(actor); int value = actor.task(Counter::increase).remote().get(); Assert.assertEquals(value, 1); - Assert.assertTrue(actor.task(Counter::wasCurrentActorRestarted).remote().get()); + // Check if actor was restarted again. + Assert.assertTrue( + actor.task(Counter::checkWasCurrentActorRestartedInActorCreationTask).remote().get()); + Assert.assertTrue(actor.task(Counter::checkWasCurrentActorRestartedInActorTask).remote().get()); // Kill the actor process again. - pid = actor.task(Counter::getPid).remote().get(); - Runtime.getRuntime().exec("kill -9 " + pid); - TimeUnit.SECONDS.sleep(1); + killActorProcess(actor); // Try calling increase on this actor again and this should fail. - try { - actor.task(Counter::increase).remote().get(); - Assert.fail("The above task didn't fail."); - } catch (RayActorException e) { - // We should receive a RayActorException because the actor is dead. - } + Assert.assertThrows( + RayActorException.class, () -> actor.task(Counter::increase).remote().get()); + } + + /** The helper to kill a counter actor. */ + private static void killActorProcess(ActorHandle actor) + throws IOException, InterruptedException { + // Kill the actor process. + int pid = actor.task(Counter::getPid).remote().get(); + Process p = Runtime.getRuntime().exec("kill -9 " + pid); + // Wait for the actor to be killed. + TimeUnit.SECONDS.sleep(1); } }