mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:12:00 +08:00
Implement actor checkpointing (#3839)
* Implement Actor checkpointing * docs * fix * fix * fix * move restore-from-checkpoint to HandleActorStateTransition * Revert "move restore-from-checkpoint to HandleActorStateTransition" This reverts commit 9aa4447c1e3e321f42a1d895d72f17098b72de12. * resubmit waiting tasks when actor frontier restored * add doc about num_actor_checkpoints_to_keep=1 * add num_actor_checkpoints_to_keep to Cython * add checkpoint_expired api * check if actor class is abstract * change checkpoint_ids to long string * implement java * Refactor to delay actor creation publish until checkpoint is resumed * debug, lint * Erase from checkpoints to restore if task fails * fix lint * update comments * avoid duplicated actor notification log * fix unintended change * add actor_id to checkpoint_expired * small java updates * make checkpoint info per actor * lint * Remove logging * Remove old actor checkpointing Python code, move new checkpointing code to FunctionActionManager * Replace old actor checkpointing tests * Fix test and lint * address comments * consolidate kill_actor * Remove __ray_checkpoint__ * fix non-ascii char * Loosen test checks * fix java * fix sphinx-build
This commit is contained in:
+24
-5
@@ -49,9 +49,20 @@ except ImportError as e:
|
||||
modin_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "modin")
|
||||
sys.path.append(modin_path)
|
||||
|
||||
from ray._raylet import (UniqueID, ObjectID, DriverID, ClientID, ActorID,
|
||||
ActorHandleID, FunctionID, ActorClassID, TaskID,
|
||||
_ID_TYPES, Config as _Config) # noqa: E402
|
||||
from ray._raylet import (
|
||||
ActorCheckpointID,
|
||||
ActorClassID,
|
||||
ActorHandleID,
|
||||
ActorID,
|
||||
ClientID,
|
||||
Config as _Config,
|
||||
DriverID,
|
||||
FunctionID,
|
||||
ObjectID,
|
||||
TaskID,
|
||||
UniqueID,
|
||||
_ID_TYPES,
|
||||
) # noqa: E402
|
||||
|
||||
_config = _Config()
|
||||
|
||||
@@ -82,8 +93,16 @@ __all__ = [
|
||||
]
|
||||
|
||||
__all__ += [
|
||||
"UniqueID", "ObjectID", "DriverID", "ClientID", "ActorID", "ActorHandleID",
|
||||
"FunctionID", "ActorClassID", "TaskID"
|
||||
"ActorCheckpointID",
|
||||
"ActorClassID",
|
||||
"ActorHandleID",
|
||||
"ActorID",
|
||||
"ClientID",
|
||||
"DriverID",
|
||||
"FunctionID",
|
||||
"ObjectID",
|
||||
"TaskID",
|
||||
"UniqueID",
|
||||
]
|
||||
|
||||
import ctypes # noqa: E402
|
||||
|
||||
+32
-5
@@ -19,12 +19,31 @@ include "includes/ray_config.pxi"
|
||||
include "includes/task.pxi"
|
||||
|
||||
from ray.includes.common cimport (
|
||||
CUniqueID, CTaskID, CObjectID, CFunctionID, CActorClassID, CActorID,
|
||||
CActorHandleID, CWorkerID, CDriverID, CConfigID, CClientID,
|
||||
CLanguage, CRayStatus, LANGUAGE_CPP, LANGUAGE_JAVA, LANGUAGE_PYTHON)
|
||||
CActorCheckpointID,
|
||||
CActorClassID,
|
||||
CActorHandleID,
|
||||
CActorID,
|
||||
CClientID,
|
||||
CConfigID,
|
||||
CDriverID,
|
||||
CFunctionID,
|
||||
CLanguage,
|
||||
CObjectID,
|
||||
CRayStatus,
|
||||
CTaskID,
|
||||
CUniqueID,
|
||||
CWorkerID,
|
||||
LANGUAGE_CPP,
|
||||
LANGUAGE_JAVA,
|
||||
LANGUAGE_PYTHON,
|
||||
)
|
||||
from ray.includes.libraylet cimport (
|
||||
CRayletClient, GCSProfileTableDataT, GCSProfileEventT,
|
||||
ResourceMappingType, WaitResultPair)
|
||||
CRayletClient,
|
||||
GCSProfileEventT,
|
||||
GCSProfileTableDataT,
|
||||
ResourceMappingType,
|
||||
WaitResultPair,
|
||||
)
|
||||
from ray.includes.task cimport CTaskSpecification
|
||||
from ray.includes.ray_config cimport RayConfig
|
||||
from ray.utils import decode
|
||||
@@ -303,6 +322,14 @@ cdef class RayletClient:
|
||||
cdef c_vector[CObjectID] free_ids = ObjectIDsToVector(object_ids)
|
||||
check_status(self.client.get().FreeObjects(free_ids, local_only))
|
||||
|
||||
def prepare_actor_checkpoint(self, ActorID actor_id):
|
||||
cdef CActorCheckpointID checkpoint_id
|
||||
check_status(self.client.get().PrepareActorCheckpoint(actor_id.data, checkpoint_id))
|
||||
return ObjectID.from_native(checkpoint_id);
|
||||
|
||||
def notify_actor_resumed_from_checkpoint(self, ActorID actor_id, ActorCheckpointID checkpoint_id):
|
||||
check_status(self.client.get().NotifyActorResumedFromCheckpoint(actor_id.data, checkpoint_id.data))
|
||||
|
||||
@property
|
||||
def language(self):
|
||||
return Language.from_native(self.client.get().GetLanguage())
|
||||
|
||||
+144
-167
@@ -6,11 +6,13 @@ import copy
|
||||
import hashlib
|
||||
import inspect
|
||||
import logging
|
||||
import six
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import ray.cloudpickle as pickle
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import namedtuple
|
||||
|
||||
from ray.function_manager import FunctionDescriptor
|
||||
import ray.ray_constants as ray_constants
|
||||
import ray.signature as signature
|
||||
@@ -75,90 +77,6 @@ def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id):
|
||||
return ActorHandleID(handle_id)
|
||||
|
||||
|
||||
def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
|
||||
frontier):
|
||||
"""Set the most recent checkpoint associated with a given actor ID.
|
||||
|
||||
Args:
|
||||
worker: The worker to use to get the checkpoint.
|
||||
actor_id: The actor ID of the actor to get the checkpoint for.
|
||||
checkpoint_index: The number of tasks included in the checkpoint.
|
||||
checkpoint: The state object to save.
|
||||
frontier: The task frontier at the time of the checkpoint.
|
||||
"""
|
||||
assert isinstance(actor_id, ActorID)
|
||||
actor_key = b"Actor:" + actor_id.binary()
|
||||
worker.redis_client.hmset(
|
||||
actor_key, {
|
||||
"checkpoint_index": checkpoint_index,
|
||||
"checkpoint": checkpoint,
|
||||
"frontier": frontier,
|
||||
})
|
||||
|
||||
|
||||
def save_and_log_checkpoint(worker, actor):
|
||||
"""Save a checkpoint on the actor and log any errors.
|
||||
|
||||
Args:
|
||||
worker: The worker to use to log errors.
|
||||
actor: The actor to checkpoint.
|
||||
"""
|
||||
try:
|
||||
actor.__ray_checkpoint__()
|
||||
except Exception:
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
ray.utils.push_error_to_driver(
|
||||
worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=worker.task_driver_id)
|
||||
|
||||
|
||||
def restore_and_log_checkpoint(worker, actor):
|
||||
"""Restore an actor from a checkpoint and log any errors.
|
||||
|
||||
Args:
|
||||
worker: The worker to use to log errors.
|
||||
actor: The actor to restore.
|
||||
"""
|
||||
checkpoint_resumed = False
|
||||
try:
|
||||
checkpoint_resumed = actor.__ray_checkpoint_restore__()
|
||||
except Exception:
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
ray.utils.push_error_to_driver(
|
||||
worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=worker.task_driver_id)
|
||||
return checkpoint_resumed
|
||||
|
||||
|
||||
def get_actor_checkpoint(worker, actor_id):
|
||||
"""Get the most recent checkpoint associated with a given actor ID.
|
||||
|
||||
Args:
|
||||
worker: The worker to use to get the checkpoint.
|
||||
actor_id: The actor ID of the actor to get the checkpoint for.
|
||||
|
||||
Returns:
|
||||
If a checkpoint exists, this returns a tuple of the number of tasks
|
||||
included in the checkpoint, the saved checkpoint state, and the
|
||||
task frontier at the time of the checkpoint. If no checkpoint
|
||||
exists, all objects are set to None. The checkpoint index is the .
|
||||
executed on the actor before the checkpoint was made.
|
||||
"""
|
||||
assert isinstance(actor_id, ActorID)
|
||||
actor_key = b"Actor:" + actor_id.binary()
|
||||
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
|
||||
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
|
||||
if checkpoint_index is not None:
|
||||
checkpoint_index = int(checkpoint_index)
|
||||
return checkpoint_index, checkpoint, frontier
|
||||
|
||||
|
||||
def method(*args, **kwargs):
|
||||
"""Annotate an actor method.
|
||||
|
||||
@@ -234,7 +152,6 @@ class ActorClass(object):
|
||||
additional methods added like __ray_terminate__).
|
||||
_class_id: The ID of this actor class.
|
||||
_class_name: The name of this class.
|
||||
_checkpoint_interval: The interval at which to checkpoint actor state.
|
||||
_num_cpus: The default number of CPUs required by the actor creation
|
||||
task.
|
||||
_num_gpus: The default number of GPUs required by the actor creation
|
||||
@@ -250,13 +167,11 @@ class ActorClass(object):
|
||||
each actor method.
|
||||
"""
|
||||
|
||||
def __init__(self, modified_class, class_id, checkpoint_interval,
|
||||
max_reconstructions, num_cpus, num_gpus, resources,
|
||||
actor_method_cpus):
|
||||
def __init__(self, modified_class, class_id, max_reconstructions, num_cpus,
|
||||
num_gpus, resources, actor_method_cpus):
|
||||
self._modified_class = modified_class
|
||||
self._class_id = class_id
|
||||
self._class_name = modified_class.__name__
|
||||
self._checkpoint_interval = checkpoint_interval
|
||||
self._max_reconstructions = max_reconstructions
|
||||
self._num_cpus = num_cpus
|
||||
self._num_gpus = num_gpus
|
||||
@@ -383,8 +298,7 @@ class ActorClass(object):
|
||||
# Export the actor.
|
||||
if not self._exported:
|
||||
worker.function_actor_manager.export_actor_class(
|
||||
self._modified_class, self._actor_method_names,
|
||||
self._checkpoint_interval)
|
||||
self._modified_class, self._actor_method_names)
|
||||
self._exported = True
|
||||
|
||||
resources = ray.utils.resources_from_resource_arguments(
|
||||
@@ -564,8 +478,6 @@ class ActorHandle(object):
|
||||
return getattr(worker.actors[self._ray_actor_id],
|
||||
method_name)(*copy.deepcopy(args))
|
||||
|
||||
is_actor_checkpoint_method = (method_name == "__ray_checkpoint__")
|
||||
|
||||
function_descriptor = FunctionDescriptor(
|
||||
self._ray_module_name, method_name, self._ray_class_name)
|
||||
with self._ray_actor_lock:
|
||||
@@ -575,7 +487,6 @@ class ActorHandle(object):
|
||||
actor_id=self._ray_actor_id,
|
||||
actor_handle_id=self._ray_actor_handle_id,
|
||||
actor_counter=self._ray_actor_counter,
|
||||
is_actor_checkpoint_method=is_actor_checkpoint_method,
|
||||
actor_creation_dummy_object_id=(
|
||||
self._ray_actor_creation_dummy_object_id),
|
||||
execution_dependencies=[self._ray_actor_cursor],
|
||||
@@ -770,7 +681,7 @@ class ActorHandle(object):
|
||||
|
||||
|
||||
def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
|
||||
checkpoint_interval, max_reconstructions):
|
||||
max_reconstructions):
|
||||
# Give an error if cls is an old-style class.
|
||||
if not issubclass(cls, object):
|
||||
raise TypeError(
|
||||
@@ -778,13 +689,14 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
|
||||
"classes. In Python 2, you must declare the class with "
|
||||
"'class ClassName(object):' instead of 'class ClassName:'.")
|
||||
|
||||
if checkpoint_interval is None:
|
||||
checkpoint_interval = -1
|
||||
if issubclass(cls, Checkpointable) and inspect.isabstract(cls):
|
||||
raise TypeError(
|
||||
"A checkpointable actor class should implement all abstract "
|
||||
"methods in the `Checkpointable` interface.")
|
||||
|
||||
if max_reconstructions is None:
|
||||
max_reconstructions = 0
|
||||
|
||||
if checkpoint_interval == 0:
|
||||
raise Exception("checkpoint_interval must be greater than 0.")
|
||||
if not (ray_constants.NO_RECONSTRUCTION <= max_reconstructions <=
|
||||
ray_constants.INFINITE_RECONSTRUCTION):
|
||||
raise Exception("max_reconstructions must be in range [%d, %d]." %
|
||||
@@ -804,26 +716,6 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
|
||||
sys.exit(0)
|
||||
assert False, "This process should have terminated."
|
||||
|
||||
def __ray_save_checkpoint__(self):
|
||||
if hasattr(self, "__ray_save__"):
|
||||
object_to_serialize = self.__ray_save__()
|
||||
else:
|
||||
object_to_serialize = self
|
||||
return pickle.dumps(object_to_serialize)
|
||||
|
||||
@classmethod
|
||||
def __ray_restore_from_checkpoint__(cls, pickled_checkpoint):
|
||||
checkpoint = pickle.loads(pickled_checkpoint)
|
||||
if hasattr(cls, "__ray_restore__"):
|
||||
actor_object = cls.__new__(cls)
|
||||
actor_object.__ray_restore__(checkpoint)
|
||||
else:
|
||||
# TODO(rkn): It's possible that this will cause problems. When
|
||||
# you unpickle the same object twice, the two objects will not
|
||||
# have the same class.
|
||||
actor_object = checkpoint
|
||||
return actor_object
|
||||
|
||||
def __ray_checkpoint__(self):
|
||||
"""Save a checkpoint.
|
||||
|
||||
@@ -832,58 +724,143 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
|
||||
(number of tasks executed so far).
|
||||
"""
|
||||
worker = ray.worker.global_worker
|
||||
checkpoint_index = worker.actor_task_counter
|
||||
# Get the state to save.
|
||||
checkpoint = self.__ray_save_checkpoint__()
|
||||
# Get the current task frontier, per actor handle.
|
||||
# NOTE(swang): This only includes actor handles that the local
|
||||
# scheduler has seen. Handle IDs for which no task has yet reached
|
||||
# the local scheduler will not be included, and may not be runnable
|
||||
# on checkpoint resumption.
|
||||
actor_id = worker.actor_id
|
||||
frontier = worker.raylet_client.get_actor_frontier(actor_id)
|
||||
# Save the checkpoint in Redis. TODO(rkn): Checkpoints
|
||||
# should not be stored in Redis. Fix this.
|
||||
set_actor_checkpoint(worker, worker.actor_id, checkpoint_index,
|
||||
checkpoint, frontier)
|
||||
|
||||
def __ray_checkpoint_restore__(self):
|
||||
"""Restore a checkpoint.
|
||||
|
||||
This task looks for a saved checkpoint and if found, restores the
|
||||
state of the actor, the task frontier in the local scheduler, and
|
||||
the checkpoint index (number of tasks executed so far).
|
||||
|
||||
Returns:
|
||||
A bool indicating whether a checkpoint was resumed.
|
||||
"""
|
||||
worker = ray.worker.global_worker
|
||||
# Get the most recent checkpoint stored, if any.
|
||||
checkpoint_index, checkpoint, frontier = get_actor_checkpoint(
|
||||
worker, worker.actor_id)
|
||||
# Try to resume from the checkpoint.
|
||||
checkpoint_resumed = False
|
||||
if checkpoint_index is not None:
|
||||
# Load the actor state from the checkpoint.
|
||||
worker.actors[worker.actor_id] = (
|
||||
worker.actor_class.__ray_restore_from_checkpoint__(
|
||||
checkpoint))
|
||||
# Set the number of tasks executed so far.
|
||||
worker.actor_task_counter = checkpoint_index
|
||||
# Set the actor frontier in the local scheduler.
|
||||
worker.raylet_client.set_actor_frontier(frontier)
|
||||
checkpoint_resumed = True
|
||||
|
||||
return checkpoint_resumed
|
||||
if not isinstance(self, ray.actor.Checkpointable):
|
||||
raise Exception(
|
||||
"__ray_checkpoint__.remote() may only be called on actors "
|
||||
"that implement ray.actor.Checkpointable")
|
||||
return worker._save_actor_checkpoint()
|
||||
|
||||
Class.__module__ = cls.__module__
|
||||
Class.__name__ = cls.__name__
|
||||
|
||||
class_id = ActorClassID(_random_string())
|
||||
|
||||
return ActorClass(Class, class_id, checkpoint_interval,
|
||||
max_reconstructions, num_cpus, num_gpus, resources,
|
||||
actor_method_cpus)
|
||||
return ActorClass(Class, class_id, max_reconstructions, num_cpus, num_gpus,
|
||||
resources, actor_method_cpus)
|
||||
|
||||
|
||||
ray.worker.global_worker.make_actor = make_actor
|
||||
|
||||
CheckpointContext = namedtuple(
|
||||
'CheckpointContext',
|
||||
[
|
||||
# Actor's ID.
|
||||
'actor_id',
|
||||
# Number of tasks executed since last checkpoint.
|
||||
'num_tasks_since_last_checkpoint',
|
||||
# Time elapsed since last checkpoint, in milliseconds.
|
||||
'time_elapsed_ms_since_last_checkpoint',
|
||||
],
|
||||
)
|
||||
"""A namedtuple that contains information about actor's last checkpoint."""
|
||||
|
||||
Checkpoint = namedtuple(
|
||||
'Checkpoint',
|
||||
[
|
||||
# ID of this checkpoint.
|
||||
'checkpoint_id',
|
||||
# The timestamp at which this checkpoint was saved,
|
||||
# represented as milliseconds elapsed since Unix epoch.
|
||||
'timestamp',
|
||||
],
|
||||
)
|
||||
"""A namedtuple that represents a checkpoint."""
|
||||
|
||||
|
||||
class Checkpointable(six.with_metaclass(ABCMeta, object)):
|
||||
"""An interface that indicates an actor can be checkpointed."""
|
||||
|
||||
@abstractmethod
|
||||
def should_checkpoint(self, checkpoint_context):
|
||||
"""Whether this actor needs to be checkpointed.
|
||||
|
||||
This method will be called after every task. You should implement this
|
||||
callback to decide whether this actor needs to be checkpointed at this
|
||||
time, based on the checkpoint context, or any other factors.
|
||||
|
||||
Args:
|
||||
checkpoint_context: A namedtuple that contains info about last
|
||||
checkpoint.
|
||||
|
||||
Returns:
|
||||
A boolean value that indicates whether this actor needs to be
|
||||
checkpointed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_checkpoint(self, actor_id, checkpoint_id):
|
||||
"""Save a checkpoint to persistent storage.
|
||||
|
||||
If `should_checkpoint` returns true, this method will be called. You
|
||||
should implement this callback to save actor's checkpoint and the given
|
||||
checkpoint id to persistent storage.
|
||||
|
||||
Args:
|
||||
actor_id: Actor's ID.
|
||||
checkpoint_id: ID of this checkpoint. You should save it together
|
||||
with actor's checkpoint data. And it will be used by the
|
||||
`load_checkpoint` method.
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_checkpoint(self, actor_id, available_checkpoints):
|
||||
"""Load actor's previous checkpoint, and restore actor's state.
|
||||
|
||||
This method will be called when an actor is reconstructed, after
|
||||
actor's constructor.
|
||||
If the actor needs to restore from previous checkpoint, this function
|
||||
should restore actor's state and return the checkpoint ID. Otherwise,
|
||||
it should do nothing and return None.
|
||||
Note, this method must return one of the checkpoint IDs in the
|
||||
`available_checkpoints` list, or None. Otherwise, an exception will be
|
||||
raised.
|
||||
|
||||
Args:
|
||||
actor_id: Actor's ID.
|
||||
available_checkpoints: A list of `Checkpoint` namedtuples that
|
||||
contains all available checkpoint IDs and their timestamps,
|
||||
sorted by timestamp in descending order.
|
||||
Returns:
|
||||
The ID of the checkpoint from which the actor was resumed, or None
|
||||
if the actor should restart from the beginning.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def checkpoint_expired(self, actor_id, checkpoint_id):
|
||||
"""Delete an expired checkpoint.
|
||||
|
||||
This method will be called when an checkpoint is expired. You should
|
||||
implement this method to delete your application checkpoint data.
|
||||
Note, the maximum number of checkpoints kept in the backend can be
|
||||
configured at `RayConfig.num_actor_checkpoints_to_keep`.
|
||||
|
||||
Args:
|
||||
actor_id: ID of the actor.
|
||||
checkpoint_id: ID of the checkpoint that has expired.
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_checkpoints_for_actor(actor_id):
|
||||
"""Get the available checkpoints for the given actor ID, return a list
|
||||
sorted by checkpoint timestamp in descending order.
|
||||
"""
|
||||
checkpoint_info = ray.worker.global_state.actor_checkpoint_info(actor_id)
|
||||
if checkpoint_info is None:
|
||||
return []
|
||||
checkpoints = [
|
||||
Checkpoint(checkpoint_id, timestamp) for checkpoint_id, timestamp in
|
||||
zip(checkpoint_info['CheckpointIds'], checkpoint_info['Timestamps'])
|
||||
]
|
||||
return sorted(
|
||||
checkpoints,
|
||||
key=lambda checkpoint: checkpoint.timestamp,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,8 @@ import time
|
||||
import ray
|
||||
from ray.function_manager import FunctionDescriptor
|
||||
import ray.gcs_utils
|
||||
import ray.ray_constants as ray_constants
|
||||
|
||||
from ray.ray_constants import ID_SIZE
|
||||
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
|
||||
hex_to_binary)
|
||||
|
||||
@@ -720,7 +721,7 @@ class GlobalState(object):
|
||||
for key in actor_keys:
|
||||
info = self.redis_client.hgetall(key)
|
||||
actor_id = key[len("Actor:"):]
|
||||
assert len(actor_id) == ray_constants.ID_SIZE
|
||||
assert len(actor_id) == ID_SIZE
|
||||
actor_info[binary_to_hex(actor_id)] = {
|
||||
"class_id": binary_to_hex(info[b"class_id"]),
|
||||
"driver_id": binary_to_hex(info[b"driver_id"]),
|
||||
@@ -906,3 +907,42 @@ class GlobalState(object):
|
||||
binary_to_hex(job_id): self._error_messages(ray.DriverID(job_id))
|
||||
for job_id in job_ids
|
||||
}
|
||||
|
||||
def actor_checkpoint_info(self, actor_id):
|
||||
"""Get checkpoint info for the given actor id.
|
||||
Args:
|
||||
actor_id: Actor's ID.
|
||||
Returns:
|
||||
A dictionary with information about the actor's checkpoint IDs and
|
||||
their timestamps.
|
||||
"""
|
||||
self._check_connected()
|
||||
message = self._execute_command(
|
||||
actor_id,
|
||||
"RAY.TABLE_LOOKUP",
|
||||
ray.gcs_utils.TablePrefix.ACTOR_CHECKPOINT_ID,
|
||||
"",
|
||||
actor_id.binary(),
|
||||
)
|
||||
if message is None:
|
||||
return None
|
||||
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
|
||||
message, 0)
|
||||
entry = (
|
||||
ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData(
|
||||
gcs_entry.Entries(0), 0))
|
||||
checkpoint_ids_str = entry.CheckpointIds()
|
||||
num_checkpoints = len(checkpoint_ids_str) // ID_SIZE
|
||||
assert len(checkpoint_ids_str) % ID_SIZE == 0
|
||||
checkpoint_ids = [
|
||||
ray.ActorCheckpointID(
|
||||
checkpoint_ids_str[(i * ID_SIZE):((i + 1) * ID_SIZE)])
|
||||
for i in range(num_checkpoints)
|
||||
]
|
||||
return {
|
||||
"ActorID": ray.utils.binary_to_hex(entry.ActorId()),
|
||||
"CheckpointIds": checkpoint_ids,
|
||||
"Timestamps": [
|
||||
entry.Timestamps(i) for i in range(num_checkpoints)
|
||||
],
|
||||
}
|
||||
|
||||
+104
-40
@@ -510,8 +510,7 @@ class FunctionActorManager(object):
|
||||
self._worker.redis_client.hmset(key, actor_class_info)
|
||||
self._worker.redis_client.rpush("Exports", key)
|
||||
|
||||
def export_actor_class(self, Class, actor_method_names,
|
||||
checkpoint_interval):
|
||||
def export_actor_class(self, Class, actor_method_names):
|
||||
function_descriptor = FunctionDescriptor.from_class(Class)
|
||||
# `task_driver_id` shouldn't be NIL, unless:
|
||||
# 1) This worker isn't an actor;
|
||||
@@ -528,7 +527,6 @@ class FunctionActorManager(object):
|
||||
"class_name": Class.__name__,
|
||||
"module": Class.__module__,
|
||||
"class": pickle.dumps(Class),
|
||||
"checkpoint_interval": checkpoint_interval,
|
||||
"driver_id": driver_id.binary(),
|
||||
"actor_method_names": json.dumps(list(actor_method_names))
|
||||
}
|
||||
@@ -576,17 +574,16 @@ class FunctionActorManager(object):
|
||||
actor_class_key: The key in Redis to use to fetch the actor.
|
||||
"""
|
||||
actor_id = self._worker.actor_id
|
||||
(driver_id_str, class_name, module, pickled_class, checkpoint_interval,
|
||||
(driver_id_str, class_name, module, pickled_class,
|
||||
actor_method_names) = self._worker.redis_client.hmget(
|
||||
actor_class_key, [
|
||||
"driver_id", "class_name", "module", "class",
|
||||
"checkpoint_interval", "actor_method_names"
|
||||
"actor_method_names"
|
||||
])
|
||||
|
||||
class_name = decode(class_name)
|
||||
module = decode(module)
|
||||
driver_id = ray.DriverID(driver_id_str)
|
||||
checkpoint_interval = int(checkpoint_interval)
|
||||
actor_method_names = json.loads(decode(actor_method_names))
|
||||
|
||||
# In Python 2, json loads strings as unicode, so convert them back to
|
||||
@@ -605,7 +602,6 @@ class FunctionActorManager(object):
|
||||
pass
|
||||
|
||||
self._worker.actors[actor_id] = TemporaryActor()
|
||||
self._worker.actor_checkpoint_interval = checkpoint_interval
|
||||
|
||||
def temporary_actor_method(*xs):
|
||||
raise Exception(
|
||||
@@ -694,48 +690,116 @@ class FunctionActorManager(object):
|
||||
# to execute.
|
||||
self._worker.actor_task_counter += 1
|
||||
|
||||
# If this is the first task to execute on the actor, try to resume
|
||||
# from a checkpoint.
|
||||
# Current __init__ will be called by default. So the real function
|
||||
# call will start from 2.
|
||||
if actor_imported and self._worker.actor_task_counter == 2:
|
||||
checkpoint_resumed = ray.actor.restore_and_log_checkpoint(
|
||||
self._worker, actor)
|
||||
if checkpoint_resumed:
|
||||
# NOTE(swang): Since we did not actually execute the
|
||||
# __init__ method, this will put None as the return value.
|
||||
# If the __init__ method is supposed to return multiple
|
||||
# values, an exception will be logged.
|
||||
return
|
||||
|
||||
# Determine whether we should checkpoint the actor.
|
||||
checkpointing_on = (actor_imported
|
||||
and self._worker.actor_checkpoint_interval > 0)
|
||||
# We should checkpoint the actor if user checkpointing is on, we've
|
||||
# executed checkpoint_interval tasks since the last checkpoint, and
|
||||
# the method we're about to execute is not a checkpoint.
|
||||
save_checkpoint = (checkpointing_on
|
||||
and (self._worker.actor_task_counter %
|
||||
self._worker.actor_checkpoint_interval == 0
|
||||
and method_name != "__ray_checkpoint__"))
|
||||
|
||||
# Execute the assigned method and save a checkpoint if necessary.
|
||||
try:
|
||||
if is_class_method(method):
|
||||
method_returns = method(*args)
|
||||
else:
|
||||
method_returns = method(actor, *args)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# Save the checkpoint before allowing the method exception
|
||||
# to be thrown.
|
||||
if save_checkpoint:
|
||||
ray.actor.save_and_log_checkpoint(self._worker, actor)
|
||||
raise
|
||||
if isinstance(actor, ray.actor.Checkpointable):
|
||||
self._save_and_log_checkpoint(actor)
|
||||
raise e
|
||||
else:
|
||||
# Save the checkpoint before returning the method's return
|
||||
# values.
|
||||
if save_checkpoint:
|
||||
ray.actor.save_and_log_checkpoint(self._worker, actor)
|
||||
# Handle any checkpointing operations before storing the
|
||||
# method's return values.
|
||||
# NOTE(swang): If method_returns is a pointer to the actor's
|
||||
# state and the checkpointing operations can modify the return
|
||||
# values if they mutate the actor's state. Is this okay?
|
||||
if isinstance(actor, ray.actor.Checkpointable):
|
||||
# If this is the first task to execute on the actor, try to
|
||||
# resume from a checkpoint.
|
||||
if self._worker.actor_task_counter == 1:
|
||||
if actor_imported:
|
||||
self._restore_and_log_checkpoint(actor)
|
||||
else:
|
||||
# Save the checkpoint before returning the method's
|
||||
# return values.
|
||||
self._save_and_log_checkpoint(actor)
|
||||
return method_returns
|
||||
|
||||
return actor_method_executor
|
||||
|
||||
def _save_and_log_checkpoint(self, actor):
|
||||
"""Save an actor checkpoint if necessary and log any errors.
|
||||
|
||||
Args:
|
||||
actor: The actor to checkpoint.
|
||||
|
||||
Returns:
|
||||
The result of the actor's user-defined `save_checkpoint` method.
|
||||
"""
|
||||
actor_id = self._worker.actor_id
|
||||
checkpoint_info = self._worker.actor_checkpoint_info[actor_id]
|
||||
checkpoint_info.num_tasks_since_last_checkpoint += 1
|
||||
now = int(1000 * time.time())
|
||||
checkpoint_context = ray.actor.CheckpointContext(
|
||||
actor_id, checkpoint_info.num_tasks_since_last_checkpoint,
|
||||
now - checkpoint_info.last_checkpoint_timestamp)
|
||||
# If we should take a checkpoint, notify raylet to prepare a checkpoint
|
||||
# and then call `save_checkpoint`.
|
||||
if actor.should_checkpoint(checkpoint_context):
|
||||
try:
|
||||
now = int(1000 * time.time())
|
||||
checkpoint_id = (self._worker.raylet_client.
|
||||
prepare_actor_checkpoint(actor_id))
|
||||
checkpoint_info.checkpoint_ids.append(checkpoint_id)
|
||||
actor.save_checkpoint(actor_id, checkpoint_id)
|
||||
if (len(checkpoint_info.checkpoint_ids) >
|
||||
ray._config.num_actor_checkpoints_to_keep()):
|
||||
actor.checkpoint_expired(
|
||||
actor_id,
|
||||
checkpoint_info.checkpoint_ids.pop(0),
|
||||
)
|
||||
checkpoint_info.num_tasks_since_last_checkpoint = 0
|
||||
checkpoint_info.last_checkpoint_timestamp = now
|
||||
except Exception:
|
||||
# Checkpoint save or reload failed. Notify the driver.
|
||||
traceback_str = ray.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
ray.utils.push_error_to_driver(
|
||||
self._worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=self._worker.task_driver_id)
|
||||
|
||||
def _restore_and_log_checkpoint(self, actor):
|
||||
"""Restore an actor from a checkpoint if available and log any errors.
|
||||
|
||||
This should only be called on workers that have just executed an actor
|
||||
creation task.
|
||||
|
||||
Args:
|
||||
actor: The actor to restore from a checkpoint.
|
||||
"""
|
||||
actor_id = self._worker.actor_id
|
||||
try:
|
||||
checkpoints = ray.actor.get_checkpoints_for_actor(actor_id)
|
||||
if len(checkpoints) > 0:
|
||||
# If we found previously saved checkpoints for this actor,
|
||||
# call the `load_checkpoint` callback.
|
||||
checkpoint_id = actor.load_checkpoint(actor_id, checkpoints)
|
||||
if checkpoint_id is not None:
|
||||
# Check that the returned checkpoint id is in the
|
||||
# `available_checkpoints` list.
|
||||
msg = (
|
||||
"`load_checkpoint` must return a checkpoint id that " +
|
||||
"exists in the `available_checkpoints` list, or eone.")
|
||||
assert any(checkpoint_id == checkpoint.checkpoint_id
|
||||
for checkpoint in checkpoints), msg
|
||||
# Notify raylet that this actor has been resumed from
|
||||
# a checkpoint.
|
||||
(self._worker.raylet_client.
|
||||
notify_actor_resumed_from_checkpoint(
|
||||
actor_id, checkpoint_id))
|
||||
except Exception:
|
||||
# Checkpoint save or reload failed. Notify the driver.
|
||||
traceback_str = ray.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
ray.utils.push_error_to_driver(
|
||||
self._worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=self._worker.task_driver_id)
|
||||
|
||||
+15
-4
@@ -5,6 +5,7 @@ from __future__ import print_function
|
||||
import flatbuffers
|
||||
import ray.core.generated.ErrorTableData
|
||||
|
||||
from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData
|
||||
from ray.core.generated.ClientTableData import ClientTableData
|
||||
from ray.core.generated.DriverTableData import DriverTableData
|
||||
from ray.core.generated.ErrorTableData import ErrorTableData
|
||||
@@ -20,10 +21,20 @@ from ray.core.generated.TablePubsub import TablePubsub
|
||||
from ray.core.generated.ray.protocol.Task import Task
|
||||
|
||||
__all__ = [
|
||||
"GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData",
|
||||
"HeartbeatBatchTableData", "DriverTableData", "ProfileTableData",
|
||||
"ObjectTableData", "Task", "TablePrefix", "TablePubsub", "Language",
|
||||
"construct_error_message"
|
||||
"ActorCheckpointIdData",
|
||||
"ClientTableData",
|
||||
"DriverTableData",
|
||||
"ErrorTableData",
|
||||
"GcsTableEntry",
|
||||
"HeartbeatBatchTableData",
|
||||
"HeartbeatTableData",
|
||||
"Language",
|
||||
"ObjectTableData",
|
||||
"ProfileTableData",
|
||||
"TablePrefix",
|
||||
"TablePubsub",
|
||||
"Task",
|
||||
"construct_error_message",
|
||||
]
|
||||
|
||||
FUNCTION_PREFIX = "RemoteFunction:"
|
||||
|
||||
@@ -6,10 +6,19 @@ from libcpp.unordered_map cimport unordered_map
|
||||
from libcpp.vector cimport vector as c_vector
|
||||
|
||||
from ray.includes.unique_ids cimport (
|
||||
CUniqueID, TaskID as CTaskID, ObjectID as CObjectID,
|
||||
FunctionID as CFunctionID, ActorClassID as CActorClassID, ActorID as CActorID,
|
||||
ActorHandleID as CActorHandleID, WorkerID as CWorkerID,
|
||||
DriverID as CDriverID, ConfigID as CConfigID, ClientID as CClientID)
|
||||
ActorCheckpointID as CActorCheckpointID,
|
||||
ActorClassID as CActorClassID,
|
||||
ActorHandleID as CActorHandleID,
|
||||
ActorID as CActorID,
|
||||
CUniqueID,
|
||||
ClientID as CClientID,
|
||||
ConfigID as CConfigID,
|
||||
DriverID as CDriverID,
|
||||
FunctionID as CFunctionID,
|
||||
ObjectID as CObjectID,
|
||||
TaskID as CTaskID,
|
||||
WorkerID as CWorkerID,
|
||||
)
|
||||
|
||||
|
||||
cdef extern from "ray/status.h" namespace "ray" nogil:
|
||||
|
||||
@@ -8,9 +8,21 @@ from libcpp.vector cimport vector as c_vector
|
||||
|
||||
|
||||
from ray.includes.common cimport (
|
||||
CUniqueID, CTaskID, CObjectID, CFunctionID, CActorClassID, CActorID,
|
||||
CActorHandleID, CWorkerID, CDriverID, CConfigID, CClientID,
|
||||
CLanguage, CRayStatus)
|
||||
CActorCheckpointID,
|
||||
CActorClassID,
|
||||
CActorHandleID,
|
||||
CActorID,
|
||||
CClientID,
|
||||
CConfigID,
|
||||
CDriverID,
|
||||
CFunctionID,
|
||||
CLanguage,
|
||||
CObjectID,
|
||||
CRayStatus,
|
||||
CTaskID,
|
||||
CUniqueID,
|
||||
CWorkerID,
|
||||
)
|
||||
from ray.includes.task cimport CTaskSpecification
|
||||
|
||||
|
||||
@@ -57,6 +69,10 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
|
||||
CRayStatus PushProfileEvents(const GCSProfileTableDataT &profile_events)
|
||||
CRayStatus FreeObjects(const c_vector[CObjectID] &object_ids,
|
||||
c_bool local_only)
|
||||
CRayStatus PrepareActorCheckpoint(const CActorID &actor_id,
|
||||
CActorCheckpointID &checkpoint_id)
|
||||
CRayStatus NotifyActorResumedFromCheckpoint(
|
||||
const CActorID &actor_id, const CActorCheckpointID &checkpoint_id)
|
||||
CLanguage GetLanguage() const
|
||||
CClientID GetClientID() const
|
||||
CDriverID GetDriverID() const
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from libc.stdint cimport int64_t, uint64_t
|
||||
from libc.stdint cimport int64_t, uint64_t, uint32_t
|
||||
from libcpp.string cimport string as c_string
|
||||
from libcpp.unordered_map cimport unordered_map
|
||||
|
||||
@@ -80,4 +80,6 @@ cdef extern from "ray/ray_config.h" nogil:
|
||||
|
||||
int64_t max_task_lease_timeout_ms() const
|
||||
|
||||
uint32_t num_actor_checkpoints_to_keep() const
|
||||
|
||||
void initialize(const unordered_map[c_string, int] &config_map)
|
||||
|
||||
@@ -144,3 +144,7 @@ cdef class Config:
|
||||
@staticmethod
|
||||
def max_task_lease_timeout_ms():
|
||||
return RayConfig.instance().max_task_lease_timeout_ms()
|
||||
|
||||
@staticmethod
|
||||
def num_actor_checkpoints_to_keep():
|
||||
return RayConfig.instance().num_actor_checkpoints_to_keep()
|
||||
|
||||
@@ -28,6 +28,7 @@ cdef extern from "ray/id.h" namespace "ray" nogil:
|
||||
ctypedef CUniqueID ActorID
|
||||
ctypedef CUniqueID ActorClassID
|
||||
ctypedef CUniqueID ActorHandleID
|
||||
ctypedef CUniqueID ActorCheckpointID
|
||||
ctypedef CUniqueID WorkerID
|
||||
ctypedef CUniqueID DriverID
|
||||
ctypedef CUniqueID ConfigID
|
||||
|
||||
@@ -7,9 +7,21 @@ See https://github.com/ray-project/ray/issues/3721.
|
||||
# WARNING: Any additional ID types defined in this file must be added to the
|
||||
# _ID_TYPES list at the bottom of this file.
|
||||
from ray.includes.common cimport (
|
||||
CUniqueID, CTaskID, CObjectID, CFunctionID, CActorClassID, CActorID,
|
||||
CActorHandleID, CWorkerID, CDriverID, CConfigID, CClientID,
|
||||
ComputePutId, ComputeTaskId)
|
||||
CActorCheckpointID,
|
||||
CActorClassID,
|
||||
CActorHandleID,
|
||||
CActorID,
|
||||
CClientID,
|
||||
CConfigID,
|
||||
CDriverID,
|
||||
CFunctionID,
|
||||
CObjectID,
|
||||
CTaskID,
|
||||
CUniqueID,
|
||||
CWorkerID,
|
||||
ComputePutId,
|
||||
ComputeTaskId,
|
||||
)
|
||||
|
||||
from ray.utils import decode
|
||||
|
||||
@@ -236,6 +248,29 @@ cdef class ActorHandleID(UniqueID):
|
||||
return "ActorHandleID(" + self.hex() + ")"
|
||||
|
||||
|
||||
cdef class ActorCheckpointID(UniqueID):
|
||||
|
||||
def __init__(self, id):
|
||||
if not id:
|
||||
self.data = CUniqueID()
|
||||
else:
|
||||
check_id(id)
|
||||
self.data = CUniqueID.from_binary(id)
|
||||
|
||||
@staticmethod
|
||||
cdef from_native(const CActorCheckpointID& cpp_id):
|
||||
cdef ActorCheckpointID self = ActorCheckpointID.__new__(ActorHandleID)
|
||||
self.data = cpp_id
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def nil():
|
||||
return ActorCheckpointID.from_native(CActorCheckpointID.nil())
|
||||
|
||||
def __repr__(self):
|
||||
return "ActorCheckpointID(" + self.hex() + ")"
|
||||
|
||||
|
||||
cdef class FunctionID(UniqueID):
|
||||
|
||||
def __init__(self, id):
|
||||
|
||||
+29
-12
@@ -117,6 +117,25 @@ class RayTaskError(Exception):
|
||||
return "\n".join(out)
|
||||
|
||||
|
||||
class ActorCheckpointInfo(object):
|
||||
"""Information used to maintain actor checkpoints."""
|
||||
|
||||
__slots__ = [
|
||||
# Number of tasks executed since last checkpoint.
|
||||
"num_tasks_since_last_checkpoint",
|
||||
# Timestamp of the last checkpoint, in milliseconds.
|
||||
"last_checkpoint_timestamp",
|
||||
# IDs of the previous checkpoints.
|
||||
"checkpoint_ids",
|
||||
]
|
||||
|
||||
def __init__(self, num_tasks_since_last_checkpoint,
|
||||
last_checkpoint_timestamp, checkpoint_ids):
|
||||
self.num_tasks_since_last_checkpoint = num_tasks_since_last_checkpoint
|
||||
self.last_checkpoint_timestamp = last_checkpoint_timestamp
|
||||
self.checkpoint_ids = checkpoint_ids
|
||||
|
||||
|
||||
class Worker(object):
|
||||
"""A class used to define the control flow of a worker process.
|
||||
|
||||
@@ -141,6 +160,8 @@ class Worker(object):
|
||||
self.actor_init_error = None
|
||||
self.make_actor = None
|
||||
self.actors = {}
|
||||
# Information used to maintain actor checkpoints.
|
||||
self.actor_checkpoint_info = {}
|
||||
self.actor_task_counter = 0
|
||||
# The number of threads Plasma should use when putting an object in the
|
||||
# object store.
|
||||
@@ -515,7 +536,6 @@ class Worker(object):
|
||||
actor_id=None,
|
||||
actor_handle_id=None,
|
||||
actor_counter=0,
|
||||
is_actor_checkpoint_method=False,
|
||||
actor_creation_id=None,
|
||||
actor_creation_dummy_object_id=None,
|
||||
max_actor_reconstructions=0,
|
||||
@@ -538,8 +558,6 @@ class Worker(object):
|
||||
be serializable objects.
|
||||
actor_id: The ID of the actor that this task is for.
|
||||
actor_counter: The counter of the actor task.
|
||||
is_actor_checkpoint_method: True if this is an actor checkpoint
|
||||
task and false otherwise.
|
||||
actor_creation_id: The ID of the actor to create, if this is an
|
||||
actor creation task.
|
||||
actor_creation_dummy_object_id: If this task is an actor method,
|
||||
@@ -900,6 +918,11 @@ class Worker(object):
|
||||
self.actor_creation_task_id = task.task_id()
|
||||
self.function_actor_manager.load_actor(driver_id,
|
||||
function_descriptor)
|
||||
self.actor_checkpoint_info[self.actor_id] = ActorCheckpointInfo(
|
||||
num_tasks_since_last_checkpoint=0,
|
||||
last_checkpoint_timestamp=int(1000 * time.time()),
|
||||
checkpoint_ids=[],
|
||||
)
|
||||
|
||||
execution_info = self.function_actor_manager.get_execution_info(
|
||||
driver_id, function_descriptor)
|
||||
@@ -2395,16 +2418,12 @@ def make_decorator(num_return_vals=None,
|
||||
num_gpus=None,
|
||||
resources=None,
|
||||
max_calls=None,
|
||||
checkpoint_interval=None,
|
||||
max_reconstructions=None,
|
||||
worker=None):
|
||||
def decorator(function_or_class):
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
# Set the remote function default resources.
|
||||
if checkpoint_interval is not None:
|
||||
raise Exception("The keyword 'checkpoint_interval' is not "
|
||||
"allowed for remote functions.")
|
||||
if max_reconstructions is not None:
|
||||
raise Exception("The keyword 'max_reconstructions' is not "
|
||||
"allowed for remote functions.")
|
||||
@@ -2437,7 +2456,7 @@ def make_decorator(num_return_vals=None,
|
||||
|
||||
return worker.make_actor(function_or_class, cpus_to_use, num_gpus,
|
||||
resources, actor_method_cpus,
|
||||
checkpoint_interval, max_reconstructions)
|
||||
max_reconstructions)
|
||||
|
||||
raise Exception("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
@@ -2509,7 +2528,7 @@ def remote(*args, **kwargs):
|
||||
"with no arguments and no parentheses, for example "
|
||||
"'@ray.remote', or it must be applied using some of "
|
||||
"the arguments 'num_return_vals', 'num_cpus', 'num_gpus', "
|
||||
"'resources', 'max_calls', 'checkpoint_interval',"
|
||||
"'resources', 'max_calls', "
|
||||
"or 'max_reconstructions', like "
|
||||
"'@ray.remote(num_return_vals=2, "
|
||||
"resources={\"CustomResource\": 1})'.")
|
||||
@@ -2517,7 +2536,7 @@ def remote(*args, **kwargs):
|
||||
for key in kwargs:
|
||||
assert key in [
|
||||
"num_return_vals", "num_cpus", "num_gpus", "resources",
|
||||
"max_calls", "checkpoint_interval", "max_reconstructions"
|
||||
"max_calls", "max_reconstructions"
|
||||
], error_string
|
||||
|
||||
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else None
|
||||
@@ -2534,7 +2553,6 @@ def remote(*args, **kwargs):
|
||||
# Handle other arguments.
|
||||
num_return_vals = kwargs.get("num_return_vals")
|
||||
max_calls = kwargs.get("max_calls")
|
||||
checkpoint_interval = kwargs.get("checkpoint_interval")
|
||||
max_reconstructions = kwargs.get("max_reconstructions")
|
||||
|
||||
return make_decorator(
|
||||
@@ -2543,6 +2561,5 @@ def remote(*args, **kwargs):
|
||||
num_gpus=num_gpus,
|
||||
resources=resources,
|
||||
max_calls=max_calls,
|
||||
checkpoint_interval=checkpoint_interval,
|
||||
max_reconstructions=max_reconstructions,
|
||||
worker=worker)
|
||||
|
||||
Reference in New Issue
Block a user