Implement actor checkpointing (#3839)

* Implement Actor checkpointing

* docs

* fix

* fix

* fix

* move restore-from-checkpoint to HandleActorStateTransition

* Revert "move restore-from-checkpoint to HandleActorStateTransition"

This reverts commit 9aa4447c1e3e321f42a1d895d72f17098b72de12.

* resubmit waiting tasks when actor frontier restored

* add doc about num_actor_checkpoints_to_keep=1

* add num_actor_checkpoints_to_keep to Cython

* add checkpoint_expired api

* check if actor class is abstract

* change checkpoint_ids to long string

* implement java

* Refactor to delay actor creation publish until checkpoint is resumed

* debug, lint

* Erase from checkpoints to restore if task fails

* fix lint

* update comments

* avoid duplicated actor notification log

* fix unintended change

* add actor_id to checkpoint_expired

* small java updates

* make checkpoint info per actor

* lint

* Remove logging

* Remove old actor checkpointing Python code, move new checkpointing code to FunctionActionManager

* Replace old actor checkpointing tests

* Fix test and lint

* address comments

* consolidate kill_actor

* Remove __ray_checkpoint__

* fix non-ascii char

* Loosen test checks

* fix java

* fix sphinx-build
This commit is contained in:
Hao Chen
2019-02-13 19:39:02 +08:00
committed by GitHub
parent 57dcd3033e
commit f31a79f3f7
41 changed files with 1708 additions and 490 deletions
+24 -5
View File
@@ -49,9 +49,20 @@ except ImportError as e:
modin_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "modin")
sys.path.append(modin_path)
from ray._raylet import (UniqueID, ObjectID, DriverID, ClientID, ActorID,
ActorHandleID, FunctionID, ActorClassID, TaskID,
_ID_TYPES, Config as _Config) # noqa: E402
from ray._raylet import (
ActorCheckpointID,
ActorClassID,
ActorHandleID,
ActorID,
ClientID,
Config as _Config,
DriverID,
FunctionID,
ObjectID,
TaskID,
UniqueID,
_ID_TYPES,
) # noqa: E402
_config = _Config()
@@ -82,8 +93,16 @@ __all__ = [
]
__all__ += [
"UniqueID", "ObjectID", "DriverID", "ClientID", "ActorID", "ActorHandleID",
"FunctionID", "ActorClassID", "TaskID"
"ActorCheckpointID",
"ActorClassID",
"ActorHandleID",
"ActorID",
"ClientID",
"DriverID",
"FunctionID",
"ObjectID",
"TaskID",
"UniqueID",
]
import ctypes # noqa: E402
+32 -5
View File
@@ -19,12 +19,31 @@ include "includes/ray_config.pxi"
include "includes/task.pxi"
from ray.includes.common cimport (
CUniqueID, CTaskID, CObjectID, CFunctionID, CActorClassID, CActorID,
CActorHandleID, CWorkerID, CDriverID, CConfigID, CClientID,
CLanguage, CRayStatus, LANGUAGE_CPP, LANGUAGE_JAVA, LANGUAGE_PYTHON)
CActorCheckpointID,
CActorClassID,
CActorHandleID,
CActorID,
CClientID,
CConfigID,
CDriverID,
CFunctionID,
CLanguage,
CObjectID,
CRayStatus,
CTaskID,
CUniqueID,
CWorkerID,
LANGUAGE_CPP,
LANGUAGE_JAVA,
LANGUAGE_PYTHON,
)
from ray.includes.libraylet cimport (
CRayletClient, GCSProfileTableDataT, GCSProfileEventT,
ResourceMappingType, WaitResultPair)
CRayletClient,
GCSProfileEventT,
GCSProfileTableDataT,
ResourceMappingType,
WaitResultPair,
)
from ray.includes.task cimport CTaskSpecification
from ray.includes.ray_config cimport RayConfig
from ray.utils import decode
@@ -303,6 +322,14 @@ cdef class RayletClient:
cdef c_vector[CObjectID] free_ids = ObjectIDsToVector(object_ids)
check_status(self.client.get().FreeObjects(free_ids, local_only))
def prepare_actor_checkpoint(self, ActorID actor_id):
cdef CActorCheckpointID checkpoint_id
check_status(self.client.get().PrepareActorCheckpoint(actor_id.data, checkpoint_id))
return ObjectID.from_native(checkpoint_id);
def notify_actor_resumed_from_checkpoint(self, ActorID actor_id, ActorCheckpointID checkpoint_id):
check_status(self.client.get().NotifyActorResumedFromCheckpoint(actor_id.data, checkpoint_id.data))
@property
def language(self):
return Language.from_native(self.client.get().GetLanguage())
+144 -167
View File
@@ -6,11 +6,13 @@ import copy
import hashlib
import inspect
import logging
import six
import sys
import threading
import traceback
import ray.cloudpickle as pickle
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from ray.function_manager import FunctionDescriptor
import ray.ray_constants as ray_constants
import ray.signature as signature
@@ -75,90 +77,6 @@ def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id):
return ActorHandleID(handle_id)
def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
frontier):
"""Set the most recent checkpoint associated with a given actor ID.
Args:
worker: The worker to use to get the checkpoint.
actor_id: The actor ID of the actor to get the checkpoint for.
checkpoint_index: The number of tasks included in the checkpoint.
checkpoint: The state object to save.
frontier: The task frontier at the time of the checkpoint.
"""
assert isinstance(actor_id, ActorID)
actor_key = b"Actor:" + actor_id.binary()
worker.redis_client.hmset(
actor_key, {
"checkpoint_index": checkpoint_index,
"checkpoint": checkpoint,
"frontier": frontier,
})
def save_and_log_checkpoint(worker, actor):
"""Save a checkpoint on the actor and log any errors.
Args:
worker: The worker to use to log errors.
actor: The actor to checkpoint.
"""
try:
actor.__ray_checkpoint__()
except Exception:
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id)
def restore_and_log_checkpoint(worker, actor):
"""Restore an actor from a checkpoint and log any errors.
Args:
worker: The worker to use to log errors.
actor: The actor to restore.
"""
checkpoint_resumed = False
try:
checkpoint_resumed = actor.__ray_checkpoint_restore__()
except Exception:
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id)
return checkpoint_resumed
def get_actor_checkpoint(worker, actor_id):
"""Get the most recent checkpoint associated with a given actor ID.
Args:
worker: The worker to use to get the checkpoint.
actor_id: The actor ID of the actor to get the checkpoint for.
Returns:
If a checkpoint exists, this returns a tuple of the number of tasks
included in the checkpoint, the saved checkpoint state, and the
task frontier at the time of the checkpoint. If no checkpoint
exists, all objects are set to None. The checkpoint index is the .
executed on the actor before the checkpoint was made.
"""
assert isinstance(actor_id, ActorID)
actor_key = b"Actor:" + actor_id.binary()
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
if checkpoint_index is not None:
checkpoint_index = int(checkpoint_index)
return checkpoint_index, checkpoint, frontier
def method(*args, **kwargs):
"""Annotate an actor method.
@@ -234,7 +152,6 @@ class ActorClass(object):
additional methods added like __ray_terminate__).
_class_id: The ID of this actor class.
_class_name: The name of this class.
_checkpoint_interval: The interval at which to checkpoint actor state.
_num_cpus: The default number of CPUs required by the actor creation
task.
_num_gpus: The default number of GPUs required by the actor creation
@@ -250,13 +167,11 @@ class ActorClass(object):
each actor method.
"""
def __init__(self, modified_class, class_id, checkpoint_interval,
max_reconstructions, num_cpus, num_gpus, resources,
actor_method_cpus):
def __init__(self, modified_class, class_id, max_reconstructions, num_cpus,
num_gpus, resources, actor_method_cpus):
self._modified_class = modified_class
self._class_id = class_id
self._class_name = modified_class.__name__
self._checkpoint_interval = checkpoint_interval
self._max_reconstructions = max_reconstructions
self._num_cpus = num_cpus
self._num_gpus = num_gpus
@@ -383,8 +298,7 @@ class ActorClass(object):
# Export the actor.
if not self._exported:
worker.function_actor_manager.export_actor_class(
self._modified_class, self._actor_method_names,
self._checkpoint_interval)
self._modified_class, self._actor_method_names)
self._exported = True
resources = ray.utils.resources_from_resource_arguments(
@@ -564,8 +478,6 @@ class ActorHandle(object):
return getattr(worker.actors[self._ray_actor_id],
method_name)(*copy.deepcopy(args))
is_actor_checkpoint_method = (method_name == "__ray_checkpoint__")
function_descriptor = FunctionDescriptor(
self._ray_module_name, method_name, self._ray_class_name)
with self._ray_actor_lock:
@@ -575,7 +487,6 @@ class ActorHandle(object):
actor_id=self._ray_actor_id,
actor_handle_id=self._ray_actor_handle_id,
actor_counter=self._ray_actor_counter,
is_actor_checkpoint_method=is_actor_checkpoint_method,
actor_creation_dummy_object_id=(
self._ray_actor_creation_dummy_object_id),
execution_dependencies=[self._ray_actor_cursor],
@@ -770,7 +681,7 @@ class ActorHandle(object):
def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
checkpoint_interval, max_reconstructions):
max_reconstructions):
# Give an error if cls is an old-style class.
if not issubclass(cls, object):
raise TypeError(
@@ -778,13 +689,14 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
"classes. In Python 2, you must declare the class with "
"'class ClassName(object):' instead of 'class ClassName:'.")
if checkpoint_interval is None:
checkpoint_interval = -1
if issubclass(cls, Checkpointable) and inspect.isabstract(cls):
raise TypeError(
"A checkpointable actor class should implement all abstract "
"methods in the `Checkpointable` interface.")
if max_reconstructions is None:
max_reconstructions = 0
if checkpoint_interval == 0:
raise Exception("checkpoint_interval must be greater than 0.")
if not (ray_constants.NO_RECONSTRUCTION <= max_reconstructions <=
ray_constants.INFINITE_RECONSTRUCTION):
raise Exception("max_reconstructions must be in range [%d, %d]." %
@@ -804,26 +716,6 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
sys.exit(0)
assert False, "This process should have terminated."
def __ray_save_checkpoint__(self):
if hasattr(self, "__ray_save__"):
object_to_serialize = self.__ray_save__()
else:
object_to_serialize = self
return pickle.dumps(object_to_serialize)
@classmethod
def __ray_restore_from_checkpoint__(cls, pickled_checkpoint):
checkpoint = pickle.loads(pickled_checkpoint)
if hasattr(cls, "__ray_restore__"):
actor_object = cls.__new__(cls)
actor_object.__ray_restore__(checkpoint)
else:
# TODO(rkn): It's possible that this will cause problems. When
# you unpickle the same object twice, the two objects will not
# have the same class.
actor_object = checkpoint
return actor_object
def __ray_checkpoint__(self):
"""Save a checkpoint.
@@ -832,58 +724,143 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
(number of tasks executed so far).
"""
worker = ray.worker.global_worker
checkpoint_index = worker.actor_task_counter
# Get the state to save.
checkpoint = self.__ray_save_checkpoint__()
# Get the current task frontier, per actor handle.
# NOTE(swang): This only includes actor handles that the local
# scheduler has seen. Handle IDs for which no task has yet reached
# the local scheduler will not be included, and may not be runnable
# on checkpoint resumption.
actor_id = worker.actor_id
frontier = worker.raylet_client.get_actor_frontier(actor_id)
# Save the checkpoint in Redis. TODO(rkn): Checkpoints
# should not be stored in Redis. Fix this.
set_actor_checkpoint(worker, worker.actor_id, checkpoint_index,
checkpoint, frontier)
def __ray_checkpoint_restore__(self):
"""Restore a checkpoint.
This task looks for a saved checkpoint and if found, restores the
state of the actor, the task frontier in the local scheduler, and
the checkpoint index (number of tasks executed so far).
Returns:
A bool indicating whether a checkpoint was resumed.
"""
worker = ray.worker.global_worker
# Get the most recent checkpoint stored, if any.
checkpoint_index, checkpoint, frontier = get_actor_checkpoint(
worker, worker.actor_id)
# Try to resume from the checkpoint.
checkpoint_resumed = False
if checkpoint_index is not None:
# Load the actor state from the checkpoint.
worker.actors[worker.actor_id] = (
worker.actor_class.__ray_restore_from_checkpoint__(
checkpoint))
# Set the number of tasks executed so far.
worker.actor_task_counter = checkpoint_index
# Set the actor frontier in the local scheduler.
worker.raylet_client.set_actor_frontier(frontier)
checkpoint_resumed = True
return checkpoint_resumed
if not isinstance(self, ray.actor.Checkpointable):
raise Exception(
"__ray_checkpoint__.remote() may only be called on actors "
"that implement ray.actor.Checkpointable")
return worker._save_actor_checkpoint()
Class.__module__ = cls.__module__
Class.__name__ = cls.__name__
class_id = ActorClassID(_random_string())
return ActorClass(Class, class_id, checkpoint_interval,
max_reconstructions, num_cpus, num_gpus, resources,
actor_method_cpus)
return ActorClass(Class, class_id, max_reconstructions, num_cpus, num_gpus,
resources, actor_method_cpus)
ray.worker.global_worker.make_actor = make_actor
CheckpointContext = namedtuple(
'CheckpointContext',
[
# Actor's ID.
'actor_id',
# Number of tasks executed since last checkpoint.
'num_tasks_since_last_checkpoint',
# Time elapsed since last checkpoint, in milliseconds.
'time_elapsed_ms_since_last_checkpoint',
],
)
"""A namedtuple that contains information about actor's last checkpoint."""
Checkpoint = namedtuple(
'Checkpoint',
[
# ID of this checkpoint.
'checkpoint_id',
# The timestamp at which this checkpoint was saved,
# represented as milliseconds elapsed since Unix epoch.
'timestamp',
],
)
"""A namedtuple that represents a checkpoint."""
class Checkpointable(six.with_metaclass(ABCMeta, object)):
"""An interface that indicates an actor can be checkpointed."""
@abstractmethod
def should_checkpoint(self, checkpoint_context):
"""Whether this actor needs to be checkpointed.
This method will be called after every task. You should implement this
callback to decide whether this actor needs to be checkpointed at this
time, based on the checkpoint context, or any other factors.
Args:
checkpoint_context: A namedtuple that contains info about last
checkpoint.
Returns:
A boolean value that indicates whether this actor needs to be
checkpointed.
"""
pass
@abstractmethod
def save_checkpoint(self, actor_id, checkpoint_id):
"""Save a checkpoint to persistent storage.
If `should_checkpoint` returns true, this method will be called. You
should implement this callback to save actor's checkpoint and the given
checkpoint id to persistent storage.
Args:
actor_id: Actor's ID.
checkpoint_id: ID of this checkpoint. You should save it together
with actor's checkpoint data. And it will be used by the
`load_checkpoint` method.
Returns:
None.
"""
pass
@abstractmethod
def load_checkpoint(self, actor_id, available_checkpoints):
"""Load actor's previous checkpoint, and restore actor's state.
This method will be called when an actor is reconstructed, after
actor's constructor.
If the actor needs to restore from previous checkpoint, this function
should restore actor's state and return the checkpoint ID. Otherwise,
it should do nothing and return None.
Note, this method must return one of the checkpoint IDs in the
`available_checkpoints` list, or None. Otherwise, an exception will be
raised.
Args:
actor_id: Actor's ID.
available_checkpoints: A list of `Checkpoint` namedtuples that
contains all available checkpoint IDs and their timestamps,
sorted by timestamp in descending order.
Returns:
The ID of the checkpoint from which the actor was resumed, or None
if the actor should restart from the beginning.
"""
pass
@abstractmethod
def checkpoint_expired(self, actor_id, checkpoint_id):
"""Delete an expired checkpoint.
This method will be called when an checkpoint is expired. You should
implement this method to delete your application checkpoint data.
Note, the maximum number of checkpoints kept in the backend can be
configured at `RayConfig.num_actor_checkpoints_to_keep`.
Args:
actor_id: ID of the actor.
checkpoint_id: ID of the checkpoint that has expired.
Returns:
None.
"""
pass
def get_checkpoints_for_actor(actor_id):
"""Get the available checkpoints for the given actor ID, return a list
sorted by checkpoint timestamp in descending order.
"""
checkpoint_info = ray.worker.global_state.actor_checkpoint_info(actor_id)
if checkpoint_info is None:
return []
checkpoints = [
Checkpoint(checkpoint_id, timestamp) for checkpoint_id, timestamp in
zip(checkpoint_info['CheckpointIds'], checkpoint_info['Timestamps'])
]
return sorted(
checkpoints,
key=lambda checkpoint: checkpoint.timestamp,
reverse=True,
)
+42 -2
View File
@@ -11,7 +11,8 @@ import time
import ray
from ray.function_manager import FunctionDescriptor
import ray.gcs_utils
import ray.ray_constants as ray_constants
from ray.ray_constants import ID_SIZE
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
@@ -720,7 +721,7 @@ class GlobalState(object):
for key in actor_keys:
info = self.redis_client.hgetall(key)
actor_id = key[len("Actor:"):]
assert len(actor_id) == ray_constants.ID_SIZE
assert len(actor_id) == ID_SIZE
actor_info[binary_to_hex(actor_id)] = {
"class_id": binary_to_hex(info[b"class_id"]),
"driver_id": binary_to_hex(info[b"driver_id"]),
@@ -906,3 +907,42 @@ class GlobalState(object):
binary_to_hex(job_id): self._error_messages(ray.DriverID(job_id))
for job_id in job_ids
}
def actor_checkpoint_info(self, actor_id):
"""Get checkpoint info for the given actor id.
Args:
actor_id: Actor's ID.
Returns:
A dictionary with information about the actor's checkpoint IDs and
their timestamps.
"""
self._check_connected()
message = self._execute_command(
actor_id,
"RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.ACTOR_CHECKPOINT_ID,
"",
actor_id.binary(),
)
if message is None:
return None
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
entry = (
ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData(
gcs_entry.Entries(0), 0))
checkpoint_ids_str = entry.CheckpointIds()
num_checkpoints = len(checkpoint_ids_str) // ID_SIZE
assert len(checkpoint_ids_str) % ID_SIZE == 0
checkpoint_ids = [
ray.ActorCheckpointID(
checkpoint_ids_str[(i * ID_SIZE):((i + 1) * ID_SIZE)])
for i in range(num_checkpoints)
]
return {
"ActorID": ray.utils.binary_to_hex(entry.ActorId()),
"CheckpointIds": checkpoint_ids,
"Timestamps": [
entry.Timestamps(i) for i in range(num_checkpoints)
],
}
+104 -40
View File
@@ -510,8 +510,7 @@ class FunctionActorManager(object):
self._worker.redis_client.hmset(key, actor_class_info)
self._worker.redis_client.rpush("Exports", key)
def export_actor_class(self, Class, actor_method_names,
checkpoint_interval):
def export_actor_class(self, Class, actor_method_names):
function_descriptor = FunctionDescriptor.from_class(Class)
# `task_driver_id` shouldn't be NIL, unless:
# 1) This worker isn't an actor;
@@ -528,7 +527,6 @@ class FunctionActorManager(object):
"class_name": Class.__name__,
"module": Class.__module__,
"class": pickle.dumps(Class),
"checkpoint_interval": checkpoint_interval,
"driver_id": driver_id.binary(),
"actor_method_names": json.dumps(list(actor_method_names))
}
@@ -576,17 +574,16 @@ class FunctionActorManager(object):
actor_class_key: The key in Redis to use to fetch the actor.
"""
actor_id = self._worker.actor_id
(driver_id_str, class_name, module, pickled_class, checkpoint_interval,
(driver_id_str, class_name, module, pickled_class,
actor_method_names) = self._worker.redis_client.hmget(
actor_class_key, [
"driver_id", "class_name", "module", "class",
"checkpoint_interval", "actor_method_names"
"actor_method_names"
])
class_name = decode(class_name)
module = decode(module)
driver_id = ray.DriverID(driver_id_str)
checkpoint_interval = int(checkpoint_interval)
actor_method_names = json.loads(decode(actor_method_names))
# In Python 2, json loads strings as unicode, so convert them back to
@@ -605,7 +602,6 @@ class FunctionActorManager(object):
pass
self._worker.actors[actor_id] = TemporaryActor()
self._worker.actor_checkpoint_interval = checkpoint_interval
def temporary_actor_method(*xs):
raise Exception(
@@ -694,48 +690,116 @@ class FunctionActorManager(object):
# to execute.
self._worker.actor_task_counter += 1
# If this is the first task to execute on the actor, try to resume
# from a checkpoint.
# Current __init__ will be called by default. So the real function
# call will start from 2.
if actor_imported and self._worker.actor_task_counter == 2:
checkpoint_resumed = ray.actor.restore_and_log_checkpoint(
self._worker, actor)
if checkpoint_resumed:
# NOTE(swang): Since we did not actually execute the
# __init__ method, this will put None as the return value.
# If the __init__ method is supposed to return multiple
# values, an exception will be logged.
return
# Determine whether we should checkpoint the actor.
checkpointing_on = (actor_imported
and self._worker.actor_checkpoint_interval > 0)
# We should checkpoint the actor if user checkpointing is on, we've
# executed checkpoint_interval tasks since the last checkpoint, and
# the method we're about to execute is not a checkpoint.
save_checkpoint = (checkpointing_on
and (self._worker.actor_task_counter %
self._worker.actor_checkpoint_interval == 0
and method_name != "__ray_checkpoint__"))
# Execute the assigned method and save a checkpoint if necessary.
try:
if is_class_method(method):
method_returns = method(*args)
else:
method_returns = method(actor, *args)
except Exception:
except Exception as e:
# Save the checkpoint before allowing the method exception
# to be thrown.
if save_checkpoint:
ray.actor.save_and_log_checkpoint(self._worker, actor)
raise
if isinstance(actor, ray.actor.Checkpointable):
self._save_and_log_checkpoint(actor)
raise e
else:
# Save the checkpoint before returning the method's return
# values.
if save_checkpoint:
ray.actor.save_and_log_checkpoint(self._worker, actor)
# Handle any checkpointing operations before storing the
# method's return values.
# NOTE(swang): If method_returns is a pointer to the actor's
# state and the checkpointing operations can modify the return
# values if they mutate the actor's state. Is this okay?
if isinstance(actor, ray.actor.Checkpointable):
# If this is the first task to execute on the actor, try to
# resume from a checkpoint.
if self._worker.actor_task_counter == 1:
if actor_imported:
self._restore_and_log_checkpoint(actor)
else:
# Save the checkpoint before returning the method's
# return values.
self._save_and_log_checkpoint(actor)
return method_returns
return actor_method_executor
def _save_and_log_checkpoint(self, actor):
"""Save an actor checkpoint if necessary and log any errors.
Args:
actor: The actor to checkpoint.
Returns:
The result of the actor's user-defined `save_checkpoint` method.
"""
actor_id = self._worker.actor_id
checkpoint_info = self._worker.actor_checkpoint_info[actor_id]
checkpoint_info.num_tasks_since_last_checkpoint += 1
now = int(1000 * time.time())
checkpoint_context = ray.actor.CheckpointContext(
actor_id, checkpoint_info.num_tasks_since_last_checkpoint,
now - checkpoint_info.last_checkpoint_timestamp)
# If we should take a checkpoint, notify raylet to prepare a checkpoint
# and then call `save_checkpoint`.
if actor.should_checkpoint(checkpoint_context):
try:
now = int(1000 * time.time())
checkpoint_id = (self._worker.raylet_client.
prepare_actor_checkpoint(actor_id))
checkpoint_info.checkpoint_ids.append(checkpoint_id)
actor.save_checkpoint(actor_id, checkpoint_id)
if (len(checkpoint_info.checkpoint_ids) >
ray._config.num_actor_checkpoints_to_keep()):
actor.checkpoint_expired(
actor_id,
checkpoint_info.checkpoint_ids.pop(0),
)
checkpoint_info.num_tasks_since_last_checkpoint = 0
checkpoint_info.last_checkpoint_timestamp = now
except Exception:
# Checkpoint save or reload failed. Notify the driver.
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
ray.utils.push_error_to_driver(
self._worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=self._worker.task_driver_id)
def _restore_and_log_checkpoint(self, actor):
"""Restore an actor from a checkpoint if available and log any errors.
This should only be called on workers that have just executed an actor
creation task.
Args:
actor: The actor to restore from a checkpoint.
"""
actor_id = self._worker.actor_id
try:
checkpoints = ray.actor.get_checkpoints_for_actor(actor_id)
if len(checkpoints) > 0:
# If we found previously saved checkpoints for this actor,
# call the `load_checkpoint` callback.
checkpoint_id = actor.load_checkpoint(actor_id, checkpoints)
if checkpoint_id is not None:
# Check that the returned checkpoint id is in the
# `available_checkpoints` list.
msg = (
"`load_checkpoint` must return a checkpoint id that " +
"exists in the `available_checkpoints` list, or eone.")
assert any(checkpoint_id == checkpoint.checkpoint_id
for checkpoint in checkpoints), msg
# Notify raylet that this actor has been resumed from
# a checkpoint.
(self._worker.raylet_client.
notify_actor_resumed_from_checkpoint(
actor_id, checkpoint_id))
except Exception:
# Checkpoint save or reload failed. Notify the driver.
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
ray.utils.push_error_to_driver(
self._worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=self._worker.task_driver_id)
+15 -4
View File
@@ -5,6 +5,7 @@ from __future__ import print_function
import flatbuffers
import ray.core.generated.ErrorTableData
from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData
from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.DriverTableData import DriverTableData
from ray.core.generated.ErrorTableData import ErrorTableData
@@ -20,10 +21,20 @@ from ray.core.generated.TablePubsub import TablePubsub
from ray.core.generated.ray.protocol.Task import Task
__all__ = [
"GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData",
"HeartbeatBatchTableData", "DriverTableData", "ProfileTableData",
"ObjectTableData", "Task", "TablePrefix", "TablePubsub", "Language",
"construct_error_message"
"ActorCheckpointIdData",
"ClientTableData",
"DriverTableData",
"ErrorTableData",
"GcsTableEntry",
"HeartbeatBatchTableData",
"HeartbeatTableData",
"Language",
"ObjectTableData",
"ProfileTableData",
"TablePrefix",
"TablePubsub",
"Task",
"construct_error_message",
]
FUNCTION_PREFIX = "RemoteFunction:"
+13 -4
View File
@@ -6,10 +6,19 @@ from libcpp.unordered_map cimport unordered_map
from libcpp.vector cimport vector as c_vector
from ray.includes.unique_ids cimport (
CUniqueID, TaskID as CTaskID, ObjectID as CObjectID,
FunctionID as CFunctionID, ActorClassID as CActorClassID, ActorID as CActorID,
ActorHandleID as CActorHandleID, WorkerID as CWorkerID,
DriverID as CDriverID, ConfigID as CConfigID, ClientID as CClientID)
ActorCheckpointID as CActorCheckpointID,
ActorClassID as CActorClassID,
ActorHandleID as CActorHandleID,
ActorID as CActorID,
CUniqueID,
ClientID as CClientID,
ConfigID as CConfigID,
DriverID as CDriverID,
FunctionID as CFunctionID,
ObjectID as CObjectID,
TaskID as CTaskID,
WorkerID as CWorkerID,
)
cdef extern from "ray/status.h" namespace "ray" nogil:
+19 -3
View File
@@ -8,9 +8,21 @@ from libcpp.vector cimport vector as c_vector
from ray.includes.common cimport (
CUniqueID, CTaskID, CObjectID, CFunctionID, CActorClassID, CActorID,
CActorHandleID, CWorkerID, CDriverID, CConfigID, CClientID,
CLanguage, CRayStatus)
CActorCheckpointID,
CActorClassID,
CActorHandleID,
CActorID,
CClientID,
CConfigID,
CDriverID,
CFunctionID,
CLanguage,
CObjectID,
CRayStatus,
CTaskID,
CUniqueID,
CWorkerID,
)
from ray.includes.task cimport CTaskSpecification
@@ -57,6 +69,10 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
CRayStatus PushProfileEvents(const GCSProfileTableDataT &profile_events)
CRayStatus FreeObjects(const c_vector[CObjectID] &object_ids,
c_bool local_only)
CRayStatus PrepareActorCheckpoint(const CActorID &actor_id,
CActorCheckpointID &checkpoint_id)
CRayStatus NotifyActorResumedFromCheckpoint(
const CActorID &actor_id, const CActorCheckpointID &checkpoint_id)
CLanguage GetLanguage() const
CClientID GetClientID() const
CDriverID GetDriverID() const
+3 -1
View File
@@ -1,4 +1,4 @@
from libc.stdint cimport int64_t, uint64_t
from libc.stdint cimport int64_t, uint64_t, uint32_t
from libcpp.string cimport string as c_string
from libcpp.unordered_map cimport unordered_map
@@ -80,4 +80,6 @@ cdef extern from "ray/ray_config.h" nogil:
int64_t max_task_lease_timeout_ms() const
uint32_t num_actor_checkpoints_to_keep() const
void initialize(const unordered_map[c_string, int] &config_map)
+4
View File
@@ -144,3 +144,7 @@ cdef class Config:
@staticmethod
def max_task_lease_timeout_ms():
return RayConfig.instance().max_task_lease_timeout_ms()
@staticmethod
def num_actor_checkpoints_to_keep():
return RayConfig.instance().num_actor_checkpoints_to_keep()
+1
View File
@@ -28,6 +28,7 @@ cdef extern from "ray/id.h" namespace "ray" nogil:
ctypedef CUniqueID ActorID
ctypedef CUniqueID ActorClassID
ctypedef CUniqueID ActorHandleID
ctypedef CUniqueID ActorCheckpointID
ctypedef CUniqueID WorkerID
ctypedef CUniqueID DriverID
ctypedef CUniqueID ConfigID
+38 -3
View File
@@ -7,9 +7,21 @@ See https://github.com/ray-project/ray/issues/3721.
# WARNING: Any additional ID types defined in this file must be added to the
# _ID_TYPES list at the bottom of this file.
from ray.includes.common cimport (
CUniqueID, CTaskID, CObjectID, CFunctionID, CActorClassID, CActorID,
CActorHandleID, CWorkerID, CDriverID, CConfigID, CClientID,
ComputePutId, ComputeTaskId)
CActorCheckpointID,
CActorClassID,
CActorHandleID,
CActorID,
CClientID,
CConfigID,
CDriverID,
CFunctionID,
CObjectID,
CTaskID,
CUniqueID,
CWorkerID,
ComputePutId,
ComputeTaskId,
)
from ray.utils import decode
@@ -236,6 +248,29 @@ cdef class ActorHandleID(UniqueID):
return "ActorHandleID(" + self.hex() + ")"
cdef class ActorCheckpointID(UniqueID):
def __init__(self, id):
if not id:
self.data = CUniqueID()
else:
check_id(id)
self.data = CUniqueID.from_binary(id)
@staticmethod
cdef from_native(const CActorCheckpointID& cpp_id):
cdef ActorCheckpointID self = ActorCheckpointID.__new__(ActorHandleID)
self.data = cpp_id
return self
@staticmethod
def nil():
return ActorCheckpointID.from_native(CActorCheckpointID.nil())
def __repr__(self):
return "ActorCheckpointID(" + self.hex() + ")"
cdef class FunctionID(UniqueID):
def __init__(self, id):
+29 -12
View File
@@ -117,6 +117,25 @@ class RayTaskError(Exception):
return "\n".join(out)
class ActorCheckpointInfo(object):
"""Information used to maintain actor checkpoints."""
__slots__ = [
# Number of tasks executed since last checkpoint.
"num_tasks_since_last_checkpoint",
# Timestamp of the last checkpoint, in milliseconds.
"last_checkpoint_timestamp",
# IDs of the previous checkpoints.
"checkpoint_ids",
]
def __init__(self, num_tasks_since_last_checkpoint,
last_checkpoint_timestamp, checkpoint_ids):
self.num_tasks_since_last_checkpoint = num_tasks_since_last_checkpoint
self.last_checkpoint_timestamp = last_checkpoint_timestamp
self.checkpoint_ids = checkpoint_ids
class Worker(object):
"""A class used to define the control flow of a worker process.
@@ -141,6 +160,8 @@ class Worker(object):
self.actor_init_error = None
self.make_actor = None
self.actors = {}
# Information used to maintain actor checkpoints.
self.actor_checkpoint_info = {}
self.actor_task_counter = 0
# The number of threads Plasma should use when putting an object in the
# object store.
@@ -515,7 +536,6 @@ class Worker(object):
actor_id=None,
actor_handle_id=None,
actor_counter=0,
is_actor_checkpoint_method=False,
actor_creation_id=None,
actor_creation_dummy_object_id=None,
max_actor_reconstructions=0,
@@ -538,8 +558,6 @@ class Worker(object):
be serializable objects.
actor_id: The ID of the actor that this task is for.
actor_counter: The counter of the actor task.
is_actor_checkpoint_method: True if this is an actor checkpoint
task and false otherwise.
actor_creation_id: The ID of the actor to create, if this is an
actor creation task.
actor_creation_dummy_object_id: If this task is an actor method,
@@ -900,6 +918,11 @@ class Worker(object):
self.actor_creation_task_id = task.task_id()
self.function_actor_manager.load_actor(driver_id,
function_descriptor)
self.actor_checkpoint_info[self.actor_id] = ActorCheckpointInfo(
num_tasks_since_last_checkpoint=0,
last_checkpoint_timestamp=int(1000 * time.time()),
checkpoint_ids=[],
)
execution_info = self.function_actor_manager.get_execution_info(
driver_id, function_descriptor)
@@ -2395,16 +2418,12 @@ def make_decorator(num_return_vals=None,
num_gpus=None,
resources=None,
max_calls=None,
checkpoint_interval=None,
max_reconstructions=None,
worker=None):
def decorator(function_or_class):
if (inspect.isfunction(function_or_class)
or is_cython(function_or_class)):
# Set the remote function default resources.
if checkpoint_interval is not None:
raise Exception("The keyword 'checkpoint_interval' is not "
"allowed for remote functions.")
if max_reconstructions is not None:
raise Exception("The keyword 'max_reconstructions' is not "
"allowed for remote functions.")
@@ -2437,7 +2456,7 @@ def make_decorator(num_return_vals=None,
return worker.make_actor(function_or_class, cpus_to_use, num_gpus,
resources, actor_method_cpus,
checkpoint_interval, max_reconstructions)
max_reconstructions)
raise Exception("The @ray.remote decorator must be applied to "
"either a function or to a class.")
@@ -2509,7 +2528,7 @@ def remote(*args, **kwargs):
"with no arguments and no parentheses, for example "
"'@ray.remote', or it must be applied using some of "
"the arguments 'num_return_vals', 'num_cpus', 'num_gpus', "
"'resources', 'max_calls', 'checkpoint_interval',"
"'resources', 'max_calls', "
"or 'max_reconstructions', like "
"'@ray.remote(num_return_vals=2, "
"resources={\"CustomResource\": 1})'.")
@@ -2517,7 +2536,7 @@ def remote(*args, **kwargs):
for key in kwargs:
assert key in [
"num_return_vals", "num_cpus", "num_gpus", "resources",
"max_calls", "checkpoint_interval", "max_reconstructions"
"max_calls", "max_reconstructions"
], error_string
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else None
@@ -2534,7 +2553,6 @@ def remote(*args, **kwargs):
# Handle other arguments.
num_return_vals = kwargs.get("num_return_vals")
max_calls = kwargs.get("max_calls")
checkpoint_interval = kwargs.get("checkpoint_interval")
max_reconstructions = kwargs.get("max_reconstructions")
return make_decorator(
@@ -2543,6 +2561,5 @@ def remote(*args, **kwargs):
num_gpus=num_gpus,
resources=resources,
max_calls=max_calls,
checkpoint_interval=checkpoint_interval,
max_reconstructions=max_reconstructions,
worker=worker)