mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:18:59 +08:00
678 lines
30 KiB
Python
678 lines
30 KiB
Python
import dis
|
|
import hashlib
|
|
import importlib
|
|
import inspect
|
|
import json
|
|
import logging
|
|
import sys
|
|
import time
|
|
import threading
|
|
import traceback
|
|
from collections import (
|
|
namedtuple,
|
|
defaultdict,
|
|
)
|
|
|
|
import ray
|
|
from ray import profiling
|
|
from ray import ray_constants
|
|
from ray import cloudpickle as pickle
|
|
from ray._raylet import PythonFunctionDescriptor
|
|
from ray.utils import (
|
|
is_function_or_method,
|
|
is_class_method,
|
|
is_static_method,
|
|
check_oversized_pickle,
|
|
decode,
|
|
ensure_str,
|
|
format_error_message,
|
|
push_error_to_driver,
|
|
)
|
|
|
|
FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
|
|
["function", "function_name", "max_calls"])
|
|
"""FunctionExecutionInfo: A named tuple storing remote function information."""
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FunctionActorManager:
|
|
"""A class used to export/load remote functions and actors.
|
|
|
|
Attributes:
|
|
_worker: The associated worker that this manager related.
|
|
_functions_to_export: The remote functions to export when
|
|
the worker gets connected.
|
|
_actors_to_export: The actors to export when the worker gets
|
|
connected.
|
|
_function_execution_info: The map from job_id to function_id
|
|
and execution_info.
|
|
_num_task_executions: The map from job_id to function
|
|
execution times.
|
|
imported_actor_classes: The set of actor classes keys (format:
|
|
ActorClass:function_id) that are already in GCS.
|
|
"""
|
|
|
|
def __init__(self, worker):
|
|
self._worker = worker
|
|
self._functions_to_export = []
|
|
self._actors_to_export = []
|
|
# This field is a dictionary that maps a driver ID to a dictionary of
|
|
# functions (and information about those functions) that have been
|
|
# registered for that driver (this inner dictionary maps function IDs
|
|
# to a FunctionExecutionInfo object. This should only be used on
|
|
# workers that execute remote functions.
|
|
self._function_execution_info = defaultdict(lambda: {})
|
|
self._num_task_executions = defaultdict(lambda: {})
|
|
# A set of all of the actor class keys that have been imported by the
|
|
# import thread. It is safe to convert this worker into an actor of
|
|
# these types.
|
|
self.imported_actor_classes = set()
|
|
self._loaded_actor_classes = {}
|
|
# Deserialize an ActorHandle will call load_actor_class(). If a
|
|
# function closure captured an ActorHandle, the deserialization of the
|
|
# function will be:
|
|
# import_thread.py
|
|
# -> fetch_and_register_remote_function (acquire lock)
|
|
# -> _load_actor_class_from_gcs (acquire lock, too)
|
|
# So, the lock should be a reentrant lock.
|
|
self.lock = threading.RLock()
|
|
self.execution_infos = {}
|
|
|
|
def increase_task_counter(self, job_id, function_descriptor):
|
|
function_id = function_descriptor.function_id
|
|
if self._worker.load_code_from_local:
|
|
job_id = ray.JobID.nil()
|
|
self._num_task_executions[job_id][function_id] += 1
|
|
|
|
def get_task_counter(self, job_id, function_descriptor):
|
|
function_id = function_descriptor.function_id
|
|
if self._worker.load_code_from_local:
|
|
job_id = ray.JobID.nil()
|
|
return self._num_task_executions[job_id][function_id]
|
|
|
|
def compute_collision_identifier(self, function_or_class):
|
|
"""The identifier is used to detect excessive duplicate exports.
|
|
|
|
The identifier is used to determine when the same function or class is
|
|
exported many times. This can yield false positives.
|
|
|
|
Args:
|
|
function_or_class: The function or class to compute an identifier
|
|
for.
|
|
|
|
Returns:
|
|
The identifier. Note that different functions or classes can give
|
|
rise to same identifier. However, the same function should
|
|
hopefully always give rise to the same identifier. TODO(rkn):
|
|
verify if this is actually the case. Note that if the
|
|
identifier is incorrect in any way, then we may give warnings
|
|
unnecessarily or fail to give warnings, but the application's
|
|
behavior won't change.
|
|
"""
|
|
import io
|
|
string_file = io.StringIO()
|
|
if sys.version_info[1] >= 7:
|
|
dis.dis(function_or_class, file=string_file, depth=2)
|
|
else:
|
|
dis.dis(function_or_class, file=string_file)
|
|
collision_identifier = (
|
|
function_or_class.__name__ + ":" + string_file.getvalue())
|
|
|
|
# Return a hash of the identifier in case it is too large.
|
|
return hashlib.sha1(collision_identifier.encode("ascii")).digest()
|
|
|
|
def export(self, remote_function):
|
|
"""Pickle a remote function and export it to redis.
|
|
|
|
Args:
|
|
remote_function: the RemoteFunction object.
|
|
"""
|
|
if self._worker.mode == ray.worker.LOCAL_MODE:
|
|
return
|
|
if self._worker.load_code_from_local:
|
|
return
|
|
|
|
function = remote_function._function
|
|
pickled_function = pickle.dumps(function)
|
|
|
|
check_oversized_pickle(pickled_function,
|
|
remote_function._function_name,
|
|
"remote function", self._worker)
|
|
key = (b"RemoteFunction:" + self._worker.current_job_id.binary() + b":"
|
|
+ remote_function._function_descriptor.function_id.binary())
|
|
self._worker.redis_client.hmset(
|
|
key, {
|
|
"job_id": self._worker.current_job_id.binary(),
|
|
"function_id": remote_function._function_descriptor.
|
|
function_id.binary(),
|
|
"function_name": remote_function._function_name,
|
|
"module": function.__module__,
|
|
"function": pickled_function,
|
|
"collision_identifier": self.compute_collision_identifier(
|
|
function),
|
|
"max_calls": remote_function._max_calls
|
|
})
|
|
self._worker.redis_client.rpush("Exports", key)
|
|
|
|
def fetch_and_register_remote_function(self, key):
|
|
"""Import a remote function."""
|
|
(job_id_str, function_id_str, function_name, serialized_function,
|
|
num_return_vals, module, resources,
|
|
max_calls) = self._worker.redis_client.hmget(key, [
|
|
"job_id", "function_id", "function_name", "function",
|
|
"num_return_vals", "module", "resources", "max_calls"
|
|
])
|
|
function_id = ray.FunctionID(function_id_str)
|
|
job_id = ray.JobID(job_id_str)
|
|
function_name = decode(function_name)
|
|
max_calls = int(max_calls)
|
|
module = decode(module)
|
|
|
|
# This is a placeholder in case the function can't be unpickled. This
|
|
# will be overwritten if the function is successfully registered.
|
|
def f(*args, **kwargs):
|
|
raise Exception("This function was not imported properly.")
|
|
|
|
# This function is called by ImportThread. This operation needs to be
|
|
# atomic. Otherwise, there is race condition. Another thread may use
|
|
# the temporary function above before the real function is ready.
|
|
with self.lock:
|
|
self._function_execution_info[job_id][function_id] = (
|
|
FunctionExecutionInfo(
|
|
function=f,
|
|
function_name=function_name,
|
|
max_calls=max_calls))
|
|
self._num_task_executions[job_id][function_id] = 0
|
|
|
|
try:
|
|
function = pickle.loads(serialized_function)
|
|
except Exception:
|
|
# If an exception was thrown when the remote function was
|
|
# imported, we record the traceback and notify the scheduler
|
|
# of the failure.
|
|
traceback_str = format_error_message(traceback.format_exc())
|
|
# Log the error message.
|
|
push_error_to_driver(
|
|
self._worker,
|
|
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
|
|
"Failed to unpickle the remote function '{}' with "
|
|
"function ID {}. Traceback:\n{}".format(
|
|
function_name, function_id.hex(), traceback_str),
|
|
job_id=job_id)
|
|
else:
|
|
# The below line is necessary. Because in the driver process,
|
|
# if the function is defined in the file where the python
|
|
# script was started from, its module is `__main__`.
|
|
# However in the worker process, the `__main__` module is a
|
|
# different module, which is `default_worker.py`
|
|
function.__module__ = module
|
|
self._function_execution_info[job_id][function_id] = (
|
|
FunctionExecutionInfo(
|
|
function=function,
|
|
function_name=function_name,
|
|
max_calls=max_calls))
|
|
# Add the function to the function table.
|
|
self._worker.redis_client.rpush(
|
|
b"FunctionTable:" + function_id.binary(),
|
|
self._worker.worker_id)
|
|
|
|
def get_execution_info(self, job_id, function_descriptor):
|
|
"""Get the FunctionExecutionInfo of a remote function.
|
|
|
|
Args:
|
|
job_id: ID of the job that the function belongs to.
|
|
function_descriptor: The FunctionDescriptor of the function to get.
|
|
|
|
Returns:
|
|
A FunctionExecutionInfo object.
|
|
"""
|
|
if self._worker.load_code_from_local:
|
|
# Load function from local code.
|
|
# Currently, we don't support isolating code by jobs,
|
|
# thus always set job ID to NIL here.
|
|
job_id = ray.JobID.nil()
|
|
if not function_descriptor.is_actor_method():
|
|
self._load_function_from_local(job_id, function_descriptor)
|
|
else:
|
|
# Load function from GCS.
|
|
# Wait until the function to be executed has actually been
|
|
# registered on this worker. We will push warnings to the user if
|
|
# we spend too long in this loop.
|
|
# The driver function may not be found in sys.path. Try to load
|
|
# the function from GCS.
|
|
with profiling.profile("wait_for_function"):
|
|
self._wait_for_function(function_descriptor, job_id)
|
|
try:
|
|
function_id = function_descriptor.function_id
|
|
info = self._function_execution_info[job_id][function_id]
|
|
except KeyError as e:
|
|
message = ("Error occurs in get_execution_info: "
|
|
"job_id: %s, function_descriptor: %s. Message: %s" %
|
|
(job_id, function_descriptor, e))
|
|
raise KeyError(message)
|
|
return info
|
|
|
|
def _load_function_from_local(self, job_id, function_descriptor):
|
|
assert not function_descriptor.is_actor_method()
|
|
function_id = function_descriptor.function_id
|
|
if (job_id in self._function_execution_info
|
|
and function_id in self._function_execution_info[function_id]):
|
|
return
|
|
module_name, function_name = (
|
|
function_descriptor.module_name,
|
|
function_descriptor.function_name,
|
|
)
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
function = getattr(module, function_name)._function
|
|
self._function_execution_info[job_id][function_id] = (
|
|
FunctionExecutionInfo(
|
|
function=function,
|
|
function_name=function_name,
|
|
max_calls=0,
|
|
))
|
|
self._num_task_executions[job_id][function_id] = 0
|
|
except Exception:
|
|
logger.exception("Failed to load function %s.", function_name)
|
|
raise Exception(
|
|
"Function {} failed to be loaded from local code.".format(
|
|
function_descriptor))
|
|
|
|
def _wait_for_function(self, function_descriptor, job_id, timeout=10):
|
|
"""Wait until the function to be executed is present on this worker.
|
|
|
|
This method will simply loop until the import thread has imported the
|
|
relevant function. If we spend too long in this loop, that may indicate
|
|
a problem somewhere and we will push an error message to the user.
|
|
|
|
If this worker is an actor, then this will wait until the actor has
|
|
been defined.
|
|
|
|
Args:
|
|
function_descriptor : The FunctionDescriptor of the function that
|
|
we want to execute.
|
|
job_id (str): The ID of the job to push the error message to
|
|
if this times out.
|
|
"""
|
|
start_time = time.time()
|
|
# Only send the warning once.
|
|
warning_sent = False
|
|
while True:
|
|
with self.lock:
|
|
if (self._worker.actor_id.is_nil()
|
|
and (function_descriptor.function_id in
|
|
self._function_execution_info[job_id])):
|
|
break
|
|
elif not self._worker.actor_id.is_nil() and (
|
|
self._worker.actor_id in self._worker.actors):
|
|
break
|
|
if time.time() - start_time > timeout:
|
|
warning_message = ("This worker was asked to execute a "
|
|
"function that it does not have "
|
|
"registered. You may have to restart "
|
|
"Ray.")
|
|
if not warning_sent:
|
|
ray.utils.push_error_to_driver(
|
|
self._worker,
|
|
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
|
|
warning_message,
|
|
job_id=job_id)
|
|
warning_sent = True
|
|
time.sleep(0.001)
|
|
|
|
def _publish_actor_class_to_key(self, key, actor_class_info):
|
|
"""Push an actor class definition to Redis.
|
|
|
|
The is factored out as a separate function because it is also called
|
|
on cached actor class definitions when a worker connects for the first
|
|
time.
|
|
|
|
Args:
|
|
key: The key to store the actor class info at.
|
|
actor_class_info: Information about the actor class.
|
|
"""
|
|
# We set the driver ID here because it may not have been available when
|
|
# the actor class was defined.
|
|
self._worker.redis_client.hmset(key, actor_class_info)
|
|
self._worker.redis_client.rpush("Exports", key)
|
|
|
|
def export_actor_class(self, Class, actor_creation_function_descriptor,
|
|
actor_method_names):
|
|
if self._worker.load_code_from_local:
|
|
return
|
|
# `current_job_id` shouldn't be NIL, unless:
|
|
# 1) This worker isn't an actor;
|
|
# 2) And a previous task started a background thread, which didn't
|
|
# finish before the task finished, and still uses Ray API
|
|
# after that.
|
|
assert not self._worker.current_job_id.is_nil(), (
|
|
"You might have started a background thread in a non-actor task, "
|
|
"please make sure the thread finishes before the task finishes.")
|
|
job_id = self._worker.current_job_id
|
|
key = (b"ActorClass:" + job_id.binary() + b":" +
|
|
actor_creation_function_descriptor.function_id.binary())
|
|
actor_class_info = {
|
|
"class_name": actor_creation_function_descriptor.class_name,
|
|
"module": actor_creation_function_descriptor.module_name,
|
|
"class": pickle.dumps(Class),
|
|
"job_id": job_id.binary(),
|
|
"collision_identifier": self.compute_collision_identifier(Class),
|
|
"actor_method_names": json.dumps(list(actor_method_names))
|
|
}
|
|
|
|
check_oversized_pickle(actor_class_info["class"],
|
|
actor_class_info["class_name"], "actor",
|
|
self._worker)
|
|
|
|
self._publish_actor_class_to_key(key, actor_class_info)
|
|
# TODO(rkn): Currently we allow actor classes to be defined
|
|
# within tasks. I tried to disable this, but it may be necessary
|
|
# because of https://github.com/ray-project/ray/issues/1146.
|
|
|
|
def load_actor_class(self, job_id, actor_creation_function_descriptor):
|
|
"""Load the actor class.
|
|
|
|
Args:
|
|
job_id: job ID of the actor.
|
|
actor_creation_function_descriptor: Function descriptor of
|
|
the actor constructor.
|
|
|
|
Returns:
|
|
The actor class.
|
|
"""
|
|
function_id = actor_creation_function_descriptor.function_id
|
|
# Check if the actor class already exists in the cache.
|
|
actor_class = self._loaded_actor_classes.get(function_id, None)
|
|
if actor_class is None:
|
|
# Load actor class.
|
|
if self._worker.load_code_from_local:
|
|
job_id = ray.JobID.nil()
|
|
# Load actor class from local code.
|
|
actor_class = self._load_actor_class_from_local(
|
|
job_id, actor_creation_function_descriptor)
|
|
else:
|
|
# Load actor class from GCS.
|
|
actor_class = self._load_actor_class_from_gcs(
|
|
job_id, actor_creation_function_descriptor)
|
|
# Save the loaded actor class in cache.
|
|
self._loaded_actor_classes[function_id] = actor_class
|
|
|
|
# Generate execution info for the methods of this actor class.
|
|
module_name = actor_creation_function_descriptor.module_name
|
|
actor_class_name = actor_creation_function_descriptor.class_name
|
|
actor_methods = inspect.getmembers(
|
|
actor_class, predicate=is_function_or_method)
|
|
for actor_method_name, actor_method in actor_methods:
|
|
# Actor creation function descriptor use a unique function
|
|
# hash to solve actor name conflict. When constructing an
|
|
# actor, the actor creation function descriptor will be the
|
|
# key to find __init__ method execution info. So, here we
|
|
# use actor creation function descriptor as method descriptor
|
|
# for generating __init__ method execution info.
|
|
if actor_method_name == "__init__":
|
|
method_descriptor = actor_creation_function_descriptor
|
|
else:
|
|
method_descriptor = PythonFunctionDescriptor(
|
|
module_name, actor_method_name, actor_class_name)
|
|
method_id = method_descriptor.function_id
|
|
executor = self._make_actor_method_executor(
|
|
actor_method_name,
|
|
actor_method,
|
|
actor_imported=True,
|
|
)
|
|
self._function_execution_info[job_id][method_id] = (
|
|
FunctionExecutionInfo(
|
|
function=executor,
|
|
function_name=actor_method_name,
|
|
max_calls=0,
|
|
))
|
|
self._num_task_executions[job_id][method_id] = 0
|
|
self._num_task_executions[job_id][function_id] = 0
|
|
return actor_class
|
|
|
|
def _load_actor_class_from_local(self, job_id,
|
|
actor_creation_function_descriptor):
|
|
"""Load actor class from local code."""
|
|
assert isinstance(job_id, ray.JobID)
|
|
module_name, class_name = (
|
|
actor_creation_function_descriptor.module_name,
|
|
actor_creation_function_descriptor.class_name)
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
actor_class = getattr(module, class_name)
|
|
if isinstance(actor_class, ray.actor.ActorClass):
|
|
return actor_class.__ray_metadata__.modified_class
|
|
else:
|
|
return actor_class
|
|
except Exception:
|
|
logger.exception("Failed to load actor_class %s.", class_name)
|
|
raise Exception(
|
|
"Actor {} failed to be imported from local code.".format(
|
|
class_name))
|
|
|
|
def _create_fake_actor_class(self, actor_class_name, actor_method_names):
|
|
class TemporaryActor:
|
|
pass
|
|
|
|
def temporary_actor_method(*args, **kwargs):
|
|
raise Exception(
|
|
"The actor with name {} failed to be imported, "
|
|
"and so cannot execute this method.".format(actor_class_name))
|
|
|
|
for method in actor_method_names:
|
|
setattr(TemporaryActor, method, temporary_actor_method)
|
|
|
|
return TemporaryActor
|
|
|
|
def _load_actor_class_from_gcs(self, job_id,
|
|
actor_creation_function_descriptor):
|
|
"""Load actor class from GCS."""
|
|
key = (b"ActorClass:" + job_id.binary() + b":" +
|
|
actor_creation_function_descriptor.function_id.binary())
|
|
# Wait for the actor class key to have been imported by the
|
|
# import thread. TODO(rkn): It shouldn't be possible to end
|
|
# up in an infinite loop here, but we should push an error to
|
|
# the driver if too much time is spent here.
|
|
while key not in self.imported_actor_classes:
|
|
time.sleep(0.001)
|
|
|
|
# Fetch raw data from GCS.
|
|
(job_id_str, class_name, module, pickled_class,
|
|
actor_method_names) = self._worker.redis_client.hmget(
|
|
key,
|
|
["job_id", "class_name", "module", "class", "actor_method_names"])
|
|
|
|
class_name = ensure_str(class_name)
|
|
module_name = ensure_str(module)
|
|
job_id = ray.JobID(job_id_str)
|
|
actor_method_names = json.loads(ensure_str(actor_method_names))
|
|
|
|
actor_class = None
|
|
try:
|
|
with self.lock:
|
|
actor_class = pickle.loads(pickled_class)
|
|
except Exception:
|
|
logger.exception("Failed to load actor class %s.", class_name)
|
|
# The actor class failed to be unpickled, create a fake actor
|
|
# class instead (just to produce error messages and to prevent
|
|
# the driver from hanging).
|
|
actor_class = self._create_fake_actor_class(
|
|
class_name, actor_method_names)
|
|
# If an exception was thrown when the actor was imported, we record
|
|
# the traceback and notify the scheduler of the failure.
|
|
traceback_str = ray.utils.format_error_message(
|
|
traceback.format_exc())
|
|
# Log the error message.
|
|
push_error_to_driver(
|
|
self._worker,
|
|
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
|
|
"Failed to unpickle actor class '{}' for actor ID {}. "
|
|
"Traceback:\n{}".format(
|
|
class_name, self._worker.actor_id.hex(), traceback_str),
|
|
job_id=job_id)
|
|
# TODO(rkn): In the future, it might make sense to have the worker
|
|
# exit here. However, currently that would lead to hanging if
|
|
# someone calls ray.get on a method invoked on the actor.
|
|
|
|
# The below line is necessary. Because in the driver process,
|
|
# if the function is defined in the file where the python script
|
|
# was started from, its module is `__main__`.
|
|
# However in the worker process, the `__main__` module is a
|
|
# different module, which is `default_worker.py`
|
|
actor_class.__module__ = module_name
|
|
return actor_class
|
|
|
|
def _make_actor_method_executor(self, method_name, method, actor_imported):
|
|
"""Make an executor that wraps a user-defined actor method.
|
|
|
|
The wrapped method updates the worker's internal state and performs any
|
|
necessary checkpointing operations.
|
|
|
|
Args:
|
|
method_name (str): The name of the actor method.
|
|
method (instancemethod): The actor method to wrap. This should be a
|
|
method defined on the actor class and should therefore take an
|
|
instance of the actor as the first argument.
|
|
actor_imported (bool): Whether the actor has been imported.
|
|
Checkpointing operations will not be run if this is set to
|
|
False.
|
|
|
|
Returns:
|
|
A function that executes the given actor method on the worker's
|
|
stored instance of the actor. The function also updates the
|
|
worker's internal state to record the executed method.
|
|
"""
|
|
|
|
def actor_method_executor(actor, *args, **kwargs):
|
|
# Update the actor's task counter to reflect the task we're about
|
|
# to execute.
|
|
self._worker.actor_task_counter += 1
|
|
|
|
# Execute the assigned method and save a checkpoint if necessary.
|
|
try:
|
|
is_bound = (is_class_method(method)
|
|
or is_static_method(type(actor), method_name))
|
|
if is_bound:
|
|
method_returns = method(*args, **kwargs)
|
|
else:
|
|
method_returns = method(actor, *args, **kwargs)
|
|
except Exception as e:
|
|
# Save the checkpoint before allowing the method exception
|
|
# to be thrown, but don't save the checkpoint for actor
|
|
# creation task.
|
|
if (isinstance(actor, ray.actor.Checkpointable)
|
|
and self._worker.actor_task_counter != 1):
|
|
self._save_and_log_checkpoint(actor)
|
|
raise e
|
|
else:
|
|
# Handle any checkpointing operations before storing the
|
|
# method's return values.
|
|
# NOTE(swang): If method_returns is a pointer to the actor's
|
|
# state and the checkpointing operations can modify the return
|
|
# values if they mutate the actor's state. Is this okay?
|
|
if isinstance(actor, ray.actor.Checkpointable):
|
|
# If this is the first task to execute on the actor, try to
|
|
# resume from a checkpoint.
|
|
if self._worker.actor_task_counter == 1:
|
|
if actor_imported:
|
|
self._restore_and_log_checkpoint(actor)
|
|
else:
|
|
# Save the checkpoint before returning the method's
|
|
# return values.
|
|
self._save_and_log_checkpoint(actor)
|
|
return method_returns
|
|
|
|
# Set method_name and method as attributes to the executor clusore
|
|
# so we can make decision based on these attributes in task executor.
|
|
# Precisely, asyncio support requires to know whether:
|
|
# - the method is a ray internal method: starts with __ray
|
|
# - the method is a coroutine function: defined by async def
|
|
actor_method_executor.name = method_name
|
|
actor_method_executor.method = method
|
|
|
|
return actor_method_executor
|
|
|
|
def _save_and_log_checkpoint(self, actor):
|
|
"""Save an actor checkpoint if necessary and log any errors.
|
|
|
|
Args:
|
|
actor: The actor to checkpoint.
|
|
|
|
Returns:
|
|
The result of the actor's user-defined `save_checkpoint` method.
|
|
"""
|
|
actor_id = self._worker.actor_id
|
|
checkpoint_info = self._worker.actor_checkpoint_info[actor_id]
|
|
checkpoint_info.num_tasks_since_last_checkpoint += 1
|
|
now = int(1000 * time.time())
|
|
checkpoint_context = ray.actor.CheckpointContext(
|
|
actor_id, checkpoint_info.num_tasks_since_last_checkpoint,
|
|
now - checkpoint_info.last_checkpoint_timestamp)
|
|
# If we should take a checkpoint, notify raylet to prepare a checkpoint
|
|
# and then call `save_checkpoint`.
|
|
if actor.should_checkpoint(checkpoint_context):
|
|
try:
|
|
now = int(1000 * time.time())
|
|
checkpoint_id = (
|
|
self._worker.core_worker.prepare_actor_checkpoint(actor_id)
|
|
)
|
|
checkpoint_info.checkpoint_ids.append(checkpoint_id)
|
|
actor.save_checkpoint(actor_id, checkpoint_id)
|
|
if (len(checkpoint_info.checkpoint_ids) >
|
|
ray._config.num_actor_checkpoints_to_keep()):
|
|
actor.checkpoint_expired(
|
|
actor_id,
|
|
checkpoint_info.checkpoint_ids.pop(0),
|
|
)
|
|
checkpoint_info.num_tasks_since_last_checkpoint = 0
|
|
checkpoint_info.last_checkpoint_timestamp = now
|
|
except Exception:
|
|
# Checkpoint save or reload failed. Notify the driver.
|
|
traceback_str = ray.utils.format_error_message(
|
|
traceback.format_exc())
|
|
ray.utils.push_error_to_driver(
|
|
self._worker,
|
|
ray_constants.CHECKPOINT_PUSH_ERROR,
|
|
traceback_str,
|
|
job_id=self._worker.current_job_id)
|
|
|
|
def _restore_and_log_checkpoint(self, actor):
|
|
"""Restore an actor from a checkpoint if available and log any errors.
|
|
|
|
This should only be called on workers that have just executed an actor
|
|
creation task.
|
|
|
|
Args:
|
|
actor: The actor to restore from a checkpoint.
|
|
"""
|
|
actor_id = self._worker.actor_id
|
|
try:
|
|
checkpoints = ray.actor.get_checkpoints_for_actor(actor_id)
|
|
if len(checkpoints) > 0:
|
|
# If we found previously saved checkpoints for this actor,
|
|
# call the `load_checkpoint` callback.
|
|
checkpoint_id = actor.load_checkpoint(actor_id, checkpoints)
|
|
if checkpoint_id is not None:
|
|
# Check that the returned checkpoint id is in the
|
|
# `available_checkpoints` list.
|
|
msg = (
|
|
"`load_checkpoint` must return a checkpoint id that " +
|
|
"exists in the `available_checkpoints` list, or None.")
|
|
assert any(checkpoint_id == checkpoint.checkpoint_id
|
|
for checkpoint in checkpoints), msg
|
|
# Notify raylet that this actor has been resumed from
|
|
# a checkpoint.
|
|
(self._worker.core_worker.
|
|
notify_actor_resumed_from_checkpoint(
|
|
actor_id, checkpoint_id))
|
|
except Exception:
|
|
# Checkpoint save or reload failed. Notify the driver.
|
|
traceback_str = ray.utils.format_error_message(
|
|
traceback.format_exc())
|
|
ray.utils.push_error_to_driver(
|
|
self._worker,
|
|
ray_constants.CHECKPOINT_PUSH_ERROR,
|
|
traceback_str,
|
|
job_id=self._worker.current_job_id)
|