diff --git a/doc/source/conf.py b/doc/source/conf.py index 69e173768..a4f9aebd4 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -36,6 +36,7 @@ MOCK_MODULES = [ "tensorflow.python.client", "tensorflow.python.util", "ray.core.generated", + "ray.core.generated.ActorCheckpointIdData", "ray.core.generated.ClientTableData", "ray.core.generated.GcsTableEntry", "ray.core.generated.HeartbeatTableData", diff --git a/java/api/src/main/java/org/ray/api/Checkpointable.java b/java/api/src/main/java/org/ray/api/Checkpointable.java new file mode 100644 index 000000000..df3404ddb --- /dev/null +++ b/java/api/src/main/java/org/ray/api/Checkpointable.java @@ -0,0 +1,99 @@ +package org.ray.api; + +import java.util.List; +import org.ray.api.id.UniqueId; + +public interface Checkpointable { + + class CheckpointContext { + + /** + * Actor's ID. + */ + public final UniqueId actorId; + /** + * Number of tasks executed since last checkpoint. + */ + public final int numTasksSinceLastCheckpoint; + /** + * Time elapsed since last checkpoint, in milliseconds. + */ + public final long timeElapsedMsSinceLastCheckpoint; + + public CheckpointContext(UniqueId actorId, int numTasksSinceLastCheckpoint, + long timeElapsedMsSinceLastCheckpoint) { + this.actorId = actorId; + this.numTasksSinceLastCheckpoint = numTasksSinceLastCheckpoint; + this.timeElapsedMsSinceLastCheckpoint = timeElapsedMsSinceLastCheckpoint; + } + } + + class Checkpoint { + + /** + * Checkpoint's ID. + */ + public final UniqueId checkpointId; + /** + * Checkpoint's timestamp. + */ + public final long timestamp; + + public Checkpoint(UniqueId checkpointId, long timestamp) { + this.checkpointId = checkpointId; + this.timestamp = timestamp; + } + } + + /** + * Whether this actor needs to be checkpointed. + * + * This method will be called after every task. You should implement this callback to decide + * whether this actor needs to be checkpointed at this time, based on the checkpoint context, or + * any other factors. + * + * @param checkpointContext An object that contains info about last checkpoint. + * @return A boolean value that indicates whether this actor needs to be checkpointed. + */ + boolean shouldCheckpoint(CheckpointContext checkpointContext); + + /** + * Save a checkpoint to persistent storage. + * + * If `shouldCheckpoint` returns true, this method will be called. You should implement this + * callback to save actor's checkpoint and the given checkpoint id to persistent storage. + * + * @param actorId Actor's ID. + * @param checkpointId An ID that represents this actor's current state in GCS. You should + * save this checkpoint ID together with actor's checkpoint data. + */ + void saveCheckpoint(UniqueId actorId, UniqueId checkpointId); + + /** + * Load actor's previous checkpoint, and restore actor's state. + * + * This method will be called when an actor is reconstructed, after actor's constructor. If the + * actor needs to restore from previous checkpoint, this function should restore actor's state and + * return the checkpoint ID. Otherwise, it should do nothing and return null. + * + * @param actorId Actor's ID. + * @param availableCheckpoints A list of available checkpoint IDs and their timestamps, sorted + * by timestamp in descending order. Note, this method must return the ID of one checkpoint in + * this list, or null. Otherwise, an exception will be thrown. + * @return The ID of the checkpoint from which the actor was resumed, or null if the actor should + * restart from the beginning. + */ + UniqueId loadCheckpoint(UniqueId actorId, List availableCheckpoints); + + /** + * Delete an expired checkpoint; + * + * This method will be called when an checkpoint is expired. You should implement this method to + * delete your application checkpoint data. Note, the maximum number of checkpoints kept in the + * backend can be configured at `RayConfig.num_actor_checkpoints_to_keep`. + * + * @param actorId ID of the actor. + * @param checkpointId ID of the checkpoint that has expired. + */ + void checkpointExpired(UniqueId actorId, UniqueId checkpointId); +} diff --git a/java/checkstyle-suppressions.xml b/java/checkstyle-suppressions.xml index d7c3b7755..80d392ffb 100644 --- a/java/checkstyle-suppressions.xml +++ b/java/checkstyle-suppressions.xml @@ -4,6 +4,7 @@ + diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index e8e32fbb3..70ad03af4 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -1,16 +1,26 @@ package org.ray.runtime; +import com.google.common.base.Preconditions; import com.google.common.base.Strings; import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import org.apache.commons.lang3.ArrayUtils; +import org.ray.api.Checkpointable.Checkpoint; +import org.ray.api.id.UniqueId; import org.ray.runtime.config.RayConfig; import org.ray.runtime.config.WorkerMode; import org.ray.runtime.gcs.RedisClient; +import org.ray.runtime.generated.ActorCheckpointIdData; +import org.ray.runtime.generated.TablePrefix; import org.ray.runtime.objectstore.ObjectStoreProxy; import org.ray.runtime.raylet.RayletClientImpl; import org.ray.runtime.runner.RunManager; +import org.ray.runtime.util.UniqueIdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -21,7 +31,14 @@ public final class RayNativeRuntime extends AbstractRayRuntime { private static final Logger LOGGER = LoggerFactory.getLogger(RayNativeRuntime.class); - private RedisClient redisClient = null; + /** + * Redis client of the primary shard. + */ + private RedisClient redisClient; + /** + * Redis clients of all shards. + */ + private List redisClients; private RunManager manager = null; public RayNativeRuntime(RayConfig rayConfig) { @@ -69,7 +86,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime { manager = new RunManager(rayConfig); manager.startRayProcesses(true); } - redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword); + + initRedisClients(); // TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis. objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName); @@ -88,6 +106,16 @@ public final class RayNativeRuntime extends AbstractRayRuntime { rayConfig.objectStoreSocketName, rayConfig.rayletSocketName); } + private void initRedisClients() { + redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword); + int numRedisShards = Integer.valueOf(redisClient.get("NumRedisShards", null)); + List addresses = redisClient.lrange("RedisShards", 0, -1); + Preconditions.checkState(numRedisShards == addresses.size()); + redisClients = addresses.stream().map(RedisClient::new) + .collect(Collectors.toList()); + redisClients.add(redisClient); + } + @Override public void shutdown() { if (null != manager) { @@ -116,4 +144,33 @@ public final class RayNativeRuntime extends AbstractRayRuntime { } } + /** + * Get the available checkpoints for the given actor ID, return a list sorted by checkpoint + * timestamp in descending order. + */ + List getCheckpointsForActor(UniqueId actorId) { + List checkpoints = new ArrayList<>(); + // TODO(hchen): implement the equivalent of Python's `GlobalState`, to avoid looping over + // all redis shards.. + String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID); + byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes()); + for (RedisClient client : redisClients) { + byte[] result = client.get(key, null); + if (result == null) { + continue; + } + ActorCheckpointIdData data = ActorCheckpointIdData + .getRootAsActorCheckpointIdData(ByteBuffer.wrap(result)); + + UniqueId[] checkpointIds + = UniqueIdUtil.getUniqueIdsFromByteBuffer(data.checkpointIdsAsByteBuffer()); + + for (int i = 0; i < checkpointIds.length; i++) { + checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i))); + } + break; + } + checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp)); + return checkpoints; + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index 6e403c80a..e18e5f5e8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -1,8 +1,14 @@ package org.ray.runtime; import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.List; +import org.ray.api.Checkpointable; +import org.ray.api.Checkpointable.Checkpoint; +import org.ray.api.Checkpointable.CheckpointContext; import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; +import org.ray.runtime.config.RunMode; import org.ray.runtime.functionmanager.RayFunction; import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskSpec; @@ -17,6 +23,9 @@ public class Worker { private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class); + // TODO(hchen): Use the C++ config. + private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20; + private final AbstractRayRuntime runtime; /** @@ -34,6 +43,22 @@ public class Worker { */ private Exception actorCreationException = null; + /** + * Number of tasks executed since last actor checkpoint. + */ + private int numTasksSinceLastCheckpoint = 0; + + /** + * IDs of this actor's previous checkpoints. + */ + private List checkpointIds; + + /** + * Timestamp of the last actor checkpoint. + */ + private long lastCheckpointTimestamp = 0; + + public Worker(AbstractRayRuntime runtime) { this.runtime = runtime; } @@ -80,8 +105,12 @@ public class Worker { } // Set result if (!spec.isActorCreationTask()) { + if (spec.isActorTask()) { + maybeSaveCheckpoint(actor, spec.actorId); + } runtime.put(returnId, result); } else { + maybeLoadCheckpoint(result, returnId); currentActor = result; currentActorId = returnId; } @@ -98,4 +127,61 @@ public class Worker { Thread.currentThread().setContextClassLoader(oldLoader); } } + + private void maybeSaveCheckpoint(Object actor, UniqueId actorId) { + if (!(actor instanceof Checkpointable)) { + return; + } + if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) { + // Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet. + return; + } + CheckpointContext checkpointContext = new CheckpointContext(actorId, + ++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp); + Checkpointable checkpointable = (Checkpointable) actor; + if (!checkpointable.shouldCheckpoint(checkpointContext)) { + return; + } + numTasksSinceLastCheckpoint = 0; + lastCheckpointTimestamp = System.currentTimeMillis(); + UniqueId checkpointId = runtime.rayletClient.prepareCheckpoint(actorId); + checkpointIds.add(checkpointId); + if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) { + ((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0)); + checkpointIds.remove(0); + } + checkpointable.saveCheckpoint(actorId, checkpointId); + } + + private void maybeLoadCheckpoint(Object actor, UniqueId actorId) { + if (!(actor instanceof Checkpointable)) { + return; + } + if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) { + // Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet. + return; + } + numTasksSinceLastCheckpoint = 0; + lastCheckpointTimestamp = System.currentTimeMillis(); + checkpointIds = new ArrayList<>(); + List availableCheckpoints = ((RayNativeRuntime) runtime) + .getCheckpointsForActor(actorId); + if (availableCheckpoints.isEmpty()) { + return; + } + UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints); + if (checkpointId != null) { + boolean checkpointValid = false; + for (Checkpoint checkpoint : availableCheckpoints) { + if (checkpoint.checkpointId.equals(checkpointId)) { + checkpointValid = true; + break; + } + } + Preconditions.checkArgument(checkpointValid, + "'loadCheckpoint' must return a checkpoint ID that exists in the " + + "'availableCheckpoints' list, or null."); + runtime.rayletClient.notifyActorResumedFromCheckpoint(actorId, checkpointId); + } + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java index 4aa2c9607..94f189785 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java @@ -1,5 +1,6 @@ package org.ray.runtime.gcs; +import java.util.List; import java.util.Map; import org.ray.runtime.util.StringUtil; @@ -77,7 +78,11 @@ public class RedisClient { return jedis.hget(key, field); } } - } + public List lrange(String key, long start, long end) { + try (Jedis jedis = jedisPool.getResource()) { + return jedis.lrange(key, start, end); + } + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index 0da3dbe80..8f93e55c1 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -4,6 +4,7 @@ import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.lang3.NotImplementedException; import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.id.UniqueId; @@ -94,4 +95,14 @@ public class MockRayletClient implements RayletClient { public void freePlasmaObjects(List objectIds, boolean localOnly) { return; } + + @Override + public UniqueId prepareCheckpoint(UniqueId actorId) { + throw new NotImplementedException("Not implemented."); + } + + @Override + public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId) { + throw new NotImplementedException("Not implemented."); + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index b68fe0182..7ecbb80c6 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -25,4 +25,8 @@ public interface RayletClient { timeoutMs, UniqueId currentTaskId); void freePlasmaObjects(List objectIds, boolean localOnly); + + UniqueId prepareCheckpoint(UniqueId actorId); + + void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 96b7657db..60eaf2d23 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -127,6 +127,17 @@ public class RayletClientImpl implements RayletClient { nativeFreePlasmaObjects(client, objectIdsArray, localOnly); } + @Override + public UniqueId prepareCheckpoint(UniqueId actorId) { + return new UniqueId(nativePrepareCheckpoint(client, actorId.getBytes())); + } + + @Override + public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId) { + nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes()); + } + + private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { bb.order(ByteOrder.LITTLE_ENDIAN); TaskInfo info = TaskInfo.getRootAsTaskInfo(bb); @@ -142,7 +153,7 @@ public class RayletClientImpl implements RayletClient { // Deserialize new actor handles UniqueId[] newActorHandles = UniqueIdUtil.getUniqueIdsFromByteBuffer( - info.newActorHandlesAsByteBuffer()); + info.newActorHandlesAsByteBuffer()); // Deserialize args FunctionArg[] args = new FunctionArg[info.argsLength()]; @@ -208,7 +219,7 @@ public class RayletClientImpl implements RayletClient { int dataOffset = 0; if (task.args[i].id != null) { objectIdOffset = fbb.createString( - UniqueIdUtil.concatUniqueIds(new UniqueId[] {task.args[i].id})); + UniqueIdUtil.concatUniqueIds(new UniqueId[]{task.args[i].id})); } else { objectIdOffset = fbb.createString(""); } @@ -258,7 +269,6 @@ public class RayletClientImpl implements RayletClient { actorIdOffset, actorHandleIdOffset, actorCounter, - false, newActorHandlesOffset, argsOffset, returnsOffset, @@ -271,8 +281,8 @@ public class RayletClientImpl implements RayletClient { if (buffer.remaining() > TASK_SPEC_BUFFER_SIZE) { LOGGER.error( - "Allocated buffer is not enough to transfer the task specification: {}vs {}", - TASK_SPEC_BUFFER_SIZE, buffer.remaining()); + "Allocated buffer is not enough to transfer the task specification: {} vs {}", + TASK_SPEC_BUFFER_SIZE, buffer.remaining()); throw new RuntimeException("Allocated buffer is not enough to transfer to task."); } return buffer; @@ -323,4 +333,8 @@ public class RayletClientImpl implements RayletClient { private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds, boolean localOnly) throws RayException; + private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId); + + private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId, + byte[] checkpointId); } diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index a0fdbfbf0..fd82cf4cf 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -4,10 +4,13 @@ import static org.ray.runtime.util.SystemUtil.pid; import java.io.IOException; import java.util.HashMap; +import java.util.List; import java.util.concurrent.TimeUnit; +import org.ray.api.Checkpointable; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.annotation.RayRemote; +import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; import org.testng.Assert; import org.testng.annotations.Test; @@ -17,10 +20,10 @@ public class ActorReconstructionTest extends BaseTest { @RayRemote() public static class Counter { - private int value = 0; + protected int value = 0; - public int increase(int delta) { - value += delta; + public int increase() { + value += 1; return value; } @@ -35,7 +38,7 @@ public class ActorReconstructionTest extends BaseTest { RayActor actor = Ray.createActor(Counter::new, options); // Call increase 3 times. for (int i = 0; i < 3; i++) { - Ray.call(Counter::increase, actor, 1).get(); + Ray.call(Counter::increase, actor).get(); } // Kill the actor process. @@ -45,7 +48,7 @@ public class ActorReconstructionTest extends BaseTest { TimeUnit.SECONDS.sleep(1); // Try calling increase on this actor again and check the value is now 4. - int value = Ray.call(Counter::increase, actor, 1).get(); + int value = Ray.call(Counter::increase, actor).get(); Assert.assertEquals(value, 4); // Kill the actor process again. @@ -55,7 +58,7 @@ public class ActorReconstructionTest extends BaseTest { // Try calling increase on this actor again and this should fail. try { - Ray.call(Counter::increase, actor, 1).get(); + Ray.call(Counter::increase, actor).get(); Assert.fail("The above task didn't fail."); } catch (StringIndexOutOfBoundsException e) { // Raylet backend will put invalid data in task's result to indicate the task has failed. @@ -64,4 +67,71 @@ public class ActorReconstructionTest extends BaseTest { // instead of throwing this exception. } } + + public static class CheckpointableCounter extends Counter implements Checkpointable { + + private boolean resumedFromCheckpoint = false; + private boolean increaseCalled = false; + + @Override + public int increase() { + increaseCalled = true; + return super.increase(); + } + + public boolean wasResumedFromCheckpoint() { + return resumedFromCheckpoint; + } + + @Override + public boolean shouldCheckpoint(CheckpointContext checkpointContext) { + // Checkpoint the actor when value is increased to 3. + boolean shouldCheckpoint = increaseCalled && value == 3; + increaseCalled = false; + return shouldCheckpoint; + } + + @Override + public void saveCheckpoint(UniqueId actorId, UniqueId checkpointId) { + // In practice, user should save the checkpoint id and data to a persistent store. + // But for simplicity, we don't do that in this unit test. + } + + @Override + public UniqueId loadCheckpoint(UniqueId actorId, List availableCheckpoints) { + // Restore previous value and return checkpoint id. + this.value = 3; + this.resumedFromCheckpoint = true; + return availableCheckpoints.get(availableCheckpoints.size() - 1).checkpointId; + } + + @Override + public void checkpointExpired(UniqueId actorId, UniqueId checkpointId) { + } + } + + @Test + public void testActorCheckpointing() throws IOException, InterruptedException { + ActorCreationOptions options = new ActorCreationOptions(new HashMap<>(), 1); + RayActor actor = Ray.createActor(CheckpointableCounter::new, options); + // Call increase 3 times. + for (int i = 0; i < 3; i++) { + Ray.call(CheckpointableCounter::increase, actor).get(); + } + // Assert that the actor wasn't resumed from a checkpoint. + Assert.assertFalse(Ray.call(CheckpointableCounter::wasResumedFromCheckpoint, actor).get()); + + // Kill the actor process. + int pid = Ray.call(CheckpointableCounter::getPid, actor).get(); + Runtime.getRuntime().exec("kill -9 " + pid); + // Wait for the actor to be killed. + TimeUnit.SECONDS.sleep(1); + + // Try calling increase on this actor again and check the value is now 4. + int value = Ray.call(CheckpointableCounter::increase, actor).get(); + Assert.assertEquals(value, 4); + // Assert that the actor was resumed from a checkpoint. + Assert.assertTrue(Ray.call(CheckpointableCounter::wasResumedFromCheckpoint, actor).get()); + } } + diff --git a/python/ray/__init__.py b/python/ray/__init__.py index cbfb99e2e..81bc7e80b 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -49,9 +49,20 @@ except ImportError as e: modin_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "modin") sys.path.append(modin_path) -from ray._raylet import (UniqueID, ObjectID, DriverID, ClientID, ActorID, - ActorHandleID, FunctionID, ActorClassID, TaskID, - _ID_TYPES, Config as _Config) # noqa: E402 +from ray._raylet import ( + ActorCheckpointID, + ActorClassID, + ActorHandleID, + ActorID, + ClientID, + Config as _Config, + DriverID, + FunctionID, + ObjectID, + TaskID, + UniqueID, + _ID_TYPES, +) # noqa: E402 _config = _Config() @@ -82,8 +93,16 @@ __all__ = [ ] __all__ += [ - "UniqueID", "ObjectID", "DriverID", "ClientID", "ActorID", "ActorHandleID", - "FunctionID", "ActorClassID", "TaskID" + "ActorCheckpointID", + "ActorClassID", + "ActorHandleID", + "ActorID", + "ClientID", + "DriverID", + "FunctionID", + "ObjectID", + "TaskID", + "UniqueID", ] import ctypes # noqa: E402 diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 818cb6820..b37ace64f 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -19,12 +19,31 @@ include "includes/ray_config.pxi" include "includes/task.pxi" from ray.includes.common cimport ( - CUniqueID, CTaskID, CObjectID, CFunctionID, CActorClassID, CActorID, - CActorHandleID, CWorkerID, CDriverID, CConfigID, CClientID, - CLanguage, CRayStatus, LANGUAGE_CPP, LANGUAGE_JAVA, LANGUAGE_PYTHON) + CActorCheckpointID, + CActorClassID, + CActorHandleID, + CActorID, + CClientID, + CConfigID, + CDriverID, + CFunctionID, + CLanguage, + CObjectID, + CRayStatus, + CTaskID, + CUniqueID, + CWorkerID, + LANGUAGE_CPP, + LANGUAGE_JAVA, + LANGUAGE_PYTHON, +) from ray.includes.libraylet cimport ( - CRayletClient, GCSProfileTableDataT, GCSProfileEventT, - ResourceMappingType, WaitResultPair) + CRayletClient, + GCSProfileEventT, + GCSProfileTableDataT, + ResourceMappingType, + WaitResultPair, +) from ray.includes.task cimport CTaskSpecification from ray.includes.ray_config cimport RayConfig from ray.utils import decode @@ -303,6 +322,14 @@ cdef class RayletClient: cdef c_vector[CObjectID] free_ids = ObjectIDsToVector(object_ids) check_status(self.client.get().FreeObjects(free_ids, local_only)) + def prepare_actor_checkpoint(self, ActorID actor_id): + cdef CActorCheckpointID checkpoint_id + check_status(self.client.get().PrepareActorCheckpoint(actor_id.data, checkpoint_id)) + return ObjectID.from_native(checkpoint_id); + + def notify_actor_resumed_from_checkpoint(self, ActorID actor_id, ActorCheckpointID checkpoint_id): + check_status(self.client.get().NotifyActorResumedFromCheckpoint(actor_id.data, checkpoint_id.data)) + @property def language(self): return Language.from_native(self.client.get().GetLanguage()) diff --git a/python/ray/actor.py b/python/ray/actor.py index 8d4e55e88..ecad3e44a 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -6,11 +6,13 @@ import copy import hashlib import inspect import logging +import six import sys import threading -import traceback -import ray.cloudpickle as pickle +from abc import ABCMeta, abstractmethod +from collections import namedtuple + from ray.function_manager import FunctionDescriptor import ray.ray_constants as ray_constants import ray.signature as signature @@ -75,90 +77,6 @@ def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id): return ActorHandleID(handle_id) -def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, - frontier): - """Set the most recent checkpoint associated with a given actor ID. - - Args: - worker: The worker to use to get the checkpoint. - actor_id: The actor ID of the actor to get the checkpoint for. - checkpoint_index: The number of tasks included in the checkpoint. - checkpoint: The state object to save. - frontier: The task frontier at the time of the checkpoint. - """ - assert isinstance(actor_id, ActorID) - actor_key = b"Actor:" + actor_id.binary() - worker.redis_client.hmset( - actor_key, { - "checkpoint_index": checkpoint_index, - "checkpoint": checkpoint, - "frontier": frontier, - }) - - -def save_and_log_checkpoint(worker, actor): - """Save a checkpoint on the actor and log any errors. - - Args: - worker: The worker to use to log errors. - actor: The actor to checkpoint. - """ - try: - actor.__ray_checkpoint__() - except Exception: - traceback_str = ray.utils.format_error_message(traceback.format_exc()) - # Log the error message. - ray.utils.push_error_to_driver( - worker, - ray_constants.CHECKPOINT_PUSH_ERROR, - traceback_str, - driver_id=worker.task_driver_id) - - -def restore_and_log_checkpoint(worker, actor): - """Restore an actor from a checkpoint and log any errors. - - Args: - worker: The worker to use to log errors. - actor: The actor to restore. - """ - checkpoint_resumed = False - try: - checkpoint_resumed = actor.__ray_checkpoint_restore__() - except Exception: - traceback_str = ray.utils.format_error_message(traceback.format_exc()) - # Log the error message. - ray.utils.push_error_to_driver( - worker, - ray_constants.CHECKPOINT_PUSH_ERROR, - traceback_str, - driver_id=worker.task_driver_id) - return checkpoint_resumed - - -def get_actor_checkpoint(worker, actor_id): - """Get the most recent checkpoint associated with a given actor ID. - - Args: - worker: The worker to use to get the checkpoint. - actor_id: The actor ID of the actor to get the checkpoint for. - - Returns: - If a checkpoint exists, this returns a tuple of the number of tasks - included in the checkpoint, the saved checkpoint state, and the - task frontier at the time of the checkpoint. If no checkpoint - exists, all objects are set to None. The checkpoint index is the . - executed on the actor before the checkpoint was made. - """ - assert isinstance(actor_id, ActorID) - actor_key = b"Actor:" + actor_id.binary() - checkpoint_index, checkpoint, frontier = worker.redis_client.hmget( - actor_key, ["checkpoint_index", "checkpoint", "frontier"]) - if checkpoint_index is not None: - checkpoint_index = int(checkpoint_index) - return checkpoint_index, checkpoint, frontier - - def method(*args, **kwargs): """Annotate an actor method. @@ -234,7 +152,6 @@ class ActorClass(object): additional methods added like __ray_terminate__). _class_id: The ID of this actor class. _class_name: The name of this class. - _checkpoint_interval: The interval at which to checkpoint actor state. _num_cpus: The default number of CPUs required by the actor creation task. _num_gpus: The default number of GPUs required by the actor creation @@ -250,13 +167,11 @@ class ActorClass(object): each actor method. """ - def __init__(self, modified_class, class_id, checkpoint_interval, - max_reconstructions, num_cpus, num_gpus, resources, - actor_method_cpus): + def __init__(self, modified_class, class_id, max_reconstructions, num_cpus, + num_gpus, resources, actor_method_cpus): self._modified_class = modified_class self._class_id = class_id self._class_name = modified_class.__name__ - self._checkpoint_interval = checkpoint_interval self._max_reconstructions = max_reconstructions self._num_cpus = num_cpus self._num_gpus = num_gpus @@ -383,8 +298,7 @@ class ActorClass(object): # Export the actor. if not self._exported: worker.function_actor_manager.export_actor_class( - self._modified_class, self._actor_method_names, - self._checkpoint_interval) + self._modified_class, self._actor_method_names) self._exported = True resources = ray.utils.resources_from_resource_arguments( @@ -564,8 +478,6 @@ class ActorHandle(object): return getattr(worker.actors[self._ray_actor_id], method_name)(*copy.deepcopy(args)) - is_actor_checkpoint_method = (method_name == "__ray_checkpoint__") - function_descriptor = FunctionDescriptor( self._ray_module_name, method_name, self._ray_class_name) with self._ray_actor_lock: @@ -575,7 +487,6 @@ class ActorHandle(object): actor_id=self._ray_actor_id, actor_handle_id=self._ray_actor_handle_id, actor_counter=self._ray_actor_counter, - is_actor_checkpoint_method=is_actor_checkpoint_method, actor_creation_dummy_object_id=( self._ray_actor_creation_dummy_object_id), execution_dependencies=[self._ray_actor_cursor], @@ -770,7 +681,7 @@ class ActorHandle(object): def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus, - checkpoint_interval, max_reconstructions): + max_reconstructions): # Give an error if cls is an old-style class. if not issubclass(cls, object): raise TypeError( @@ -778,13 +689,14 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus, "classes. In Python 2, you must declare the class with " "'class ClassName(object):' instead of 'class ClassName:'.") - if checkpoint_interval is None: - checkpoint_interval = -1 + if issubclass(cls, Checkpointable) and inspect.isabstract(cls): + raise TypeError( + "A checkpointable actor class should implement all abstract " + "methods in the `Checkpointable` interface.") + if max_reconstructions is None: max_reconstructions = 0 - if checkpoint_interval == 0: - raise Exception("checkpoint_interval must be greater than 0.") if not (ray_constants.NO_RECONSTRUCTION <= max_reconstructions <= ray_constants.INFINITE_RECONSTRUCTION): raise Exception("max_reconstructions must be in range [%d, %d]." % @@ -804,26 +716,6 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus, sys.exit(0) assert False, "This process should have terminated." - def __ray_save_checkpoint__(self): - if hasattr(self, "__ray_save__"): - object_to_serialize = self.__ray_save__() - else: - object_to_serialize = self - return pickle.dumps(object_to_serialize) - - @classmethod - def __ray_restore_from_checkpoint__(cls, pickled_checkpoint): - checkpoint = pickle.loads(pickled_checkpoint) - if hasattr(cls, "__ray_restore__"): - actor_object = cls.__new__(cls) - actor_object.__ray_restore__(checkpoint) - else: - # TODO(rkn): It's possible that this will cause problems. When - # you unpickle the same object twice, the two objects will not - # have the same class. - actor_object = checkpoint - return actor_object - def __ray_checkpoint__(self): """Save a checkpoint. @@ -832,58 +724,143 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus, (number of tasks executed so far). """ worker = ray.worker.global_worker - checkpoint_index = worker.actor_task_counter - # Get the state to save. - checkpoint = self.__ray_save_checkpoint__() - # Get the current task frontier, per actor handle. - # NOTE(swang): This only includes actor handles that the local - # scheduler has seen. Handle IDs for which no task has yet reached - # the local scheduler will not be included, and may not be runnable - # on checkpoint resumption. - actor_id = worker.actor_id - frontier = worker.raylet_client.get_actor_frontier(actor_id) - # Save the checkpoint in Redis. TODO(rkn): Checkpoints - # should not be stored in Redis. Fix this. - set_actor_checkpoint(worker, worker.actor_id, checkpoint_index, - checkpoint, frontier) - - def __ray_checkpoint_restore__(self): - """Restore a checkpoint. - - This task looks for a saved checkpoint and if found, restores the - state of the actor, the task frontier in the local scheduler, and - the checkpoint index (number of tasks executed so far). - - Returns: - A bool indicating whether a checkpoint was resumed. - """ - worker = ray.worker.global_worker - # Get the most recent checkpoint stored, if any. - checkpoint_index, checkpoint, frontier = get_actor_checkpoint( - worker, worker.actor_id) - # Try to resume from the checkpoint. - checkpoint_resumed = False - if checkpoint_index is not None: - # Load the actor state from the checkpoint. - worker.actors[worker.actor_id] = ( - worker.actor_class.__ray_restore_from_checkpoint__( - checkpoint)) - # Set the number of tasks executed so far. - worker.actor_task_counter = checkpoint_index - # Set the actor frontier in the local scheduler. - worker.raylet_client.set_actor_frontier(frontier) - checkpoint_resumed = True - - return checkpoint_resumed + if not isinstance(self, ray.actor.Checkpointable): + raise Exception( + "__ray_checkpoint__.remote() may only be called on actors " + "that implement ray.actor.Checkpointable") + return worker._save_actor_checkpoint() Class.__module__ = cls.__module__ Class.__name__ = cls.__name__ class_id = ActorClassID(_random_string()) - return ActorClass(Class, class_id, checkpoint_interval, - max_reconstructions, num_cpus, num_gpus, resources, - actor_method_cpus) + return ActorClass(Class, class_id, max_reconstructions, num_cpus, num_gpus, + resources, actor_method_cpus) ray.worker.global_worker.make_actor = make_actor + +CheckpointContext = namedtuple( + 'CheckpointContext', + [ + # Actor's ID. + 'actor_id', + # Number of tasks executed since last checkpoint. + 'num_tasks_since_last_checkpoint', + # Time elapsed since last checkpoint, in milliseconds. + 'time_elapsed_ms_since_last_checkpoint', + ], +) +"""A namedtuple that contains information about actor's last checkpoint.""" + +Checkpoint = namedtuple( + 'Checkpoint', + [ + # ID of this checkpoint. + 'checkpoint_id', + # The timestamp at which this checkpoint was saved, + # represented as milliseconds elapsed since Unix epoch. + 'timestamp', + ], +) +"""A namedtuple that represents a checkpoint.""" + + +class Checkpointable(six.with_metaclass(ABCMeta, object)): + """An interface that indicates an actor can be checkpointed.""" + + @abstractmethod + def should_checkpoint(self, checkpoint_context): + """Whether this actor needs to be checkpointed. + + This method will be called after every task. You should implement this + callback to decide whether this actor needs to be checkpointed at this + time, based on the checkpoint context, or any other factors. + + Args: + checkpoint_context: A namedtuple that contains info about last + checkpoint. + + Returns: + A boolean value that indicates whether this actor needs to be + checkpointed. + """ + pass + + @abstractmethod + def save_checkpoint(self, actor_id, checkpoint_id): + """Save a checkpoint to persistent storage. + + If `should_checkpoint` returns true, this method will be called. You + should implement this callback to save actor's checkpoint and the given + checkpoint id to persistent storage. + + Args: + actor_id: Actor's ID. + checkpoint_id: ID of this checkpoint. You should save it together + with actor's checkpoint data. And it will be used by the + `load_checkpoint` method. + Returns: + None. + """ + pass + + @abstractmethod + def load_checkpoint(self, actor_id, available_checkpoints): + """Load actor's previous checkpoint, and restore actor's state. + + This method will be called when an actor is reconstructed, after + actor's constructor. + If the actor needs to restore from previous checkpoint, this function + should restore actor's state and return the checkpoint ID. Otherwise, + it should do nothing and return None. + Note, this method must return one of the checkpoint IDs in the + `available_checkpoints` list, or None. Otherwise, an exception will be + raised. + + Args: + actor_id: Actor's ID. + available_checkpoints: A list of `Checkpoint` namedtuples that + contains all available checkpoint IDs and their timestamps, + sorted by timestamp in descending order. + Returns: + The ID of the checkpoint from which the actor was resumed, or None + if the actor should restart from the beginning. + """ + pass + + @abstractmethod + def checkpoint_expired(self, actor_id, checkpoint_id): + """Delete an expired checkpoint. + + This method will be called when an checkpoint is expired. You should + implement this method to delete your application checkpoint data. + Note, the maximum number of checkpoints kept in the backend can be + configured at `RayConfig.num_actor_checkpoints_to_keep`. + + Args: + actor_id: ID of the actor. + checkpoint_id: ID of the checkpoint that has expired. + Returns: + None. + """ + pass + + +def get_checkpoints_for_actor(actor_id): + """Get the available checkpoints for the given actor ID, return a list + sorted by checkpoint timestamp in descending order. + """ + checkpoint_info = ray.worker.global_state.actor_checkpoint_info(actor_id) + if checkpoint_info is None: + return [] + checkpoints = [ + Checkpoint(checkpoint_id, timestamp) for checkpoint_id, timestamp in + zip(checkpoint_info['CheckpointIds'], checkpoint_info['Timestamps']) + ] + return sorted( + checkpoints, + key=lambda checkpoint: checkpoint.timestamp, + reverse=True, + ) diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 2d1e4bee0..4936213ed 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -11,7 +11,8 @@ import time import ray from ray.function_manager import FunctionDescriptor import ray.gcs_utils -import ray.ray_constants as ray_constants + +from ray.ray_constants import ID_SIZE from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) @@ -720,7 +721,7 @@ class GlobalState(object): for key in actor_keys: info = self.redis_client.hgetall(key) actor_id = key[len("Actor:"):] - assert len(actor_id) == ray_constants.ID_SIZE + assert len(actor_id) == ID_SIZE actor_info[binary_to_hex(actor_id)] = { "class_id": binary_to_hex(info[b"class_id"]), "driver_id": binary_to_hex(info[b"driver_id"]), @@ -906,3 +907,42 @@ class GlobalState(object): binary_to_hex(job_id): self._error_messages(ray.DriverID(job_id)) for job_id in job_ids } + + def actor_checkpoint_info(self, actor_id): + """Get checkpoint info for the given actor id. + Args: + actor_id: Actor's ID. + Returns: + A dictionary with information about the actor's checkpoint IDs and + their timestamps. + """ + self._check_connected() + message = self._execute_command( + actor_id, + "RAY.TABLE_LOOKUP", + ray.gcs_utils.TablePrefix.ACTOR_CHECKPOINT_ID, + "", + actor_id.binary(), + ) + if message is None: + return None + gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + message, 0) + entry = ( + ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData( + gcs_entry.Entries(0), 0)) + checkpoint_ids_str = entry.CheckpointIds() + num_checkpoints = len(checkpoint_ids_str) // ID_SIZE + assert len(checkpoint_ids_str) % ID_SIZE == 0 + checkpoint_ids = [ + ray.ActorCheckpointID( + checkpoint_ids_str[(i * ID_SIZE):((i + 1) * ID_SIZE)]) + for i in range(num_checkpoints) + ] + return { + "ActorID": ray.utils.binary_to_hex(entry.ActorId()), + "CheckpointIds": checkpoint_ids, + "Timestamps": [ + entry.Timestamps(i) for i in range(num_checkpoints) + ], + } diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index 06ec3e49e..3abed7ea6 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -510,8 +510,7 @@ class FunctionActorManager(object): self._worker.redis_client.hmset(key, actor_class_info) self._worker.redis_client.rpush("Exports", key) - def export_actor_class(self, Class, actor_method_names, - checkpoint_interval): + def export_actor_class(self, Class, actor_method_names): function_descriptor = FunctionDescriptor.from_class(Class) # `task_driver_id` shouldn't be NIL, unless: # 1) This worker isn't an actor; @@ -528,7 +527,6 @@ class FunctionActorManager(object): "class_name": Class.__name__, "module": Class.__module__, "class": pickle.dumps(Class), - "checkpoint_interval": checkpoint_interval, "driver_id": driver_id.binary(), "actor_method_names": json.dumps(list(actor_method_names)) } @@ -576,17 +574,16 @@ class FunctionActorManager(object): actor_class_key: The key in Redis to use to fetch the actor. """ actor_id = self._worker.actor_id - (driver_id_str, class_name, module, pickled_class, checkpoint_interval, + (driver_id_str, class_name, module, pickled_class, actor_method_names) = self._worker.redis_client.hmget( actor_class_key, [ "driver_id", "class_name", "module", "class", - "checkpoint_interval", "actor_method_names" + "actor_method_names" ]) class_name = decode(class_name) module = decode(module) driver_id = ray.DriverID(driver_id_str) - checkpoint_interval = int(checkpoint_interval) actor_method_names = json.loads(decode(actor_method_names)) # In Python 2, json loads strings as unicode, so convert them back to @@ -605,7 +602,6 @@ class FunctionActorManager(object): pass self._worker.actors[actor_id] = TemporaryActor() - self._worker.actor_checkpoint_interval = checkpoint_interval def temporary_actor_method(*xs): raise Exception( @@ -694,48 +690,116 @@ class FunctionActorManager(object): # to execute. self._worker.actor_task_counter += 1 - # If this is the first task to execute on the actor, try to resume - # from a checkpoint. - # Current __init__ will be called by default. So the real function - # call will start from 2. - if actor_imported and self._worker.actor_task_counter == 2: - checkpoint_resumed = ray.actor.restore_and_log_checkpoint( - self._worker, actor) - if checkpoint_resumed: - # NOTE(swang): Since we did not actually execute the - # __init__ method, this will put None as the return value. - # If the __init__ method is supposed to return multiple - # values, an exception will be logged. - return - - # Determine whether we should checkpoint the actor. - checkpointing_on = (actor_imported - and self._worker.actor_checkpoint_interval > 0) - # We should checkpoint the actor if user checkpointing is on, we've - # executed checkpoint_interval tasks since the last checkpoint, and - # the method we're about to execute is not a checkpoint. - save_checkpoint = (checkpointing_on - and (self._worker.actor_task_counter % - self._worker.actor_checkpoint_interval == 0 - and method_name != "__ray_checkpoint__")) - # Execute the assigned method and save a checkpoint if necessary. try: if is_class_method(method): method_returns = method(*args) else: method_returns = method(actor, *args) - except Exception: + except Exception as e: # Save the checkpoint before allowing the method exception # to be thrown. - if save_checkpoint: - ray.actor.save_and_log_checkpoint(self._worker, actor) - raise + if isinstance(actor, ray.actor.Checkpointable): + self._save_and_log_checkpoint(actor) + raise e else: - # Save the checkpoint before returning the method's return - # values. - if save_checkpoint: - ray.actor.save_and_log_checkpoint(self._worker, actor) + # Handle any checkpointing operations before storing the + # method's return values. + # NOTE(swang): If method_returns is a pointer to the actor's + # state and the checkpointing operations can modify the return + # values if they mutate the actor's state. Is this okay? + if isinstance(actor, ray.actor.Checkpointable): + # If this is the first task to execute on the actor, try to + # resume from a checkpoint. + if self._worker.actor_task_counter == 1: + if actor_imported: + self._restore_and_log_checkpoint(actor) + else: + # Save the checkpoint before returning the method's + # return values. + self._save_and_log_checkpoint(actor) return method_returns return actor_method_executor + + def _save_and_log_checkpoint(self, actor): + """Save an actor checkpoint if necessary and log any errors. + + Args: + actor: The actor to checkpoint. + + Returns: + The result of the actor's user-defined `save_checkpoint` method. + """ + actor_id = self._worker.actor_id + checkpoint_info = self._worker.actor_checkpoint_info[actor_id] + checkpoint_info.num_tasks_since_last_checkpoint += 1 + now = int(1000 * time.time()) + checkpoint_context = ray.actor.CheckpointContext( + actor_id, checkpoint_info.num_tasks_since_last_checkpoint, + now - checkpoint_info.last_checkpoint_timestamp) + # If we should take a checkpoint, notify raylet to prepare a checkpoint + # and then call `save_checkpoint`. + if actor.should_checkpoint(checkpoint_context): + try: + now = int(1000 * time.time()) + checkpoint_id = (self._worker.raylet_client. + prepare_actor_checkpoint(actor_id)) + checkpoint_info.checkpoint_ids.append(checkpoint_id) + actor.save_checkpoint(actor_id, checkpoint_id) + if (len(checkpoint_info.checkpoint_ids) > + ray._config.num_actor_checkpoints_to_keep()): + actor.checkpoint_expired( + actor_id, + checkpoint_info.checkpoint_ids.pop(0), + ) + checkpoint_info.num_tasks_since_last_checkpoint = 0 + checkpoint_info.last_checkpoint_timestamp = now + except Exception: + # Checkpoint save or reload failed. Notify the driver. + traceback_str = ray.utils.format_error_message( + traceback.format_exc()) + ray.utils.push_error_to_driver( + self._worker, + ray_constants.CHECKPOINT_PUSH_ERROR, + traceback_str, + driver_id=self._worker.task_driver_id) + + def _restore_and_log_checkpoint(self, actor): + """Restore an actor from a checkpoint if available and log any errors. + + This should only be called on workers that have just executed an actor + creation task. + + Args: + actor: The actor to restore from a checkpoint. + """ + actor_id = self._worker.actor_id + try: + checkpoints = ray.actor.get_checkpoints_for_actor(actor_id) + if len(checkpoints) > 0: + # If we found previously saved checkpoints for this actor, + # call the `load_checkpoint` callback. + checkpoint_id = actor.load_checkpoint(actor_id, checkpoints) + if checkpoint_id is not None: + # Check that the returned checkpoint id is in the + # `available_checkpoints` list. + msg = ( + "`load_checkpoint` must return a checkpoint id that " + + "exists in the `available_checkpoints` list, or eone.") + assert any(checkpoint_id == checkpoint.checkpoint_id + for checkpoint in checkpoints), msg + # Notify raylet that this actor has been resumed from + # a checkpoint. + (self._worker.raylet_client. + notify_actor_resumed_from_checkpoint( + actor_id, checkpoint_id)) + except Exception: + # Checkpoint save or reload failed. Notify the driver. + traceback_str = ray.utils.format_error_message( + traceback.format_exc()) + ray.utils.push_error_to_driver( + self._worker, + ray_constants.CHECKPOINT_PUSH_ERROR, + traceback_str, + driver_id=self._worker.task_driver_id) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index b46f4170a..bc6afbeea 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -5,6 +5,7 @@ from __future__ import print_function import flatbuffers import ray.core.generated.ErrorTableData +from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData from ray.core.generated.ClientTableData import ClientTableData from ray.core.generated.DriverTableData import DriverTableData from ray.core.generated.ErrorTableData import ErrorTableData @@ -20,10 +21,20 @@ from ray.core.generated.TablePubsub import TablePubsub from ray.core.generated.ray.protocol.Task import Task __all__ = [ - "GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData", - "HeartbeatBatchTableData", "DriverTableData", "ProfileTableData", - "ObjectTableData", "Task", "TablePrefix", "TablePubsub", "Language", - "construct_error_message" + "ActorCheckpointIdData", + "ClientTableData", + "DriverTableData", + "ErrorTableData", + "GcsTableEntry", + "HeartbeatBatchTableData", + "HeartbeatTableData", + "Language", + "ObjectTableData", + "ProfileTableData", + "TablePrefix", + "TablePubsub", + "Task", + "construct_error_message", ] FUNCTION_PREFIX = "RemoteFunction:" diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 5601ab280..bdabd87e4 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -6,10 +6,19 @@ from libcpp.unordered_map cimport unordered_map from libcpp.vector cimport vector as c_vector from ray.includes.unique_ids cimport ( - CUniqueID, TaskID as CTaskID, ObjectID as CObjectID, - FunctionID as CFunctionID, ActorClassID as CActorClassID, ActorID as CActorID, - ActorHandleID as CActorHandleID, WorkerID as CWorkerID, - DriverID as CDriverID, ConfigID as CConfigID, ClientID as CClientID) + ActorCheckpointID as CActorCheckpointID, + ActorClassID as CActorClassID, + ActorHandleID as CActorHandleID, + ActorID as CActorID, + CUniqueID, + ClientID as CClientID, + ConfigID as CConfigID, + DriverID as CDriverID, + FunctionID as CFunctionID, + ObjectID as CObjectID, + TaskID as CTaskID, + WorkerID as CWorkerID, +) cdef extern from "ray/status.h" namespace "ray" nogil: diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index 63c15f659..dd5e6c0ca 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -8,9 +8,21 @@ from libcpp.vector cimport vector as c_vector from ray.includes.common cimport ( - CUniqueID, CTaskID, CObjectID, CFunctionID, CActorClassID, CActorID, - CActorHandleID, CWorkerID, CDriverID, CConfigID, CClientID, - CLanguage, CRayStatus) + CActorCheckpointID, + CActorClassID, + CActorHandleID, + CActorID, + CClientID, + CConfigID, + CDriverID, + CFunctionID, + CLanguage, + CObjectID, + CRayStatus, + CTaskID, + CUniqueID, + CWorkerID, +) from ray.includes.task cimport CTaskSpecification @@ -57,6 +69,10 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: CRayStatus PushProfileEvents(const GCSProfileTableDataT &profile_events) CRayStatus FreeObjects(const c_vector[CObjectID] &object_ids, c_bool local_only) + CRayStatus PrepareActorCheckpoint(const CActorID &actor_id, + CActorCheckpointID &checkpoint_id) + CRayStatus NotifyActorResumedFromCheckpoint( + const CActorID &actor_id, const CActorCheckpointID &checkpoint_id) CLanguage GetLanguage() const CClientID GetClientID() const CDriverID GetDriverID() const diff --git a/python/ray/includes/ray_config.pxd b/python/ray/includes/ray_config.pxd index 809313479..101d39c71 100644 --- a/python/ray/includes/ray_config.pxd +++ b/python/ray/includes/ray_config.pxd @@ -1,4 +1,4 @@ -from libc.stdint cimport int64_t, uint64_t +from libc.stdint cimport int64_t, uint64_t, uint32_t from libcpp.string cimport string as c_string from libcpp.unordered_map cimport unordered_map @@ -80,4 +80,6 @@ cdef extern from "ray/ray_config.h" nogil: int64_t max_task_lease_timeout_ms() const + uint32_t num_actor_checkpoints_to_keep() const + void initialize(const unordered_map[c_string, int] &config_map) diff --git a/python/ray/includes/ray_config.pxi b/python/ray/includes/ray_config.pxi index cb7fa53c3..3c7ad1e5f 100644 --- a/python/ray/includes/ray_config.pxi +++ b/python/ray/includes/ray_config.pxi @@ -144,3 +144,7 @@ cdef class Config: @staticmethod def max_task_lease_timeout_ms(): return RayConfig.instance().max_task_lease_timeout_ms() + + @staticmethod + def num_actor_checkpoints_to_keep(): + return RayConfig.instance().num_actor_checkpoints_to_keep() diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index 5e0df0bf2..d29e5d398 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -28,6 +28,7 @@ cdef extern from "ray/id.h" namespace "ray" nogil: ctypedef CUniqueID ActorID ctypedef CUniqueID ActorClassID ctypedef CUniqueID ActorHandleID + ctypedef CUniqueID ActorCheckpointID ctypedef CUniqueID WorkerID ctypedef CUniqueID DriverID ctypedef CUniqueID ConfigID diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 1f31179e4..47d31ab09 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -7,9 +7,21 @@ See https://github.com/ray-project/ray/issues/3721. # WARNING: Any additional ID types defined in this file must be added to the # _ID_TYPES list at the bottom of this file. from ray.includes.common cimport ( - CUniqueID, CTaskID, CObjectID, CFunctionID, CActorClassID, CActorID, - CActorHandleID, CWorkerID, CDriverID, CConfigID, CClientID, - ComputePutId, ComputeTaskId) + CActorCheckpointID, + CActorClassID, + CActorHandleID, + CActorID, + CClientID, + CConfigID, + CDriverID, + CFunctionID, + CObjectID, + CTaskID, + CUniqueID, + CWorkerID, + ComputePutId, + ComputeTaskId, +) from ray.utils import decode @@ -236,6 +248,29 @@ cdef class ActorHandleID(UniqueID): return "ActorHandleID(" + self.hex() + ")" +cdef class ActorCheckpointID(UniqueID): + + def __init__(self, id): + if not id: + self.data = CUniqueID() + else: + check_id(id) + self.data = CUniqueID.from_binary(id) + + @staticmethod + cdef from_native(const CActorCheckpointID& cpp_id): + cdef ActorCheckpointID self = ActorCheckpointID.__new__(ActorHandleID) + self.data = cpp_id + return self + + @staticmethod + def nil(): + return ActorCheckpointID.from_native(CActorCheckpointID.nil()) + + def __repr__(self): + return "ActorCheckpointID(" + self.hex() + ")" + + cdef class FunctionID(UniqueID): def __init__(self, id): diff --git a/python/ray/worker.py b/python/ray/worker.py index 5c936bfa2..05764444d 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -117,6 +117,25 @@ class RayTaskError(Exception): return "\n".join(out) +class ActorCheckpointInfo(object): + """Information used to maintain actor checkpoints.""" + + __slots__ = [ + # Number of tasks executed since last checkpoint. + "num_tasks_since_last_checkpoint", + # Timestamp of the last checkpoint, in milliseconds. + "last_checkpoint_timestamp", + # IDs of the previous checkpoints. + "checkpoint_ids", + ] + + def __init__(self, num_tasks_since_last_checkpoint, + last_checkpoint_timestamp, checkpoint_ids): + self.num_tasks_since_last_checkpoint = num_tasks_since_last_checkpoint + self.last_checkpoint_timestamp = last_checkpoint_timestamp + self.checkpoint_ids = checkpoint_ids + + class Worker(object): """A class used to define the control flow of a worker process. @@ -141,6 +160,8 @@ class Worker(object): self.actor_init_error = None self.make_actor = None self.actors = {} + # Information used to maintain actor checkpoints. + self.actor_checkpoint_info = {} self.actor_task_counter = 0 # The number of threads Plasma should use when putting an object in the # object store. @@ -515,7 +536,6 @@ class Worker(object): actor_id=None, actor_handle_id=None, actor_counter=0, - is_actor_checkpoint_method=False, actor_creation_id=None, actor_creation_dummy_object_id=None, max_actor_reconstructions=0, @@ -538,8 +558,6 @@ class Worker(object): be serializable objects. actor_id: The ID of the actor that this task is for. actor_counter: The counter of the actor task. - is_actor_checkpoint_method: True if this is an actor checkpoint - task and false otherwise. actor_creation_id: The ID of the actor to create, if this is an actor creation task. actor_creation_dummy_object_id: If this task is an actor method, @@ -900,6 +918,11 @@ class Worker(object): self.actor_creation_task_id = task.task_id() self.function_actor_manager.load_actor(driver_id, function_descriptor) + self.actor_checkpoint_info[self.actor_id] = ActorCheckpointInfo( + num_tasks_since_last_checkpoint=0, + last_checkpoint_timestamp=int(1000 * time.time()), + checkpoint_ids=[], + ) execution_info = self.function_actor_manager.get_execution_info( driver_id, function_descriptor) @@ -2395,16 +2418,12 @@ def make_decorator(num_return_vals=None, num_gpus=None, resources=None, max_calls=None, - checkpoint_interval=None, max_reconstructions=None, worker=None): def decorator(function_or_class): if (inspect.isfunction(function_or_class) or is_cython(function_or_class)): # Set the remote function default resources. - if checkpoint_interval is not None: - raise Exception("The keyword 'checkpoint_interval' is not " - "allowed for remote functions.") if max_reconstructions is not None: raise Exception("The keyword 'max_reconstructions' is not " "allowed for remote functions.") @@ -2437,7 +2456,7 @@ def make_decorator(num_return_vals=None, return worker.make_actor(function_or_class, cpus_to_use, num_gpus, resources, actor_method_cpus, - checkpoint_interval, max_reconstructions) + max_reconstructions) raise Exception("The @ray.remote decorator must be applied to " "either a function or to a class.") @@ -2509,7 +2528,7 @@ def remote(*args, **kwargs): "with no arguments and no parentheses, for example " "'@ray.remote', or it must be applied using some of " "the arguments 'num_return_vals', 'num_cpus', 'num_gpus', " - "'resources', 'max_calls', 'checkpoint_interval'," + "'resources', 'max_calls', " "or 'max_reconstructions', like " "'@ray.remote(num_return_vals=2, " "resources={\"CustomResource\": 1})'.") @@ -2517,7 +2536,7 @@ def remote(*args, **kwargs): for key in kwargs: assert key in [ "num_return_vals", "num_cpus", "num_gpus", "resources", - "max_calls", "checkpoint_interval", "max_reconstructions" + "max_calls", "max_reconstructions" ], error_string num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else None @@ -2534,7 +2553,6 @@ def remote(*args, **kwargs): # Handle other arguments. num_return_vals = kwargs.get("num_return_vals") max_calls = kwargs.get("max_calls") - checkpoint_interval = kwargs.get("checkpoint_interval") max_reconstructions = kwargs.get("max_reconstructions") return make_decorator( @@ -2543,6 +2561,5 @@ def remote(*args, **kwargs): num_gpus=num_gpus, resources=resources, max_calls=max_calls, - checkpoint_interval=checkpoint_interval, max_reconstructions=max_reconstructions, worker=worker) diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index c729856e4..4ce6a07b4 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -118,6 +118,8 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, task_lease_table_.reset(new TaskLeaseTable(shard_contexts_, this)); heartbeat_table_.reset(new HeartbeatTable(shard_contexts_, this)); profile_table_.reset(new ProfileTable(shard_contexts_, this)); + actor_checkpoint_table_.reset(new ActorCheckpointTable(shard_contexts_, this)); + actor_checkpoint_id_table_.reset(new ActorCheckpointIdTable(shard_contexts_, this)); command_type_ = command_type; // TODO(swang): Call the client table's Connect() method here. To do this, @@ -219,6 +221,14 @@ DriverTable &AsyncGcsClient::driver_table() { return *driver_table_; } ProfileTable &AsyncGcsClient::profile_table() { return *profile_table_; } +ActorCheckpointTable &AsyncGcsClient::actor_checkpoint_table() { + return *actor_checkpoint_table_; +} + +ActorCheckpointIdTable &AsyncGcsClient::actor_checkpoint_id_table() { + return *actor_checkpoint_id_table_; +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index 6145d0aa2..062af0dc4 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -60,6 +60,8 @@ class RAY_EXPORT AsyncGcsClient { ErrorTable &error_table(); DriverTable &driver_table(); ProfileTable &profile_table(); + ActorCheckpointTable &actor_checkpoint_table(); + ActorCheckpointIdTable &actor_checkpoint_id_table(); // We also need something to export generic code to run on workers from the // driver (to set the PYTHONPATH) @@ -90,6 +92,8 @@ class RAY_EXPORT AsyncGcsClient { std::unique_ptr error_table_; std::unique_ptr profile_table_; std::unique_ptr client_table_; + std::unique_ptr actor_checkpoint_table_; + std::unique_ptr actor_checkpoint_id_table_; // The following contexts write to the data shard std::vector> shard_contexts_; std::vector> shard_asio_async_clients_; diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 7619349f9..c850ddd22 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -20,6 +20,8 @@ enum TablePrefix:int { DRIVER, PROFILE, TASK_LEASE, + ACTOR_CHECKPOINT, + ACTOR_CHECKPOINT_ID, } // The channel that Add operations to the Table should be published on, if any. @@ -72,8 +74,6 @@ table TaskInfo { actor_handle_id: string; // Number of tasks that have been submitted to this actor so far. actor_counter: int; - // True if this task is an actor checkpoint task and false otherwise. - is_actor_checkpoint_method: bool; // If this is an actor task, then this will be populated with all of the new // actor handles that were forked from this handle since the last task on // this handle was submitted. @@ -318,3 +318,34 @@ table DriverTableData { // Whether it's dead. is_dead: bool; } + +// This table stores the actor checkpoint data. An actor checkpoint +// is the snapshot of an actor's state in the actor registration. +// See `actor_registration.h` for more detailed explanation of these fields. +table ActorCheckpointData { + // ID of this actor. + actor_id: string; + // The dummy object ID of actor's most recently executed task. + execution_dependency: string; + // A list of IDs of this actor's handles. + handle_ids: [string]; + // The task counters of the above handles. + task_counters: [long]; + // The frontier dependencies of the above handles. + frontier_dependencies: [string]; + // A list of unreleased dummy objects from this actor. + unreleased_dummy_objects: [string]; + // The numbers of dependencies for the above unreleased dummy objects. + num_dummy_object_dependencies: [int]; +} + +// This table stores the actor-to-available-checkpoint-ids mapping. +table ActorCheckpointIdData { + // ID of this actor. + actor_id: string; + // IDs of this actor's available checkpoints. + // Note, this is a long string that concatenates all the IDs. + checkpoint_ids: string; + // A list of the timestamps for each of the above `checkpoint_ids`. + timestamps: [long]; +} diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 00ed64d7c..5e1d828d2 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -2,6 +2,8 @@ #include "ray/common/common_protocol.h" #include "ray/gcs/client.h" +#include "ray/ray_config.h" +#include "ray/util/util.h" namespace { @@ -438,6 +440,41 @@ std::string ClientTable::DebugString() const { return result.str(); } +Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id, + const ActorID &actor_id, + const UniqueID &checkpoint_id) { + auto lookup_callback = [this, checkpoint_id, job_id, actor_id]( + ray::gcs::AsyncGcsClient *client, const UniqueID &id, + const ActorCheckpointIdDataT &data) { + std::shared_ptr copy = + std::make_shared(data); + copy->timestamps.push_back(current_sys_time_ms()); + copy->checkpoint_ids += checkpoint_id.binary(); + auto num_to_keep = RayConfig::instance().num_actor_checkpoints_to_keep(); + while (copy->timestamps.size() > num_to_keep) { + // Delete the checkpoint from actor checkpoint table. + const auto &checkpoint_id = + UniqueID::from_binary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); + RAY_LOG(DEBUG) << "Deleting checkpoint " << checkpoint_id << " for actor " + << actor_id; + copy->timestamps.erase(copy->timestamps.begin()); + copy->checkpoint_ids.erase(0, kUniqueIDSize); + // TODO(hchen): also delete checkpoint data from GCS. + } + RAY_CHECK_OK(Add(job_id, actor_id, copy, nullptr)); + }; + auto failure_callback = [this, checkpoint_id, job_id, actor_id]( + ray::gcs::AsyncGcsClient *client, const UniqueID &id) { + std::shared_ptr data = + std::make_shared(); + data->actor_id = id.binary(); + data->timestamps.push_back(current_sys_time_ms()); + data->checkpoint_ids = checkpoint_id.binary(); + RAY_CHECK_OK(Add(job_id, actor_id, data, nullptr)); + }; + return Lookup(job_id, actor_id, lookup_callback, failure_callback); +} + template class Log; template class Log; template class Table; @@ -451,6 +488,8 @@ template class Log; template class Log; template class Log; template class Log; +template class Table; +template class Table; } // namespace gcs diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 1eab64960..49b19114d 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -443,6 +443,34 @@ class TaskLeaseTable : public Table { } }; +class ActorCheckpointTable : public Table { + public: + ActorCheckpointTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Table(contexts, client) { + prefix_ = TablePrefix::ACTOR_CHECKPOINT; + }; +}; + +class ActorCheckpointIdTable : public Table { + public: + ActorCheckpointIdTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Table(contexts, client) { + prefix_ = TablePrefix::ACTOR_CHECKPOINT_ID; + }; + + /// Add a checkpoint id to an actor, and remove a previous checkpoint if the + /// total number of checkpoints in GCS exceeds the max allowed value. + /// + /// \param job_id The ID of the job (= driver). + /// \param actor_id ID of the actor. + /// \param checkpoint_id ID of the checkpoint. + /// \return Status. + Status AddCheckpointId(const JobID &job_id, const ActorID &actor_id, + const UniqueID &checkpoint_id); +}; + namespace raylet { class TaskTable : public Table { diff --git a/src/ray/id.h b/src/ray/id.h index ed60f5b82..562365951 100644 --- a/src/ray/id.h +++ b/src/ray/id.h @@ -45,6 +45,7 @@ typedef UniqueID FunctionID; typedef UniqueID ActorClassID; typedef UniqueID ActorID; typedef UniqueID ActorHandleID; +typedef UniqueID ActorCheckpointID; typedef UniqueID WorkerID; typedef UniqueID DriverID; typedef UniqueID ConfigID; diff --git a/src/ray/ray_config_def.h b/src/ray/ray_config_def.h index bfe09b5a5..895668332 100644 --- a/src/ray/ray_config_def.h +++ b/src/ray/ray_config_def.h @@ -144,3 +144,9 @@ RAY_CONFIG(int, num_workers_per_process, 1); /// Maximum timeout in milliseconds within which a task lease must be renewed. RAY_CONFIG(int64_t, max_task_lease_timeout_ms, 60000); + +/// Maximum number of checkpoints to keep in GCS for an actor. +/// Note: this number should be set to at least 2. Because saving a application +/// checkpoint isn't atomic with saving the backend checkpoint, and it will break +/// if this number is set to 1 and users save application checkpoints in place. +RAY_CONFIG(uint32_t, num_actor_checkpoints_to_keep, 20); diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index 0e95e200a..1cc55367e 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -11,6 +11,25 @@ namespace raylet { ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data) : actor_table_data_(actor_table_data) {} +ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data, + const ActorCheckpointDataT &checkpoint_data) + : actor_table_data_(actor_table_data), + execution_dependency_(ObjectID::from_binary(checkpoint_data.execution_dependency)) { + // Restore `frontier_`. + for (size_t i = 0; i < checkpoint_data.handle_ids.size(); i++) { + auto handle_id = ActorHandleID::from_binary(checkpoint_data.handle_ids[i]); + auto &frontier_entry = frontier_[handle_id]; + frontier_entry.task_counter = checkpoint_data.task_counters[i]; + frontier_entry.execution_dependency = + ObjectID::from_binary(checkpoint_data.frontier_dependencies[i]); + } + // Restore `dummy_objects_`. + for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects.size(); i++) { + auto dummy = ObjectID::from_binary(checkpoint_data.unreleased_dummy_objects[i]); + dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies[i]; + } +} + const ClientID ActorRegistration::GetNodeManagerId() const { return ClientID::from_binary(actor_table_data_.node_manager_id); } @@ -77,6 +96,35 @@ void ActorRegistration::AddHandle(const ActorHandleID &handle_id, int ActorRegistration::NumHandles() const { return frontier_.size(); } +std::shared_ptr ActorRegistration::GenerateCheckpointData( + const ActorID &actor_id, const Task &task) { + const auto actor_handle_id = task.GetTaskSpecification().ActorHandleId(); + const auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); + // Make a copy of the actor registration, and extend its frontier to include + // the most recent task. + // Note(hchen): this is needed because this method is called before + // `FinishAssignedTask`, which will be called when the worker tries to fetch + // the next task. + ActorRegistration copy = *this; + copy.ExtendFrontier(actor_handle_id, dummy_object); + + // Use actor's current state to generate checkpoint data. + auto checkpoint_data = std::make_shared(); + checkpoint_data->actor_id = actor_id.binary(); + checkpoint_data->execution_dependency = copy.GetExecutionDependency().binary(); + for (const auto &frontier : copy.GetFrontier()) { + checkpoint_data->handle_ids.push_back(frontier.first.binary()); + checkpoint_data->task_counters.push_back(frontier.second.task_counter); + checkpoint_data->frontier_dependencies.push_back( + frontier.second.execution_dependency.binary()); + } + for (const auto &entry : copy.GetDummyObjects()) { + checkpoint_data->unreleased_dummy_objects.push_back(entry.first.binary()); + checkpoint_data->num_dummy_object_dependencies.push_back(entry.second); + } + return checkpoint_data; +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index d56090103..38533e7db 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -5,6 +5,7 @@ #include "ray/gcs/format/gcs_generated.h" #include "ray/id.h" +#include "ray/raylet/task.h" namespace ray { @@ -24,6 +25,12 @@ class ActorRegistration { /// this actor. This includes the actor's node manager location. ActorRegistration(const ActorTableDataT &actor_table_data); + /// Recreate an actor's registration from a checkpoint. + /// + /// \param checkpoint_data The checkpoint used to restore the actor. + ActorRegistration(const ActorTableDataT &actor_table_data, + const ActorCheckpointDataT &checkpoint_data); + /// Each actor may have multiple callers, or "handles". A frontier leaf /// represents the execution state of the actor with respect to a single /// handle. @@ -119,6 +126,14 @@ class ActorRegistration { /// \return int. int NumHandles() const; + /// Generate checkpoint data based on actor's current state. + /// + /// \param actor_id ID of this actor. + /// \param task The task that just finished on the actor. + /// \return A shared pointer to the generated checkpoint data. + std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, + const Task &task); + private: /// Information from the global actor table about this actor, including the /// node manager location. diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 289d02170..710928cdb 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -71,6 +71,12 @@ enum MessageType:int { PushProfileEventsRequest, // Free the objects in objects store. FreeObjectsInObjectStoreRequest, + // Request raylet backend to prepare a checkpoint for an actor. + PrepareActorCheckpointRequest, + // Reply of `PrepareActorCheckpointRequest`. + PrepareActorCheckpointReply, + // Notify raylet backend that an actor was resumed from a checkpoint. + NotifyActorResumedFromCheckpoint, // A node manager requests to connect to another node manager. ConnectClient, } @@ -207,6 +213,23 @@ table FreeObjectsRequest { object_ids: [string]; } +table PrepareActorCheckpointRequest { + // ID of the actor. + actor_id: string; +} + +table PrepareActorCheckpointReply { + // ID of the checkpoint. + checkpoint_id: string; +} + +table NotifyActorResumedFromCheckpoint { + // ID of the actor. + actor_id: string; + // ID of the checkpoint from which the actor was resumed. + checkpoint_id: string; +} + table ConnectClient { // ID of the connecting client. client_id: string; diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 634d04a9a..c3875c6d2 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -268,6 +268,40 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( ThrowRayExceptionIfNotOK(env, status, "[RayletClient] Failed to free objects."); } +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativePrepareCheckpoint + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env, jclass, + jlong client, + jbyteArray actorId) { + auto raylet_client = reinterpret_cast(client); + UniqueIdFromJByteArray actor_id(env, actorId); + ActorCheckpointID checkpoint_id; + RAY_CHECK_OK(raylet_client->PrepareActorCheckpoint(*actor_id.PID, checkpoint_id)); + jbyteArray result = env->NewByteArray(sizeof(ActorCheckpointID)); + env->SetByteArrayRegion(result, 0, sizeof(ActorCheckpointID), + reinterpret_cast(&checkpoint_id)); + return result; +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeNotifyActorResumedFromCheckpoint + * Signature: (J[B[B)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint( + JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) { + auto raylet_client = reinterpret_cast(client); + UniqueIdFromJByteArray actor_id(env, actorId); + UniqueIdFromJByteArray checkpoint_id(env, checkpointId); + RAY_CHECK_OK( + raylet_client->NotifyActorResumedFromCheckpoint(*actor_id.PID, *checkpoint_id.PID)); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h index fff804a04..8bf64e98c 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h @@ -7,6 +7,8 @@ #ifdef __cplusplus extern "C" { #endif +#undef org_ray_runtime_raylet_RayletClientImpl_TASK_SPEC_BUFFER_SIZE +#define org_ray_runtime_raylet_RayletClientImpl_TASK_SPEC_BUFFER_SIZE 2097152L /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeInit @@ -58,6 +60,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(JNIEnv *, JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked( JNIEnv *, jclass, jlong, jbyteArray); +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativePutObject + * Signature: (J[B[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativePutObject( + JNIEnv *, jclass, jlong, jbyteArray, jbyteArray); + /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeWaitObject @@ -88,6 +98,24 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects(JNIEnv *, j jlong, jobjectArray, jboolean); +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativePrepareCheckpoint + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *, jclass, + jlong, jbyteArray); + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeNotifyActorResumedFromCheckpoint + * Signature: (J[B[B)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint( + JNIEnv *, jclass, jlong, jbyteArray, jbyteArray); + #ifdef __cplusplus } #endif diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index adef01209..deb099bb0 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -140,7 +140,7 @@ ray::Status NodeManager::RegisterGcs() { if (!data.empty()) { // We only need the last entry, because it represents the latest state of // this actor. - HandleActorStateTransition(actor_id, data.back()); + HandleActorStateTransition(actor_id, ActorRegistration(data.back())); } }; @@ -507,13 +507,7 @@ void NodeManager::PublishActorStateTransition( } void NodeManager::HandleActorStateTransition(const ActorID &actor_id, - const ActorTableDataT &data) { - ActorRegistration actor_registration(data); - RAY_LOG(DEBUG) << "Actor notification received: actor_id = " << actor_id - << ", node_manager_id = " << actor_registration.GetNodeManagerId() - << ", state = " << EnumNameActorState(actor_registration.GetState()) - << ", remaining_reconstructions = " - << actor_registration.GetRemainingReconstructions(); + ActorRegistration &&actor_registration) { // Update local registry. auto it = actor_registry_.find(actor_id); if (it == actor_registry_.end()) { @@ -536,6 +530,11 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, return; } } + RAY_LOG(DEBUG) << "Actor notification received: actor_id = " << actor_id + << ", node_manager_id = " << actor_registration.GetNodeManagerId() + << ", state = " << EnumNameActorState(actor_registration.GetState()) + << ", remaining_reconstructions = " + << actor_registration.GetRemainingReconstructions(); if (actor_registration.GetState() == ActorState::ALIVE) { // The actor's location is now known. Dequeue any methods that were @@ -700,6 +699,12 @@ void NodeManager::ProcessClientMessage( std::vector object_ids = from_flatbuf(*message->object_ids()); object_manager_.FreeObjects(object_ids, message->local_only()); } break; + case protocol::MessageType::PrepareActorCheckpointRequest: { + ProcessPrepareActorCheckpointRequest(client, message_data); + } break; + case protocol::MessageType::NotifyActorResumedFromCheckpoint: { + ProcessNotifyActorResumedFromCheckpoint(message_data); + } break; default: RAY_LOG(FATAL) << "Received unexpected message type " << message_type; @@ -762,7 +767,7 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca // So if we receive any actor tasks before we receive GCS notification, // these tasks can be correctly routed to the `MethodsWaitingForActorCreation` queue, // instead of being assigned to the dead actor. - HandleActorStateTransition(actor_id, new_actor_data); + HandleActorStateTransition(actor_id, ActorRegistration(new_actor_data)); } ray::gcs::ActorTable::WriteCallback failure_callback = nullptr; if (was_local) { @@ -1014,6 +1019,64 @@ void NodeManager::ProcessPushErrorRequestMessage(const uint8_t *message_data) { timestamp)); } +void NodeManager::ProcessPrepareActorCheckpointRequest( + const std::shared_ptr &client, const uint8_t *message_data) { + auto message = + flatbuffers::GetRoot(message_data); + ActorID actor_id = from_flatbuf(*message->actor_id()); + RAY_LOG(DEBUG) << "Preparing checkpoint for actor " << actor_id; + const auto &actor_entry = actor_registry_.find(actor_id); + RAY_CHECK(actor_entry != actor_registry_.end()); + + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + RAY_CHECK(worker && worker->GetActorId() == actor_id); + + // Find the task that is running on this actor. + const auto task_id = worker->GetAssignedTaskId(); + const Task &task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING); + // Generate checkpoint id and data. + ActorCheckpointID checkpoint_id = UniqueID::from_random(); + auto checkpoint_data = + actor_entry->second.GenerateCheckpointData(actor_entry->first, task); + + // Write checkpoint data to GCS. + RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add( + UniqueID::nil(), checkpoint_id, checkpoint_data, + [worker, actor_id, this](ray::gcs::AsyncGcsClient *client, + const UniqueID &checkpoint_id, + const ActorCheckpointDataT &data) { + RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor " + << worker->GetActorId(); + // Save this actor-to-checkpoint mapping, and remove old checkpoints associated + // with this actor. + RAY_CHECK_OK(gcs_client_->actor_checkpoint_id_table().AddCheckpointId( + JobID::nil(), actor_id, checkpoint_id)); + // Send reply to worker. + flatbuffers::FlatBufferBuilder fbb; + auto reply = ray::protocol::CreatePrepareActorCheckpointReply( + fbb, to_flatbuf(fbb, checkpoint_id)); + fbb.Finish(reply); + worker->Connection()->WriteMessageAsync( + static_cast(protocol::MessageType::PrepareActorCheckpointReply), + fbb.GetSize(), fbb.GetBufferPointer(), [](const ray::Status &status) { + if (!status.ok()) { + RAY_LOG(WARNING) + << "Failed to send PrepareActorCheckpointReply to client"; + } + }); + })); +} + +void NodeManager::ProcessNotifyActorResumedFromCheckpoint(const uint8_t *message_data) { + auto message = + flatbuffers::GetRoot(message_data); + ActorID actor_id = from_flatbuf(*message->actor_id()); + ActorCheckpointID checkpoint_id = from_flatbuf(*message->checkpoint_id()); + RAY_LOG(DEBUG) << "Actor " << actor_id << " was resumed from checkpoint " + << checkpoint_id; + checkpoint_id_to_restore_.emplace(actor_id, checkpoint_id); +} + void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client) { node_manager_client.ProcessMessages(); } @@ -1154,6 +1217,12 @@ bool NodeManager::CheckDependencyManagerInvariant() const { void NodeManager::TreatTaskAsFailed(const Task &task) { const TaskSpecification &spec = task.GetTaskSpecification(); RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed."; + // If this was an actor creation task that tried to resume from a checkpoint, + // then erase it here since the task did not finish. + if (spec.IsActorCreationTask()) { + ActorID actor_id = spec.ActorCreationId(); + checkpoint_id_to_restore_.erase(actor_id); + } // Loop over the return IDs (except the dummy ID) and store a fake object in // the object store. int64_t num_returns = spec.NumReturns(); @@ -1320,7 +1389,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag if (!data.empty()) { // The actor has been created. We only need the last entry, because // it represents the latest state of this actor. - HandleActorStateTransition(actor_id, data.back()); + HandleActorStateTransition(actor_id, ActorRegistration(data.back())); } }; RAY_CHECK_OK(gcs_client_->actor_table().Lookup(JobID::nil(), spec.ActorId(), @@ -1672,86 +1741,141 @@ void NodeManager::FinishAssignedTask(Worker &worker) { } } -void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { - // If this was an actor creation task, then convert the worker to an actor - // and notify the other node managers. - if (task.GetTaskSpecification().IsActorCreationTask()) { - // Convert the worker to an actor. - auto actor_id = task.GetTaskSpecification().ActorCreationId(); - worker.AssignActorId(actor_id); - // Publish the actor creation event to all other nodes so that methods for - // the actor will be forwarded directly to this node. - auto actor_entry = actor_registry_.find(actor_id); - ActorTableDataT new_actor_data; - if (actor_entry == actor_registry_.end()) { - // Set all of the static fields for the actor. These fields will not - // change even if the actor fails or is reconstructed. - new_actor_data.actor_id = actor_id.binary(); - new_actor_data.actor_creation_dummy_object_id = - task.GetTaskSpecification().ActorDummyObject().binary(); - new_actor_data.driver_id = task.GetTaskSpecification().DriverId().binary(); - new_actor_data.max_reconstructions = - task.GetTaskSpecification().MaxActorReconstructions(); - // This is the first time that the actor has been created, so the number - // of remaining reconstructions is the max. - new_actor_data.remaining_reconstructions = - task.GetTaskSpecification().MaxActorReconstructions(); - } else { - // If we've already seen this actor, it means that this actor was reconstructed. - // Thus, its previous state must be RECONSTRUCTING. - RAY_CHECK(actor_entry->second.GetState() == ActorState::RECONSTRUCTING); - // Copy the static fields from the current actor entry. - new_actor_data = actor_entry->second.GetTableData(); - // We are reconstructing the actor, so subtract its - // remaining_reconstructions by 1. - new_actor_data.remaining_reconstructions--; - } - - // Set the new fields for the actor's state to indicate that the actor is - // now alive on this node manager. - new_actor_data.node_manager_id = - gcs_client_->client_table().GetLocalClientId().binary(); - new_actor_data.state = ActorState::ALIVE; - HandleActorStateTransition(actor_id, new_actor_data); - PublishActorStateTransition( - actor_id, new_actor_data, - /*failure_callback=*/ - [](gcs::AsyncGcsClient *client, const ActorID &id, const ActorTableDataT &data) { - // Only one node at a time should succeed at creating the actor. - RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; - }); +ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { + RAY_CHECK(task.GetTaskSpecification().IsActorCreationTask()); + auto actor_id = task.GetTaskSpecification().ActorCreationId(); + auto actor_entry = actor_registry_.find(actor_id); + ActorTableDataT new_actor_data; + // TODO(swang): If this is an actor that was reconstructed, and previous + // actor notifications were delayed, then this node may not have an entry for + // the actor in actor_regisry_. Then, the fields for the number of + // reconstructions will be wrong. + if (actor_entry == actor_registry_.end()) { + // Set all of the static fields for the actor. These fields will not + // change even if the actor fails or is reconstructed. + new_actor_data.actor_id = actor_id.binary(); + new_actor_data.actor_creation_dummy_object_id = + task.GetTaskSpecification().ActorDummyObject().binary(); + new_actor_data.driver_id = task.GetTaskSpecification().DriverId().binary(); + new_actor_data.max_reconstructions = + task.GetTaskSpecification().MaxActorReconstructions(); + // This is the first time that the actor has been created, so the number + // of remaining reconstructions is the max. + new_actor_data.remaining_reconstructions = + task.GetTaskSpecification().MaxActorReconstructions(); + } else { + // If we've already seen this actor, it means that this actor was reconstructed. + // Thus, its previous state must be RECONSTRUCTING. + RAY_CHECK(actor_entry->second.GetState() == ActorState::RECONSTRUCTING); + // Copy the static fields from the current actor entry. + new_actor_data = actor_entry->second.GetTableData(); + // We are reconstructing the actor, so subtract its + // remaining_reconstructions by 1. + new_actor_data.remaining_reconstructions--; } - // Update the actor's frontier. + // Set the new fields for the actor's state to indicate that the actor is + // now alive on this node manager. + new_actor_data.node_manager_id = + gcs_client_->client_table().GetLocalClientId().binary(); + new_actor_data.state = ActorState::ALIVE; + return new_actor_data; +} + +void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { ActorID actor_id; ActorHandleID actor_handle_id; + bool resumed_from_checkpoint = false; if (task.GetTaskSpecification().IsActorCreationTask()) { actor_id = task.GetTaskSpecification().ActorCreationId(); actor_handle_id = ActorHandleID::nil(); + if (checkpoint_id_to_restore_.count(actor_id) > 0) { + resumed_from_checkpoint = true; + } } else { actor_id = task.GetTaskSpecification().ActorId(); actor_handle_id = task.GetTaskSpecification().ActorHandleId(); } - auto actor_entry = actor_registry_.find(actor_id); - RAY_CHECK(actor_entry != actor_registry_.end()); - // Extend the actor's frontier to include the executed task. - const auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); - const ObjectID object_to_release = - actor_entry->second.ExtendFrontier(actor_handle_id, dummy_object); - if (!object_to_release.is_nil()) { - // If there were no new actor handles created, then no other actor task - // will depend on this execution dependency, so it safe to release. - HandleObjectMissing(object_to_release); + + if (task.GetTaskSpecification().IsActorCreationTask()) { + // This was an actor creation task. Convert the worker to an actor. + worker.AssignActorId(actor_id); + // Notify the other node managers that the actor has been created. + const auto new_actor_data = CreateActorTableDataFromCreationTask(task); + if (resumed_from_checkpoint) { + // This actor was resumed from a checkpoint. In this case, we first look + // up the checkpoint in GCS and use it to restore the actor registration + // and frontier. + const auto checkpoint_id = checkpoint_id_to_restore_[actor_id]; + checkpoint_id_to_restore_.erase(actor_id); + RAY_LOG(DEBUG) << "Looking up checkpoint " << checkpoint_id << " for actor " + << actor_id; + RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Lookup( + JobID::nil(), checkpoint_id, + [this, actor_id, new_actor_data](ray::gcs::AsyncGcsClient *client, + const UniqueID &checkpoint_id, + const ActorCheckpointDataT &checkpoint_data) { + RAY_LOG(INFO) << "Restoring registration for actor " << actor_id + << " from checkpoint " << checkpoint_id; + ActorRegistration actor_registration = + ActorRegistration(new_actor_data, checkpoint_data); + // Mark the unreleased dummy objects in the checkpoint frontier as local. + for (const auto &entry : actor_registration.GetDummyObjects()) { + HandleObjectLocal(entry.first); + } + HandleActorStateTransition(actor_id, std::move(actor_registration)); + PublishActorStateTransition( + actor_id, new_actor_data, + /*failure_callback=*/ + [](gcs::AsyncGcsClient *client, const ActorID &id, + const ActorTableDataT &data) { + // Only one node at a time should succeed at creating the actor. + RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; + }); + }, + [actor_id](ray::gcs::AsyncGcsClient *client, const UniqueID &checkpoint_id) { + RAY_LOG(FATAL) << "Couldn't find checkpoint " << checkpoint_id + << " for actor " << actor_id << " in GCS."; + })); + } else { + // The actor did not resume from a checkpoint. Immediately notify the + // other node managers that the actor has been created. + HandleActorStateTransition(actor_id, ActorRegistration(new_actor_data)); + PublishActorStateTransition( + actor_id, new_actor_data, + /*failure_callback=*/ + [](gcs::AsyncGcsClient *client, const ActorID &id, + const ActorTableDataT &data) { + // Only one node at a time should succeed at creating the actor. + RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; + }); + } + } + + if (!resumed_from_checkpoint) { + // The actor was not resumed from a checkpoint. We extend the actor's + // frontier as usual since there is no frontier to restore. + auto actor_entry = actor_registry_.find(actor_id); + RAY_CHECK(actor_entry != actor_registry_.end()); + // Extend the actor's frontier to include the executed task. + const auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); + const ObjectID object_to_release = + actor_entry->second.ExtendFrontier(actor_handle_id, dummy_object); + if (!object_to_release.is_nil()) { + // If there were no new actor handles created, then no other actor task + // will depend on this execution dependency, so it safe to release. + HandleObjectMissing(object_to_release); + } + // Mark the dummy object as locally available to indicate that the actor's + // state has changed and the next method can run. This is not added to the + // object table, so the update will be invisible to both the local object + // manager and the other nodes. + // NOTE(swang): The dummy objects must be marked as local whenever + // ExtendFrontier is called, and vice versa, so that we can clean up the + // dummy objects properly in case the actor fails and needs to be + // reconstructed. + HandleObjectLocal(dummy_object); } - // Mark the dummy object as locally available to indicate that the actor's - // state has changed and the next method can run. This is not added to the - // object table, so the update will be invisible to both the local object - // manager and the other nodes. - // NOTE(swang): The dummy objects must be marked as local whenever - // ExtendFrontier is called, and vice versa, so that we can clean up the - // dummy objects properly in case the actor fails and needs to be - // reconstructed. - HandleObjectLocal(dummy_object); } void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 1b4daade2..dd9ac71bd 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -186,6 +186,10 @@ class NodeManager { /// \param worker The worker that finished the task. /// \return Void. void FinishAssignedTask(Worker &worker); + /// Helper function to produce actor table data for a newly created actor. + /// + /// \param task The actor creation task that created the actor. + ActorTableDataT CreateActorTableDataFromCreationTask(const Task &task); /// Handle a worker finishing an assigned actor task or actor creation task. /// \param worker The worker that finished the task. /// \param task The actor task or actor creationt ask. @@ -282,9 +286,11 @@ class NodeManager { /// old state transition. /// /// \param actor_id The actor ID of the actor whose state was updated. - /// \param data Data associated with this notification. + /// \param actor_registration The ActorRegistration object that represents actor's + /// new state. /// \return Void. - void HandleActorStateTransition(const ActorID &actor_id, const ActorTableDataT &data); + void HandleActorStateTransition(const ActorID &actor_id, + ActorRegistration &&actor_registration); /// Publish an actor's state transition to all other nodes. /// @@ -385,6 +391,25 @@ class NodeManager { /// \return Void. void ProcessPushErrorRequestMessage(const uint8_t *message_data); + /// Process client message of PrepareActorCheckpointRequest. + /// + /// \param client The client that sent the message. + /// \param message_data A pointer to the message data. + void ProcessPrepareActorCheckpointRequest( + const std::shared_ptr &client, const uint8_t *message_data); + + /// Process client message of NotifyActorResumedFromCheckpoint. + /// + /// \param message_data A pointer to the message data. + void ProcessNotifyActorResumedFromCheckpoint(const uint8_t *message_data); + + /// Update actor frontier when a task finishes. + /// If the task is an actor creation task and the actor was resumed from a checkpoint, + /// restore the frontier from the checkpoint. Otherwise, just extend actor frontier. + /// + /// \param task The task that just finished. + void UpdateActorFrontier(const Task &task); + /// Handle the case where an actor is disconnected, determine whether this /// actor needs to be reconstructed and then update actor table. /// This function needs to be called either when actor process dies or when @@ -458,6 +483,10 @@ class NodeManager { /// A mapping from actor ID to registration information about that actor /// (including which node manager owns it). std::unordered_map actor_registry_; + + /// This map stores actor ID to the ID of the checkpoint that will be used to + /// restore the actor. + std::unordered_map checkpoint_id_to_restore_; }; } // namespace raylet diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index d931051bc..13e92d0c4 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -358,3 +358,31 @@ ray::Status RayletClient::FreeObjects(const std::vector &object_i auto status = conn_->WriteMessage(MessageType::FreeObjectsInObjectStoreRequest, &fbb); return status; } + +ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, + ActorCheckpointID &checkpoint_id) { + flatbuffers::FlatBufferBuilder fbb; + auto message = + ray::protocol::CreatePrepareActorCheckpointRequest(fbb, to_flatbuf(fbb, actor_id)); + fbb.Finish(message); + + std::unique_ptr reply; + auto status = + conn_->AtomicRequestReply(MessageType::PrepareActorCheckpointRequest, + MessageType::PrepareActorCheckpointReply, reply, &fbb); + if (!status.ok()) return status; + auto reply_message = + flatbuffers::GetRoot(reply.get()); + checkpoint_id = ObjectID::from_binary(reply_message->checkpoint_id()->str()); + return ray::Status::OK(); +} + +ray::Status RayletClient::NotifyActorResumedFromCheckpoint( + const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateNotifyActorResumedFromCheckpoint( + fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, checkpoint_id)); + fbb.Finish(message); + + return conn_->WriteMessage(MessageType::NotifyActorResumedFromCheckpoint, &fbb); +} diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 8aeaa5fa7..d3ea765df 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -10,6 +10,7 @@ #include "ray/status.h" using ray::ActorID; +using ray::ActorCheckpointID; using ray::JobID; using ray::ObjectID; using ray::TaskID; @@ -146,6 +147,22 @@ class RayletClient { /// \return ray::Status. ray::Status FreeObjects(const std::vector &object_ids, bool local_only); + /// Request raylet backend to prepare a checkpoint for an actor. + /// + /// \param actor_id ID of the actor. + /// \param checkpoint_id ID of the new checkpoint (output parameter). + /// \return ray::Status. + ray::Status PrepareActorCheckpoint(const ActorID &actor_id, + ActorCheckpointID &checkpoint_id); + + /// Notify raylet backend that an actor was resumed from a checkpoint. + /// + /// \param actor_id ID of the actor. + /// \param checkpoint_id ID of the checkpoint from which the actor was resumed. + /// \return ray::Status. + ray::Status NotifyActorResumedFromCheckpoint(const ActorID &actor_id, + const ActorCheckpointID &checkpoint_id); + Language GetLanguage() const { return language_; } ClientID GetClientID() const { return client_id_; } diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 5f183c7be..a8c0f40fe 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -99,7 +99,7 @@ TaskSpecification::TaskSpecification( fbb, to_flatbuf(fbb, driver_id), to_flatbuf(fbb, task_id), to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id), to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions, - to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter, false, + to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter, object_ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), object_ids_to_flatbuf(fbb, returns), map_to_flatbuf(fbb, required_resources), map_to_flatbuf(fbb, required_placement_resources), language, diff --git a/test/actor_test.py b/test/actor_test.py index 230a5db30..4a3a0ad58 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -19,14 +19,17 @@ import ray.test.cluster_utils @pytest.fixture -def ray_start_regular(): +def ray_start_regular(request): + internal_config = { + "initial_reconstruction_timeout_milliseconds": 200, + "num_heartbeats_timeout": 10, + } + internal_config.update(getattr(request, "param", {})) # Start the Ray processes. ray.init( num_cpus=1, - _internal_config=json.dumps({ - "initial_reconstruction_timeout_milliseconds": 200, - "num_heartbeats_timeout": 10, - })) + _internal_config=json.dumps(internal_config), + ) yield None # The code after the yield will run as teardown code. ray.shutdown() @@ -51,11 +54,16 @@ def ray_start_cluster(): @pytest.fixture() def two_node_cluster(): + internal_config = json.dumps({ + "initial_reconstruction_timeout_milliseconds": 200, + "num_heartbeats_timeout": 10, + }) cluster = ray.test.cluster_utils.Cluster() for _ in range(2): - cluster.add_node(num_cpus=1) + remote_node = cluster.add_node( + num_cpus=1, _internal_config=internal_config) ray.init(redis_address=cluster.redis_address) - yield cluster + yield cluster, remote_node # The code after the yield will run as teardown code. ray.shutdown() @@ -80,6 +88,69 @@ def head_node_cluster(request): cluster.shutdown() +@pytest.fixture +def ray_checkpointable_actor_cls(request): + checkpoint_dir = "/tmp/ray_temp_checkpoint_dir/" + if not os.path.isdir(checkpoint_dir): + os.mkdir(checkpoint_dir) + + class CheckpointableActor(ray.actor.Checkpointable): + def __init__(self): + self.value = 0 + self.resumed_from_checkpoint = False + self.checkpoint_dir = checkpoint_dir + + def local_plasma(self): + return ray.worker.global_worker.plasma_client.store_socket_name + + def increase(self): + self.value += 1 + return self.value + + def get(self): + return self.value + + def was_resumed_from_checkpoint(self): + return self.resumed_from_checkpoint + + def get_pid(self): + return os.getpid() + + def should_checkpoint(self, checkpoint_context): + # Checkpoint the actor when value is increased to 3. + should_checkpoint = self.value == 3 + return should_checkpoint + + def save_checkpoint(self, actor_id, checkpoint_id): + actor_id, checkpoint_id = actor_id.hex(), checkpoint_id.hex() + # Save checkpoint into a file. + with open(self.checkpoint_dir + actor_id, "a+") as f: + print(checkpoint_id, self.value, file=f) + + def load_checkpoint(self, actor_id, available_checkpoints): + actor_id = actor_id.hex() + filename = self.checkpoint_dir + actor_id + # Load checkpoint from the file. + if not os.path.isfile(filename): + return None + + with open(filename, "r") as f: + lines = f.readlines() + checkpoint_id, value = lines[-1].split(" ") + self.value = int(value) + self.resumed_from_checkpoint = True + checkpoint_id = ray.ActorCheckpointID( + ray.utils.hex_to_binary(checkpoint_id)) + assert any(checkpoint_id == checkpoint.checkpoint_id + for checkpoint in available_checkpoints) + return checkpoint_id + + def checkpoint_expired(self, actor_id, checkpoint_id): + pass + + return CheckpointableActor + + def test_actor_init_error_propagated(ray_start_regular): @ray.remote class Actor(object): @@ -1461,146 +1532,6 @@ def setup_counter_actor(test_checkpoint=False, return actor, ids -@pytest.mark.skip("This test does not work yet.") -@pytest.mark.skipif( - os.environ.get("RAY_USE_NEW_GCS") == "on", - reason="Hanging with new GCS API.") -def test_checkpointing(two_node_cluster): - cluster = two_node_cluster - actor, ids = setup_counter_actor(test_checkpoint=True) - # Wait for the last task to finish running. - ray.get(ids[-1]) - - # Kill the corresponding plasma store to get rid of the cached objects. - cluster.list_all_nodes()[1].kill_plasma_store(wait=True) - - # Check that the actor restored from a checkpoint. - assert ray.get(actor.test_restore.remote()) - # Check that we can submit another call on the actor and get the - # correct counter result. - x = ray.get(actor.inc.remote()) - assert x == 101 - # Check that the number of inc calls since actor initialization is less - # than the counter value, since the actor initialized from a - # checkpoint. - num_inc_calls = ray.get(actor.get_num_inc_calls.remote()) - assert num_inc_calls < x - - -@pytest.mark.skip("This test does not work yet.") -@pytest.mark.skipif( - os.environ.get("RAY_USE_NEW_GCS") == "on", - reason="Hanging with new GCS API.") -def test_remote_checkpoint(two_node_cluster): - cluster = two_node_cluster - actor, ids = setup_counter_actor(test_checkpoint=True) - - # Do a remote checkpoint call and wait for it to finish. - ray.get(actor.__ray_checkpoint__.remote()) - - # Kill the corresponding plasma store to get rid of the cached objects. - cluster.list_all_nodes()[1].kill_plasma_store(wait=True) - - # Check that the actor restored from a checkpoint. - assert ray.get(actor.test_restore.remote()) - # Check that the number of inc calls since actor initialization is - # exactly zero, since there could not have been another inc call since - # the remote checkpoint. - num_inc_calls = ray.get(actor.get_num_inc_calls.remote()) - assert num_inc_calls == 0 - # Check that we can submit another call on the actor and get the - # correct counter result. - x = ray.get(actor.inc.remote()) - assert x == 101 - - -@pytest.mark.skip("This test does not work yet.") -@pytest.mark.skipif( - os.environ.get("RAY_USE_NEW_GCS") == "on", - reason="Hanging with new GCS API.") -def test_lost_checkpoint(two_node_cluster): - cluster = two_node_cluster - actor, ids = setup_counter_actor(test_checkpoint=True) - # Wait for the first fraction of tasks to finish running. - ray.get(ids[len(ids) // 10]) - - # Kill the corresponding plasma store to get rid of the cached objects. - cluster.list_all_nodes()[1].kill_plasma_store(wait=True) - - # Check that the actor restored from a checkpoint. - assert ray.get(actor.test_restore.remote()) - # Check that we can submit another call on the actor and get the - # correct counter result. - x = ray.get(actor.inc.remote()) - assert x == 101 - # Check that the number of inc calls since actor initialization is less - # than the counter value, since the actor initialized from a - # checkpoint. - num_inc_calls = ray.get(actor.get_num_inc_calls.remote()) - assert num_inc_calls < x - assert 5 < num_inc_calls - - -@pytest.mark.skip("This test does not work yet.") -@pytest.mark.skipif( - os.environ.get("RAY_USE_NEW_GCS") == "on", - reason="Hanging with new GCS API.") -def test_checkpoint_exception(two_node_cluster): - cluster = two_node_cluster - actor, ids = setup_counter_actor(test_checkpoint=True, save_exception=True) - # Wait for the last task to finish running. - ray.get(ids[-1]) - - # Kill the corresponding plasma store to get rid of the cached objects. - cluster.list_all_nodes()[1].kill_plasma_store(wait=True) - - # Check that we can submit another call on the actor and get the - # correct counter result. - x = ray.get(actor.inc.remote()) - assert x == 101 - # Check that the number of inc calls since actor initialization is - # equal to the counter value, since the actor did not initialize from a - # checkpoint. - num_inc_calls = ray.get(actor.get_num_inc_calls.remote()) - assert num_inc_calls == x - # Check that errors were raised when trying to save the checkpoint. - errors = ray.error_info() - assert 0 < len(errors) - for error in errors: - assert error["type"] == ray_constants.CHECKPOINT_PUSH_ERROR - - -@pytest.mark.skip("This test does not work yet.") -@pytest.mark.skipif( - os.environ.get("RAY_USE_NEW_GCS") == "on", - reason="Hanging with new GCS API.") -def test_checkpoint_resume_exception(two_node_cluster): - cluster = two_node_cluster - actor, ids = setup_counter_actor( - test_checkpoint=True, resume_exception=True) - # Wait for the last task to finish running. - ray.get(ids[-1]) - - # Kill the corresponding plasma store to get rid of the cached objects. - cluster.list_all_nodes()[1].kill_plasma_store(wait=True) - - # Check that we can submit another call on the actor and get the - # correct counter result. - x = ray.get(actor.inc.remote()) - assert x == 101 - # Check that the number of inc calls since actor initialization is - # equal to the counter value, since the actor did not initialize from a - # checkpoint. - num_inc_calls = ray.get(actor.get_num_inc_calls.remote()) - assert num_inc_calls == x - # Check that an error was raised when trying to resume from the - # checkpoint. - errors = ray.error_info() - assert len(errors) == 1 - for error in errors: - assert error["type"] == ray_constants.CHECKPOINT_PUSH_ERROR - - @pytest.mark.skip("Fork/join consistency not yet implemented.") def test_distributed_handle(two_node_cluster): cluster = two_node_cluster @@ -2380,3 +2311,255 @@ def test_multiple_actor_reconstruction(head_node_cluster): for _, result_id_list in result_ids.items(): results = list(range(1, len(result_id_list) + 1)) assert ray.get(result_id_list) == results + + +def kill_actor(actor): + """A helper function that kills an actor process.""" + pid = ray.get(actor.get_pid.remote()) + os.kill(pid, signal.SIGKILL) + time.sleep(1) + + +def test_checkpointing(ray_start_regular, ray_checkpointable_actor_cls): + """Test actor checkpointing and restoring from a checkpoint.""" + actor = ray.remote( + max_reconstructions=2)(ray_checkpointable_actor_cls).remote() + # Call increase 3 times. + expected = 0 + for _ in range(3): + ray.get(actor.increase.remote()) + expected += 1 + # Assert that the actor wasn't resumed from a checkpoint. + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False + # Kill actor process. + kill_actor(actor) + # Assert that the actor was resumed from a checkpoint and its value is + # still correct. + assert ray.get(actor.get.remote()) == expected + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is True + + # Submit some more tasks. These should get replayed since they happen after + # the checkpoint. + for _ in range(3): + ray.get(actor.increase.remote()) + expected += 1 + # Kill actor again and check that reconstruction still works after the + # actor resuming from a checkpoint. + kill_actor(actor) + assert ray.get(actor.get.remote()) == expected + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is True + + +def test_remote_checkpointing(ray_start_regular, ray_checkpointable_actor_cls): + """Test checkpointing of a remote actor through method invocation.""" + + # Define a class that exposes a method to save checkpoints. + class RemoteCheckpointableActor(ray_checkpointable_actor_cls): + def __init__(self): + super(RemoteCheckpointableActor, self).__init__() + self._should_checkpoint = False + + def checkpoint(self): + self._should_checkpoint = True + + def should_checkpoint(self, checkpoint_context): + should_checkpoint = self._should_checkpoint + self._should_checkpoint = False + return should_checkpoint + + cls = ray.remote(max_reconstructions=2)(RemoteCheckpointableActor) + actor = cls.remote() + # Call increase 3 times. + expected = 0 + for _ in range(3): + ray.get(actor.increase.remote()) + expected += 1 + # Call a checkpoint task. + actor.checkpoint.remote() + # Assert that the actor wasn't resumed from a checkpoint. + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False + # Kill actor process. + kill_actor(actor) + # Assert that the actor was resumed from a checkpoint and its value is + # still correct. + assert ray.get(actor.get.remote()) == expected + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is True + + # Submit some more tasks. These should get replayed since they happen after + # the checkpoint. + for _ in range(3): + ray.get(actor.increase.remote()) + expected += 1 + # Kill actor again and check that reconstruction still works after the + # actor resuming from a checkpoint. + kill_actor(actor) + assert ray.get(actor.get.remote()) == expected + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is True + + +def test_checkpointing_on_node_failure(two_node_cluster, + ray_checkpointable_actor_cls): + """Test actor checkpointing on a remote node.""" + # Place the actor on the remote node. + cluster, remote_node = two_node_cluster + actor_cls = ray.remote(max_reconstructions=1)(ray_checkpointable_actor_cls) + actor = actor_cls.remote() + while (ray.get(actor.local_plasma.remote()) != + remote_node.plasma_store_socket_name): + actor = actor_cls.remote() + + # Call increase several times. + expected = 0 + for _ in range(6): + ray.get(actor.increase.remote()) + expected += 1 + # Assert that the actor wasn't resumed from a checkpoint. + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False + # Kill actor process. + cluster.remove_node(remote_node) + # Assert that the actor was resumed from a checkpoint and its value is + # still correct. + assert ray.get(actor.get.remote()) == expected + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is True + + +def test_checkpointing_save_exception(ray_start_regular, + ray_checkpointable_actor_cls): + """Test actor can still be recovered if checkpoints fail to complete.""" + + @ray.remote(max_reconstructions=2) + class RemoteCheckpointableActor(ray_checkpointable_actor_cls): + def save_checkpoint(self, actor_id, checkpoint_context): + raise Exception("Error during save") + + actor = RemoteCheckpointableActor.remote() + # Call increase 3 times. + expected = 0 + for _ in range(3): + ray.get(actor.increase.remote()) + expected += 1 + # Assert that the actor wasn't resumed from a checkpoint. + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False + # Kill actor process. + kill_actor(actor) + # Assert that the actor still wasn't resumed from a checkpoint and its + # value is still correct. + assert ray.get(actor.get.remote()) == expected + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False + + # Submit some more tasks. These should get replayed since they happen after + # the checkpoint. + for _ in range(3): + ray.get(actor.increase.remote()) + expected += 1 + # Kill actor again, and check that reconstruction still works and the actor + # wasn't resumed from a checkpoint. + kill_actor(actor) + assert ray.get(actor.get.remote()) == expected + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False + + # Check that checkpointing errors were pushed to the driver. + errors = ray.error_info() + assert len(errors) > 0 + for error in errors: + # An error for the actor process dying may also get pushed. + assert (error["type"] == ray_constants.CHECKPOINT_PUSH_ERROR + or error["type"] == ray_constants.WORKER_DIED_PUSH_ERROR) + + +def test_checkpointing_load_exception(ray_start_regular, + ray_checkpointable_actor_cls): + """Test actor can still be recovered if checkpoints fail to load.""" + + @ray.remote(max_reconstructions=2) + class RemoteCheckpointableActor(ray_checkpointable_actor_cls): + def load_checkpoint(self, actor_id, checkpoints): + raise Exception("Error during load") + + actor = RemoteCheckpointableActor.remote() + # Call increase 3 times. + expected = 0 + for _ in range(3): + ray.get(actor.increase.remote()) + expected += 1 + # Assert that the actor wasn't resumed from a checkpoint. + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False + # Kill actor process. + kill_actor(actor) + # Assert that the actor still wasn't resumed from a checkpoint and its + # value is still correct. + assert ray.get(actor.get.remote()) == expected + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False + + # Submit some more tasks. These should get replayed since they happen after + # the checkpoint. + for _ in range(3): + ray.get(actor.increase.remote()) + expected += 1 + # Kill actor again, and check that reconstruction still works and the actor + # wasn't resumed from a checkpoint. + kill_actor(actor) + assert ray.get(actor.get.remote()) == expected + assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False + + # Check that checkpointing errors were pushed to the driver. + errors = ray.error_info() + assert len(errors) > 0 + for error in errors: + # An error for the actor process dying may also get pushed. + assert (error["type"] == ray_constants.CHECKPOINT_PUSH_ERROR + or error["type"] == ray_constants.WORKER_DIED_PUSH_ERROR) + + +@pytest.mark.parametrize( + "ray_start_regular", + # This overwrite currently isn't effective, + # see https://github.com/ray-project/ray/issues/3926. + [{ + "num_actor_checkpoints_to_keep": 20 + }], + indirect=True, +) +def test_deleting_actor_checkpoint(ray_start_regular): + """Test deleting old actor checkpoints.""" + + @ray.remote + class CheckpointableActor(ray.actor.Checkpointable): + def __init__(self): + self.checkpoint_ids = [] + + def get_checkpoint_ids(self): + return self.checkpoint_ids + + def should_checkpoint(self, checkpoint_context): + # Save checkpoints after every task + return True + + def save_checkpoint(self, actor_id, checkpoint_id): + self.checkpoint_ids.append(checkpoint_id) + pass + + def load_checkpoint(self, actor_id, available_checkpoints): + pass + + def checkpoint_expired(self, actor_id, checkpoint_id): + assert checkpoint_id == self.checkpoint_ids[0] + del self.checkpoint_ids[0] + + actor = CheckpointableActor.remote() + for i in range(19): + assert len(ray.get(actor.get_checkpoint_ids.remote())) == i + 1 + for _ in range(20): + assert len(ray.get(actor.get_checkpoint_ids.remote())) == 20 + + +def test_bad_checkpointable_actor_class(): + """Test error raised if an actor class doesn't implement all abstract + methods in the Checkpointable interface.""" + + with pytest.raises(TypeError): + + @ray.remote + class BadCheckpointableActor(ray.actor.Checkpointable): + def should_checkpoint(self, checkpoint_context): + return True