mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
Move function/actor exporting & loading code to function_manager.py (#3003)
Move function/actor exporting & loading code to function_manager.py to prepare the code change for function descriptor for python.
This commit is contained in:
committed by
Robert Nishihara
parent
d73ee36e60
commit
9948e8c11b
+24
-275
@@ -5,31 +5,19 @@ from __future__ import print_function
|
||||
import copy
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import traceback
|
||||
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.function_manager import FunctionActorManager
|
||||
import ray.local_scheduler
|
||||
import ray.ray_constants as ray_constants
|
||||
import ray.signature as signature
|
||||
import ray.worker
|
||||
from ray.utils import (
|
||||
decode,
|
||||
_random_string,
|
||||
check_oversized_pickle,
|
||||
is_cython,
|
||||
push_error_to_driver,
|
||||
)
|
||||
from ray.utils import _random_string
|
||||
|
||||
DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1
|
||||
|
||||
|
||||
def is_classmethod(f):
|
||||
"""Returns whether the given method is a classmethod."""
|
||||
|
||||
return hasattr(f, "__self__") and f.__self__ is not None
|
||||
|
||||
|
||||
def compute_actor_handle_id(actor_handle_id, num_forks):
|
||||
"""Deterministically compute an actor handle ID.
|
||||
|
||||
@@ -96,24 +84,6 @@ def compute_actor_creation_function_id(class_id):
|
||||
return ray.ObjectID(class_id)
|
||||
|
||||
|
||||
def compute_actor_method_function_id(class_name, attr):
|
||||
"""Get the function ID corresponding to an actor method.
|
||||
|
||||
Args:
|
||||
class_name (str): The class name of the actor.
|
||||
attr (str): The attribute name of the method.
|
||||
|
||||
Returns:
|
||||
Function ID corresponding to the method.
|
||||
"""
|
||||
function_id_hash = hashlib.sha1()
|
||||
function_id_hash.update(class_name.encode("ascii"))
|
||||
function_id_hash.update(attr.encode("ascii"))
|
||||
function_id = function_id_hash.digest()
|
||||
assert len(function_id) == ray_constants.ID_SIZE
|
||||
return ray.ObjectID(function_id)
|
||||
|
||||
|
||||
def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
|
||||
frontier):
|
||||
"""Set the most recent checkpoint associated with a given actor ID.
|
||||
@@ -134,28 +104,6 @@ def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
|
||||
})
|
||||
|
||||
|
||||
def get_actor_checkpoint(worker, actor_id):
|
||||
"""Get the most recent checkpoint associated with a given actor ID.
|
||||
|
||||
Args:
|
||||
worker: The worker to use to get the checkpoint.
|
||||
actor_id: The actor ID of the actor to get the checkpoint for.
|
||||
|
||||
Returns:
|
||||
If a checkpoint exists, this returns a tuple of the number of tasks
|
||||
included in the checkpoint, the saved checkpoint state, and the
|
||||
task frontier at the time of the checkpoint. If no checkpoint
|
||||
exists, all objects are set to None. The checkpoint index is the .
|
||||
executed on the actor before the checkpoint was made.
|
||||
"""
|
||||
actor_key = b"Actor:" + actor_id
|
||||
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
|
||||
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
|
||||
if checkpoint_index is not None:
|
||||
checkpoint_index = int(checkpoint_index)
|
||||
return checkpoint_index, checkpoint, frontier
|
||||
|
||||
|
||||
def save_and_log_checkpoint(worker, actor):
|
||||
"""Save a checkpoint on the actor and log any errors.
|
||||
|
||||
@@ -205,219 +153,26 @@ def restore_and_log_checkpoint(worker, actor):
|
||||
return checkpoint_resumed
|
||||
|
||||
|
||||
def make_actor_method_executor(worker, method_name, method, actor_imported):
|
||||
"""Make an executor that wraps a user-defined actor method.
|
||||
|
||||
The wrapped method updates the worker's internal state and performs any
|
||||
necessary checkpointing operations.
|
||||
def get_actor_checkpoint(worker, actor_id):
|
||||
"""Get the most recent checkpoint associated with a given actor ID.
|
||||
|
||||
Args:
|
||||
worker (Worker): The worker that is executing the actor.
|
||||
method_name (str): The name of the actor method.
|
||||
method (instancemethod): The actor method to wrap. This should be a
|
||||
method defined on the actor class and should therefore take an
|
||||
instance of the actor as the first argument.
|
||||
actor_imported (bool): Whether the actor has been imported.
|
||||
Checkpointing operations will not be run if this is set to False.
|
||||
worker: The worker to use to get the checkpoint.
|
||||
actor_id: The actor ID of the actor to get the checkpoint for.
|
||||
|
||||
Returns:
|
||||
A function that executes the given actor method on the worker's stored
|
||||
instance of the actor. The function also updates the worker's
|
||||
internal state to record the executed method.
|
||||
If a checkpoint exists, this returns a tuple of the number of tasks
|
||||
included in the checkpoint, the saved checkpoint state, and the
|
||||
task frontier at the time of the checkpoint. If no checkpoint
|
||||
exists, all objects are set to None. The checkpoint index is the .
|
||||
executed on the actor before the checkpoint was made.
|
||||
"""
|
||||
|
||||
def actor_method_executor(dummy_return_id, actor, *args):
|
||||
# Update the actor's task counter to reflect the task we're about to
|
||||
# execute.
|
||||
worker.actor_task_counter += 1
|
||||
|
||||
# If this is the first task to execute on the actor, try to resume from
|
||||
# a checkpoint.
|
||||
if actor_imported and worker.actor_task_counter == 1:
|
||||
checkpoint_resumed = restore_and_log_checkpoint(worker, actor)
|
||||
if checkpoint_resumed:
|
||||
# NOTE(swang): Since we did not actually execute the __init__
|
||||
# method, this will put None as the return value. If the
|
||||
# __init__ method is supposed to return multiple values, an
|
||||
# exception will be logged.
|
||||
return
|
||||
|
||||
# Determine whether we should checkpoint the actor.
|
||||
checkpointing_on = (actor_imported
|
||||
and worker.actor_checkpoint_interval > 0)
|
||||
# We should checkpoint the actor if user checkpointing is on, we've
|
||||
# executed checkpoint_interval tasks since the last checkpoint, and the
|
||||
# method we're about to execute is not a checkpoint.
|
||||
save_checkpoint = (
|
||||
checkpointing_on and
|
||||
(worker.actor_task_counter % worker.actor_checkpoint_interval == 0
|
||||
and method_name != "__ray_checkpoint__"))
|
||||
|
||||
# Execute the assigned method and save a checkpoint if necessary.
|
||||
try:
|
||||
if is_classmethod(method):
|
||||
method_returns = method(*args)
|
||||
else:
|
||||
method_returns = method(actor, *args)
|
||||
except Exception:
|
||||
# Save the checkpoint before allowing the method exception to be
|
||||
# thrown.
|
||||
if save_checkpoint:
|
||||
save_and_log_checkpoint(worker, actor)
|
||||
raise
|
||||
else:
|
||||
# Save the checkpoint before returning the method's return values.
|
||||
if save_checkpoint:
|
||||
save_and_log_checkpoint(worker, actor)
|
||||
return method_returns
|
||||
|
||||
return actor_method_executor
|
||||
|
||||
|
||||
def fetch_and_register_actor(actor_class_key, worker):
|
||||
"""Import an actor.
|
||||
|
||||
This will be called by the worker's import thread when the worker receives
|
||||
the actor_class export, assuming that the worker is an actor for that
|
||||
class.
|
||||
|
||||
Args:
|
||||
actor_class_key: The key in Redis to use to fetch the actor.
|
||||
worker: The worker to use.
|
||||
"""
|
||||
actor_id_str = worker.actor_id
|
||||
(driver_id, class_id, class_name, module, pickled_class,
|
||||
checkpoint_interval, actor_method_names) = worker.redis_client.hmget(
|
||||
actor_class_key, [
|
||||
"driver_id", "class_id", "class_name", "module", "class",
|
||||
"checkpoint_interval", "actor_method_names"
|
||||
])
|
||||
|
||||
class_name = decode(class_name)
|
||||
module = decode(module)
|
||||
checkpoint_interval = int(checkpoint_interval)
|
||||
actor_method_names = json.loads(decode(actor_method_names))
|
||||
|
||||
# Create a temporary actor with some temporary methods so that if the actor
|
||||
# fails to be unpickled, the temporary actor can be used (just to produce
|
||||
# error messages and to prevent the driver from hanging).
|
||||
class TemporaryActor(object):
|
||||
pass
|
||||
|
||||
worker.actors[actor_id_str] = TemporaryActor()
|
||||
worker.actor_checkpoint_interval = checkpoint_interval
|
||||
|
||||
def temporary_actor_method(*xs):
|
||||
raise Exception("The actor with name {} failed to be imported, and so "
|
||||
"cannot execute this method".format(class_name))
|
||||
|
||||
# Register the actor method executors.
|
||||
for actor_method_name in actor_method_names:
|
||||
function_id = compute_actor_method_function_id(class_name,
|
||||
actor_method_name).id()
|
||||
temporary_executor = make_actor_method_executor(
|
||||
worker,
|
||||
actor_method_name,
|
||||
temporary_actor_method,
|
||||
actor_imported=False)
|
||||
worker.function_execution_info[driver_id][function_id] = (
|
||||
ray.worker.FunctionExecutionInfo(
|
||||
function=temporary_executor,
|
||||
function_name=actor_method_name,
|
||||
max_calls=0))
|
||||
worker.num_task_executions[driver_id][function_id] = 0
|
||||
|
||||
try:
|
||||
unpickled_class = pickle.loads(pickled_class)
|
||||
worker.actor_class = unpickled_class
|
||||
except Exception:
|
||||
# If an exception was thrown when the actor was imported, we record the
|
||||
# traceback and notify the scheduler of the failure.
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
push_error_to_driver(
|
||||
worker,
|
||||
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id,
|
||||
data={"actor_id": actor_id_str})
|
||||
# TODO(rkn): In the future, it might make sense to have the worker exit
|
||||
# here. However, currently that would lead to hanging if someone calls
|
||||
# ray.get on a method invoked on the actor.
|
||||
else:
|
||||
# TODO(pcm): Why is the below line necessary?
|
||||
unpickled_class.__module__ = module
|
||||
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
|
||||
|
||||
def pred(x):
|
||||
return (inspect.isfunction(x) or inspect.ismethod(x)
|
||||
or is_cython(x))
|
||||
|
||||
actor_methods = inspect.getmembers(unpickled_class, predicate=pred)
|
||||
for actor_method_name, actor_method in actor_methods:
|
||||
function_id = compute_actor_method_function_id(
|
||||
class_name, actor_method_name).id()
|
||||
executor = make_actor_method_executor(
|
||||
worker, actor_method_name, actor_method, actor_imported=True)
|
||||
worker.function_execution_info[driver_id][function_id] = (
|
||||
ray.worker.FunctionExecutionInfo(
|
||||
function=executor,
|
||||
function_name=actor_method_name,
|
||||
max_calls=0))
|
||||
# We do not set worker.function_properties[driver_id][function_id]
|
||||
# because we currently do need the actor worker to submit new tasks
|
||||
# for the actor.
|
||||
|
||||
|
||||
def publish_actor_class_to_key(key, actor_class_info, worker):
|
||||
"""Push an actor class definition to Redis.
|
||||
|
||||
The is factored out as a separate function because it is also called
|
||||
on cached actor class definitions when a worker connects for the first
|
||||
time.
|
||||
|
||||
Args:
|
||||
key: The key to store the actor class info at.
|
||||
actor_class_info: Information about the actor class.
|
||||
worker: The worker to use to connect to Redis.
|
||||
"""
|
||||
# We set the driver ID here because it may not have been available when the
|
||||
# actor class was defined.
|
||||
actor_class_info["driver_id"] = worker.task_driver_id.id()
|
||||
worker.redis_client.hmset(key, actor_class_info)
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
|
||||
|
||||
def export_actor_class(class_id, Class, actor_method_names,
|
||||
checkpoint_interval, worker):
|
||||
key = b"ActorClass:" + class_id
|
||||
actor_class_info = {
|
||||
"class_name": Class.__name__,
|
||||
"module": Class.__module__,
|
||||
"class": pickle.dumps(Class),
|
||||
"checkpoint_interval": checkpoint_interval,
|
||||
"actor_method_names": json.dumps(list(actor_method_names))
|
||||
}
|
||||
|
||||
check_oversized_pickle(actor_class_info["class"],
|
||||
actor_class_info["class_name"], "actor", worker)
|
||||
|
||||
if worker.mode is None:
|
||||
# This means that 'ray.init()' has not been called yet and so we must
|
||||
# cache the actor class definition and export it when 'ray.init()' is
|
||||
# called.
|
||||
assert worker.cached_remote_functions_and_actors is not None
|
||||
worker.cached_remote_functions_and_actors.append(
|
||||
("actor", (key, actor_class_info)))
|
||||
# This caching code path is currently not used because we only export
|
||||
# actor class definitions lazily when we instantiate the actor for the
|
||||
# first time.
|
||||
assert False, "This should be unreachable."
|
||||
else:
|
||||
publish_actor_class_to_key(key, actor_class_info, worker)
|
||||
# TODO(rkn): Currently we allow actor classes to be defined within tasks.
|
||||
# I tried to disable this, but it may be necessary because of
|
||||
# https://github.com/ray-project/ray/issues/1146.
|
||||
actor_key = b"Actor:" + actor_id
|
||||
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
|
||||
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
|
||||
if checkpoint_index is not None:
|
||||
checkpoint_index = int(checkpoint_index)
|
||||
return checkpoint_index, checkpoint, frontier
|
||||
|
||||
|
||||
def method(*args, **kwargs):
|
||||
@@ -518,13 +273,8 @@ class ActorClass(object):
|
||||
self._actor_method_cpus = actor_method_cpus
|
||||
self._exported = False
|
||||
|
||||
# Get the actor methods of the given class.
|
||||
def pred(x):
|
||||
return (inspect.isfunction(x) or inspect.ismethod(x)
|
||||
or is_cython(x))
|
||||
|
||||
self._actor_methods = inspect.getmembers(
|
||||
self._modified_class, predicate=pred)
|
||||
self._modified_class, ray.utils.is_function_or_method)
|
||||
# Extract the signatures of each of the methods. This will be used
|
||||
# to catch some errors if the methods are called with inappropriate
|
||||
# arguments.
|
||||
@@ -537,7 +287,7 @@ class ActorClass(object):
|
||||
# don't support, there may not be much the user can do about it.
|
||||
signature.check_signature_supported(method, warn=True)
|
||||
self._method_signatures[method_name] = signature.extract_signature(
|
||||
method, ignore_first=not is_classmethod(method))
|
||||
method, ignore_first=not ray.utils.is_class_method(method))
|
||||
|
||||
# Set the default number of return values for this method.
|
||||
if hasattr(method, "__ray_num_return_vals__"):
|
||||
@@ -614,9 +364,9 @@ class ActorClass(object):
|
||||
else:
|
||||
# Export the actor.
|
||||
if not self._exported:
|
||||
export_actor_class(self._class_id, self._modified_class,
|
||||
self._actor_method_names,
|
||||
self._checkpoint_interval, worker)
|
||||
worker.function_actor_manager.export_actor_class(
|
||||
self._class_id, self._modified_class,
|
||||
self._actor_method_names, self._checkpoint_interval)
|
||||
self._exported = True
|
||||
|
||||
resources = ray.utils.resources_from_resource_arguments(
|
||||
@@ -801,8 +551,8 @@ class ActorHandle(object):
|
||||
else:
|
||||
actor_handle_id = self._ray_actor_handle_id
|
||||
|
||||
function_id = compute_actor_method_function_id(self._ray_class_name,
|
||||
method_name)
|
||||
function_id = FunctionActorManager.compute_actor_method_function_id(
|
||||
self._ray_class_name, method_name)
|
||||
object_ids = worker.submit_task(
|
||||
function_id,
|
||||
args,
|
||||
@@ -1068,5 +818,4 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
|
||||
resources, actor_method_cpus)
|
||||
|
||||
|
||||
ray.worker.global_worker.fetch_and_register_actor = fetch_and_register_actor
|
||||
ray.worker.global_worker.make_actor = make_actor
|
||||
|
||||
@@ -0,0 +1,486 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from collections import (
|
||||
namedtuple,
|
||||
defaultdict,
|
||||
)
|
||||
|
||||
import ray
|
||||
from ray import profiling
|
||||
from ray import ray_constants
|
||||
from ray import cloudpickle as pickle
|
||||
from ray.utils import (
|
||||
is_cython,
|
||||
is_function_or_method,
|
||||
is_class_method,
|
||||
check_oversized_pickle,
|
||||
decode,
|
||||
format_error_message,
|
||||
push_error_to_driver,
|
||||
)
|
||||
|
||||
FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
|
||||
["function", "function_name", "max_calls"])
|
||||
"""FunctionExecutionInfo: A named tuple storing remote function information."""
|
||||
|
||||
|
||||
class FunctionActorManager(object):
|
||||
"""A class used to export/load remote functions and actors.
|
||||
|
||||
Attributes:
|
||||
_worker: The associated worker that this manager related.
|
||||
_functions_to_export: The remote functions to export when
|
||||
the worker gets connected.
|
||||
_actors_to_export: The actors to export when the worker gets
|
||||
connected.
|
||||
_function_execution_info: The map from driver_id to finction_id
|
||||
and execution_info.
|
||||
_num_task_executions: The map from driver_id to function
|
||||
execution times.
|
||||
"""
|
||||
|
||||
def __init__(self, worker):
|
||||
self._worker = worker
|
||||
self._functions_to_export = []
|
||||
self._actors_to_export = []
|
||||
# This field is a dictionary that maps a driver ID to a dictionary of
|
||||
# functions (and information about those functions) that have been
|
||||
# registered for that driver (this inner dictionary maps function IDs
|
||||
# to a FunctionExecutionInfo object. This should only be used on
|
||||
# workers that execute remote functions.
|
||||
self._function_execution_info = defaultdict(lambda: {})
|
||||
self._num_task_executions = defaultdict(lambda: {})
|
||||
|
||||
def increase_task_counter(self, driver_id, function_id):
|
||||
self._num_task_executions[driver_id][function_id] += 1
|
||||
|
||||
def get_task_counter(self, driver_id, function_id):
|
||||
return self._num_task_executions[driver_id][function_id]
|
||||
|
||||
def export_cached(self):
|
||||
"""Export cached remote functions
|
||||
|
||||
Note: this should be called only once when worker is connected.
|
||||
"""
|
||||
for remote_function in self._functions_to_export:
|
||||
self._do_export(remote_function)
|
||||
self._functions_to_export = None
|
||||
for info in self._actors_to_export:
|
||||
(key, actor_class_info) = info
|
||||
self._publish_actor_class_to_key(key, actor_class_info)
|
||||
|
||||
def reset_cache(self):
|
||||
self._functions_to_export = []
|
||||
self._actors_to_export = []
|
||||
|
||||
def export(self, remote_function):
|
||||
"""Export a remote function.
|
||||
|
||||
Args:
|
||||
remote_function: the RemoteFunction object.
|
||||
"""
|
||||
if self._worker.mode is None:
|
||||
# If the worker isn't connected, cache the function
|
||||
# and export it later.
|
||||
self._functions_to_export.append(remote_function)
|
||||
return
|
||||
if self._worker.mode != ray.worker.SCRIPT_MODE:
|
||||
# Don't need to export if the worker is not a driver.
|
||||
return
|
||||
self._do_export(remote_function)
|
||||
|
||||
def _do_export(self, remote_function):
|
||||
"""Pickle a remote function and export it to redis.
|
||||
|
||||
Args:
|
||||
remote_function: the RemoteFunction object.
|
||||
"""
|
||||
# Work around limitations of Python pickling.
|
||||
function = remote_function._function
|
||||
function_name_global_valid = function.__name__ in function.__globals__
|
||||
function_name_global_value = function.__globals__.get(
|
||||
function.__name__)
|
||||
# Allow the function to reference itself as a global variable
|
||||
if not is_cython(function):
|
||||
function.__globals__[function.__name__] = remote_function
|
||||
try:
|
||||
pickled_function = pickle.dumps(function)
|
||||
finally:
|
||||
# Undo our changes
|
||||
if function_name_global_valid:
|
||||
function.__globals__[function.__name__] = (
|
||||
function_name_global_value)
|
||||
else:
|
||||
del function.__globals__[function.__name__]
|
||||
|
||||
check_oversized_pickle(pickled_function,
|
||||
remote_function._function_name,
|
||||
"remote function", self._worker)
|
||||
|
||||
key = (b"RemoteFunction:" + self._worker.task_driver_id.id() + b":" +
|
||||
remote_function._function_id)
|
||||
self._worker.redis_client.hmset(
|
||||
key, {
|
||||
"driver_id": self._worker.task_driver_id.id(),
|
||||
"function_id": remote_function._function_id,
|
||||
"name": remote_function._function_name,
|
||||
"module": function.__module__,
|
||||
"function": pickled_function,
|
||||
"max_calls": remote_function._max_calls
|
||||
})
|
||||
self._worker.redis_client.rpush("Exports", key)
|
||||
|
||||
def fetch_and_register_remote_function(self, key):
|
||||
"""Import a remote function."""
|
||||
(driver_id, function_id_str, function_name, serialized_function,
|
||||
num_return_vals, module, resources,
|
||||
max_calls) = self._worker.redis_client.hmget(key, [
|
||||
"driver_id", "function_id", "name", "function", "num_return_vals",
|
||||
"module", "resources", "max_calls"
|
||||
])
|
||||
function_id = ray.ObjectID(function_id_str)
|
||||
function_name = decode(function_name)
|
||||
max_calls = int(max_calls)
|
||||
module = decode(module)
|
||||
|
||||
# This is a placeholder in case the function can't be unpickled. This
|
||||
# will be overwritten if the function is successfully registered.
|
||||
def f():
|
||||
raise Exception("This function was not imported properly.")
|
||||
|
||||
self._function_execution_info[driver_id][function_id.id()] = (
|
||||
FunctionExecutionInfo(
|
||||
function=f, function_name=function_name, max_calls=max_calls))
|
||||
self._num_task_executions[driver_id][function_id.id()] = 0
|
||||
|
||||
try:
|
||||
function = pickle.loads(serialized_function)
|
||||
except Exception as e:
|
||||
# If an exception was thrown when the remote function was imported,
|
||||
# we record the traceback and notify the scheduler of the failure.
|
||||
traceback_str = format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
push_error_to_driver(
|
||||
self._worker,
|
||||
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=driver_id,
|
||||
data={
|
||||
"function_id": function_id.id(),
|
||||
"function_name": function_name
|
||||
})
|
||||
else:
|
||||
# The below line is necessary. Because in the driver process,
|
||||
# if the function is defined in the file where the python script
|
||||
# was started from, its module is `__main__`.
|
||||
# However in the worker process, the `__main__` module is a
|
||||
# different module, which is `default_worker.py`
|
||||
function.__module__ = module
|
||||
self._function_execution_info[driver_id][function_id.id()] = (
|
||||
FunctionExecutionInfo(
|
||||
function=function,
|
||||
function_name=function_name,
|
||||
max_calls=max_calls))
|
||||
# Add the function to the function table.
|
||||
self._worker.redis_client.rpush(
|
||||
b"FunctionTable:" + function_id.id(), self._worker.worker_id)
|
||||
|
||||
def get_execution_info(self, driver_id, function_id):
|
||||
"""Get the FunctionExecutionInfo of a remote function.
|
||||
|
||||
Args:
|
||||
driver_id: ID of the driver that the function belongs to.
|
||||
function_id: ID of the function to get.
|
||||
|
||||
Returns:
|
||||
A FunctionExecutionInfo object.
|
||||
"""
|
||||
# Wait until the function to be executed has actually been registered
|
||||
# on this worker. We will push warnings to the user if we spend too
|
||||
# long in this loop.
|
||||
with profiling.profile("wait_for_function", worker=self._worker):
|
||||
self._wait_for_function(function_id, driver_id)
|
||||
return self._function_execution_info[driver_id][function_id.id()]
|
||||
|
||||
def _wait_for_function(self, function_id, driver_id, timeout=10):
|
||||
"""Wait until the function to be executed is present on this worker.
|
||||
|
||||
This method will simply loop until the import thread has imported the
|
||||
relevant function. If we spend too long in this loop, that may indicate
|
||||
a problem somewhere and we will push an error message to the user.
|
||||
|
||||
If this worker is an actor, then this will wait until the actor has
|
||||
been defined.
|
||||
|
||||
Args:
|
||||
function_id (str): The ID of the function that we want to execute.
|
||||
driver_id (str): The ID of the driver to push the error message to
|
||||
if this times out.
|
||||
"""
|
||||
start_time = time.time()
|
||||
# Only send the warning once.
|
||||
warning_sent = False
|
||||
while True:
|
||||
with self._worker.lock:
|
||||
if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID
|
||||
and (function_id.id() in
|
||||
self._function_execution_info[driver_id])):
|
||||
break
|
||||
elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and (
|
||||
self._worker.actor_id in self._worker.actors):
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
warning_message = ("This worker was asked to execute a "
|
||||
"function that it does not have "
|
||||
"registered. You may have to restart "
|
||||
"Ray.")
|
||||
if not warning_sent:
|
||||
ray.utils.push_error_to_driver(
|
||||
self._worker,
|
||||
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=driver_id)
|
||||
warning_sent = True
|
||||
time.sleep(0.001)
|
||||
|
||||
@classmethod
|
||||
def compute_actor_method_function_id(cls, class_name, attr):
|
||||
"""Get the function ID corresponding to an actor method.
|
||||
|
||||
Args:
|
||||
class_name (str): The class name of the actor.
|
||||
attr (str): The attribute name of the method.
|
||||
|
||||
Returns:
|
||||
Function ID corresponding to the method.
|
||||
"""
|
||||
function_id_hash = hashlib.sha1()
|
||||
function_id_hash.update(class_name.encode("ascii"))
|
||||
function_id_hash.update(attr.encode("ascii"))
|
||||
function_id = function_id_hash.digest()
|
||||
assert len(function_id) == ray_constants.ID_SIZE
|
||||
return ray.ObjectID(function_id)
|
||||
|
||||
def _publish_actor_class_to_key(self, key, actor_class_info):
|
||||
"""Push an actor class definition to Redis.
|
||||
|
||||
The is factored out as a separate function because it is also called
|
||||
on cached actor class definitions when a worker connects for the first
|
||||
time.
|
||||
|
||||
Args:
|
||||
key: The key to store the actor class info at.
|
||||
actor_class_info: Information about the actor class.
|
||||
worker: The worker to use to connect to Redis.
|
||||
"""
|
||||
# We set the driver ID here because it may not have been available when
|
||||
# the actor class was defined.
|
||||
actor_class_info["driver_id"] = self._worker.task_driver_id.id()
|
||||
self._worker.redis_client.hmset(key, actor_class_info)
|
||||
self._worker.redis_client.rpush("Exports", key)
|
||||
|
||||
def export_actor_class(self, class_id, Class, actor_method_names,
|
||||
checkpoint_interval):
|
||||
key = b"ActorClass:" + class_id
|
||||
actor_class_info = {
|
||||
"class_name": Class.__name__,
|
||||
"module": Class.__module__,
|
||||
"class": pickle.dumps(Class),
|
||||
"checkpoint_interval": checkpoint_interval,
|
||||
"actor_method_names": json.dumps(list(actor_method_names))
|
||||
}
|
||||
|
||||
check_oversized_pickle(actor_class_info["class"],
|
||||
actor_class_info["class_name"], "actor",
|
||||
self._worker)
|
||||
|
||||
if self._worker.mode is None:
|
||||
# This means that 'ray.init()' has not been called yet and so we
|
||||
# must cache the actor class definition and export it when
|
||||
# 'ray.init()' is called.
|
||||
assert self._actors_to_export is not None
|
||||
self._actors_to_export.append((key, actor_class_info))
|
||||
# This caching code path is currently not used because we only
|
||||
# export actor class definitions lazily when we instantiate the
|
||||
# actor for the first time.
|
||||
assert False, "This should be unreachable."
|
||||
else:
|
||||
self._publish_actor_class_to_key(key, actor_class_info)
|
||||
# TODO(rkn): Currently we allow actor classes to be defined
|
||||
# within tasks. I tried to disable this, but it may be necessary
|
||||
# because of https://github.com/ray-project/ray/issues/1146.
|
||||
|
||||
def fetch_and_register_actor(self, actor_class_key):
|
||||
"""Import an actor.
|
||||
|
||||
This will be called by the worker's import thread when the worker
|
||||
receives the actor_class export, assuming that the worker is an actor
|
||||
for that class.
|
||||
|
||||
Args:
|
||||
actor_class_key: The key in Redis to use to fetch the actor.
|
||||
worker: The worker to use.
|
||||
"""
|
||||
actor_id_str = self._worker.actor_id
|
||||
(driver_id, class_id, class_name, module, pickled_class,
|
||||
checkpoint_interval,
|
||||
actor_method_names) = self._worker.redis_client.hmget(
|
||||
actor_class_key, [
|
||||
"driver_id", "class_id", "class_name", "module", "class",
|
||||
"checkpoint_interval", "actor_method_names"
|
||||
])
|
||||
|
||||
class_name = decode(class_name)
|
||||
module = decode(module)
|
||||
checkpoint_interval = int(checkpoint_interval)
|
||||
actor_method_names = json.loads(decode(actor_method_names))
|
||||
|
||||
# Create a temporary actor with some temporary methods so that if
|
||||
# the actor fails to be unpickled, the temporary actor can be used
|
||||
# (just to produce error messages and to prevent the driver from
|
||||
# hanging).
|
||||
class TemporaryActor(object):
|
||||
pass
|
||||
|
||||
self._worker.actors[actor_id_str] = TemporaryActor()
|
||||
self._worker.actor_checkpoint_interval = checkpoint_interval
|
||||
|
||||
def temporary_actor_method(*xs):
|
||||
raise Exception(
|
||||
"The actor with name {} failed to be imported, "
|
||||
"and so cannot execute this method".format(class_name))
|
||||
|
||||
# Register the actor method executors.
|
||||
for actor_method_name in actor_method_names:
|
||||
function_id = (
|
||||
FunctionActorManager.compute_actor_method_function_id(
|
||||
class_name, actor_method_name).id())
|
||||
temporary_executor = self._make_actor_method_executor(
|
||||
actor_method_name,
|
||||
temporary_actor_method,
|
||||
actor_imported=False)
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=temporary_executor,
|
||||
function_name=actor_method_name,
|
||||
max_calls=0))
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
|
||||
try:
|
||||
unpickled_class = pickle.loads(pickled_class)
|
||||
self._worker.actor_class = unpickled_class
|
||||
except Exception:
|
||||
# If an exception was thrown when the actor was imported, we record
|
||||
# the traceback and notify the scheduler of the failure.
|
||||
traceback_str = ray.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
# Log the error message.
|
||||
push_error_to_driver(
|
||||
self._worker,
|
||||
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id,
|
||||
data={"actor_id": actor_id_str})
|
||||
# TODO(rkn): In the future, it might make sense to have the worker
|
||||
# exit here. However, currently that would lead to hanging if
|
||||
# someone calls ray.get on a method invoked on the actor.
|
||||
else:
|
||||
# TODO(pcm): Why is the below line necessary?
|
||||
unpickled_class.__module__ = module
|
||||
self._worker.actors[actor_id_str] = unpickled_class.__new__(
|
||||
unpickled_class)
|
||||
|
||||
actor_methods = inspect.getmembers(
|
||||
unpickled_class, predicate=is_function_or_method)
|
||||
for actor_method_name, actor_method in actor_methods:
|
||||
function_id = (
|
||||
FunctionActorManager.compute_actor_method_function_id(
|
||||
class_name, actor_method_name).id())
|
||||
executor = self._make_actor_method_executor(
|
||||
actor_method_name, actor_method, actor_imported=True)
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=executor,
|
||||
function_name=actor_method_name,
|
||||
max_calls=0))
|
||||
# We do not set function_properties[driver_id][function_id]
|
||||
# because we currently do need the actor worker to submit new
|
||||
# tasks for the actor.
|
||||
|
||||
def _make_actor_method_executor(self, method_name, method, actor_imported):
|
||||
"""Make an executor that wraps a user-defined actor method.
|
||||
|
||||
The wrapped method updates the worker's internal state and performs any
|
||||
necessary checkpointing operations.
|
||||
|
||||
Args:
|
||||
worker (Worker): The worker that is executing the actor.
|
||||
method_name (str): The name of the actor method.
|
||||
method (instancemethod): The actor method to wrap. This should be a
|
||||
method defined on the actor class and should therefore take an
|
||||
instance of the actor as the first argument.
|
||||
actor_imported (bool): Whether the actor has been imported.
|
||||
Checkpointing operations will not be run if this is set to
|
||||
False.
|
||||
|
||||
Returns:
|
||||
A function that executes the given actor method on the worker's
|
||||
stored instance of the actor. The function also updates the
|
||||
worker's internal state to record the executed method.
|
||||
"""
|
||||
|
||||
def actor_method_executor(dummy_return_id, actor, *args):
|
||||
# Update the actor's task counter to reflect the task we're about
|
||||
# to execute.
|
||||
self._worker.actor_task_counter += 1
|
||||
|
||||
# If this is the first task to execute on the actor, try to resume
|
||||
# from a checkpoint.
|
||||
if actor_imported and self._worker.actor_task_counter == 1:
|
||||
checkpoint_resumed = ray.actor.restore_and_log_checkpoint(
|
||||
self._worker, actor)
|
||||
if checkpoint_resumed:
|
||||
# NOTE(swang): Since we did not actually execute the
|
||||
# __init__ method, this will put None as the return value.
|
||||
# If the __init__ method is supposed to return multiple
|
||||
# values, an exception will be logged.
|
||||
return
|
||||
|
||||
# Determine whether we should checkpoint the actor.
|
||||
checkpointing_on = (actor_imported
|
||||
and self._worker.actor_checkpoint_interval > 0)
|
||||
# We should checkpoint the actor if user checkpointing is on, we've
|
||||
# executed checkpoint_interval tasks since the last checkpoint, and
|
||||
# the method we're about to execute is not a checkpoint.
|
||||
save_checkpoint = (checkpointing_on
|
||||
and (self._worker.actor_task_counter %
|
||||
self._worker.actor_checkpoint_interval == 0
|
||||
and method_name != "__ray_checkpoint__"))
|
||||
|
||||
# Execute the assigned method and save a checkpoint if necessary.
|
||||
try:
|
||||
if is_class_method(method):
|
||||
method_returns = method(*args)
|
||||
else:
|
||||
method_returns = method(actor, *args)
|
||||
except Exception:
|
||||
# Save the checkpoint before allowing the method exception
|
||||
# to be thrown.
|
||||
if save_checkpoint:
|
||||
ray.actor.save_and_log_checkpoint(self._worker, actor)
|
||||
raise
|
||||
else:
|
||||
# Save the checkpoint before returning the method's return
|
||||
# values.
|
||||
if save_checkpoint:
|
||||
ray.actor.save_and_log_checkpoint(self._worker, actor)
|
||||
return method_returns
|
||||
|
||||
return actor_method_executor
|
||||
@@ -88,7 +88,8 @@ class ImportThread(object):
|
||||
if key.startswith(b"RemoteFunction"):
|
||||
with profiling.profile(
|
||||
"register_remote_function", worker=self.worker):
|
||||
self.fetch_and_register_remote_function(key)
|
||||
(self.worker.function_actor_manager.
|
||||
fetch_and_register_remote_function(key))
|
||||
elif key.startswith(b"FunctionsToRun"):
|
||||
with profiling.profile(
|
||||
"fetch_and_run_function", worker=self.worker):
|
||||
@@ -103,58 +104,6 @@ class ImportThread(object):
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
|
||||
def fetch_and_register_remote_function(self, key):
|
||||
"""Import a remote function."""
|
||||
from ray.worker import FunctionExecutionInfo
|
||||
(driver_id, function_id_str, function_name, serialized_function,
|
||||
num_return_vals, module, resources,
|
||||
max_calls) = self.redis_client.hmget(key, [
|
||||
"driver_id", "function_id", "name", "function", "num_return_vals",
|
||||
"module", "resources", "max_calls"
|
||||
])
|
||||
function_id = ray.ObjectID(function_id_str)
|
||||
function_name = utils.decode(function_name)
|
||||
max_calls = int(max_calls)
|
||||
module = utils.decode(module)
|
||||
|
||||
# This is a placeholder in case the function can't be unpickled. This
|
||||
# will be overwritten if the function is successfully registered.
|
||||
def f():
|
||||
raise Exception("This function was not imported properly.")
|
||||
|
||||
self.worker.function_execution_info[driver_id][function_id.id()] = (
|
||||
FunctionExecutionInfo(
|
||||
function=f, function_name=function_name, max_calls=max_calls))
|
||||
self.worker.num_task_executions[driver_id][function_id.id()] = 0
|
||||
|
||||
try:
|
||||
function = pickle.loads(serialized_function)
|
||||
except Exception:
|
||||
# If an exception was thrown when the remote function was imported,
|
||||
# we record the traceback and notify the scheduler of the failure.
|
||||
traceback_str = utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
utils.push_error_to_driver(
|
||||
self.worker,
|
||||
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=driver_id,
|
||||
data={
|
||||
"function_id": function_id.id(),
|
||||
"function_name": function_name
|
||||
})
|
||||
else:
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
self.worker.function_execution_info[driver_id][
|
||||
function_id.id()] = (FunctionExecutionInfo(
|
||||
function=function,
|
||||
function_name=function_name,
|
||||
max_calls=max_calls))
|
||||
# Add the function to the function table.
|
||||
self.redis_client.rpush(b"FunctionTable:" + function_id.id(),
|
||||
self.worker.worker_id)
|
||||
|
||||
def fetch_and_execute_function_to_run(self, key):
|
||||
"""Run on arbitrary function on the worker."""
|
||||
(driver_id, serialized_function,
|
||||
|
||||
@@ -22,7 +22,7 @@ def compute_function_id(function):
|
||||
func: The actual function.
|
||||
|
||||
Returns:
|
||||
This returns the function ID.
|
||||
Raw bytes of the function id
|
||||
"""
|
||||
function_id_hash = hashlib.sha1()
|
||||
# Include the function module and name in the hash.
|
||||
@@ -39,8 +39,6 @@ def compute_function_id(function):
|
||||
# Compute the function ID.
|
||||
function_id = function_id_hash.digest()
|
||||
assert len(function_id) == ray_constants.ID_SIZE
|
||||
function_id = ray.ObjectID(function_id)
|
||||
|
||||
return function_id
|
||||
|
||||
|
||||
@@ -72,7 +70,7 @@ class RemoteFunction(object):
|
||||
# TODO(rkn): We store the function ID as a string, so that
|
||||
# RemoteFunction objects can be pickled. We should undo this when
|
||||
# we allow ObjectIDs to be pickled.
|
||||
self._function_id = compute_function_id(self._function).id()
|
||||
self._function_id = compute_function_id(function)
|
||||
self._function_name = (
|
||||
self._function.__module__ + '.' + self._function.__name__)
|
||||
self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS
|
||||
@@ -90,11 +88,7 @@ class RemoteFunction(object):
|
||||
|
||||
# # Export the function.
|
||||
worker = ray.worker.get_global_worker()
|
||||
if worker.mode == ray.worker.SCRIPT_MODE:
|
||||
self._export()
|
||||
elif worker.mode is None:
|
||||
worker.cached_remote_functions_and_actors.append(
|
||||
("remote_function", self))
|
||||
worker.function_actor_manager.export(self)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("Remote functions cannot be called directly. Instead "
|
||||
@@ -141,9 +135,3 @@ class RemoteFunction(object):
|
||||
return object_ids[0]
|
||||
elif len(object_ids) > 1:
|
||||
return object_ids
|
||||
|
||||
def _export(self):
|
||||
worker = ray.worker.get_global_worker()
|
||||
worker.export_remote_function(
|
||||
ray.ObjectID(self._function_id), self._function_name,
|
||||
self._function, self._max_calls, self)
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import print_function
|
||||
import binascii
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import numpy as np
|
||||
import os
|
||||
import subprocess
|
||||
@@ -144,6 +145,23 @@ def is_cython(obj):
|
||||
(hasattr(obj, "__func__") and check_cython(obj.__func__))
|
||||
|
||||
|
||||
def is_function_or_method(obj):
|
||||
"""Check if an object is a function or method.
|
||||
|
||||
Args:
|
||||
obj: The Python object in question.
|
||||
|
||||
Returns:
|
||||
True if the object is an function or method.
|
||||
"""
|
||||
return (inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj))
|
||||
|
||||
|
||||
def is_class_method(f):
|
||||
"""Returns whether the given method is a class_method."""
|
||||
return hasattr(f, "__self__") and f.__self__ is not None
|
||||
|
||||
|
||||
def random_string():
|
||||
"""Generate a random string to use as an ID.
|
||||
|
||||
|
||||
+26
-159
@@ -3,7 +3,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import atexit
|
||||
import collections
|
||||
import colorama
|
||||
import hashlib
|
||||
import inspect
|
||||
@@ -33,6 +32,7 @@ import ray.plasma
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray import import_thread
|
||||
from ray import profiling
|
||||
from ray.function_manager import FunctionActorManager
|
||||
from ray.utils import (
|
||||
binary_to_hex,
|
||||
check_oversized_pickle,
|
||||
@@ -176,11 +176,6 @@ class RayGetArgumentError(Exception):
|
||||
self.task_error))
|
||||
|
||||
|
||||
FunctionExecutionInfo = collections.namedtuple(
|
||||
"FunctionExecutionInfo", ["function", "function_name", "max_calls"])
|
||||
"""FunctionExecutionInfo: A named tuple storing remote function information."""
|
||||
|
||||
|
||||
class Worker(object):
|
||||
"""A class used to define the control flow of a worker process.
|
||||
|
||||
@@ -189,19 +184,9 @@ class Worker(object):
|
||||
functions outside of this class are considered exposed.
|
||||
|
||||
Attributes:
|
||||
function_execution_info (Dict[str, FunctionExecutionInfo]): A
|
||||
dictionary mapping the name of a remote function to the remote
|
||||
function itself. This is the set of remote functions that can be
|
||||
executed by this worker.
|
||||
connected (bool): True if Ray has been started and False otherwise.
|
||||
mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and
|
||||
WORKER_MODE.
|
||||
cached_remote_functions_and_actors: A list of information for exporting
|
||||
remote functions and actor classes definitions that were defined
|
||||
before the worker called connect. When the worker eventually does
|
||||
call connect, if it is a driver, it will export these functions and
|
||||
actors. If cached_remote_functions_and_actors is None, that means
|
||||
that connect has been called already.
|
||||
cached_functions_to_run (List): A list of functions to run on all of
|
||||
the workers that should be exported as soon as connect is called.
|
||||
profiler: the profiler used to aggregate profiling information.
|
||||
@@ -216,24 +201,15 @@ class Worker(object):
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a Worker object."""
|
||||
# This field is a dictionary that maps a driver ID to a dictionary of
|
||||
# functions (and information about those functions) that have been
|
||||
# registered for that driver (this inner dictionary maps function IDs
|
||||
# to a FunctionExecutionInfo object. This should only be used on
|
||||
# workers that execute remote functions.
|
||||
self.function_execution_info = collections.defaultdict(lambda: {})
|
||||
# This is a dictionary mapping driver ID to a dictionary that maps
|
||||
# remote function IDs for that driver to a counter of the number of
|
||||
# times that remote function has been executed on this worker. The
|
||||
# counter is incremented every time the function is executed on this
|
||||
# worker. When the counter reaches the maximum number of executions
|
||||
# allowed for a particular function, the worker is killed.
|
||||
self.num_task_executions = collections.defaultdict(lambda: {})
|
||||
self.connected = False
|
||||
self.mode = None
|
||||
self.cached_remote_functions_and_actors = []
|
||||
self.cached_functions_to_run = []
|
||||
self.fetch_and_register_actor = None
|
||||
self.actor_init_error = None
|
||||
self.make_actor = None
|
||||
self.actors = {}
|
||||
@@ -255,6 +231,7 @@ class Worker(object):
|
||||
self.serialization_context_map = {}
|
||||
# Identity of the driver that this worker is processing.
|
||||
self.task_driver_id = None
|
||||
self.function_actor_manager = FunctionActorManager(self)
|
||||
|
||||
def mark_actor_init_failed(self, error):
|
||||
"""Called to mark this actor as failed during initialization."""
|
||||
@@ -674,57 +651,6 @@ class Worker(object):
|
||||
|
||||
return task.returns()
|
||||
|
||||
def export_remote_function(self, function_id, function_name, function,
|
||||
max_calls, decorated_function):
|
||||
"""Export a remote function.
|
||||
|
||||
Args:
|
||||
function_id: The ID of the function.
|
||||
function_name: The name of the function.
|
||||
function: The raw undecorated function to export.
|
||||
max_calls: The maximum number of times a given worker can execute
|
||||
this function before exiting.
|
||||
decorated_function: The decorated function (this is used to enable
|
||||
the remote function to recursively call itself).
|
||||
"""
|
||||
if self.mode != SCRIPT_MODE:
|
||||
raise Exception("export_remote_function can only be called on a "
|
||||
"driver.")
|
||||
|
||||
key = (b"RemoteFunction:" + self.task_driver_id.id() + b":" +
|
||||
function_id.id())
|
||||
|
||||
# Work around limitations of Python pickling.
|
||||
function_name_global_valid = function.__name__ in function.__globals__
|
||||
function_name_global_value = function.__globals__.get(
|
||||
function.__name__)
|
||||
# Allow the function to reference itself as a global variable
|
||||
if not is_cython(function):
|
||||
function.__globals__[function.__name__] = decorated_function
|
||||
try:
|
||||
pickled_function = pickle.dumps(function)
|
||||
finally:
|
||||
# Undo our changes
|
||||
if function_name_global_valid:
|
||||
function.__globals__[function.__name__] = (
|
||||
function_name_global_value)
|
||||
else:
|
||||
del function.__globals__[function.__name__]
|
||||
|
||||
check_oversized_pickle(pickled_function, function_name,
|
||||
"remote function", self)
|
||||
|
||||
self.redis_client.hmset(
|
||||
key, {
|
||||
"driver_id": self.task_driver_id.id(),
|
||||
"function_id": function_id.id(),
|
||||
"name": function_name,
|
||||
"module": function.__module__,
|
||||
"function": pickled_function,
|
||||
"max_calls": max_calls
|
||||
})
|
||||
self.redis_client.rpush("Exports", key)
|
||||
|
||||
def run_function_on_all_workers(self, function,
|
||||
run_on_other_drivers=False):
|
||||
"""Run arbitrary code on all of the workers.
|
||||
@@ -783,47 +709,6 @@ class Worker(object):
|
||||
# operations into a transaction (or by implementing a custom
|
||||
# command that does all three things).
|
||||
|
||||
def _wait_for_function(self, function_id, driver_id, timeout=10):
|
||||
"""Wait until the function to be executed is present on this worker.
|
||||
|
||||
This method will simply loop until the import thread has imported the
|
||||
relevant function. If we spend too long in this loop, that may indicate
|
||||
a problem somewhere and we will push an error message to the user.
|
||||
|
||||
If this worker is an actor, then this will wait until the actor has
|
||||
been defined.
|
||||
|
||||
Args:
|
||||
function_id (str): The ID of the function that we want to execute.
|
||||
driver_id (str): The ID of the driver to push the error message to
|
||||
if this times out.
|
||||
"""
|
||||
start_time = time.time()
|
||||
# Only send the warning once.
|
||||
warning_sent = False
|
||||
while True:
|
||||
with self.lock:
|
||||
if (self.actor_id == NIL_ACTOR_ID
|
||||
and (function_id.id() in
|
||||
self.function_execution_info[driver_id])):
|
||||
break
|
||||
elif self.actor_id != NIL_ACTOR_ID and (
|
||||
self.actor_id in self.actors):
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
warning_message = ("This worker was asked to execute a "
|
||||
"function that it does not have "
|
||||
"registered. You may have to restart "
|
||||
"Ray.")
|
||||
if not warning_sent:
|
||||
ray.utils.push_error_to_driver(
|
||||
self,
|
||||
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=driver_id)
|
||||
warning_sent = True
|
||||
time.sleep(0.001)
|
||||
|
||||
def _get_arguments_for_execution(self, function_name, serialized_args):
|
||||
"""Retrieve the arguments for the remote function.
|
||||
|
||||
@@ -891,7 +776,7 @@ class Worker(object):
|
||||
|
||||
self.put_object(object_ids[i], outputs[i])
|
||||
|
||||
def _process_task(self, task):
|
||||
def _process_task(self, task, function_execution_info):
|
||||
"""Execute a task assigned to this worker.
|
||||
|
||||
This method deserializes a task from the scheduler, and attempts to
|
||||
@@ -913,10 +798,8 @@ class Worker(object):
|
||||
return_object_ids = task.returns()
|
||||
if task.actor_id().id() != NIL_ACTOR_ID:
|
||||
dummy_return_id = return_object_ids.pop()
|
||||
function_executor = self.function_execution_info[
|
||||
self.task_driver_id.id()][function_id.id()].function
|
||||
function_name = self.function_execution_info[self.task_driver_id.id()][
|
||||
function_id.id()].function_name
|
||||
function_executor = function_execution_info.function
|
||||
function_name = function_execution_info.function_name
|
||||
|
||||
# Get task arguments from the object store.
|
||||
try:
|
||||
@@ -926,12 +809,12 @@ class Worker(object):
|
||||
arguments = self._get_arguments_for_execution(
|
||||
function_name, args)
|
||||
except (RayGetError, RayGetArgumentError) as e:
|
||||
self._handle_process_task_failure(function_id, return_object_ids,
|
||||
e, None)
|
||||
self._handle_process_task_failure(function_id, function_name,
|
||||
return_object_ids, e, None)
|
||||
return
|
||||
except Exception as e:
|
||||
self._handle_process_task_failure(
|
||||
function_id, return_object_ids, e,
|
||||
function_id, function_name, return_object_ids, e,
|
||||
ray.utils.format_error_message(traceback.format_exc()))
|
||||
return
|
||||
|
||||
@@ -950,8 +833,9 @@ class Worker(object):
|
||||
task_exception = task.actor_id().id() == NIL_ACTOR_ID
|
||||
traceback_str = ray.utils.format_error_message(
|
||||
traceback.format_exc(), task_exception=task_exception)
|
||||
self._handle_process_task_failure(function_id, return_object_ids,
|
||||
e, traceback_str)
|
||||
self._handle_process_task_failure(function_id, function_name,
|
||||
return_object_ids, e,
|
||||
traceback_str)
|
||||
return
|
||||
|
||||
# Store the outputs in the local object store.
|
||||
@@ -966,13 +850,11 @@ class Worker(object):
|
||||
self._store_outputs_in_objstore(return_object_ids, outputs)
|
||||
except Exception as e:
|
||||
self._handle_process_task_failure(
|
||||
function_id, return_object_ids, e,
|
||||
function_id, function_name, return_object_ids, e,
|
||||
ray.utils.format_error_message(traceback.format_exc()))
|
||||
|
||||
def _handle_process_task_failure(self, function_id, return_object_ids,
|
||||
error, backtrace):
|
||||
function_name = self.function_execution_info[self.task_driver_id.id()][
|
||||
function_id.id()].function_name
|
||||
def _handle_process_task_failure(self, function_id, function_name,
|
||||
return_object_ids, error, backtrace):
|
||||
failure_object = RayTaskError(function_name, error, backtrace)
|
||||
failure_objects = [
|
||||
failure_object for _ in range(len(return_object_ids))
|
||||
@@ -1014,7 +896,7 @@ class Worker(object):
|
||||
time.sleep(0.001)
|
||||
|
||||
with self.lock:
|
||||
self.fetch_and_register_actor(key, self)
|
||||
self.function_actor_manager.fetch_and_register_actor(key)
|
||||
|
||||
def _wait_for_and_process_task(self, task):
|
||||
"""Wait for a task to be ready and process the task.
|
||||
@@ -1031,11 +913,8 @@ class Worker(object):
|
||||
self._become_actor(task)
|
||||
return
|
||||
|
||||
# Wait until the function to be executed has actually been registered
|
||||
# on this worker. We will push warnings to the user if we spend too
|
||||
# long in this loop.
|
||||
with profiling.profile("wait_for_function", worker=self):
|
||||
self._wait_for_function(function_id, driver_id)
|
||||
execution_info = self.function_actor_manager.get_execution_info(
|
||||
driver_id, function_id)
|
||||
|
||||
# Execute the task.
|
||||
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
|
||||
@@ -1043,9 +922,7 @@ class Worker(object):
|
||||
# because that may indicate that the system is hanging, and it'd be
|
||||
# good to know where the system is hanging.
|
||||
with self.lock:
|
||||
|
||||
function_name = (self.function_execution_info[driver_id][
|
||||
function_id.id()]).function_name
|
||||
function_name = execution_info.function_name
|
||||
if not self.use_raylet:
|
||||
extra_data = {
|
||||
"function_name": function_name,
|
||||
@@ -1058,7 +935,7 @@ class Worker(object):
|
||||
"task_id": task.task_id().hex()
|
||||
}
|
||||
with profiling.profile("task", extra_data=extra_data, worker=self):
|
||||
self._process_task(task)
|
||||
self._process_task(task, execution_info)
|
||||
|
||||
# In the non-raylet code path, push all of the log events to the global
|
||||
# state store. In the raylet code path, this is done periodically in a
|
||||
@@ -1067,11 +944,11 @@ class Worker(object):
|
||||
self.profiler.flush_profile_data()
|
||||
|
||||
# Increase the task execution counter.
|
||||
self.num_task_executions[driver_id][function_id.id()] += 1
|
||||
self.function_actor_manager.increase_task_counter(
|
||||
driver_id, function_id.id())
|
||||
|
||||
reached_max_executions = (
|
||||
self.num_task_executions[driver_id][function_id.id()] == self.
|
||||
function_execution_info[driver_id][function_id.id()].max_calls)
|
||||
reached_max_executions = (self.function_actor_manager.get_task_counter(
|
||||
driver_id, function_id.id()) == execution_info.max_calls)
|
||||
if reached_max_executions:
|
||||
self.local_scheduler_client.disconnect()
|
||||
os._exit(0)
|
||||
@@ -2112,7 +1989,6 @@ def connect(info,
|
||||
error_message = "Perhaps you called ray.init twice by accident?"
|
||||
assert not worker.connected, error_message
|
||||
assert worker.cached_functions_to_run is not None, error_message
|
||||
assert worker.cached_remote_functions_and_actors is not None, error_message
|
||||
# Initialize some fields.
|
||||
worker.worker_id = random_string()
|
||||
|
||||
@@ -2350,18 +2226,9 @@ def connect(info,
|
||||
# Export cached functions_to_run.
|
||||
for function in worker.cached_functions_to_run:
|
||||
worker.run_function_on_all_workers(function)
|
||||
# Export cached remote functions to the workers.
|
||||
for cached_type, info in worker.cached_remote_functions_and_actors:
|
||||
if cached_type == "remote_function":
|
||||
info._export()
|
||||
elif cached_type == "actor":
|
||||
(key, actor_class_info) = info
|
||||
ray.actor.publish_actor_class_to_key(key, actor_class_info,
|
||||
worker)
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
# Export cached remote functions and actors to the workers.
|
||||
worker.function_actor_manager.export_cached()
|
||||
worker.cached_functions_to_run = None
|
||||
worker.cached_remote_functions_and_actors = None
|
||||
|
||||
|
||||
def disconnect(worker=global_worker):
|
||||
@@ -2372,7 +2239,7 @@ def disconnect(worker=global_worker):
|
||||
# tests.
|
||||
worker.connected = False
|
||||
worker.cached_functions_to_run = []
|
||||
worker.cached_remote_functions_and_actors = []
|
||||
worker.function_actor_manager.reset_cache()
|
||||
worker.serialization_context_map.clear()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user