Implement actor checkpointing (#3839)

* Implement Actor checkpointing

* docs

* fix

* fix

* fix

* move restore-from-checkpoint to HandleActorStateTransition

* Revert "move restore-from-checkpoint to HandleActorStateTransition"

This reverts commit 9aa4447c1e3e321f42a1d895d72f17098b72de12.

* resubmit waiting tasks when actor frontier restored

* add doc about num_actor_checkpoints_to_keep=1

* add num_actor_checkpoints_to_keep to Cython

* add checkpoint_expired api

* check if actor class is abstract

* change checkpoint_ids to long string

* implement java

* Refactor to delay actor creation publish until checkpoint is resumed

* debug, lint

* Erase from checkpoints to restore if task fails

* fix lint

* update comments

* avoid duplicated actor notification log

* fix unintended change

* add actor_id to checkpoint_expired

* small java updates

* make checkpoint info per actor

* lint

* Remove logging

* Remove old actor checkpointing Python code, move new checkpointing code to FunctionActionManager

* Replace old actor checkpointing tests

* Fix test and lint

* address comments

* consolidate kill_actor

* Remove __ray_checkpoint__

* fix non-ascii char

* Loosen test checks

* fix java

* fix sphinx-build
This commit is contained in:
Hao Chen
2019-02-13 19:39:02 +08:00
committed by GitHub
parent 57dcd3033e
commit f31a79f3f7
41 changed files with 1708 additions and 490 deletions
@@ -1,16 +1,26 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.gcs.RedisClient;
import org.ray.runtime.generated.ActorCheckpointIdData;
import org.ray.runtime.generated.TablePrefix;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.raylet.RayletClientImpl;
import org.ray.runtime.runner.RunManager;
import org.ray.runtime.util.UniqueIdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -21,7 +31,14 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(RayNativeRuntime.class);
private RedisClient redisClient = null;
/**
* Redis client of the primary shard.
*/
private RedisClient redisClient;
/**
* Redis clients of all shards.
*/
private List<RedisClient> redisClients;
private RunManager manager = null;
public RayNativeRuntime(RayConfig rayConfig) {
@@ -69,7 +86,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
manager = new RunManager(rayConfig);
manager.startRayProcesses(true);
}
redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
initRedisClients();
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
@@ -88,6 +106,16 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
}
private void initRedisClients() {
redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
int numRedisShards = Integer.valueOf(redisClient.get("NumRedisShards", null));
List<String> addresses = redisClient.lrange("RedisShards", 0, -1);
Preconditions.checkState(numRedisShards == addresses.size());
redisClients = addresses.stream().map(RedisClient::new)
.collect(Collectors.toList());
redisClients.add(redisClient);
}
@Override
public void shutdown() {
if (null != manager) {
@@ -116,4 +144,33 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
}
}
/**
* Get the available checkpoints for the given actor ID, return a list sorted by checkpoint
* timestamp in descending order.
*/
List<Checkpoint> getCheckpointsForActor(UniqueId actorId) {
List<Checkpoint> checkpoints = new ArrayList<>();
// TODO(hchen): implement the equivalent of Python's `GlobalState`, to avoid looping over
// all redis shards..
String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID);
byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes());
for (RedisClient client : redisClients) {
byte[] result = client.get(key, null);
if (result == null) {
continue;
}
ActorCheckpointIdData data = ActorCheckpointIdData
.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result));
UniqueId[] checkpointIds
= UniqueIdUtil.getUniqueIdsFromByteBuffer(data.checkpointIdsAsByteBuffer());
for (int i = 0; i < checkpointIds.length; i++) {
checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i)));
}
break;
}
checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp));
return checkpoints;
}
}
@@ -1,8 +1,14 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
import org.ray.api.Checkpointable;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.Checkpointable.CheckpointContext;
import org.ray.api.exception.RayException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.TaskSpec;
@@ -17,6 +23,9 @@ public class Worker {
private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class);
// TODO(hchen): Use the C++ config.
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
private final AbstractRayRuntime runtime;
/**
@@ -34,6 +43,22 @@ public class Worker {
*/
private Exception actorCreationException = null;
/**
* Number of tasks executed since last actor checkpoint.
*/
private int numTasksSinceLastCheckpoint = 0;
/**
* IDs of this actor's previous checkpoints.
*/
private List<UniqueId> checkpointIds;
/**
* Timestamp of the last actor checkpoint.
*/
private long lastCheckpointTimestamp = 0;
public Worker(AbstractRayRuntime runtime) {
this.runtime = runtime;
}
@@ -80,8 +105,12 @@ public class Worker {
}
// Set result
if (!spec.isActorCreationTask()) {
if (spec.isActorTask()) {
maybeSaveCheckpoint(actor, spec.actorId);
}
runtime.put(returnId, result);
} else {
maybeLoadCheckpoint(result, returnId);
currentActor = result;
currentActorId = returnId;
}
@@ -98,4 +127,61 @@ public class Worker {
Thread.currentThread().setContextClassLoader(oldLoader);
}
}
private void maybeSaveCheckpoint(Object actor, UniqueId actorId) {
if (!(actor instanceof Checkpointable)) {
return;
}
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
return;
}
CheckpointContext checkpointContext = new CheckpointContext(actorId,
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
Checkpointable checkpointable = (Checkpointable) actor;
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
return;
}
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
UniqueId checkpointId = runtime.rayletClient.prepareCheckpoint(actorId);
checkpointIds.add(checkpointId);
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
checkpointIds.remove(0);
}
checkpointable.saveCheckpoint(actorId, checkpointId);
}
private void maybeLoadCheckpoint(Object actor, UniqueId actorId) {
if (!(actor instanceof Checkpointable)) {
return;
}
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
return;
}
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
checkpointIds = new ArrayList<>();
List<Checkpoint> availableCheckpoints = ((RayNativeRuntime) runtime)
.getCheckpointsForActor(actorId);
if (availableCheckpoints.isEmpty()) {
return;
}
UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints);
if (checkpointId != null) {
boolean checkpointValid = false;
for (Checkpoint checkpoint : availableCheckpoints) {
if (checkpoint.checkpointId.equals(checkpointId)) {
checkpointValid = true;
break;
}
}
Preconditions.checkArgument(checkpointValid,
"'loadCheckpoint' must return a checkpoint ID that exists in the "
+ "'availableCheckpoints' list, or null.");
runtime.rayletClient.notifyActorResumedFromCheckpoint(actorId, checkpointId);
}
}
}
@@ -1,5 +1,6 @@
package org.ray.runtime.gcs;
import java.util.List;
import java.util.Map;
import org.ray.runtime.util.StringUtil;
@@ -77,7 +78,11 @@ public class RedisClient {
return jedis.hget(key, field);
}
}
}
public List<String> lrange(String key, long start, long end) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.lrange(key, start, end);
}
}
}
@@ -4,6 +4,7 @@ import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang3.NotImplementedException;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.id.UniqueId;
@@ -94,4 +95,14 @@ public class MockRayletClient implements RayletClient {
public void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly) {
return;
}
@Override
public UniqueId prepareCheckpoint(UniqueId actorId) {
throw new NotImplementedException("Not implemented.");
}
@Override
public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId) {
throw new NotImplementedException("Not implemented.");
}
}
@@ -25,4 +25,8 @@ public interface RayletClient {
timeoutMs, UniqueId currentTaskId);
void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly);
UniqueId prepareCheckpoint(UniqueId actorId);
void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId);
}
@@ -127,6 +127,17 @@ public class RayletClientImpl implements RayletClient {
nativeFreePlasmaObjects(client, objectIdsArray, localOnly);
}
@Override
public UniqueId prepareCheckpoint(UniqueId actorId) {
return new UniqueId(nativePrepareCheckpoint(client, actorId.getBytes()));
}
@Override
public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId) {
nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes());
}
private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) {
bb.order(ByteOrder.LITTLE_ENDIAN);
TaskInfo info = TaskInfo.getRootAsTaskInfo(bb);
@@ -142,7 +153,7 @@ public class RayletClientImpl implements RayletClient {
// Deserialize new actor handles
UniqueId[] newActorHandles = UniqueIdUtil.getUniqueIdsFromByteBuffer(
info.newActorHandlesAsByteBuffer());
info.newActorHandlesAsByteBuffer());
// Deserialize args
FunctionArg[] args = new FunctionArg[info.argsLength()];
@@ -208,7 +219,7 @@ public class RayletClientImpl implements RayletClient {
int dataOffset = 0;
if (task.args[i].id != null) {
objectIdOffset = fbb.createString(
UniqueIdUtil.concatUniqueIds(new UniqueId[] {task.args[i].id}));
UniqueIdUtil.concatUniqueIds(new UniqueId[]{task.args[i].id}));
} else {
objectIdOffset = fbb.createString("");
}
@@ -258,7 +269,6 @@ public class RayletClientImpl implements RayletClient {
actorIdOffset,
actorHandleIdOffset,
actorCounter,
false,
newActorHandlesOffset,
argsOffset,
returnsOffset,
@@ -271,8 +281,8 @@ public class RayletClientImpl implements RayletClient {
if (buffer.remaining() > TASK_SPEC_BUFFER_SIZE) {
LOGGER.error(
"Allocated buffer is not enough to transfer the task specification: {}vs {}",
TASK_SPEC_BUFFER_SIZE, buffer.remaining());
"Allocated buffer is not enough to transfer the task specification: {} vs {}",
TASK_SPEC_BUFFER_SIZE, buffer.remaining());
throw new RuntimeException("Allocated buffer is not enough to transfer to task.");
}
return buffer;
@@ -323,4 +333,8 @@ public class RayletClientImpl implements RayletClient {
private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds,
boolean localOnly) throws RayException;
private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId);
private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId,
byte[] checkpointId);
}