mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 06:15:23 +08:00
Implement actor checkpointing (#3839)
* Implement Actor checkpointing * docs * fix * fix * fix * move restore-from-checkpoint to HandleActorStateTransition * Revert "move restore-from-checkpoint to HandleActorStateTransition" This reverts commit 9aa4447c1e3e321f42a1d895d72f17098b72de12. * resubmit waiting tasks when actor frontier restored * add doc about num_actor_checkpoints_to_keep=1 * add num_actor_checkpoints_to_keep to Cython * add checkpoint_expired api * check if actor class is abstract * change checkpoint_ids to long string * implement java * Refactor to delay actor creation publish until checkpoint is resumed * debug, lint * Erase from checkpoints to restore if task fails * fix lint * update comments * avoid duplicated actor notification log * fix unintended change * add actor_id to checkpoint_expired * small java updates * make checkpoint info per actor * lint * Remove logging * Remove old actor checkpointing Python code, move new checkpointing code to FunctionActionManager * Replace old actor checkpointing tests * Fix test and lint * address comments * consolidate kill_actor * Remove __ray_checkpoint__ * fix non-ascii char * Loosen test checks * fix java * fix sphinx-build
This commit is contained in:
+104
-40
@@ -510,8 +510,7 @@ class FunctionActorManager(object):
|
||||
self._worker.redis_client.hmset(key, actor_class_info)
|
||||
self._worker.redis_client.rpush("Exports", key)
|
||||
|
||||
def export_actor_class(self, Class, actor_method_names,
|
||||
checkpoint_interval):
|
||||
def export_actor_class(self, Class, actor_method_names):
|
||||
function_descriptor = FunctionDescriptor.from_class(Class)
|
||||
# `task_driver_id` shouldn't be NIL, unless:
|
||||
# 1) This worker isn't an actor;
|
||||
@@ -528,7 +527,6 @@ class FunctionActorManager(object):
|
||||
"class_name": Class.__name__,
|
||||
"module": Class.__module__,
|
||||
"class": pickle.dumps(Class),
|
||||
"checkpoint_interval": checkpoint_interval,
|
||||
"driver_id": driver_id.binary(),
|
||||
"actor_method_names": json.dumps(list(actor_method_names))
|
||||
}
|
||||
@@ -576,17 +574,16 @@ class FunctionActorManager(object):
|
||||
actor_class_key: The key in Redis to use to fetch the actor.
|
||||
"""
|
||||
actor_id = self._worker.actor_id
|
||||
(driver_id_str, class_name, module, pickled_class, checkpoint_interval,
|
||||
(driver_id_str, class_name, module, pickled_class,
|
||||
actor_method_names) = self._worker.redis_client.hmget(
|
||||
actor_class_key, [
|
||||
"driver_id", "class_name", "module", "class",
|
||||
"checkpoint_interval", "actor_method_names"
|
||||
"actor_method_names"
|
||||
])
|
||||
|
||||
class_name = decode(class_name)
|
||||
module = decode(module)
|
||||
driver_id = ray.DriverID(driver_id_str)
|
||||
checkpoint_interval = int(checkpoint_interval)
|
||||
actor_method_names = json.loads(decode(actor_method_names))
|
||||
|
||||
# In Python 2, json loads strings as unicode, so convert them back to
|
||||
@@ -605,7 +602,6 @@ class FunctionActorManager(object):
|
||||
pass
|
||||
|
||||
self._worker.actors[actor_id] = TemporaryActor()
|
||||
self._worker.actor_checkpoint_interval = checkpoint_interval
|
||||
|
||||
def temporary_actor_method(*xs):
|
||||
raise Exception(
|
||||
@@ -694,48 +690,116 @@ class FunctionActorManager(object):
|
||||
# to execute.
|
||||
self._worker.actor_task_counter += 1
|
||||
|
||||
# If this is the first task to execute on the actor, try to resume
|
||||
# from a checkpoint.
|
||||
# Current __init__ will be called by default. So the real function
|
||||
# call will start from 2.
|
||||
if actor_imported and self._worker.actor_task_counter == 2:
|
||||
checkpoint_resumed = ray.actor.restore_and_log_checkpoint(
|
||||
self._worker, actor)
|
||||
if checkpoint_resumed:
|
||||
# NOTE(swang): Since we did not actually execute the
|
||||
# __init__ method, this will put None as the return value.
|
||||
# If the __init__ method is supposed to return multiple
|
||||
# values, an exception will be logged.
|
||||
return
|
||||
|
||||
# Determine whether we should checkpoint the actor.
|
||||
checkpointing_on = (actor_imported
|
||||
and self._worker.actor_checkpoint_interval > 0)
|
||||
# We should checkpoint the actor if user checkpointing is on, we've
|
||||
# executed checkpoint_interval tasks since the last checkpoint, and
|
||||
# the method we're about to execute is not a checkpoint.
|
||||
save_checkpoint = (checkpointing_on
|
||||
and (self._worker.actor_task_counter %
|
||||
self._worker.actor_checkpoint_interval == 0
|
||||
and method_name != "__ray_checkpoint__"))
|
||||
|
||||
# Execute the assigned method and save a checkpoint if necessary.
|
||||
try:
|
||||
if is_class_method(method):
|
||||
method_returns = method(*args)
|
||||
else:
|
||||
method_returns = method(actor, *args)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# Save the checkpoint before allowing the method exception
|
||||
# to be thrown.
|
||||
if save_checkpoint:
|
||||
ray.actor.save_and_log_checkpoint(self._worker, actor)
|
||||
raise
|
||||
if isinstance(actor, ray.actor.Checkpointable):
|
||||
self._save_and_log_checkpoint(actor)
|
||||
raise e
|
||||
else:
|
||||
# Save the checkpoint before returning the method's return
|
||||
# values.
|
||||
if save_checkpoint:
|
||||
ray.actor.save_and_log_checkpoint(self._worker, actor)
|
||||
# Handle any checkpointing operations before storing the
|
||||
# method's return values.
|
||||
# NOTE(swang): If method_returns is a pointer to the actor's
|
||||
# state and the checkpointing operations can modify the return
|
||||
# values if they mutate the actor's state. Is this okay?
|
||||
if isinstance(actor, ray.actor.Checkpointable):
|
||||
# If this is the first task to execute on the actor, try to
|
||||
# resume from a checkpoint.
|
||||
if self._worker.actor_task_counter == 1:
|
||||
if actor_imported:
|
||||
self._restore_and_log_checkpoint(actor)
|
||||
else:
|
||||
# Save the checkpoint before returning the method's
|
||||
# return values.
|
||||
self._save_and_log_checkpoint(actor)
|
||||
return method_returns
|
||||
|
||||
return actor_method_executor
|
||||
|
||||
def _save_and_log_checkpoint(self, actor):
|
||||
"""Save an actor checkpoint if necessary and log any errors.
|
||||
|
||||
Args:
|
||||
actor: The actor to checkpoint.
|
||||
|
||||
Returns:
|
||||
The result of the actor's user-defined `save_checkpoint` method.
|
||||
"""
|
||||
actor_id = self._worker.actor_id
|
||||
checkpoint_info = self._worker.actor_checkpoint_info[actor_id]
|
||||
checkpoint_info.num_tasks_since_last_checkpoint += 1
|
||||
now = int(1000 * time.time())
|
||||
checkpoint_context = ray.actor.CheckpointContext(
|
||||
actor_id, checkpoint_info.num_tasks_since_last_checkpoint,
|
||||
now - checkpoint_info.last_checkpoint_timestamp)
|
||||
# If we should take a checkpoint, notify raylet to prepare a checkpoint
|
||||
# and then call `save_checkpoint`.
|
||||
if actor.should_checkpoint(checkpoint_context):
|
||||
try:
|
||||
now = int(1000 * time.time())
|
||||
checkpoint_id = (self._worker.raylet_client.
|
||||
prepare_actor_checkpoint(actor_id))
|
||||
checkpoint_info.checkpoint_ids.append(checkpoint_id)
|
||||
actor.save_checkpoint(actor_id, checkpoint_id)
|
||||
if (len(checkpoint_info.checkpoint_ids) >
|
||||
ray._config.num_actor_checkpoints_to_keep()):
|
||||
actor.checkpoint_expired(
|
||||
actor_id,
|
||||
checkpoint_info.checkpoint_ids.pop(0),
|
||||
)
|
||||
checkpoint_info.num_tasks_since_last_checkpoint = 0
|
||||
checkpoint_info.last_checkpoint_timestamp = now
|
||||
except Exception:
|
||||
# Checkpoint save or reload failed. Notify the driver.
|
||||
traceback_str = ray.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
ray.utils.push_error_to_driver(
|
||||
self._worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=self._worker.task_driver_id)
|
||||
|
||||
def _restore_and_log_checkpoint(self, actor):
|
||||
"""Restore an actor from a checkpoint if available and log any errors.
|
||||
|
||||
This should only be called on workers that have just executed an actor
|
||||
creation task.
|
||||
|
||||
Args:
|
||||
actor: The actor to restore from a checkpoint.
|
||||
"""
|
||||
actor_id = self._worker.actor_id
|
||||
try:
|
||||
checkpoints = ray.actor.get_checkpoints_for_actor(actor_id)
|
||||
if len(checkpoints) > 0:
|
||||
# If we found previously saved checkpoints for this actor,
|
||||
# call the `load_checkpoint` callback.
|
||||
checkpoint_id = actor.load_checkpoint(actor_id, checkpoints)
|
||||
if checkpoint_id is not None:
|
||||
# Check that the returned checkpoint id is in the
|
||||
# `available_checkpoints` list.
|
||||
msg = (
|
||||
"`load_checkpoint` must return a checkpoint id that " +
|
||||
"exists in the `available_checkpoints` list, or eone.")
|
||||
assert any(checkpoint_id == checkpoint.checkpoint_id
|
||||
for checkpoint in checkpoints), msg
|
||||
# Notify raylet that this actor has been resumed from
|
||||
# a checkpoint.
|
||||
(self._worker.raylet_client.
|
||||
notify_actor_resumed_from_checkpoint(
|
||||
actor_id, checkpoint_id))
|
||||
except Exception:
|
||||
# Checkpoint save or reload failed. Notify the driver.
|
||||
traceback_str = ray.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
ray.utils.push_error_to_driver(
|
||||
self._worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=self._worker.task_driver_id)
|
||||
|
||||
Reference in New Issue
Block a user