mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 15:44:37 +08:00
Implement actor checkpointing (#3839)
* Implement Actor checkpointing * docs * fix * fix * fix * move restore-from-checkpoint to HandleActorStateTransition * Revert "move restore-from-checkpoint to HandleActorStateTransition" This reverts commit 9aa4447c1e3e321f42a1d895d72f17098b72de12. * resubmit waiting tasks when actor frontier restored * add doc about num_actor_checkpoints_to_keep=1 * add num_actor_checkpoints_to_keep to Cython * add checkpoint_expired api * check if actor class is abstract * change checkpoint_ids to long string * implement java * Refactor to delay actor creation publish until checkpoint is resumed * debug, lint * Erase from checkpoints to restore if task fails * fix lint * update comments * avoid duplicated actor notification log * fix unintended change * add actor_id to checkpoint_expired * small java updates * make checkpoint info per actor * lint * Remove logging * Remove old actor checkpointing Python code, move new checkpointing code to FunctionActionManager * Replace old actor checkpointing tests * Fix test and lint * address comments * consolidate kill_actor * Remove __ray_checkpoint__ * fix non-ascii char * Loosen test checks * fix java * fix sphinx-build
This commit is contained in:
+144
-167
@@ -6,11 +6,13 @@ import copy
|
||||
import hashlib
|
||||
import inspect
|
||||
import logging
|
||||
import six
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import ray.cloudpickle as pickle
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import namedtuple
|
||||
|
||||
from ray.function_manager import FunctionDescriptor
|
||||
import ray.ray_constants as ray_constants
|
||||
import ray.signature as signature
|
||||
@@ -75,90 +77,6 @@ def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id):
|
||||
return ActorHandleID(handle_id)
|
||||
|
||||
|
||||
def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
|
||||
frontier):
|
||||
"""Set the most recent checkpoint associated with a given actor ID.
|
||||
|
||||
Args:
|
||||
worker: The worker to use to get the checkpoint.
|
||||
actor_id: The actor ID of the actor to get the checkpoint for.
|
||||
checkpoint_index: The number of tasks included in the checkpoint.
|
||||
checkpoint: The state object to save.
|
||||
frontier: The task frontier at the time of the checkpoint.
|
||||
"""
|
||||
assert isinstance(actor_id, ActorID)
|
||||
actor_key = b"Actor:" + actor_id.binary()
|
||||
worker.redis_client.hmset(
|
||||
actor_key, {
|
||||
"checkpoint_index": checkpoint_index,
|
||||
"checkpoint": checkpoint,
|
||||
"frontier": frontier,
|
||||
})
|
||||
|
||||
|
||||
def save_and_log_checkpoint(worker, actor):
|
||||
"""Save a checkpoint on the actor and log any errors.
|
||||
|
||||
Args:
|
||||
worker: The worker to use to log errors.
|
||||
actor: The actor to checkpoint.
|
||||
"""
|
||||
try:
|
||||
actor.__ray_checkpoint__()
|
||||
except Exception:
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
ray.utils.push_error_to_driver(
|
||||
worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=worker.task_driver_id)
|
||||
|
||||
|
||||
def restore_and_log_checkpoint(worker, actor):
|
||||
"""Restore an actor from a checkpoint and log any errors.
|
||||
|
||||
Args:
|
||||
worker: The worker to use to log errors.
|
||||
actor: The actor to restore.
|
||||
"""
|
||||
checkpoint_resumed = False
|
||||
try:
|
||||
checkpoint_resumed = actor.__ray_checkpoint_restore__()
|
||||
except Exception:
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
ray.utils.push_error_to_driver(
|
||||
worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=worker.task_driver_id)
|
||||
return checkpoint_resumed
|
||||
|
||||
|
||||
def get_actor_checkpoint(worker, actor_id):
|
||||
"""Get the most recent checkpoint associated with a given actor ID.
|
||||
|
||||
Args:
|
||||
worker: The worker to use to get the checkpoint.
|
||||
actor_id: The actor ID of the actor to get the checkpoint for.
|
||||
|
||||
Returns:
|
||||
If a checkpoint exists, this returns a tuple of the number of tasks
|
||||
included in the checkpoint, the saved checkpoint state, and the
|
||||
task frontier at the time of the checkpoint. If no checkpoint
|
||||
exists, all objects are set to None. The checkpoint index is the .
|
||||
executed on the actor before the checkpoint was made.
|
||||
"""
|
||||
assert isinstance(actor_id, ActorID)
|
||||
actor_key = b"Actor:" + actor_id.binary()
|
||||
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
|
||||
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
|
||||
if checkpoint_index is not None:
|
||||
checkpoint_index = int(checkpoint_index)
|
||||
return checkpoint_index, checkpoint, frontier
|
||||
|
||||
|
||||
def method(*args, **kwargs):
|
||||
"""Annotate an actor method.
|
||||
|
||||
@@ -234,7 +152,6 @@ class ActorClass(object):
|
||||
additional methods added like __ray_terminate__).
|
||||
_class_id: The ID of this actor class.
|
||||
_class_name: The name of this class.
|
||||
_checkpoint_interval: The interval at which to checkpoint actor state.
|
||||
_num_cpus: The default number of CPUs required by the actor creation
|
||||
task.
|
||||
_num_gpus: The default number of GPUs required by the actor creation
|
||||
@@ -250,13 +167,11 @@ class ActorClass(object):
|
||||
each actor method.
|
||||
"""
|
||||
|
||||
def __init__(self, modified_class, class_id, checkpoint_interval,
|
||||
max_reconstructions, num_cpus, num_gpus, resources,
|
||||
actor_method_cpus):
|
||||
def __init__(self, modified_class, class_id, max_reconstructions, num_cpus,
|
||||
num_gpus, resources, actor_method_cpus):
|
||||
self._modified_class = modified_class
|
||||
self._class_id = class_id
|
||||
self._class_name = modified_class.__name__
|
||||
self._checkpoint_interval = checkpoint_interval
|
||||
self._max_reconstructions = max_reconstructions
|
||||
self._num_cpus = num_cpus
|
||||
self._num_gpus = num_gpus
|
||||
@@ -383,8 +298,7 @@ class ActorClass(object):
|
||||
# Export the actor.
|
||||
if not self._exported:
|
||||
worker.function_actor_manager.export_actor_class(
|
||||
self._modified_class, self._actor_method_names,
|
||||
self._checkpoint_interval)
|
||||
self._modified_class, self._actor_method_names)
|
||||
self._exported = True
|
||||
|
||||
resources = ray.utils.resources_from_resource_arguments(
|
||||
@@ -564,8 +478,6 @@ class ActorHandle(object):
|
||||
return getattr(worker.actors[self._ray_actor_id],
|
||||
method_name)(*copy.deepcopy(args))
|
||||
|
||||
is_actor_checkpoint_method = (method_name == "__ray_checkpoint__")
|
||||
|
||||
function_descriptor = FunctionDescriptor(
|
||||
self._ray_module_name, method_name, self._ray_class_name)
|
||||
with self._ray_actor_lock:
|
||||
@@ -575,7 +487,6 @@ class ActorHandle(object):
|
||||
actor_id=self._ray_actor_id,
|
||||
actor_handle_id=self._ray_actor_handle_id,
|
||||
actor_counter=self._ray_actor_counter,
|
||||
is_actor_checkpoint_method=is_actor_checkpoint_method,
|
||||
actor_creation_dummy_object_id=(
|
||||
self._ray_actor_creation_dummy_object_id),
|
||||
execution_dependencies=[self._ray_actor_cursor],
|
||||
@@ -770,7 +681,7 @@ class ActorHandle(object):
|
||||
|
||||
|
||||
def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
|
||||
checkpoint_interval, max_reconstructions):
|
||||
max_reconstructions):
|
||||
# Give an error if cls is an old-style class.
|
||||
if not issubclass(cls, object):
|
||||
raise TypeError(
|
||||
@@ -778,13 +689,14 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
|
||||
"classes. In Python 2, you must declare the class with "
|
||||
"'class ClassName(object):' instead of 'class ClassName:'.")
|
||||
|
||||
if checkpoint_interval is None:
|
||||
checkpoint_interval = -1
|
||||
if issubclass(cls, Checkpointable) and inspect.isabstract(cls):
|
||||
raise TypeError(
|
||||
"A checkpointable actor class should implement all abstract "
|
||||
"methods in the `Checkpointable` interface.")
|
||||
|
||||
if max_reconstructions is None:
|
||||
max_reconstructions = 0
|
||||
|
||||
if checkpoint_interval == 0:
|
||||
raise Exception("checkpoint_interval must be greater than 0.")
|
||||
if not (ray_constants.NO_RECONSTRUCTION <= max_reconstructions <=
|
||||
ray_constants.INFINITE_RECONSTRUCTION):
|
||||
raise Exception("max_reconstructions must be in range [%d, %d]." %
|
||||
@@ -804,26 +716,6 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
|
||||
sys.exit(0)
|
||||
assert False, "This process should have terminated."
|
||||
|
||||
def __ray_save_checkpoint__(self):
|
||||
if hasattr(self, "__ray_save__"):
|
||||
object_to_serialize = self.__ray_save__()
|
||||
else:
|
||||
object_to_serialize = self
|
||||
return pickle.dumps(object_to_serialize)
|
||||
|
||||
@classmethod
|
||||
def __ray_restore_from_checkpoint__(cls, pickled_checkpoint):
|
||||
checkpoint = pickle.loads(pickled_checkpoint)
|
||||
if hasattr(cls, "__ray_restore__"):
|
||||
actor_object = cls.__new__(cls)
|
||||
actor_object.__ray_restore__(checkpoint)
|
||||
else:
|
||||
# TODO(rkn): It's possible that this will cause problems. When
|
||||
# you unpickle the same object twice, the two objects will not
|
||||
# have the same class.
|
||||
actor_object = checkpoint
|
||||
return actor_object
|
||||
|
||||
def __ray_checkpoint__(self):
|
||||
"""Save a checkpoint.
|
||||
|
||||
@@ -832,58 +724,143 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
|
||||
(number of tasks executed so far).
|
||||
"""
|
||||
worker = ray.worker.global_worker
|
||||
checkpoint_index = worker.actor_task_counter
|
||||
# Get the state to save.
|
||||
checkpoint = self.__ray_save_checkpoint__()
|
||||
# Get the current task frontier, per actor handle.
|
||||
# NOTE(swang): This only includes actor handles that the local
|
||||
# scheduler has seen. Handle IDs for which no task has yet reached
|
||||
# the local scheduler will not be included, and may not be runnable
|
||||
# on checkpoint resumption.
|
||||
actor_id = worker.actor_id
|
||||
frontier = worker.raylet_client.get_actor_frontier(actor_id)
|
||||
# Save the checkpoint in Redis. TODO(rkn): Checkpoints
|
||||
# should not be stored in Redis. Fix this.
|
||||
set_actor_checkpoint(worker, worker.actor_id, checkpoint_index,
|
||||
checkpoint, frontier)
|
||||
|
||||
def __ray_checkpoint_restore__(self):
|
||||
"""Restore a checkpoint.
|
||||
|
||||
This task looks for a saved checkpoint and if found, restores the
|
||||
state of the actor, the task frontier in the local scheduler, and
|
||||
the checkpoint index (number of tasks executed so far).
|
||||
|
||||
Returns:
|
||||
A bool indicating whether a checkpoint was resumed.
|
||||
"""
|
||||
worker = ray.worker.global_worker
|
||||
# Get the most recent checkpoint stored, if any.
|
||||
checkpoint_index, checkpoint, frontier = get_actor_checkpoint(
|
||||
worker, worker.actor_id)
|
||||
# Try to resume from the checkpoint.
|
||||
checkpoint_resumed = False
|
||||
if checkpoint_index is not None:
|
||||
# Load the actor state from the checkpoint.
|
||||
worker.actors[worker.actor_id] = (
|
||||
worker.actor_class.__ray_restore_from_checkpoint__(
|
||||
checkpoint))
|
||||
# Set the number of tasks executed so far.
|
||||
worker.actor_task_counter = checkpoint_index
|
||||
# Set the actor frontier in the local scheduler.
|
||||
worker.raylet_client.set_actor_frontier(frontier)
|
||||
checkpoint_resumed = True
|
||||
|
||||
return checkpoint_resumed
|
||||
if not isinstance(self, ray.actor.Checkpointable):
|
||||
raise Exception(
|
||||
"__ray_checkpoint__.remote() may only be called on actors "
|
||||
"that implement ray.actor.Checkpointable")
|
||||
return worker._save_actor_checkpoint()
|
||||
|
||||
Class.__module__ = cls.__module__
|
||||
Class.__name__ = cls.__name__
|
||||
|
||||
class_id = ActorClassID(_random_string())
|
||||
|
||||
return ActorClass(Class, class_id, checkpoint_interval,
|
||||
max_reconstructions, num_cpus, num_gpus, resources,
|
||||
actor_method_cpus)
|
||||
return ActorClass(Class, class_id, max_reconstructions, num_cpus, num_gpus,
|
||||
resources, actor_method_cpus)
|
||||
|
||||
|
||||
ray.worker.global_worker.make_actor = make_actor
|
||||
|
||||
CheckpointContext = namedtuple(
|
||||
'CheckpointContext',
|
||||
[
|
||||
# Actor's ID.
|
||||
'actor_id',
|
||||
# Number of tasks executed since last checkpoint.
|
||||
'num_tasks_since_last_checkpoint',
|
||||
# Time elapsed since last checkpoint, in milliseconds.
|
||||
'time_elapsed_ms_since_last_checkpoint',
|
||||
],
|
||||
)
|
||||
"""A namedtuple that contains information about actor's last checkpoint."""
|
||||
|
||||
Checkpoint = namedtuple(
|
||||
'Checkpoint',
|
||||
[
|
||||
# ID of this checkpoint.
|
||||
'checkpoint_id',
|
||||
# The timestamp at which this checkpoint was saved,
|
||||
# represented as milliseconds elapsed since Unix epoch.
|
||||
'timestamp',
|
||||
],
|
||||
)
|
||||
"""A namedtuple that represents a checkpoint."""
|
||||
|
||||
|
||||
class Checkpointable(six.with_metaclass(ABCMeta, object)):
|
||||
"""An interface that indicates an actor can be checkpointed."""
|
||||
|
||||
@abstractmethod
|
||||
def should_checkpoint(self, checkpoint_context):
|
||||
"""Whether this actor needs to be checkpointed.
|
||||
|
||||
This method will be called after every task. You should implement this
|
||||
callback to decide whether this actor needs to be checkpointed at this
|
||||
time, based on the checkpoint context, or any other factors.
|
||||
|
||||
Args:
|
||||
checkpoint_context: A namedtuple that contains info about last
|
||||
checkpoint.
|
||||
|
||||
Returns:
|
||||
A boolean value that indicates whether this actor needs to be
|
||||
checkpointed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_checkpoint(self, actor_id, checkpoint_id):
|
||||
"""Save a checkpoint to persistent storage.
|
||||
|
||||
If `should_checkpoint` returns true, this method will be called. You
|
||||
should implement this callback to save actor's checkpoint and the given
|
||||
checkpoint id to persistent storage.
|
||||
|
||||
Args:
|
||||
actor_id: Actor's ID.
|
||||
checkpoint_id: ID of this checkpoint. You should save it together
|
||||
with actor's checkpoint data. And it will be used by the
|
||||
`load_checkpoint` method.
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_checkpoint(self, actor_id, available_checkpoints):
|
||||
"""Load actor's previous checkpoint, and restore actor's state.
|
||||
|
||||
This method will be called when an actor is reconstructed, after
|
||||
actor's constructor.
|
||||
If the actor needs to restore from previous checkpoint, this function
|
||||
should restore actor's state and return the checkpoint ID. Otherwise,
|
||||
it should do nothing and return None.
|
||||
Note, this method must return one of the checkpoint IDs in the
|
||||
`available_checkpoints` list, or None. Otherwise, an exception will be
|
||||
raised.
|
||||
|
||||
Args:
|
||||
actor_id: Actor's ID.
|
||||
available_checkpoints: A list of `Checkpoint` namedtuples that
|
||||
contains all available checkpoint IDs and their timestamps,
|
||||
sorted by timestamp in descending order.
|
||||
Returns:
|
||||
The ID of the checkpoint from which the actor was resumed, or None
|
||||
if the actor should restart from the beginning.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def checkpoint_expired(self, actor_id, checkpoint_id):
|
||||
"""Delete an expired checkpoint.
|
||||
|
||||
This method will be called when an checkpoint is expired. You should
|
||||
implement this method to delete your application checkpoint data.
|
||||
Note, the maximum number of checkpoints kept in the backend can be
|
||||
configured at `RayConfig.num_actor_checkpoints_to_keep`.
|
||||
|
||||
Args:
|
||||
actor_id: ID of the actor.
|
||||
checkpoint_id: ID of the checkpoint that has expired.
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_checkpoints_for_actor(actor_id):
|
||||
"""Get the available checkpoints for the given actor ID, return a list
|
||||
sorted by checkpoint timestamp in descending order.
|
||||
"""
|
||||
checkpoint_info = ray.worker.global_state.actor_checkpoint_info(actor_id)
|
||||
if checkpoint_info is None:
|
||||
return []
|
||||
checkpoints = [
|
||||
Checkpoint(checkpoint_id, timestamp) for checkpoint_id, timestamp in
|
||||
zip(checkpoint_info['CheckpointIds'], checkpoint_info['Timestamps'])
|
||||
]
|
||||
return sorted(
|
||||
checkpoints,
|
||||
key=lambda checkpoint: checkpoint.timestamp,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user