Enable function_descriptor in backend to replace the function_id (#3028)

This commit is contained in:
Yuhong Guo
2018-12-19 07:53:59 +08:00
committed by Robert Nishihara
parent 3822b20319
commit fb33fa9097
20 changed files with 557 additions and 282 deletions
+46 -41
View File
@@ -10,7 +10,7 @@ import sys
import traceback
import ray.cloudpickle as pickle
from ray.function_manager import FunctionActorManager
from ray.function_manager import FunctionDescriptor
import ray.raylet
import ray.ray_constants as ray_constants
import ray.signature as signature
@@ -76,18 +76,6 @@ def compute_actor_handle_id_non_forked(actor_id, actor_handle_id,
return ray.ObjectID(handle_id)
def compute_actor_creation_function_id(class_id):
"""Compute the function ID for an actor creation task.
Args:
class_id: The ID of the actor class.
Returns:
The function ID of the actor creation event.
"""
return ray.ObjectID(class_id)
def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
frontier):
"""Set the most recent checkpoint associated with a given actor ID.
@@ -287,6 +275,23 @@ class ActorClass(object):
self._actor_methods = inspect.getmembers(
self._modified_class, ray.utils.is_function_or_method)
self._actor_method_names = [
method_name for method_name, _ in self._actor_methods
]
constructor_name = "__init__"
if constructor_name not in self._actor_method_names:
# Add __init__ if it does not exist.
# Actor creation will be executed with __init__ together.
# Assign an __init__ function will avoid many checks later on.
def __init__(self):
pass
self._modified_class.__init__ = __init__
self._actor_method_names.append(constructor_name)
self._actor_methods.append((constructor_name, __init__))
# Extract the signatures of each of the methods. This will be used
# to catch some errors if the methods are called with inappropriate
# arguments.
@@ -300,7 +305,6 @@ class ActorClass(object):
signature.check_signature_supported(method, warn=True)
self._method_signatures[method_name] = signature.extract_signature(
method, ignore_first=not ray.utils.is_class_method(method))
# Set the default number of return values for this method.
if hasattr(method, "__ray_num_return_vals__"):
self._actor_method_num_return_vals[method_name] = (
@@ -309,10 +313,6 @@ class ActorClass(object):
self._actor_method_num_return_vals[method_name] = (
DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS)
self._actor_method_names = [
method_name for method_name, _ in self._actor_methods
]
def __call__(self, *args, **kwargs):
raise Exception("Actors methods cannot be instantiated directly. "
"Instead of running '{}()', try '{}.remote()'.".format(
@@ -386,14 +386,14 @@ class ActorClass(object):
# Instead, instantiate the actor locally and add it to the worker's
# dictionary
if worker.mode == ray.LOCAL_MODE:
worker.actors[actor_id] = self._modified_class.__new__(
self._modified_class)
worker.actors[actor_id] = self._modified_class(
*copy.deepcopy(args), **copy.deepcopy(kwargs))
else:
# Export the actor.
if not self._exported:
worker.function_actor_manager.export_actor_class(
self._class_id, self._modified_class,
self._actor_method_names, self._checkpoint_interval)
self._modified_class, self._actor_method_names,
self._checkpoint_interval)
self._exported = True
resources = ray.utils.resources_from_resource_arguments(
@@ -409,10 +409,19 @@ class ActorClass(object):
actor_placement_resources = resources.copy()
actor_placement_resources["CPU"] += 1
creation_args = [self._class_id]
function_id = compute_actor_creation_function_id(self._class_id)
if args is None:
args = []
if kwargs is None:
kwargs = {}
function_name = "__init__"
function_signature = self._method_signatures[function_name]
creation_args = signature.extend_args(function_signature, args,
kwargs)
function_descriptor = FunctionDescriptor(
self._modified_class.__module__, function_name,
self._modified_class.__name__)
[actor_cursor] = worker.submit_task(
function_id,
function_descriptor,
creation_args,
actor_creation_id=actor_id,
max_actor_reconstructions=self._max_reconstructions,
@@ -424,19 +433,10 @@ class ActorClass(object):
# creation task.
actor_counter = 1
actor_handle = ActorHandle(
actor_id, self._class_name, actor_cursor, actor_counter,
self._actor_method_names, self._method_signatures,
self._actor_method_num_return_vals, actor_cursor,
self._actor_method_cpus, worker.task_driver_id)
# Call __init__ as a remote function.
if "__init__" in actor_handle._ray_actor_method_names:
actor_handle.__init__.remote(*args, **kwargs)
else:
if len(args) != 0 or len(kwargs) != 0:
raise Exception("Arguments cannot be passed to the actor "
"constructor because this actor class has no "
"__init__ method.")
actor_id, self._modified_class.__module__, self._class_name,
actor_cursor, actor_counter, self._actor_method_names,
self._method_signatures, self._actor_method_num_return_vals,
actor_cursor, self._actor_method_cpus, worker.task_driver_id)
return actor_handle
@@ -458,6 +458,7 @@ class ActorHandle(object):
Attributes:
_ray_actor_id: The ID of the corresponding actor.
_ray_module_name: The module name of this actor.
_ray_actor_handle_id: The ID of this handle. If this is the "original"
handle for an actor (as opposed to one created by passing another
handle into a task), then this ID must be NIL_ID. If this
@@ -496,6 +497,7 @@ class ActorHandle(object):
def __init__(self,
actor_id,
module_name,
class_name,
actor_cursor,
actor_counter,
@@ -510,6 +512,7 @@ class ActorHandle(object):
# False if this actor handle was created by forking or pickling. True
# if it was created by the _serialization_helper function.
self._ray_original_handle = previous_actor_handle_id is None
self._ray_module_name = module_name
self._ray_actor_id = actor_id
if self._ray_original_handle:
@@ -602,10 +605,10 @@ class ActorHandle(object):
else:
actor_handle_id = self._ray_actor_handle_id
function_id = FunctionActorManager.compute_actor_method_function_id(
self._ray_class_name, method_name)
function_descriptor = FunctionDescriptor(
self._ray_module_name, method_name, self._ray_class_name)
object_ids = worker.submit_task(
function_id,
function_descriptor,
args,
actor_id=self._ray_actor_id,
actor_handle_id=actor_handle_id,
@@ -706,6 +709,7 @@ class ActorHandle(object):
"""
state = {
"actor_id": self._ray_actor_id.id(),
"module_name": self._ray_module_name,
"class_name": self._ray_class_name,
"actor_forks": self._ray_actor_forks,
"actor_cursor": self._ray_actor_cursor.id()
@@ -753,6 +757,7 @@ class ActorHandle(object):
self.__init__(
ray.ObjectID(state["actor_id"]),
state["module_name"],
state["class_name"],
ray.ObjectID(state["actor_cursor"])
if state["actor_cursor"] is not None else None,
+10 -2
View File
@@ -9,6 +9,7 @@ import sys
import time
import ray
from ray.function_manager import FunctionDescriptor
import ray.gcs_utils
import ray.ray_constants as ray_constants
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
@@ -234,6 +235,9 @@ class GlobalState(object):
execution_spec = task_table_message.TaskExecutionSpec()
task_spec = task_table_message.TaskSpecification()
task_spec = ray.raylet.task_from_string(task_spec)
function_descriptor_list = task_spec.function_descriptor_list()
function_descriptor = FunctionDescriptor.from_bytes_list(
function_descriptor_list)
task_spec_info = {
"DriverID": binary_to_hex(task_spec.driver_id().id()),
"TaskID": binary_to_hex(task_spec.task_id().id()),
@@ -245,10 +249,14 @@ class GlobalState(object):
"ActorCreationDummyObjectID": binary_to_hex(
task_spec.actor_creation_dummy_object_id().id()),
"ActorCounter": task_spec.actor_counter(),
"FunctionID": binary_to_hex(task_spec.function_id().id()),
"Args": task_spec.arguments(),
"ReturnObjectIDs": task_spec.returns(),
"RequiredResources": task_spec.required_resources()
"RequiredResources": task_spec.required_resources(),
"FunctionID": binary_to_hex(function_descriptor.function_id.id()),
"FunctionHash": binary_to_hex(function_descriptor.function_hash),
"ModuleName": function_descriptor.module_name,
"ClassName": function_descriptor.class_name,
"FunctionName": function_descriptor.function_name,
}
return {
+278 -46
View File
@@ -5,6 +5,7 @@ from __future__ import print_function
import hashlib
import inspect
import json
import logging
import sys
import time
import traceback
@@ -18,6 +19,7 @@ from ray import profiling
from ray import ray_constants
from ray import cloudpickle as pickle
from ray.utils import (
binary_to_hex,
is_cython,
is_function_or_method,
is_class_method,
@@ -31,6 +33,228 @@ FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
["function", "function_name", "max_calls"])
"""FunctionExecutionInfo: A named tuple storing remote function information."""
logger = logging.getLogger(__name__)
class FunctionDescriptor(object):
"""A class used to describe a python function.
Attributes:
module_name: the module name that the function belongs to.
class_name: the class name that the function belongs to if exists.
It could be empty is the function is not a class method.
function_name: the function name of the function.
function_hash: the hash code of the function source code if the
function code is available.
function_id: the function id calculated from this descriptor.
is_for_driver_task: whether this descriptor is for driver task.
"""
def __init__(self,
module_name,
function_name,
class_name="",
function_source_hash=b""):
self._module_name = module_name
self._class_name = class_name
self._function_name = function_name
self._function_source_hash = function_source_hash
self._function_id = self._get_function_id()
def __repr__(self):
return ("FunctionDescriptor:" + self._module_name + "." +
self._class_name + "." + self._function_name + "." +
binary_to_hex(self._function_source_hash))
@classmethod
def from_bytes_list(cls, function_descriptor_list):
"""Create a FunctionDescriptor instance from list of bytes.
This function is used to create the function descriptor from
backend data.
Args:
cls: Current class which is required argument for classmethod.
function_descriptor_list: list of bytes to represent the
function descriptor.
Returns:
The FunctionDescriptor instance created from the bytes list.
"""
assert isinstance(function_descriptor_list, list)
if len(function_descriptor_list) == 0:
# This is a function descriptor of driver task.
return FunctionDescriptor.for_driver_task()
elif (len(function_descriptor_list) == 3
or len(function_descriptor_list) == 4):
module_name = function_descriptor_list[0].decode()
class_name = function_descriptor_list[1].decode()
function_name = function_descriptor_list[2].decode()
if len(function_descriptor_list) == 4:
return cls(module_name, function_name, class_name,
function_descriptor_list[3])
else:
return cls(module_name, function_name, class_name)
else:
raise Exception(
"Invalid input for FunctionDescriptor.from_bytes_list")
@classmethod
def from_function(cls, function):
"""Create a FunctionDescriptor from a function instance.
This function is used to create the function descriptor from
a python function. If a function is a class function, it should
not be used by this function.
Args:
cls: Current class which is required argument for classmethod.
function: the python function used to create the function
descriptor.
Returns:
The FunctionDescriptor instance created according to the function.
"""
module_name = function.__module__
function_name = function.__name__
class_name = ""
function_source_hasher = hashlib.sha1()
try:
# If we are running a script or are in IPython, include the source
# code in the hash.
source = inspect.getsource(function).encode("ascii")
function_source_hasher.update(source)
function_source_hash = function_source_hasher.digest()
except (IOError, OSError, TypeError):
# Source code may not be available:
# e.g. Cython or Python interpreter.
function_source_hash = b""
return cls(module_name, function_name, class_name,
function_source_hash)
@classmethod
def from_class(cls, target_class):
"""Create a FunctionDescriptor from a class.
Args:
cls: Current class which is required argument for classmethod.
target_class: the python class used to create the function
descriptor.
Returns:
The FunctionDescriptor instance created according to the class.
"""
module_name = target_class.__module__
class_name = target_class.__name__
return cls(module_name, "__init__", class_name)
@classmethod
def for_driver_task(cls):
"""Create a FunctionDescriptor instance for a driver task."""
return cls("", "", "", b"")
@property
def is_for_driver_task(self):
"""See whether this function descriptor is for a driver or not.
Returns:
True if this function descriptor is for driver tasks.
"""
return all(
len(x) == 0
for x in [self.module_name, self.class_name, self.function_name])
@property
def module_name(self):
"""Get the module name of current function descriptor.
Returns:
The module name of the function descriptor.
"""
return self._module_name
@property
def class_name(self):
"""Get the class name of current function descriptor.
Returns:
The class name of the function descriptor. It could be
empty if the function is not a class method.
"""
return self._class_name
@property
def function_name(self):
"""Get the function name of current function descriptor.
Returns:
The function name of the function descriptor.
"""
return self._function_name
@property
def function_hash(self):
"""Get the hash code of the function source code.
Returns:
The bytes with length of ray_constants.ID_SIZE if the source
code is available. Otherwise, the bytes length will be 0.
"""
return self._function_source_hash
@property
def function_id(self):
"""Get the function id calculated from this descriptor.
Returns:
The value of ray.ObjectID that represents the function id.
"""
return ray.ObjectID(self._function_id)
def _get_function_id(self):
"""Calculate the function id of current function descriptor.
This function id is calculated from all the fields of function
descriptor.
Returns:
bytes with length of ray_constants.ID_SIZE.
"""
if self.is_for_driver_task:
return ray_constants.NIL_FUNCTION_ID.id()
function_id_hash = hashlib.sha1()
# Include the function module and name in the hash.
function_id_hash.update(self.module_name.encode("ascii"))
function_id_hash.update(self.function_name.encode("ascii"))
function_id_hash.update(self.class_name.encode("ascii"))
function_id_hash.update(self._function_source_hash)
# Compute the function ID.
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return function_id
def get_function_descriptor_list(self):
"""Return a list of bytes representing the function descriptor.
This function is used to pass this function descriptor to backend.
Returns:
A list of bytes.
"""
descriptor_list = []
if self.is_for_driver_task:
# Driver task returns an empty list.
return descriptor_list
else:
descriptor_list.append(self.module_name.encode("ascii"))
descriptor_list.append(self.class_name.encode("ascii"))
descriptor_list.append(self.function_name.encode("ascii"))
if len(self._function_source_hash) != 0:
descriptor_list.append(self._function_source_hash)
return descriptor_list
class FunctionActorManager(object):
"""A class used to export/load remote functions and actors.
@@ -45,6 +269,8 @@ class FunctionActorManager(object):
and execution_info.
_num_task_executions: The map from driver_id to function
execution times.
imported_actor_classes: The set of actor classes keys (format:
ActorClass:function_id) that are already in GCS.
"""
def __init__(self, worker):
@@ -58,11 +284,17 @@ class FunctionActorManager(object):
# workers that execute remote functions.
self._function_execution_info = defaultdict(lambda: {})
self._num_task_executions = defaultdict(lambda: {})
# A set of all of the actor class keys that have been imported by the
# import thread. It is safe to convert this worker into an actor of
# these types.
self.imported_actor_classes = set()
def increase_task_counter(self, driver_id, function_id):
def increase_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id.id()
self._num_task_executions[driver_id][function_id] += 1
def get_task_counter(self, driver_id, function_id):
def get_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id.id()
return self._num_task_executions[driver_id][function_id]
def export_cached(self):
@@ -124,13 +356,13 @@ class FunctionActorManager(object):
check_oversized_pickle(pickled_function,
remote_function._function_name,
"remote function", self._worker)
key = (b"RemoteFunction:" + self._worker.task_driver_id.id() + b":" +
remote_function._function_id)
remote_function._function_descriptor.function_id.id())
self._worker.redis_client.hmset(
key, {
"driver_id": self._worker.task_driver_id.id(),
"function_id": remote_function._function_id,
"function_id": remote_function._function_descriptor.
function_id.id(),
"name": remote_function._function_name,
"module": function.__module__,
"function": pickled_function,
@@ -193,24 +425,28 @@ class FunctionActorManager(object):
self._worker.redis_client.rpush(
b"FunctionTable:" + function_id.id(), self._worker.worker_id)
def get_execution_info(self, driver_id, function_id):
def get_execution_info(self, driver_id, function_descriptor):
"""Get the FunctionExecutionInfo of a remote function.
Args:
driver_id: ID of the driver that the function belongs to.
function_id: ID of the function to get.
function_descriptor: The FunctionDescriptor of the function to get.
Returns:
A FunctionExecutionInfo object.
"""
# Wait until the function to be executed has actually been registered
# on this worker. We will push warnings to the user if we spend too
# long in this loop.
with profiling.profile("wait_for_function", worker=self._worker):
self._wait_for_function(function_id, driver_id)
return self._function_execution_info[driver_id][function_id.id()]
function_id = function_descriptor.function_id.id()
def _wait_for_function(self, function_id, driver_id, timeout=10):
# Wait until the function to be executed has actually been
# registered on this worker. We will push warnings to the user if
# we spend too long in this loop.
# The driver function may not be found in sys.path. Try to load
# the function from GCS.
with profiling.profile("wait_for_function", worker=self._worker):
self._wait_for_function(function_descriptor, driver_id)
return self._function_execution_info[driver_id][function_id]
def _wait_for_function(self, function_descriptor, driver_id, timeout=10):
"""Wait until the function to be executed is present on this worker.
This method will simply loop until the import thread has imported the
@@ -221,7 +457,8 @@ class FunctionActorManager(object):
been defined.
Args:
function_id (str): The ID of the function that we want to execute.
function_descriptor : The FunctionDescriptor of the function that
we want to execute.
driver_id (str): The ID of the driver to push the error message to
if this times out.
"""
@@ -231,7 +468,7 @@ class FunctionActorManager(object):
while True:
with self._worker.lock:
if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID
and (function_id.id() in
and (function_descriptor.function_id.id() in
self._function_execution_info[driver_id])):
break
elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and (
@@ -251,24 +488,6 @@ class FunctionActorManager(object):
warning_sent = True
time.sleep(0.001)
@classmethod
def compute_actor_method_function_id(cls, class_name, attr):
"""Get the function ID corresponding to an actor method.
Args:
class_name (str): The class name of the actor.
attr (str): The attribute name of the method.
Returns:
Function ID corresponding to the method.
"""
function_id_hash = hashlib.sha1()
function_id_hash.update(class_name.encode("ascii"))
function_id_hash.update(attr.encode("ascii"))
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return ray.ObjectID(function_id)
def _publish_actor_class_to_key(self, key, actor_class_info):
"""Push an actor class definition to Redis.
@@ -287,9 +506,10 @@ class FunctionActorManager(object):
self._worker.redis_client.hmset(key, actor_class_info)
self._worker.redis_client.rpush("Exports", key)
def export_actor_class(self, class_id, Class, actor_method_names,
def export_actor_class(self, Class, actor_method_names,
checkpoint_interval):
key = b"ActorClass:" + class_id
function_descriptor = FunctionDescriptor.from_class(Class)
key = b"ActorClass:" + function_descriptor.function_id.id()
actor_class_info = {
"class_name": Class.__name__,
"module": Class.__module__,
@@ -318,6 +538,17 @@ class FunctionActorManager(object):
# within tasks. I tried to disable this, but it may be necessary
# because of https://github.com/ray-project/ray/issues/1146.
def load_actor(self, driver_id, function_descriptor):
key = b"ActorClass:" + function_descriptor.function_id.id()
# Wait for the actor class key to have been imported by the
# import thread. TODO(rkn): It shouldn't be possible to end
# up in an infinite loop here, but we should push an error to
# the driver if too much time is spent here.
while key not in self.imported_actor_classes:
time.sleep(0.001)
with self._worker.lock:
self.fetch_and_register_actor(key)
def fetch_and_register_actor(self, actor_class_key):
"""Import an actor.
@@ -330,11 +561,10 @@ class FunctionActorManager(object):
worker: The worker to use.
"""
actor_id_str = self._worker.actor_id
(driver_id, class_id, class_name, module, pickled_class,
checkpoint_interval,
(driver_id, class_name, module, pickled_class, checkpoint_interval,
actor_method_names) = self._worker.redis_client.hmget(
actor_class_key, [
"driver_id", "class_id", "class_name", "module", "class",
"driver_id", "class_name", "module", "class",
"checkpoint_interval", "actor_method_names"
])
@@ -368,9 +598,9 @@ class FunctionActorManager(object):
# Register the actor method executors.
for actor_method_name in actor_method_names:
function_id = (
FunctionActorManager.compute_actor_method_function_id(
class_name, actor_method_name).id())
function_descriptor = FunctionDescriptor(module, actor_method_name,
class_name)
function_id = function_descriptor.function_id.id()
temporary_executor = self._make_actor_method_executor(
actor_method_name,
temporary_actor_method,
@@ -409,9 +639,9 @@ class FunctionActorManager(object):
actor_methods = inspect.getmembers(
unpickled_class, predicate=is_function_or_method)
for actor_method_name, actor_method in actor_methods:
function_id = (
FunctionActorManager.compute_actor_method_function_id(
class_name, actor_method_name).id())
function_descriptor = FunctionDescriptor(
module, actor_method_name, class_name)
function_id = function_descriptor.function_id.id()
executor = self._make_actor_method_executor(
actor_method_name, actor_method, actor_imported=True)
self._function_execution_info[driver_id][function_id] = (
@@ -452,7 +682,9 @@ class FunctionActorManager(object):
# If this is the first task to execute on the actor, try to resume
# from a checkpoint.
if actor_imported and self._worker.actor_task_counter == 1:
# Current __init__ will be called by default. So the real function
# call will start from 2.
if actor_imported and self._worker.actor_task_counter == 2:
checkpoint_resumed = ray.actor.restore_and_log_checkpoint(
self._worker, actor)
if checkpoint_resumed:
+9 -9
View File
@@ -3,26 +3,26 @@ from __future__ import division
from __future__ import print_function
import flatbuffers
import ray.core.generated.ErrorTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.ProfileTableData import ProfileTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData
from ray.core.generated.DriverTableData import DriverTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.Language import Language
from ray.core.generated.ObjectTableData import ObjectTableData
from ray.core.generated.ray.protocol.Task import Task
from ray.core.generated.ProfileTableData import ProfileTableData
from ray.core.generated.TablePrefix import TablePrefix
from ray.core.generated.TablePubsub import TablePubsub
from ray.core.generated.ray.protocol.Task import Task
__all__ = [
"GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData",
"HeartbeatBatchTableData", "DriverTableData", "ProfileTableData",
"ObjectTableData", "Task", "TablePrefix", "TablePubsub",
"ObjectTableData", "Task", "TablePrefix", "TablePubsub", "Language",
"construct_error_message"
]
+1 -1
View File
@@ -98,7 +98,7 @@ class ImportThread(object):
# Keep track of the fact that this actor class has been
# exported so that we know it is safe to turn this worker
# into an actor of that class.
self.worker.imported_actor_classes.add(key)
self.worker.function_actor_manager.imported_actor_classes.add(key)
# TODO(rkn): We may need to bring back the case of
# fetching actor classes here.
else:
+1
View File
@@ -16,6 +16,7 @@ def env_integer(key, default):
ID_SIZE = 20
NIL_JOB_ID = ObjectID(ID_SIZE * b"\xff")
NIL_FUNCTION_ID = NIL_JOB_ID
# If a remote function or actor (or some other export) has serialized size
# greater than this quantity, print an warning.
+4 -36
View File
@@ -3,11 +3,9 @@ from __future__ import division
from __future__ import print_function
import copy
import hashlib
import inspect
import logging
import ray.ray_constants as ray_constants
from ray.function_manager import FunctionDescriptor
import ray.signature
# Default parameters for remote functions.
@@ -18,33 +16,6 @@ DEFAULT_REMOTE_FUNCTION_MAX_CALLS = 0
logger = logging.getLogger(__name__)
def compute_function_id(function):
"""Compute an function ID for a function.
Args:
func: The actual function.
Returns:
Raw bytes of the function id
"""
function_id_hash = hashlib.sha1()
# Include the function module and name in the hash.
function_id_hash.update(function.__module__.encode("ascii"))
function_id_hash.update(function.__name__.encode("ascii"))
try:
# If we are running a script or are in IPython, include the source code
# in the hash.
source = inspect.getsource(function).encode("ascii")
function_id_hash.update(source)
except (IOError, OSError, TypeError):
# Source code may not be available: e.g. Cython or Python interpreter.
pass
# Compute the function ID.
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return function_id
class RemoteFunction(object):
"""A remote function.
@@ -52,7 +23,7 @@ class RemoteFunction(object):
Attributes:
_function: The original function.
_function_id: The ID of the function.
_function_descriptor: The function descriptor.
_function_name: The module and function name.
_num_cpus: The default number of CPUs to use for invocations of this
remote function.
@@ -70,10 +41,7 @@ class RemoteFunction(object):
def __init__(self, function, num_cpus, num_gpus, resources,
num_return_vals, max_calls):
self._function = function
# TODO(rkn): We store the function ID as a string, so that
# RemoteFunction objects can be pickled. We should undo this when
# we allow ObjectIDs to be pickled.
self._function_id = compute_function_id(function)
self._function_descriptor = FunctionDescriptor.from_function(function)
self._function_name = (
self._function.__module__ + '.' + self._function.__name__)
self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS
@@ -147,7 +115,7 @@ class RemoteFunction(object):
result = self._function(*copy.deepcopy(args))
return result
object_ids = worker.submit_task(
ray.ObjectID(self._function_id),
self._function_descriptor,
args,
num_return_vals=num_return_vals,
resources=resources)
+64 -73
View File
@@ -36,7 +36,7 @@ import ray.plasma
import ray.ray_constants as ray_constants
from ray import import_thread
from ray import profiling
from ray.function_manager import FunctionActorManager
from ray.function_manager import (FunctionActorManager, FunctionDescriptor)
from ray.utils import (
check_oversized_pickle,
is_cython,
@@ -54,7 +54,6 @@ ERROR_KEY_PREFIX = b"Error:"
# This must match the definition of NIL_ACTOR_ID in task.h.
NIL_ID = ray_constants.ID_SIZE * b"\xff"
NIL_LOCAL_SCHEDULER_ID = NIL_ID
NIL_FUNCTION_ID = NIL_ID
NIL_ACTOR_ID = NIL_ID
NIL_ACTOR_HANDLE_ID = NIL_ID
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
@@ -161,10 +160,6 @@ class Worker(object):
self.make_actor = None
self.actors = {}
self.actor_task_counter = 0
# A set of all of the actor class keys that have been imported by the
# import thread. It is safe to convert this worker into an actor of
# these types.
self.imported_actor_classes = set()
# The number of threads Plasma should use when putting an object in the
# object store.
self.memcopy_threads = 12
@@ -518,7 +513,7 @@ class Worker(object):
return final_results
def submit_task(self,
function_id,
function_descriptor,
args,
actor_id=None,
actor_handle_id=None,
@@ -531,15 +526,16 @@ class Worker(object):
num_return_vals=None,
resources=None,
placement_resources=None,
driver_id=None):
driver_id=None,
language=ray.gcs_utils.Language.PYTHON):
"""Submit a remote task to the scheduler.
Tell the scheduler to schedule the execution of the function with ID
function_id with arguments args. Retrieve object IDs for the outputs of
the function from the scheduler and immediately return them.
Tell the scheduler to schedule the execution of the function with
function_descriptor with arguments args. Retrieve object IDs for the
outputs of the function from the scheduler and immediately return them.
Args:
function_id: The ID of the function to execute.
function_descriptor: The function descriptor to execute.
args: The arguments to pass into the function. Arguments can be
object IDs or they can be values. If they are values, they must
be serializable objects.
@@ -623,13 +619,15 @@ class Worker(object):
# The parent task must be set for the submitted task.
assert not self.current_task_id.is_nil()
# Submit the task to local scheduler.
function_descriptor_list = (
function_descriptor.get_function_descriptor_list())
task = ray.raylet.Task(
driver_id, ray.ObjectID(function_id.id()),
args_for_local_scheduler, num_return_vals,
self.current_task_id, task_index, actor_creation_id,
actor_creation_dummy_object_id, max_actor_reconstructions,
actor_id, actor_handle_id, actor_counter,
execution_dependencies, resources, placement_resources)
driver_id, function_descriptor_list, args_for_local_scheduler,
num_return_vals, self.current_task_id, task_index,
actor_creation_id, actor_creation_dummy_object_id,
max_actor_reconstructions, actor_id, actor_handle_id,
actor_counter, execution_dependencies, resources,
placement_resources)
self.raylet_client.submit_task(task)
return task.returns()
@@ -778,10 +776,12 @@ class Worker(object):
self.task_driver_id = task.driver_id()
self.current_task_id = task.task_id()
function_id = task.function_id()
function_descriptor = FunctionDescriptor.from_bytes_list(
task.function_descriptor_list())
args = task.arguments()
return_object_ids = task.returns()
if task.actor_id().id() != NIL_ACTOR_ID:
if (task.actor_id().id() != NIL_ACTOR_ID
or task.actor_creation_id().id() != NIL_ACTOR_ID):
dummy_return_id = return_object_ids.pop()
function_executor = function_execution_info.function
function_name = function_execution_info.function_name
@@ -796,33 +796,36 @@ class Worker(object):
function_name, args)
except RayTaskError as e:
self._handle_process_task_failure(
function_id, function_name, return_object_ids, e,
function_descriptor, return_object_ids, e,
ray.utils.format_error_message(traceback.format_exc()))
return
except Exception as e:
self._handle_process_task_failure(
function_id, function_name, return_object_ids, e,
function_descriptor, return_object_ids, e,
ray.utils.format_error_message(traceback.format_exc()))
return
# Execute the task.
try:
with profiling.profile("task:execute", worker=self):
if task.actor_id().id() == NIL_ACTOR_ID:
if (task.actor_id().id() == NIL_ACTOR_ID
and task.actor_creation_id().id() == NIL_ACTOR_ID):
outputs = function_executor(*arguments)
else:
outputs = function_executor(
dummy_return_id, self.actors[task.actor_id().id()],
*arguments)
if task.actor_id().id() != NIL_ACTOR_ID:
key = task.actor_id().id()
else:
key = task.actor_creation_id().id()
outputs = function_executor(dummy_return_id,
self.actors[key], *arguments)
except Exception as e:
# Determine whether the exception occured during a task, not an
# actor method.
task_exception = task.actor_id().id() == NIL_ACTOR_ID
traceback_str = ray.utils.format_error_message(
traceback.format_exc(), task_exception=task_exception)
self._handle_process_task_failure(function_id, function_name,
return_object_ids, e,
traceback_str)
self._handle_process_task_failure(
function_descriptor, return_object_ids, e, traceback_str)
return
# Store the outputs in the local object store.
@@ -837,11 +840,13 @@ class Worker(object):
self._store_outputs_in_object_store(return_object_ids, outputs)
except Exception as e:
self._handle_process_task_failure(
function_id, function_name, return_object_ids, e,
function_descriptor, return_object_ids, e,
ray.utils.format_error_message(traceback.format_exc()))
def _handle_process_task_failure(self, function_id, function_name,
def _handle_process_task_failure(self, function_descriptor,
return_object_ids, error, backtrace):
function_name = function_descriptor.function_name
function_id = function_descriptor.function_id
failure_object = RayTaskError(function_name, backtrace)
failure_objects = [
failure_object for _ in range(len(return_object_ids))
@@ -855,53 +860,34 @@ class Worker(object):
driver_id=self.task_driver_id.id(),
data={
"function_id": function_id.id(),
"function_name": function_name
"function_name": function_name,
"module_name": function_descriptor.module_name,
"class_name": function_descriptor.class_name
})
# Mark the actor init as failed
if self.actor_id != NIL_ACTOR_ID and function_name == "__init__":
self.mark_actor_init_failed(error)
def _become_actor(self, task):
"""Turn this worker into an actor.
Args:
task: The actor creation task.
"""
assert self.actor_id == NIL_ACTOR_ID
arguments = task.arguments()
assert len(arguments) == 1
self.actor_id = task.actor_creation_id().id()
class_id = arguments[0]
key = b"ActorClass:" + class_id
# Wait for the actor class key to have been imported by the import
# thread. TODO(rkn): It shouldn't be possible to end up in an infinite
# loop here, but we should push an error to the driver if too much time
# is spent here.
while key not in self.imported_actor_classes:
time.sleep(0.001)
with self.lock:
self.function_actor_manager.fetch_and_register_actor(key)
def _wait_for_and_process_task(self, task):
"""Wait for a task to be ready and process the task.
Args:
task: The task to execute.
"""
function_id = task.function_id()
function_descriptor = FunctionDescriptor.from_bytes_list(
task.function_descriptor_list())
driver_id = task.driver_id().id()
# TODO(rkn): It would be preferable for actor creation tasks to share
# more of the code path with regular task execution.
if (task.actor_creation_id() != ray.ObjectID(NIL_ACTOR_ID)):
self._become_actor(task)
return
assert self.actor_id == NIL_ACTOR_ID
self.actor_id = task.actor_creation_id().id()
self.function_actor_manager.load_actor(driver_id,
function_descriptor)
execution_info = self.function_actor_manager.get_execution_info(
driver_id, function_id)
driver_id, function_descriptor)
# Execute the task.
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
@@ -915,8 +901,14 @@ class Worker(object):
"task_id": task.task_id().hex()
}
if task.actor_id().id() == NIL_ACTOR_ID:
title = "ray_worker:{}()".format(function_name)
next_title = "ray_worker"
if (task.actor_creation_id() == ray.ObjectID(NIL_ACTOR_ID)):
title = "ray_worker:{}()".format(function_name)
next_title = "ray_worker"
else:
actor = self.actors[task.actor_creation_id().id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
function_name)
next_title = "ray_{}".format(actor.__class__.__name__)
else:
actor = self.actors[task.actor_id().id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
@@ -934,10 +926,10 @@ class Worker(object):
# Increase the task execution counter.
self.function_actor_manager.increase_task_counter(
driver_id, function_id.id())
driver_id, function_descriptor)
reached_max_executions = (self.function_actor_manager.get_task_counter(
driver_id, function_id.id()) == execution_info.max_calls)
driver_id, function_descriptor) == execution_info.max_calls)
if reached_max_executions:
self.raylet_client.disconnect()
sys.exit(0)
@@ -2096,15 +2088,14 @@ def connect(info,
# rerun the driver.
nil_actor_counter = 0
driver_task = ray.raylet.Task(worker.task_driver_id,
ray.ObjectID(NIL_FUNCTION_ID), [], 0,
worker.current_task_id,
worker.task_index,
ray.ObjectID(NIL_ACTOR_ID),
ray.ObjectID(NIL_ACTOR_ID), 0,
ray.ObjectID(NIL_ACTOR_ID),
ray.ObjectID(NIL_ACTOR_ID),
nil_actor_counter, [], {"CPU": 0}, {})
function_descriptor = FunctionDescriptor.for_driver_task()
driver_task = ray.raylet.Task(
worker.task_driver_id,
function_descriptor.get_function_descriptor_list(), [], 0,
worker.current_task_id, worker.task_index,
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), 0,
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID),
nil_actor_counter, [], {"CPU": 0}, {})
# Add the driver task to the task table.
global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD",