Enable function_descriptor in backend to replace the function_id (#3028)

This commit is contained in:
Yuhong Guo
2018-12-19 07:53:59 +08:00
committed by Robert Nishihara
parent 3822b20319
commit fb33fa9097
20 changed files with 557 additions and 282 deletions
+1
View File
@@ -50,6 +50,7 @@ MOCK_MODULES = [
"ray.core.generated.ray.protocol.Task",
"ray.core.generated.TablePrefix",
"ray.core.generated.TablePubsub",
"ray.core.generated.Language",
]
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()
@@ -38,23 +38,20 @@ public final class TaskInfo extends Table {
public ByteBuffer actorHandleIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 20, 1); }
public int actorCounter() { int o = __offset(22); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
public boolean isActorCheckpointMethod() { int o = __offset(24); return o != 0 ? 0!=bb.get(o + bb_pos) : false; }
public String functionId() { int o = __offset(26); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer functionIdAsByteBuffer() { return __vector_as_bytebuffer(26, 1); }
public ByteBuffer functionIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 26, 1); }
public Arg args(int j) { return args(new Arg(), j); }
public Arg args(Arg obj, int j) { int o = __offset(28); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
public int argsLength() { int o = __offset(28); return o != 0 ? __vector_len(o) : 0; }
public String returns(int j) { int o = __offset(30); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int returnsLength() { int o = __offset(30); return o != 0 ? __vector_len(o) : 0; }
public Arg args(Arg obj, int j) { int o = __offset(26); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
public int argsLength() { int o = __offset(26); return o != 0 ? __vector_len(o) : 0; }
public String returns(int j) { int o = __offset(28); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int returnsLength() { int o = __offset(28); return o != 0 ? __vector_len(o) : 0; }
public ResourcePair requiredResources(int j) { return requiredResources(new ResourcePair(), j); }
public ResourcePair requiredResources(ResourcePair obj, int j) { int o = __offset(32); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
public int requiredResourcesLength() { int o = __offset(32); return o != 0 ? __vector_len(o) : 0; }
public ResourcePair requiredResources(ResourcePair obj, int j) { int o = __offset(30); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
public int requiredResourcesLength() { int o = __offset(30); return o != 0 ? __vector_len(o) : 0; }
public ResourcePair requiredPlacementResources(int j) { return requiredPlacementResources(new ResourcePair(), j); }
public ResourcePair requiredPlacementResources(ResourcePair obj, int j) { int o = __offset(34); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
public int requiredPlacementResourcesLength() { int o = __offset(34); return o != 0 ? __vector_len(o) : 0; }
public int language() { int o = __offset(36); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
public String functionDescriptor(int j) { int o = __offset(38); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int functionDescriptorLength() { int o = __offset(38); return o != 0 ? __vector_len(o) : 0; }
public ResourcePair requiredPlacementResources(ResourcePair obj, int j) { int o = __offset(32); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
public int requiredPlacementResourcesLength() { int o = __offset(32); return o != 0 ? __vector_len(o) : 0; }
public int language() { int o = __offset(34); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
public String functionDescriptor(int j) { int o = __offset(36); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int functionDescriptorLength() { int o = __offset(36); return o != 0 ? __vector_len(o) : 0; }
public static int createTaskInfo(FlatBufferBuilder builder,
int driver_idOffset,
@@ -68,21 +65,19 @@ public final class TaskInfo extends Table {
int actor_handle_idOffset,
int actor_counter,
boolean is_actor_checkpoint_method,
int function_idOffset,
int argsOffset,
int returnsOffset,
int required_resourcesOffset,
int required_placement_resourcesOffset,
int language,
int function_descriptorOffset) {
builder.startObject(18);
builder.startObject(17);
TaskInfo.addFunctionDescriptor(builder, function_descriptorOffset);
TaskInfo.addLanguage(builder, language);
TaskInfo.addRequiredPlacementResources(builder, required_placement_resourcesOffset);
TaskInfo.addRequiredResources(builder, required_resourcesOffset);
TaskInfo.addReturns(builder, returnsOffset);
TaskInfo.addArgs(builder, argsOffset);
TaskInfo.addFunctionId(builder, function_idOffset);
TaskInfo.addActorCounter(builder, actor_counter);
TaskInfo.addActorHandleId(builder, actor_handle_idOffset);
TaskInfo.addActorId(builder, actor_idOffset);
@@ -97,7 +92,7 @@ public final class TaskInfo extends Table {
return TaskInfo.endTaskInfo(builder);
}
public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(18); }
public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(17); }
public static void addDriverId(FlatBufferBuilder builder, int driverIdOffset) { builder.addOffset(0, driverIdOffset, 0); }
public static void addTaskId(FlatBufferBuilder builder, int taskIdOffset) { builder.addOffset(1, taskIdOffset, 0); }
public static void addParentTaskId(FlatBufferBuilder builder, int parentTaskIdOffset) { builder.addOffset(2, parentTaskIdOffset, 0); }
@@ -109,21 +104,20 @@ public final class TaskInfo extends Table {
public static void addActorHandleId(FlatBufferBuilder builder, int actorHandleIdOffset) { builder.addOffset(8, actorHandleIdOffset, 0); }
public static void addActorCounter(FlatBufferBuilder builder, int actorCounter) { builder.addInt(9, actorCounter, 0); }
public static void addIsActorCheckpointMethod(FlatBufferBuilder builder, boolean isActorCheckpointMethod) { builder.addBoolean(10, isActorCheckpointMethod, false); }
public static void addFunctionId(FlatBufferBuilder builder, int functionIdOffset) { builder.addOffset(11, functionIdOffset, 0); }
public static void addArgs(FlatBufferBuilder builder, int argsOffset) { builder.addOffset(12, argsOffset, 0); }
public static void addArgs(FlatBufferBuilder builder, int argsOffset) { builder.addOffset(11, argsOffset, 0); }
public static int createArgsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startArgsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addReturns(FlatBufferBuilder builder, int returnsOffset) { builder.addOffset(13, returnsOffset, 0); }
public static void addReturns(FlatBufferBuilder builder, int returnsOffset) { builder.addOffset(12, returnsOffset, 0); }
public static int createReturnsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startReturnsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addRequiredResources(FlatBufferBuilder builder, int requiredResourcesOffset) { builder.addOffset(14, requiredResourcesOffset, 0); }
public static void addRequiredResources(FlatBufferBuilder builder, int requiredResourcesOffset) { builder.addOffset(13, requiredResourcesOffset, 0); }
public static int createRequiredResourcesVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startRequiredResourcesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addRequiredPlacementResources(FlatBufferBuilder builder, int requiredPlacementResourcesOffset) { builder.addOffset(15, requiredPlacementResourcesOffset, 0); }
public static void addRequiredPlacementResources(FlatBufferBuilder builder, int requiredPlacementResourcesOffset) { builder.addOffset(14, requiredPlacementResourcesOffset, 0); }
public static int createRequiredPlacementResourcesVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startRequiredPlacementResourcesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addLanguage(FlatBufferBuilder builder, int language) { builder.addInt(16, language, 0); }
public static void addFunctionDescriptor(FlatBufferBuilder builder, int functionDescriptorOffset) { builder.addOffset(17, functionDescriptorOffset, 0); }
public static void addLanguage(FlatBufferBuilder builder, int language) { builder.addInt(15, language, 0); }
public static void addFunctionDescriptor(FlatBufferBuilder builder, int functionDescriptorOffset) { builder.addOffset(16, functionDescriptorOffset, 0); }
public static int createFunctionDescriptorVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startFunctionDescriptorVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static int endTaskInfo(FlatBufferBuilder builder) {
@@ -131,10 +125,14 @@ public final class TaskInfo extends Table {
return o;
}
//this is manually added to avoid encoding/decoding cost as our object
//id is a byte array instead of a string
/** This is manually added to avoid encoding/decoding cost as our object
* id is a byte array instead of a string.
* This function is error-prone. If the fields before `returns` changed,
* the offset number should be changed accordingly.
* TODO(yuhguo): fix this error-prone funciton.
*/
public ByteBuffer returnsAsByteBuffer(int j) {
int o = __offset(30);
int o = __offset(28);
if (o == 0) {
return null;
}
@@ -183,7 +183,6 @@ public class RayletClientImpl implements RayletClient {
final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer());
final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer());
final int actorCounter = task.actorCounter;
final int functionIdOffset = fbb.createString(UniqueId.randomId().toByteBuffer());
// Serialize args
int[] argsOffsets = new int[task.args.length];
for (int i = 0; i < argsOffsets.length; i++) {
@@ -245,7 +244,6 @@ public class RayletClientImpl implements RayletClient {
actorHandleIdOffset,
actorCounter,
false,
functionIdOffset,
argsOffset,
returnsOffset,
requiredResourcesOffset,
+46 -41
View File
@@ -10,7 +10,7 @@ import sys
import traceback
import ray.cloudpickle as pickle
from ray.function_manager import FunctionActorManager
from ray.function_manager import FunctionDescriptor
import ray.raylet
import ray.ray_constants as ray_constants
import ray.signature as signature
@@ -76,18 +76,6 @@ def compute_actor_handle_id_non_forked(actor_id, actor_handle_id,
return ray.ObjectID(handle_id)
def compute_actor_creation_function_id(class_id):
"""Compute the function ID for an actor creation task.
Args:
class_id: The ID of the actor class.
Returns:
The function ID of the actor creation event.
"""
return ray.ObjectID(class_id)
def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
frontier):
"""Set the most recent checkpoint associated with a given actor ID.
@@ -287,6 +275,23 @@ class ActorClass(object):
self._actor_methods = inspect.getmembers(
self._modified_class, ray.utils.is_function_or_method)
self._actor_method_names = [
method_name for method_name, _ in self._actor_methods
]
constructor_name = "__init__"
if constructor_name not in self._actor_method_names:
# Add __init__ if it does not exist.
# Actor creation will be executed with __init__ together.
# Assign an __init__ function will avoid many checks later on.
def __init__(self):
pass
self._modified_class.__init__ = __init__
self._actor_method_names.append(constructor_name)
self._actor_methods.append((constructor_name, __init__))
# Extract the signatures of each of the methods. This will be used
# to catch some errors if the methods are called with inappropriate
# arguments.
@@ -300,7 +305,6 @@ class ActorClass(object):
signature.check_signature_supported(method, warn=True)
self._method_signatures[method_name] = signature.extract_signature(
method, ignore_first=not ray.utils.is_class_method(method))
# Set the default number of return values for this method.
if hasattr(method, "__ray_num_return_vals__"):
self._actor_method_num_return_vals[method_name] = (
@@ -309,10 +313,6 @@ class ActorClass(object):
self._actor_method_num_return_vals[method_name] = (
DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS)
self._actor_method_names = [
method_name for method_name, _ in self._actor_methods
]
def __call__(self, *args, **kwargs):
raise Exception("Actors methods cannot be instantiated directly. "
"Instead of running '{}()', try '{}.remote()'.".format(
@@ -386,14 +386,14 @@ class ActorClass(object):
# Instead, instantiate the actor locally and add it to the worker's
# dictionary
if worker.mode == ray.LOCAL_MODE:
worker.actors[actor_id] = self._modified_class.__new__(
self._modified_class)
worker.actors[actor_id] = self._modified_class(
*copy.deepcopy(args), **copy.deepcopy(kwargs))
else:
# Export the actor.
if not self._exported:
worker.function_actor_manager.export_actor_class(
self._class_id, self._modified_class,
self._actor_method_names, self._checkpoint_interval)
self._modified_class, self._actor_method_names,
self._checkpoint_interval)
self._exported = True
resources = ray.utils.resources_from_resource_arguments(
@@ -409,10 +409,19 @@ class ActorClass(object):
actor_placement_resources = resources.copy()
actor_placement_resources["CPU"] += 1
creation_args = [self._class_id]
function_id = compute_actor_creation_function_id(self._class_id)
if args is None:
args = []
if kwargs is None:
kwargs = {}
function_name = "__init__"
function_signature = self._method_signatures[function_name]
creation_args = signature.extend_args(function_signature, args,
kwargs)
function_descriptor = FunctionDescriptor(
self._modified_class.__module__, function_name,
self._modified_class.__name__)
[actor_cursor] = worker.submit_task(
function_id,
function_descriptor,
creation_args,
actor_creation_id=actor_id,
max_actor_reconstructions=self._max_reconstructions,
@@ -424,19 +433,10 @@ class ActorClass(object):
# creation task.
actor_counter = 1
actor_handle = ActorHandle(
actor_id, self._class_name, actor_cursor, actor_counter,
self._actor_method_names, self._method_signatures,
self._actor_method_num_return_vals, actor_cursor,
self._actor_method_cpus, worker.task_driver_id)
# Call __init__ as a remote function.
if "__init__" in actor_handle._ray_actor_method_names:
actor_handle.__init__.remote(*args, **kwargs)
else:
if len(args) != 0 or len(kwargs) != 0:
raise Exception("Arguments cannot be passed to the actor "
"constructor because this actor class has no "
"__init__ method.")
actor_id, self._modified_class.__module__, self._class_name,
actor_cursor, actor_counter, self._actor_method_names,
self._method_signatures, self._actor_method_num_return_vals,
actor_cursor, self._actor_method_cpus, worker.task_driver_id)
return actor_handle
@@ -458,6 +458,7 @@ class ActorHandle(object):
Attributes:
_ray_actor_id: The ID of the corresponding actor.
_ray_module_name: The module name of this actor.
_ray_actor_handle_id: The ID of this handle. If this is the "original"
handle for an actor (as opposed to one created by passing another
handle into a task), then this ID must be NIL_ID. If this
@@ -496,6 +497,7 @@ class ActorHandle(object):
def __init__(self,
actor_id,
module_name,
class_name,
actor_cursor,
actor_counter,
@@ -510,6 +512,7 @@ class ActorHandle(object):
# False if this actor handle was created by forking or pickling. True
# if it was created by the _serialization_helper function.
self._ray_original_handle = previous_actor_handle_id is None
self._ray_module_name = module_name
self._ray_actor_id = actor_id
if self._ray_original_handle:
@@ -602,10 +605,10 @@ class ActorHandle(object):
else:
actor_handle_id = self._ray_actor_handle_id
function_id = FunctionActorManager.compute_actor_method_function_id(
self._ray_class_name, method_name)
function_descriptor = FunctionDescriptor(
self._ray_module_name, method_name, self._ray_class_name)
object_ids = worker.submit_task(
function_id,
function_descriptor,
args,
actor_id=self._ray_actor_id,
actor_handle_id=actor_handle_id,
@@ -706,6 +709,7 @@ class ActorHandle(object):
"""
state = {
"actor_id": self._ray_actor_id.id(),
"module_name": self._ray_module_name,
"class_name": self._ray_class_name,
"actor_forks": self._ray_actor_forks,
"actor_cursor": self._ray_actor_cursor.id()
@@ -753,6 +757,7 @@ class ActorHandle(object):
self.__init__(
ray.ObjectID(state["actor_id"]),
state["module_name"],
state["class_name"],
ray.ObjectID(state["actor_cursor"])
if state["actor_cursor"] is not None else None,
+10 -2
View File
@@ -9,6 +9,7 @@ import sys
import time
import ray
from ray.function_manager import FunctionDescriptor
import ray.gcs_utils
import ray.ray_constants as ray_constants
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
@@ -234,6 +235,9 @@ class GlobalState(object):
execution_spec = task_table_message.TaskExecutionSpec()
task_spec = task_table_message.TaskSpecification()
task_spec = ray.raylet.task_from_string(task_spec)
function_descriptor_list = task_spec.function_descriptor_list()
function_descriptor = FunctionDescriptor.from_bytes_list(
function_descriptor_list)
task_spec_info = {
"DriverID": binary_to_hex(task_spec.driver_id().id()),
"TaskID": binary_to_hex(task_spec.task_id().id()),
@@ -245,10 +249,14 @@ class GlobalState(object):
"ActorCreationDummyObjectID": binary_to_hex(
task_spec.actor_creation_dummy_object_id().id()),
"ActorCounter": task_spec.actor_counter(),
"FunctionID": binary_to_hex(task_spec.function_id().id()),
"Args": task_spec.arguments(),
"ReturnObjectIDs": task_spec.returns(),
"RequiredResources": task_spec.required_resources()
"RequiredResources": task_spec.required_resources(),
"FunctionID": binary_to_hex(function_descriptor.function_id.id()),
"FunctionHash": binary_to_hex(function_descriptor.function_hash),
"ModuleName": function_descriptor.module_name,
"ClassName": function_descriptor.class_name,
"FunctionName": function_descriptor.function_name,
}
return {
+278 -46
View File
@@ -5,6 +5,7 @@ from __future__ import print_function
import hashlib
import inspect
import json
import logging
import sys
import time
import traceback
@@ -18,6 +19,7 @@ from ray import profiling
from ray import ray_constants
from ray import cloudpickle as pickle
from ray.utils import (
binary_to_hex,
is_cython,
is_function_or_method,
is_class_method,
@@ -31,6 +33,228 @@ FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
["function", "function_name", "max_calls"])
"""FunctionExecutionInfo: A named tuple storing remote function information."""
logger = logging.getLogger(__name__)
class FunctionDescriptor(object):
"""A class used to describe a python function.
Attributes:
module_name: the module name that the function belongs to.
class_name: the class name that the function belongs to if exists.
It could be empty is the function is not a class method.
function_name: the function name of the function.
function_hash: the hash code of the function source code if the
function code is available.
function_id: the function id calculated from this descriptor.
is_for_driver_task: whether this descriptor is for driver task.
"""
def __init__(self,
module_name,
function_name,
class_name="",
function_source_hash=b""):
self._module_name = module_name
self._class_name = class_name
self._function_name = function_name
self._function_source_hash = function_source_hash
self._function_id = self._get_function_id()
def __repr__(self):
return ("FunctionDescriptor:" + self._module_name + "." +
self._class_name + "." + self._function_name + "." +
binary_to_hex(self._function_source_hash))
@classmethod
def from_bytes_list(cls, function_descriptor_list):
"""Create a FunctionDescriptor instance from list of bytes.
This function is used to create the function descriptor from
backend data.
Args:
cls: Current class which is required argument for classmethod.
function_descriptor_list: list of bytes to represent the
function descriptor.
Returns:
The FunctionDescriptor instance created from the bytes list.
"""
assert isinstance(function_descriptor_list, list)
if len(function_descriptor_list) == 0:
# This is a function descriptor of driver task.
return FunctionDescriptor.for_driver_task()
elif (len(function_descriptor_list) == 3
or len(function_descriptor_list) == 4):
module_name = function_descriptor_list[0].decode()
class_name = function_descriptor_list[1].decode()
function_name = function_descriptor_list[2].decode()
if len(function_descriptor_list) == 4:
return cls(module_name, function_name, class_name,
function_descriptor_list[3])
else:
return cls(module_name, function_name, class_name)
else:
raise Exception(
"Invalid input for FunctionDescriptor.from_bytes_list")
@classmethod
def from_function(cls, function):
"""Create a FunctionDescriptor from a function instance.
This function is used to create the function descriptor from
a python function. If a function is a class function, it should
not be used by this function.
Args:
cls: Current class which is required argument for classmethod.
function: the python function used to create the function
descriptor.
Returns:
The FunctionDescriptor instance created according to the function.
"""
module_name = function.__module__
function_name = function.__name__
class_name = ""
function_source_hasher = hashlib.sha1()
try:
# If we are running a script or are in IPython, include the source
# code in the hash.
source = inspect.getsource(function).encode("ascii")
function_source_hasher.update(source)
function_source_hash = function_source_hasher.digest()
except (IOError, OSError, TypeError):
# Source code may not be available:
# e.g. Cython or Python interpreter.
function_source_hash = b""
return cls(module_name, function_name, class_name,
function_source_hash)
@classmethod
def from_class(cls, target_class):
"""Create a FunctionDescriptor from a class.
Args:
cls: Current class which is required argument for classmethod.
target_class: the python class used to create the function
descriptor.
Returns:
The FunctionDescriptor instance created according to the class.
"""
module_name = target_class.__module__
class_name = target_class.__name__
return cls(module_name, "__init__", class_name)
@classmethod
def for_driver_task(cls):
"""Create a FunctionDescriptor instance for a driver task."""
return cls("", "", "", b"")
@property
def is_for_driver_task(self):
"""See whether this function descriptor is for a driver or not.
Returns:
True if this function descriptor is for driver tasks.
"""
return all(
len(x) == 0
for x in [self.module_name, self.class_name, self.function_name])
@property
def module_name(self):
"""Get the module name of current function descriptor.
Returns:
The module name of the function descriptor.
"""
return self._module_name
@property
def class_name(self):
"""Get the class name of current function descriptor.
Returns:
The class name of the function descriptor. It could be
empty if the function is not a class method.
"""
return self._class_name
@property
def function_name(self):
"""Get the function name of current function descriptor.
Returns:
The function name of the function descriptor.
"""
return self._function_name
@property
def function_hash(self):
"""Get the hash code of the function source code.
Returns:
The bytes with length of ray_constants.ID_SIZE if the source
code is available. Otherwise, the bytes length will be 0.
"""
return self._function_source_hash
@property
def function_id(self):
"""Get the function id calculated from this descriptor.
Returns:
The value of ray.ObjectID that represents the function id.
"""
return ray.ObjectID(self._function_id)
def _get_function_id(self):
"""Calculate the function id of current function descriptor.
This function id is calculated from all the fields of function
descriptor.
Returns:
bytes with length of ray_constants.ID_SIZE.
"""
if self.is_for_driver_task:
return ray_constants.NIL_FUNCTION_ID.id()
function_id_hash = hashlib.sha1()
# Include the function module and name in the hash.
function_id_hash.update(self.module_name.encode("ascii"))
function_id_hash.update(self.function_name.encode("ascii"))
function_id_hash.update(self.class_name.encode("ascii"))
function_id_hash.update(self._function_source_hash)
# Compute the function ID.
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return function_id
def get_function_descriptor_list(self):
"""Return a list of bytes representing the function descriptor.
This function is used to pass this function descriptor to backend.
Returns:
A list of bytes.
"""
descriptor_list = []
if self.is_for_driver_task:
# Driver task returns an empty list.
return descriptor_list
else:
descriptor_list.append(self.module_name.encode("ascii"))
descriptor_list.append(self.class_name.encode("ascii"))
descriptor_list.append(self.function_name.encode("ascii"))
if len(self._function_source_hash) != 0:
descriptor_list.append(self._function_source_hash)
return descriptor_list
class FunctionActorManager(object):
"""A class used to export/load remote functions and actors.
@@ -45,6 +269,8 @@ class FunctionActorManager(object):
and execution_info.
_num_task_executions: The map from driver_id to function
execution times.
imported_actor_classes: The set of actor classes keys (format:
ActorClass:function_id) that are already in GCS.
"""
def __init__(self, worker):
@@ -58,11 +284,17 @@ class FunctionActorManager(object):
# workers that execute remote functions.
self._function_execution_info = defaultdict(lambda: {})
self._num_task_executions = defaultdict(lambda: {})
# A set of all of the actor class keys that have been imported by the
# import thread. It is safe to convert this worker into an actor of
# these types.
self.imported_actor_classes = set()
def increase_task_counter(self, driver_id, function_id):
def increase_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id.id()
self._num_task_executions[driver_id][function_id] += 1
def get_task_counter(self, driver_id, function_id):
def get_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id.id()
return self._num_task_executions[driver_id][function_id]
def export_cached(self):
@@ -124,13 +356,13 @@ class FunctionActorManager(object):
check_oversized_pickle(pickled_function,
remote_function._function_name,
"remote function", self._worker)
key = (b"RemoteFunction:" + self._worker.task_driver_id.id() + b":" +
remote_function._function_id)
remote_function._function_descriptor.function_id.id())
self._worker.redis_client.hmset(
key, {
"driver_id": self._worker.task_driver_id.id(),
"function_id": remote_function._function_id,
"function_id": remote_function._function_descriptor.
function_id.id(),
"name": remote_function._function_name,
"module": function.__module__,
"function": pickled_function,
@@ -193,24 +425,28 @@ class FunctionActorManager(object):
self._worker.redis_client.rpush(
b"FunctionTable:" + function_id.id(), self._worker.worker_id)
def get_execution_info(self, driver_id, function_id):
def get_execution_info(self, driver_id, function_descriptor):
"""Get the FunctionExecutionInfo of a remote function.
Args:
driver_id: ID of the driver that the function belongs to.
function_id: ID of the function to get.
function_descriptor: The FunctionDescriptor of the function to get.
Returns:
A FunctionExecutionInfo object.
"""
# Wait until the function to be executed has actually been registered
# on this worker. We will push warnings to the user if we spend too
# long in this loop.
with profiling.profile("wait_for_function", worker=self._worker):
self._wait_for_function(function_id, driver_id)
return self._function_execution_info[driver_id][function_id.id()]
function_id = function_descriptor.function_id.id()
def _wait_for_function(self, function_id, driver_id, timeout=10):
# Wait until the function to be executed has actually been
# registered on this worker. We will push warnings to the user if
# we spend too long in this loop.
# The driver function may not be found in sys.path. Try to load
# the function from GCS.
with profiling.profile("wait_for_function", worker=self._worker):
self._wait_for_function(function_descriptor, driver_id)
return self._function_execution_info[driver_id][function_id]
def _wait_for_function(self, function_descriptor, driver_id, timeout=10):
"""Wait until the function to be executed is present on this worker.
This method will simply loop until the import thread has imported the
@@ -221,7 +457,8 @@ class FunctionActorManager(object):
been defined.
Args:
function_id (str): The ID of the function that we want to execute.
function_descriptor : The FunctionDescriptor of the function that
we want to execute.
driver_id (str): The ID of the driver to push the error message to
if this times out.
"""
@@ -231,7 +468,7 @@ class FunctionActorManager(object):
while True:
with self._worker.lock:
if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID
and (function_id.id() in
and (function_descriptor.function_id.id() in
self._function_execution_info[driver_id])):
break
elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and (
@@ -251,24 +488,6 @@ class FunctionActorManager(object):
warning_sent = True
time.sleep(0.001)
@classmethod
def compute_actor_method_function_id(cls, class_name, attr):
"""Get the function ID corresponding to an actor method.
Args:
class_name (str): The class name of the actor.
attr (str): The attribute name of the method.
Returns:
Function ID corresponding to the method.
"""
function_id_hash = hashlib.sha1()
function_id_hash.update(class_name.encode("ascii"))
function_id_hash.update(attr.encode("ascii"))
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return ray.ObjectID(function_id)
def _publish_actor_class_to_key(self, key, actor_class_info):
"""Push an actor class definition to Redis.
@@ -287,9 +506,10 @@ class FunctionActorManager(object):
self._worker.redis_client.hmset(key, actor_class_info)
self._worker.redis_client.rpush("Exports", key)
def export_actor_class(self, class_id, Class, actor_method_names,
def export_actor_class(self, Class, actor_method_names,
checkpoint_interval):
key = b"ActorClass:" + class_id
function_descriptor = FunctionDescriptor.from_class(Class)
key = b"ActorClass:" + function_descriptor.function_id.id()
actor_class_info = {
"class_name": Class.__name__,
"module": Class.__module__,
@@ -318,6 +538,17 @@ class FunctionActorManager(object):
# within tasks. I tried to disable this, but it may be necessary
# because of https://github.com/ray-project/ray/issues/1146.
def load_actor(self, driver_id, function_descriptor):
key = b"ActorClass:" + function_descriptor.function_id.id()
# Wait for the actor class key to have been imported by the
# import thread. TODO(rkn): It shouldn't be possible to end
# up in an infinite loop here, but we should push an error to
# the driver if too much time is spent here.
while key not in self.imported_actor_classes:
time.sleep(0.001)
with self._worker.lock:
self.fetch_and_register_actor(key)
def fetch_and_register_actor(self, actor_class_key):
"""Import an actor.
@@ -330,11 +561,10 @@ class FunctionActorManager(object):
worker: The worker to use.
"""
actor_id_str = self._worker.actor_id
(driver_id, class_id, class_name, module, pickled_class,
checkpoint_interval,
(driver_id, class_name, module, pickled_class, checkpoint_interval,
actor_method_names) = self._worker.redis_client.hmget(
actor_class_key, [
"driver_id", "class_id", "class_name", "module", "class",
"driver_id", "class_name", "module", "class",
"checkpoint_interval", "actor_method_names"
])
@@ -368,9 +598,9 @@ class FunctionActorManager(object):
# Register the actor method executors.
for actor_method_name in actor_method_names:
function_id = (
FunctionActorManager.compute_actor_method_function_id(
class_name, actor_method_name).id())
function_descriptor = FunctionDescriptor(module, actor_method_name,
class_name)
function_id = function_descriptor.function_id.id()
temporary_executor = self._make_actor_method_executor(
actor_method_name,
temporary_actor_method,
@@ -409,9 +639,9 @@ class FunctionActorManager(object):
actor_methods = inspect.getmembers(
unpickled_class, predicate=is_function_or_method)
for actor_method_name, actor_method in actor_methods:
function_id = (
FunctionActorManager.compute_actor_method_function_id(
class_name, actor_method_name).id())
function_descriptor = FunctionDescriptor(
module, actor_method_name, class_name)
function_id = function_descriptor.function_id.id()
executor = self._make_actor_method_executor(
actor_method_name, actor_method, actor_imported=True)
self._function_execution_info[driver_id][function_id] = (
@@ -452,7 +682,9 @@ class FunctionActorManager(object):
# If this is the first task to execute on the actor, try to resume
# from a checkpoint.
if actor_imported and self._worker.actor_task_counter == 1:
# Current __init__ will be called by default. So the real function
# call will start from 2.
if actor_imported and self._worker.actor_task_counter == 2:
checkpoint_resumed = ray.actor.restore_and_log_checkpoint(
self._worker, actor)
if checkpoint_resumed:
+9 -9
View File
@@ -3,26 +3,26 @@ from __future__ import division
from __future__ import print_function
import flatbuffers
import ray.core.generated.ErrorTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.ProfileTableData import ProfileTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData
from ray.core.generated.DriverTableData import DriverTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.Language import Language
from ray.core.generated.ObjectTableData import ObjectTableData
from ray.core.generated.ray.protocol.Task import Task
from ray.core.generated.ProfileTableData import ProfileTableData
from ray.core.generated.TablePrefix import TablePrefix
from ray.core.generated.TablePubsub import TablePubsub
from ray.core.generated.ray.protocol.Task import Task
__all__ = [
"GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData",
"HeartbeatBatchTableData", "DriverTableData", "ProfileTableData",
"ObjectTableData", "Task", "TablePrefix", "TablePubsub",
"ObjectTableData", "Task", "TablePrefix", "TablePubsub", "Language",
"construct_error_message"
]
+1 -1
View File
@@ -98,7 +98,7 @@ class ImportThread(object):
# Keep track of the fact that this actor class has been
# exported so that we know it is safe to turn this worker
# into an actor of that class.
self.worker.imported_actor_classes.add(key)
self.worker.function_actor_manager.imported_actor_classes.add(key)
# TODO(rkn): We may need to bring back the case of
# fetching actor classes here.
else:
+1
View File
@@ -16,6 +16,7 @@ def env_integer(key, default):
ID_SIZE = 20
NIL_JOB_ID = ObjectID(ID_SIZE * b"\xff")
NIL_FUNCTION_ID = NIL_JOB_ID
# If a remote function or actor (or some other export) has serialized size
# greater than this quantity, print an warning.
+4 -36
View File
@@ -3,11 +3,9 @@ from __future__ import division
from __future__ import print_function
import copy
import hashlib
import inspect
import logging
import ray.ray_constants as ray_constants
from ray.function_manager import FunctionDescriptor
import ray.signature
# Default parameters for remote functions.
@@ -18,33 +16,6 @@ DEFAULT_REMOTE_FUNCTION_MAX_CALLS = 0
logger = logging.getLogger(__name__)
def compute_function_id(function):
"""Compute an function ID for a function.
Args:
func: The actual function.
Returns:
Raw bytes of the function id
"""
function_id_hash = hashlib.sha1()
# Include the function module and name in the hash.
function_id_hash.update(function.__module__.encode("ascii"))
function_id_hash.update(function.__name__.encode("ascii"))
try:
# If we are running a script or are in IPython, include the source code
# in the hash.
source = inspect.getsource(function).encode("ascii")
function_id_hash.update(source)
except (IOError, OSError, TypeError):
# Source code may not be available: e.g. Cython or Python interpreter.
pass
# Compute the function ID.
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return function_id
class RemoteFunction(object):
"""A remote function.
@@ -52,7 +23,7 @@ class RemoteFunction(object):
Attributes:
_function: The original function.
_function_id: The ID of the function.
_function_descriptor: The function descriptor.
_function_name: The module and function name.
_num_cpus: The default number of CPUs to use for invocations of this
remote function.
@@ -70,10 +41,7 @@ class RemoteFunction(object):
def __init__(self, function, num_cpus, num_gpus, resources,
num_return_vals, max_calls):
self._function = function
# TODO(rkn): We store the function ID as a string, so that
# RemoteFunction objects can be pickled. We should undo this when
# we allow ObjectIDs to be pickled.
self._function_id = compute_function_id(function)
self._function_descriptor = FunctionDescriptor.from_function(function)
self._function_name = (
self._function.__module__ + '.' + self._function.__name__)
self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS
@@ -147,7 +115,7 @@ class RemoteFunction(object):
result = self._function(*copy.deepcopy(args))
return result
object_ids = worker.submit_task(
ray.ObjectID(self._function_id),
self._function_descriptor,
args,
num_return_vals=num_return_vals,
resources=resources)
+64 -73
View File
@@ -36,7 +36,7 @@ import ray.plasma
import ray.ray_constants as ray_constants
from ray import import_thread
from ray import profiling
from ray.function_manager import FunctionActorManager
from ray.function_manager import (FunctionActorManager, FunctionDescriptor)
from ray.utils import (
check_oversized_pickle,
is_cython,
@@ -54,7 +54,6 @@ ERROR_KEY_PREFIX = b"Error:"
# This must match the definition of NIL_ACTOR_ID in task.h.
NIL_ID = ray_constants.ID_SIZE * b"\xff"
NIL_LOCAL_SCHEDULER_ID = NIL_ID
NIL_FUNCTION_ID = NIL_ID
NIL_ACTOR_ID = NIL_ID
NIL_ACTOR_HANDLE_ID = NIL_ID
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
@@ -161,10 +160,6 @@ class Worker(object):
self.make_actor = None
self.actors = {}
self.actor_task_counter = 0
# A set of all of the actor class keys that have been imported by the
# import thread. It is safe to convert this worker into an actor of
# these types.
self.imported_actor_classes = set()
# The number of threads Plasma should use when putting an object in the
# object store.
self.memcopy_threads = 12
@@ -518,7 +513,7 @@ class Worker(object):
return final_results
def submit_task(self,
function_id,
function_descriptor,
args,
actor_id=None,
actor_handle_id=None,
@@ -531,15 +526,16 @@ class Worker(object):
num_return_vals=None,
resources=None,
placement_resources=None,
driver_id=None):
driver_id=None,
language=ray.gcs_utils.Language.PYTHON):
"""Submit a remote task to the scheduler.
Tell the scheduler to schedule the execution of the function with ID
function_id with arguments args. Retrieve object IDs for the outputs of
the function from the scheduler and immediately return them.
Tell the scheduler to schedule the execution of the function with
function_descriptor with arguments args. Retrieve object IDs for the
outputs of the function from the scheduler and immediately return them.
Args:
function_id: The ID of the function to execute.
function_descriptor: The function descriptor to execute.
args: The arguments to pass into the function. Arguments can be
object IDs or they can be values. If they are values, they must
be serializable objects.
@@ -623,13 +619,15 @@ class Worker(object):
# The parent task must be set for the submitted task.
assert not self.current_task_id.is_nil()
# Submit the task to local scheduler.
function_descriptor_list = (
function_descriptor.get_function_descriptor_list())
task = ray.raylet.Task(
driver_id, ray.ObjectID(function_id.id()),
args_for_local_scheduler, num_return_vals,
self.current_task_id, task_index, actor_creation_id,
actor_creation_dummy_object_id, max_actor_reconstructions,
actor_id, actor_handle_id, actor_counter,
execution_dependencies, resources, placement_resources)
driver_id, function_descriptor_list, args_for_local_scheduler,
num_return_vals, self.current_task_id, task_index,
actor_creation_id, actor_creation_dummy_object_id,
max_actor_reconstructions, actor_id, actor_handle_id,
actor_counter, execution_dependencies, resources,
placement_resources)
self.raylet_client.submit_task(task)
return task.returns()
@@ -778,10 +776,12 @@ class Worker(object):
self.task_driver_id = task.driver_id()
self.current_task_id = task.task_id()
function_id = task.function_id()
function_descriptor = FunctionDescriptor.from_bytes_list(
task.function_descriptor_list())
args = task.arguments()
return_object_ids = task.returns()
if task.actor_id().id() != NIL_ACTOR_ID:
if (task.actor_id().id() != NIL_ACTOR_ID
or task.actor_creation_id().id() != NIL_ACTOR_ID):
dummy_return_id = return_object_ids.pop()
function_executor = function_execution_info.function
function_name = function_execution_info.function_name
@@ -796,33 +796,36 @@ class Worker(object):
function_name, args)
except RayTaskError as e:
self._handle_process_task_failure(
function_id, function_name, return_object_ids, e,
function_descriptor, return_object_ids, e,
ray.utils.format_error_message(traceback.format_exc()))
return
except Exception as e:
self._handle_process_task_failure(
function_id, function_name, return_object_ids, e,
function_descriptor, return_object_ids, e,
ray.utils.format_error_message(traceback.format_exc()))
return
# Execute the task.
try:
with profiling.profile("task:execute", worker=self):
if task.actor_id().id() == NIL_ACTOR_ID:
if (task.actor_id().id() == NIL_ACTOR_ID
and task.actor_creation_id().id() == NIL_ACTOR_ID):
outputs = function_executor(*arguments)
else:
outputs = function_executor(
dummy_return_id, self.actors[task.actor_id().id()],
*arguments)
if task.actor_id().id() != NIL_ACTOR_ID:
key = task.actor_id().id()
else:
key = task.actor_creation_id().id()
outputs = function_executor(dummy_return_id,
self.actors[key], *arguments)
except Exception as e:
# Determine whether the exception occured during a task, not an
# actor method.
task_exception = task.actor_id().id() == NIL_ACTOR_ID
traceback_str = ray.utils.format_error_message(
traceback.format_exc(), task_exception=task_exception)
self._handle_process_task_failure(function_id, function_name,
return_object_ids, e,
traceback_str)
self._handle_process_task_failure(
function_descriptor, return_object_ids, e, traceback_str)
return
# Store the outputs in the local object store.
@@ -837,11 +840,13 @@ class Worker(object):
self._store_outputs_in_object_store(return_object_ids, outputs)
except Exception as e:
self._handle_process_task_failure(
function_id, function_name, return_object_ids, e,
function_descriptor, return_object_ids, e,
ray.utils.format_error_message(traceback.format_exc()))
def _handle_process_task_failure(self, function_id, function_name,
def _handle_process_task_failure(self, function_descriptor,
return_object_ids, error, backtrace):
function_name = function_descriptor.function_name
function_id = function_descriptor.function_id
failure_object = RayTaskError(function_name, backtrace)
failure_objects = [
failure_object for _ in range(len(return_object_ids))
@@ -855,53 +860,34 @@ class Worker(object):
driver_id=self.task_driver_id.id(),
data={
"function_id": function_id.id(),
"function_name": function_name
"function_name": function_name,
"module_name": function_descriptor.module_name,
"class_name": function_descriptor.class_name
})
# Mark the actor init as failed
if self.actor_id != NIL_ACTOR_ID and function_name == "__init__":
self.mark_actor_init_failed(error)
def _become_actor(self, task):
"""Turn this worker into an actor.
Args:
task: The actor creation task.
"""
assert self.actor_id == NIL_ACTOR_ID
arguments = task.arguments()
assert len(arguments) == 1
self.actor_id = task.actor_creation_id().id()
class_id = arguments[0]
key = b"ActorClass:" + class_id
# Wait for the actor class key to have been imported by the import
# thread. TODO(rkn): It shouldn't be possible to end up in an infinite
# loop here, but we should push an error to the driver if too much time
# is spent here.
while key not in self.imported_actor_classes:
time.sleep(0.001)
with self.lock:
self.function_actor_manager.fetch_and_register_actor(key)
def _wait_for_and_process_task(self, task):
"""Wait for a task to be ready and process the task.
Args:
task: The task to execute.
"""
function_id = task.function_id()
function_descriptor = FunctionDescriptor.from_bytes_list(
task.function_descriptor_list())
driver_id = task.driver_id().id()
# TODO(rkn): It would be preferable for actor creation tasks to share
# more of the code path with regular task execution.
if (task.actor_creation_id() != ray.ObjectID(NIL_ACTOR_ID)):
self._become_actor(task)
return
assert self.actor_id == NIL_ACTOR_ID
self.actor_id = task.actor_creation_id().id()
self.function_actor_manager.load_actor(driver_id,
function_descriptor)
execution_info = self.function_actor_manager.get_execution_info(
driver_id, function_id)
driver_id, function_descriptor)
# Execute the task.
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
@@ -915,8 +901,14 @@ class Worker(object):
"task_id": task.task_id().hex()
}
if task.actor_id().id() == NIL_ACTOR_ID:
title = "ray_worker:{}()".format(function_name)
next_title = "ray_worker"
if (task.actor_creation_id() == ray.ObjectID(NIL_ACTOR_ID)):
title = "ray_worker:{}()".format(function_name)
next_title = "ray_worker"
else:
actor = self.actors[task.actor_creation_id().id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
function_name)
next_title = "ray_{}".format(actor.__class__.__name__)
else:
actor = self.actors[task.actor_id().id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
@@ -934,10 +926,10 @@ class Worker(object):
# Increase the task execution counter.
self.function_actor_manager.increase_task_counter(
driver_id, function_id.id())
driver_id, function_descriptor)
reached_max_executions = (self.function_actor_manager.get_task_counter(
driver_id, function_id.id()) == execution_info.max_calls)
driver_id, function_descriptor) == execution_info.max_calls)
if reached_max_executions:
self.raylet_client.disconnect()
sys.exit(0)
@@ -2096,15 +2088,14 @@ def connect(info,
# rerun the driver.
nil_actor_counter = 0
driver_task = ray.raylet.Task(worker.task_driver_id,
ray.ObjectID(NIL_FUNCTION_ID), [], 0,
worker.current_task_id,
worker.task_index,
ray.ObjectID(NIL_ACTOR_ID),
ray.ObjectID(NIL_ACTOR_ID), 0,
ray.ObjectID(NIL_ACTOR_ID),
ray.ObjectID(NIL_ACTOR_ID),
nil_actor_counter, [], {"CPU": 0}, {})
function_descriptor = FunctionDescriptor.for_driver_task()
driver_task = ray.raylet.Task(
worker.task_driver_id,
function_descriptor.get_function_descriptor_list(), [], 0,
worker.current_task_id, worker.task_index,
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), 0,
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID),
nil_actor_counter, [], {"CPU": 0}, {})
# Add the driver task to the task table.
global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD",
+22
View File
@@ -69,3 +69,25 @@ map_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
}
return fbb.CreateVector(resource_vector);
}
std::vector<std::string> string_vec_from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> &flatbuf_vec) {
std::vector<std::string> string_vector;
string_vector.reserve(flatbuf_vec.size());
for (int64_t i = 0; i < flatbuf_vec.size(); i++) {
const auto flatbuf_str = flatbuf_vec.Get(i);
string_vector.push_back(string_from_flatbuf(*flatbuf_str));
}
return string_vector;
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
string_vec_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
const std::vector<std::string> &string_vector) {
std::vector<flatbuffers::Offset<flatbuffers::String>> flatbuf_str_vec;
flatbuf_str_vec.reserve(flatbuf_str_vec.size());
for (auto const &str : string_vector) {
flatbuf_str_vec.push_back(fbb.CreateString(str));
}
return fbb.CreateVector(flatbuf_str_vec);
}
+6
View File
@@ -72,4 +72,10 @@ map_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
const std::unordered_map<std::string, double> map_from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<ResourcePair>> &resource_vector);
std::vector<std::string> string_vec_from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> &flatbuf_vec);
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
string_vec_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
const std::vector<std::string> &string_vector);
#endif
-2
View File
@@ -73,8 +73,6 @@ table TaskInfo {
actor_counter: int;
// True if this task is an actor checkpoint task and false otherwise.
is_actor_checkpoint_method: bool;
// Function ID of the task.
function_id: string;
// Task arguments.
args: [Arg];
// Object IDs of return values.
+61 -17
View File
@@ -84,6 +84,32 @@ int PyObjectToUniqueID(PyObject *object, ObjectID *objectid) {
}
}
int PyListStringToStringVector(PyObject *object,
std::vector<std::string> *function_descriptor) {
if (function_descriptor == nullptr) {
PyErr_SetString(PyExc_TypeError, "function descriptor must be non-empty pointer");
return 0;
}
function_descriptor->clear();
std::vector<std::string> string_vector;
if (PyList_Check(object)) {
Py_ssize_t size = PyList_Size(object);
for (Py_ssize_t i = 0; i < size; ++i) {
PyObject *item = PyList_GetItem(object, i);
if (PyBytes_Check(item) == 0) {
PyErr_SetString(PyExc_TypeError,
"PyListStringToStringVector takes a list of byte strings.");
return 0;
}
function_descriptor->emplace_back(PyBytes_AsString(item), PyBytes_Size(item));
}
return 1;
} else {
PyErr_SetString(PyExc_TypeError, "must be a list of strings");
return 0;
}
}
static int PyObjectID_init(PyObjectID *self, PyObject *args, PyObject *kwds) {
const char *data;
int size;
@@ -363,12 +389,12 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) {
UniqueID actor_handle_id;
// How many tasks have been launched on the actor so far?
int actor_counter = 0;
// ID of the function this task executes.
FunctionID function_id;
// Arguments of the task (can be PyObjectIDs or Python values).
PyObject *arguments;
// Number of return values of this task.
int num_returns;
// Task language type enum number.
int language = static_cast<int>(Language::PYTHON);
// The ID of the task that called this task.
TaskID parent_task_id;
// The number of tasks that the parent task has called prior to this one.
@@ -387,14 +413,17 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) {
PyObject *resource_map = nullptr;
// Dictionary of required placement resources for this task.
PyObject *placement_resource_map = nullptr;
if (!PyArg_ParseTuple(args, "O&O&OiO&i|O&O&iO&O&iOOO", &PyObjectToUniqueID, &driver_id,
&PyObjectToUniqueID, &function_id, &arguments, &num_returns,
&PyObjectToUniqueID, &parent_task_id, &parent_counter,
&PyObjectToUniqueID, &actor_creation_id, &PyObjectToUniqueID,
&actor_creation_dummy_object_id, &max_actor_reconstructions,
&PyObjectToUniqueID, &actor_id, &PyObjectToUniqueID,
&actor_handle_id, &actor_counter, &execution_arguments,
&resource_map, &placement_resource_map)) {
// Function descriptor.
std::vector<std::string> function_descriptor;
if (!PyArg_ParseTuple(
args, "O&O&OiO&i|O&O&iO&O&iOOOi", &PyObjectToUniqueID, &driver_id,
&PyListStringToStringVector, &function_descriptor, &arguments, &num_returns,
&PyObjectToUniqueID, &parent_task_id, &parent_counter, &PyObjectToUniqueID,
&actor_creation_id, &PyObjectToUniqueID, &actor_creation_dummy_object_id,
&max_actor_reconstructions, &PyObjectToUniqueID, &actor_id, &PyObjectToUniqueID,
&actor_handle_id, &actor_counter, &execution_arguments, &resource_map,
&placement_resource_map, &language)) {
return -1;
}
@@ -424,6 +453,7 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) {
self->task_spec = nullptr;
// Create the task spec.
// Parse the arguments from the list.
std::vector<std::shared_ptr<ray::raylet::TaskArgument>> task_args;
for (Py_ssize_t i = 0; i < num_args; ++i) {
@@ -444,8 +474,8 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) {
self->task_spec = new ray::raylet::TaskSpecification(
driver_id, parent_task_id, parent_counter, actor_creation_id,
actor_creation_dummy_object_id, max_actor_reconstructions, actor_id,
actor_handle_id, actor_counter, function_id, task_args, num_returns,
required_resources, required_placement_resources, Language::PYTHON);
actor_handle_id, actor_counter, task_args, num_returns, required_resources,
required_placement_resources, Language::PYTHON, function_descriptor);
/* Set the task's execution dependencies. */
self->execution_dependencies = new std::vector<ObjectID>();
@@ -470,9 +500,23 @@ static void PyTask_dealloc(PyTask *self) {
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject *>(self));
}
static PyObject *PyTask_function_id(PyTask *self) {
FunctionID function_id = self->task_spec->FunctionId();
return PyObjectID_make(function_id);
// Helper function to change a c++ string vector to a Python string list.
static PyObject *VectorStringToPyBytesList(
const std::vector<std::string> &function_descriptor) {
size_t size = function_descriptor.size();
PyObject *return_list = PyList_New(static_cast<Py_ssize_t>(size));
for (size_t i = 0; i < size; ++i) {
auto py_bytes = PyBytes_FromStringAndSize(function_descriptor[i].data(),
function_descriptor[i].size());
PyList_SetItem(return_list, i, py_bytes);
}
return return_list;
}
static PyObject *PyTask_function_descriptor_vector(PyTask *self) {
std::vector<std::string> function_descriptor;
function_descriptor = self->task_spec->FunctionDescriptor();
return VectorStringToPyBytesList(function_descriptor);
}
static PyObject *PyTask_actor_id(PyTask *self) {
@@ -597,8 +641,8 @@ static PyObject *PyTask_to_serialized_flatbuf(PyTask *self) {
}
static PyMethodDef PyTask_methods[] = {
{"function_id", (PyCFunction)PyTask_function_id, METH_NOARGS,
"Return the function ID for this task."},
{"function_descriptor_list", (PyCFunction)PyTask_function_descriptor_vector,
METH_NOARGS, "Return the function descriptor for this task."},
{"parent_task_id", (PyCFunction)PyTask_parent_task_id, METH_NOARGS,
"Return the task ID of the parent task."},
{"parent_counter", (PyCFunction)PyTask_parent_counter, METH_NOARGS,
+3 -2
View File
@@ -112,9 +112,10 @@ static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
std::vector<ObjectID> references = {argument};
task_arguments.emplace_back(std::make_shared<TaskArgumentByReference>(references));
}
std::vector<std::string> function_descriptor(3);
auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0,
UniqueID::from_random(), task_arguments, num_returns,
required_resources, Language::PYTHON);
task_arguments, num_returns, required_resources,
Language::PYTHON, function_descriptor);
auto execution_spec = TaskExecutionSpecification(std::vector<ObjectID>());
execution_spec.IncrementNumForwards();
Task task = Task(execution_spec, spec);
@@ -74,9 +74,10 @@ static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
std::vector<ObjectID> references = {argument};
task_arguments.emplace_back(std::make_shared<TaskArgumentByReference>(references));
}
std::vector<std::string> function_descriptor(3);
auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0,
UniqueID::from_random(), task_arguments, num_returns,
required_resources, Language::PYTHON);
task_arguments, num_returns, required_resources,
Language::PYTHON, function_descriptor);
auto execution_spec = TaskExecutionSpecification(std::vector<ObjectID>());
execution_spec.IncrementNumForwards();
Task task = Task(execution_spec, spec);
+12 -12
View File
@@ -56,25 +56,24 @@ TaskSpecification::TaskSpecification(const std::string &string) {
TaskSpecification::TaskSpecification(
const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
const FunctionID &function_id,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments, int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const Language &language)
const Language &language, const std::vector<std::string> &function_descriptor)
: TaskSpecification(driver_id, parent_task_id, parent_counter, ActorID::nil(),
ObjectID::nil(), 0, ActorID::nil(), ActorHandleID::nil(), -1,
function_id, task_arguments, num_returns, required_resources,
std::unordered_map<std::string, double>(), language) {}
task_arguments, num_returns, required_resources,
std::unordered_map<std::string, double>(), language,
function_descriptor) {}
TaskSpecification::TaskSpecification(
const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id,
const int64_t max_actor_reconstructions, const ActorID &actor_id,
const ActorHandleID &actor_handle_id, int64_t actor_counter,
const FunctionID &function_id,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments, int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources,
const Language &language)
const Language &language, const std::vector<std::string> &function_descriptor)
: spec_() {
flatbuffers::FlatBufferBuilder fbb;
@@ -99,9 +98,10 @@ TaskSpecification::TaskSpecification(
to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id),
to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions,
to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter, false,
to_flatbuf(fbb, function_id), fbb.CreateVector(arguments),
fbb.CreateVector(returns), map_to_flatbuf(fbb, required_resources),
map_to_flatbuf(fbb, required_placement_resources), language);
fbb.CreateVector(arguments), fbb.CreateVector(returns),
map_to_flatbuf(fbb, required_resources),
map_to_flatbuf(fbb, required_placement_resources), language,
string_vec_to_flatbuf(fbb, function_descriptor));
fbb.Finish(spec);
AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize());
}
@@ -134,9 +134,9 @@ int64_t TaskSpecification::ParentCounter() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return message->parent_counter();
}
FunctionID TaskSpecification::FunctionId() const {
std::vector<std::string> TaskSpecification::FunctionDescriptor() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->function_id());
return string_vec_from_flatbuf(*message->function_descriptor());
}
int64_t TaskSpecification::NumArgs() const {
@@ -197,7 +197,7 @@ const ResourceSet TaskSpecification::GetRequiredPlacementResources() const {
bool TaskSpecification::IsDriverTask() const {
// Driver tasks are empty tasks that have no function ID set.
return FunctionId().is_nil();
return FunctionDescriptor().empty();
}
Language TaskSpecification::GetLanguage() const {
+7 -7
View File
@@ -91,17 +91,18 @@ class TaskSpecification {
/// \param parent_task_id The task ID of the task that spawned this task.
/// \param parent_counter The number of tasks that this task's parent spawned
/// before this task.
/// \param function_id The ID of the function this task should execute.
/// \param function_descriptor The function descriptor.
/// \param task_arguments The list of task arguments.
/// \param num_returns The number of values returned by the task.
/// \param required_resources The task's resource demands.
/// \param language The language of the worker that must execute the function.
TaskSpecification(const UniqueID &driver_id, const TaskID &parent_task_id,
int64_t parent_counter, const FunctionID &function_id,
int64_t parent_counter,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments,
int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const Language &language);
const Language &language,
const std::vector<std::string> &function_descriptor);
// TODO(swang): Define an actor task constructor.
/// Create a task specification from the raw fields.
@@ -119,7 +120,6 @@ class TaskSpecification {
/// task. If this is not an actor task, then this is nil.
/// \param actor_counter The number of tasks submitted before this task from
/// the same actor handle. If this is not an actor task, then this is 0.
/// \param function_id The ID of the function this task should execute.
/// \param task_arguments The list of task arguments.
/// \param num_returns The number of values returned by the task.
/// \param required_resources The task's resource demands.
@@ -127,17 +127,17 @@ class TaskSpecification {
/// task on a node. Typically, this should be an empty map in which case it
/// will default to be equal to the required_resources argument.
/// \param language The language of the worker that must execute the function.
/// \param function_descriptor The function descriptor.
TaskSpecification(
const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id,
int64_t max_actor_reconstructions, const ActorID &actor_id,
const ActorHandleID &actor_handle_id, int64_t actor_counter,
const FunctionID &function_id,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments,
int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources,
const Language &language);
const Language &language, const std::vector<std::string> &function_descriptor);
/// Deserialize a task specification from a flatbuffer's string data.
///
@@ -159,7 +159,7 @@ class TaskSpecification {
UniqueID DriverId() const;
TaskID ParentTaskId() const;
int64_t ParentCounter() const;
FunctionID FunctionId() const;
std::vector<std::string> FunctionDescriptor() const;
int64_t NumArgs() const;
int64_t NumReturns() const;
bool ArgByRef(int64_t arg_index) const;
+3 -2
View File
@@ -63,9 +63,10 @@ class WorkerPoolTest : public ::testing::Test {
static inline TaskSpecification ExampleTaskSpec(
const ActorID actor_id = ActorID::nil(),
const Language &language = Language::PYTHON) {
std::vector<std::string> function_descriptor(3);
return TaskSpecification(UniqueID::nil(), UniqueID::nil(), 0, ActorID::nil(),
ObjectID::nil(), 0, actor_id, ActorHandleID::nil(), 0,
FunctionID::nil(), {}, 0, {{}}, {{}}, language);
ObjectID::nil(), 0, actor_id, ActorHandleID::nil(), 0, {}, 0,
{{}}, {{}}, language, function_descriptor);
}
TEST_F(WorkerPoolTest, HandleWorkerRegistration) {