Implement actor checkpointing (#3839)

* Implement Actor checkpointing

* docs

* fix

* fix

* fix

* move restore-from-checkpoint to HandleActorStateTransition

* Revert "move restore-from-checkpoint to HandleActorStateTransition"

This reverts commit 9aa4447c1e3e321f42a1d895d72f17098b72de12.

* resubmit waiting tasks when actor frontier restored

* add doc about num_actor_checkpoints_to_keep=1

* add num_actor_checkpoints_to_keep to Cython

* add checkpoint_expired api

* check if actor class is abstract

* change checkpoint_ids to long string

* implement java

* Refactor to delay actor creation publish until checkpoint is resumed

* debug, lint

* Erase from checkpoints to restore if task fails

* fix lint

* update comments

* avoid duplicated actor notification log

* fix unintended change

* add actor_id to checkpoint_expired

* small java updates

* make checkpoint info per actor

* lint

* Remove logging

* Remove old actor checkpointing Python code, move new checkpointing code to FunctionActionManager

* Replace old actor checkpointing tests

* Fix test and lint

* address comments

* consolidate kill_actor

* Remove __ray_checkpoint__

* fix non-ascii char

* Loosen test checks

* fix java

* fix sphinx-build
This commit is contained in:
Hao Chen
2019-02-13 19:39:02 +08:00
committed by GitHub
parent 57dcd3033e
commit f31a79f3f7
41 changed files with 1708 additions and 490 deletions
+144 -167
View File
@@ -6,11 +6,13 @@ import copy
import hashlib
import inspect
import logging
import six
import sys
import threading
import traceback
import ray.cloudpickle as pickle
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from ray.function_manager import FunctionDescriptor
import ray.ray_constants as ray_constants
import ray.signature as signature
@@ -75,90 +77,6 @@ def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id):
return ActorHandleID(handle_id)
def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
frontier):
"""Set the most recent checkpoint associated with a given actor ID.
Args:
worker: The worker to use to get the checkpoint.
actor_id: The actor ID of the actor to get the checkpoint for.
checkpoint_index: The number of tasks included in the checkpoint.
checkpoint: The state object to save.
frontier: The task frontier at the time of the checkpoint.
"""
assert isinstance(actor_id, ActorID)
actor_key = b"Actor:" + actor_id.binary()
worker.redis_client.hmset(
actor_key, {
"checkpoint_index": checkpoint_index,
"checkpoint": checkpoint,
"frontier": frontier,
})
def save_and_log_checkpoint(worker, actor):
"""Save a checkpoint on the actor and log any errors.
Args:
worker: The worker to use to log errors.
actor: The actor to checkpoint.
"""
try:
actor.__ray_checkpoint__()
except Exception:
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id)
def restore_and_log_checkpoint(worker, actor):
"""Restore an actor from a checkpoint and log any errors.
Args:
worker: The worker to use to log errors.
actor: The actor to restore.
"""
checkpoint_resumed = False
try:
checkpoint_resumed = actor.__ray_checkpoint_restore__()
except Exception:
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id)
return checkpoint_resumed
def get_actor_checkpoint(worker, actor_id):
"""Get the most recent checkpoint associated with a given actor ID.
Args:
worker: The worker to use to get the checkpoint.
actor_id: The actor ID of the actor to get the checkpoint for.
Returns:
If a checkpoint exists, this returns a tuple of the number of tasks
included in the checkpoint, the saved checkpoint state, and the
task frontier at the time of the checkpoint. If no checkpoint
exists, all objects are set to None. The checkpoint index is the .
executed on the actor before the checkpoint was made.
"""
assert isinstance(actor_id, ActorID)
actor_key = b"Actor:" + actor_id.binary()
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
if checkpoint_index is not None:
checkpoint_index = int(checkpoint_index)
return checkpoint_index, checkpoint, frontier
def method(*args, **kwargs):
"""Annotate an actor method.
@@ -234,7 +152,6 @@ class ActorClass(object):
additional methods added like __ray_terminate__).
_class_id: The ID of this actor class.
_class_name: The name of this class.
_checkpoint_interval: The interval at which to checkpoint actor state.
_num_cpus: The default number of CPUs required by the actor creation
task.
_num_gpus: The default number of GPUs required by the actor creation
@@ -250,13 +167,11 @@ class ActorClass(object):
each actor method.
"""
def __init__(self, modified_class, class_id, checkpoint_interval,
max_reconstructions, num_cpus, num_gpus, resources,
actor_method_cpus):
def __init__(self, modified_class, class_id, max_reconstructions, num_cpus,
num_gpus, resources, actor_method_cpus):
self._modified_class = modified_class
self._class_id = class_id
self._class_name = modified_class.__name__
self._checkpoint_interval = checkpoint_interval
self._max_reconstructions = max_reconstructions
self._num_cpus = num_cpus
self._num_gpus = num_gpus
@@ -383,8 +298,7 @@ class ActorClass(object):
# Export the actor.
if not self._exported:
worker.function_actor_manager.export_actor_class(
self._modified_class, self._actor_method_names,
self._checkpoint_interval)
self._modified_class, self._actor_method_names)
self._exported = True
resources = ray.utils.resources_from_resource_arguments(
@@ -564,8 +478,6 @@ class ActorHandle(object):
return getattr(worker.actors[self._ray_actor_id],
method_name)(*copy.deepcopy(args))
is_actor_checkpoint_method = (method_name == "__ray_checkpoint__")
function_descriptor = FunctionDescriptor(
self._ray_module_name, method_name, self._ray_class_name)
with self._ray_actor_lock:
@@ -575,7 +487,6 @@ class ActorHandle(object):
actor_id=self._ray_actor_id,
actor_handle_id=self._ray_actor_handle_id,
actor_counter=self._ray_actor_counter,
is_actor_checkpoint_method=is_actor_checkpoint_method,
actor_creation_dummy_object_id=(
self._ray_actor_creation_dummy_object_id),
execution_dependencies=[self._ray_actor_cursor],
@@ -770,7 +681,7 @@ class ActorHandle(object):
def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
checkpoint_interval, max_reconstructions):
max_reconstructions):
# Give an error if cls is an old-style class.
if not issubclass(cls, object):
raise TypeError(
@@ -778,13 +689,14 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
"classes. In Python 2, you must declare the class with "
"'class ClassName(object):' instead of 'class ClassName:'.")
if checkpoint_interval is None:
checkpoint_interval = -1
if issubclass(cls, Checkpointable) and inspect.isabstract(cls):
raise TypeError(
"A checkpointable actor class should implement all abstract "
"methods in the `Checkpointable` interface.")
if max_reconstructions is None:
max_reconstructions = 0
if checkpoint_interval == 0:
raise Exception("checkpoint_interval must be greater than 0.")
if not (ray_constants.NO_RECONSTRUCTION <= max_reconstructions <=
ray_constants.INFINITE_RECONSTRUCTION):
raise Exception("max_reconstructions must be in range [%d, %d]." %
@@ -804,26 +716,6 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
sys.exit(0)
assert False, "This process should have terminated."
def __ray_save_checkpoint__(self):
if hasattr(self, "__ray_save__"):
object_to_serialize = self.__ray_save__()
else:
object_to_serialize = self
return pickle.dumps(object_to_serialize)
@classmethod
def __ray_restore_from_checkpoint__(cls, pickled_checkpoint):
checkpoint = pickle.loads(pickled_checkpoint)
if hasattr(cls, "__ray_restore__"):
actor_object = cls.__new__(cls)
actor_object.__ray_restore__(checkpoint)
else:
# TODO(rkn): It's possible that this will cause problems. When
# you unpickle the same object twice, the two objects will not
# have the same class.
actor_object = checkpoint
return actor_object
def __ray_checkpoint__(self):
"""Save a checkpoint.
@@ -832,58 +724,143 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
(number of tasks executed so far).
"""
worker = ray.worker.global_worker
checkpoint_index = worker.actor_task_counter
# Get the state to save.
checkpoint = self.__ray_save_checkpoint__()
# Get the current task frontier, per actor handle.
# NOTE(swang): This only includes actor handles that the local
# scheduler has seen. Handle IDs for which no task has yet reached
# the local scheduler will not be included, and may not be runnable
# on checkpoint resumption.
actor_id = worker.actor_id
frontier = worker.raylet_client.get_actor_frontier(actor_id)
# Save the checkpoint in Redis. TODO(rkn): Checkpoints
# should not be stored in Redis. Fix this.
set_actor_checkpoint(worker, worker.actor_id, checkpoint_index,
checkpoint, frontier)
def __ray_checkpoint_restore__(self):
"""Restore a checkpoint.
This task looks for a saved checkpoint and if found, restores the
state of the actor, the task frontier in the local scheduler, and
the checkpoint index (number of tasks executed so far).
Returns:
A bool indicating whether a checkpoint was resumed.
"""
worker = ray.worker.global_worker
# Get the most recent checkpoint stored, if any.
checkpoint_index, checkpoint, frontier = get_actor_checkpoint(
worker, worker.actor_id)
# Try to resume from the checkpoint.
checkpoint_resumed = False
if checkpoint_index is not None:
# Load the actor state from the checkpoint.
worker.actors[worker.actor_id] = (
worker.actor_class.__ray_restore_from_checkpoint__(
checkpoint))
# Set the number of tasks executed so far.
worker.actor_task_counter = checkpoint_index
# Set the actor frontier in the local scheduler.
worker.raylet_client.set_actor_frontier(frontier)
checkpoint_resumed = True
return checkpoint_resumed
if not isinstance(self, ray.actor.Checkpointable):
raise Exception(
"__ray_checkpoint__.remote() may only be called on actors "
"that implement ray.actor.Checkpointable")
return worker._save_actor_checkpoint()
Class.__module__ = cls.__module__
Class.__name__ = cls.__name__
class_id = ActorClassID(_random_string())
return ActorClass(Class, class_id, checkpoint_interval,
max_reconstructions, num_cpus, num_gpus, resources,
actor_method_cpus)
return ActorClass(Class, class_id, max_reconstructions, num_cpus, num_gpus,
resources, actor_method_cpus)
ray.worker.global_worker.make_actor = make_actor
CheckpointContext = namedtuple(
'CheckpointContext',
[
# Actor's ID.
'actor_id',
# Number of tasks executed since last checkpoint.
'num_tasks_since_last_checkpoint',
# Time elapsed since last checkpoint, in milliseconds.
'time_elapsed_ms_since_last_checkpoint',
],
)
"""A namedtuple that contains information about actor's last checkpoint."""
Checkpoint = namedtuple(
'Checkpoint',
[
# ID of this checkpoint.
'checkpoint_id',
# The timestamp at which this checkpoint was saved,
# represented as milliseconds elapsed since Unix epoch.
'timestamp',
],
)
"""A namedtuple that represents a checkpoint."""
class Checkpointable(six.with_metaclass(ABCMeta, object)):
"""An interface that indicates an actor can be checkpointed."""
@abstractmethod
def should_checkpoint(self, checkpoint_context):
"""Whether this actor needs to be checkpointed.
This method will be called after every task. You should implement this
callback to decide whether this actor needs to be checkpointed at this
time, based on the checkpoint context, or any other factors.
Args:
checkpoint_context: A namedtuple that contains info about last
checkpoint.
Returns:
A boolean value that indicates whether this actor needs to be
checkpointed.
"""
pass
@abstractmethod
def save_checkpoint(self, actor_id, checkpoint_id):
"""Save a checkpoint to persistent storage.
If `should_checkpoint` returns true, this method will be called. You
should implement this callback to save actor's checkpoint and the given
checkpoint id to persistent storage.
Args:
actor_id: Actor's ID.
checkpoint_id: ID of this checkpoint. You should save it together
with actor's checkpoint data. And it will be used by the
`load_checkpoint` method.
Returns:
None.
"""
pass
@abstractmethod
def load_checkpoint(self, actor_id, available_checkpoints):
"""Load actor's previous checkpoint, and restore actor's state.
This method will be called when an actor is reconstructed, after
actor's constructor.
If the actor needs to restore from previous checkpoint, this function
should restore actor's state and return the checkpoint ID. Otherwise,
it should do nothing and return None.
Note, this method must return one of the checkpoint IDs in the
`available_checkpoints` list, or None. Otherwise, an exception will be
raised.
Args:
actor_id: Actor's ID.
available_checkpoints: A list of `Checkpoint` namedtuples that
contains all available checkpoint IDs and their timestamps,
sorted by timestamp in descending order.
Returns:
The ID of the checkpoint from which the actor was resumed, or None
if the actor should restart from the beginning.
"""
pass
@abstractmethod
def checkpoint_expired(self, actor_id, checkpoint_id):
"""Delete an expired checkpoint.
This method will be called when an checkpoint is expired. You should
implement this method to delete your application checkpoint data.
Note, the maximum number of checkpoints kept in the backend can be
configured at `RayConfig.num_actor_checkpoints_to_keep`.
Args:
actor_id: ID of the actor.
checkpoint_id: ID of the checkpoint that has expired.
Returns:
None.
"""
pass
def get_checkpoints_for_actor(actor_id):
"""Get the available checkpoints for the given actor ID, return a list
sorted by checkpoint timestamp in descending order.
"""
checkpoint_info = ray.worker.global_state.actor_checkpoint_info(actor_id)
if checkpoint_info is None:
return []
checkpoints = [
Checkpoint(checkpoint_id, timestamp) for checkpoint_id, timestamp in
zip(checkpoint_info['CheckpointIds'], checkpoint_info['Timestamps'])
]
return sorted(
checkpoints,
key=lambda checkpoint: checkpoint.timestamp,
reverse=True,
)