diff --git a/python/ray/actor.py b/python/ray/actor.py index 3886e1927..d61fac7d7 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -5,31 +5,19 @@ from __future__ import print_function import copy import hashlib import inspect -import json import traceback import ray.cloudpickle as pickle +from ray.function_manager import FunctionActorManager import ray.local_scheduler import ray.ray_constants as ray_constants import ray.signature as signature import ray.worker -from ray.utils import ( - decode, - _random_string, - check_oversized_pickle, - is_cython, - push_error_to_driver, -) +from ray.utils import _random_string DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1 -def is_classmethod(f): - """Returns whether the given method is a classmethod.""" - - return hasattr(f, "__self__") and f.__self__ is not None - - def compute_actor_handle_id(actor_handle_id, num_forks): """Deterministically compute an actor handle ID. @@ -96,24 +84,6 @@ def compute_actor_creation_function_id(class_id): return ray.ObjectID(class_id) -def compute_actor_method_function_id(class_name, attr): - """Get the function ID corresponding to an actor method. - - Args: - class_name (str): The class name of the actor. - attr (str): The attribute name of the method. - - Returns: - Function ID corresponding to the method. - """ - function_id_hash = hashlib.sha1() - function_id_hash.update(class_name.encode("ascii")) - function_id_hash.update(attr.encode("ascii")) - function_id = function_id_hash.digest() - assert len(function_id) == ray_constants.ID_SIZE - return ray.ObjectID(function_id) - - def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, frontier): """Set the most recent checkpoint associated with a given actor ID. @@ -134,28 +104,6 @@ def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, }) -def get_actor_checkpoint(worker, actor_id): - """Get the most recent checkpoint associated with a given actor ID. - - Args: - worker: The worker to use to get the checkpoint. - actor_id: The actor ID of the actor to get the checkpoint for. - - Returns: - If a checkpoint exists, this returns a tuple of the number of tasks - included in the checkpoint, the saved checkpoint state, and the - task frontier at the time of the checkpoint. If no checkpoint - exists, all objects are set to None. The checkpoint index is the . - executed on the actor before the checkpoint was made. - """ - actor_key = b"Actor:" + actor_id - checkpoint_index, checkpoint, frontier = worker.redis_client.hmget( - actor_key, ["checkpoint_index", "checkpoint", "frontier"]) - if checkpoint_index is not None: - checkpoint_index = int(checkpoint_index) - return checkpoint_index, checkpoint, frontier - - def save_and_log_checkpoint(worker, actor): """Save a checkpoint on the actor and log any errors. @@ -205,219 +153,26 @@ def restore_and_log_checkpoint(worker, actor): return checkpoint_resumed -def make_actor_method_executor(worker, method_name, method, actor_imported): - """Make an executor that wraps a user-defined actor method. - - The wrapped method updates the worker's internal state and performs any - necessary checkpointing operations. +def get_actor_checkpoint(worker, actor_id): + """Get the most recent checkpoint associated with a given actor ID. Args: - worker (Worker): The worker that is executing the actor. - method_name (str): The name of the actor method. - method (instancemethod): The actor method to wrap. This should be a - method defined on the actor class and should therefore take an - instance of the actor as the first argument. - actor_imported (bool): Whether the actor has been imported. - Checkpointing operations will not be run if this is set to False. + worker: The worker to use to get the checkpoint. + actor_id: The actor ID of the actor to get the checkpoint for. Returns: - A function that executes the given actor method on the worker's stored - instance of the actor. The function also updates the worker's - internal state to record the executed method. + If a checkpoint exists, this returns a tuple of the number of tasks + included in the checkpoint, the saved checkpoint state, and the + task frontier at the time of the checkpoint. If no checkpoint + exists, all objects are set to None. The checkpoint index is the . + executed on the actor before the checkpoint was made. """ - - def actor_method_executor(dummy_return_id, actor, *args): - # Update the actor's task counter to reflect the task we're about to - # execute. - worker.actor_task_counter += 1 - - # If this is the first task to execute on the actor, try to resume from - # a checkpoint. - if actor_imported and worker.actor_task_counter == 1: - checkpoint_resumed = restore_and_log_checkpoint(worker, actor) - if checkpoint_resumed: - # NOTE(swang): Since we did not actually execute the __init__ - # method, this will put None as the return value. If the - # __init__ method is supposed to return multiple values, an - # exception will be logged. - return - - # Determine whether we should checkpoint the actor. - checkpointing_on = (actor_imported - and worker.actor_checkpoint_interval > 0) - # We should checkpoint the actor if user checkpointing is on, we've - # executed checkpoint_interval tasks since the last checkpoint, and the - # method we're about to execute is not a checkpoint. - save_checkpoint = ( - checkpointing_on and - (worker.actor_task_counter % worker.actor_checkpoint_interval == 0 - and method_name != "__ray_checkpoint__")) - - # Execute the assigned method and save a checkpoint if necessary. - try: - if is_classmethod(method): - method_returns = method(*args) - else: - method_returns = method(actor, *args) - except Exception: - # Save the checkpoint before allowing the method exception to be - # thrown. - if save_checkpoint: - save_and_log_checkpoint(worker, actor) - raise - else: - # Save the checkpoint before returning the method's return values. - if save_checkpoint: - save_and_log_checkpoint(worker, actor) - return method_returns - - return actor_method_executor - - -def fetch_and_register_actor(actor_class_key, worker): - """Import an actor. - - This will be called by the worker's import thread when the worker receives - the actor_class export, assuming that the worker is an actor for that - class. - - Args: - actor_class_key: The key in Redis to use to fetch the actor. - worker: The worker to use. - """ - actor_id_str = worker.actor_id - (driver_id, class_id, class_name, module, pickled_class, - checkpoint_interval, actor_method_names) = worker.redis_client.hmget( - actor_class_key, [ - "driver_id", "class_id", "class_name", "module", "class", - "checkpoint_interval", "actor_method_names" - ]) - - class_name = decode(class_name) - module = decode(module) - checkpoint_interval = int(checkpoint_interval) - actor_method_names = json.loads(decode(actor_method_names)) - - # Create a temporary actor with some temporary methods so that if the actor - # fails to be unpickled, the temporary actor can be used (just to produce - # error messages and to prevent the driver from hanging). - class TemporaryActor(object): - pass - - worker.actors[actor_id_str] = TemporaryActor() - worker.actor_checkpoint_interval = checkpoint_interval - - def temporary_actor_method(*xs): - raise Exception("The actor with name {} failed to be imported, and so " - "cannot execute this method".format(class_name)) - - # Register the actor method executors. - for actor_method_name in actor_method_names: - function_id = compute_actor_method_function_id(class_name, - actor_method_name).id() - temporary_executor = make_actor_method_executor( - worker, - actor_method_name, - temporary_actor_method, - actor_imported=False) - worker.function_execution_info[driver_id][function_id] = ( - ray.worker.FunctionExecutionInfo( - function=temporary_executor, - function_name=actor_method_name, - max_calls=0)) - worker.num_task_executions[driver_id][function_id] = 0 - - try: - unpickled_class = pickle.loads(pickled_class) - worker.actor_class = unpickled_class - except Exception: - # If an exception was thrown when the actor was imported, we record the - # traceback and notify the scheduler of the failure. - traceback_str = ray.utils.format_error_message(traceback.format_exc()) - # Log the error message. - push_error_to_driver( - worker, - ray_constants.REGISTER_ACTOR_PUSH_ERROR, - traceback_str, - driver_id, - data={"actor_id": actor_id_str}) - # TODO(rkn): In the future, it might make sense to have the worker exit - # here. However, currently that would lead to hanging if someone calls - # ray.get on a method invoked on the actor. - else: - # TODO(pcm): Why is the below line necessary? - unpickled_class.__module__ = module - worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class) - - def pred(x): - return (inspect.isfunction(x) or inspect.ismethod(x) - or is_cython(x)) - - actor_methods = inspect.getmembers(unpickled_class, predicate=pred) - for actor_method_name, actor_method in actor_methods: - function_id = compute_actor_method_function_id( - class_name, actor_method_name).id() - executor = make_actor_method_executor( - worker, actor_method_name, actor_method, actor_imported=True) - worker.function_execution_info[driver_id][function_id] = ( - ray.worker.FunctionExecutionInfo( - function=executor, - function_name=actor_method_name, - max_calls=0)) - # We do not set worker.function_properties[driver_id][function_id] - # because we currently do need the actor worker to submit new tasks - # for the actor. - - -def publish_actor_class_to_key(key, actor_class_info, worker): - """Push an actor class definition to Redis. - - The is factored out as a separate function because it is also called - on cached actor class definitions when a worker connects for the first - time. - - Args: - key: The key to store the actor class info at. - actor_class_info: Information about the actor class. - worker: The worker to use to connect to Redis. - """ - # We set the driver ID here because it may not have been available when the - # actor class was defined. - actor_class_info["driver_id"] = worker.task_driver_id.id() - worker.redis_client.hmset(key, actor_class_info) - worker.redis_client.rpush("Exports", key) - - -def export_actor_class(class_id, Class, actor_method_names, - checkpoint_interval, worker): - key = b"ActorClass:" + class_id - actor_class_info = { - "class_name": Class.__name__, - "module": Class.__module__, - "class": pickle.dumps(Class), - "checkpoint_interval": checkpoint_interval, - "actor_method_names": json.dumps(list(actor_method_names)) - } - - check_oversized_pickle(actor_class_info["class"], - actor_class_info["class_name"], "actor", worker) - - if worker.mode is None: - # This means that 'ray.init()' has not been called yet and so we must - # cache the actor class definition and export it when 'ray.init()' is - # called. - assert worker.cached_remote_functions_and_actors is not None - worker.cached_remote_functions_and_actors.append( - ("actor", (key, actor_class_info))) - # This caching code path is currently not used because we only export - # actor class definitions lazily when we instantiate the actor for the - # first time. - assert False, "This should be unreachable." - else: - publish_actor_class_to_key(key, actor_class_info, worker) - # TODO(rkn): Currently we allow actor classes to be defined within tasks. - # I tried to disable this, but it may be necessary because of - # https://github.com/ray-project/ray/issues/1146. + actor_key = b"Actor:" + actor_id + checkpoint_index, checkpoint, frontier = worker.redis_client.hmget( + actor_key, ["checkpoint_index", "checkpoint", "frontier"]) + if checkpoint_index is not None: + checkpoint_index = int(checkpoint_index) + return checkpoint_index, checkpoint, frontier def method(*args, **kwargs): @@ -518,13 +273,8 @@ class ActorClass(object): self._actor_method_cpus = actor_method_cpus self._exported = False - # Get the actor methods of the given class. - def pred(x): - return (inspect.isfunction(x) or inspect.ismethod(x) - or is_cython(x)) - self._actor_methods = inspect.getmembers( - self._modified_class, predicate=pred) + self._modified_class, ray.utils.is_function_or_method) # Extract the signatures of each of the methods. This will be used # to catch some errors if the methods are called with inappropriate # arguments. @@ -537,7 +287,7 @@ class ActorClass(object): # don't support, there may not be much the user can do about it. signature.check_signature_supported(method, warn=True) self._method_signatures[method_name] = signature.extract_signature( - method, ignore_first=not is_classmethod(method)) + method, ignore_first=not ray.utils.is_class_method(method)) # Set the default number of return values for this method. if hasattr(method, "__ray_num_return_vals__"): @@ -614,9 +364,9 @@ class ActorClass(object): else: # Export the actor. if not self._exported: - export_actor_class(self._class_id, self._modified_class, - self._actor_method_names, - self._checkpoint_interval, worker) + worker.function_actor_manager.export_actor_class( + self._class_id, self._modified_class, + self._actor_method_names, self._checkpoint_interval) self._exported = True resources = ray.utils.resources_from_resource_arguments( @@ -801,8 +551,8 @@ class ActorHandle(object): else: actor_handle_id = self._ray_actor_handle_id - function_id = compute_actor_method_function_id(self._ray_class_name, - method_name) + function_id = FunctionActorManager.compute_actor_method_function_id( + self._ray_class_name, method_name) object_ids = worker.submit_task( function_id, args, @@ -1068,5 +818,4 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus, resources, actor_method_cpus) -ray.worker.global_worker.fetch_and_register_actor = fetch_and_register_actor ray.worker.global_worker.make_actor = make_actor diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py new file mode 100644 index 000000000..0e123bd67 --- /dev/null +++ b/python/ray/function_manager.py @@ -0,0 +1,486 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib +import inspect +import json +import time +import traceback +from collections import ( + namedtuple, + defaultdict, +) + +import ray +from ray import profiling +from ray import ray_constants +from ray import cloudpickle as pickle +from ray.utils import ( + is_cython, + is_function_or_method, + is_class_method, + check_oversized_pickle, + decode, + format_error_message, + push_error_to_driver, +) + +FunctionExecutionInfo = namedtuple("FunctionExecutionInfo", + ["function", "function_name", "max_calls"]) +"""FunctionExecutionInfo: A named tuple storing remote function information.""" + + +class FunctionActorManager(object): + """A class used to export/load remote functions and actors. + + Attributes: + _worker: The associated worker that this manager related. + _functions_to_export: The remote functions to export when + the worker gets connected. + _actors_to_export: The actors to export when the worker gets + connected. + _function_execution_info: The map from driver_id to finction_id + and execution_info. + _num_task_executions: The map from driver_id to function + execution times. + """ + + def __init__(self, worker): + self._worker = worker + self._functions_to_export = [] + self._actors_to_export = [] + # This field is a dictionary that maps a driver ID to a dictionary of + # functions (and information about those functions) that have been + # registered for that driver (this inner dictionary maps function IDs + # to a FunctionExecutionInfo object. This should only be used on + # workers that execute remote functions. + self._function_execution_info = defaultdict(lambda: {}) + self._num_task_executions = defaultdict(lambda: {}) + + def increase_task_counter(self, driver_id, function_id): + self._num_task_executions[driver_id][function_id] += 1 + + def get_task_counter(self, driver_id, function_id): + return self._num_task_executions[driver_id][function_id] + + def export_cached(self): + """Export cached remote functions + + Note: this should be called only once when worker is connected. + """ + for remote_function in self._functions_to_export: + self._do_export(remote_function) + self._functions_to_export = None + for info in self._actors_to_export: + (key, actor_class_info) = info + self._publish_actor_class_to_key(key, actor_class_info) + + def reset_cache(self): + self._functions_to_export = [] + self._actors_to_export = [] + + def export(self, remote_function): + """Export a remote function. + + Args: + remote_function: the RemoteFunction object. + """ + if self._worker.mode is None: + # If the worker isn't connected, cache the function + # and export it later. + self._functions_to_export.append(remote_function) + return + if self._worker.mode != ray.worker.SCRIPT_MODE: + # Don't need to export if the worker is not a driver. + return + self._do_export(remote_function) + + def _do_export(self, remote_function): + """Pickle a remote function and export it to redis. + + Args: + remote_function: the RemoteFunction object. + """ + # Work around limitations of Python pickling. + function = remote_function._function + function_name_global_valid = function.__name__ in function.__globals__ + function_name_global_value = function.__globals__.get( + function.__name__) + # Allow the function to reference itself as a global variable + if not is_cython(function): + function.__globals__[function.__name__] = remote_function + try: + pickled_function = pickle.dumps(function) + finally: + # Undo our changes + if function_name_global_valid: + function.__globals__[function.__name__] = ( + function_name_global_value) + else: + del function.__globals__[function.__name__] + + check_oversized_pickle(pickled_function, + remote_function._function_name, + "remote function", self._worker) + + key = (b"RemoteFunction:" + self._worker.task_driver_id.id() + b":" + + remote_function._function_id) + self._worker.redis_client.hmset( + key, { + "driver_id": self._worker.task_driver_id.id(), + "function_id": remote_function._function_id, + "name": remote_function._function_name, + "module": function.__module__, + "function": pickled_function, + "max_calls": remote_function._max_calls + }) + self._worker.redis_client.rpush("Exports", key) + + def fetch_and_register_remote_function(self, key): + """Import a remote function.""" + (driver_id, function_id_str, function_name, serialized_function, + num_return_vals, module, resources, + max_calls) = self._worker.redis_client.hmget(key, [ + "driver_id", "function_id", "name", "function", "num_return_vals", + "module", "resources", "max_calls" + ]) + function_id = ray.ObjectID(function_id_str) + function_name = decode(function_name) + max_calls = int(max_calls) + module = decode(module) + + # This is a placeholder in case the function can't be unpickled. This + # will be overwritten if the function is successfully registered. + def f(): + raise Exception("This function was not imported properly.") + + self._function_execution_info[driver_id][function_id.id()] = ( + FunctionExecutionInfo( + function=f, function_name=function_name, max_calls=max_calls)) + self._num_task_executions[driver_id][function_id.id()] = 0 + + try: + function = pickle.loads(serialized_function) + except Exception as e: + # If an exception was thrown when the remote function was imported, + # we record the traceback and notify the scheduler of the failure. + traceback_str = format_error_message(traceback.format_exc()) + # Log the error message. + push_error_to_driver( + self._worker, + ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, + traceback_str, + driver_id=driver_id, + data={ + "function_id": function_id.id(), + "function_name": function_name + }) + else: + # The below line is necessary. Because in the driver process, + # if the function is defined in the file where the python script + # was started from, its module is `__main__`. + # However in the worker process, the `__main__` module is a + # different module, which is `default_worker.py` + function.__module__ = module + self._function_execution_info[driver_id][function_id.id()] = ( + FunctionExecutionInfo( + function=function, + function_name=function_name, + max_calls=max_calls)) + # Add the function to the function table. + self._worker.redis_client.rpush( + b"FunctionTable:" + function_id.id(), self._worker.worker_id) + + def get_execution_info(self, driver_id, function_id): + """Get the FunctionExecutionInfo of a remote function. + + Args: + driver_id: ID of the driver that the function belongs to. + function_id: ID of the function to get. + + Returns: + A FunctionExecutionInfo object. + """ + # Wait until the function to be executed has actually been registered + # on this worker. We will push warnings to the user if we spend too + # long in this loop. + with profiling.profile("wait_for_function", worker=self._worker): + self._wait_for_function(function_id, driver_id) + return self._function_execution_info[driver_id][function_id.id()] + + def _wait_for_function(self, function_id, driver_id, timeout=10): + """Wait until the function to be executed is present on this worker. + + This method will simply loop until the import thread has imported the + relevant function. If we spend too long in this loop, that may indicate + a problem somewhere and we will push an error message to the user. + + If this worker is an actor, then this will wait until the actor has + been defined. + + Args: + function_id (str): The ID of the function that we want to execute. + driver_id (str): The ID of the driver to push the error message to + if this times out. + """ + start_time = time.time() + # Only send the warning once. + warning_sent = False + while True: + with self._worker.lock: + if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID + and (function_id.id() in + self._function_execution_info[driver_id])): + break + elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and ( + self._worker.actor_id in self._worker.actors): + break + if time.time() - start_time > timeout: + warning_message = ("This worker was asked to execute a " + "function that it does not have " + "registered. You may have to restart " + "Ray.") + if not warning_sent: + ray.utils.push_error_to_driver( + self._worker, + ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR, + warning_message, + driver_id=driver_id) + warning_sent = True + time.sleep(0.001) + + @classmethod + def compute_actor_method_function_id(cls, class_name, attr): + """Get the function ID corresponding to an actor method. + + Args: + class_name (str): The class name of the actor. + attr (str): The attribute name of the method. + + Returns: + Function ID corresponding to the method. + """ + function_id_hash = hashlib.sha1() + function_id_hash.update(class_name.encode("ascii")) + function_id_hash.update(attr.encode("ascii")) + function_id = function_id_hash.digest() + assert len(function_id) == ray_constants.ID_SIZE + return ray.ObjectID(function_id) + + def _publish_actor_class_to_key(self, key, actor_class_info): + """Push an actor class definition to Redis. + + The is factored out as a separate function because it is also called + on cached actor class definitions when a worker connects for the first + time. + + Args: + key: The key to store the actor class info at. + actor_class_info: Information about the actor class. + worker: The worker to use to connect to Redis. + """ + # We set the driver ID here because it may not have been available when + # the actor class was defined. + actor_class_info["driver_id"] = self._worker.task_driver_id.id() + self._worker.redis_client.hmset(key, actor_class_info) + self._worker.redis_client.rpush("Exports", key) + + def export_actor_class(self, class_id, Class, actor_method_names, + checkpoint_interval): + key = b"ActorClass:" + class_id + actor_class_info = { + "class_name": Class.__name__, + "module": Class.__module__, + "class": pickle.dumps(Class), + "checkpoint_interval": checkpoint_interval, + "actor_method_names": json.dumps(list(actor_method_names)) + } + + check_oversized_pickle(actor_class_info["class"], + actor_class_info["class_name"], "actor", + self._worker) + + if self._worker.mode is None: + # This means that 'ray.init()' has not been called yet and so we + # must cache the actor class definition and export it when + # 'ray.init()' is called. + assert self._actors_to_export is not None + self._actors_to_export.append((key, actor_class_info)) + # This caching code path is currently not used because we only + # export actor class definitions lazily when we instantiate the + # actor for the first time. + assert False, "This should be unreachable." + else: + self._publish_actor_class_to_key(key, actor_class_info) + # TODO(rkn): Currently we allow actor classes to be defined + # within tasks. I tried to disable this, but it may be necessary + # because of https://github.com/ray-project/ray/issues/1146. + + def fetch_and_register_actor(self, actor_class_key): + """Import an actor. + + This will be called by the worker's import thread when the worker + receives the actor_class export, assuming that the worker is an actor + for that class. + + Args: + actor_class_key: The key in Redis to use to fetch the actor. + worker: The worker to use. + """ + actor_id_str = self._worker.actor_id + (driver_id, class_id, class_name, module, pickled_class, + checkpoint_interval, + actor_method_names) = self._worker.redis_client.hmget( + actor_class_key, [ + "driver_id", "class_id", "class_name", "module", "class", + "checkpoint_interval", "actor_method_names" + ]) + + class_name = decode(class_name) + module = decode(module) + checkpoint_interval = int(checkpoint_interval) + actor_method_names = json.loads(decode(actor_method_names)) + + # Create a temporary actor with some temporary methods so that if + # the actor fails to be unpickled, the temporary actor can be used + # (just to produce error messages and to prevent the driver from + # hanging). + class TemporaryActor(object): + pass + + self._worker.actors[actor_id_str] = TemporaryActor() + self._worker.actor_checkpoint_interval = checkpoint_interval + + def temporary_actor_method(*xs): + raise Exception( + "The actor with name {} failed to be imported, " + "and so cannot execute this method".format(class_name)) + + # Register the actor method executors. + for actor_method_name in actor_method_names: + function_id = ( + FunctionActorManager.compute_actor_method_function_id( + class_name, actor_method_name).id()) + temporary_executor = self._make_actor_method_executor( + actor_method_name, + temporary_actor_method, + actor_imported=False) + self._function_execution_info[driver_id][function_id] = ( + FunctionExecutionInfo( + function=temporary_executor, + function_name=actor_method_name, + max_calls=0)) + self._num_task_executions[driver_id][function_id] = 0 + + try: + unpickled_class = pickle.loads(pickled_class) + self._worker.actor_class = unpickled_class + except Exception: + # If an exception was thrown when the actor was imported, we record + # the traceback and notify the scheduler of the failure. + traceback_str = ray.utils.format_error_message( + traceback.format_exc()) + # Log the error message. + push_error_to_driver( + self._worker, + ray_constants.REGISTER_ACTOR_PUSH_ERROR, + traceback_str, + driver_id, + data={"actor_id": actor_id_str}) + # TODO(rkn): In the future, it might make sense to have the worker + # exit here. However, currently that would lead to hanging if + # someone calls ray.get on a method invoked on the actor. + else: + # TODO(pcm): Why is the below line necessary? + unpickled_class.__module__ = module + self._worker.actors[actor_id_str] = unpickled_class.__new__( + unpickled_class) + + actor_methods = inspect.getmembers( + unpickled_class, predicate=is_function_or_method) + for actor_method_name, actor_method in actor_methods: + function_id = ( + FunctionActorManager.compute_actor_method_function_id( + class_name, actor_method_name).id()) + executor = self._make_actor_method_executor( + actor_method_name, actor_method, actor_imported=True) + self._function_execution_info[driver_id][function_id] = ( + FunctionExecutionInfo( + function=executor, + function_name=actor_method_name, + max_calls=0)) + # We do not set function_properties[driver_id][function_id] + # because we currently do need the actor worker to submit new + # tasks for the actor. + + def _make_actor_method_executor(self, method_name, method, actor_imported): + """Make an executor that wraps a user-defined actor method. + + The wrapped method updates the worker's internal state and performs any + necessary checkpointing operations. + + Args: + worker (Worker): The worker that is executing the actor. + method_name (str): The name of the actor method. + method (instancemethod): The actor method to wrap. This should be a + method defined on the actor class and should therefore take an + instance of the actor as the first argument. + actor_imported (bool): Whether the actor has been imported. + Checkpointing operations will not be run if this is set to + False. + + Returns: + A function that executes the given actor method on the worker's + stored instance of the actor. The function also updates the + worker's internal state to record the executed method. + """ + + def actor_method_executor(dummy_return_id, actor, *args): + # Update the actor's task counter to reflect the task we're about + # to execute. + self._worker.actor_task_counter += 1 + + # If this is the first task to execute on the actor, try to resume + # from a checkpoint. + if actor_imported and self._worker.actor_task_counter == 1: + checkpoint_resumed = ray.actor.restore_and_log_checkpoint( + self._worker, actor) + if checkpoint_resumed: + # NOTE(swang): Since we did not actually execute the + # __init__ method, this will put None as the return value. + # If the __init__ method is supposed to return multiple + # values, an exception will be logged. + return + + # Determine whether we should checkpoint the actor. + checkpointing_on = (actor_imported + and self._worker.actor_checkpoint_interval > 0) + # We should checkpoint the actor if user checkpointing is on, we've + # executed checkpoint_interval tasks since the last checkpoint, and + # the method we're about to execute is not a checkpoint. + save_checkpoint = (checkpointing_on + and (self._worker.actor_task_counter % + self._worker.actor_checkpoint_interval == 0 + and method_name != "__ray_checkpoint__")) + + # Execute the assigned method and save a checkpoint if necessary. + try: + if is_class_method(method): + method_returns = method(*args) + else: + method_returns = method(actor, *args) + except Exception: + # Save the checkpoint before allowing the method exception + # to be thrown. + if save_checkpoint: + ray.actor.save_and_log_checkpoint(self._worker, actor) + raise + else: + # Save the checkpoint before returning the method's return + # values. + if save_checkpoint: + ray.actor.save_and_log_checkpoint(self._worker, actor) + return method_returns + + return actor_method_executor diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py index 659cdf1ce..85fe0b89d 100644 --- a/python/ray/import_thread.py +++ b/python/ray/import_thread.py @@ -88,7 +88,8 @@ class ImportThread(object): if key.startswith(b"RemoteFunction"): with profiling.profile( "register_remote_function", worker=self.worker): - self.fetch_and_register_remote_function(key) + (self.worker.function_actor_manager. + fetch_and_register_remote_function(key)) elif key.startswith(b"FunctionsToRun"): with profiling.profile( "fetch_and_run_function", worker=self.worker): @@ -103,58 +104,6 @@ class ImportThread(object): else: raise Exception("This code should be unreachable.") - def fetch_and_register_remote_function(self, key): - """Import a remote function.""" - from ray.worker import FunctionExecutionInfo - (driver_id, function_id_str, function_name, serialized_function, - num_return_vals, module, resources, - max_calls) = self.redis_client.hmget(key, [ - "driver_id", "function_id", "name", "function", "num_return_vals", - "module", "resources", "max_calls" - ]) - function_id = ray.ObjectID(function_id_str) - function_name = utils.decode(function_name) - max_calls = int(max_calls) - module = utils.decode(module) - - # This is a placeholder in case the function can't be unpickled. This - # will be overwritten if the function is successfully registered. - def f(): - raise Exception("This function was not imported properly.") - - self.worker.function_execution_info[driver_id][function_id.id()] = ( - FunctionExecutionInfo( - function=f, function_name=function_name, max_calls=max_calls)) - self.worker.num_task_executions[driver_id][function_id.id()] = 0 - - try: - function = pickle.loads(serialized_function) - except Exception: - # If an exception was thrown when the remote function was imported, - # we record the traceback and notify the scheduler of the failure. - traceback_str = utils.format_error_message(traceback.format_exc()) - # Log the error message. - utils.push_error_to_driver( - self.worker, - ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, - traceback_str, - driver_id=driver_id, - data={ - "function_id": function_id.id(), - "function_name": function_name - }) - else: - # TODO(rkn): Why is the below line necessary? - function.__module__ = module - self.worker.function_execution_info[driver_id][ - function_id.id()] = (FunctionExecutionInfo( - function=function, - function_name=function_name, - max_calls=max_calls)) - # Add the function to the function table. - self.redis_client.rpush(b"FunctionTable:" + function_id.id(), - self.worker.worker_id) - def fetch_and_execute_function_to_run(self, key): """Run on arbitrary function on the worker.""" (driver_id, serialized_function, diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 287d3d045..b96f5d7e7 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -22,7 +22,7 @@ def compute_function_id(function): func: The actual function. Returns: - This returns the function ID. + Raw bytes of the function id """ function_id_hash = hashlib.sha1() # Include the function module and name in the hash. @@ -39,8 +39,6 @@ def compute_function_id(function): # Compute the function ID. function_id = function_id_hash.digest() assert len(function_id) == ray_constants.ID_SIZE - function_id = ray.ObjectID(function_id) - return function_id @@ -72,7 +70,7 @@ class RemoteFunction(object): # TODO(rkn): We store the function ID as a string, so that # RemoteFunction objects can be pickled. We should undo this when # we allow ObjectIDs to be pickled. - self._function_id = compute_function_id(self._function).id() + self._function_id = compute_function_id(function) self._function_name = ( self._function.__module__ + '.' + self._function.__name__) self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS @@ -90,11 +88,7 @@ class RemoteFunction(object): # # Export the function. worker = ray.worker.get_global_worker() - if worker.mode == ray.worker.SCRIPT_MODE: - self._export() - elif worker.mode is None: - worker.cached_remote_functions_and_actors.append( - ("remote_function", self)) + worker.function_actor_manager.export(self) def __call__(self, *args, **kwargs): raise Exception("Remote functions cannot be called directly. Instead " @@ -141,9 +135,3 @@ class RemoteFunction(object): return object_ids[0] elif len(object_ids) > 1: return object_ids - - def _export(self): - worker = ray.worker.get_global_worker() - worker.export_remote_function( - ray.ObjectID(self._function_id), self._function_name, - self._function, self._max_calls, self) diff --git a/python/ray/utils.py b/python/ray/utils.py index 0f6adaea9..55f85c8ac 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -5,6 +5,7 @@ from __future__ import print_function import binascii import functools import hashlib +import inspect import numpy as np import os import subprocess @@ -144,6 +145,23 @@ def is_cython(obj): (hasattr(obj, "__func__") and check_cython(obj.__func__)) +def is_function_or_method(obj): + """Check if an object is a function or method. + + Args: + obj: The Python object in question. + + Returns: + True if the object is an function or method. + """ + return (inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj)) + + +def is_class_method(f): + """Returns whether the given method is a class_method.""" + return hasattr(f, "__self__") and f.__self__ is not None + + def random_string(): """Generate a random string to use as an ID. diff --git a/python/ray/worker.py b/python/ray/worker.py index c0714b1fc..5299c2d07 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function import atexit -import collections import colorama import hashlib import inspect @@ -33,6 +32,7 @@ import ray.plasma import ray.ray_constants as ray_constants from ray import import_thread from ray import profiling +from ray.function_manager import FunctionActorManager from ray.utils import ( binary_to_hex, check_oversized_pickle, @@ -176,11 +176,6 @@ class RayGetArgumentError(Exception): self.task_error)) -FunctionExecutionInfo = collections.namedtuple( - "FunctionExecutionInfo", ["function", "function_name", "max_calls"]) -"""FunctionExecutionInfo: A named tuple storing remote function information.""" - - class Worker(object): """A class used to define the control flow of a worker process. @@ -189,19 +184,9 @@ class Worker(object): functions outside of this class are considered exposed. Attributes: - function_execution_info (Dict[str, FunctionExecutionInfo]): A - dictionary mapping the name of a remote function to the remote - function itself. This is the set of remote functions that can be - executed by this worker. connected (bool): True if Ray has been started and False otherwise. mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and WORKER_MODE. - cached_remote_functions_and_actors: A list of information for exporting - remote functions and actor classes definitions that were defined - before the worker called connect. When the worker eventually does - call connect, if it is a driver, it will export these functions and - actors. If cached_remote_functions_and_actors is None, that means - that connect has been called already. cached_functions_to_run (List): A list of functions to run on all of the workers that should be exported as soon as connect is called. profiler: the profiler used to aggregate profiling information. @@ -216,24 +201,15 @@ class Worker(object): def __init__(self): """Initialize a Worker object.""" - # This field is a dictionary that maps a driver ID to a dictionary of - # functions (and information about those functions) that have been - # registered for that driver (this inner dictionary maps function IDs - # to a FunctionExecutionInfo object. This should only be used on - # workers that execute remote functions. - self.function_execution_info = collections.defaultdict(lambda: {}) # This is a dictionary mapping driver ID to a dictionary that maps # remote function IDs for that driver to a counter of the number of # times that remote function has been executed on this worker. The # counter is incremented every time the function is executed on this # worker. When the counter reaches the maximum number of executions # allowed for a particular function, the worker is killed. - self.num_task_executions = collections.defaultdict(lambda: {}) self.connected = False self.mode = None - self.cached_remote_functions_and_actors = [] self.cached_functions_to_run = [] - self.fetch_and_register_actor = None self.actor_init_error = None self.make_actor = None self.actors = {} @@ -255,6 +231,7 @@ class Worker(object): self.serialization_context_map = {} # Identity of the driver that this worker is processing. self.task_driver_id = None + self.function_actor_manager = FunctionActorManager(self) def mark_actor_init_failed(self, error): """Called to mark this actor as failed during initialization.""" @@ -674,57 +651,6 @@ class Worker(object): return task.returns() - def export_remote_function(self, function_id, function_name, function, - max_calls, decorated_function): - """Export a remote function. - - Args: - function_id: The ID of the function. - function_name: The name of the function. - function: The raw undecorated function to export. - max_calls: The maximum number of times a given worker can execute - this function before exiting. - decorated_function: The decorated function (this is used to enable - the remote function to recursively call itself). - """ - if self.mode != SCRIPT_MODE: - raise Exception("export_remote_function can only be called on a " - "driver.") - - key = (b"RemoteFunction:" + self.task_driver_id.id() + b":" + - function_id.id()) - - # Work around limitations of Python pickling. - function_name_global_valid = function.__name__ in function.__globals__ - function_name_global_value = function.__globals__.get( - function.__name__) - # Allow the function to reference itself as a global variable - if not is_cython(function): - function.__globals__[function.__name__] = decorated_function - try: - pickled_function = pickle.dumps(function) - finally: - # Undo our changes - if function_name_global_valid: - function.__globals__[function.__name__] = ( - function_name_global_value) - else: - del function.__globals__[function.__name__] - - check_oversized_pickle(pickled_function, function_name, - "remote function", self) - - self.redis_client.hmset( - key, { - "driver_id": self.task_driver_id.id(), - "function_id": function_id.id(), - "name": function_name, - "module": function.__module__, - "function": pickled_function, - "max_calls": max_calls - }) - self.redis_client.rpush("Exports", key) - def run_function_on_all_workers(self, function, run_on_other_drivers=False): """Run arbitrary code on all of the workers. @@ -783,47 +709,6 @@ class Worker(object): # operations into a transaction (or by implementing a custom # command that does all three things). - def _wait_for_function(self, function_id, driver_id, timeout=10): - """Wait until the function to be executed is present on this worker. - - This method will simply loop until the import thread has imported the - relevant function. If we spend too long in this loop, that may indicate - a problem somewhere and we will push an error message to the user. - - If this worker is an actor, then this will wait until the actor has - been defined. - - Args: - function_id (str): The ID of the function that we want to execute. - driver_id (str): The ID of the driver to push the error message to - if this times out. - """ - start_time = time.time() - # Only send the warning once. - warning_sent = False - while True: - with self.lock: - if (self.actor_id == NIL_ACTOR_ID - and (function_id.id() in - self.function_execution_info[driver_id])): - break - elif self.actor_id != NIL_ACTOR_ID and ( - self.actor_id in self.actors): - break - if time.time() - start_time > timeout: - warning_message = ("This worker was asked to execute a " - "function that it does not have " - "registered. You may have to restart " - "Ray.") - if not warning_sent: - ray.utils.push_error_to_driver( - self, - ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR, - warning_message, - driver_id=driver_id) - warning_sent = True - time.sleep(0.001) - def _get_arguments_for_execution(self, function_name, serialized_args): """Retrieve the arguments for the remote function. @@ -891,7 +776,7 @@ class Worker(object): self.put_object(object_ids[i], outputs[i]) - def _process_task(self, task): + def _process_task(self, task, function_execution_info): """Execute a task assigned to this worker. This method deserializes a task from the scheduler, and attempts to @@ -913,10 +798,8 @@ class Worker(object): return_object_ids = task.returns() if task.actor_id().id() != NIL_ACTOR_ID: dummy_return_id = return_object_ids.pop() - function_executor = self.function_execution_info[ - self.task_driver_id.id()][function_id.id()].function - function_name = self.function_execution_info[self.task_driver_id.id()][ - function_id.id()].function_name + function_executor = function_execution_info.function + function_name = function_execution_info.function_name # Get task arguments from the object store. try: @@ -926,12 +809,12 @@ class Worker(object): arguments = self._get_arguments_for_execution( function_name, args) except (RayGetError, RayGetArgumentError) as e: - self._handle_process_task_failure(function_id, return_object_ids, - e, None) + self._handle_process_task_failure(function_id, function_name, + return_object_ids, e, None) return except Exception as e: self._handle_process_task_failure( - function_id, return_object_ids, e, + function_id, function_name, return_object_ids, e, ray.utils.format_error_message(traceback.format_exc())) return @@ -950,8 +833,9 @@ class Worker(object): task_exception = task.actor_id().id() == NIL_ACTOR_ID traceback_str = ray.utils.format_error_message( traceback.format_exc(), task_exception=task_exception) - self._handle_process_task_failure(function_id, return_object_ids, - e, traceback_str) + self._handle_process_task_failure(function_id, function_name, + return_object_ids, e, + traceback_str) return # Store the outputs in the local object store. @@ -966,13 +850,11 @@ class Worker(object): self._store_outputs_in_objstore(return_object_ids, outputs) except Exception as e: self._handle_process_task_failure( - function_id, return_object_ids, e, + function_id, function_name, return_object_ids, e, ray.utils.format_error_message(traceback.format_exc())) - def _handle_process_task_failure(self, function_id, return_object_ids, - error, backtrace): - function_name = self.function_execution_info[self.task_driver_id.id()][ - function_id.id()].function_name + def _handle_process_task_failure(self, function_id, function_name, + return_object_ids, error, backtrace): failure_object = RayTaskError(function_name, error, backtrace) failure_objects = [ failure_object for _ in range(len(return_object_ids)) @@ -1014,7 +896,7 @@ class Worker(object): time.sleep(0.001) with self.lock: - self.fetch_and_register_actor(key, self) + self.function_actor_manager.fetch_and_register_actor(key) def _wait_for_and_process_task(self, task): """Wait for a task to be ready and process the task. @@ -1031,11 +913,8 @@ class Worker(object): self._become_actor(task) return - # Wait until the function to be executed has actually been registered - # on this worker. We will push warnings to the user if we spend too - # long in this loop. - with profiling.profile("wait_for_function", worker=self): - self._wait_for_function(function_id, driver_id) + execution_info = self.function_actor_manager.get_execution_info( + driver_id, function_id) # Execute the task. # TODO(rkn): Consider acquiring this lock with a timeout and pushing a @@ -1043,9 +922,7 @@ class Worker(object): # because that may indicate that the system is hanging, and it'd be # good to know where the system is hanging. with self.lock: - - function_name = (self.function_execution_info[driver_id][ - function_id.id()]).function_name + function_name = execution_info.function_name if not self.use_raylet: extra_data = { "function_name": function_name, @@ -1058,7 +935,7 @@ class Worker(object): "task_id": task.task_id().hex() } with profiling.profile("task", extra_data=extra_data, worker=self): - self._process_task(task) + self._process_task(task, execution_info) # In the non-raylet code path, push all of the log events to the global # state store. In the raylet code path, this is done periodically in a @@ -1067,11 +944,11 @@ class Worker(object): self.profiler.flush_profile_data() # Increase the task execution counter. - self.num_task_executions[driver_id][function_id.id()] += 1 + self.function_actor_manager.increase_task_counter( + driver_id, function_id.id()) - reached_max_executions = ( - self.num_task_executions[driver_id][function_id.id()] == self. - function_execution_info[driver_id][function_id.id()].max_calls) + reached_max_executions = (self.function_actor_manager.get_task_counter( + driver_id, function_id.id()) == execution_info.max_calls) if reached_max_executions: self.local_scheduler_client.disconnect() os._exit(0) @@ -2112,7 +1989,6 @@ def connect(info, error_message = "Perhaps you called ray.init twice by accident?" assert not worker.connected, error_message assert worker.cached_functions_to_run is not None, error_message - assert worker.cached_remote_functions_and_actors is not None, error_message # Initialize some fields. worker.worker_id = random_string() @@ -2350,18 +2226,9 @@ def connect(info, # Export cached functions_to_run. for function in worker.cached_functions_to_run: worker.run_function_on_all_workers(function) - # Export cached remote functions to the workers. - for cached_type, info in worker.cached_remote_functions_and_actors: - if cached_type == "remote_function": - info._export() - elif cached_type == "actor": - (key, actor_class_info) = info - ray.actor.publish_actor_class_to_key(key, actor_class_info, - worker) - else: - assert False, "This code should be unreachable." + # Export cached remote functions and actors to the workers. + worker.function_actor_manager.export_cached() worker.cached_functions_to_run = None - worker.cached_remote_functions_and_actors = None def disconnect(worker=global_worker): @@ -2372,7 +2239,7 @@ def disconnect(worker=global_worker): # tests. worker.connected = False worker.cached_functions_to_run = [] - worker.cached_remote_functions_and_actors = [] + worker.function_actor_manager.reset_cache() worker.serialization_context_map.clear()