mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 21:08:50 +08:00
Implement actor checkpointing (#3839)
* Implement Actor checkpointing * docs * fix * fix * fix * move restore-from-checkpoint to HandleActorStateTransition * Revert "move restore-from-checkpoint to HandleActorStateTransition" This reverts commit 9aa4447c1e3e321f42a1d895d72f17098b72de12. * resubmit waiting tasks when actor frontier restored * add doc about num_actor_checkpoints_to_keep=1 * add num_actor_checkpoints_to_keep to Cython * add checkpoint_expired api * check if actor class is abstract * change checkpoint_ids to long string * implement java * Refactor to delay actor creation publish until checkpoint is resumed * debug, lint * Erase from checkpoints to restore if task fails * fix lint * update comments * avoid duplicated actor notification log * fix unintended change * add actor_id to checkpoint_expired * small java updates * make checkpoint info per actor * lint * Remove logging * Remove old actor checkpointing Python code, move new checkpointing code to FunctionActionManager * Replace old actor checkpointing tests * Fix test and lint * address comments * consolidate kill_actor * Remove __ray_checkpoint__ * fix non-ascii char * Loosen test checks * fix java * fix sphinx-build
This commit is contained in:
@@ -0,0 +1,99 @@
|
||||
package org.ray.api;
|
||||
|
||||
import java.util.List;
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
public interface Checkpointable {
|
||||
|
||||
class CheckpointContext {
|
||||
|
||||
/**
|
||||
* Actor's ID.
|
||||
*/
|
||||
public final UniqueId actorId;
|
||||
/**
|
||||
* Number of tasks executed since last checkpoint.
|
||||
*/
|
||||
public final int numTasksSinceLastCheckpoint;
|
||||
/**
|
||||
* Time elapsed since last checkpoint, in milliseconds.
|
||||
*/
|
||||
public final long timeElapsedMsSinceLastCheckpoint;
|
||||
|
||||
public CheckpointContext(UniqueId actorId, int numTasksSinceLastCheckpoint,
|
||||
long timeElapsedMsSinceLastCheckpoint) {
|
||||
this.actorId = actorId;
|
||||
this.numTasksSinceLastCheckpoint = numTasksSinceLastCheckpoint;
|
||||
this.timeElapsedMsSinceLastCheckpoint = timeElapsedMsSinceLastCheckpoint;
|
||||
}
|
||||
}
|
||||
|
||||
class Checkpoint {
|
||||
|
||||
/**
|
||||
* Checkpoint's ID.
|
||||
*/
|
||||
public final UniqueId checkpointId;
|
||||
/**
|
||||
* Checkpoint's timestamp.
|
||||
*/
|
||||
public final long timestamp;
|
||||
|
||||
public Checkpoint(UniqueId checkpointId, long timestamp) {
|
||||
this.checkpointId = checkpointId;
|
||||
this.timestamp = timestamp;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether this actor needs to be checkpointed.
|
||||
*
|
||||
* This method will be called after every task. You should implement this callback to decide
|
||||
* whether this actor needs to be checkpointed at this time, based on the checkpoint context, or
|
||||
* any other factors.
|
||||
*
|
||||
* @param checkpointContext An object that contains info about last checkpoint.
|
||||
* @return A boolean value that indicates whether this actor needs to be checkpointed.
|
||||
*/
|
||||
boolean shouldCheckpoint(CheckpointContext checkpointContext);
|
||||
|
||||
/**
|
||||
* Save a checkpoint to persistent storage.
|
||||
*
|
||||
* If `shouldCheckpoint` returns true, this method will be called. You should implement this
|
||||
* callback to save actor's checkpoint and the given checkpoint id to persistent storage.
|
||||
*
|
||||
* @param actorId Actor's ID.
|
||||
* @param checkpointId An ID that represents this actor's current state in GCS. You should
|
||||
* save this checkpoint ID together with actor's checkpoint data.
|
||||
*/
|
||||
void saveCheckpoint(UniqueId actorId, UniqueId checkpointId);
|
||||
|
||||
/**
|
||||
* Load actor's previous checkpoint, and restore actor's state.
|
||||
*
|
||||
* This method will be called when an actor is reconstructed, after actor's constructor. If the
|
||||
* actor needs to restore from previous checkpoint, this function should restore actor's state and
|
||||
* return the checkpoint ID. Otherwise, it should do nothing and return null.
|
||||
*
|
||||
* @param actorId Actor's ID.
|
||||
* @param availableCheckpoints A list of available checkpoint IDs and their timestamps, sorted
|
||||
* by timestamp in descending order. Note, this method must return the ID of one checkpoint in
|
||||
* this list, or null. Otherwise, an exception will be thrown.
|
||||
* @return The ID of the checkpoint from which the actor was resumed, or null if the actor should
|
||||
* restart from the beginning.
|
||||
*/
|
||||
UniqueId loadCheckpoint(UniqueId actorId, List<Checkpoint> availableCheckpoints);
|
||||
|
||||
/**
|
||||
* Delete an expired checkpoint;
|
||||
*
|
||||
* This method will be called when an checkpoint is expired. You should implement this method to
|
||||
* delete your application checkpoint data. Note, the maximum number of checkpoints kept in the
|
||||
* backend can be configured at `RayConfig.num_actor_checkpoints_to_keep`.
|
||||
*
|
||||
* @param actorId ID of the actor.
|
||||
* @param checkpointId ID of the checkpoint that has expired.
|
||||
*/
|
||||
void checkpointExpired(UniqueId actorId, UniqueId checkpointId);
|
||||
}
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
<suppressions>
|
||||
<suppress checks="OperatorWrap" files=".*" />
|
||||
<suppress checks="JavadocParagraph" files=".*" />
|
||||
<suppress checks="MemberNameCheck" files="PathConfig.java"/>
|
||||
<suppress checks="MemberNameCheck" files="RayParameters.java"/>
|
||||
<suppress checks="AbbreviationAsWordInNameCheck" files="RayParameters.java"/>
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.ray.api.Checkpointable.Checkpoint;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.WorkerMode;
|
||||
import org.ray.runtime.gcs.RedisClient;
|
||||
import org.ray.runtime.generated.ActorCheckpointIdData;
|
||||
import org.ray.runtime.generated.TablePrefix;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.raylet.RayletClientImpl;
|
||||
import org.ray.runtime.runner.RunManager;
|
||||
import org.ray.runtime.util.UniqueIdUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -21,7 +31,14 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayNativeRuntime.class);
|
||||
|
||||
private RedisClient redisClient = null;
|
||||
/**
|
||||
* Redis client of the primary shard.
|
||||
*/
|
||||
private RedisClient redisClient;
|
||||
/**
|
||||
* Redis clients of all shards.
|
||||
*/
|
||||
private List<RedisClient> redisClients;
|
||||
private RunManager manager = null;
|
||||
|
||||
public RayNativeRuntime(RayConfig rayConfig) {
|
||||
@@ -69,7 +86,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
manager = new RunManager(rayConfig);
|
||||
manager.startRayProcesses(true);
|
||||
}
|
||||
redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
|
||||
|
||||
initRedisClients();
|
||||
|
||||
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
|
||||
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
|
||||
@@ -88,6 +106,16 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
|
||||
}
|
||||
|
||||
private void initRedisClients() {
|
||||
redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
|
||||
int numRedisShards = Integer.valueOf(redisClient.get("NumRedisShards", null));
|
||||
List<String> addresses = redisClient.lrange("RedisShards", 0, -1);
|
||||
Preconditions.checkState(numRedisShards == addresses.size());
|
||||
redisClients = addresses.stream().map(RedisClient::new)
|
||||
.collect(Collectors.toList());
|
||||
redisClients.add(redisClient);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
if (null != manager) {
|
||||
@@ -116,4 +144,33 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the available checkpoints for the given actor ID, return a list sorted by checkpoint
|
||||
* timestamp in descending order.
|
||||
*/
|
||||
List<Checkpoint> getCheckpointsForActor(UniqueId actorId) {
|
||||
List<Checkpoint> checkpoints = new ArrayList<>();
|
||||
// TODO(hchen): implement the equivalent of Python's `GlobalState`, to avoid looping over
|
||||
// all redis shards..
|
||||
String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID);
|
||||
byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes());
|
||||
for (RedisClient client : redisClients) {
|
||||
byte[] result = client.get(key, null);
|
||||
if (result == null) {
|
||||
continue;
|
||||
}
|
||||
ActorCheckpointIdData data = ActorCheckpointIdData
|
||||
.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result));
|
||||
|
||||
UniqueId[] checkpointIds
|
||||
= UniqueIdUtil.getUniqueIdsFromByteBuffer(data.checkpointIdsAsByteBuffer());
|
||||
|
||||
for (int i = 0; i < checkpointIds.length; i++) {
|
||||
checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i)));
|
||||
}
|
||||
break;
|
||||
}
|
||||
checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp));
|
||||
return checkpoints;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.ray.api.Checkpointable;
|
||||
import org.ray.api.Checkpointable.Checkpoint;
|
||||
import org.ray.api.Checkpointable.CheckpointContext;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
@@ -17,6 +23,9 @@ public class Worker {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class);
|
||||
|
||||
// TODO(hchen): Use the C++ config.
|
||||
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
|
||||
|
||||
private final AbstractRayRuntime runtime;
|
||||
|
||||
/**
|
||||
@@ -34,6 +43,22 @@ public class Worker {
|
||||
*/
|
||||
private Exception actorCreationException = null;
|
||||
|
||||
/**
|
||||
* Number of tasks executed since last actor checkpoint.
|
||||
*/
|
||||
private int numTasksSinceLastCheckpoint = 0;
|
||||
|
||||
/**
|
||||
* IDs of this actor's previous checkpoints.
|
||||
*/
|
||||
private List<UniqueId> checkpointIds;
|
||||
|
||||
/**
|
||||
* Timestamp of the last actor checkpoint.
|
||||
*/
|
||||
private long lastCheckpointTimestamp = 0;
|
||||
|
||||
|
||||
public Worker(AbstractRayRuntime runtime) {
|
||||
this.runtime = runtime;
|
||||
}
|
||||
@@ -80,8 +105,12 @@ public class Worker {
|
||||
}
|
||||
// Set result
|
||||
if (!spec.isActorCreationTask()) {
|
||||
if (spec.isActorTask()) {
|
||||
maybeSaveCheckpoint(actor, spec.actorId);
|
||||
}
|
||||
runtime.put(returnId, result);
|
||||
} else {
|
||||
maybeLoadCheckpoint(result, returnId);
|
||||
currentActor = result;
|
||||
currentActorId = returnId;
|
||||
}
|
||||
@@ -98,4 +127,61 @@ public class Worker {
|
||||
Thread.currentThread().setContextClassLoader(oldLoader);
|
||||
}
|
||||
}
|
||||
|
||||
private void maybeSaveCheckpoint(Object actor, UniqueId actorId) {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
|
||||
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
|
||||
return;
|
||||
}
|
||||
CheckpointContext checkpointContext = new CheckpointContext(actorId,
|
||||
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
|
||||
Checkpointable checkpointable = (Checkpointable) actor;
|
||||
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
UniqueId checkpointId = runtime.rayletClient.prepareCheckpoint(actorId);
|
||||
checkpointIds.add(checkpointId);
|
||||
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
|
||||
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
|
||||
checkpointIds.remove(0);
|
||||
}
|
||||
checkpointable.saveCheckpoint(actorId, checkpointId);
|
||||
}
|
||||
|
||||
private void maybeLoadCheckpoint(Object actor, UniqueId actorId) {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
|
||||
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
checkpointIds = new ArrayList<>();
|
||||
List<Checkpoint> availableCheckpoints = ((RayNativeRuntime) runtime)
|
||||
.getCheckpointsForActor(actorId);
|
||||
if (availableCheckpoints.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints);
|
||||
if (checkpointId != null) {
|
||||
boolean checkpointValid = false;
|
||||
for (Checkpoint checkpoint : availableCheckpoints) {
|
||||
if (checkpoint.checkpointId.equals(checkpointId)) {
|
||||
checkpointValid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Preconditions.checkArgument(checkpointValid,
|
||||
"'loadCheckpoint' must return a checkpoint ID that exists in the "
|
||||
+ "'availableCheckpoints' list, or null.");
|
||||
runtime.rayletClient.notifyActorResumedFromCheckpoint(actorId, checkpointId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package org.ray.runtime.gcs;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.ray.runtime.util.StringUtil;
|
||||
@@ -77,7 +78,11 @@ public class RedisClient {
|
||||
return jedis.hget(key, field);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public List<String> lrange(String key, long start, long end) {
|
||||
try (Jedis jedis = jedisPool.getResource()) {
|
||||
return jedis.lrange(key, start, end);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.google.common.collect.ImmutableList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.id.UniqueId;
|
||||
@@ -94,4 +95,14 @@ public class MockRayletClient implements RayletClient {
|
||||
public void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly) {
|
||||
return;
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId prepareCheckpoint(UniqueId actorId) {
|
||||
throw new NotImplementedException("Not implemented.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId) {
|
||||
throw new NotImplementedException("Not implemented.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,4 +25,8 @@ public interface RayletClient {
|
||||
timeoutMs, UniqueId currentTaskId);
|
||||
|
||||
void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly);
|
||||
|
||||
UniqueId prepareCheckpoint(UniqueId actorId);
|
||||
|
||||
void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId);
|
||||
}
|
||||
|
||||
@@ -127,6 +127,17 @@ public class RayletClientImpl implements RayletClient {
|
||||
nativeFreePlasmaObjects(client, objectIdsArray, localOnly);
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId prepareCheckpoint(UniqueId actorId) {
|
||||
return new UniqueId(nativePrepareCheckpoint(client, actorId.getBytes()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId) {
|
||||
nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes());
|
||||
}
|
||||
|
||||
|
||||
private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) {
|
||||
bb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
TaskInfo info = TaskInfo.getRootAsTaskInfo(bb);
|
||||
@@ -142,7 +153,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
|
||||
// Deserialize new actor handles
|
||||
UniqueId[] newActorHandles = UniqueIdUtil.getUniqueIdsFromByteBuffer(
|
||||
info.newActorHandlesAsByteBuffer());
|
||||
info.newActorHandlesAsByteBuffer());
|
||||
|
||||
// Deserialize args
|
||||
FunctionArg[] args = new FunctionArg[info.argsLength()];
|
||||
@@ -208,7 +219,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
int dataOffset = 0;
|
||||
if (task.args[i].id != null) {
|
||||
objectIdOffset = fbb.createString(
|
||||
UniqueIdUtil.concatUniqueIds(new UniqueId[] {task.args[i].id}));
|
||||
UniqueIdUtil.concatUniqueIds(new UniqueId[]{task.args[i].id}));
|
||||
} else {
|
||||
objectIdOffset = fbb.createString("");
|
||||
}
|
||||
@@ -258,7 +269,6 @@ public class RayletClientImpl implements RayletClient {
|
||||
actorIdOffset,
|
||||
actorHandleIdOffset,
|
||||
actorCounter,
|
||||
false,
|
||||
newActorHandlesOffset,
|
||||
argsOffset,
|
||||
returnsOffset,
|
||||
@@ -271,8 +281,8 @@ public class RayletClientImpl implements RayletClient {
|
||||
|
||||
if (buffer.remaining() > TASK_SPEC_BUFFER_SIZE) {
|
||||
LOGGER.error(
|
||||
"Allocated buffer is not enough to transfer the task specification: {}vs {}",
|
||||
TASK_SPEC_BUFFER_SIZE, buffer.remaining());
|
||||
"Allocated buffer is not enough to transfer the task specification: {} vs {}",
|
||||
TASK_SPEC_BUFFER_SIZE, buffer.remaining());
|
||||
throw new RuntimeException("Allocated buffer is not enough to transfer to task.");
|
||||
}
|
||||
return buffer;
|
||||
@@ -323,4 +333,8 @@ public class RayletClientImpl implements RayletClient {
|
||||
private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds,
|
||||
boolean localOnly) throws RayException;
|
||||
|
||||
private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId);
|
||||
|
||||
private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId,
|
||||
byte[] checkpointId);
|
||||
}
|
||||
|
||||
@@ -4,10 +4,13 @@ import static org.ray.runtime.util.SystemUtil.pid;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.ray.api.Checkpointable;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
@@ -17,10 +20,10 @@ public class ActorReconstructionTest extends BaseTest {
|
||||
@RayRemote()
|
||||
public static class Counter {
|
||||
|
||||
private int value = 0;
|
||||
protected int value = 0;
|
||||
|
||||
public int increase(int delta) {
|
||||
value += delta;
|
||||
public int increase() {
|
||||
value += 1;
|
||||
return value;
|
||||
}
|
||||
|
||||
@@ -35,7 +38,7 @@ public class ActorReconstructionTest extends BaseTest {
|
||||
RayActor<Counter> actor = Ray.createActor(Counter::new, options);
|
||||
// Call increase 3 times.
|
||||
for (int i = 0; i < 3; i++) {
|
||||
Ray.call(Counter::increase, actor, 1).get();
|
||||
Ray.call(Counter::increase, actor).get();
|
||||
}
|
||||
|
||||
// Kill the actor process.
|
||||
@@ -45,7 +48,7 @@ public class ActorReconstructionTest extends BaseTest {
|
||||
TimeUnit.SECONDS.sleep(1);
|
||||
|
||||
// Try calling increase on this actor again and check the value is now 4.
|
||||
int value = Ray.call(Counter::increase, actor, 1).get();
|
||||
int value = Ray.call(Counter::increase, actor).get();
|
||||
Assert.assertEquals(value, 4);
|
||||
|
||||
// Kill the actor process again.
|
||||
@@ -55,7 +58,7 @@ public class ActorReconstructionTest extends BaseTest {
|
||||
|
||||
// Try calling increase on this actor again and this should fail.
|
||||
try {
|
||||
Ray.call(Counter::increase, actor, 1).get();
|
||||
Ray.call(Counter::increase, actor).get();
|
||||
Assert.fail("The above task didn't fail.");
|
||||
} catch (StringIndexOutOfBoundsException e) {
|
||||
// Raylet backend will put invalid data in task's result to indicate the task has failed.
|
||||
@@ -64,4 +67,71 @@ public class ActorReconstructionTest extends BaseTest {
|
||||
// instead of throwing this exception.
|
||||
}
|
||||
}
|
||||
|
||||
public static class CheckpointableCounter extends Counter implements Checkpointable {
|
||||
|
||||
private boolean resumedFromCheckpoint = false;
|
||||
private boolean increaseCalled = false;
|
||||
|
||||
@Override
|
||||
public int increase() {
|
||||
increaseCalled = true;
|
||||
return super.increase();
|
||||
}
|
||||
|
||||
public boolean wasResumedFromCheckpoint() {
|
||||
return resumedFromCheckpoint;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean shouldCheckpoint(CheckpointContext checkpointContext) {
|
||||
// Checkpoint the actor when value is increased to 3.
|
||||
boolean shouldCheckpoint = increaseCalled && value == 3;
|
||||
increaseCalled = false;
|
||||
return shouldCheckpoint;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveCheckpoint(UniqueId actorId, UniqueId checkpointId) {
|
||||
// In practice, user should save the checkpoint id and data to a persistent store.
|
||||
// But for simplicity, we don't do that in this unit test.
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId loadCheckpoint(UniqueId actorId, List<Checkpoint> availableCheckpoints) {
|
||||
// Restore previous value and return checkpoint id.
|
||||
this.value = 3;
|
||||
this.resumedFromCheckpoint = true;
|
||||
return availableCheckpoints.get(availableCheckpoints.size() - 1).checkpointId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkpointExpired(UniqueId actorId, UniqueId checkpointId) {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActorCheckpointing() throws IOException, InterruptedException {
|
||||
ActorCreationOptions options = new ActorCreationOptions(new HashMap<>(), 1);
|
||||
RayActor<CheckpointableCounter> actor = Ray.createActor(CheckpointableCounter::new, options);
|
||||
// Call increase 3 times.
|
||||
for (int i = 0; i < 3; i++) {
|
||||
Ray.call(CheckpointableCounter::increase, actor).get();
|
||||
}
|
||||
// Assert that the actor wasn't resumed from a checkpoint.
|
||||
Assert.assertFalse(Ray.call(CheckpointableCounter::wasResumedFromCheckpoint, actor).get());
|
||||
|
||||
// Kill the actor process.
|
||||
int pid = Ray.call(CheckpointableCounter::getPid, actor).get();
|
||||
Runtime.getRuntime().exec("kill -9 " + pid);
|
||||
// Wait for the actor to be killed.
|
||||
TimeUnit.SECONDS.sleep(1);
|
||||
|
||||
// Try calling increase on this actor again and check the value is now 4.
|
||||
int value = Ray.call(CheckpointableCounter::increase, actor).get();
|
||||
Assert.assertEquals(value, 4);
|
||||
// Assert that the actor was resumed from a checkpoint.
|
||||
Assert.assertTrue(Ray.call(CheckpointableCounter::wasResumedFromCheckpoint, actor).get());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user