Implement actor checkpointing (#3839)

* Implement Actor checkpointing

* docs

* fix

* fix

* fix

* move restore-from-checkpoint to HandleActorStateTransition

* Revert "move restore-from-checkpoint to HandleActorStateTransition"

This reverts commit 9aa4447c1e3e321f42a1d895d72f17098b72de12.

* resubmit waiting tasks when actor frontier restored

* add doc about num_actor_checkpoints_to_keep=1

* add num_actor_checkpoints_to_keep to Cython

* add checkpoint_expired api

* check if actor class is abstract

* change checkpoint_ids to long string

* implement java

* Refactor to delay actor creation publish until checkpoint is resumed

* debug, lint

* Erase from checkpoints to restore if task fails

* fix lint

* update comments

* avoid duplicated actor notification log

* fix unintended change

* add actor_id to checkpoint_expired

* small java updates

* make checkpoint info per actor

* lint

* Remove logging

* Remove old actor checkpointing Python code, move new checkpointing code to FunctionActionManager

* Replace old actor checkpointing tests

* Fix test and lint

* address comments

* consolidate kill_actor

* Remove __ray_checkpoint__

* fix non-ascii char

* Loosen test checks

* fix java

* fix sphinx-build
This commit is contained in:
Hao Chen
2019-02-13 19:39:02 +08:00
committed by GitHub
parent 57dcd3033e
commit f31a79f3f7
41 changed files with 1708 additions and 490 deletions
+104 -40
View File
@@ -510,8 +510,7 @@ class FunctionActorManager(object):
self._worker.redis_client.hmset(key, actor_class_info)
self._worker.redis_client.rpush("Exports", key)
def export_actor_class(self, Class, actor_method_names,
checkpoint_interval):
def export_actor_class(self, Class, actor_method_names):
function_descriptor = FunctionDescriptor.from_class(Class)
# `task_driver_id` shouldn't be NIL, unless:
# 1) This worker isn't an actor;
@@ -528,7 +527,6 @@ class FunctionActorManager(object):
"class_name": Class.__name__,
"module": Class.__module__,
"class": pickle.dumps(Class),
"checkpoint_interval": checkpoint_interval,
"driver_id": driver_id.binary(),
"actor_method_names": json.dumps(list(actor_method_names))
}
@@ -576,17 +574,16 @@ class FunctionActorManager(object):
actor_class_key: The key in Redis to use to fetch the actor.
"""
actor_id = self._worker.actor_id
(driver_id_str, class_name, module, pickled_class, checkpoint_interval,
(driver_id_str, class_name, module, pickled_class,
actor_method_names) = self._worker.redis_client.hmget(
actor_class_key, [
"driver_id", "class_name", "module", "class",
"checkpoint_interval", "actor_method_names"
"actor_method_names"
])
class_name = decode(class_name)
module = decode(module)
driver_id = ray.DriverID(driver_id_str)
checkpoint_interval = int(checkpoint_interval)
actor_method_names = json.loads(decode(actor_method_names))
# In Python 2, json loads strings as unicode, so convert them back to
@@ -605,7 +602,6 @@ class FunctionActorManager(object):
pass
self._worker.actors[actor_id] = TemporaryActor()
self._worker.actor_checkpoint_interval = checkpoint_interval
def temporary_actor_method(*xs):
raise Exception(
@@ -694,48 +690,116 @@ class FunctionActorManager(object):
# to execute.
self._worker.actor_task_counter += 1
# If this is the first task to execute on the actor, try to resume
# from a checkpoint.
# Current __init__ will be called by default. So the real function
# call will start from 2.
if actor_imported and self._worker.actor_task_counter == 2:
checkpoint_resumed = ray.actor.restore_and_log_checkpoint(
self._worker, actor)
if checkpoint_resumed:
# NOTE(swang): Since we did not actually execute the
# __init__ method, this will put None as the return value.
# If the __init__ method is supposed to return multiple
# values, an exception will be logged.
return
# Determine whether we should checkpoint the actor.
checkpointing_on = (actor_imported
and self._worker.actor_checkpoint_interval > 0)
# We should checkpoint the actor if user checkpointing is on, we've
# executed checkpoint_interval tasks since the last checkpoint, and
# the method we're about to execute is not a checkpoint.
save_checkpoint = (checkpointing_on
and (self._worker.actor_task_counter %
self._worker.actor_checkpoint_interval == 0
and method_name != "__ray_checkpoint__"))
# Execute the assigned method and save a checkpoint if necessary.
try:
if is_class_method(method):
method_returns = method(*args)
else:
method_returns = method(actor, *args)
except Exception:
except Exception as e:
# Save the checkpoint before allowing the method exception
# to be thrown.
if save_checkpoint:
ray.actor.save_and_log_checkpoint(self._worker, actor)
raise
if isinstance(actor, ray.actor.Checkpointable):
self._save_and_log_checkpoint(actor)
raise e
else:
# Save the checkpoint before returning the method's return
# values.
if save_checkpoint:
ray.actor.save_and_log_checkpoint(self._worker, actor)
# Handle any checkpointing operations before storing the
# method's return values.
# NOTE(swang): If method_returns is a pointer to the actor's
# state and the checkpointing operations can modify the return
# values if they mutate the actor's state. Is this okay?
if isinstance(actor, ray.actor.Checkpointable):
# If this is the first task to execute on the actor, try to
# resume from a checkpoint.
if self._worker.actor_task_counter == 1:
if actor_imported:
self._restore_and_log_checkpoint(actor)
else:
# Save the checkpoint before returning the method's
# return values.
self._save_and_log_checkpoint(actor)
return method_returns
return actor_method_executor
def _save_and_log_checkpoint(self, actor):
"""Save an actor checkpoint if necessary and log any errors.
Args:
actor: The actor to checkpoint.
Returns:
The result of the actor's user-defined `save_checkpoint` method.
"""
actor_id = self._worker.actor_id
checkpoint_info = self._worker.actor_checkpoint_info[actor_id]
checkpoint_info.num_tasks_since_last_checkpoint += 1
now = int(1000 * time.time())
checkpoint_context = ray.actor.CheckpointContext(
actor_id, checkpoint_info.num_tasks_since_last_checkpoint,
now - checkpoint_info.last_checkpoint_timestamp)
# If we should take a checkpoint, notify raylet to prepare a checkpoint
# and then call `save_checkpoint`.
if actor.should_checkpoint(checkpoint_context):
try:
now = int(1000 * time.time())
checkpoint_id = (self._worker.raylet_client.
prepare_actor_checkpoint(actor_id))
checkpoint_info.checkpoint_ids.append(checkpoint_id)
actor.save_checkpoint(actor_id, checkpoint_id)
if (len(checkpoint_info.checkpoint_ids) >
ray._config.num_actor_checkpoints_to_keep()):
actor.checkpoint_expired(
actor_id,
checkpoint_info.checkpoint_ids.pop(0),
)
checkpoint_info.num_tasks_since_last_checkpoint = 0
checkpoint_info.last_checkpoint_timestamp = now
except Exception:
# Checkpoint save or reload failed. Notify the driver.
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
ray.utils.push_error_to_driver(
self._worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=self._worker.task_driver_id)
def _restore_and_log_checkpoint(self, actor):
"""Restore an actor from a checkpoint if available and log any errors.
This should only be called on workers that have just executed an actor
creation task.
Args:
actor: The actor to restore from a checkpoint.
"""
actor_id = self._worker.actor_id
try:
checkpoints = ray.actor.get_checkpoints_for_actor(actor_id)
if len(checkpoints) > 0:
# If we found previously saved checkpoints for this actor,
# call the `load_checkpoint` callback.
checkpoint_id = actor.load_checkpoint(actor_id, checkpoints)
if checkpoint_id is not None:
# Check that the returned checkpoint id is in the
# `available_checkpoints` list.
msg = (
"`load_checkpoint` must return a checkpoint id that " +
"exists in the `available_checkpoints` list, or eone.")
assert any(checkpoint_id == checkpoint.checkpoint_id
for checkpoint in checkpoints), msg
# Notify raylet that this actor has been resumed from
# a checkpoint.
(self._worker.raylet_client.
notify_actor_resumed_from_checkpoint(
actor_id, checkpoint_id))
except Exception:
# Checkpoint save or reload failed. Notify the driver.
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
ray.utils.push_error_to_driver(
self._worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=self._worker.task_driver_id)