Move function/actor exporting & loading code to function_manager.py (#3003)

Move function/actor exporting & loading code to function_manager.py to prepare the code change for function descriptor for python.
This commit is contained in:
Yuhong Guo
2018-10-04 07:21:04 +08:00
committed by Robert Nishihara
parent d73ee36e60
commit 9948e8c11b
6 changed files with 559 additions and 502 deletions
+24 -275
View File
@@ -5,31 +5,19 @@ from __future__ import print_function
import copy
import hashlib
import inspect
import json
import traceback
import ray.cloudpickle as pickle
from ray.function_manager import FunctionActorManager
import ray.local_scheduler
import ray.ray_constants as ray_constants
import ray.signature as signature
import ray.worker
from ray.utils import (
decode,
_random_string,
check_oversized_pickle,
is_cython,
push_error_to_driver,
)
from ray.utils import _random_string
DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1
def is_classmethod(f):
"""Returns whether the given method is a classmethod."""
return hasattr(f, "__self__") and f.__self__ is not None
def compute_actor_handle_id(actor_handle_id, num_forks):
"""Deterministically compute an actor handle ID.
@@ -96,24 +84,6 @@ def compute_actor_creation_function_id(class_id):
return ray.ObjectID(class_id)
def compute_actor_method_function_id(class_name, attr):
"""Get the function ID corresponding to an actor method.
Args:
class_name (str): The class name of the actor.
attr (str): The attribute name of the method.
Returns:
Function ID corresponding to the method.
"""
function_id_hash = hashlib.sha1()
function_id_hash.update(class_name.encode("ascii"))
function_id_hash.update(attr.encode("ascii"))
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return ray.ObjectID(function_id)
def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
frontier):
"""Set the most recent checkpoint associated with a given actor ID.
@@ -134,28 +104,6 @@ def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
})
def get_actor_checkpoint(worker, actor_id):
"""Get the most recent checkpoint associated with a given actor ID.
Args:
worker: The worker to use to get the checkpoint.
actor_id: The actor ID of the actor to get the checkpoint for.
Returns:
If a checkpoint exists, this returns a tuple of the number of tasks
included in the checkpoint, the saved checkpoint state, and the
task frontier at the time of the checkpoint. If no checkpoint
exists, all objects are set to None. The checkpoint index is the .
executed on the actor before the checkpoint was made.
"""
actor_key = b"Actor:" + actor_id
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
if checkpoint_index is not None:
checkpoint_index = int(checkpoint_index)
return checkpoint_index, checkpoint, frontier
def save_and_log_checkpoint(worker, actor):
"""Save a checkpoint on the actor and log any errors.
@@ -205,219 +153,26 @@ def restore_and_log_checkpoint(worker, actor):
return checkpoint_resumed
def make_actor_method_executor(worker, method_name, method, actor_imported):
"""Make an executor that wraps a user-defined actor method.
The wrapped method updates the worker's internal state and performs any
necessary checkpointing operations.
def get_actor_checkpoint(worker, actor_id):
"""Get the most recent checkpoint associated with a given actor ID.
Args:
worker (Worker): The worker that is executing the actor.
method_name (str): The name of the actor method.
method (instancemethod): The actor method to wrap. This should be a
method defined on the actor class and should therefore take an
instance of the actor as the first argument.
actor_imported (bool): Whether the actor has been imported.
Checkpointing operations will not be run if this is set to False.
worker: The worker to use to get the checkpoint.
actor_id: The actor ID of the actor to get the checkpoint for.
Returns:
A function that executes the given actor method on the worker's stored
instance of the actor. The function also updates the worker's
internal state to record the executed method.
If a checkpoint exists, this returns a tuple of the number of tasks
included in the checkpoint, the saved checkpoint state, and the
task frontier at the time of the checkpoint. If no checkpoint
exists, all objects are set to None. The checkpoint index is the .
executed on the actor before the checkpoint was made.
"""
def actor_method_executor(dummy_return_id, actor, *args):
# Update the actor's task counter to reflect the task we're about to
# execute.
worker.actor_task_counter += 1
# If this is the first task to execute on the actor, try to resume from
# a checkpoint.
if actor_imported and worker.actor_task_counter == 1:
checkpoint_resumed = restore_and_log_checkpoint(worker, actor)
if checkpoint_resumed:
# NOTE(swang): Since we did not actually execute the __init__
# method, this will put None as the return value. If the
# __init__ method is supposed to return multiple values, an
# exception will be logged.
return
# Determine whether we should checkpoint the actor.
checkpointing_on = (actor_imported
and worker.actor_checkpoint_interval > 0)
# We should checkpoint the actor if user checkpointing is on, we've
# executed checkpoint_interval tasks since the last checkpoint, and the
# method we're about to execute is not a checkpoint.
save_checkpoint = (
checkpointing_on and
(worker.actor_task_counter % worker.actor_checkpoint_interval == 0
and method_name != "__ray_checkpoint__"))
# Execute the assigned method and save a checkpoint if necessary.
try:
if is_classmethod(method):
method_returns = method(*args)
else:
method_returns = method(actor, *args)
except Exception:
# Save the checkpoint before allowing the method exception to be
# thrown.
if save_checkpoint:
save_and_log_checkpoint(worker, actor)
raise
else:
# Save the checkpoint before returning the method's return values.
if save_checkpoint:
save_and_log_checkpoint(worker, actor)
return method_returns
return actor_method_executor
def fetch_and_register_actor(actor_class_key, worker):
"""Import an actor.
This will be called by the worker's import thread when the worker receives
the actor_class export, assuming that the worker is an actor for that
class.
Args:
actor_class_key: The key in Redis to use to fetch the actor.
worker: The worker to use.
"""
actor_id_str = worker.actor_id
(driver_id, class_id, class_name, module, pickled_class,
checkpoint_interval, actor_method_names) = worker.redis_client.hmget(
actor_class_key, [
"driver_id", "class_id", "class_name", "module", "class",
"checkpoint_interval", "actor_method_names"
])
class_name = decode(class_name)
module = decode(module)
checkpoint_interval = int(checkpoint_interval)
actor_method_names = json.loads(decode(actor_method_names))
# Create a temporary actor with some temporary methods so that if the actor
# fails to be unpickled, the temporary actor can be used (just to produce
# error messages and to prevent the driver from hanging).
class TemporaryActor(object):
pass
worker.actors[actor_id_str] = TemporaryActor()
worker.actor_checkpoint_interval = checkpoint_interval
def temporary_actor_method(*xs):
raise Exception("The actor with name {} failed to be imported, and so "
"cannot execute this method".format(class_name))
# Register the actor method executors.
for actor_method_name in actor_method_names:
function_id = compute_actor_method_function_id(class_name,
actor_method_name).id()
temporary_executor = make_actor_method_executor(
worker,
actor_method_name,
temporary_actor_method,
actor_imported=False)
worker.function_execution_info[driver_id][function_id] = (
ray.worker.FunctionExecutionInfo(
function=temporary_executor,
function_name=actor_method_name,
max_calls=0))
worker.num_task_executions[driver_id][function_id] = 0
try:
unpickled_class = pickle.loads(pickled_class)
worker.actor_class = unpickled_class
except Exception:
# If an exception was thrown when the actor was imported, we record the
# traceback and notify the scheduler of the failure.
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
push_error_to_driver(
worker,
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
traceback_str,
driver_id,
data={"actor_id": actor_id_str})
# TODO(rkn): In the future, it might make sense to have the worker exit
# here. However, currently that would lead to hanging if someone calls
# ray.get on a method invoked on the actor.
else:
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
def pred(x):
return (inspect.isfunction(x) or inspect.ismethod(x)
or is_cython(x))
actor_methods = inspect.getmembers(unpickled_class, predicate=pred)
for actor_method_name, actor_method in actor_methods:
function_id = compute_actor_method_function_id(
class_name, actor_method_name).id()
executor = make_actor_method_executor(
worker, actor_method_name, actor_method, actor_imported=True)
worker.function_execution_info[driver_id][function_id] = (
ray.worker.FunctionExecutionInfo(
function=executor,
function_name=actor_method_name,
max_calls=0))
# We do not set worker.function_properties[driver_id][function_id]
# because we currently do need the actor worker to submit new tasks
# for the actor.
def publish_actor_class_to_key(key, actor_class_info, worker):
"""Push an actor class definition to Redis.
The is factored out as a separate function because it is also called
on cached actor class definitions when a worker connects for the first
time.
Args:
key: The key to store the actor class info at.
actor_class_info: Information about the actor class.
worker: The worker to use to connect to Redis.
"""
# We set the driver ID here because it may not have been available when the
# actor class was defined.
actor_class_info["driver_id"] = worker.task_driver_id.id()
worker.redis_client.hmset(key, actor_class_info)
worker.redis_client.rpush("Exports", key)
def export_actor_class(class_id, Class, actor_method_names,
checkpoint_interval, worker):
key = b"ActorClass:" + class_id
actor_class_info = {
"class_name": Class.__name__,
"module": Class.__module__,
"class": pickle.dumps(Class),
"checkpoint_interval": checkpoint_interval,
"actor_method_names": json.dumps(list(actor_method_names))
}
check_oversized_pickle(actor_class_info["class"],
actor_class_info["class_name"], "actor", worker)
if worker.mode is None:
# This means that 'ray.init()' has not been called yet and so we must
# cache the actor class definition and export it when 'ray.init()' is
# called.
assert worker.cached_remote_functions_and_actors is not None
worker.cached_remote_functions_and_actors.append(
("actor", (key, actor_class_info)))
# This caching code path is currently not used because we only export
# actor class definitions lazily when we instantiate the actor for the
# first time.
assert False, "This should be unreachable."
else:
publish_actor_class_to_key(key, actor_class_info, worker)
# TODO(rkn): Currently we allow actor classes to be defined within tasks.
# I tried to disable this, but it may be necessary because of
# https://github.com/ray-project/ray/issues/1146.
actor_key = b"Actor:" + actor_id
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
if checkpoint_index is not None:
checkpoint_index = int(checkpoint_index)
return checkpoint_index, checkpoint, frontier
def method(*args, **kwargs):
@@ -518,13 +273,8 @@ class ActorClass(object):
self._actor_method_cpus = actor_method_cpus
self._exported = False
# Get the actor methods of the given class.
def pred(x):
return (inspect.isfunction(x) or inspect.ismethod(x)
or is_cython(x))
self._actor_methods = inspect.getmembers(
self._modified_class, predicate=pred)
self._modified_class, ray.utils.is_function_or_method)
# Extract the signatures of each of the methods. This will be used
# to catch some errors if the methods are called with inappropriate
# arguments.
@@ -537,7 +287,7 @@ class ActorClass(object):
# don't support, there may not be much the user can do about it.
signature.check_signature_supported(method, warn=True)
self._method_signatures[method_name] = signature.extract_signature(
method, ignore_first=not is_classmethod(method))
method, ignore_first=not ray.utils.is_class_method(method))
# Set the default number of return values for this method.
if hasattr(method, "__ray_num_return_vals__"):
@@ -614,9 +364,9 @@ class ActorClass(object):
else:
# Export the actor.
if not self._exported:
export_actor_class(self._class_id, self._modified_class,
self._actor_method_names,
self._checkpoint_interval, worker)
worker.function_actor_manager.export_actor_class(
self._class_id, self._modified_class,
self._actor_method_names, self._checkpoint_interval)
self._exported = True
resources = ray.utils.resources_from_resource_arguments(
@@ -801,8 +551,8 @@ class ActorHandle(object):
else:
actor_handle_id = self._ray_actor_handle_id
function_id = compute_actor_method_function_id(self._ray_class_name,
method_name)
function_id = FunctionActorManager.compute_actor_method_function_id(
self._ray_class_name, method_name)
object_ids = worker.submit_task(
function_id,
args,
@@ -1068,5 +818,4 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
resources, actor_method_cpus)
ray.worker.global_worker.fetch_and_register_actor = fetch_and_register_actor
ray.worker.global_worker.make_actor = make_actor
+486
View File
@@ -0,0 +1,486 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import inspect
import json
import time
import traceback
from collections import (
namedtuple,
defaultdict,
)
import ray
from ray import profiling
from ray import ray_constants
from ray import cloudpickle as pickle
from ray.utils import (
is_cython,
is_function_or_method,
is_class_method,
check_oversized_pickle,
decode,
format_error_message,
push_error_to_driver,
)
FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
["function", "function_name", "max_calls"])
"""FunctionExecutionInfo: A named tuple storing remote function information."""
class FunctionActorManager(object):
"""A class used to export/load remote functions and actors.
Attributes:
_worker: The associated worker that this manager related.
_functions_to_export: The remote functions to export when
the worker gets connected.
_actors_to_export: The actors to export when the worker gets
connected.
_function_execution_info: The map from driver_id to finction_id
and execution_info.
_num_task_executions: The map from driver_id to function
execution times.
"""
def __init__(self, worker):
self._worker = worker
self._functions_to_export = []
self._actors_to_export = []
# This field is a dictionary that maps a driver ID to a dictionary of
# functions (and information about those functions) that have been
# registered for that driver (this inner dictionary maps function IDs
# to a FunctionExecutionInfo object. This should only be used on
# workers that execute remote functions.
self._function_execution_info = defaultdict(lambda: {})
self._num_task_executions = defaultdict(lambda: {})
def increase_task_counter(self, driver_id, function_id):
self._num_task_executions[driver_id][function_id] += 1
def get_task_counter(self, driver_id, function_id):
return self._num_task_executions[driver_id][function_id]
def export_cached(self):
"""Export cached remote functions
Note: this should be called only once when worker is connected.
"""
for remote_function in self._functions_to_export:
self._do_export(remote_function)
self._functions_to_export = None
for info in self._actors_to_export:
(key, actor_class_info) = info
self._publish_actor_class_to_key(key, actor_class_info)
def reset_cache(self):
self._functions_to_export = []
self._actors_to_export = []
def export(self, remote_function):
"""Export a remote function.
Args:
remote_function: the RemoteFunction object.
"""
if self._worker.mode is None:
# If the worker isn't connected, cache the function
# and export it later.
self._functions_to_export.append(remote_function)
return
if self._worker.mode != ray.worker.SCRIPT_MODE:
# Don't need to export if the worker is not a driver.
return
self._do_export(remote_function)
def _do_export(self, remote_function):
"""Pickle a remote function and export it to redis.
Args:
remote_function: the RemoteFunction object.
"""
# Work around limitations of Python pickling.
function = remote_function._function
function_name_global_valid = function.__name__ in function.__globals__
function_name_global_value = function.__globals__.get(
function.__name__)
# Allow the function to reference itself as a global variable
if not is_cython(function):
function.__globals__[function.__name__] = remote_function
try:
pickled_function = pickle.dumps(function)
finally:
# Undo our changes
if function_name_global_valid:
function.__globals__[function.__name__] = (
function_name_global_value)
else:
del function.__globals__[function.__name__]
check_oversized_pickle(pickled_function,
remote_function._function_name,
"remote function", self._worker)
key = (b"RemoteFunction:" + self._worker.task_driver_id.id() + b":" +
remote_function._function_id)
self._worker.redis_client.hmset(
key, {
"driver_id": self._worker.task_driver_id.id(),
"function_id": remote_function._function_id,
"name": remote_function._function_name,
"module": function.__module__,
"function": pickled_function,
"max_calls": remote_function._max_calls
})
self._worker.redis_client.rpush("Exports", key)
def fetch_and_register_remote_function(self, key):
"""Import a remote function."""
(driver_id, function_id_str, function_name, serialized_function,
num_return_vals, module, resources,
max_calls) = self._worker.redis_client.hmget(key, [
"driver_id", "function_id", "name", "function", "num_return_vals",
"module", "resources", "max_calls"
])
function_id = ray.ObjectID(function_id_str)
function_name = decode(function_name)
max_calls = int(max_calls)
module = decode(module)
# This is a placeholder in case the function can't be unpickled. This
# will be overwritten if the function is successfully registered.
def f():
raise Exception("This function was not imported properly.")
self._function_execution_info[driver_id][function_id.id()] = (
FunctionExecutionInfo(
function=f, function_name=function_name, max_calls=max_calls))
self._num_task_executions[driver_id][function_id.id()] = 0
try:
function = pickle.loads(serialized_function)
except Exception as e:
# If an exception was thrown when the remote function was imported,
# we record the traceback and notify the scheduler of the failure.
traceback_str = format_error_message(traceback.format_exc())
# Log the error message.
push_error_to_driver(
self._worker,
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
data={
"function_id": function_id.id(),
"function_name": function_name
})
else:
# The below line is necessary. Because in the driver process,
# if the function is defined in the file where the python script
# was started from, its module is `__main__`.
# However in the worker process, the `__main__` module is a
# different module, which is `default_worker.py`
function.__module__ = module
self._function_execution_info[driver_id][function_id.id()] = (
FunctionExecutionInfo(
function=function,
function_name=function_name,
max_calls=max_calls))
# Add the function to the function table.
self._worker.redis_client.rpush(
b"FunctionTable:" + function_id.id(), self._worker.worker_id)
def get_execution_info(self, driver_id, function_id):
"""Get the FunctionExecutionInfo of a remote function.
Args:
driver_id: ID of the driver that the function belongs to.
function_id: ID of the function to get.
Returns:
A FunctionExecutionInfo object.
"""
# Wait until the function to be executed has actually been registered
# on this worker. We will push warnings to the user if we spend too
# long in this loop.
with profiling.profile("wait_for_function", worker=self._worker):
self._wait_for_function(function_id, driver_id)
return self._function_execution_info[driver_id][function_id.id()]
def _wait_for_function(self, function_id, driver_id, timeout=10):
"""Wait until the function to be executed is present on this worker.
This method will simply loop until the import thread has imported the
relevant function. If we spend too long in this loop, that may indicate
a problem somewhere and we will push an error message to the user.
If this worker is an actor, then this will wait until the actor has
been defined.
Args:
function_id (str): The ID of the function that we want to execute.
driver_id (str): The ID of the driver to push the error message to
if this times out.
"""
start_time = time.time()
# Only send the warning once.
warning_sent = False
while True:
with self._worker.lock:
if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID
and (function_id.id() in
self._function_execution_info[driver_id])):
break
elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and (
self._worker.actor_id in self._worker.actors):
break
if time.time() - start_time > timeout:
warning_message = ("This worker was asked to execute a "
"function that it does not have "
"registered. You may have to restart "
"Ray.")
if not warning_sent:
ray.utils.push_error_to_driver(
self._worker,
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
warning_message,
driver_id=driver_id)
warning_sent = True
time.sleep(0.001)
@classmethod
def compute_actor_method_function_id(cls, class_name, attr):
"""Get the function ID corresponding to an actor method.
Args:
class_name (str): The class name of the actor.
attr (str): The attribute name of the method.
Returns:
Function ID corresponding to the method.
"""
function_id_hash = hashlib.sha1()
function_id_hash.update(class_name.encode("ascii"))
function_id_hash.update(attr.encode("ascii"))
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return ray.ObjectID(function_id)
def _publish_actor_class_to_key(self, key, actor_class_info):
"""Push an actor class definition to Redis.
The is factored out as a separate function because it is also called
on cached actor class definitions when a worker connects for the first
time.
Args:
key: The key to store the actor class info at.
actor_class_info: Information about the actor class.
worker: The worker to use to connect to Redis.
"""
# We set the driver ID here because it may not have been available when
# the actor class was defined.
actor_class_info["driver_id"] = self._worker.task_driver_id.id()
self._worker.redis_client.hmset(key, actor_class_info)
self._worker.redis_client.rpush("Exports", key)
def export_actor_class(self, class_id, Class, actor_method_names,
checkpoint_interval):
key = b"ActorClass:" + class_id
actor_class_info = {
"class_name": Class.__name__,
"module": Class.__module__,
"class": pickle.dumps(Class),
"checkpoint_interval": checkpoint_interval,
"actor_method_names": json.dumps(list(actor_method_names))
}
check_oversized_pickle(actor_class_info["class"],
actor_class_info["class_name"], "actor",
self._worker)
if self._worker.mode is None:
# This means that 'ray.init()' has not been called yet and so we
# must cache the actor class definition and export it when
# 'ray.init()' is called.
assert self._actors_to_export is not None
self._actors_to_export.append((key, actor_class_info))
# This caching code path is currently not used because we only
# export actor class definitions lazily when we instantiate the
# actor for the first time.
assert False, "This should be unreachable."
else:
self._publish_actor_class_to_key(key, actor_class_info)
# TODO(rkn): Currently we allow actor classes to be defined
# within tasks. I tried to disable this, but it may be necessary
# because of https://github.com/ray-project/ray/issues/1146.
def fetch_and_register_actor(self, actor_class_key):
"""Import an actor.
This will be called by the worker's import thread when the worker
receives the actor_class export, assuming that the worker is an actor
for that class.
Args:
actor_class_key: The key in Redis to use to fetch the actor.
worker: The worker to use.
"""
actor_id_str = self._worker.actor_id
(driver_id, class_id, class_name, module, pickled_class,
checkpoint_interval,
actor_method_names) = self._worker.redis_client.hmget(
actor_class_key, [
"driver_id", "class_id", "class_name", "module", "class",
"checkpoint_interval", "actor_method_names"
])
class_name = decode(class_name)
module = decode(module)
checkpoint_interval = int(checkpoint_interval)
actor_method_names = json.loads(decode(actor_method_names))
# Create a temporary actor with some temporary methods so that if
# the actor fails to be unpickled, the temporary actor can be used
# (just to produce error messages and to prevent the driver from
# hanging).
class TemporaryActor(object):
pass
self._worker.actors[actor_id_str] = TemporaryActor()
self._worker.actor_checkpoint_interval = checkpoint_interval
def temporary_actor_method(*xs):
raise Exception(
"The actor with name {} failed to be imported, "
"and so cannot execute this method".format(class_name))
# Register the actor method executors.
for actor_method_name in actor_method_names:
function_id = (
FunctionActorManager.compute_actor_method_function_id(
class_name, actor_method_name).id())
temporary_executor = self._make_actor_method_executor(
actor_method_name,
temporary_actor_method,
actor_imported=False)
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=temporary_executor,
function_name=actor_method_name,
max_calls=0))
self._num_task_executions[driver_id][function_id] = 0
try:
unpickled_class = pickle.loads(pickled_class)
self._worker.actor_class = unpickled_class
except Exception:
# If an exception was thrown when the actor was imported, we record
# the traceback and notify the scheduler of the failure.
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
# Log the error message.
push_error_to_driver(
self._worker,
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
traceback_str,
driver_id,
data={"actor_id": actor_id_str})
# TODO(rkn): In the future, it might make sense to have the worker
# exit here. However, currently that would lead to hanging if
# someone calls ray.get on a method invoked on the actor.
else:
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
self._worker.actors[actor_id_str] = unpickled_class.__new__(
unpickled_class)
actor_methods = inspect.getmembers(
unpickled_class, predicate=is_function_or_method)
for actor_method_name, actor_method in actor_methods:
function_id = (
FunctionActorManager.compute_actor_method_function_id(
class_name, actor_method_name).id())
executor = self._make_actor_method_executor(
actor_method_name, actor_method, actor_imported=True)
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=executor,
function_name=actor_method_name,
max_calls=0))
# We do not set function_properties[driver_id][function_id]
# because we currently do need the actor worker to submit new
# tasks for the actor.
def _make_actor_method_executor(self, method_name, method, actor_imported):
"""Make an executor that wraps a user-defined actor method.
The wrapped method updates the worker's internal state and performs any
necessary checkpointing operations.
Args:
worker (Worker): The worker that is executing the actor.
method_name (str): The name of the actor method.
method (instancemethod): The actor method to wrap. This should be a
method defined on the actor class and should therefore take an
instance of the actor as the first argument.
actor_imported (bool): Whether the actor has been imported.
Checkpointing operations will not be run if this is set to
False.
Returns:
A function that executes the given actor method on the worker's
stored instance of the actor. The function also updates the
worker's internal state to record the executed method.
"""
def actor_method_executor(dummy_return_id, actor, *args):
# Update the actor's task counter to reflect the task we're about
# to execute.
self._worker.actor_task_counter += 1
# If this is the first task to execute on the actor, try to resume
# from a checkpoint.
if actor_imported and self._worker.actor_task_counter == 1:
checkpoint_resumed = ray.actor.restore_and_log_checkpoint(
self._worker, actor)
if checkpoint_resumed:
# NOTE(swang): Since we did not actually execute the
# __init__ method, this will put None as the return value.
# If the __init__ method is supposed to return multiple
# values, an exception will be logged.
return
# Determine whether we should checkpoint the actor.
checkpointing_on = (actor_imported
and self._worker.actor_checkpoint_interval > 0)
# We should checkpoint the actor if user checkpointing is on, we've
# executed checkpoint_interval tasks since the last checkpoint, and
# the method we're about to execute is not a checkpoint.
save_checkpoint = (checkpointing_on
and (self._worker.actor_task_counter %
self._worker.actor_checkpoint_interval == 0
and method_name != "__ray_checkpoint__"))
# Execute the assigned method and save a checkpoint if necessary.
try:
if is_class_method(method):
method_returns = method(*args)
else:
method_returns = method(actor, *args)
except Exception:
# Save the checkpoint before allowing the method exception
# to be thrown.
if save_checkpoint:
ray.actor.save_and_log_checkpoint(self._worker, actor)
raise
else:
# Save the checkpoint before returning the method's return
# values.
if save_checkpoint:
ray.actor.save_and_log_checkpoint(self._worker, actor)
return method_returns
return actor_method_executor
+2 -53
View File
@@ -88,7 +88,8 @@ class ImportThread(object):
if key.startswith(b"RemoteFunction"):
with profiling.profile(
"register_remote_function", worker=self.worker):
self.fetch_and_register_remote_function(key)
(self.worker.function_actor_manager.
fetch_and_register_remote_function(key))
elif key.startswith(b"FunctionsToRun"):
with profiling.profile(
"fetch_and_run_function", worker=self.worker):
@@ -103,58 +104,6 @@ class ImportThread(object):
else:
raise Exception("This code should be unreachable.")
def fetch_and_register_remote_function(self, key):
"""Import a remote function."""
from ray.worker import FunctionExecutionInfo
(driver_id, function_id_str, function_name, serialized_function,
num_return_vals, module, resources,
max_calls) = self.redis_client.hmget(key, [
"driver_id", "function_id", "name", "function", "num_return_vals",
"module", "resources", "max_calls"
])
function_id = ray.ObjectID(function_id_str)
function_name = utils.decode(function_name)
max_calls = int(max_calls)
module = utils.decode(module)
# This is a placeholder in case the function can't be unpickled. This
# will be overwritten if the function is successfully registered.
def f():
raise Exception("This function was not imported properly.")
self.worker.function_execution_info[driver_id][function_id.id()] = (
FunctionExecutionInfo(
function=f, function_name=function_name, max_calls=max_calls))
self.worker.num_task_executions[driver_id][function_id.id()] = 0
try:
function = pickle.loads(serialized_function)
except Exception:
# If an exception was thrown when the remote function was imported,
# we record the traceback and notify the scheduler of the failure.
traceback_str = utils.format_error_message(traceback.format_exc())
# Log the error message.
utils.push_error_to_driver(
self.worker,
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
data={
"function_id": function_id.id(),
"function_name": function_name
})
else:
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
self.worker.function_execution_info[driver_id][
function_id.id()] = (FunctionExecutionInfo(
function=function,
function_name=function_name,
max_calls=max_calls))
# Add the function to the function table.
self.redis_client.rpush(b"FunctionTable:" + function_id.id(),
self.worker.worker_id)
def fetch_and_execute_function_to_run(self, key):
"""Run on arbitrary function on the worker."""
(driver_id, serialized_function,
+3 -15
View File
@@ -22,7 +22,7 @@ def compute_function_id(function):
func: The actual function.
Returns:
This returns the function ID.
Raw bytes of the function id
"""
function_id_hash = hashlib.sha1()
# Include the function module and name in the hash.
@@ -39,8 +39,6 @@ def compute_function_id(function):
# Compute the function ID.
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
function_id = ray.ObjectID(function_id)
return function_id
@@ -72,7 +70,7 @@ class RemoteFunction(object):
# TODO(rkn): We store the function ID as a string, so that
# RemoteFunction objects can be pickled. We should undo this when
# we allow ObjectIDs to be pickled.
self._function_id = compute_function_id(self._function).id()
self._function_id = compute_function_id(function)
self._function_name = (
self._function.__module__ + '.' + self._function.__name__)
self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS
@@ -90,11 +88,7 @@ class RemoteFunction(object):
# # Export the function.
worker = ray.worker.get_global_worker()
if worker.mode == ray.worker.SCRIPT_MODE:
self._export()
elif worker.mode is None:
worker.cached_remote_functions_and_actors.append(
("remote_function", self))
worker.function_actor_manager.export(self)
def __call__(self, *args, **kwargs):
raise Exception("Remote functions cannot be called directly. Instead "
@@ -141,9 +135,3 @@ class RemoteFunction(object):
return object_ids[0]
elif len(object_ids) > 1:
return object_ids
def _export(self):
worker = ray.worker.get_global_worker()
worker.export_remote_function(
ray.ObjectID(self._function_id), self._function_name,
self._function, self._max_calls, self)
+18
View File
@@ -5,6 +5,7 @@ from __future__ import print_function
import binascii
import functools
import hashlib
import inspect
import numpy as np
import os
import subprocess
@@ -144,6 +145,23 @@ def is_cython(obj):
(hasattr(obj, "__func__") and check_cython(obj.__func__))
def is_function_or_method(obj):
"""Check if an object is a function or method.
Args:
obj: The Python object in question.
Returns:
True if the object is an function or method.
"""
return (inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj))
def is_class_method(f):
"""Returns whether the given method is a class_method."""
return hasattr(f, "__self__") and f.__self__ is not None
def random_string():
"""Generate a random string to use as an ID.
+26 -159
View File
@@ -3,7 +3,6 @@ from __future__ import division
from __future__ import print_function
import atexit
import collections
import colorama
import hashlib
import inspect
@@ -33,6 +32,7 @@ import ray.plasma
import ray.ray_constants as ray_constants
from ray import import_thread
from ray import profiling
from ray.function_manager import FunctionActorManager
from ray.utils import (
binary_to_hex,
check_oversized_pickle,
@@ -176,11 +176,6 @@ class RayGetArgumentError(Exception):
self.task_error))
FunctionExecutionInfo = collections.namedtuple(
"FunctionExecutionInfo", ["function", "function_name", "max_calls"])
"""FunctionExecutionInfo: A named tuple storing remote function information."""
class Worker(object):
"""A class used to define the control flow of a worker process.
@@ -189,19 +184,9 @@ class Worker(object):
functions outside of this class are considered exposed.
Attributes:
function_execution_info (Dict[str, FunctionExecutionInfo]): A
dictionary mapping the name of a remote function to the remote
function itself. This is the set of remote functions that can be
executed by this worker.
connected (bool): True if Ray has been started and False otherwise.
mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and
WORKER_MODE.
cached_remote_functions_and_actors: A list of information for exporting
remote functions and actor classes definitions that were defined
before the worker called connect. When the worker eventually does
call connect, if it is a driver, it will export these functions and
actors. If cached_remote_functions_and_actors is None, that means
that connect has been called already.
cached_functions_to_run (List): A list of functions to run on all of
the workers that should be exported as soon as connect is called.
profiler: the profiler used to aggregate profiling information.
@@ -216,24 +201,15 @@ class Worker(object):
def __init__(self):
"""Initialize a Worker object."""
# This field is a dictionary that maps a driver ID to a dictionary of
# functions (and information about those functions) that have been
# registered for that driver (this inner dictionary maps function IDs
# to a FunctionExecutionInfo object. This should only be used on
# workers that execute remote functions.
self.function_execution_info = collections.defaultdict(lambda: {})
# This is a dictionary mapping driver ID to a dictionary that maps
# remote function IDs for that driver to a counter of the number of
# times that remote function has been executed on this worker. The
# counter is incremented every time the function is executed on this
# worker. When the counter reaches the maximum number of executions
# allowed for a particular function, the worker is killed.
self.num_task_executions = collections.defaultdict(lambda: {})
self.connected = False
self.mode = None
self.cached_remote_functions_and_actors = []
self.cached_functions_to_run = []
self.fetch_and_register_actor = None
self.actor_init_error = None
self.make_actor = None
self.actors = {}
@@ -255,6 +231,7 @@ class Worker(object):
self.serialization_context_map = {}
# Identity of the driver that this worker is processing.
self.task_driver_id = None
self.function_actor_manager = FunctionActorManager(self)
def mark_actor_init_failed(self, error):
"""Called to mark this actor as failed during initialization."""
@@ -674,57 +651,6 @@ class Worker(object):
return task.returns()
def export_remote_function(self, function_id, function_name, function,
max_calls, decorated_function):
"""Export a remote function.
Args:
function_id: The ID of the function.
function_name: The name of the function.
function: The raw undecorated function to export.
max_calls: The maximum number of times a given worker can execute
this function before exiting.
decorated_function: The decorated function (this is used to enable
the remote function to recursively call itself).
"""
if self.mode != SCRIPT_MODE:
raise Exception("export_remote_function can only be called on a "
"driver.")
key = (b"RemoteFunction:" + self.task_driver_id.id() + b":" +
function_id.id())
# Work around limitations of Python pickling.
function_name_global_valid = function.__name__ in function.__globals__
function_name_global_value = function.__globals__.get(
function.__name__)
# Allow the function to reference itself as a global variable
if not is_cython(function):
function.__globals__[function.__name__] = decorated_function
try:
pickled_function = pickle.dumps(function)
finally:
# Undo our changes
if function_name_global_valid:
function.__globals__[function.__name__] = (
function_name_global_value)
else:
del function.__globals__[function.__name__]
check_oversized_pickle(pickled_function, function_name,
"remote function", self)
self.redis_client.hmset(
key, {
"driver_id": self.task_driver_id.id(),
"function_id": function_id.id(),
"name": function_name,
"module": function.__module__,
"function": pickled_function,
"max_calls": max_calls
})
self.redis_client.rpush("Exports", key)
def run_function_on_all_workers(self, function,
run_on_other_drivers=False):
"""Run arbitrary code on all of the workers.
@@ -783,47 +709,6 @@ class Worker(object):
# operations into a transaction (or by implementing a custom
# command that does all three things).
def _wait_for_function(self, function_id, driver_id, timeout=10):
"""Wait until the function to be executed is present on this worker.
This method will simply loop until the import thread has imported the
relevant function. If we spend too long in this loop, that may indicate
a problem somewhere and we will push an error message to the user.
If this worker is an actor, then this will wait until the actor has
been defined.
Args:
function_id (str): The ID of the function that we want to execute.
driver_id (str): The ID of the driver to push the error message to
if this times out.
"""
start_time = time.time()
# Only send the warning once.
warning_sent = False
while True:
with self.lock:
if (self.actor_id == NIL_ACTOR_ID
and (function_id.id() in
self.function_execution_info[driver_id])):
break
elif self.actor_id != NIL_ACTOR_ID and (
self.actor_id in self.actors):
break
if time.time() - start_time > timeout:
warning_message = ("This worker was asked to execute a "
"function that it does not have "
"registered. You may have to restart "
"Ray.")
if not warning_sent:
ray.utils.push_error_to_driver(
self,
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
warning_message,
driver_id=driver_id)
warning_sent = True
time.sleep(0.001)
def _get_arguments_for_execution(self, function_name, serialized_args):
"""Retrieve the arguments for the remote function.
@@ -891,7 +776,7 @@ class Worker(object):
self.put_object(object_ids[i], outputs[i])
def _process_task(self, task):
def _process_task(self, task, function_execution_info):
"""Execute a task assigned to this worker.
This method deserializes a task from the scheduler, and attempts to
@@ -913,10 +798,8 @@ class Worker(object):
return_object_ids = task.returns()
if task.actor_id().id() != NIL_ACTOR_ID:
dummy_return_id = return_object_ids.pop()
function_executor = self.function_execution_info[
self.task_driver_id.id()][function_id.id()].function
function_name = self.function_execution_info[self.task_driver_id.id()][
function_id.id()].function_name
function_executor = function_execution_info.function
function_name = function_execution_info.function_name
# Get task arguments from the object store.
try:
@@ -926,12 +809,12 @@ class Worker(object):
arguments = self._get_arguments_for_execution(
function_name, args)
except (RayGetError, RayGetArgumentError) as e:
self._handle_process_task_failure(function_id, return_object_ids,
e, None)
self._handle_process_task_failure(function_id, function_name,
return_object_ids, e, None)
return
except Exception as e:
self._handle_process_task_failure(
function_id, return_object_ids, e,
function_id, function_name, return_object_ids, e,
ray.utils.format_error_message(traceback.format_exc()))
return
@@ -950,8 +833,9 @@ class Worker(object):
task_exception = task.actor_id().id() == NIL_ACTOR_ID
traceback_str = ray.utils.format_error_message(
traceback.format_exc(), task_exception=task_exception)
self._handle_process_task_failure(function_id, return_object_ids,
e, traceback_str)
self._handle_process_task_failure(function_id, function_name,
return_object_ids, e,
traceback_str)
return
# Store the outputs in the local object store.
@@ -966,13 +850,11 @@ class Worker(object):
self._store_outputs_in_objstore(return_object_ids, outputs)
except Exception as e:
self._handle_process_task_failure(
function_id, return_object_ids, e,
function_id, function_name, return_object_ids, e,
ray.utils.format_error_message(traceback.format_exc()))
def _handle_process_task_failure(self, function_id, return_object_ids,
error, backtrace):
function_name = self.function_execution_info[self.task_driver_id.id()][
function_id.id()].function_name
def _handle_process_task_failure(self, function_id, function_name,
return_object_ids, error, backtrace):
failure_object = RayTaskError(function_name, error, backtrace)
failure_objects = [
failure_object for _ in range(len(return_object_ids))
@@ -1014,7 +896,7 @@ class Worker(object):
time.sleep(0.001)
with self.lock:
self.fetch_and_register_actor(key, self)
self.function_actor_manager.fetch_and_register_actor(key)
def _wait_for_and_process_task(self, task):
"""Wait for a task to be ready and process the task.
@@ -1031,11 +913,8 @@ class Worker(object):
self._become_actor(task)
return
# Wait until the function to be executed has actually been registered
# on this worker. We will push warnings to the user if we spend too
# long in this loop.
with profiling.profile("wait_for_function", worker=self):
self._wait_for_function(function_id, driver_id)
execution_info = self.function_actor_manager.get_execution_info(
driver_id, function_id)
# Execute the task.
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
@@ -1043,9 +922,7 @@ class Worker(object):
# because that may indicate that the system is hanging, and it'd be
# good to know where the system is hanging.
with self.lock:
function_name = (self.function_execution_info[driver_id][
function_id.id()]).function_name
function_name = execution_info.function_name
if not self.use_raylet:
extra_data = {
"function_name": function_name,
@@ -1058,7 +935,7 @@ class Worker(object):
"task_id": task.task_id().hex()
}
with profiling.profile("task", extra_data=extra_data, worker=self):
self._process_task(task)
self._process_task(task, execution_info)
# In the non-raylet code path, push all of the log events to the global
# state store. In the raylet code path, this is done periodically in a
@@ -1067,11 +944,11 @@ class Worker(object):
self.profiler.flush_profile_data()
# Increase the task execution counter.
self.num_task_executions[driver_id][function_id.id()] += 1
self.function_actor_manager.increase_task_counter(
driver_id, function_id.id())
reached_max_executions = (
self.num_task_executions[driver_id][function_id.id()] == self.
function_execution_info[driver_id][function_id.id()].max_calls)
reached_max_executions = (self.function_actor_manager.get_task_counter(
driver_id, function_id.id()) == execution_info.max_calls)
if reached_max_executions:
self.local_scheduler_client.disconnect()
os._exit(0)
@@ -2112,7 +1989,6 @@ def connect(info,
error_message = "Perhaps you called ray.init twice by accident?"
assert not worker.connected, error_message
assert worker.cached_functions_to_run is not None, error_message
assert worker.cached_remote_functions_and_actors is not None, error_message
# Initialize some fields.
worker.worker_id = random_string()
@@ -2350,18 +2226,9 @@ def connect(info,
# Export cached functions_to_run.
for function in worker.cached_functions_to_run:
worker.run_function_on_all_workers(function)
# Export cached remote functions to the workers.
for cached_type, info in worker.cached_remote_functions_and_actors:
if cached_type == "remote_function":
info._export()
elif cached_type == "actor":
(key, actor_class_info) = info
ray.actor.publish_actor_class_to_key(key, actor_class_info,
worker)
else:
assert False, "This code should be unreachable."
# Export cached remote functions and actors to the workers.
worker.function_actor_manager.export_cached()
worker.cached_functions_to_run = None
worker.cached_remote_functions_and_actors = None
def disconnect(worker=global_worker):
@@ -2372,7 +2239,7 @@ def disconnect(worker=global_worker):
# tests.
worker.connected = False
worker.cached_functions_to_run = []
worker.cached_remote_functions_and_actors = []
worker.function_actor_manager.reset_cache()
worker.serialization_context_map.clear()