Files
ray/python/ray/function_manager.py
T
2018-10-24 16:30:00 -07:00

496 lines
21 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import inspect
import json
import sys
import time
import traceback
from collections import (
namedtuple,
defaultdict,
)
import ray
from ray import profiling
from ray import ray_constants
from ray import cloudpickle as pickle
from ray.utils import (
is_cython,
is_function_or_method,
is_class_method,
check_oversized_pickle,
decode,
format_error_message,
push_error_to_driver,
)
FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
["function", "function_name", "max_calls"])
"""FunctionExecutionInfo: A named tuple storing remote function information."""
class FunctionActorManager(object):
"""A class used to export/load remote functions and actors.
Attributes:
_worker: The associated worker that this manager related.
_functions_to_export: The remote functions to export when
the worker gets connected.
_actors_to_export: The actors to export when the worker gets
connected.
_function_execution_info: The map from driver_id to finction_id
and execution_info.
_num_task_executions: The map from driver_id to function
execution times.
"""
def __init__(self, worker):
self._worker = worker
self._functions_to_export = []
self._actors_to_export = []
# This field is a dictionary that maps a driver ID to a dictionary of
# functions (and information about those functions) that have been
# registered for that driver (this inner dictionary maps function IDs
# to a FunctionExecutionInfo object. This should only be used on
# workers that execute remote functions.
self._function_execution_info = defaultdict(lambda: {})
self._num_task_executions = defaultdict(lambda: {})
def increase_task_counter(self, driver_id, function_id):
self._num_task_executions[driver_id][function_id] += 1
def get_task_counter(self, driver_id, function_id):
return self._num_task_executions[driver_id][function_id]
def export_cached(self):
"""Export cached remote functions
Note: this should be called only once when worker is connected.
"""
for remote_function in self._functions_to_export:
self._do_export(remote_function)
self._functions_to_export = None
for info in self._actors_to_export:
(key, actor_class_info) = info
self._publish_actor_class_to_key(key, actor_class_info)
def reset_cache(self):
self._functions_to_export = []
self._actors_to_export = []
def export(self, remote_function):
"""Export a remote function.
Args:
remote_function: the RemoteFunction object.
"""
if self._worker.mode is None:
# If the worker isn't connected, cache the function
# and export it later.
self._functions_to_export.append(remote_function)
return
if self._worker.mode != ray.worker.SCRIPT_MODE:
# Don't need to export if the worker is not a driver.
return
self._do_export(remote_function)
def _do_export(self, remote_function):
"""Pickle a remote function and export it to redis.
Args:
remote_function: the RemoteFunction object.
"""
# Work around limitations of Python pickling.
function = remote_function._function
function_name_global_valid = function.__name__ in function.__globals__
function_name_global_value = function.__globals__.get(
function.__name__)
# Allow the function to reference itself as a global variable
if not is_cython(function):
function.__globals__[function.__name__] = remote_function
try:
pickled_function = pickle.dumps(function)
finally:
# Undo our changes
if function_name_global_valid:
function.__globals__[function.__name__] = (
function_name_global_value)
else:
del function.__globals__[function.__name__]
check_oversized_pickle(pickled_function,
remote_function._function_name,
"remote function", self._worker)
key = (b"RemoteFunction:" + self._worker.task_driver_id.id() + b":" +
remote_function._function_id)
self._worker.redis_client.hmset(
key, {
"driver_id": self._worker.task_driver_id.id(),
"function_id": remote_function._function_id,
"name": remote_function._function_name,
"module": function.__module__,
"function": pickled_function,
"max_calls": remote_function._max_calls
})
self._worker.redis_client.rpush("Exports", key)
def fetch_and_register_remote_function(self, key):
"""Import a remote function."""
(driver_id, function_id_str, function_name, serialized_function,
num_return_vals, module, resources,
max_calls) = self._worker.redis_client.hmget(key, [
"driver_id", "function_id", "name", "function", "num_return_vals",
"module", "resources", "max_calls"
])
function_id = ray.ObjectID(function_id_str)
function_name = decode(function_name)
max_calls = int(max_calls)
module = decode(module)
# This is a placeholder in case the function can't be unpickled. This
# will be overwritten if the function is successfully registered.
def f():
raise Exception("This function was not imported properly.")
self._function_execution_info[driver_id][function_id.id()] = (
FunctionExecutionInfo(
function=f, function_name=function_name, max_calls=max_calls))
self._num_task_executions[driver_id][function_id.id()] = 0
try:
function = pickle.loads(serialized_function)
except Exception:
# If an exception was thrown when the remote function was imported,
# we record the traceback and notify the scheduler of the failure.
traceback_str = format_error_message(traceback.format_exc())
# Log the error message.
push_error_to_driver(
self._worker,
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
data={
"function_id": function_id.id(),
"function_name": function_name
})
else:
# The below line is necessary. Because in the driver process,
# if the function is defined in the file where the python script
# was started from, its module is `__main__`.
# However in the worker process, the `__main__` module is a
# different module, which is `default_worker.py`
function.__module__ = module
self._function_execution_info[driver_id][function_id.id()] = (
FunctionExecutionInfo(
function=function,
function_name=function_name,
max_calls=max_calls))
# Add the function to the function table.
self._worker.redis_client.rpush(
b"FunctionTable:" + function_id.id(), self._worker.worker_id)
def get_execution_info(self, driver_id, function_id):
"""Get the FunctionExecutionInfo of a remote function.
Args:
driver_id: ID of the driver that the function belongs to.
function_id: ID of the function to get.
Returns:
A FunctionExecutionInfo object.
"""
# Wait until the function to be executed has actually been registered
# on this worker. We will push warnings to the user if we spend too
# long in this loop.
with profiling.profile("wait_for_function", worker=self._worker):
self._wait_for_function(function_id, driver_id)
return self._function_execution_info[driver_id][function_id.id()]
def _wait_for_function(self, function_id, driver_id, timeout=10):
"""Wait until the function to be executed is present on this worker.
This method will simply loop until the import thread has imported the
relevant function. If we spend too long in this loop, that may indicate
a problem somewhere and we will push an error message to the user.
If this worker is an actor, then this will wait until the actor has
been defined.
Args:
function_id (str): The ID of the function that we want to execute.
driver_id (str): The ID of the driver to push the error message to
if this times out.
"""
start_time = time.time()
# Only send the warning once.
warning_sent = False
while True:
with self._worker.lock:
if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID
and (function_id.id() in
self._function_execution_info[driver_id])):
break
elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and (
self._worker.actor_id in self._worker.actors):
break
if time.time() - start_time > timeout:
warning_message = ("This worker was asked to execute a "
"function that it does not have "
"registered. You may have to restart "
"Ray.")
if not warning_sent:
ray.utils.push_error_to_driver(
self._worker,
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
warning_message,
driver_id=driver_id)
warning_sent = True
time.sleep(0.001)
@classmethod
def compute_actor_method_function_id(cls, class_name, attr):
"""Get the function ID corresponding to an actor method.
Args:
class_name (str): The class name of the actor.
attr (str): The attribute name of the method.
Returns:
Function ID corresponding to the method.
"""
function_id_hash = hashlib.sha1()
function_id_hash.update(class_name.encode("ascii"))
function_id_hash.update(attr.encode("ascii"))
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return ray.ObjectID(function_id)
def _publish_actor_class_to_key(self, key, actor_class_info):
"""Push an actor class definition to Redis.
The is factored out as a separate function because it is also called
on cached actor class definitions when a worker connects for the first
time.
Args:
key: The key to store the actor class info at.
actor_class_info: Information about the actor class.
worker: The worker to use to connect to Redis.
"""
# We set the driver ID here because it may not have been available when
# the actor class was defined.
actor_class_info["driver_id"] = self._worker.task_driver_id.id()
self._worker.redis_client.hmset(key, actor_class_info)
self._worker.redis_client.rpush("Exports", key)
def export_actor_class(self, class_id, Class, actor_method_names,
checkpoint_interval):
key = b"ActorClass:" + class_id
actor_class_info = {
"class_name": Class.__name__,
"module": Class.__module__,
"class": pickle.dumps(Class),
"checkpoint_interval": checkpoint_interval,
"actor_method_names": json.dumps(list(actor_method_names))
}
check_oversized_pickle(actor_class_info["class"],
actor_class_info["class_name"], "actor",
self._worker)
if self._worker.mode is None:
# This means that 'ray.init()' has not been called yet and so we
# must cache the actor class definition and export it when
# 'ray.init()' is called.
assert self._actors_to_export is not None
self._actors_to_export.append((key, actor_class_info))
# This caching code path is currently not used because we only
# export actor class definitions lazily when we instantiate the
# actor for the first time.
assert False, "This should be unreachable."
else:
self._publish_actor_class_to_key(key, actor_class_info)
# TODO(rkn): Currently we allow actor classes to be defined
# within tasks. I tried to disable this, but it may be necessary
# because of https://github.com/ray-project/ray/issues/1146.
def fetch_and_register_actor(self, actor_class_key):
"""Import an actor.
This will be called by the worker's import thread when the worker
receives the actor_class export, assuming that the worker is an actor
for that class.
Args:
actor_class_key: The key in Redis to use to fetch the actor.
worker: The worker to use.
"""
actor_id_str = self._worker.actor_id
(driver_id, class_id, class_name, module, pickled_class,
checkpoint_interval,
actor_method_names) = self._worker.redis_client.hmget(
actor_class_key, [
"driver_id", "class_id", "class_name", "module", "class",
"checkpoint_interval", "actor_method_names"
])
class_name = decode(class_name)
module = decode(module)
checkpoint_interval = int(checkpoint_interval)
actor_method_names = json.loads(decode(actor_method_names))
# In Python 2, json loads strings as unicode, so convert them back to
# strings.
if sys.version_info < (3, 0):
actor_method_names = [
method_name.encode("ascii")
for method_name in actor_method_names
]
# Create a temporary actor with some temporary methods so that if
# the actor fails to be unpickled, the temporary actor can be used
# (just to produce error messages and to prevent the driver from
# hanging).
class TemporaryActor(object):
pass
self._worker.actors[actor_id_str] = TemporaryActor()
self._worker.actor_checkpoint_interval = checkpoint_interval
def temporary_actor_method(*xs):
raise Exception(
"The actor with name {} failed to be imported, "
"and so cannot execute this method".format(class_name))
# Register the actor method executors.
for actor_method_name in actor_method_names:
function_id = (
FunctionActorManager.compute_actor_method_function_id(
class_name, actor_method_name).id())
temporary_executor = self._make_actor_method_executor(
actor_method_name,
temporary_actor_method,
actor_imported=False)
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=temporary_executor,
function_name=actor_method_name,
max_calls=0))
self._num_task_executions[driver_id][function_id] = 0
try:
unpickled_class = pickle.loads(pickled_class)
self._worker.actor_class = unpickled_class
except Exception:
# If an exception was thrown when the actor was imported, we record
# the traceback and notify the scheduler of the failure.
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
# Log the error message.
push_error_to_driver(
self._worker,
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
traceback_str,
driver_id,
data={"actor_id": actor_id_str})
# TODO(rkn): In the future, it might make sense to have the worker
# exit here. However, currently that would lead to hanging if
# someone calls ray.get on a method invoked on the actor.
else:
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
self._worker.actors[actor_id_str] = unpickled_class.__new__(
unpickled_class)
actor_methods = inspect.getmembers(
unpickled_class, predicate=is_function_or_method)
for actor_method_name, actor_method in actor_methods:
function_id = (
FunctionActorManager.compute_actor_method_function_id(
class_name, actor_method_name).id())
executor = self._make_actor_method_executor(
actor_method_name, actor_method, actor_imported=True)
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=executor,
function_name=actor_method_name,
max_calls=0))
# We do not set function_properties[driver_id][function_id]
# because we currently do need the actor worker to submit new
# tasks for the actor.
def _make_actor_method_executor(self, method_name, method, actor_imported):
"""Make an executor that wraps a user-defined actor method.
The wrapped method updates the worker's internal state and performs any
necessary checkpointing operations.
Args:
worker (Worker): The worker that is executing the actor.
method_name (str): The name of the actor method.
method (instancemethod): The actor method to wrap. This should be a
method defined on the actor class and should therefore take an
instance of the actor as the first argument.
actor_imported (bool): Whether the actor has been imported.
Checkpointing operations will not be run if this is set to
False.
Returns:
A function that executes the given actor method on the worker's
stored instance of the actor. The function also updates the
worker's internal state to record the executed method.
"""
def actor_method_executor(dummy_return_id, actor, *args):
# Update the actor's task counter to reflect the task we're about
# to execute.
self._worker.actor_task_counter += 1
# If this is the first task to execute on the actor, try to resume
# from a checkpoint.
if actor_imported and self._worker.actor_task_counter == 1:
checkpoint_resumed = ray.actor.restore_and_log_checkpoint(
self._worker, actor)
if checkpoint_resumed:
# NOTE(swang): Since we did not actually execute the
# __init__ method, this will put None as the return value.
# If the __init__ method is supposed to return multiple
# values, an exception will be logged.
return
# Determine whether we should checkpoint the actor.
checkpointing_on = (actor_imported
and self._worker.actor_checkpoint_interval > 0)
# We should checkpoint the actor if user checkpointing is on, we've
# executed checkpoint_interval tasks since the last checkpoint, and
# the method we're about to execute is not a checkpoint.
save_checkpoint = (checkpointing_on
and (self._worker.actor_task_counter %
self._worker.actor_checkpoint_interval == 0
and method_name != "__ray_checkpoint__"))
# Execute the assigned method and save a checkpoint if necessary.
try:
if is_class_method(method):
method_returns = method(*args)
else:
method_returns = method(actor, *args)
except Exception:
# Save the checkpoint before allowing the method exception
# to be thrown.
if save_checkpoint:
ray.actor.save_and_log_checkpoint(self._worker, actor)
raise
else:
# Save the checkpoint before returning the method's return
# values.
if save_checkpoint:
ray.actor.save_and_log_checkpoint(self._worker, actor)
return method_returns
return actor_method_executor