Actor checkpointing for distributed actor handles (#1498)

* Expose calls to get and set the actor frontier

* Remove fields used for old checkpointing prototype, change actor_checkpoint_failed -> succeeded

* Prototype for actor checkpointing

* Filter out duplicate tasks on the local scheduler

* Clean up some of the Python checkpointing code

* More cleanups

* Documentation

* cleanup and fix unit test

* Allow remote checkpoint calls through actor handle

* Check whether object is local before reconstructing

* Enable checkpointing for distributed actor handles, refactor tests

* Fix local scheduler tests

* lint

* Address comments

* lint

* Skip tests that fail on new GCS

* style

* Don't put same object twice when setting the actor frontier

* Address Philipp's comments, cleaner fbs naming
This commit is contained in:
Stephanie Wang
2018-02-07 11:19:32 -08:00
committed by GitHub
parent 0a9dbc84b5
commit ff8e7f8259
13 changed files with 1006 additions and 590 deletions
+171 -176
View File
@@ -8,7 +8,6 @@ import inspect
import json
import traceback
import pyarrow.plasma as plasma
import ray.cloudpickle as pickle
import ray.local_scheduler
import ray.signature as signature
@@ -66,23 +65,24 @@ def compute_actor_method_function_id(class_name, attr):
return ray.local_scheduler.ObjectID(function_id)
def get_checkpoint_indices(worker, actor_id):
"""Get the checkpoint indices associated with a given actor ID.
def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
frontier):
"""Set the most recent checkpoint associated with a given actor ID.
Args:
worker: The worker to use to get the checkpoint indices.
actor_id: The actor ID of the actor to get the checkpoint indices for.
Returns:
The indices of existing checkpoints as a list of integers.
worker: The worker to use to get the checkpoint.
actor_id: The actor ID of the actor to get the checkpoint for.
checkpoint_index: The number of tasks included in the checkpoint.
checkpoint: The state object to save.
frontier: The task frontier at the time of the checkpoint.
"""
actor_key = b"Actor:" + actor_id
checkpoint_indices = []
for key in worker.redis_client.hkeys(actor_key):
if key.startswith(b"checkpoint_"):
index = int(key[len(b"checkpoint_"):])
checkpoint_indices.append(index)
return checkpoint_indices
worker.redis_client.hmset(
actor_key, {
"checkpoint_index": checkpoint_index,
"checkpoint": checkpoint,
"frontier": frontier,
})
def get_actor_checkpoint(worker, actor_id):
@@ -93,30 +93,74 @@ def get_actor_checkpoint(worker, actor_id):
actor_id: The actor ID of the actor to get the checkpoint for.
Returns:
If a checkpoint exists, this returns a tuple of the checkpoint index
and the checkpoint. Otherwise it returns (-1, None). The checkpoint
index is the actor counter of the last task that was executed on
the actor before the checkpoint was made.
If a checkpoint exists, this returns a tuple of the number of tasks
included in the checkpoint, the saved checkpoint state, and the
task frontier at the time of the checkpoint. If no checkpoint
exists, all objects are set to None. The checkpoint index is the .
executed on the actor before the checkpoint was made.
"""
checkpoint_indices = get_checkpoint_indices(worker, actor_id)
if len(checkpoint_indices) == 0:
return -1, None
else:
actor_key = b"Actor:" + actor_id
checkpoint_index = max(checkpoint_indices)
checkpoint = worker.redis_client.hget(
actor_key, "checkpoint_{}".format(checkpoint_index))
return checkpoint_index, checkpoint
actor_key = b"Actor:" + actor_id
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
if checkpoint_index is not None:
checkpoint_index = int(checkpoint_index)
return checkpoint_index, checkpoint, frontier
def make_actor_method_executor(worker, method_name, method):
def save_and_log_checkpoint(worker, actor):
"""Save a checkpoint on the actor and log any errors.
Args:
worker: The worker to use to log errors.
actor: The actor to checkpoint.
checkpoint_index: The number of tasks that have executed so far.
"""
try:
actor.__ray_checkpoint__()
except Exception:
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
"checkpoint",
traceback_str,
driver_id=worker.task_driver_id.id(),
data={"actor_class": actor.__class__.__name__,
"function_name": actor.__ray_checkpoint__.__name__})
def restore_and_log_checkpoint(worker, actor):
"""Restore an actor from a checkpoint and log any errors.
Args:
worker: The worker to use to log errors.
actor: The actor to restore.
"""
checkpoint_resumed = False
try:
checkpoint_resumed = actor.__ray_checkpoint_restore__()
except Exception:
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
"checkpoint",
traceback_str,
driver_id=worker.task_driver_id.id(),
data={
"actor_class": actor.__class__.__name__,
"function_name":
actor.__ray_checkpoint_restore__.__name__})
return checkpoint_resumed
def make_actor_method_executor(worker, method_name, method, actor_imported):
"""Make an executor that wraps a user-defined actor method.
The executor wraps the method to update the worker's internal state. If the
task is a success, the dummy object returned is added to the object store,
to signal that the following task can run, and the worker's task counter is
updated to match the executed task. Else, the executor reports failure to
the local scheduler so that the task counter does not get updated.
The wrapped method updates the worker's internal state and performs any
necessary checkpointing operations.
Args:
worker (Worker): The worker that is executing the actor.
@@ -124,6 +168,8 @@ def make_actor_method_executor(worker, method_name, method):
method (instancemethod): The actor method to wrap. This should be a
method defined on the actor class and should therefore take an
instance of the actor as the first argument.
actor_imported (bool): Whether the actor has been imported.
Checkpointing operations will not be run if this is set to False.
Returns:
A function that executes the given actor method on the worker's stored
@@ -131,35 +177,48 @@ def make_actor_method_executor(worker, method_name, method):
internal state to record the executed method.
"""
def actor_method_executor(dummy_return_id, task_counter, actor,
*args):
if method_name == "__ray_checkpoint__":
# Execute the checkpoint task.
actor_checkpoint_failed, error = method(actor, *args)
# If the checkpoint was successfully loaded, update the actor's
# task counter and set a flag to notify the local scheduler, so
# that the task following the checkpoint can run.
if not actor_checkpoint_failed:
worker.actor_task_counter = task_counter + 1
# Once the actor has resumed from a checkpoint, it counts as
# loaded.
worker.actor_loaded = True
# Report to the local scheduler whether this task succeeded in
# loading the checkpoint.
worker.actor_checkpoint_failed = actor_checkpoint_failed
# If there was an exception during the checkpoint method, re-raise
# it after updating the actor's internal state.
if error is not None:
raise error
return None
def actor_method_executor(dummy_return_id, actor, *args):
# Update the actor's task counter to reflect the task we're about to
# execute.
worker.actor_task_counter += 1
# If this is the first task to execute on the actor, try to resume from
# a checkpoint.
if actor_imported and worker.actor_task_counter == 1:
checkpoint_resumed = restore_and_log_checkpoint(worker, actor)
if checkpoint_resumed:
# NOTE(swang): Since we did not actually execute the __init__
# method, this will put None as the return value. If the
# __init__ method is supposed to return multiple values, an
# exception will be logged.
return
# Determine whether we should checkpoint the actor.
checkpointing_on = (actor_imported and
worker.actor_checkpoint_interval > 0)
# We should checkpoint the actor if user checkpointing is on, we've
# executed checkpoint_interval tasks since the last checkpoint, and the
# method we're about to execute is not a checkpoint.
save_checkpoint = (checkpointing_on and
(worker.actor_task_counter %
worker.actor_checkpoint_interval == 0 and
method_name != "__ray_checkpoint__"))
# Execute the assigned method and save a checkpoint if necessary.
try:
method_returns = method(actor, *args)
except Exception:
# Save the checkpoint before allowing the method exception to be
# thrown.
if save_checkpoint:
save_and_log_checkpoint(worker, actor)
raise
else:
# Update the worker's internal state before executing the method in
# case the method throws an exception.
worker.actor_task_counter = task_counter + 1
# Once the actor executes a task, it counts as loaded.
worker.actor_loaded = True
# Execute the actor method.
return method(actor, *args)
# Save the checkpoint before returning the method's return values.
if save_checkpoint:
save_and_log_checkpoint(worker, actor)
return method_returns
return actor_method_executor
@@ -207,7 +266,8 @@ def fetch_and_register_actor(actor_class_key, worker):
actor_method_name).id()
temporary_executor = make_actor_method_executor(worker,
actor_method_name,
temporary_actor_method)
temporary_actor_method,
actor_imported=False)
worker.functions[driver_id][function_id] = (actor_method_name,
temporary_executor)
worker.num_task_executions[driver_id][function_id] = 0
@@ -218,7 +278,7 @@ def fetch_and_register_actor(actor_class_key, worker):
except Exception:
# If an exception was thrown when the actor was imported, we record the
# traceback and notify the scheduler of the failure.
traceback_str = ray.worker.format_error_message(traceback.format_exc())
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
push_error_to_driver(worker.redis_client, "register_actor_signatures",
traceback_str, driver_id,
@@ -238,7 +298,8 @@ def fetch_and_register_actor(actor_class_key, worker):
function_id = compute_actor_method_function_id(
class_name, actor_method_name).id()
executor = make_actor_method_executor(worker, actor_method_name,
actor_method)
actor_method,
actor_imported=True)
worker.functions[driver_id][function_id] = (actor_method_name,
executor)
# We do not set worker.function_properties[driver_id][function_id]
@@ -412,18 +473,6 @@ class ActorMethod(object):
dependency=self._actor._ray_actor_cursor)
# Checkpoint methods do not take in the state of the previous actor method
# as an explicit data dependency.
class CheckpointMethod(ActorMethod):
def remote(self):
# A checkpoint's arguments are the current task counter and the
# object ID of the preceding task. The latter is an implicit data
# dependency, since the checkpoint method can run at any time.
args = [self._actor._ray_actor_counter,
[self._actor._ray_actor_cursor]]
return self._actor._actor_method_call(self._method_name, args=args)
class ActorHandleWrapper(object):
"""A wrapper for the contents of an ActorHandle.
@@ -455,9 +504,6 @@ def wrap_actor_handle(actor_handle):
Returns:
An ActorHandleWrapper instance that stores the ActorHandle's fields.
"""
if actor_handle._ray_checkpoint_interval > 0:
raise Exception("Checkpointing not yet supported for distributed "
"actor handles.")
wrapper = ActorHandleWrapper(
actor_handle._ray_actor_id,
compute_actor_handle_id(actor_handle._ray_actor_handle_id,
@@ -600,12 +646,6 @@ def make_actor_handle_class(class_name):
self._ray_actor_counter += 1
self._ray_actor_cursor = object_ids.pop()
# Submit a checkpoint task if it is time to do so.
if (self._ray_checkpoint_interval > 1 and
self._ray_actor_counter % self._ray_checkpoint_interval ==
0):
self.__ray_checkpoint__.remote()
# The last object returned is the dummy object that should be
# passed in to the next actor method. Do not return it to the user.
if len(object_ids) == 1:
@@ -629,10 +669,7 @@ def make_actor_handle_class(class_name):
# this was causing cyclic references which were prevent
# object deallocation from behaving in a predictable
# manner.
if attr == "__ray_checkpoint__":
actor_method_cls = CheckpointMethod
else:
actor_method_cls = ActorMethod
actor_method_cls = ActorMethod
return actor_method_cls(self, attr)
except AttributeError:
pass
@@ -755,9 +792,6 @@ def make_actor(cls, resources, checkpoint_interval):
"actor placement.")
if checkpoint_interval == 0:
raise Exception("checkpoint_interval must be greater than 0.")
# Add one to the checkpoint interval since we will insert a mock task for
# every checkpoint.
checkpoint_interval += 1
# Modify the class to have an additional method that will be used for
# terminating the worker.
@@ -802,97 +836,58 @@ def make_actor(cls, resources, checkpoint_interval):
actor_object = checkpoint
return actor_object
def __ray_checkpoint__(self, task_counter, previous_object_id):
"""Save or resume a stored checkpoint.
def __ray_checkpoint__(self):
"""Save a checkpoint.
This task checkpoints the current state of the actor. If the actor
has not yet executed to `task_counter`, then the task instead
attempts to resume from a saved checkpoint that matches
`task_counter`. If the most recently saved checkpoint is earlier
than `task_counter`, the task requests reconstruction of the tasks
that executed since the previous checkpoint and before
`task_counter`.
Args:
self: An instance of the actor class.
task_counter: The index assigned to this checkpoint method.
previous_object_id: The dummy object returned by the task that
immediately precedes this checkpoint.
Returns:
A bool representing whether the checkpoint was successfully
loaded (whether the actor can safely execute the next task)
and an Exception instance, if one was thrown.
This task saves the current state of the actor, the current task
frontier according to the local scheduler, and the checkpoint index
(number of tasks executed so far).
"""
worker = ray.worker.global_worker
previous_object_id = previous_object_id[0]
plasma_id = plasma.ObjectID(previous_object_id.id())
checkpoint_index = worker.actor_task_counter
# Get the state to save.
checkpoint = self.__ray_save_checkpoint__()
# Get the current task frontier, per actor handle.
# NOTE(swang): This only includes actor handles that the local
# scheduler has seen. Handle IDs for which no task has yet reached
# the local scheduler will not be included, and may not be runnable
# on checkpoint resumption.
actor_id = ray.local_scheduler.ObjectID(worker.actor_id)
frontier = worker.local_scheduler_client.get_actor_frontier(
actor_id)
# Save the checkpoint in Redis. TODO(rkn): Checkpoints
# should not be stored in Redis. Fix this.
set_actor_checkpoint(worker, worker.actor_id, checkpoint_index,
checkpoint, frontier)
# Initialize the return values. `actor_checkpoint_failed` will be
# set to True if we fail to load the checkpoint. `error` will be
# set to the Exception, if one is thrown.
actor_checkpoint_failed = False
error_to_return = None
def __ray_checkpoint_restore__(self):
"""Restore a checkpoint.
# Save or resume the checkpoint.
if worker.actor_loaded:
# The actor has loaded, so we are running the normal execution.
# Save the checkpoint.
print("Saving actor checkpoint. actor_counter = {}."
.format(task_counter))
actor_key = b"Actor:" + worker.actor_id
This task looks for a saved checkpoint and if found, restores the
state of the actor, the task frontier in the local scheduler, and
the checkpoint index (number of tasks executed so far).
try:
checkpoint = worker.actors[
worker.actor_id].__ray_save_checkpoint__()
# Save the checkpoint in Redis. TODO(rkn): Checkpoints
# should not be stored in Redis. Fix this.
worker.redis_client.hset(
actor_key,
"checkpoint_{}".format(task_counter),
checkpoint)
# Remove the previous checkpoints if there is one.
checkpoint_indices = get_checkpoint_indices(
worker, worker.actor_id)
for index in checkpoint_indices:
if index < task_counter:
worker.redis_client.hdel(
actor_key, "checkpoint_{}".format(index))
# An exception was thrown. Save the error.
except Exception as error:
# Checkpoint saves should not block execution on the actor,
# so we still consider the task successful.
error_to_return = error
else:
# The actor has not yet loaded. Try loading it from the most
# recent checkpoint.
checkpoint_index, checkpoint = get_actor_checkpoint(
worker, worker.actor_id)
if checkpoint_index == task_counter:
# The checkpoint matches ours. Resume the actor instance.
try:
actor = (worker.actor_class.
__ray_restore_from_checkpoint__(checkpoint))
worker.actors[worker.actor_id] = actor
# An exception was thrown. Save the error.
except Exception as error:
# We could not resume the checkpoint, so count the task
# as failed.
actor_checkpoint_failed = True
error_to_return = error
else:
# We cannot resume a mismatching checkpoint, so count the
# task as failed.
actor_checkpoint_failed = True
Returns:
A bool indicating whether a checkpoint was resumed.
"""
worker = ray.worker.global_worker
# Get the most recent checkpoint stored, if any.
checkpoint_index, checkpoint, frontier = get_actor_checkpoint(
worker, worker.actor_id)
# Try to resume from the checkpoint.
checkpoint_resumed = False
if checkpoint_index is not None:
# Load the actor state from the checkpoint.
worker.actors[worker.actor_id] = (
worker.actor_class.__ray_restore_from_checkpoint__(
checkpoint))
# Set the number of tasks executed so far.
worker.actor_task_counter = checkpoint_index
# Set the actor frontier in the local scheduler.
worker.local_scheduler_client.set_actor_frontier(frontier)
checkpoint_resumed = True
# Fall back to lineage reconstruction if we were unable to load the
# checkpoint.
if actor_checkpoint_failed:
worker.local_scheduler_client.reconstruct_object(
plasma_id.binary())
worker.local_scheduler_client.notify_unblocked()
return actor_checkpoint_failed, error_to_return
return checkpoint_resumed
Class.__module__ = cls.__module__
Class.__name__ = cls.__name__
+22
View File
@@ -20,6 +20,28 @@ def _random_string():
return np.random.bytes(20)
def format_error_message(exception_message, task_exception=False):
"""Improve the formatting of an exception thrown by a remote function.
This method takes a traceback from an exception and makes it nicer by
removing a few uninformative lines and adding some space to indent the
remaining lines nicely.
Args:
exception_message (str): A message generated by traceback.format_exc().
Returns:
A string of the formatted exception message.
"""
lines = exception_message.split("\n")
if task_exception:
# For errors that occur inside of tasks, remove lines 1, 2, 3, and 4,
# which are always the same, they just contain information about the
# main loop.
lines = lines[0:1] + lines[5:]
return "\n".join(lines)
def push_error_to_driver(redis_client, error_type, message, driver_id=None,
data=None):
"""Push an error message to the driver to be printed in the background.
+7 -42
View File
@@ -222,14 +222,6 @@ class Worker(object):
self.make_actor = None
self.actors = {}
self.actor_task_counter = 0
# Whether an actor instance has been loaded yet. The actor counts as
# loaded once it has either executed its first task or successfully
# resumed from a checkpoint.
self.actor_loaded = False
# This field is used to report actor checkpoint failure for the last
# task assigned. Workers are not assigned a task on startup, so we
# initialize to False.
self.actor_checkpoint_failed = False
# The number of threads Plasma should use when putting an object in the
# object store.
self.memcopy_threads = 12
@@ -755,7 +747,7 @@ class Worker(object):
except Exception as e:
self._handle_process_task_failure(
function_id, return_object_ids, e,
format_error_message(traceback.format_exc()))
ray.utils.format_error_message(traceback.format_exc()))
return
# Execute the task.
@@ -765,15 +757,15 @@ class Worker(object):
outputs = function_executor.executor(arguments)
else:
outputs = function_executor(
dummy_return_id, task.actor_counter(),
dummy_return_id,
self.actors[task.actor_id().id()],
*arguments)
except Exception as e:
# Determine whether the exception occured during a task, not an
# actor method.
task_exception = task.actor_id().id() == NIL_ACTOR_ID
traceback_str = format_error_message(traceback.format_exc(),
task_exception=task_exception)
traceback_str = ray.utils.format_error_message(
traceback.format_exc(), task_exception=task_exception)
self._handle_process_task_failure(function_id, return_object_ids,
e, traceback_str)
return
@@ -791,7 +783,7 @@ class Worker(object):
except Exception as e:
self._handle_process_task_failure(
function_id, return_object_ids, e,
format_error_message(traceback.format_exc()))
ray.utils.format_error_message(traceback.format_exc()))
def _handle_process_task_failure(self, function_id, return_object_ids,
error, backtrace):
@@ -863,12 +855,7 @@ class Worker(object):
A task from the local scheduler.
"""
with log_span("ray:get_task", worker=self):
task = self.local_scheduler_client.get_task(
self.actor_checkpoint_failed)
# We assume that the task is not a checkpoint, or that if it is,
# that the task will succeed. The checkpoint task executor is
# responsible for reporting task failure to the local scheduler.
self.actor_checkpoint_failed = False
task = self.local_scheduler_client.get_task()
# Automatically restrict the GPUs available to this task.
ray.utils.set_cuda_visible_devices(ray.get_gpu_ids())
@@ -1613,7 +1600,7 @@ def fetch_and_register_remote_function(key, worker=global_worker):
except Exception:
# If an exception was thrown when the remote function was imported, we
# record the traceback and notify the scheduler of the failure.
traceback_str = format_error_message(traceback.format_exc())
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(worker.redis_client,
"register_remote_function",
@@ -2351,28 +2338,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
return ready_ids, remaining_ids
def format_error_message(exception_message, task_exception=False):
"""Improve the formatting of an exception thrown by a remote function.
This method takes a traceback from an exception and makes it nicer by
removing a few uninformative lines and adding some space to indent the
remaining lines nicely.
Args:
exception_message (str): A message generated by traceback.format_exc().
Returns:
A string of the formatted exception message.
"""
lines = exception_message.split("\n")
if task_exception:
# For errors that occur inside of tasks, remove lines 1, 2, 3, and 4,
# which are always the same, they just contain information about the
# main loop.
lines = lines[0:1] + lines[5:]
return "\n".join(lines)
def _submit_task(function_id, args, worker=global_worker):
"""This is a wrapper around worker.submit_task.