diff --git a/java/api/src/main/java/io/ray/api/runtimecontext/RuntimeContext.java b/java/api/src/main/java/io/ray/api/runtimecontext/RuntimeContext.java index 705497ca0..b5fa486aa 100644 --- a/java/api/src/main/java/io/ray/api/runtimecontext/RuntimeContext.java +++ b/java/api/src/main/java/io/ray/api/runtimecontext/RuntimeContext.java @@ -17,11 +17,7 @@ public interface RuntimeContext { */ ActorId getCurrentActorId(); - /** - * Returns true if the current actor was restarted, false if it's created for the first time. - * - *
Note, this method should only be called from an actor creation task.
- */
+ /** Returns true if the current actor was restarted, otherwise false. */
boolean wasCurrentActorRestarted();
/**
diff --git a/java/runtime/src/main/java/io/ray/runtime/context/RuntimeContextImpl.java b/java/runtime/src/main/java/io/ray/runtime/context/RuntimeContextImpl.java
index 2fec99d49..913586f77 100644
--- a/java/runtime/src/main/java/io/ray/runtime/context/RuntimeContextImpl.java
+++ b/java/runtime/src/main/java/io/ray/runtime/context/RuntimeContextImpl.java
@@ -7,7 +7,6 @@ import io.ray.api.runtimecontext.NodeInfo;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.config.RunMode;
-import io.ray.runtime.generated.Common.TaskType;
import java.util.List;
public class RuntimeContextImpl implements RuntimeContext {
@@ -33,14 +32,9 @@ public class RuntimeContextImpl implements RuntimeContext {
@Override
public boolean wasCurrentActorRestarted() {
- TaskType currentTaskType = runtime.getWorkerContext().getCurrentTaskType();
- Preconditions.checkState(
- currentTaskType == TaskType.ACTOR_CREATION_TASK,
- "This method can only be called from an actor creation task.");
if (isSingleProcess()) {
return false;
}
-
return runtime.getGcsClient().wasCurrentActorRestarted(getCurrentActorId());
}
diff --git a/java/test/src/main/java/io/ray/test/ActorRestartTest.java b/java/test/src/main/java/io/ray/test/ActorRestartTest.java
index 577487e6e..fe70e0867 100644
--- a/java/test/src/main/java/io/ray/test/ActorRestartTest.java
+++ b/java/test/src/main/java/io/ray/test/ActorRestartTest.java
@@ -22,7 +22,7 @@ public class ActorRestartTest extends BaseTest {
wasCurrentActorRestarted = Ray.getRuntimeContext().wasCurrentActorRestarted();
}
- public boolean wasCurrentActorRestarted() {
+ public boolean checkWasCurrentActorRestartedInActorCreationTask() {
return wasCurrentActorRestarted;
}
@@ -31,6 +31,10 @@ public class ActorRestartTest extends BaseTest {
return value;
}
+ public boolean checkWasCurrentActorRestartedInActorTask() {
+ return Ray.getRuntimeContext().wasCurrentActorRestarted();
+ }
+
public int getPid() {
return SystemUtil.pid();
}
@@ -43,30 +47,38 @@ public class ActorRestartTest extends BaseTest {
actor.task(Counter::increase).remote().get();
}
- Assert.assertFalse(actor.task(Counter::wasCurrentActorRestarted).remote().get());
+ // Check if actor was restarted.
+ Assert.assertFalse(
+ actor.task(Counter::checkWasCurrentActorRestartedInActorCreationTask).remote().get());
+ Assert.assertFalse(
+ actor.task(Counter::checkWasCurrentActorRestartedInActorTask).remote().get());
// Kill the actor process.
- int pid = actor.task(Counter::getPid).remote().get();
- Runtime.getRuntime().exec("kill -9 " + pid);
- // Wait for the actor to be killed.
- TimeUnit.SECONDS.sleep(1);
+ killActorProcess(actor);
int value = actor.task(Counter::increase).remote().get();
Assert.assertEquals(value, 1);
- Assert.assertTrue(actor.task(Counter::wasCurrentActorRestarted).remote().get());
+ // Check if actor was restarted again.
+ Assert.assertTrue(
+ actor.task(Counter::checkWasCurrentActorRestartedInActorCreationTask).remote().get());
+ Assert.assertTrue(actor.task(Counter::checkWasCurrentActorRestartedInActorTask).remote().get());
// Kill the actor process again.
- pid = actor.task(Counter::getPid).remote().get();
- Runtime.getRuntime().exec("kill -9 " + pid);
- TimeUnit.SECONDS.sleep(1);
+ killActorProcess(actor);
// Try calling increase on this actor again and this should fail.
- try {
- actor.task(Counter::increase).remote().get();
- Assert.fail("The above task didn't fail.");
- } catch (RayActorException e) {
- // We should receive a RayActorException because the actor is dead.
- }
+ Assert.assertThrows(
+ RayActorException.class, () -> actor.task(Counter::increase).remote().get());
+ }
+
+ /** The helper to kill a counter actor. */
+ private static void killActorProcess(ActorHandle