From fb33fa90975934bbbd75fe742f3bf6828c71f0fb Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Wed, 19 Dec 2018 07:53:59 +0800 Subject: [PATCH] Enable function_descriptor in backend to replace the function_id (#3028) --- doc/source/conf.py | 1 + .../org/ray/runtime/generated/TaskInfo.java | 54 ++- .../ray/runtime/raylet/RayletClientImpl.java | 2 - python/ray/actor.py | 87 ++--- python/ray/experimental/state.py | 12 +- python/ray/function_manager.py | 324 +++++++++++++++--- python/ray/gcs_utils.py | 18 +- python/ray/import_thread.py | 2 +- python/ray/ray_constants.py | 1 + python/ray/remote_function.py | 40 +-- python/ray/worker.py | 137 ++++---- src/ray/common/common_protocol.cc | 22 ++ src/ray/common/common_protocol.h | 6 + src/ray/gcs/format/gcs.fbs | 2 - src/ray/raylet/lib/python/common_extension.cc | 78 ++++- src/ray/raylet/lineage_cache_test.cc | 5 +- .../raylet/task_dependency_manager_test.cc | 5 +- src/ray/raylet/task_spec.cc | 24 +- src/ray/raylet/task_spec.h | 14 +- src/ray/raylet/worker_pool_test.cc | 5 +- 20 files changed, 557 insertions(+), 282 deletions(-) diff --git a/doc/source/conf.py b/doc/source/conf.py index 2a2b1a37c..8193ccf40 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -50,6 +50,7 @@ MOCK_MODULES = [ "ray.core.generated.ray.protocol.Task", "ray.core.generated.TablePrefix", "ray.core.generated.TablePubsub", + "ray.core.generated.Language", ] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java b/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java index f06992ea8..4e17e45a7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java +++ b/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java @@ -38,23 +38,20 @@ public final class TaskInfo extends Table { public ByteBuffer actorHandleIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 20, 1); } public int actorCounter() { int o = __offset(22); return o != 0 ? bb.getInt(o + bb_pos) : 0; } public boolean isActorCheckpointMethod() { int o = __offset(24); return o != 0 ? 0!=bb.get(o + bb_pos) : false; } - public String functionId() { int o = __offset(26); return o != 0 ? __string(o + bb_pos) : null; } - public ByteBuffer functionIdAsByteBuffer() { return __vector_as_bytebuffer(26, 1); } - public ByteBuffer functionIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 26, 1); } public Arg args(int j) { return args(new Arg(), j); } - public Arg args(Arg obj, int j) { int o = __offset(28); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } - public int argsLength() { int o = __offset(28); return o != 0 ? __vector_len(o) : 0; } - public String returns(int j) { int o = __offset(30); return o != 0 ? __string(__vector(o) + j * 4) : null; } - public int returnsLength() { int o = __offset(30); return o != 0 ? __vector_len(o) : 0; } + public Arg args(Arg obj, int j) { int o = __offset(26); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } + public int argsLength() { int o = __offset(26); return o != 0 ? __vector_len(o) : 0; } + public String returns(int j) { int o = __offset(28); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int returnsLength() { int o = __offset(28); return o != 0 ? __vector_len(o) : 0; } public ResourcePair requiredResources(int j) { return requiredResources(new ResourcePair(), j); } - public ResourcePair requiredResources(ResourcePair obj, int j) { int o = __offset(32); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } - public int requiredResourcesLength() { int o = __offset(32); return o != 0 ? __vector_len(o) : 0; } + public ResourcePair requiredResources(ResourcePair obj, int j) { int o = __offset(30); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } + public int requiredResourcesLength() { int o = __offset(30); return o != 0 ? __vector_len(o) : 0; } public ResourcePair requiredPlacementResources(int j) { return requiredPlacementResources(new ResourcePair(), j); } - public ResourcePair requiredPlacementResources(ResourcePair obj, int j) { int o = __offset(34); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } - public int requiredPlacementResourcesLength() { int o = __offset(34); return o != 0 ? __vector_len(o) : 0; } - public int language() { int o = __offset(36); return o != 0 ? bb.getInt(o + bb_pos) : 0; } - public String functionDescriptor(int j) { int o = __offset(38); return o != 0 ? __string(__vector(o) + j * 4) : null; } - public int functionDescriptorLength() { int o = __offset(38); return o != 0 ? __vector_len(o) : 0; } + public ResourcePair requiredPlacementResources(ResourcePair obj, int j) { int o = __offset(32); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } + public int requiredPlacementResourcesLength() { int o = __offset(32); return o != 0 ? __vector_len(o) : 0; } + public int language() { int o = __offset(34); return o != 0 ? bb.getInt(o + bb_pos) : 0; } + public String functionDescriptor(int j) { int o = __offset(36); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int functionDescriptorLength() { int o = __offset(36); return o != 0 ? __vector_len(o) : 0; } public static int createTaskInfo(FlatBufferBuilder builder, int driver_idOffset, @@ -68,21 +65,19 @@ public final class TaskInfo extends Table { int actor_handle_idOffset, int actor_counter, boolean is_actor_checkpoint_method, - int function_idOffset, int argsOffset, int returnsOffset, int required_resourcesOffset, int required_placement_resourcesOffset, int language, int function_descriptorOffset) { - builder.startObject(18); + builder.startObject(17); TaskInfo.addFunctionDescriptor(builder, function_descriptorOffset); TaskInfo.addLanguage(builder, language); TaskInfo.addRequiredPlacementResources(builder, required_placement_resourcesOffset); TaskInfo.addRequiredResources(builder, required_resourcesOffset); TaskInfo.addReturns(builder, returnsOffset); TaskInfo.addArgs(builder, argsOffset); - TaskInfo.addFunctionId(builder, function_idOffset); TaskInfo.addActorCounter(builder, actor_counter); TaskInfo.addActorHandleId(builder, actor_handle_idOffset); TaskInfo.addActorId(builder, actor_idOffset); @@ -97,7 +92,7 @@ public final class TaskInfo extends Table { return TaskInfo.endTaskInfo(builder); } - public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(18); } + public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(17); } public static void addDriverId(FlatBufferBuilder builder, int driverIdOffset) { builder.addOffset(0, driverIdOffset, 0); } public static void addTaskId(FlatBufferBuilder builder, int taskIdOffset) { builder.addOffset(1, taskIdOffset, 0); } public static void addParentTaskId(FlatBufferBuilder builder, int parentTaskIdOffset) { builder.addOffset(2, parentTaskIdOffset, 0); } @@ -109,21 +104,20 @@ public final class TaskInfo extends Table { public static void addActorHandleId(FlatBufferBuilder builder, int actorHandleIdOffset) { builder.addOffset(8, actorHandleIdOffset, 0); } public static void addActorCounter(FlatBufferBuilder builder, int actorCounter) { builder.addInt(9, actorCounter, 0); } public static void addIsActorCheckpointMethod(FlatBufferBuilder builder, boolean isActorCheckpointMethod) { builder.addBoolean(10, isActorCheckpointMethod, false); } - public static void addFunctionId(FlatBufferBuilder builder, int functionIdOffset) { builder.addOffset(11, functionIdOffset, 0); } - public static void addArgs(FlatBufferBuilder builder, int argsOffset) { builder.addOffset(12, argsOffset, 0); } + public static void addArgs(FlatBufferBuilder builder, int argsOffset) { builder.addOffset(11, argsOffset, 0); } public static int createArgsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startArgsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } - public static void addReturns(FlatBufferBuilder builder, int returnsOffset) { builder.addOffset(13, returnsOffset, 0); } + public static void addReturns(FlatBufferBuilder builder, int returnsOffset) { builder.addOffset(12, returnsOffset, 0); } public static int createReturnsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startReturnsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } - public static void addRequiredResources(FlatBufferBuilder builder, int requiredResourcesOffset) { builder.addOffset(14, requiredResourcesOffset, 0); } + public static void addRequiredResources(FlatBufferBuilder builder, int requiredResourcesOffset) { builder.addOffset(13, requiredResourcesOffset, 0); } public static int createRequiredResourcesVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startRequiredResourcesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } - public static void addRequiredPlacementResources(FlatBufferBuilder builder, int requiredPlacementResourcesOffset) { builder.addOffset(15, requiredPlacementResourcesOffset, 0); } + public static void addRequiredPlacementResources(FlatBufferBuilder builder, int requiredPlacementResourcesOffset) { builder.addOffset(14, requiredPlacementResourcesOffset, 0); } public static int createRequiredPlacementResourcesVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startRequiredPlacementResourcesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } - public static void addLanguage(FlatBufferBuilder builder, int language) { builder.addInt(16, language, 0); } - public static void addFunctionDescriptor(FlatBufferBuilder builder, int functionDescriptorOffset) { builder.addOffset(17, functionDescriptorOffset, 0); } + public static void addLanguage(FlatBufferBuilder builder, int language) { builder.addInt(15, language, 0); } + public static void addFunctionDescriptor(FlatBufferBuilder builder, int functionDescriptorOffset) { builder.addOffset(16, functionDescriptorOffset, 0); } public static int createFunctionDescriptorVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startFunctionDescriptorVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endTaskInfo(FlatBufferBuilder builder) { @@ -131,10 +125,14 @@ public final class TaskInfo extends Table { return o; } - //this is manually added to avoid encoding/decoding cost as our object - //id is a byte array instead of a string + /** This is manually added to avoid encoding/decoding cost as our object + * id is a byte array instead of a string. + * This function is error-prone. If the fields before `returns` changed, + * the offset number should be changed accordingly. + * TODO(yuhguo): fix this error-prone funciton. + */ public ByteBuffer returnsAsByteBuffer(int j) { - int o = __offset(30); + int o = __offset(28); if (o == 0) { return null; } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 9757b4f07..5b83e6c19 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -183,7 +183,6 @@ public class RayletClientImpl implements RayletClient { final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer()); final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer()); final int actorCounter = task.actorCounter; - final int functionIdOffset = fbb.createString(UniqueId.randomId().toByteBuffer()); // Serialize args int[] argsOffsets = new int[task.args.length]; for (int i = 0; i < argsOffsets.length; i++) { @@ -245,7 +244,6 @@ public class RayletClientImpl implements RayletClient { actorHandleIdOffset, actorCounter, false, - functionIdOffset, argsOffset, returnsOffset, requiredResourcesOffset, diff --git a/python/ray/actor.py b/python/ray/actor.py index 1167f064a..1bac7a137 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -10,7 +10,7 @@ import sys import traceback import ray.cloudpickle as pickle -from ray.function_manager import FunctionActorManager +from ray.function_manager import FunctionDescriptor import ray.raylet import ray.ray_constants as ray_constants import ray.signature as signature @@ -76,18 +76,6 @@ def compute_actor_handle_id_non_forked(actor_id, actor_handle_id, return ray.ObjectID(handle_id) -def compute_actor_creation_function_id(class_id): - """Compute the function ID for an actor creation task. - - Args: - class_id: The ID of the actor class. - - Returns: - The function ID of the actor creation event. - """ - return ray.ObjectID(class_id) - - def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, frontier): """Set the most recent checkpoint associated with a given actor ID. @@ -287,6 +275,23 @@ class ActorClass(object): self._actor_methods = inspect.getmembers( self._modified_class, ray.utils.is_function_or_method) + self._actor_method_names = [ + method_name for method_name, _ in self._actor_methods + ] + + constructor_name = "__init__" + if constructor_name not in self._actor_method_names: + # Add __init__ if it does not exist. + # Actor creation will be executed with __init__ together. + + # Assign an __init__ function will avoid many checks later on. + def __init__(self): + pass + + self._modified_class.__init__ = __init__ + self._actor_method_names.append(constructor_name) + self._actor_methods.append((constructor_name, __init__)) + # Extract the signatures of each of the methods. This will be used # to catch some errors if the methods are called with inappropriate # arguments. @@ -300,7 +305,6 @@ class ActorClass(object): signature.check_signature_supported(method, warn=True) self._method_signatures[method_name] = signature.extract_signature( method, ignore_first=not ray.utils.is_class_method(method)) - # Set the default number of return values for this method. if hasattr(method, "__ray_num_return_vals__"): self._actor_method_num_return_vals[method_name] = ( @@ -309,10 +313,6 @@ class ActorClass(object): self._actor_method_num_return_vals[method_name] = ( DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS) - self._actor_method_names = [ - method_name for method_name, _ in self._actor_methods - ] - def __call__(self, *args, **kwargs): raise Exception("Actors methods cannot be instantiated directly. " "Instead of running '{}()', try '{}.remote()'.".format( @@ -386,14 +386,14 @@ class ActorClass(object): # Instead, instantiate the actor locally and add it to the worker's # dictionary if worker.mode == ray.LOCAL_MODE: - worker.actors[actor_id] = self._modified_class.__new__( - self._modified_class) + worker.actors[actor_id] = self._modified_class( + *copy.deepcopy(args), **copy.deepcopy(kwargs)) else: # Export the actor. if not self._exported: worker.function_actor_manager.export_actor_class( - self._class_id, self._modified_class, - self._actor_method_names, self._checkpoint_interval) + self._modified_class, self._actor_method_names, + self._checkpoint_interval) self._exported = True resources = ray.utils.resources_from_resource_arguments( @@ -409,10 +409,19 @@ class ActorClass(object): actor_placement_resources = resources.copy() actor_placement_resources["CPU"] += 1 - creation_args = [self._class_id] - function_id = compute_actor_creation_function_id(self._class_id) + if args is None: + args = [] + if kwargs is None: + kwargs = {} + function_name = "__init__" + function_signature = self._method_signatures[function_name] + creation_args = signature.extend_args(function_signature, args, + kwargs) + function_descriptor = FunctionDescriptor( + self._modified_class.__module__, function_name, + self._modified_class.__name__) [actor_cursor] = worker.submit_task( - function_id, + function_descriptor, creation_args, actor_creation_id=actor_id, max_actor_reconstructions=self._max_reconstructions, @@ -424,19 +433,10 @@ class ActorClass(object): # creation task. actor_counter = 1 actor_handle = ActorHandle( - actor_id, self._class_name, actor_cursor, actor_counter, - self._actor_method_names, self._method_signatures, - self._actor_method_num_return_vals, actor_cursor, - self._actor_method_cpus, worker.task_driver_id) - - # Call __init__ as a remote function. - if "__init__" in actor_handle._ray_actor_method_names: - actor_handle.__init__.remote(*args, **kwargs) - else: - if len(args) != 0 or len(kwargs) != 0: - raise Exception("Arguments cannot be passed to the actor " - "constructor because this actor class has no " - "__init__ method.") + actor_id, self._modified_class.__module__, self._class_name, + actor_cursor, actor_counter, self._actor_method_names, + self._method_signatures, self._actor_method_num_return_vals, + actor_cursor, self._actor_method_cpus, worker.task_driver_id) return actor_handle @@ -458,6 +458,7 @@ class ActorHandle(object): Attributes: _ray_actor_id: The ID of the corresponding actor. + _ray_module_name: The module name of this actor. _ray_actor_handle_id: The ID of this handle. If this is the "original" handle for an actor (as opposed to one created by passing another handle into a task), then this ID must be NIL_ID. If this @@ -496,6 +497,7 @@ class ActorHandle(object): def __init__(self, actor_id, + module_name, class_name, actor_cursor, actor_counter, @@ -510,6 +512,7 @@ class ActorHandle(object): # False if this actor handle was created by forking or pickling. True # if it was created by the _serialization_helper function. self._ray_original_handle = previous_actor_handle_id is None + self._ray_module_name = module_name self._ray_actor_id = actor_id if self._ray_original_handle: @@ -602,10 +605,10 @@ class ActorHandle(object): else: actor_handle_id = self._ray_actor_handle_id - function_id = FunctionActorManager.compute_actor_method_function_id( - self._ray_class_name, method_name) + function_descriptor = FunctionDescriptor( + self._ray_module_name, method_name, self._ray_class_name) object_ids = worker.submit_task( - function_id, + function_descriptor, args, actor_id=self._ray_actor_id, actor_handle_id=actor_handle_id, @@ -706,6 +709,7 @@ class ActorHandle(object): """ state = { "actor_id": self._ray_actor_id.id(), + "module_name": self._ray_module_name, "class_name": self._ray_class_name, "actor_forks": self._ray_actor_forks, "actor_cursor": self._ray_actor_cursor.id() @@ -753,6 +757,7 @@ class ActorHandle(object): self.__init__( ray.ObjectID(state["actor_id"]), + state["module_name"], state["class_name"], ray.ObjectID(state["actor_cursor"]) if state["actor_cursor"] is not None else None, diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index d97cc274f..bef88f48c 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -9,6 +9,7 @@ import sys import time import ray +from ray.function_manager import FunctionDescriptor import ray.gcs_utils import ray.ray_constants as ray_constants from ray.utils import (decode, binary_to_object_id, binary_to_hex, @@ -234,6 +235,9 @@ class GlobalState(object): execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() task_spec = ray.raylet.task_from_string(task_spec) + function_descriptor_list = task_spec.function_descriptor_list() + function_descriptor = FunctionDescriptor.from_bytes_list( + function_descriptor_list) task_spec_info = { "DriverID": binary_to_hex(task_spec.driver_id().id()), "TaskID": binary_to_hex(task_spec.task_id().id()), @@ -245,10 +249,14 @@ class GlobalState(object): "ActorCreationDummyObjectID": binary_to_hex( task_spec.actor_creation_dummy_object_id().id()), "ActorCounter": task_spec.actor_counter(), - "FunctionID": binary_to_hex(task_spec.function_id().id()), "Args": task_spec.arguments(), "ReturnObjectIDs": task_spec.returns(), - "RequiredResources": task_spec.required_resources() + "RequiredResources": task_spec.required_resources(), + "FunctionID": binary_to_hex(function_descriptor.function_id.id()), + "FunctionHash": binary_to_hex(function_descriptor.function_hash), + "ModuleName": function_descriptor.module_name, + "ClassName": function_descriptor.class_name, + "FunctionName": function_descriptor.function_name, } return { diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index 72ec53651..94ff22baf 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -5,6 +5,7 @@ from __future__ import print_function import hashlib import inspect import json +import logging import sys import time import traceback @@ -18,6 +19,7 @@ from ray import profiling from ray import ray_constants from ray import cloudpickle as pickle from ray.utils import ( + binary_to_hex, is_cython, is_function_or_method, is_class_method, @@ -31,6 +33,228 @@ FunctionExecutionInfo = namedtuple("FunctionExecutionInfo", ["function", "function_name", "max_calls"]) """FunctionExecutionInfo: A named tuple storing remote function information.""" +logger = logging.getLogger(__name__) + + +class FunctionDescriptor(object): + """A class used to describe a python function. + + Attributes: + module_name: the module name that the function belongs to. + class_name: the class name that the function belongs to if exists. + It could be empty is the function is not a class method. + function_name: the function name of the function. + function_hash: the hash code of the function source code if the + function code is available. + function_id: the function id calculated from this descriptor. + is_for_driver_task: whether this descriptor is for driver task. + """ + + def __init__(self, + module_name, + function_name, + class_name="", + function_source_hash=b""): + self._module_name = module_name + self._class_name = class_name + self._function_name = function_name + self._function_source_hash = function_source_hash + self._function_id = self._get_function_id() + + def __repr__(self): + return ("FunctionDescriptor:" + self._module_name + "." + + self._class_name + "." + self._function_name + "." + + binary_to_hex(self._function_source_hash)) + + @classmethod + def from_bytes_list(cls, function_descriptor_list): + """Create a FunctionDescriptor instance from list of bytes. + + This function is used to create the function descriptor from + backend data. + + Args: + cls: Current class which is required argument for classmethod. + function_descriptor_list: list of bytes to represent the + function descriptor. + + Returns: + The FunctionDescriptor instance created from the bytes list. + """ + assert isinstance(function_descriptor_list, list) + if len(function_descriptor_list) == 0: + # This is a function descriptor of driver task. + return FunctionDescriptor.for_driver_task() + elif (len(function_descriptor_list) == 3 + or len(function_descriptor_list) == 4): + module_name = function_descriptor_list[0].decode() + class_name = function_descriptor_list[1].decode() + function_name = function_descriptor_list[2].decode() + if len(function_descriptor_list) == 4: + return cls(module_name, function_name, class_name, + function_descriptor_list[3]) + else: + return cls(module_name, function_name, class_name) + else: + raise Exception( + "Invalid input for FunctionDescriptor.from_bytes_list") + + @classmethod + def from_function(cls, function): + """Create a FunctionDescriptor from a function instance. + + This function is used to create the function descriptor from + a python function. If a function is a class function, it should + not be used by this function. + + Args: + cls: Current class which is required argument for classmethod. + function: the python function used to create the function + descriptor. + + Returns: + The FunctionDescriptor instance created according to the function. + """ + module_name = function.__module__ + function_name = function.__name__ + class_name = "" + + function_source_hasher = hashlib.sha1() + try: + # If we are running a script or are in IPython, include the source + # code in the hash. + source = inspect.getsource(function).encode("ascii") + function_source_hasher.update(source) + function_source_hash = function_source_hasher.digest() + except (IOError, OSError, TypeError): + # Source code may not be available: + # e.g. Cython or Python interpreter. + function_source_hash = b"" + + return cls(module_name, function_name, class_name, + function_source_hash) + + @classmethod + def from_class(cls, target_class): + """Create a FunctionDescriptor from a class. + + Args: + cls: Current class which is required argument for classmethod. + target_class: the python class used to create the function + descriptor. + + Returns: + The FunctionDescriptor instance created according to the class. + """ + module_name = target_class.__module__ + class_name = target_class.__name__ + return cls(module_name, "__init__", class_name) + + @classmethod + def for_driver_task(cls): + """Create a FunctionDescriptor instance for a driver task.""" + return cls("", "", "", b"") + + @property + def is_for_driver_task(self): + """See whether this function descriptor is for a driver or not. + + Returns: + True if this function descriptor is for driver tasks. + """ + return all( + len(x) == 0 + for x in [self.module_name, self.class_name, self.function_name]) + + @property + def module_name(self): + """Get the module name of current function descriptor. + + Returns: + The module name of the function descriptor. + """ + return self._module_name + + @property + def class_name(self): + """Get the class name of current function descriptor. + + Returns: + The class name of the function descriptor. It could be + empty if the function is not a class method. + """ + return self._class_name + + @property + def function_name(self): + """Get the function name of current function descriptor. + + Returns: + The function name of the function descriptor. + """ + return self._function_name + + @property + def function_hash(self): + """Get the hash code of the function source code. + + Returns: + The bytes with length of ray_constants.ID_SIZE if the source + code is available. Otherwise, the bytes length will be 0. + """ + return self._function_source_hash + + @property + def function_id(self): + """Get the function id calculated from this descriptor. + + Returns: + The value of ray.ObjectID that represents the function id. + """ + return ray.ObjectID(self._function_id) + + def _get_function_id(self): + """Calculate the function id of current function descriptor. + + This function id is calculated from all the fields of function + descriptor. + + Returns: + bytes with length of ray_constants.ID_SIZE. + """ + if self.is_for_driver_task: + return ray_constants.NIL_FUNCTION_ID.id() + function_id_hash = hashlib.sha1() + # Include the function module and name in the hash. + function_id_hash.update(self.module_name.encode("ascii")) + function_id_hash.update(self.function_name.encode("ascii")) + function_id_hash.update(self.class_name.encode("ascii")) + function_id_hash.update(self._function_source_hash) + # Compute the function ID. + function_id = function_id_hash.digest() + assert len(function_id) == ray_constants.ID_SIZE + return function_id + + def get_function_descriptor_list(self): + """Return a list of bytes representing the function descriptor. + + This function is used to pass this function descriptor to backend. + + Returns: + A list of bytes. + """ + descriptor_list = [] + if self.is_for_driver_task: + # Driver task returns an empty list. + return descriptor_list + else: + descriptor_list.append(self.module_name.encode("ascii")) + descriptor_list.append(self.class_name.encode("ascii")) + descriptor_list.append(self.function_name.encode("ascii")) + if len(self._function_source_hash) != 0: + descriptor_list.append(self._function_source_hash) + return descriptor_list + class FunctionActorManager(object): """A class used to export/load remote functions and actors. @@ -45,6 +269,8 @@ class FunctionActorManager(object): and execution_info. _num_task_executions: The map from driver_id to function execution times. + imported_actor_classes: The set of actor classes keys (format: + ActorClass:function_id) that are already in GCS. """ def __init__(self, worker): @@ -58,11 +284,17 @@ class FunctionActorManager(object): # workers that execute remote functions. self._function_execution_info = defaultdict(lambda: {}) self._num_task_executions = defaultdict(lambda: {}) + # A set of all of the actor class keys that have been imported by the + # import thread. It is safe to convert this worker into an actor of + # these types. + self.imported_actor_classes = set() - def increase_task_counter(self, driver_id, function_id): + def increase_task_counter(self, driver_id, function_descriptor): + function_id = function_descriptor.function_id.id() self._num_task_executions[driver_id][function_id] += 1 - def get_task_counter(self, driver_id, function_id): + def get_task_counter(self, driver_id, function_descriptor): + function_id = function_descriptor.function_id.id() return self._num_task_executions[driver_id][function_id] def export_cached(self): @@ -124,13 +356,13 @@ class FunctionActorManager(object): check_oversized_pickle(pickled_function, remote_function._function_name, "remote function", self._worker) - key = (b"RemoteFunction:" + self._worker.task_driver_id.id() + b":" + - remote_function._function_id) + remote_function._function_descriptor.function_id.id()) self._worker.redis_client.hmset( key, { "driver_id": self._worker.task_driver_id.id(), - "function_id": remote_function._function_id, + "function_id": remote_function._function_descriptor. + function_id.id(), "name": remote_function._function_name, "module": function.__module__, "function": pickled_function, @@ -193,24 +425,28 @@ class FunctionActorManager(object): self._worker.redis_client.rpush( b"FunctionTable:" + function_id.id(), self._worker.worker_id) - def get_execution_info(self, driver_id, function_id): + def get_execution_info(self, driver_id, function_descriptor): """Get the FunctionExecutionInfo of a remote function. Args: driver_id: ID of the driver that the function belongs to. - function_id: ID of the function to get. + function_descriptor: The FunctionDescriptor of the function to get. Returns: A FunctionExecutionInfo object. """ - # Wait until the function to be executed has actually been registered - # on this worker. We will push warnings to the user if we spend too - # long in this loop. - with profiling.profile("wait_for_function", worker=self._worker): - self._wait_for_function(function_id, driver_id) - return self._function_execution_info[driver_id][function_id.id()] + function_id = function_descriptor.function_id.id() - def _wait_for_function(self, function_id, driver_id, timeout=10): + # Wait until the function to be executed has actually been + # registered on this worker. We will push warnings to the user if + # we spend too long in this loop. + # The driver function may not be found in sys.path. Try to load + # the function from GCS. + with profiling.profile("wait_for_function", worker=self._worker): + self._wait_for_function(function_descriptor, driver_id) + return self._function_execution_info[driver_id][function_id] + + def _wait_for_function(self, function_descriptor, driver_id, timeout=10): """Wait until the function to be executed is present on this worker. This method will simply loop until the import thread has imported the @@ -221,7 +457,8 @@ class FunctionActorManager(object): been defined. Args: - function_id (str): The ID of the function that we want to execute. + function_descriptor : The FunctionDescriptor of the function that + we want to execute. driver_id (str): The ID of the driver to push the error message to if this times out. """ @@ -231,7 +468,7 @@ class FunctionActorManager(object): while True: with self._worker.lock: if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID - and (function_id.id() in + and (function_descriptor.function_id.id() in self._function_execution_info[driver_id])): break elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and ( @@ -251,24 +488,6 @@ class FunctionActorManager(object): warning_sent = True time.sleep(0.001) - @classmethod - def compute_actor_method_function_id(cls, class_name, attr): - """Get the function ID corresponding to an actor method. - - Args: - class_name (str): The class name of the actor. - attr (str): The attribute name of the method. - - Returns: - Function ID corresponding to the method. - """ - function_id_hash = hashlib.sha1() - function_id_hash.update(class_name.encode("ascii")) - function_id_hash.update(attr.encode("ascii")) - function_id = function_id_hash.digest() - assert len(function_id) == ray_constants.ID_SIZE - return ray.ObjectID(function_id) - def _publish_actor_class_to_key(self, key, actor_class_info): """Push an actor class definition to Redis. @@ -287,9 +506,10 @@ class FunctionActorManager(object): self._worker.redis_client.hmset(key, actor_class_info) self._worker.redis_client.rpush("Exports", key) - def export_actor_class(self, class_id, Class, actor_method_names, + def export_actor_class(self, Class, actor_method_names, checkpoint_interval): - key = b"ActorClass:" + class_id + function_descriptor = FunctionDescriptor.from_class(Class) + key = b"ActorClass:" + function_descriptor.function_id.id() actor_class_info = { "class_name": Class.__name__, "module": Class.__module__, @@ -318,6 +538,17 @@ class FunctionActorManager(object): # within tasks. I tried to disable this, but it may be necessary # because of https://github.com/ray-project/ray/issues/1146. + def load_actor(self, driver_id, function_descriptor): + key = b"ActorClass:" + function_descriptor.function_id.id() + # Wait for the actor class key to have been imported by the + # import thread. TODO(rkn): It shouldn't be possible to end + # up in an infinite loop here, but we should push an error to + # the driver if too much time is spent here. + while key not in self.imported_actor_classes: + time.sleep(0.001) + with self._worker.lock: + self.fetch_and_register_actor(key) + def fetch_and_register_actor(self, actor_class_key): """Import an actor. @@ -330,11 +561,10 @@ class FunctionActorManager(object): worker: The worker to use. """ actor_id_str = self._worker.actor_id - (driver_id, class_id, class_name, module, pickled_class, - checkpoint_interval, + (driver_id, class_name, module, pickled_class, checkpoint_interval, actor_method_names) = self._worker.redis_client.hmget( actor_class_key, [ - "driver_id", "class_id", "class_name", "module", "class", + "driver_id", "class_name", "module", "class", "checkpoint_interval", "actor_method_names" ]) @@ -368,9 +598,9 @@ class FunctionActorManager(object): # Register the actor method executors. for actor_method_name in actor_method_names: - function_id = ( - FunctionActorManager.compute_actor_method_function_id( - class_name, actor_method_name).id()) + function_descriptor = FunctionDescriptor(module, actor_method_name, + class_name) + function_id = function_descriptor.function_id.id() temporary_executor = self._make_actor_method_executor( actor_method_name, temporary_actor_method, @@ -409,9 +639,9 @@ class FunctionActorManager(object): actor_methods = inspect.getmembers( unpickled_class, predicate=is_function_or_method) for actor_method_name, actor_method in actor_methods: - function_id = ( - FunctionActorManager.compute_actor_method_function_id( - class_name, actor_method_name).id()) + function_descriptor = FunctionDescriptor( + module, actor_method_name, class_name) + function_id = function_descriptor.function_id.id() executor = self._make_actor_method_executor( actor_method_name, actor_method, actor_imported=True) self._function_execution_info[driver_id][function_id] = ( @@ -452,7 +682,9 @@ class FunctionActorManager(object): # If this is the first task to execute on the actor, try to resume # from a checkpoint. - if actor_imported and self._worker.actor_task_counter == 1: + # Current __init__ will be called by default. So the real function + # call will start from 2. + if actor_imported and self._worker.actor_task_counter == 2: checkpoint_resumed = ray.actor.restore_and_log_checkpoint( self._worker, actor) if checkpoint_resumed: diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 347f7ab9f..c477f4bcc 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -3,26 +3,26 @@ from __future__ import division from __future__ import print_function import flatbuffers - import ray.core.generated.ErrorTableData -from ray.core.generated.GcsTableEntry import GcsTableEntry from ray.core.generated.ClientTableData import ClientTableData -from ray.core.generated.ErrorTableData import ErrorTableData -from ray.core.generated.ProfileTableData import ProfileTableData -from ray.core.generated.HeartbeatTableData import HeartbeatTableData -from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData from ray.core.generated.DriverTableData import DriverTableData +from ray.core.generated.ErrorTableData import ErrorTableData +from ray.core.generated.GcsTableEntry import GcsTableEntry +from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData +from ray.core.generated.HeartbeatTableData import HeartbeatTableData +from ray.core.generated.Language import Language from ray.core.generated.ObjectTableData import ObjectTableData -from ray.core.generated.ray.protocol.Task import Task - +from ray.core.generated.ProfileTableData import ProfileTableData from ray.core.generated.TablePrefix import TablePrefix from ray.core.generated.TablePubsub import TablePubsub +from ray.core.generated.ray.protocol.Task import Task + __all__ = [ "GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData", "HeartbeatBatchTableData", "DriverTableData", "ProfileTableData", - "ObjectTableData", "Task", "TablePrefix", "TablePubsub", + "ObjectTableData", "Task", "TablePrefix", "TablePubsub", "Language", "construct_error_message" ] diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py index 70dba3223..08031c7b6 100644 --- a/python/ray/import_thread.py +++ b/python/ray/import_thread.py @@ -98,7 +98,7 @@ class ImportThread(object): # Keep track of the fact that this actor class has been # exported so that we know it is safe to turn this worker # into an actor of that class. - self.worker.imported_actor_classes.add(key) + self.worker.function_actor_manager.imported_actor_classes.add(key) # TODO(rkn): We may need to bring back the case of # fetching actor classes here. else: diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index 82aa1617d..fc89d48ed 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -16,6 +16,7 @@ def env_integer(key, default): ID_SIZE = 20 NIL_JOB_ID = ObjectID(ID_SIZE * b"\xff") +NIL_FUNCTION_ID = NIL_JOB_ID # If a remote function or actor (or some other export) has serialized size # greater than this quantity, print an warning. diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index fb2a29e45..b634451df 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -3,11 +3,9 @@ from __future__ import division from __future__ import print_function import copy -import hashlib -import inspect import logging -import ray.ray_constants as ray_constants +from ray.function_manager import FunctionDescriptor import ray.signature # Default parameters for remote functions. @@ -18,33 +16,6 @@ DEFAULT_REMOTE_FUNCTION_MAX_CALLS = 0 logger = logging.getLogger(__name__) -def compute_function_id(function): - """Compute an function ID for a function. - - Args: - func: The actual function. - - Returns: - Raw bytes of the function id - """ - function_id_hash = hashlib.sha1() - # Include the function module and name in the hash. - function_id_hash.update(function.__module__.encode("ascii")) - function_id_hash.update(function.__name__.encode("ascii")) - try: - # If we are running a script or are in IPython, include the source code - # in the hash. - source = inspect.getsource(function).encode("ascii") - function_id_hash.update(source) - except (IOError, OSError, TypeError): - # Source code may not be available: e.g. Cython or Python interpreter. - pass - # Compute the function ID. - function_id = function_id_hash.digest() - assert len(function_id) == ray_constants.ID_SIZE - return function_id - - class RemoteFunction(object): """A remote function. @@ -52,7 +23,7 @@ class RemoteFunction(object): Attributes: _function: The original function. - _function_id: The ID of the function. + _function_descriptor: The function descriptor. _function_name: The module and function name. _num_cpus: The default number of CPUs to use for invocations of this remote function. @@ -70,10 +41,7 @@ class RemoteFunction(object): def __init__(self, function, num_cpus, num_gpus, resources, num_return_vals, max_calls): self._function = function - # TODO(rkn): We store the function ID as a string, so that - # RemoteFunction objects can be pickled. We should undo this when - # we allow ObjectIDs to be pickled. - self._function_id = compute_function_id(function) + self._function_descriptor = FunctionDescriptor.from_function(function) self._function_name = ( self._function.__module__ + '.' + self._function.__name__) self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS @@ -147,7 +115,7 @@ class RemoteFunction(object): result = self._function(*copy.deepcopy(args)) return result object_ids = worker.submit_task( - ray.ObjectID(self._function_id), + self._function_descriptor, args, num_return_vals=num_return_vals, resources=resources) diff --git a/python/ray/worker.py b/python/ray/worker.py index 1cee3b714..06f774959 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -36,7 +36,7 @@ import ray.plasma import ray.ray_constants as ray_constants from ray import import_thread from ray import profiling -from ray.function_manager import FunctionActorManager +from ray.function_manager import (FunctionActorManager, FunctionDescriptor) from ray.utils import ( check_oversized_pickle, is_cython, @@ -54,7 +54,6 @@ ERROR_KEY_PREFIX = b"Error:" # This must match the definition of NIL_ACTOR_ID in task.h. NIL_ID = ray_constants.ID_SIZE * b"\xff" NIL_LOCAL_SCHEDULER_ID = NIL_ID -NIL_FUNCTION_ID = NIL_ID NIL_ACTOR_ID = NIL_ID NIL_ACTOR_HANDLE_ID = NIL_ID NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" @@ -161,10 +160,6 @@ class Worker(object): self.make_actor = None self.actors = {} self.actor_task_counter = 0 - # A set of all of the actor class keys that have been imported by the - # import thread. It is safe to convert this worker into an actor of - # these types. - self.imported_actor_classes = set() # The number of threads Plasma should use when putting an object in the # object store. self.memcopy_threads = 12 @@ -518,7 +513,7 @@ class Worker(object): return final_results def submit_task(self, - function_id, + function_descriptor, args, actor_id=None, actor_handle_id=None, @@ -531,15 +526,16 @@ class Worker(object): num_return_vals=None, resources=None, placement_resources=None, - driver_id=None): + driver_id=None, + language=ray.gcs_utils.Language.PYTHON): """Submit a remote task to the scheduler. - Tell the scheduler to schedule the execution of the function with ID - function_id with arguments args. Retrieve object IDs for the outputs of - the function from the scheduler and immediately return them. + Tell the scheduler to schedule the execution of the function with + function_descriptor with arguments args. Retrieve object IDs for the + outputs of the function from the scheduler and immediately return them. Args: - function_id: The ID of the function to execute. + function_descriptor: The function descriptor to execute. args: The arguments to pass into the function. Arguments can be object IDs or they can be values. If they are values, they must be serializable objects. @@ -623,13 +619,15 @@ class Worker(object): # The parent task must be set for the submitted task. assert not self.current_task_id.is_nil() # Submit the task to local scheduler. + function_descriptor_list = ( + function_descriptor.get_function_descriptor_list()) task = ray.raylet.Task( - driver_id, ray.ObjectID(function_id.id()), - args_for_local_scheduler, num_return_vals, - self.current_task_id, task_index, actor_creation_id, - actor_creation_dummy_object_id, max_actor_reconstructions, - actor_id, actor_handle_id, actor_counter, - execution_dependencies, resources, placement_resources) + driver_id, function_descriptor_list, args_for_local_scheduler, + num_return_vals, self.current_task_id, task_index, + actor_creation_id, actor_creation_dummy_object_id, + max_actor_reconstructions, actor_id, actor_handle_id, + actor_counter, execution_dependencies, resources, + placement_resources) self.raylet_client.submit_task(task) return task.returns() @@ -778,10 +776,12 @@ class Worker(object): self.task_driver_id = task.driver_id() self.current_task_id = task.task_id() - function_id = task.function_id() + function_descriptor = FunctionDescriptor.from_bytes_list( + task.function_descriptor_list()) args = task.arguments() return_object_ids = task.returns() - if task.actor_id().id() != NIL_ACTOR_ID: + if (task.actor_id().id() != NIL_ACTOR_ID + or task.actor_creation_id().id() != NIL_ACTOR_ID): dummy_return_id = return_object_ids.pop() function_executor = function_execution_info.function function_name = function_execution_info.function_name @@ -796,33 +796,36 @@ class Worker(object): function_name, args) except RayTaskError as e: self._handle_process_task_failure( - function_id, function_name, return_object_ids, e, + function_descriptor, return_object_ids, e, ray.utils.format_error_message(traceback.format_exc())) return except Exception as e: self._handle_process_task_failure( - function_id, function_name, return_object_ids, e, + function_descriptor, return_object_ids, e, ray.utils.format_error_message(traceback.format_exc())) return # Execute the task. try: with profiling.profile("task:execute", worker=self): - if task.actor_id().id() == NIL_ACTOR_ID: + if (task.actor_id().id() == NIL_ACTOR_ID + and task.actor_creation_id().id() == NIL_ACTOR_ID): outputs = function_executor(*arguments) else: - outputs = function_executor( - dummy_return_id, self.actors[task.actor_id().id()], - *arguments) + if task.actor_id().id() != NIL_ACTOR_ID: + key = task.actor_id().id() + else: + key = task.actor_creation_id().id() + outputs = function_executor(dummy_return_id, + self.actors[key], *arguments) except Exception as e: # Determine whether the exception occured during a task, not an # actor method. task_exception = task.actor_id().id() == NIL_ACTOR_ID traceback_str = ray.utils.format_error_message( traceback.format_exc(), task_exception=task_exception) - self._handle_process_task_failure(function_id, function_name, - return_object_ids, e, - traceback_str) + self._handle_process_task_failure( + function_descriptor, return_object_ids, e, traceback_str) return # Store the outputs in the local object store. @@ -837,11 +840,13 @@ class Worker(object): self._store_outputs_in_object_store(return_object_ids, outputs) except Exception as e: self._handle_process_task_failure( - function_id, function_name, return_object_ids, e, + function_descriptor, return_object_ids, e, ray.utils.format_error_message(traceback.format_exc())) - def _handle_process_task_failure(self, function_id, function_name, + def _handle_process_task_failure(self, function_descriptor, return_object_ids, error, backtrace): + function_name = function_descriptor.function_name + function_id = function_descriptor.function_id failure_object = RayTaskError(function_name, backtrace) failure_objects = [ failure_object for _ in range(len(return_object_ids)) @@ -855,53 +860,34 @@ class Worker(object): driver_id=self.task_driver_id.id(), data={ "function_id": function_id.id(), - "function_name": function_name + "function_name": function_name, + "module_name": function_descriptor.module_name, + "class_name": function_descriptor.class_name }) # Mark the actor init as failed if self.actor_id != NIL_ACTOR_ID and function_name == "__init__": self.mark_actor_init_failed(error) - def _become_actor(self, task): - """Turn this worker into an actor. - - Args: - task: The actor creation task. - """ - assert self.actor_id == NIL_ACTOR_ID - arguments = task.arguments() - assert len(arguments) == 1 - self.actor_id = task.actor_creation_id().id() - class_id = arguments[0] - - key = b"ActorClass:" + class_id - - # Wait for the actor class key to have been imported by the import - # thread. TODO(rkn): It shouldn't be possible to end up in an infinite - # loop here, but we should push an error to the driver if too much time - # is spent here. - while key not in self.imported_actor_classes: - time.sleep(0.001) - - with self.lock: - self.function_actor_manager.fetch_and_register_actor(key) - def _wait_for_and_process_task(self, task): """Wait for a task to be ready and process the task. Args: task: The task to execute. """ - function_id = task.function_id() + function_descriptor = FunctionDescriptor.from_bytes_list( + task.function_descriptor_list()) driver_id = task.driver_id().id() # TODO(rkn): It would be preferable for actor creation tasks to share # more of the code path with regular task execution. if (task.actor_creation_id() != ray.ObjectID(NIL_ACTOR_ID)): - self._become_actor(task) - return + assert self.actor_id == NIL_ACTOR_ID + self.actor_id = task.actor_creation_id().id() + self.function_actor_manager.load_actor(driver_id, + function_descriptor) execution_info = self.function_actor_manager.get_execution_info( - driver_id, function_id) + driver_id, function_descriptor) # Execute the task. # TODO(rkn): Consider acquiring this lock with a timeout and pushing a @@ -915,8 +901,14 @@ class Worker(object): "task_id": task.task_id().hex() } if task.actor_id().id() == NIL_ACTOR_ID: - title = "ray_worker:{}()".format(function_name) - next_title = "ray_worker" + if (task.actor_creation_id() == ray.ObjectID(NIL_ACTOR_ID)): + title = "ray_worker:{}()".format(function_name) + next_title = "ray_worker" + else: + actor = self.actors[task.actor_creation_id().id()] + title = "ray_{}:{}()".format(actor.__class__.__name__, + function_name) + next_title = "ray_{}".format(actor.__class__.__name__) else: actor = self.actors[task.actor_id().id()] title = "ray_{}:{}()".format(actor.__class__.__name__, @@ -934,10 +926,10 @@ class Worker(object): # Increase the task execution counter. self.function_actor_manager.increase_task_counter( - driver_id, function_id.id()) + driver_id, function_descriptor) reached_max_executions = (self.function_actor_manager.get_task_counter( - driver_id, function_id.id()) == execution_info.max_calls) + driver_id, function_descriptor) == execution_info.max_calls) if reached_max_executions: self.raylet_client.disconnect() sys.exit(0) @@ -2096,15 +2088,14 @@ def connect(info, # rerun the driver. nil_actor_counter = 0 - driver_task = ray.raylet.Task(worker.task_driver_id, - ray.ObjectID(NIL_FUNCTION_ID), [], 0, - worker.current_task_id, - worker.task_index, - ray.ObjectID(NIL_ACTOR_ID), - ray.ObjectID(NIL_ACTOR_ID), 0, - ray.ObjectID(NIL_ACTOR_ID), - ray.ObjectID(NIL_ACTOR_ID), - nil_actor_counter, [], {"CPU": 0}, {}) + function_descriptor = FunctionDescriptor.for_driver_task() + driver_task = ray.raylet.Task( + worker.task_driver_id, + function_descriptor.get_function_descriptor_list(), [], 0, + worker.current_task_id, worker.task_index, + ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), 0, + ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), + nil_actor_counter, [], {"CPU": 0}, {}) # Add the driver task to the task table. global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD", diff --git a/src/ray/common/common_protocol.cc b/src/ray/common/common_protocol.cc index 5ce4c89d6..bcbfcc5f0 100644 --- a/src/ray/common/common_protocol.cc +++ b/src/ray/common/common_protocol.cc @@ -69,3 +69,25 @@ map_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, } return fbb.CreateVector(resource_vector); } + +std::vector string_vec_from_flatbuf( + const flatbuffers::Vector> &flatbuf_vec) { + std::vector string_vector; + string_vector.reserve(flatbuf_vec.size()); + for (int64_t i = 0; i < flatbuf_vec.size(); i++) { + const auto flatbuf_str = flatbuf_vec.Get(i); + string_vector.push_back(string_from_flatbuf(*flatbuf_str)); + } + return string_vector; +} + +flatbuffers::Offset>> +string_vec_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, + const std::vector &string_vector) { + std::vector> flatbuf_str_vec; + flatbuf_str_vec.reserve(flatbuf_str_vec.size()); + for (auto const &str : string_vector) { + flatbuf_str_vec.push_back(fbb.CreateString(str)); + } + return fbb.CreateVector(flatbuf_str_vec); +} diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h index 3afa6b8e5..de8f27fc4 100644 --- a/src/ray/common/common_protocol.h +++ b/src/ray/common/common_protocol.h @@ -72,4 +72,10 @@ map_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::unordered_map map_from_flatbuf( const flatbuffers::Vector> &resource_vector); +std::vector string_vec_from_flatbuf( + const flatbuffers::Vector> &flatbuf_vec); + +flatbuffers::Offset>> +string_vec_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, + const std::vector &string_vector); #endif diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 99207d9b4..df4077a45 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -73,8 +73,6 @@ table TaskInfo { actor_counter: int; // True if this task is an actor checkpoint task and false otherwise. is_actor_checkpoint_method: bool; - // Function ID of the task. - function_id: string; // Task arguments. args: [Arg]; // Object IDs of return values. diff --git a/src/ray/raylet/lib/python/common_extension.cc b/src/ray/raylet/lib/python/common_extension.cc index f5d1c12a2..d986427cd 100644 --- a/src/ray/raylet/lib/python/common_extension.cc +++ b/src/ray/raylet/lib/python/common_extension.cc @@ -84,6 +84,32 @@ int PyObjectToUniqueID(PyObject *object, ObjectID *objectid) { } } +int PyListStringToStringVector(PyObject *object, + std::vector *function_descriptor) { + if (function_descriptor == nullptr) { + PyErr_SetString(PyExc_TypeError, "function descriptor must be non-empty pointer"); + return 0; + } + function_descriptor->clear(); + std::vector string_vector; + if (PyList_Check(object)) { + Py_ssize_t size = PyList_Size(object); + for (Py_ssize_t i = 0; i < size; ++i) { + PyObject *item = PyList_GetItem(object, i); + if (PyBytes_Check(item) == 0) { + PyErr_SetString(PyExc_TypeError, + "PyListStringToStringVector takes a list of byte strings."); + return 0; + } + function_descriptor->emplace_back(PyBytes_AsString(item), PyBytes_Size(item)); + } + return 1; + } else { + PyErr_SetString(PyExc_TypeError, "must be a list of strings"); + return 0; + } +} + static int PyObjectID_init(PyObjectID *self, PyObject *args, PyObject *kwds) { const char *data; int size; @@ -363,12 +389,12 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { UniqueID actor_handle_id; // How many tasks have been launched on the actor so far? int actor_counter = 0; - // ID of the function this task executes. - FunctionID function_id; // Arguments of the task (can be PyObjectIDs or Python values). PyObject *arguments; // Number of return values of this task. int num_returns; + // Task language type enum number. + int language = static_cast(Language::PYTHON); // The ID of the task that called this task. TaskID parent_task_id; // The number of tasks that the parent task has called prior to this one. @@ -387,14 +413,17 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { PyObject *resource_map = nullptr; // Dictionary of required placement resources for this task. PyObject *placement_resource_map = nullptr; - if (!PyArg_ParseTuple(args, "O&O&OiO&i|O&O&iO&O&iOOO", &PyObjectToUniqueID, &driver_id, - &PyObjectToUniqueID, &function_id, &arguments, &num_returns, - &PyObjectToUniqueID, &parent_task_id, &parent_counter, - &PyObjectToUniqueID, &actor_creation_id, &PyObjectToUniqueID, - &actor_creation_dummy_object_id, &max_actor_reconstructions, - &PyObjectToUniqueID, &actor_id, &PyObjectToUniqueID, - &actor_handle_id, &actor_counter, &execution_arguments, - &resource_map, &placement_resource_map)) { + + // Function descriptor. + std::vector function_descriptor; + if (!PyArg_ParseTuple( + args, "O&O&OiO&i|O&O&iO&O&iOOOi", &PyObjectToUniqueID, &driver_id, + &PyListStringToStringVector, &function_descriptor, &arguments, &num_returns, + &PyObjectToUniqueID, &parent_task_id, &parent_counter, &PyObjectToUniqueID, + &actor_creation_id, &PyObjectToUniqueID, &actor_creation_dummy_object_id, + &max_actor_reconstructions, &PyObjectToUniqueID, &actor_id, &PyObjectToUniqueID, + &actor_handle_id, &actor_counter, &execution_arguments, &resource_map, + &placement_resource_map, &language)) { return -1; } @@ -424,6 +453,7 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { self->task_spec = nullptr; // Create the task spec. + // Parse the arguments from the list. std::vector> task_args; for (Py_ssize_t i = 0; i < num_args; ++i) { @@ -444,8 +474,8 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { self->task_spec = new ray::raylet::TaskSpecification( driver_id, parent_task_id, parent_counter, actor_creation_id, actor_creation_dummy_object_id, max_actor_reconstructions, actor_id, - actor_handle_id, actor_counter, function_id, task_args, num_returns, - required_resources, required_placement_resources, Language::PYTHON); + actor_handle_id, actor_counter, task_args, num_returns, required_resources, + required_placement_resources, Language::PYTHON, function_descriptor); /* Set the task's execution dependencies. */ self->execution_dependencies = new std::vector(); @@ -470,9 +500,23 @@ static void PyTask_dealloc(PyTask *self) { Py_TYPE(self)->tp_free(reinterpret_cast(self)); } -static PyObject *PyTask_function_id(PyTask *self) { - FunctionID function_id = self->task_spec->FunctionId(); - return PyObjectID_make(function_id); +// Helper function to change a c++ string vector to a Python string list. +static PyObject *VectorStringToPyBytesList( + const std::vector &function_descriptor) { + size_t size = function_descriptor.size(); + PyObject *return_list = PyList_New(static_cast(size)); + for (size_t i = 0; i < size; ++i) { + auto py_bytes = PyBytes_FromStringAndSize(function_descriptor[i].data(), + function_descriptor[i].size()); + PyList_SetItem(return_list, i, py_bytes); + } + return return_list; +} + +static PyObject *PyTask_function_descriptor_vector(PyTask *self) { + std::vector function_descriptor; + function_descriptor = self->task_spec->FunctionDescriptor(); + return VectorStringToPyBytesList(function_descriptor); } static PyObject *PyTask_actor_id(PyTask *self) { @@ -597,8 +641,8 @@ static PyObject *PyTask_to_serialized_flatbuf(PyTask *self) { } static PyMethodDef PyTask_methods[] = { - {"function_id", (PyCFunction)PyTask_function_id, METH_NOARGS, - "Return the function ID for this task."}, + {"function_descriptor_list", (PyCFunction)PyTask_function_descriptor_vector, + METH_NOARGS, "Return the function descriptor for this task."}, {"parent_task_id", (PyCFunction)PyTask_parent_task_id, METH_NOARGS, "Return the task ID of the parent task."}, {"parent_counter", (PyCFunction)PyTask_parent_counter, METH_NOARGS, diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index edfb0db69..32a0e5932 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -112,9 +112,10 @@ static inline Task ExampleTask(const std::vector &arguments, std::vector references = {argument}; task_arguments.emplace_back(std::make_shared(references)); } + std::vector function_descriptor(3); auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0, - UniqueID::from_random(), task_arguments, num_returns, - required_resources, Language::PYTHON); + task_arguments, num_returns, required_resources, + Language::PYTHON, function_descriptor); auto execution_spec = TaskExecutionSpecification(std::vector()); execution_spec.IncrementNumForwards(); Task task = Task(execution_spec, spec); diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index 56a97dcb0..1e0528317 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -74,9 +74,10 @@ static inline Task ExampleTask(const std::vector &arguments, std::vector references = {argument}; task_arguments.emplace_back(std::make_shared(references)); } + std::vector function_descriptor(3); auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0, - UniqueID::from_random(), task_arguments, num_returns, - required_resources, Language::PYTHON); + task_arguments, num_returns, required_resources, + Language::PYTHON, function_descriptor); auto execution_spec = TaskExecutionSpecification(std::vector()); execution_spec.IncrementNumForwards(); Task task = Task(execution_spec, spec); diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 540df9b63..7e33a9acf 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -56,25 +56,24 @@ TaskSpecification::TaskSpecification(const std::string &string) { TaskSpecification::TaskSpecification( const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, - const FunctionID &function_id, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, - const Language &language) + const Language &language, const std::vector &function_descriptor) : TaskSpecification(driver_id, parent_task_id, parent_counter, ActorID::nil(), ObjectID::nil(), 0, ActorID::nil(), ActorHandleID::nil(), -1, - function_id, task_arguments, num_returns, required_resources, - std::unordered_map(), language) {} + task_arguments, num_returns, required_resources, + std::unordered_map(), language, + function_descriptor) {} TaskSpecification::TaskSpecification( const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, const int64_t max_actor_reconstructions, const ActorID &actor_id, const ActorHandleID &actor_handle_id, int64_t actor_counter, - const FunctionID &function_id, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language) + const Language &language, const std::vector &function_descriptor) : spec_() { flatbuffers::FlatBufferBuilder fbb; @@ -99,9 +98,10 @@ TaskSpecification::TaskSpecification( to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id), to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter, false, - to_flatbuf(fbb, function_id), fbb.CreateVector(arguments), - fbb.CreateVector(returns), map_to_flatbuf(fbb, required_resources), - map_to_flatbuf(fbb, required_placement_resources), language); + fbb.CreateVector(arguments), fbb.CreateVector(returns), + map_to_flatbuf(fbb, required_resources), + map_to_flatbuf(fbb, required_placement_resources), language, + string_vec_to_flatbuf(fbb, function_descriptor)); fbb.Finish(spec); AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize()); } @@ -134,9 +134,9 @@ int64_t TaskSpecification::ParentCounter() const { auto message = flatbuffers::GetRoot(spec_.data()); return message->parent_counter(); } -FunctionID TaskSpecification::FunctionId() const { +std::vector TaskSpecification::FunctionDescriptor() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->function_id()); + return string_vec_from_flatbuf(*message->function_descriptor()); } int64_t TaskSpecification::NumArgs() const { @@ -197,7 +197,7 @@ const ResourceSet TaskSpecification::GetRequiredPlacementResources() const { bool TaskSpecification::IsDriverTask() const { // Driver tasks are empty tasks that have no function ID set. - return FunctionId().is_nil(); + return FunctionDescriptor().empty(); } Language TaskSpecification::GetLanguage() const { diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index 2799f0568..da33275e9 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -91,17 +91,18 @@ class TaskSpecification { /// \param parent_task_id The task ID of the task that spawned this task. /// \param parent_counter The number of tasks that this task's parent spawned /// before this task. - /// \param function_id The ID of the function this task should execute. + /// \param function_descriptor The function descriptor. /// \param task_arguments The list of task arguments. /// \param num_returns The number of values returned by the task. /// \param required_resources The task's resource demands. /// \param language The language of the worker that must execute the function. TaskSpecification(const UniqueID &driver_id, const TaskID &parent_task_id, - int64_t parent_counter, const FunctionID &function_id, + int64_t parent_counter, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, - const Language &language); + const Language &language, + const std::vector &function_descriptor); // TODO(swang): Define an actor task constructor. /// Create a task specification from the raw fields. @@ -119,7 +120,6 @@ class TaskSpecification { /// task. If this is not an actor task, then this is nil. /// \param actor_counter The number of tasks submitted before this task from /// the same actor handle. If this is not an actor task, then this is 0. - /// \param function_id The ID of the function this task should execute. /// \param task_arguments The list of task arguments. /// \param num_returns The number of values returned by the task. /// \param required_resources The task's resource demands. @@ -127,17 +127,17 @@ class TaskSpecification { /// task on a node. Typically, this should be an empty map in which case it /// will default to be equal to the required_resources argument. /// \param language The language of the worker that must execute the function. + /// \param function_descriptor The function descriptor. TaskSpecification( const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, int64_t max_actor_reconstructions, const ActorID &actor_id, const ActorHandleID &actor_handle_id, int64_t actor_counter, - const FunctionID &function_id, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language); + const Language &language, const std::vector &function_descriptor); /// Deserialize a task specification from a flatbuffer's string data. /// @@ -159,7 +159,7 @@ class TaskSpecification { UniqueID DriverId() const; TaskID ParentTaskId() const; int64_t ParentCounter() const; - FunctionID FunctionId() const; + std::vector FunctionDescriptor() const; int64_t NumArgs() const; int64_t NumReturns() const; bool ArgByRef(int64_t arg_index) const; diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 3933ec76c..9b228457b 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -63,9 +63,10 @@ class WorkerPoolTest : public ::testing::Test { static inline TaskSpecification ExampleTaskSpec( const ActorID actor_id = ActorID::nil(), const Language &language = Language::PYTHON) { + std::vector function_descriptor(3); return TaskSpecification(UniqueID::nil(), UniqueID::nil(), 0, ActorID::nil(), - ObjectID::nil(), 0, actor_id, ActorHandleID::nil(), 0, - FunctionID::nil(), {}, 0, {{}}, {{}}, language); + ObjectID::nil(), 0, actor_id, ActorHandleID::nil(), 0, {}, 0, + {{}}, {{}}, language, function_descriptor); } TEST_F(WorkerPoolTest, HandleWorkerRegistration) {