Define common data structures with protobuf. (#5121)

This commit is contained in:
Hao Chen
2019-07-08 22:41:37 +08:00
committed by GitHub
parent b4e51c8aa1
commit 8a30b93e42
64 changed files with 1233 additions and 1561 deletions
-2
View File
@@ -4,8 +4,6 @@
/python/ray/pyarrow_files/
/python/build
/python/dist
/python/flatbuffers-1.7.1/
/flatbuffers-1.7.1/
/thirdparty/pkg/
# Files generated by flatc should be ignored
+34 -62
View File
@@ -11,10 +11,27 @@ COPTS = ["-DRAY_USE_GLOG"]
# === Begin of protobuf definitions ===
proto_library(
name = "common_proto",
srcs = ["src/ray/protobuf/common.proto"],
visibility = ["//java:__subpackages__"],
)
cc_proto_library(
name = "common_cc_proto",
deps = [":common_proto"],
)
python_proto_compile(
name = "common_py_proto",
deps = [":common_proto"],
)
proto_library(
name = "gcs_proto",
srcs = ["src/ray/protobuf/gcs.proto"],
visibility = ["//java:__subpackages__"],
deps = [":common_proto"],
)
cc_proto_library(
@@ -30,6 +47,7 @@ python_proto_compile(
proto_library(
name = "node_manager_proto",
srcs = ["src/ray/protobuf/node_manager.proto"],
deps = [":common_proto"],
)
cc_proto_library(
@@ -50,6 +68,7 @@ cc_proto_library(
proto_library(
name = "worker_proto",
srcs = ["src/ray/protobuf/worker.proto"],
deps = [":common_proto"],
)
cc_proto_library(
@@ -60,6 +79,7 @@ cc_proto_library(
proto_library(
name = "core_worker_proto",
srcs = ["src/ray/protobuf/core_worker.proto"],
deps = [":common_proto"],
)
cc_proto_library(
@@ -242,8 +262,8 @@ cc_library(
copts = COPTS,
linkopts = ["-pthread"],
deps = [
":common_cc_proto",
":gcs",
":gcs_fbs",
":node_manager_fbs",
":node_manager_rpc",
":object_manager",
@@ -491,11 +511,7 @@ cc_library(
],
),
copts = COPTS,
includes = [
"src/ray/gcs/format",
],
deps = [
":gcs_fbs",
":ray_util",
"@boost//:asio",
"@plasma//:plasma_client",
@@ -550,15 +566,10 @@ cc_library(
),
hdrs = glob([
"src/ray/gcs/*.h",
"src/ray/gcs/format/*.h",
]),
copts = COPTS,
includes = [
"src/ray/gcs/format",
],
deps = [
":gcs_cc_proto",
":gcs_fbs",
":hiredis",
":node_manager_fbs",
":node_manager_rpc",
@@ -598,13 +609,6 @@ FLATC_ARGS = [
"--scoped-enums",
]
flatbuffer_cc_library(
name = "gcs_fbs",
srcs = ["src/ray/gcs/format/gcs.fbs"],
flatc_args = FLATC_ARGS,
out_prefix = "src/ray/gcs/format/",
)
flatbuffer_cc_library(
name = "common_fbs",
srcs = ["@plasma//:cpp/src/plasma/format/common.fbs"],
@@ -616,8 +620,6 @@ flatbuffer_cc_library(
name = "node_manager_fbs",
srcs = ["src/ray/raylet/format/node_manager.fbs"],
flatc_args = FLATC_ARGS,
include_paths = ["src/ray/gcs/format"],
includes = [":gcs_fbs_includes"],
out_prefix = "src/ray/raylet/format/",
)
@@ -686,43 +688,6 @@ filegroup(
visibility = ["//java:__subpackages__"],
)
filegroup(
name = "gcs_fbs_file",
srcs = ["src/ray/gcs/format/gcs.fbs"],
visibility = ["//java:__subpackages__"],
)
flatbuffer_py_library(
name = "python_node_manager_fbs",
srcs = [
"src/ray/raylet/format/node_manager.fbs",
],
outs = [
"ray/protocol/DisconnectClient.py",
"ray/protocol/FetchOrReconstruct.py",
"ray/protocol/ForwardTaskRequest.py",
"ray/protocol/FreeObjectsRequest.py",
"ray/protocol/GetTaskReply.py",
"ray/protocol/MessageType.py",
"ray/protocol/NotifyUnblocked.py",
"ray/protocol/PushErrorRequest.py",
"ray/protocol/RegisterClientReply.py",
"ray/protocol/RegisterClientRequest.py",
"ray/protocol/RegisterNodeManagerRequest.py",
"ray/protocol/ResourceIdSetInfo.py",
"ray/protocol/SubmitTaskRequest.py",
"ray/protocol/Task.py",
"ray/protocol/TaskExecutionSpecification.py",
"ray/protocol/WaitReply.py",
"ray/protocol/WaitRequest.py",
],
include_paths = [
"src/ray/gcs/format/",
],
includes = ["src/ray/gcs/format/gcs.fbs"],
out_prefix = "python/ray/core/generated/",
)
filegroup(
name = "python_sources",
srcs = glob([
@@ -781,13 +746,20 @@ cc_binary(
],
)
filegroup(
name = "all_py_proto",
srcs = [
"common_py_proto",
"gcs_py_proto",
],
)
genrule(
name = "ray_pkg",
srcs = [
"python/ray/_raylet.so",
"//:python_sources",
"//:gcs_py_proto",
"//:python_node_manager_fbs",
"//:all_py_proto",
"//:redis-server",
"//:redis-cli",
"//:libray_redis_module.so",
@@ -809,12 +781,12 @@ genrule(
cp -f $(location @plasma//:plasma_store_server) $$WORK_DIR/python/ray/core/src/plasma/ &&
cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ &&
mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ &&
for f in $(locations //:python_node_manager_fbs); do
cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/;
done &&
for f in $(locations //:gcs_py_proto); do
for f in $(locations //:all_py_proto); do
cp -f $$f $$WORK_DIR/python/ray/core/generated/;
done &&
# NOTE(hchen): Protobuf doesn't allow specifying Python package name. So we use this `sed`
# command to change the import path in the generated file.
sed -i -E 's/from src.ray.protobuf/from ./' $$WORK_DIR/python/ray/core/generated/gcs_pb2.py &&
echo $$WORK_DIR > $@
""",
local = 1,
+12 -13
View File
@@ -13,23 +13,22 @@ def flatbuffer_py_library(name, srcs, outs, out_prefix, includes = [], include_p
includes = includes,
)
def flatbuffer_java_library(name, srcs, outs, out_prefix, includes = [], include_paths = []):
flatbuffer_library_public(
name = name,
srcs = srcs,
outs = outs,
language_flag = "-j",
out_prefix = out_prefix,
include_paths = include_paths,
includes = includes,
)
def define_java_module(name, additional_srcs = [], additional_resources = [], define_test_lib = False, test_deps = [], **kwargs):
def define_java_module(
name,
additional_srcs = [],
exclude_srcs = [],
additional_resources = [],
define_test_lib = False,
test_deps = [],
**kwargs):
lib_name = "org_ray_ray_" + name
pom_file_targets = [lib_name]
native.java_library(
name = lib_name,
srcs = additional_srcs + native.glob([name + "/src/main/java/**/*.java"]),
srcs = additional_srcs + native.glob(
[name + "/src/main/java/**/*.java"],
exclude = exclude_srcs,
),
resources = native.glob([name + "/src/main/resources/**"]) + additional_resources,
**kwargs
)
+21 -45
View File
@@ -1,4 +1,4 @@
load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module")
load("//bazel:ray.bzl", "define_java_module")
load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile")
exports_files([
@@ -50,8 +50,10 @@ define_java_module(
define_java_module(
name = "runtime",
additional_srcs = [
":generate_java_gcs_fbs",
":gcs_java_proto",
":all_java_proto",
],
exclude_srcs = [
"runtime/src/main/java/org/ray/runtime/generated/*.java",
],
additional_resources = [
":java_native_deps",
@@ -68,7 +70,6 @@ define_java_module(
deps = [
":org_ray_ray_api",
"@plasma//:org_apache_arrow_arrow_plasma",
"@maven//:com_github_davidmoten_flatbuffers_java",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_typesafe_config",
@@ -151,39 +152,22 @@ java_binary(
],
)
java_proto_compile(
name = "common_java_proto",
deps = ["@//:common_proto"],
)
java_proto_compile(
name = "gcs_java_proto",
deps = ["@//:gcs_proto"],
)
flatbuffers_generated_files = [
"Arg.java",
"Language.java",
"TaskInfo.java",
"ResourcePair.java",
]
flatbuffer_java_library(
name = "java_gcs_fbs",
srcs = ["//:gcs_fbs_file"],
outs = flatbuffers_generated_files,
out_prefix = "",
)
genrule(
name = "generate_java_gcs_fbs",
srcs = [":java_gcs_fbs"],
outs = [
"runtime/src/main/java/org/ray/runtime/generated/" + file for file in flatbuffers_generated_files
filegroup(
name = "all_java_proto",
srcs = [
":common_java_proto",
":gcs_java_proto",
],
cmd = """
for f in $(locations //java:java_gcs_fbs); do
chmod +w $$f
mv -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated
done
python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/..
""",
local = 1,
)
filegroup(
@@ -202,8 +186,7 @@ filegroup(
genrule(
name = "gen_maven_deps",
srcs = [
":gcs_java_proto",
":generate_java_gcs_fbs",
":all_java_proto",
":java_native_deps",
":copy_pom_file",
"@plasma//:org_apache_arrow_arrow_plasma",
@@ -212,6 +195,11 @@ genrule(
cmd = """
set -x
WORK_DIR=$$(pwd)
# Copy protobuf-generated files.
rm -rf $$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated
for f in $(locations //java:all_java_proto); do
unzip $$f -x META-INF/MANIFEST.MF -d $$WORK_DIR/java/runtime/src/main/java
done
# Copy native dependecies.
NATIVE_DEPS_DIR=$$WORK_DIR/java/runtime/native_dependencies/
rm -rf $$NATIVE_DEPS_DIR
@@ -220,18 +208,6 @@ genrule(
chmod +w $$f
cp $$f $$NATIVE_DEPS_DIR
done
# Copy protobuf-generated files.
GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated
rm -rf $$GENERATED_DIR
mkdir -p $$GENERATED_DIR
for f in $(locations //java:gcs_java_proto); do
unzip $$f
mv org/ray/runtime/generated/* $$GENERATED_DIR
done
# Copy flatbuffers-generated files
for f in $(locations //java:generate_java_gcs_fbs); do
cp $$f $$GENERATED_DIR
done
# Install plasma jar to local maven repo.
mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \
-DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT
-1
View File
@@ -4,7 +4,6 @@ def gen_java_deps():
maven_install(
artifacts = [
"com.beust:jcommander:1.72",
"com.github.davidmoten:flatbuffers-java:1.9.0.1",
"com.google.guava:guava:27.0.1-jre",
"com.google.protobuf:protobuf-java:3.8.0",
"com.puppycrawl.tools:checkstyle:8.15",
-5
View File
@@ -31,11 +31,6 @@
<artifactId>jcommander</artifactId>
<version>1.72</version>
</dependency>
<dependency>
<groupId>com.github.davidmoten</groupId>
<artifactId>flatbuffers-java</artifactId>
<version>1.9.0.1</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
@@ -1,13 +1,17 @@
package org.ray.runtime.raylet;
import com.google.common.base.Preconditions;
import com.google.flatbuffers.FlatBufferBuilder;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
@@ -15,10 +19,8 @@ import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.generated.Arg;
import org.ray.runtime.generated.Language;
import org.ray.runtime.generated.ResourcePair;
import org.ray.runtime.generated.TaskInfo;
import org.ray.runtime.generated.Common;
import org.ray.runtime.generated.Common.TaskType;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskLanguage;
import org.ray.runtime.task.TaskSpec;
@@ -30,15 +32,6 @@ public class RayletClientImpl implements RayletClient {
private static final Logger LOGGER = LoggerFactory.getLogger(RayletClientImpl.class);
private static final int TASK_SPEC_BUFFER_SIZE = 2 * 1024 * 1024;
/**
* Direct buffers that are used to hold flatbuffer-serialized task specs.
*/
private static ThreadLocal<ByteBuffer> taskSpecBuffer = ThreadLocal.withInitial(() ->
ByteBuffer.allocateDirect(TASK_SPEC_BUFFER_SIZE).order(ByteOrder.LITTLE_ENDIAN)
);
/**
* The pointer to c++'s raylet client.
*/
@@ -86,21 +79,20 @@ public class RayletClientImpl implements RayletClient {
Preconditions.checkState(!spec.parentTaskId.isNil());
Preconditions.checkState(!spec.jobId.isNil());
ByteBuffer info = convertTaskSpecToFlatbuffer(spec);
byte[] taskSpec = convertTaskSpecToProtobuf(spec);
byte[] cursorId = null;
if (!spec.getExecutionDependencies().isEmpty()) {
//TODO(hchen): handle more than one dependencies.
cursorId = spec.getExecutionDependencies().get(0).getBytes();
}
nativeSubmitTask(client, cursorId, info, info.position(), info.remaining());
nativeSubmitTask(client, cursorId, taskSpec);
}
@Override
public TaskSpec getTask() {
byte[] bytes = nativeGetTask(client);
assert (null != bytes);
ByteBuffer bb = ByteBuffer.wrap(bytes);
return parseTaskSpecFromFlatbuffer(bb);
return parseTaskSpecFromProtobuf(bytes);
}
@Override
@@ -127,7 +119,7 @@ public class RayletClientImpl implements RayletClient {
@Override
public void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly,
boolean deleteCreatingTasks) {
boolean deleteCreatingTasks) {
byte[][] objectIdsArray = IdUtil.getIdBytes(objectIds);
nativeFreePlasmaObjects(client, objectIdsArray, localOnly, deleteCreatingTasks);
}
@@ -142,59 +134,76 @@ public class RayletClientImpl implements RayletClient {
nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes());
}
private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) {
bb.order(ByteOrder.LITTLE_ENDIAN);
TaskInfo info = TaskInfo.getRootAsTaskInfo(bb);
UniqueId jobId = UniqueId.fromByteBuffer(info.jobIdAsByteBuffer());
TaskId taskId = TaskId.fromByteBuffer(info.taskIdAsByteBuffer());
TaskId parentTaskId = TaskId.fromByteBuffer(info.parentTaskIdAsByteBuffer());
int parentCounter = info.parentCounter();
UniqueId actorCreationId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer());
int maxActorReconstructions = info.maxActorReconstructions();
UniqueId actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer());
UniqueId actorHandleId = UniqueId.fromByteBuffer(info.actorHandleIdAsByteBuffer());
int actorCounter = info.actorCounter();
int numReturns = info.numReturns();
/**
* Parse `TaskSpec` protobuf bytes.
*/
private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
Common.TaskSpec taskSpec;
try {
taskSpec = Common.TaskSpec.parseFrom(bytes);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException("Invalid protobuf data.");
}
// Deserialize new actor handles
UniqueId[] newActorHandles = IdUtil.getUniqueIdsFromByteBuffer(
info.newActorHandlesAsByteBuffer());
// Parse common fields.
UniqueId jobId = UniqueId.fromByteBuffer(taskSpec.getJobId().asReadOnlyByteBuffer());
TaskId taskId = TaskId.fromByteBuffer(taskSpec.getTaskId().asReadOnlyByteBuffer());
TaskId parentTaskId = TaskId.fromByteBuffer(taskSpec.getParentTaskId().asReadOnlyByteBuffer());
int parentCounter = (int) taskSpec.getParentCounter();
int numReturns = (int) taskSpec.getNumReturns();
Map<String, Double> resources = taskSpec.getRequiredResourcesMap();
// Deserialize args
FunctionArg[] args = new FunctionArg[info.argsLength()];
for (int i = 0; i < info.argsLength(); i++) {
Arg arg = info.args(i);
int objectIdsLength = arg.objectIdsAsByteBuffer().remaining() / UniqueId.LENGTH;
// Parse args.
FunctionArg[] args = new FunctionArg[taskSpec.getArgsCount()];
for (int i = 0; i < args.length; i++) {
Common.TaskArg arg = taskSpec.getArgs(i);
int objectIdsLength = arg.getObjectIdsCount();
if (objectIdsLength > 0) {
Preconditions.checkArgument(objectIdsLength == 1,
"This arg has more than one id: {}", objectIdsLength);
args[i] = FunctionArg.passByReference(ObjectId.fromByteBuffer(arg.objectIdsAsByteBuffer()));
args[i] = FunctionArg
.passByReference(ObjectId.fromByteBuffer(arg.getObjectIds(0).asReadOnlyByteBuffer()));
} else {
ByteBuffer lbb = arg.dataAsByteBuffer();
Preconditions.checkState(lbb != null && lbb.remaining() > 0);
byte[] data = new byte[lbb.remaining()];
lbb.get(data);
args[i] = FunctionArg.passByValue(data);
args[i] = FunctionArg.passByValue(arg.getData().toByteArray());
}
}
// Deserialize required resources;
Map<String, Double> resources = new HashMap<>();
for (int i = 0; i < info.requiredResourcesLength(); i++) {
resources.put(info.requiredResources(i).key(), info.requiredResources(i).value());
}
// Deserialize function descriptor
Preconditions.checkArgument(info.language() == Language.JAVA);
Preconditions.checkArgument(info.functionDescriptorLength() == 3);
// Parse function descriptor
Preconditions.checkArgument(taskSpec.getLanguage() == Common.Language.JAVA);
Preconditions.checkArgument(taskSpec.getFunctionDescriptorCount() == 3);
JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor(
info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2)
taskSpec.getFunctionDescriptor(0).toString(Charset.defaultCharset()),
taskSpec.getFunctionDescriptor(1).toString(Charset.defaultCharset()),
taskSpec.getFunctionDescriptor(2).toString(Charset.defaultCharset())
);
// Deserialize dynamic worker options.
// Parse ActorCreationTaskSpec.
UniqueId actorCreationId = UniqueId.NIL;
int maxActorReconstructions = 0;
UniqueId[] newActorHandles = new UniqueId[0];
List<String> dynamicWorkerOptions = new ArrayList<>();
for (int i = 0; i < info.dynamicWorkerOptionsLength(); ++i) {
dynamicWorkerOptions.add(info.dynamicWorkerOptions(i));
if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) {
Common.ActorCreationTaskSpec actorCreationTaskSpec = taskSpec.getActorCreationTaskSpec();
actorCreationId = UniqueId
.fromByteBuffer(actorCreationTaskSpec.getActorId().asReadOnlyByteBuffer());
maxActorReconstructions = (int) actorCreationTaskSpec.getMaxActorReconstructions();
dynamicWorkerOptions = ImmutableList
.copyOf(actorCreationTaskSpec.getDynamicWorkerOptionsList());
}
// Parse ActorTaskSpec.
UniqueId actorId = UniqueId.NIL;
UniqueId actorHandleId = UniqueId.NIL;
int actorCounter = 0;
if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) {
Common.ActorTaskSpec actorTaskSpec = taskSpec.getActorTaskSpec();
actorId = UniqueId.fromByteBuffer(actorTaskSpec.getActorId().asReadOnlyByteBuffer());
actorHandleId = UniqueId
.fromByteBuffer(actorTaskSpec.getActorHandleId().asReadOnlyByteBuffer());
actorCounter = (int) actorTaskSpec.getActorCounter();
newActorHandles = actorTaskSpec.getNewActorHandlesList().stream()
.map(byteString -> UniqueId.fromByteBuffer(byteString.asReadOnlyByteBuffer()))
.toArray(UniqueId[]::new);
}
return new TaskSpec(jobId, taskId, parentTaskId, parentCounter, actorCreationId,
@@ -202,122 +211,78 @@ public class RayletClientImpl implements RayletClient {
args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions);
}
private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) {
ByteBuffer bb = taskSpecBuffer.get();
bb.clear();
/**
* Convert a `TaskSpec` to protobuf-serialized bytes.
*/
private static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
// Set common fields.
Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder()
.setJobId(ByteString.copyFrom(task.jobId.getBytes()))
.setTaskId(ByteString.copyFrom(task.taskId.getBytes()))
.setParentTaskId(ByteString.copyFrom(task.parentTaskId.getBytes()))
.setParentCounter(task.parentCounter)
.setNumReturns(task.numReturns)
.putAllRequiredResources(task.resources);
FlatBufferBuilder fbb = new FlatBufferBuilder(bb);
final int jobIdOffset = fbb.createString(task.jobId.toByteBuffer());
final int taskIdOffset = fbb.createString(task.taskId.toByteBuffer());
final int parentTaskIdOffset = fbb.createString(task.parentTaskId.toByteBuffer());
final int parentCounter = task.parentCounter;
final int actorCreateIdOffset = fbb.createString(task.actorCreationId.toByteBuffer());
final int actorCreateDummyIdOffset = fbb.createString(task.actorId.toByteBuffer());
final int maxActorReconstructions = task.maxActorReconstructions;
final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer());
final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer());
final int actorCounter = task.actorCounter;
final int numReturnsOffset = task.numReturns;
// Serialize the new actor handles.
int newActorHandlesOffset
= fbb.createString(IdUtil.concatIds(task.newActorHandles));
// Serialize args
int[] argsOffsets = new int[task.args.length];
for (int i = 0; i < argsOffsets.length; i++) {
int objectIdOffset = 0;
int dataOffset = 0;
if (task.args[i].id != null) {
objectIdOffset = fbb.createString(
IdUtil.concatIds(new ObjectId[]{task.args[i].id}));
} else {
objectIdOffset = fbb.createString("");
}
if (task.args[i].data != null) {
dataOffset = fbb.createString(ByteBuffer.wrap(task.args[i].data));
}
argsOffsets[i] = Arg.createArg(fbb, objectIdOffset, dataOffset);
}
int argsOffset = fbb.createVectorOfTables(argsOffsets);
// Serialize required resources
// The required_resources vector indicates the quantities of the different
// resources required by this task. The index in this vector corresponds to
// the resource type defined in the ResourceIndex enum. For example,
int[] requiredResourcesOffsets = new int[task.resources.size()];
int i = 0;
for (Map.Entry<String, Double> entry : task.resources.entrySet()) {
int keyOffset = fbb.createString(ByteBuffer.wrap(entry.getKey().getBytes()));
requiredResourcesOffsets[i++] =
ResourcePair.createResourcePair(fbb, keyOffset, entry.getValue());
}
int requiredResourcesOffset = fbb.createVectorOfTables(requiredResourcesOffsets);
int[] requiredPlacementResourcesOffsets = new int[0];
int requiredPlacementResourcesOffset =
fbb.createVectorOfTables(requiredPlacementResourcesOffsets);
int language;
int functionDescriptorOffset;
// Set args
builder.addAllArgs(
Arrays.stream(task.args).map(arg -> {
Common.TaskArg.Builder argBuilder = Common.TaskArg.newBuilder();
if (arg.id != null) {
argBuilder.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build();
} else {
argBuilder.setData(ByteString.copyFrom(arg.data)).build();
}
return argBuilder.build();
}).collect(Collectors.toList())
);
// Set function descriptor and language.
if (task.language == TaskLanguage.JAVA) {
// This is a Java task.
language = Language.JAVA;
int[] functionDescriptorOffsets = new int[]{
fbb.createString(task.getJavaFunctionDescriptor().className),
fbb.createString(task.getJavaFunctionDescriptor().name),
fbb.createString(task.getJavaFunctionDescriptor().typeDescriptor)
};
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
builder.setLanguage(Common.Language.JAVA);
builder.addAllFunctionDescriptor(ImmutableList.of(
ByteString.copyFrom(task.getJavaFunctionDescriptor().className.getBytes()),
ByteString.copyFrom(task.getJavaFunctionDescriptor().name.getBytes()),
ByteString.copyFrom(task.getJavaFunctionDescriptor().typeDescriptor.getBytes())
));
} else {
// This is a Python task.
language = Language.PYTHON;
int[] functionDescriptorOffsets = new int[]{
fbb.createString(task.getPyFunctionDescriptor().moduleName),
fbb.createString(task.getPyFunctionDescriptor().className),
fbb.createString(task.getPyFunctionDescriptor().functionName),
fbb.createString("")
};
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
builder.setLanguage(Common.Language.PYTHON);
builder.addAllFunctionDescriptor(ImmutableList.of(
ByteString.copyFrom(task.getPyFunctionDescriptor().moduleName.getBytes()),
ByteString.copyFrom(task.getPyFunctionDescriptor().className.getBytes()),
ByteString.copyFrom(task.getPyFunctionDescriptor().functionName.getBytes()),
ByteString.EMPTY
));
}
int [] dynamicWorkerOptionsOffsets = new int[task.dynamicWorkerOptions.size()];
for (int index = 0; index < task.dynamicWorkerOptions.size(); ++index) {
dynamicWorkerOptionsOffsets[index] = fbb.createString(task.dynamicWorkerOptions.get(index));
if (!task.actorCreationId.isNil()) {
// Actor creation task.
builder.setType(TaskType.ACTOR_CREATION_TASK);
builder.setActorCreationTaskSpec(
Common.ActorCreationTaskSpec.newBuilder()
.setActorId(ByteString.copyFrom(task.actorCreationId.getBytes()))
.setMaxActorReconstructions(task.maxActorReconstructions)
.addAllDynamicWorkerOptions(task.dynamicWorkerOptions)
);
} else if (!task.actorId.isNil()) {
// Actor task.
builder.setType(TaskType.ACTOR_TASK);
List<ByteString> newHandles = Arrays.stream(task.newActorHandles)
.map(id -> ByteString.copyFrom(id.getBytes())).collect(Collectors.toList());
builder.setActorTaskSpec(
Common.ActorTaskSpec.newBuilder()
.setActorId(ByteString.copyFrom(task.actorId.getBytes()))
.setActorHandleId(ByteString.copyFrom(task.actorHandleId.getBytes()))
.setActorCreationDummyObjectId(ByteString.copyFrom(task.actorId.getBytes()))
.setActorCounter(task.actorCounter)
.addAllNewActorHandles(newHandles)
);
} else {
// Normal task.
builder.setType(TaskType.NORMAL_TASK);
}
int dynamicWorkerOptionsOffset = fbb.createVectorOfTables(dynamicWorkerOptionsOffsets);
int root = TaskInfo.createTaskInfo(
fbb,
jobIdOffset,
taskIdOffset,
parentTaskIdOffset,
parentCounter,
actorCreateIdOffset,
actorCreateDummyIdOffset,
maxActorReconstructions,
actorIdOffset,
actorHandleIdOffset,
actorCounter,
newActorHandlesOffset,
argsOffset,
numReturnsOffset,
requiredResourcesOffset,
requiredPlacementResourcesOffset,
language,
functionDescriptorOffset,
dynamicWorkerOptionsOffset);
fbb.finish(root);
ByteBuffer buffer = fbb.dataBuffer();
if (buffer.remaining() > TASK_SPEC_BUFFER_SIZE) {
LOGGER.error(
"Allocated buffer is not enough to transfer the task specification: {} vs {}",
TASK_SPEC_BUFFER_SIZE, buffer.remaining());
throw new RuntimeException("Allocated buffer is not enough to transfer to task.");
}
return buffer;
return builder.build().toByteArray();
}
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
@@ -344,10 +309,9 @@ public class RayletClientImpl implements RayletClient {
private static native long nativeInit(String localSchedulerSocket, byte[] workerId,
boolean isWorker, byte[] driverTaskId);
private static native void nativeSubmitTask(long client, byte[] cursorId, ByteBuffer taskBuff,
int pos, int taskSize) throws RayException;
private static native void nativeSubmitTask(long client, byte[] cursorId, byte[] taskSpec)
throws RayException;
// return TaskInfo (in FlatBuffer)
private static native byte[] nativeGetTask(long client) throws RayException;
private static native void nativeDestroy(long client) throws RayException;
@@ -375,5 +339,5 @@ public class RayletClientImpl implements RayletClient {
byte[] checkpointId);
private static native void nativeSetResource(long conn, String resourceName, double capacity,
byte[] nodeId) throws RayException;
byte[] nodeId) throws RayException;
}
+24 -26
View File
@@ -24,8 +24,8 @@ from ray.includes.common cimport (
)
from ray.includes.libraylet cimport (
CRayletClient,
GCSProfileEventT,
GCSProfileTableDataT,
GCSProfileEvent,
GCSProfileTableData,
ResourceMappingType,
WaitResultPair,
)
@@ -34,7 +34,7 @@ from ray.includes.unique_ids cimport (
CObjectID,
CClientID,
)
from ray.includes.task cimport CTaskSpecification
from ray.includes.task cimport CTaskSpec
from ray.includes.ray_config cimport RayConfig
from ray.utils import decode
@@ -232,18 +232,22 @@ cdef class RayletClient:
def disconnect(self):
check_status(self.client.get().Disconnect())
def submit_task(self, Task task_spec):
def submit_task(self, TaskSpec task_spec, execution_dependencies):
cdef:
CObjectID c_id
c_vector[CObjectID] c_dependencies
for dep in execution_dependencies:
c_dependencies.push_back((<ObjectID>dep).native())
check_status(self.client.get().SubmitTask(
task_spec.execution_dependencies.get()[0],
task_spec.task_spec.get()[0]))
c_dependencies, task_spec.task_spec.get()[0]))
def get_task(self):
cdef:
unique_ptr[CTaskSpecification] task_spec
unique_ptr[CTaskSpec] task_spec
with nogil:
check_status(self.client.get().GetTask(&task_spec))
return Task.make(task_spec)
return TaskSpec.make(task_spec)
def task_done(self):
check_status(self.client.get().TaskDone())
@@ -303,19 +307,19 @@ cdef class RayletClient:
def push_profile_events(self, component_type, UniqueID component_id,
node_ip_address, profile_data):
cdef:
GCSProfileTableDataT profile_info
GCSProfileEventT *profile_event
GCSProfileTableData profile_info
GCSProfileEvent *profile_event
c_string event_type
if len(profile_data) == 0:
return # Short circuit if there are no profile events.
profile_info.component_type = component_type.encode("ascii")
profile_info.component_id = component_id.binary()
profile_info.node_ip_address = node_ip_address.encode("ascii")
profile_info.set_component_type(component_type.encode("ascii"))
profile_info.set_component_id(component_id.binary())
profile_info.set_node_ip_address(node_ip_address.encode("ascii"))
for py_profile_event in profile_data:
profile_event = new GCSProfileEventT()
profile_event = profile_info.add_profile_events()
if not isinstance(py_profile_event, dict):
raise TypeError(
"Incorrect type for a profile event. Expected dict "
@@ -325,28 +329,22 @@ cdef class RayletClient:
# that will cause segfaults in the node manager.
for key_string, event_data in py_profile_event.items():
if key_string == "event_type":
profile_event.event_type = event_data.encode("ascii")
if profile_event.event_type.length() == 0:
if len(event_data) == 0:
raise ValueError(
"'event_type' should not be a null string.")
profile_event.set_event_type(event_data.encode("ascii"))
elif key_string == "start_time":
profile_event.start_time = float(event_data)
profile_event.set_start_time(float(event_data))
elif key_string == "end_time":
profile_event.end_time = float(event_data)
profile_event.set_end_time(float(event_data))
elif key_string == "extra_data":
profile_event.extra_data = event_data.encode("ascii")
if profile_event.extra_data.length() == 0:
if len(event_data) == 0:
raise ValueError(
"'extra_data' should not be a null string.")
profile_event.set_extra_data(event_data.encode("ascii"))
else:
raise ValueError(
"Unknown profile event key '%s'" % key_string)
# Note that profile_info.profile_events is a vector of unique
# pointers, so profile_event will be deallocated when profile_info
# goes out of scope. "emplace_back" of vector has not been
# supported by Cython
profile_info.profile_events.push_back(
unique_ptr[GCSProfileEventT](profile_event))
check_status(self.client.get().PushProfileEvents(profile_info))
-3
View File
@@ -2,8 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.core.generated.ray.protocol.Task import Task
from ray.core.generated.gcs_pb2 import (
ActorCheckpointIdData,
ClientTableData,
@@ -33,7 +31,6 @@ __all__ = [
"ProfileTableData",
"TablePrefix",
"TablePubsub",
"Task",
"TaskTableData",
"construct_error_message",
]
+2 -5
View File
@@ -87,17 +87,14 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
int parent_task_counter)
cdef extern from "ray/gcs/format/gcs_generated.h" nogil:
cdef cppclass GCSArg "Arg":
pass
cdef extern from "ray/protobuf/common.pb.h" nogil:
cdef cppclass CLanguage "Language":
pass
# This is a workaround for C++ enum class since Cython has no corresponding
# representation.
cdef extern from "ray/gcs/format/gcs_generated.h" namespace "Language" nogil:
cdef extern from "ray/protobuf/common.pb.h" namespace "Language" nogil:
cdef CLanguage LANGUAGE_PYTHON "Language::PYTHON"
cdef CLanguage LANGUAGE_CPP "Language::CPP"
cdef CLanguage LANGUAGE_JAVA "Language::JAVA"
+17 -17
View File
@@ -19,23 +19,23 @@ from ray.includes.unique_ids cimport (
CObjectID,
CTaskID,
)
from ray.includes.task cimport CTaskSpecification
from ray.includes.task cimport CTaskSpec
cdef extern from "ray/gcs/format/gcs_generated.h" nogil:
cdef cppclass GCSProfileEventT "ProfileEventT":
c_string event_type
double start_time
double end_time
c_string extra_data
GCSProfileEventT()
cdef extern from "ray/protobuf/gcs.pb.h" nogil:
cdef cppclass GCSProfileEvent "ProfileTableData::ProfileEvent":
void set_event_type(const c_string &value)
void set_start_time(double value)
void set_end_time(double value)
c_string set_extra_data(const c_string &value)
GCSProfileEvent()
cdef cppclass GCSProfileTableDataT "ProfileTableDataT":
c_string component_type
c_string component_id
c_string node_ip_address
c_vector[unique_ptr[GCSProfileEventT]] profile_events
GCSProfileTableDataT()
cdef cppclass GCSProfileTableData "ProfileTableData":
void set_component_type(const c_string &value)
void set_component_id(const c_string &value)
void set_node_ip_address(const c_string &value)
GCSProfileEvent *add_profile_events()
GCSProfileTableData()
ctypedef unordered_map[c_string, c_vector[pair[int64_t, double]]] \
@@ -52,8 +52,8 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
CRayStatus Disconnect()
CRayStatus SubmitTask(
const c_vector[CObjectID] &execution_dependencies,
const CTaskSpecification &task_spec)
CRayStatus GetTask(unique_ptr[CTaskSpecification] *task_spec)
const CTaskSpec &task_spec)
CRayStatus GetTask(unique_ptr[CTaskSpec] *task_spec)
CRayStatus TaskDone()
CRayStatus FetchOrReconstruct(c_vector[CObjectID] &object_ids,
c_bool fetch_only,
@@ -66,7 +66,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
CRayStatus PushError(const CJobID &job_id, const c_string &type,
const c_string &error_message, double timestamp)
CRayStatus PushProfileEvents(
const GCSProfileTableDataT &profile_events)
const GCSProfileTableData &profile_events)
CRayStatus FreeObjects(const c_vector[CObjectID] &object_ids,
c_bool local_only, c_bool delete_creating_tasks)
CRayStatus PrepareActorCheckpoint(const CActorID &actor_id,
+65 -71
View File
@@ -1,4 +1,4 @@
from libc.stdint cimport int64_t, uint8_t
from libc.stdint cimport uint8_t, uint64_t
from libcpp cimport bool as c_bool
from libcpp.memory cimport unique_ptr, shared_ptr
from libcpp.string cimport string as c_string
@@ -17,72 +17,45 @@ from ray.includes.unique_ids cimport (
CTaskID,
)
cdef extern from "ray/protobuf/common.pb.h" namespace "ray::rpc" nogil:
cdef cppclass RpcTaskSpec "ray::rpc::TaskSpec":
void CopyFrom(const RpcTaskSpec &value)
cdef extern from "ray/raylet/task_execution_spec.h" \
namespace "ray::raylet" nogil:
cdef cppclass CTaskExecutionSpecification \
"ray::raylet::TaskExecutionSpecification":
CTaskExecutionSpecification(const c_vector[CObjectID] &&dependencies)
CTaskExecutionSpecification(
const c_vector[CObjectID] &&dependencies, int num_forwards)
c_vector[CObjectID] ExecutionDependencies() const
void SetExecutionDependencies(const c_vector[CObjectID] &dependencies)
int NumForwards() const
void IncrementNumForwards()
int64_t LastTimestamp() const
void SetLastTimestamp(int64_t new_timestamp)
cdef cppclass RpcTaskExecutionSpec "ray::rpc::TaskExecutionSpec":
void CopyFrom(const RpcTaskExecutionSpec &value)
void add_dependencies(const c_string &value)
cdef cppclass RpcTask "ray::rpc::Task":
RpcTaskSpec *mutable_task_spec()
cdef extern from "ray/protobuf/gcs.pb.h" namespace "ray::rpc" nogil:
cdef cppclass TaskTableData "ray::rpc::TaskTableData":
RpcTask *mutable_task()
const c_string &SerializeAsString()
cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
cdef cppclass CTaskArgument "ray::raylet::TaskArgument":
pass
cdef cppclass CTaskArgumentByReference \
"ray::raylet::TaskArgumentByReference":
CTaskArgumentByReference(const c_vector[CObjectID] &references)
cdef cppclass CTaskArgumentByValue "ray::raylet::TaskArgumentByValue":
CTaskArgumentByValue(const uint8_t *value, size_t length)
cdef cppclass CTaskSpecification "ray::raylet::TaskSpecification":
CTaskSpecification(
const CJobID &job_id, const CTaskID &parent_task_id,
int64_t parent_counter,
const c_vector[shared_ptr[CTaskArgument]] &task_arguments,
int64_t num_returns,
const unordered_map[c_string, double] &required_resources,
const CLanguage &language,
const c_vector[c_string] &function_descriptor)
CTaskSpecification(
const CJobID &job_id, const CTaskID &parent_task_id,
int64_t parent_counter, const CActorID &actor_creation_id,
const CObjectID &actor_creation_dummy_object_id,
int64_t max_actor_reconstructions, const CActorID &actor_id,
const CActorHandleID &actor_handle_id, int64_t actor_counter,
const c_vector[CActorHandleID] &new_actor_handles,
const c_vector[shared_ptr[CTaskArgument]] &task_arguments,
int64_t num_returns,
const unordered_map[c_string, double] &required_resources,
const unordered_map[c_string, double] &required_placement_res,
const CLanguage &language,
const c_vector[c_string] &function_descriptor)
CTaskSpecification(const c_string &string)
c_string SerializeAsString() const
cdef cppclass CTaskSpec "ray::raylet::TaskSpecification":
CTaskSpec(const RpcTaskSpec message)
CTaskSpec(const c_string &serialized_binary)
const RpcTaskSpec &GetMessage()
c_string Serialize() const
CTaskID TaskId() const
CJobID JobId() const
CTaskID ParentTaskId() const
int64_t ParentCounter() const
uint64_t ParentCounter() const
c_vector[c_string] FunctionDescriptor() const
c_string FunctionDescriptorString() const
int64_t NumArgs() const
int64_t NumReturns() const
c_bool ArgByRef(int64_t arg_index) const
int ArgIdCount(int64_t arg_index) const
CObjectID ArgId(int64_t arg_index, int64_t id_index) const
CObjectID ReturnId(int64_t return_index) const
const uint8_t *ArgVal(int64_t arg_index) const
size_t ArgValLength(int64_t arg_index) const
uint64_t NumArgs() const
uint64_t NumReturns() const
c_bool ArgByRef(uint64_t arg_index) const
int ArgIdCount(uint64_t arg_index) const
CObjectID ArgId(uint64_t arg_index, uint64_t id_index) const
CObjectID ReturnId(uint64_t return_index) const
const uint8_t *ArgVal(uint64_t arg_index) const
size_t ArgValLength(uint64_t arg_index) const
double GetRequiredResource(const c_string &resource_name) const
const ResourceSet GetRequiredResources() const
const ResourceSet GetRequiredPlacementResources() const
@@ -93,25 +66,46 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
c_bool IsActorTask() const
CActorID ActorCreationId() const
CObjectID ActorCreationDummyObjectId() const
int64_t MaxActorReconstructions() const
uint64_t MaxActorReconstructions() const
CActorID ActorId() const
CActorHandleID ActorHandleId() const
int64_t ActorCounter() const
uint64_t ActorCounter() const
CObjectID ActorDummyObject() const
c_vector[CActorHandleID] NewActorHandles() const
cdef extern from "ray/raylet/task_util.h" namespace "ray::raylet" nogil:
cdef cppclass TaskSpecBuilder "ray::raylet::TaskSpecBuilder":
TaskSpecBuilder &SetCommonTaskSpec(
const CLanguage &language, const c_vector[c_string] &function_descriptor,
const CJobID &job_id, const CTaskID &parent_task_id, uint64_t parent_counter,
uint64_t num_returns, const unordered_map[c_string, double] &required_resources,
const unordered_map[c_string, double] &required_placement_resources);
TaskSpecBuilder &AddByRefArg(const CObjectID &arg_id);
TaskSpecBuilder &AddByValueArg(const c_string &data);
TaskSpecBuilder &SetActorCreationTaskSpec(
const CActorID &actor_id, uint64_t max_reconstructions,
const c_vector[c_string] &dynamic_worker_options);
TaskSpecBuilder &SetActorTaskSpec(
const CActorID &actor_id, const CActorHandleID &actor_handle_id,
const CObjectID &actor_creation_dummy_object_id, uint64_t actor_counter,
const c_vector[CActorHandleID] &new_handle_ids);
RpcTaskSpec GetMessage();
cdef extern from "ray/raylet/task_execution_spec.h" namespace "ray::raylet" nogil:
cdef cppclass CTaskExecutionSpec "ray::raylet::TaskExecutionSpecification":
CTaskExecutionSpec(RpcTaskExecutionSpec message)
CTaskExecutionSpec(const c_string &serialized_binary)
const RpcTaskExecutionSpec &GetMessage()
c_vector[CObjectID] ExecutionDependencies()
uint64_t NumForwards()
cdef extern from "ray/raylet/task.h" namespace "ray::raylet" nogil:
cdef cppclass CTask "ray::raylet::Task":
CTask(const CTaskExecutionSpecification &execution_spec,
const CTaskSpecification &task_spec)
const CTaskExecutionSpecification &GetTaskExecutionSpec() const
const CTaskSpecification &GetTaskSpecification() const
void SetExecutionDependencies(const c_vector[CObjectID] &dependencies)
void IncrementNumForwards()
const c_vector[CObjectID] &GetDependencies() const
void CopyTaskExecutionSpec(const CTask &task)
cdef c_string SerializeTaskAsString(
const c_vector[CObjectID] *dependencies,
const CTaskSpecification *task_spec)
CTask(CTaskSpec task_spec, CTaskExecutionSpec task_execution_spec)
+115 -60
View File
@@ -5,18 +5,19 @@ from libcpp.memory cimport (
static_pointer_cast,
)
from ray.includes.task cimport (
CTaskArgument,
CTaskArgumentByReference,
CTaskArgumentByValue,
CTaskSpecification,
SerializeTaskAsString,
CTask,
CTaskExecutionSpec,
CTaskSpec,
RpcTaskExecutionSpec,
TaskSpecBuilder,
TaskTableData,
)
cdef class Task:
cdef class TaskSpec:
"""Cython wrapper class of C++ `ray::raylet::TaskSpecification`."""
cdef:
unique_ptr[CTaskSpecification] task_spec
unique_ptr[c_vector[CObjectID]] execution_dependencies
unique_ptr[CTaskSpec] task_spec
def __init__(self, JobID job_id, function_descriptor, arguments,
int num_returns, TaskID parent_task_id, int parent_counter,
@@ -24,73 +25,78 @@ cdef class Task:
ObjectID actor_creation_dummy_object_id,
int32_t max_actor_reconstructions, ActorID actor_id,
ActorHandleID actor_handle_id, int actor_counter,
new_actor_handles, execution_arguments, resource_map,
placement_resource_map):
new_actor_handles, resource_map, placement_resource_map):
cdef:
TaskSpecBuilder builder
unordered_map[c_string, double] required_resources
unordered_map[c_string, double] required_placement_resources
c_vector[shared_ptr[CTaskArgument]] task_args
c_vector[CActorHandleID] task_new_actor_handles
c_vector[c_string] c_function_descriptor
c_string pickled_str
c_vector[CObjectID] references
c_vector[CActorHandleID] c_new_actor_handles
# Convert function descriptor to C++ vector.
for item in function_descriptor:
if not isinstance(item, bytes):
raise TypeError(
"'function_descriptor' takes a list of byte strings.")
c_function_descriptor.push_back(item)
# Parse the resource map.
# Convert resource map to C++ unordered_map.
if resource_map is not None:
required_resources = resource_map_from_dict(resource_map)
if placement_resource_map is not None:
required_placement_resources = (
resource_map_from_dict(placement_resource_map))
# Parse the arguments from the list.
# Build common task spec.
builder.SetCommonTaskSpec(
LANGUAGE_PYTHON,
c_function_descriptor,
job_id.native(),
parent_task_id.native(),
parent_counter,
num_returns,
required_resources,
required_placement_resources,
)
# Build arguments.
for arg in arguments:
if isinstance(arg, ObjectID):
references = c_vector[CObjectID]()
references.push_back((<ObjectID>arg).native())
task_args.push_back(
static_pointer_cast[CTaskArgument,
CTaskArgumentByReference](
make_shared[CTaskArgumentByReference](references)))
builder.AddByRefArg((<ObjectID>arg).native())
else:
pickled_str = pickle.dumps(
arg, protocol=pickle.HIGHEST_PROTOCOL)
task_args.push_back(
static_pointer_cast[CTaskArgument,
CTaskArgumentByValue](
make_shared[CTaskArgumentByValue](
<uint8_t *>pickled_str.c_str(),
pickled_str.size())))
builder.AddByValueArg(pickled_str)
for new_actor_handle in new_actor_handles:
task_new_actor_handles.push_back(
(<ActorHandleID?>new_actor_handle).native())
self.task_spec.reset(new CTaskSpecification(
job_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(),
actor_creation_dummy_object_id.native(), max_actor_reconstructions, actor_id.native(),
actor_handle_id.native(), actor_counter, task_new_actor_handles, task_args, num_returns,
required_resources, required_placement_resources, LANGUAGE_PYTHON,
c_function_descriptor))
# Set the task's execution dependencies.
self.execution_dependencies.reset(new c_vector[CObjectID]())
if execution_arguments is not None:
for execution_arg in execution_arguments:
self.execution_dependencies.get().push_back(
(<ObjectID?>execution_arg).native())
if not actor_creation_id.is_nil():
# Actor creation task.
builder.SetActorCreationTaskSpec(
actor_creation_id.native(),
max_actor_reconstructions,
[],
)
elif not actor_id.is_nil():
# Actor task.
for new_actor_handle in new_actor_handles:
c_new_actor_handles.push_back(
(<ActorHandleID?>new_actor_handle).native())
builder.SetActorTaskSpec(
actor_id.native(),
actor_handle_id.native(),
actor_creation_dummy_object_id.native(),
actor_counter,
c_new_actor_handles,
)
else:
# Normal task.
pass
self.task_spec.reset(new CTaskSpec(builder.GetMessage()))
@staticmethod
cdef make(unique_ptr[CTaskSpecification]& task_spec):
cdef Task self = Task.__new__(Task)
cdef make(unique_ptr[CTaskSpec]& task_spec):
cdef TaskSpec self = TaskSpec.__new__(TaskSpec)
self.task_spec.reset(task_spec.release())
# The created task does not include any execution dependencies.
self.execution_dependencies.reset(new c_vector[CObjectID]())
return self
@staticmethod
@@ -103,11 +109,8 @@ cdef class Task:
Returns:
Python task specification object.
"""
cdef Task self = Task.__new__(Task)
# TODO(pcm): Use flatbuffers validation here.
self.task_spec.reset(new CTaskSpecification(task_spec_str))
# The created task does not include any execution dependencies.
self.execution_dependencies.reset(new c_vector[CObjectID]())
cdef TaskSpec self = TaskSpec.__new__(TaskSpec)
self.task_spec.reset(new CTaskSpec(task_spec_str))
return self
def to_string(self):
@@ -116,11 +119,7 @@ cdef class Task:
Returns:
String representing the task specification.
"""
return self.task_spec.get().SerializeAsString()
def _serialized_raylet_task(self):
return SerializeTaskAsString(
self.execution_dependencies.get(), self.task_spec.get())
return self.task_spec.get().Serialize()
def job_id(self):
"""Return the job ID for this task."""
@@ -150,7 +149,7 @@ cdef class Task:
def arguments(self):
"""Return the arguments for the task."""
cdef:
CTaskSpecification *task_spec = self.task_spec.get()
CTaskSpec*task_spec = self.task_spec.get()
int64_t num_args = task_spec.NumArgs()
int32_t lang = <int32_t>task_spec.GetLanguage()
int count
@@ -175,7 +174,7 @@ cdef class Task:
def returns(self):
"""Return the object IDs for the return values of the task."""
cdef CTaskSpecification *task_spec = self.task_spec.get()
cdef CTaskSpec *task_spec = self.task_spec.get()
return_id_list = []
for i in range(task_spec.NumReturns()):
return_id_list.append(ObjectID(task_spec.ReturnId(i).Binary()))
@@ -221,3 +220,59 @@ cdef class Task:
def actor_counter(self):
"""Return the actor counter for this task."""
return self.task_spec.get().ActorCounter()
cdef class TaskExecutionSpec:
"""Cython wrapper class of C++ `ray::raylet::TaskExecutionSpecification`."""
cdef:
unique_ptr[CTaskExecutionSpec] c_spec
def __init__(self, execution_dependencies):
cdef:
RpcTaskExecutionSpec message;
for dependency in execution_dependencies:
message.add_dependencies(
(<ObjectID?>dependency).binary())
self.c_spec.reset(new CTaskExecutionSpec(message))
@staticmethod
def from_string(const c_string& string):
"""Convert a string to a Ray `TaskExecutionSpec` Python object.
"""
cdef TaskExecutionSpec self = TaskExecutionSpec.__new__(TaskExecutionSpec)
self.c_spec.reset(new CTaskExecutionSpec(string))
return self
def dependencies(self):
cdef:
CObjectID c_id
c_vector[CObjectID] dependencies = (
self.c_spec.get().ExecutionDependencies())
ret = []
for c_id in dependencies:
ret.append(ObjectID(c_id.Binary()))
return ret
def num_forwards(self):
return self.c_spec.get().NumForwards()
cdef class Task:
"""Cython wrapper class of C++ `ray::raylet::Task`."""
cdef:
unique_ptr[CTask] c_task
def __init__(self, TaskSpec task_spec, TaskExecutionSpec task_execution_spec):
self.c_task.reset(new CTask(task_spec.task_spec.get()[0],
task_execution_spec.c_spec.get()[0]))
def generate_gcs_task_table_data(TaskSpec task_spec):
"""Converts a Python `TaskSpec` object to serialized GCS `TaskTableData`.
"""
cdef:
TaskTableData task_table_data
task_table_data.mutable_task().mutable_task_spec().CopyFrom(
task_spec.task_spec.get().GetMessage())
return task_table_data.SerializeAsString()
+6 -11
View File
@@ -305,12 +305,9 @@ class GlobalState(object):
assert len(gcs_entries.entries) == 1
task_table_data = gcs_utils.TaskTableData.FromString(
gcs_entries.entries[0])
task_table_message = gcs_utils.Task.GetRootAsTask(
task_table_data.task, 0)
execution_spec = task_table_message.TaskExecutionSpec()
task_spec = task_table_message.TaskSpecification()
task = ray._raylet.Task.from_string(task_spec)
task = ray._raylet.TaskSpec.from_string(
task_table_data.task.task_spec.SerializeToString())
function_descriptor_list = task.function_descriptor_list()
function_descriptor = FunctionDescriptor.from_bytes_list(
function_descriptor_list)
@@ -335,14 +332,12 @@ class GlobalState(object):
"FunctionName": function_descriptor.function_name,
}
execution_spec = ray._raylet.TaskExecutionSpec.from_string(
task_table_data.task.task_execution_spec.SerializeToString())
return {
"ExecutionSpec": {
"Dependencies": [
execution_spec.Dependencies(i)
for i in range(execution_spec.DependenciesLength())
],
"LastTimestamp": execution_spec.LastTimestamp(),
"NumForwards": execution_spec.NumForwards()
"Dependencies": execution_spec.dependencies(),
"NumForwards": execution_spec.num_forwards(),
},
"TaskSpec": task_spec_info
}
+11 -11
View File
@@ -688,7 +688,7 @@ class Worker(object):
function_descriptor_list = (
function_descriptor.get_function_descriptor_list())
assert isinstance(job_id, JobID)
task = ray._raylet.Task(
task = ray._raylet.TaskSpec(
job_id,
function_descriptor_list,
args_for_raylet,
@@ -702,11 +702,10 @@ class Worker(object):
actor_handle_id,
actor_counter,
new_actor_handles,
execution_dependencies,
resources,
placement_resources,
)
self.raylet_client.submit_task(task)
self.raylet_client.submit_task(task, execution_dependencies)
return task.returns()
@@ -1864,7 +1863,7 @@ def connect(node,
nil_actor_counter = 0
function_descriptor = FunctionDescriptor.for_driver_task()
driver_task = ray._raylet.Task(
driver_task_spec = ray._raylet.TaskSpec(
worker.current_job_id,
function_descriptor.get_function_descriptor_list(),
[], # arguments.
@@ -1878,24 +1877,25 @@ def connect(node,
ActorHandleID.nil(), # actor_handle_id.
nil_actor_counter, # actor_counter.
[], # new_actor_handles.
[], # execution_dependencies.
{}, # resource_map.
{}, # placement_resource_map.
)
task_table_data = ray.gcs_utils.TaskTableData()
task_table_data.task = driver_task._serialized_raylet_task()
task_table_data = ray._raylet.generate_gcs_task_table_data(
driver_task_spec)
# Add the driver task to the task table.
ray.state.state._execute_command(
driver_task.task_id(), "RAY.TABLE_ADD",
driver_task_spec.task_id(),
"RAY.TABLE_ADD",
ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"),
ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"),
driver_task.task_id().binary(),
task_table_data.SerializeToString())
driver_task_spec.task_id().binary(),
task_table_data,
)
# Set the driver's current task ID to the task ID assigned to the
# driver task.
worker.task_context.current_task_id = driver_task.task_id()
worker.task_context.current_task_id = driver_task_spec.task_id()
worker.raylet_client = ray._raylet.RayletClient(
node.raylet_socket_name,
+3 -5
View File
@@ -28,11 +28,10 @@ ray_files = [
"ray/dashboard/res/main.css", "ray/dashboard/res/main.js"
]
# These are the directories where automatically generated Python flatbuffer
# These are the directories where automatically generated Python protobuf
# bindings are created.
generated_python_directories = [
"ray/core/generated", "ray/core/generated/ray",
"ray/core/generated/ray/protocol"
"ray/core/generated",
]
optional_ray_files = []
@@ -88,7 +87,7 @@ class build_ext(_build_ext.build_ext):
files_to_include = ray_files + pyarrow_files + modin_files
# Copy over the autogenerated flatbuffer Python bindings.
# Copy over the autogenerated protobuf Python bindings.
for directory in generated_python_directories:
for filename in os.listdir(directory):
if filename[-3:] == ".py":
@@ -148,7 +147,6 @@ requires = [
# NOTE: Don't upgrade the version of six! Doing so causes installation
# problems. See https://github.com/ray-project/ray/issues/4169.
"six >= 1.0.0",
"flatbuffers",
"faulthandler;python_version<'3.3'",
"protobuf >= 3.8.0",
]
-22
View File
@@ -6,28 +6,6 @@ std::string string_from_flatbuf(const flatbuffers::String &string) {
return std::string(string.data(), string.size());
}
const std::unordered_map<std::string, double> map_from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<ResourcePair>> &resource_vector) {
std::unordered_map<std::string, double> required_resources;
for (int64_t i = 0; i < resource_vector.size(); i++) {
const ResourcePair *resource_pair = resource_vector.Get(i);
required_resources[string_from_flatbuf(*resource_pair->key())] =
resource_pair->value();
}
return required_resources;
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<ResourcePair>>>
map_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
const std::unordered_map<std::string, double> &resource_map) {
std::vector<flatbuffers::Offset<ResourcePair>> resource_vector;
for (auto const &resource_pair : resource_map) {
resource_vector.push_back(CreateResourcePair(
fbb, fbb.CreateString(resource_pair.first), resource_pair.second));
}
return fbb.CreateVector(resource_vector);
}
std::vector<std::string> string_vec_from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> &flatbuf_vec) {
std::vector<std::string> string_vector;
+1 -20
View File
@@ -1,8 +1,7 @@
#ifndef COMMON_PROTOCOL_H
#define COMMON_PROTOCOL_H
#include "ray/gcs/format/gcs_generated.h"
#include <flatbuffers/flatbuffers.h>
#include <unordered_map>
#include "ray/common/id.h"
@@ -76,24 +75,6 @@ to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector<ID> &ids);
/// @return The std::string version of the flatbuffer string.
std::string string_from_flatbuf(const flatbuffers::String &string);
/// Convert a std::unordered_map to a flatbuffer vector of pairs.
///
/// @param fbb Reference to the flatbuffer builder.
/// @param resource_map A mapping from resource name to resource quantity.
/// @return A flatbuffer vector of ResourcePair objects.
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<ResourcePair>>>
map_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
const std::unordered_map<std::string, double> &resource_map);
/// Convert a flatbuffer vector of ResourcePair objects to a std::unordered map
/// from resource name to resource quantity.
///
/// @param fbb Reference to the flatbuffer builder.
/// @param resource_vector The flatbuffer object.
/// @return A map from resource name to resource quantity.
const std::unordered_map<std::string, double> map_from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<ResourcePair>> &resource_vector);
std::vector<std::string> string_vec_from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> &flatbuf_vec);
+3 -3
View File
@@ -5,12 +5,14 @@
#include "ray/common/buffer.h"
#include "ray/common/id.h"
#include "ray/gcs/format/gcs_generated.h"
#include "ray/raylet/raylet_client.h"
#include "ray/raylet/task_spec.h"
namespace ray {
using rpc::Language;
using rpc::TaskType;
/// Type of this worker.
enum class WorkerType { WORKER, DRIVER };
@@ -66,8 +68,6 @@ class TaskArg {
const std::shared_ptr<Buffer> data_;
};
enum class TaskType { NORMAL_TASK, ACTOR_CREATION_TASK, ACTOR_TASK };
/// Information of a task
struct TaskInfo {
/// The ID of task.
+1 -1
View File
@@ -3,7 +3,7 @@
namespace ray {
CoreWorker::CoreWorker(const enum WorkerType worker_type, const ::Language language,
CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id)
: worker_type_(worker_type),
+3 -4
View File
@@ -7,7 +7,6 @@
#include "ray/core_worker/object_interface.h"
#include "ray/core_worker/task_execution.h"
#include "ray/core_worker/task_interface.h"
#include "ray/gcs/format/gcs_generated.h"
#include "ray/raylet/raylet_client.h"
namespace ray {
@@ -23,7 +22,7 @@ class CoreWorker {
/// \param[in] langauge Language of this worker.
///
/// NOTE(zhijunfu): the constructor would throw if a failure happens.
CoreWorker(const WorkerType worker_type, const ::Language language,
CoreWorker(const WorkerType worker_type, const Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id = JobID::Nil());
@@ -31,7 +30,7 @@ class CoreWorker {
enum WorkerType WorkerType() const { return worker_type_; }
/// Language of this worker.
::Language Language() const { return language_; }
enum Language Language() const { return language_; }
/// Return the `CoreWorkerTaskInterface` that contains the methods related to task
/// submisson.
@@ -53,7 +52,7 @@ class CoreWorker {
const enum WorkerType worker_type_;
/// Language of this worker.
const ::Language language_;
const enum Language language_;
/// raylet socket name.
const std::string raylet_socket_;
+1 -2
View File
@@ -301,8 +301,7 @@ TEST_F(ZeroNodeTest, TestWorkerContext) {
}
TEST_F(ZeroNodeTest, TestActorHandle) {
ActorHandle handle1(ActorID::FromRandom(), ActorHandleID::FromRandom(),
::Language::JAVA,
ActorHandle handle1(ActorID::FromRandom(), ActorHandleID::FromRandom(), Language::JAVA,
{"org.ray.exampleClass", "exampleMethod", "exampleSignature"});
auto forkedHandle1 = handle1.Fork();
+60 -90
View File
@@ -8,11 +8,11 @@ namespace ray {
ActorHandle::ActorHandle(
const class ActorID &actor_id, const class ActorHandleID &actor_handle_id,
const ::Language actor_language,
const Language actor_language,
const std::vector<std::string> &actor_creation_task_function_descriptor) {
inner_.set_actor_id(actor_id.Data(), actor_id.Size());
inner_.set_actor_handle_id(actor_handle_id.Data(), actor_handle_id.Size());
inner_.set_actor_language(static_cast<int>(actor_language));
inner_.set_actor_language(actor_language);
*inner_.mutable_actor_creation_task_function_descriptor() = {
actor_creation_task_function_descriptor.begin(),
actor_creation_task_function_descriptor.end()};
@@ -30,9 +30,7 @@ ray::ActorHandleID ActorHandle::ActorHandleID() const {
return ActorHandleID::FromBinary(inner_.actor_handle_id());
};
::Language ActorHandle::ActorLanguage() const {
return (::Language)inner_.actor_language();
};
Language ActorHandle::ActorLanguage() const { return inner_.actor_language(); };
std::vector<std::string> ActorHandle::ActorCreationTaskFunctionDescriptor() const {
return ray::rpc::VectorFromProtobuf(inner_.actor_creation_task_function_descriptor());
@@ -100,30 +98,43 @@ CoreWorkerTaskInterface::CoreWorkerTaskInterface(
new CoreWorkerRayletTaskSubmitter(raylet_client)));
}
Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function,
const std::vector<TaskArg> &args,
const TaskOptions &task_options,
std::vector<ObjectID> *return_ids) {
auto &context = worker_context_;
auto next_task_index = context.GetNextTaskIndex();
const auto task_id = GenerateTaskId(context.GetCurrentJobID(),
context.GetCurrentTaskID(), next_task_index);
raylet::TaskSpecBuilder CoreWorkerTaskInterface::BuildCommonTaskSpec(
const RayFunction &function, const std::vector<TaskArg> &args, uint64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources,
std::vector<ObjectID> *return_ids) {
raylet::TaskSpecBuilder builder;
auto next_task_index = worker_context_.GetNextTaskIndex();
// Build common task spec.
builder.SetCommonTaskSpec(
function.language, function.function_descriptor, worker_context_.GetCurrentJobID(),
worker_context_.GetCurrentTaskID(), next_task_index, num_returns,
required_resources, required_placement_resources);
// Set task arguments.
for (const auto &arg : args) {
if (arg.IsPassedByReference()) {
builder.AddByRefArg(arg.GetReference());
} else {
builder.AddByValueArg(arg.GetValue()->Data(), arg.GetValue()->Size());
}
}
auto num_returns = task_options.num_returns;
// Compute return IDs.
const auto task_id = TaskID::FromBinary(builder.GetMessage().task_id());
(*return_ids).resize(num_returns);
for (int i = 0; i < num_returns; i++) {
(*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1);
}
return builder;
}
auto task_arguments = BuildTaskArguments(args);
ray::raylet::TaskSpecification spec(context.GetCurrentJobID(),
context.GetCurrentTaskID(), next_task_index,
task_arguments, num_returns, task_options.resources,
function.language, function.function_descriptor);
std::vector<ObjectID> execution_dependencies;
TaskSpec task(std::move(spec), execution_dependencies);
Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function,
const std::vector<TaskArg> &args,
const TaskOptions &task_options,
std::vector<ObjectID> *return_ids) {
auto builder = BuildCommonTaskSpec(function, args, task_options.num_returns,
task_options.resources, {}, return_ids);
TaskSpec task(builder.Build(), {});
return task_submitters_[static_cast<int>(TaskTransportType::RAYLET)]->SubmitTask(task);
}
@@ -131,33 +142,20 @@ Status CoreWorkerTaskInterface::CreateActor(
const RayFunction &function, const std::vector<TaskArg> &args,
const ActorCreationOptions &actor_creation_options,
std::unique_ptr<ActorHandle> *actor_handle) {
auto &context = worker_context_;
auto next_task_index = context.GetNextTaskIndex();
const auto task_id = GenerateTaskId(context.GetCurrentJobID(),
context.GetCurrentTaskID(), next_task_index);
std::vector<ObjectID> return_ids;
return_ids.push_back(ObjectID::ForTaskReturn(task_id, 1));
ActorID actor_creation_id = ActorID::FromBinary(return_ids[0].Binary());
*actor_handle = std::unique_ptr<ActorHandle>(
new ActorHandle(actor_creation_id, ActorHandleID::Nil(), function.language,
function.function_descriptor));
auto builder = BuildCommonTaskSpec(function, args, 1, actor_creation_options.resources,
actor_creation_options.resources, &return_ids);
const ActorID actor_id = ActorID::FromBinary(return_ids[0].Binary());
builder.SetActorCreationTaskSpec(actor_id, actor_creation_options.max_reconstructions,
{});
*actor_handle = std::unique_ptr<ActorHandle>(new ActorHandle(
actor_id, ActorHandleID::Nil(), function.language, function.function_descriptor));
(*actor_handle)->IncreaseTaskCounter();
(*actor_handle)->SetActorCursor(return_ids[0]);
auto task_arguments = BuildTaskArguments(args);
// Note that the caller is supposed to specify required placement resources
// correctly via actor_creation_options.resources.
ray::raylet::TaskSpecification spec(
context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index,
actor_creation_id, ObjectID::Nil(), actor_creation_options.max_reconstructions,
ActorID::Nil(), ActorHandleID::Nil(), 0, {}, task_arguments, 1,
actor_creation_options.resources, actor_creation_options.resources,
function.language, function.function_descriptor);
std::vector<ObjectID> execution_dependencies;
TaskSpec task(std::move(spec), execution_dependencies);
const TaskSpec task(builder.Build(), {});
return task_submitters_[static_cast<int>(TaskTransportType::RAYLET)]->SubmitTask(task);
}
@@ -166,65 +164,37 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle,
const std::vector<TaskArg> &args,
const TaskOptions &task_options,
std::vector<ObjectID> *return_ids) {
auto &context = worker_context_;
auto next_task_index = context.GetNextTaskIndex();
const auto task_id = GenerateTaskId(context.GetCurrentJobID(),
context.GetCurrentTaskID(), next_task_index);
// Add one for actor cursor object id.
const auto num_returns = task_options.num_returns + 1;
// add one for actor cursor object id.
auto num_returns = task_options.num_returns + 1;
(*return_ids).resize(num_returns);
for (int i = 0; i < num_returns; i++) {
(*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1);
}
auto actor_creation_dummy_object_id =
ObjectID::FromBinary(actor_handle.ActorID().Binary());
auto task_arguments = BuildTaskArguments(args);
// Build common task spec.
auto builder = BuildCommonTaskSpec(function, args, num_returns, task_options.resources,
{}, return_ids);
std::unique_lock<std::mutex> guard(actor_handle.mutex_);
// Build actor task spec.
const auto actor_creation_dummy_object_id =
ObjectID::FromBinary(actor_handle.ActorID().Binary());
builder.SetActorTaskSpec(actor_handle.ActorID(), actor_handle.ActorHandleID(),
actor_creation_dummy_object_id,
actor_handle.IncreaseTaskCounter(),
actor_handle.NewActorHandles());
ray::raylet::TaskSpecification spec(
context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index,
ActorID::Nil(), actor_creation_dummy_object_id, 0, actor_handle.ActorID(),
actor_handle.ActorHandleID(), actor_handle.IncreaseTaskCounter(),
actor_handle.NewActorHandles(), task_arguments, num_returns, task_options.resources,
task_options.resources, function.language, function.function_descriptor);
std::vector<ObjectID> execution_dependencies;
execution_dependencies.push_back(actor_handle.ActorCursor());
const TaskSpec task(builder.Build(), {actor_handle.ActorCursor()});
// Manipulate actor handle state.
auto actor_cursor = (*return_ids).back();
actor_handle.SetActorCursor(actor_cursor);
actor_handle.ClearNewActorHandles();
guard.unlock();
TaskSpec task(std::move(spec), execution_dependencies);
// Submit task.
auto status =
task_submitters_[static_cast<int>(TaskTransportType::RAYLET)]->SubmitTask(task);
// remove cursor from return ids.
// Remove cursor from return ids.
(*return_ids).pop_back();
return status;
}
std::vector<std::shared_ptr<raylet::TaskArgument>>
CoreWorkerTaskInterface::BuildTaskArguments(const std::vector<TaskArg> &args) {
std::vector<std::shared_ptr<raylet::TaskArgument>> task_arguments;
for (const auto &arg : args) {
if (arg.IsPassedByReference()) {
std::vector<ObjectID> references{arg.GetReference()};
task_arguments.push_back(
std::make_shared<raylet::TaskArgumentByReference>(references));
} else {
auto data = arg.GetValue();
task_arguments.push_back(
std::make_shared<raylet::TaskArgumentByValue>(data->Data(), data->Size()));
}
}
return task_arguments;
}
} // namespace ray
+20 -7
View File
@@ -9,10 +9,14 @@
#include "ray/core_worker/transport/transport.h"
#include "ray/protobuf/core_worker.pb.h"
#include "ray/raylet/task.h"
#include "ray/raylet/task_spec.h"
#include "ray/raylet/task_util.h"
#include "ray/rpc/util.h"
namespace ray {
using rpc::Language;
class CoreWorker;
/// Options of a non-actor-creation task.
@@ -45,7 +49,7 @@ struct ActorCreationOptions {
class ActorHandle {
public:
ActorHandle(const ActorID &actor_id, const ActorHandleID &actor_handle_id,
const ::Language actor_language,
const Language actor_language,
const std::vector<std::string> &actor_creation_task_function_descriptor);
ActorHandle(const ActorHandle &other);
@@ -57,7 +61,7 @@ class ActorHandle {
ray::ActorHandleID ActorHandleID() const;
/// Language of the actor.
::Language ActorLanguage() const;
Language ActorLanguage() const;
// Function descriptor of actor creation task.
std::vector<std::string> ActorCreationTaskFunctionDescriptor() const;
@@ -149,12 +153,21 @@ class CoreWorkerTaskInterface {
std::vector<ObjectID> *return_ids);
private:
/// Build the arguments for a task spec.
/// Build common attributes of the task spec, and compute return ids.
///
/// \param[in] args Arguments of a task.
/// \return Arguments as required by task spec.
std::vector<std::shared_ptr<raylet::TaskArgument>> BuildTaskArguments(
const std::vector<TaskArg> &args);
/// \param[in] function The remote function to execute.
/// \param[in] args Arguments of this task.
/// \param[in] num_returns Number of returns.
/// \param[in] required_resources Resources required by this task.
/// \param[in] required_placement_resources Resources required by placing this task on a
/// node.
/// \param[out] return_ids Return IDs.
/// \return A `TaskSpecBuilder`.
raylet::TaskSpecBuilder BuildCommonTaskSpec(
const RayFunction &function, const std::vector<TaskArg> &args, uint64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources,
std::vector<ObjectID> *return_ids);
/// Reference to the parent CoreWorker's context.
WorkerContext &worker_context_;
@@ -38,11 +38,8 @@ CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(
void CoreWorkerRayletTaskReceiver::HandleAssignTask(
const rpc::AssignTaskRequest &request, rpc::AssignTaskReply *reply,
rpc::RequestDoneCallback done_callback) {
const std::string &task_message = request.task_spec();
const raylet::Task task(*flatbuffers::GetRoot<protocol::Task>(
reinterpret_cast<const uint8_t *>(task_message.data())));
const raylet::Task task(request.task());
const auto &spec = task.GetTaskSpecification();
auto status = task_handler_(spec);
done_callback(status);
}
+51 -44
View File
@@ -13,10 +13,6 @@ namespace ray {
namespace gcs {
namespace {
constexpr char kRandomId[] = "abcdefghijklmnopqrst";
} // namespace
/* Flush redis. */
static inline void flushall_redis(void) {
redisContext *context = redisConnect("127.0.0.1", 6379);
@@ -82,23 +78,40 @@ class TestGcsWithChainAsio : public TestGcsWithAsio {
TestGcsWithChainAsio() : TestGcsWithAsio(gcs::CommandType::kChain){};
};
void TestTableLookup(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
TaskID task_id = TaskID::FromRandom();
/// A helper function that creates a GCS `TaskTableData` object.
std::shared_ptr<TaskTableData> CreateTaskTableData(const TaskID &task_id,
uint64_t num_returns = 0) {
auto data = std::make_shared<TaskTableData>();
data->set_task("123");
data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary());
data->mutable_task()->mutable_task_spec()->set_num_returns(num_returns);
return data;
}
/// A helper function that compare wether 2 `TaskTableData` objects are equal.
/// Note, this function only compares fields set by `CreateTaskTableData`.
bool TaskTableDataEqual(const TaskTableData &data1, const TaskTableData &data2) {
const auto &spec1 = data1.task().task_spec();
const auto &spec2 = data2.task().task_spec();
return (spec1.task_id() == spec2.task_id() &&
spec1.num_returns() == spec2.num_returns());
}
void TestTableLookup(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
const auto task_id = TaskID::FromRandom();
const auto data = CreateTaskTableData(task_id);
// Check that we added the correct task.
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id,
const TaskTableData &d) {
ASSERT_EQ(id, task_id);
ASSERT_EQ(data->task(), d.task());
ASSERT_TRUE(TaskTableDataEqual(*data, d));
};
// Check that the lookup returns the added task.
auto lookup_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id,
const TaskTableData &d) {
ASSERT_EQ(id, task_id);
ASSERT_EQ(data->task(), d.task());
ASSERT_TRUE(TaskTableDataEqual(*data, d));
test->Stop();
};
@@ -386,7 +399,7 @@ void TestDeleteKeysFromTable(const JobID &job_id,
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id,
const TaskTableData &d) {
ASSERT_EQ(id, task_id);
ASSERT_EQ(data->task(), d.task());
ASSERT_TRUE(TaskTableDataEqual(*data, d));
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback));
@@ -501,9 +514,7 @@ void TestDeleteKeys(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> cl
std::vector<std::shared_ptr<TaskTableData>> task_vector;
auto AppendTaskData = [&task_vector](size_t add_count) {
for (size_t i = 0; i < add_count; ++i) {
auto task_data = std::make_shared<TaskTableData>();
task_data->set_task(ObjectID::FromRandom().Hex());
task_vector.push_back(task_data);
task_vector.push_back(CreateTaskTableData(TaskID::FromRandom()));
}
};
AppendTaskData(1);
@@ -682,25 +693,26 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeAll) {
void TestTableSubscribeId(const JobID &job_id,
std::shared_ptr<gcs::AsyncGcsClient> client) {
int num_modifications = 3;
// Add a table entry.
TaskID task_id1 = TaskID::FromRandom();
std::vector<std::string> task_specs1 = {"abc", "def", "ghi"};
// Add a table entry at a second key.
TaskID task_id2 = TaskID::FromRandom();
std::vector<std::string> task_specs2 = {"jkl", "mno", "pqr"};
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto notification_callback = [task_id2, task_specs2](gcs::AsyncGcsClient *client,
const TaskID &id,
const TaskTableData &data) {
auto notification_callback = [task_id2, num_modifications](gcs::AsyncGcsClient *client,
const TaskID &id,
const TaskTableData &data) {
// Check that we only get notifications for the requested key.
ASSERT_EQ(id, task_id2);
// Check that we get notifications in the same order as the writes.
ASSERT_EQ(data.task(), task_specs2[test->NumCallbacks()]);
ASSERT_TRUE(
TaskTableDataEqual(data, *CreateTaskTableData(task_id2, test->NumCallbacks())));
test->IncrementNumCallbacks();
if (test->NumCallbacks() == task_specs2.size()) {
if (test->NumCallbacks() == num_modifications) {
test->Stop();
}
};
@@ -717,21 +729,19 @@ void TestTableSubscribeId(const JobID &job_id,
// The callback for subscription success. Once we've subscribed, request
// notifications for only one of the keys, then write to both keys.
auto subscribe_callback = [job_id, task_id1, task_id2, task_specs1,
task_specs2](gcs::AsyncGcsClient *client) {
auto subscribe_callback = [job_id, task_id1, task_id2,
num_modifications](gcs::AsyncGcsClient *client) {
// Request notifications for one of the keys.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id2, client->client_table().GetLocalClientId()));
// Write both keys. We should only receive notifications for the key that
// we requested them for.
for (const auto &task_spec : task_specs1) {
auto data = std::make_shared<TaskTableData>();
data->set_task(task_spec);
for (uint64_t i = 0; i < num_modifications; i++) {
auto data = CreateTaskTableData(task_id1, i);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id1, data, nullptr));
}
for (const auto &task_spec : task_specs2) {
auto data = std::make_shared<TaskTableData>();
data->set_task(task_spec);
for (uint64_t i = 0; i < num_modifications; i++) {
auto data = CreateTaskTableData(task_id2, i);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id2, data, nullptr));
}
};
@@ -749,7 +759,7 @@ void TestTableSubscribeId(const JobID &job_id,
ASSERT_TRUE(failure_notification_received);
// Check that we received one notification callback for each write to the
// requested key.
ASSERT_EQ(test->NumCallbacks(), task_specs2.size());
ASSERT_EQ(test->NumCallbacks(), num_modifications);
}
TEST_MACRO(TestGcsWithAsio, TestTableSubscribeId);
@@ -910,10 +920,9 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeId) {
void TestTableSubscribeCancel(const JobID &job_id,
std::shared_ptr<gcs::AsyncGcsClient> client) {
// Add a table entry.
TaskID task_id = TaskID::FromRandom();
std::vector<std::string> task_specs = {"jkl", "mno", "pqr"};
auto data = std::make_shared<TaskTableData>();
data->set_task(task_specs[0]);
const auto task_id = TaskID::FromRandom();
const int num_modifications = 3;
const auto data = CreateTaskTableData(task_id, 0);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr));
// The failure callback should not be called since all keys are non-empty
@@ -924,26 +933,26 @@ void TestTableSubscribeCancel(const JobID &job_id,
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto notification_callback = [task_id, task_specs](gcs::AsyncGcsClient *client,
const TaskID &id,
const TaskTableData &data) {
auto notification_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id,
const TaskTableData &data) {
ASSERT_EQ(id, task_id);
// Check that we only get notifications for the first and last writes,
// since notifications are canceled in between.
if (test->NumCallbacks() == 0) {
ASSERT_EQ(data.task(), task_specs.front());
ASSERT_TRUE(TaskTableDataEqual(data, *CreateTaskTableData(task_id, 0)));
} else {
ASSERT_EQ(data.task(), task_specs.back());
ASSERT_TRUE(
TaskTableDataEqual(data, *CreateTaskTableData(task_id, num_modifications - 1)));
}
test->IncrementNumCallbacks();
if (test->NumCallbacks() == 2) {
if (test->NumCallbacks() == num_modifications - 1) {
test->Stop();
}
};
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto subscribe_callback = [job_id, task_id, task_specs](gcs::AsyncGcsClient *client) {
auto subscribe_callback = [job_id, task_id](gcs::AsyncGcsClient *client) {
// Request notifications, then cancel immediately. We should receive a
// notification for the current value at the key.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
@@ -952,10 +961,8 @@ void TestTableSubscribeCancel(const JobID &job_id,
job_id, task_id, client->client_table().GetLocalClientId()));
// Write to the key. Since we canceled notifications, we should not receive
// a notification for these writes.
auto remaining = std::vector<std::string>(++task_specs.begin(), task_specs.end());
for (const auto &task_spec : remaining) {
auto data = std::make_shared<TaskTableData>();
data->set_task(task_spec);
for (uint64_t i = 1; i < num_modifications; i++) {
auto data = CreateTaskTableData(task_id, i);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr));
}
// Request notifications again. We should receive a notification for the
-105
View File
@@ -1,105 +0,0 @@
// TODO(hchen): Migrate data structures in this file to protobuf (`gcs.proto`).
enum Language:int {
PYTHON=0,
JAVA=1,
CPP=2,
}
table Arg {
// Object ID for pass-by-reference arguments. Normally there is only one
// object ID in this list which represents the object that is being passed.
// However to support reducers in a MapReduce workload, we also support
// passing multiple object IDs for each argument.
// Note that this is a long string that concatenate all of the object IDs.
object_ids: string;
// Data for pass-by-value arguments.
data: string;
}
table TaskInfo {
// ID of the job that created this task.
job_id: string;
// Task ID of the task.
task_id: string;
// Task ID of the parent task.
parent_task_id: string;
// A count of the number of tasks submitted by the parent task before this one.
parent_counter: int;
// The ID of the actor to create if this is an actor creation task.
actor_creation_id: string;
// The dummy object ID of the actor creation task if this is an actor method.
actor_creation_dummy_object_id: string;
// The max number of times this actor should be recontructed.
// If this number of 0 or negative, the actor won't be reconstructed on failure.
max_actor_reconstructions: int;
// Actor ID of the task. This is the actor that this task is executed on
// or NIL_ACTOR_ID if the task is just a normal task.
actor_id: string;
// The ID of the handle that was used to submit the task. This should be
// unique across handles with the same actor_id.
actor_handle_id: string;
// Number of tasks that have been submitted to this actor so far.
actor_counter: int;
// If this is an actor task, then this will be populated with all of the new
// actor handles that were forked from this handle since the last task on
// this handle was submitted.
// Note that this is a long string that concatenate all of the new_actor_handle IDs.
new_actor_handles: string;
// Task arguments.
args: [Arg];
// Number of return objects.
num_returns: int;
// The required_resources vector indicates the quantities of the different
// resources required by this task.
required_resources: [ResourcePair];
// The resources required for placing this task on a node. If this is empty,
// then the placement resources are equal to the required_resources.
required_placement_resources: [ResourcePair];
// The language that this task belongs to.
language: Language;
// Function descriptor, which is a list of strings that can
// uniquely describe a function.
// For a Python function, it should be: [module_name, class_name, function_name]
// For a Java function, it should be: [class_name, method_name, type_descriptor]
function_descriptor: [string];
// The dynamic options used in the worker command when starting the worker process for
// an actor creation task. If the list isn't empty, the options will be used to replace
// the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the
// worker command.
dynamic_worker_options: [string];
}
table ResourcePair {
// The name of the resource.
key: string;
// The quantity of the resource.
value: double;
}
table ProfileEvent {
// The type of the event.
event_type: string;
// The start time of the event.
start_time: double;
// The end time of the event. If the event is a point event, then this should
// be the same as the start time.
end_time: double;
// Additional data associated with the event. This data must be serialized
// using JSON.
extra_data: string;
}
table ProfileTableData {
// The type of the component that generated the event, e.g., worker or
// object_manager, or node_manager.
component_type: string;
// An identifier for the component that generated the event.
component_id: string;
// An identifier for the node that generated the event.
node_ip_address: string;
// This is a batch of profiling events. We batch these together for
// performance reasons because a single task may generate many events, and
// we don't want each event to require a GCS command.
profile_events: [ProfileEvent];
}
-24
View File
@@ -1,24 +0,0 @@
#ifndef RAY_RAYLET_GCS_FORMAT_UTIL_H
#define RAY_RAYLET_GCS_FORMAT_UTIL_H
#include "ray/gcs/format/gcs_generated.h"
namespace std {
template <>
struct hash<Language> {
size_t operator()(const Language &language) const {
return std::hash<int32_t>()(static_cast<int32_t>(language));
}
};
template <>
struct hash<const Language> {
size_t operator()(const Language &language) const {
return std::hash<int32_t>()(static_cast<int32_t>(language));
}
};
} // namespace std
#endif // RAY_RAYLET_GCS_FORMAT_UTIL_H
@@ -4,7 +4,6 @@
#include "ray/common/common_protocol.h"
#include "ray/common/id.h"
#include "ray/common/status.h"
#include "ray/gcs/format/gcs_generated.h"
#include "ray/protobuf/gcs.pb.h"
#include "ray/util/logging.h"
#include "redis_string.h"
+11 -7
View File
@@ -40,7 +40,8 @@ namespace gcs {
template <typename ID, typename Data>
Status Log<ID, Data>::Append(const JobID &job_id, const ID &id,
std::shared_ptr<Data> &data, const WriteCallback &done) {
const std::shared_ptr<Data> &data,
const WriteCallback &done) {
num_appends_++;
auto callback = [this, id, data, done](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
@@ -59,8 +60,9 @@ Status Log<ID, Data>::Append(const JobID &job_id, const ID &id,
template <typename ID, typename Data>
Status Log<ID, Data>::AppendAt(const JobID &job_id, const ID &id,
std::shared_ptr<Data> &data, const WriteCallback &done,
const WriteCallback &failure, int log_length) {
const std::shared_ptr<Data> &data,
const WriteCallback &done, const WriteCallback &failure,
int log_length) {
num_appends_++;
auto callback = [this, id, data, done, failure](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
@@ -226,7 +228,8 @@ std::string Log<ID, Data>::DebugString() const {
template <typename ID, typename Data>
Status Table<ID, Data>::Add(const JobID &job_id, const ID &id,
std::shared_ptr<Data> &data, const WriteCallback &done) {
const std::shared_ptr<Data> &data,
const WriteCallback &done) {
num_adds_++;
auto callback = [this, id, data, done](const CallbackReply &reply) {
if (done != nullptr) {
@@ -288,8 +291,8 @@ std::string Table<ID, Data>::DebugString() const {
}
template <typename ID, typename Data>
Status Set<ID, Data>::Add(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
const WriteCallback &done) {
Status Set<ID, Data>::Add(const JobID &job_id, const ID &id,
const std::shared_ptr<Data> &data, const WriteCallback &done) {
num_adds_++;
auto callback = [this, id, data, done](const CallbackReply &reply) {
if (done != nullptr) {
@@ -303,7 +306,8 @@ Status Set<ID, Data>::Add(const JobID &job_id, const ID &id, std::shared_ptr<Dat
template <typename ID, typename Data>
Status Set<ID, Data>::Remove(const JobID &job_id, const ID &id,
std::shared_ptr<Data> &data, const WriteCallback &done) {
const std::shared_ptr<Data> &data,
const WriteCallback &done) {
num_removes_++;
auto callback = [this, id, data, done](const CallbackReply &reply) {
if (done != nullptr) {
+15 -14
View File
@@ -67,10 +67,10 @@ class LogInterface {
public:
using WriteCallback =
std::function<void(AsyncGcsClient *client, const ID &id, const Data &data)>;
virtual Status Append(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
const WriteCallback &done) = 0;
virtual Status Append(const JobID &job_id, const ID &id,
const std::shared_ptr<Data> &data, const WriteCallback &done) = 0;
virtual Status AppendAt(const JobID &job_id, const ID &task_id,
std::shared_ptr<Data> &data, const WriteCallback &done,
const std::shared_ptr<Data> &data, const WriteCallback &done,
const WriteCallback &failure, int log_length) = 0;
virtual ~LogInterface(){};
};
@@ -126,7 +126,7 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
Status Append(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
Status Append(const JobID &job_id, const ID &id, const std::shared_ptr<Data> &data,
const WriteCallback &done);
/// Append a log entry to a key if and only if the log has the given number
@@ -141,7 +141,7 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
/// \param log_length The number of entries that the log must have for the
/// append to succeed.
/// \return Status
Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
Status AppendAt(const JobID &job_id, const ID &id, const std::shared_ptr<Data> &data,
const WriteCallback &done, const WriteCallback &failure,
int log_length);
@@ -272,8 +272,8 @@ template <typename ID, typename Data>
class TableInterface {
public:
using WriteCallback = typename Log<ID, Data>::WriteCallback;
virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr<Data> &data,
const WriteCallback &done) = 0;
virtual Status Add(const JobID &job_id, const ID &task_id,
const std::shared_ptr<Data> &data, const WriteCallback &done) = 0;
virtual ~TableInterface(){};
};
@@ -315,7 +315,7 @@ class Table : private Log<ID, Data>,
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
Status Add(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
Status Add(const JobID &job_id, const ID &id, const std::shared_ptr<Data> &data,
const WriteCallback &done);
/// Lookup an entry asynchronously.
@@ -378,10 +378,10 @@ template <typename ID, typename Data>
class SetInterface {
public:
using WriteCallback = typename Log<ID, Data>::WriteCallback;
virtual Status Add(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
virtual Status Add(const JobID &job_id, const ID &id, const std::shared_ptr<Data> &data,
const WriteCallback &done) = 0;
virtual Status Remove(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
const WriteCallback &done) = 0;
virtual Status Remove(const JobID &job_id, const ID &id,
const std::shared_ptr<Data> &data, const WriteCallback &done) = 0;
virtual ~SetInterface(){};
};
@@ -420,7 +420,7 @@ class Set : private Log<ID, Data>,
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
Status Add(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
Status Add(const JobID &job_id, const ID &id, const std::shared_ptr<Data> &data,
const WriteCallback &done);
/// Remove an entry from the set.
@@ -431,7 +431,7 @@ class Set : private Log<ID, Data>,
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
Status Remove(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
Status Remove(const JobID &job_id, const ID &id, const std::shared_ptr<Data> &data,
const WriteCallback &done);
Status Subscribe(const JobID &job_id, const ClientID &client_id,
@@ -695,7 +695,8 @@ class TaskLeaseTable : public Table<TaskID, TaskLeaseData> {
prefix_ = TablePrefix::TASK_LEASE;
}
Status Add(const JobID &job_id, const TaskID &id, std::shared_ptr<TaskLeaseData> &data,
Status Add(const JobID &job_id, const TaskID &id,
const std::shared_ptr<TaskLeaseData> &data,
const WriteCallback &done) override {
RAY_RETURN_NOT_OK((Table<TaskID, TaskLeaseData>::Add(job_id, id, data, done)));
// Mark the entry for expiration in Redis. It's okay if this command fails
+123
View File
@@ -0,0 +1,123 @@
syntax = "proto3";
package ray.rpc;
option java_package = "org.ray.runtime.generated";
// Language of a task or worker.
enum Language {
PYTHON = 0;
JAVA = 1;
CPP = 2;
}
// Type of a task.
enum TaskType {
// Normal task.
NORMAL_TASK = 0;
// Actor creation task.
ACTOR_CREATION_TASK = 1;
// Actor task.
ACTOR_TASK = 2;
}
/// The task specification encapsulates all immutable information about the
/// task. These fields are determined at submission time, converse to the
/// `TaskExecutionSpec` may change at execution time.
message TaskSpec {
// Type of this task.
TaskType type = 1;
// Language of this task.
Language language = 2;
// Function descriptor of this task, which is a list of strings that can
// uniquely describe the function to execute.
// For a Python function, it should be: [module_name, class_name, function_name]
// For a Java function, it should be: [class_name, method_name, type_descriptor]
repeated bytes function_descriptor = 3;
// ID of the job that this task belongs to.
bytes job_id = 4;
// Task ID of the task.
bytes task_id = 5;
// Task ID of the parent task.
bytes parent_task_id = 6;
// A count of the number of tasks submitted by the parent task before this one.
uint64 parent_counter = 7;
// Task arguments.
repeated TaskArg args = 8;
// Number of return objects.
uint64 num_returns = 9;
// Quantities of the different resources required by this task.
map<string, double> required_resources = 10;
// The resources required for placing this task on a node. If this is empty,
// then the placement resources are equal to the required_resources.
map<string, double> required_placement_resources = 11;
// Task specification for an actor creation task.
// This field is only valid when `type == ACTOR_CREATION_TASK`.
ActorCreationTaskSpec actor_creation_task_spec = 14;
// Task specification for an actor task.
// This field is only valid when `type == ACTOR_TASK`.
ActorTaskSpec actor_task_spec = 15;
}
// Argument in the task.
message TaskArg {
// Object IDs for pass-by-reference arguments. Normally there is only one
// object ID in this list which represents the object that is being passed.
// However to support reducers in a MapReduce workload, we also support
// passing multiple object IDs for each argument.
repeated bytes object_ids = 1;
// Data for pass-by-value arguments.
bytes data = 2;
}
// Task spec of an actor creation task.
message ActorCreationTaskSpec {
// ID of the actor that will be created by this task.
bytes actor_id = 2;
// The max number of times this actor should be recontructed.
// If this number of 0 or negative, the actor won't be reconstructed on failure.
uint64 max_actor_reconstructions = 3;
// The dynamic options used in the worker command when starting the worker process for
// an actor creation task. If the list isn't empty, the options will be used to replace
// the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the
// worker command.
repeated string dynamic_worker_options = 4;
}
// Task spec of an actor task.
message ActorTaskSpec {
// Actor ID of the task. This is the actor that this task is executed on
// or NIL_ACTOR_ID if the task is just a normal task.
bytes actor_id = 2;
// The ID of the handle that was used to submit the task. This should be
// unique across handles with the same actor_id.
bytes actor_handle_id = 3;
// The dummy object ID of the actor creation task if this is an actor method.
bytes actor_creation_dummy_object_id = 4;
// Number of tasks that have been submitted to this actor so far.
uint64 actor_counter = 5;
// If this is an actor task, then this will be populated with all of the new
// actor handles that were forked from this handle since the last task on
// this handle was submitted.
// Note that this is a long string that concatenate all of the new_actor_handle IDs.
repeated bytes new_actor_handles = 6;
}
// The task execution specification encapsulates all mutable information about
// the task. These fields may change at execution time, converse to the
// `TaskSpec` is determined at submission time.
message TaskExecutionSpec {
// A list of object IDs representing the dependencies of this task that may
// change at execution time.
repeated bytes dependencies = 1;
// The last time this task was received for scheduling.
double last_timestamp = 2;
// The number of times this task was spilled back by raylets.
uint64 num_forwards = 3;
}
// Represents a task, including task spec, and task execution spec.
message Task {
TaskSpec task_spec = 1;
TaskExecutionSpec task_execution_spec = 2;
}
+3 -1
View File
@@ -2,6 +2,8 @@ syntax = "proto3";
package ray.rpc;
import "src/ray/protobuf/common.proto";
message ActorHandle {
// ID of the actor.
bytes actor_id = 1;
@@ -10,7 +12,7 @@ message ActorHandle {
bytes actor_handle_id = 2;
// Language of the actor.
int32 actor_language = 3;
Language actor_language = 3;
// Function descriptor of actor creation task.
repeated string actor_creation_task_function_descriptor = 4;
+3 -12
View File
@@ -2,14 +2,9 @@ syntax = "proto3";
package ray.rpc;
option java_package = "org.ray.runtime.generated";
import "src/ray/protobuf/common.proto";
// Language of a worker or task.
enum Language {
PYTHON = 0;
CPP = 1;
JAVA = 2;
}
option java_package = "org.ray.runtime.generated";
// These indexes are mapped to strings in ray_redis_module.cc.
enum TablePrefix {
@@ -77,12 +72,8 @@ message TaskReconstructionData {
bytes node_manager_id = 2;
}
// TODO(hchen): Task table currently still uses flatbuffers-defined data structure
// (`Task` in `node_manager.fbs`), because a lot of code depends on that. This should
// be migrated to protobuf very soon.
message TaskTableData {
// Flatbuffers-serialized content of the task, see `src/ray/raylet/task.h`.
bytes task = 1;
Task task = 1;
}
message ActorTableData {
+3 -5
View File
@@ -2,16 +2,14 @@ syntax = "proto3";
package ray.rpc;
import "src/ray/protobuf/common.proto";
message ForwardTaskRequest {
// The ID of the task to be forwarded.
bytes task_id = 1;
// The tasks in the uncommitted lineage of the forwarded task. This
// should include task_id.
// TODO(hchen): Currently, `uncommitted_tasks` are represented as
// flatbutters-serialized bytes. This is because the flatbuffers-defined Task data
// structure is being used in many places. We should move Task and all related data
// strucutres to protobuf.
repeated bytes uncommitted_tasks = 2;
repeated Task uncommitted_tasks = 2;
}
message ForwardTaskReply {
+4 -8
View File
@@ -2,15 +2,11 @@ syntax = "proto3";
package ray.rpc;
import "src/ray/protobuf/common.proto";
message AssignTaskRequest {
// The ID of the task to be pushed.
bytes task_id = 1;
// The task to be pushed. This should include task_id.
// TODO(hchen): Currently, `task_spec` are represented as
// flatbutters-serialized bytes. This is because the flatbuffers-defined Task data
// structure is being used in many places. We should move Task and all related data
// structures to protobuf.
bytes task_spec = 2;
// The task to be pushed.
Task task = 1;
}
message AssignTaskReply {
+2 -4
View File
@@ -1,8 +1,5 @@
// raylet protocol specification
include "gcs.fbs";
// TODO(swang): We put the flatbuffer types in a separate namespace for now to
// avoid conflicts with legacy Ray types.
namespace ray.protocol;
@@ -137,7 +134,8 @@ table RegisterClientRequest {
// The job ID if the client is a driver, otherwise it should be NIL.
job_id: string;
// Language of this worker.
language: Language;
// TODO(hchen): Use `Language` in `common.proto`.
language: int;
// Port that this worker is listening on.
// If port > 0, then worker will listen to this port and wait for
// raylet to push tasks, instead of invoking GetTask().
@@ -59,8 +59,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
* Signature: (J[BLjava/nio/ByteBuffer;II)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask(
JNIEnv *env, jclass, jlong client, jbyteArray cursorId, jobject taskBuff, jint pos,
jint taskSize) {
JNIEnv *env, jclass, jlong client, jbyteArray cursorId, jbyteArray taskSpec) {
auto raylet_client = reinterpret_cast<RayletClient *>(client);
std::vector<ObjectID> execution_dependencies;
@@ -69,8 +68,13 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit
execution_dependencies.push_back(cursor_id.GetId());
}
auto data = reinterpret_cast<uint8_t *>(env->GetDirectBufferAddress(taskBuff)) + pos;
ray::raylet::TaskSpecification task_spec(data, taskSize);
jbyte *data = env->GetByteArrayElements(taskSpec, NULL);
jsize size = env->GetArrayLength(taskSpec);
ray::rpc::TaskSpec task_spec_message;
task_spec_message.ParseFromArray(data, size);
env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT);
ray::raylet::TaskSpecification task_spec(task_spec_message);
auto status = raylet_client->SubmitTask(execution_dependencies, task_spec);
ThrowRayExceptionIfNotOK(env, status);
}
@@ -90,24 +94,16 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_native
return nullptr;
}
// We serialize the task specification using flatbuffers and then parse the
// resulting string. This awkwardness is due to the fact that the Java
// implementation does not use the underlying C++ TaskSpecification class.
flatbuffers::FlatBufferBuilder fbb;
auto message = spec->ToFlatbuffer(fbb);
fbb.Finish(message);
auto task_message = flatbuffers::GetRoot<flatbuffers::String>(fbb.GetBufferPointer());
// Serialize the task spec and copy to Java byte array.
auto task_data = spec->Serialize();
jbyteArray result;
result = env->NewByteArray(task_message->size());
jbyteArray result = env->NewByteArray(task_data.size());
if (result == nullptr) {
return nullptr; /* out of memory error thrown */
}
// move from task spec structure to the java structure
env->SetByteArrayRegion(
result, 0, task_message->size(),
reinterpret_cast<jbyte *>(const_cast<char *>(task_message->data())));
env->SetByteArrayRegion(result, 0, task_data.size(),
reinterpret_cast<const jbyte *>(task_data.data()));
return result;
}
@@ -20,10 +20,10 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeSubmitTask
* Signature: (J[BLjava/nio/ByteBuffer;II)V
* Signature: (J[B[B)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask(
JNIEnv *, jclass, jlong, jbyteArray, jobject, jint, jint);
JNIEnv *, jclass, jlong, jbyteArray, jbyteArray);
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
+4 -2
View File
@@ -272,9 +272,11 @@ void LineageCache::FlushTask(const TaskID &task_id) {
[this](ray::gcs::AsyncGcsClient *client, const TaskID &id,
const TaskTableData &data) { HandleEntryCommitted(id); };
auto task = lineage_.GetEntry(task_id);
// TODO(swang): Make this better...
auto task_data = std::make_shared<TaskTableData>();
task_data->set_task(task->TaskData().Serialize());
task_data->mutable_task()->mutable_task_spec()->CopyFrom(
task->TaskData().GetTaskSpecification().GetMessage());
task_data->mutable_task()->mutable_task_execution_spec()->CopyFrom(
task->TaskData().GetTaskExecutionSpec().GetMessage());
RAY_CHECK_OK(task_storage_.Add(JobID(task->TaskData().GetTaskSpecification().JobId()),
task_id, task_data, task_callback));
+11 -15
View File
@@ -8,6 +8,7 @@
#include "ray/raylet/task.h"
#include "ray/raylet/task_execution_spec.h"
#include "ray/raylet/task_spec.h"
#include "ray/raylet/task_util.h"
namespace ray {
@@ -23,7 +24,7 @@ class MockGcs : public gcs::TableInterface<TaskID, TaskTableData>,
}
Status Add(const JobID &job_id, const TaskID &task_id,
std::shared_ptr<TaskTableData> &task_data,
const std::shared_ptr<TaskTableData> &task_data,
const gcs::TableInterface<TaskID, TaskTableData>::WriteCallback &done) {
task_table_[task_id] = task_data;
auto callback = done;
@@ -125,21 +126,16 @@ class LineageCacheTest : public ::testing::Test {
};
static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
int64_t num_returns) {
std::unordered_map<std::string, double> required_resources;
std::vector<std::shared_ptr<TaskArgument>> task_arguments;
for (auto &argument : arguments) {
std::vector<ObjectID> references = {argument};
task_arguments.emplace_back(std::make_shared<TaskArgumentByReference>(references));
uint64_t num_returns) {
TaskSpecBuilder builder;
builder.SetCommonTaskSpec(Language::PYTHON, {"", "", ""}, JobID::Nil(),
TaskID::FromRandom(), 0, num_returns, {}, {});
for (const auto &arg : arguments) {
builder.AddByRefArg(arg);
}
std::vector<std::string> function_descriptor(3);
auto spec = TaskSpecification(JobID::Nil(), TaskID::FromRandom(), 0, task_arguments,
num_returns, required_resources, Language::PYTHON,
function_descriptor);
auto execution_spec = TaskExecutionSpecification(std::vector<ObjectID>());
execution_spec.IncrementNumForwards();
Task task = Task(execution_spec, spec);
return task;
rpc::TaskExecutionSpec execution_spec_message;
execution_spec_message.set_num_forwards(1);
return Task(builder.Build(), TaskExecutionSpecification(execution_spec_message));
}
/// Helper method to create a Lineage object with a single task.
+3
View File
@@ -2,11 +2,14 @@
#include "ray/common/ray_config.h"
#include "ray/common/status.h"
#include "ray/protobuf/common.pb.h"
#include "ray/raylet/raylet.h"
#include "ray/stats/stats.h"
#include "gflags/gflags.h"
using ray::rpc::Language;
DEFINE_string(raylet_socket_name, "", "The socket name of raylet.");
DEFINE_string(store_socket_name, "", "The socket name of object store.");
DEFINE_int32(object_manager_port, -1, "The port of object manager.");
+26 -30
View File
@@ -797,19 +797,10 @@ void NodeManager::ProcessClientMessage(
ProcessPushErrorRequestMessage(message_data);
} break;
case protocol::MessageType::PushProfileEventsRequest: {
ProfileTableDataT fbs_message;
flatbuffers::GetRoot<ProfileTableData>(message_data)->UnPackTo(&fbs_message);
auto fbs_message = flatbuffers::GetRoot<flatbuffers::String>(message_data);
rpc::ProfileTableData profile_table_data;
profile_table_data.set_component_type(fbs_message.component_type);
profile_table_data.set_component_id(fbs_message.component_id);
for (const auto &fbs_event : fbs_message.profile_events) {
rpc::ProfileTableData::ProfileEvent *event =
profile_table_data.add_profile_events();
event->set_event_type(fbs_event->event_type);
event->set_start_time(fbs_event->start_time);
event->set_end_time(fbs_event->end_time);
event->set_extra_data(fbs_event->extra_data);
}
RAY_CHECK(
profile_table_data.ParseFromArray(fbs_message->data(), fbs_message->size()));
RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data));
} break;
case protocol::MessageType::FreeObjectsInObjectStoreRequest: {
@@ -845,8 +836,9 @@ void NodeManager::ProcessRegisterClientRequestMessage(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
auto message = flatbuffers::GetRoot<protocol::RegisterClientRequest>(message_data);
client->SetClientID(from_flatbuf<ClientID>(*message->worker_id()));
auto worker = std::make_shared<Worker>(message->worker_pid(), message->language(),
message->port(), client);
Language language = static_cast<Language>(message->language());
auto worker =
std::make_shared<Worker>(message->worker_pid(), language, message->port(), client);
if (message->is_worker()) {
// Register the new worker.
worker_pool_.RegisterWorker(std::move(worker));
@@ -1050,14 +1042,18 @@ void NodeManager::ProcessDisconnectClientMessage(
void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) {
// Read the task submitted by the client.
auto message = flatbuffers::GetRoot<protocol::SubmitTaskRequest>(message_data);
TaskExecutionSpecification task_execution_spec(
from_flatbuf<ObjectID>(*message->execution_dependencies()));
TaskSpecification task_spec(*message->task_spec());
Task task(task_execution_spec, task_spec);
auto fbs_message = flatbuffers::GetRoot<protocol::SubmitTaskRequest>(message_data);
rpc::Task task_message;
RAY_CHECK(task_message.mutable_task_spec()->ParseFromArray(
fbs_message->task_spec()->data(), fbs_message->task_spec()->size()));
for (const auto &dependency :
string_vec_from_flatbuf(*fbs_message->execution_dependencies())) {
task_message.mutable_task_execution_spec()->add_dependencies(dependency);
}
// Submit the task to the raylet. Since the task was submitted
// locally, there is no uncommitted lineage.
SubmitTask(task, Lineage());
SubmitTask(Task(task_message), Lineage());
}
void NodeManager::ProcessFetchOrReconstructMessage(
@@ -1224,10 +1220,8 @@ void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request,
TaskID task_id = TaskID::FromBinary(request.task_id());
Lineage uncommitted_lineage;
for (int i = 0; i < request.uncommitted_tasks_size(); i++) {
const std::string &task_message = request.uncommitted_tasks(i);
const Task task(*flatbuffers::GetRoot<protocol::Task>(
reinterpret_cast<const uint8_t *>(task_message.data())));
RAY_CHECK(uncommitted_lineage.SetEntry(std::move(task), GcsStatus::UNCOMMITTED));
Task task(request.uncommitted_tasks(i));
RAY_CHECK(uncommitted_lineage.SetEntry(task, GcsStatus::UNCOMMITTED));
}
const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData();
RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId()
@@ -1769,7 +1763,7 @@ bool NodeManager::AssignTask(const Task &task) {
worker->GetTaskResourceIds().Plus(worker->GetLifetimeResourceIds());
auto resource_id_set_flatbuf = resource_id_set.ToFlatbuf(fbb);
auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb),
auto message = protocol::CreateGetTaskReply(fbb, fbb.CreateString(spec.Serialize()),
fbb.CreateVector(resource_id_set_flatbuf));
fbb.Finish(message);
const auto &task_id = spec.TaskId();
@@ -2025,9 +2019,7 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) {
const TaskTableData &task_data) {
// The task was in the GCS task table. Use the stored task spec to
// re-execute the task.
auto message = flatbuffers::GetRoot<protocol::Task>(task_data.task().data());
const Task task(*message);
ResubmitTask(task);
ResubmitTask(Task(task_data.task()));
},
/*failure_callback=*/
[this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id) {
@@ -2238,8 +2230,12 @@ void NodeManager::ForwardTask(
// Prepare the request message.
rpc::ForwardTaskRequest request;
request.set_task_id(task_id.Binary());
for (auto &entry : uncommitted_lineage.GetEntries()) {
request.add_uncommitted_tasks(entry.second.TaskData().Serialize());
for (auto &task_entry : uncommitted_lineage.GetEntries()) {
auto task = request.add_uncommitted_tasks();
task->mutable_task_spec()->CopyFrom(
task_entry.second.TaskData().GetTaskSpecification().GetMessage());
task->mutable_task_execution_spec()->CopyFrom(
task_entry.second.TaskData().GetTaskExecutionSpec().GetMessage());
}
// Move the FORWARDING task to the SWAP queue so that we remember that we
+3 -1
View File
@@ -10,6 +10,7 @@
#include "ray/raylet/task.h"
#include "ray/object_manager/object_manager.h"
#include "ray/common/client_connection.h"
#include "ray/protobuf/common.pb.h"
#include "ray/raylet/actor_registration.h"
#include "ray/raylet/lineage_cache.h"
#include "ray/raylet/scheduling_policy.h"
@@ -31,6 +32,7 @@ using rpc::ErrorType;
using rpc::HeartbeatBatchTableData;
using rpc::HeartbeatTableData;
using rpc::JobTableData;
using rpc::Language;
struct NodeManagerConfig {
/// The node's resource configuration.
@@ -48,7 +50,7 @@ struct NodeManagerConfig {
/// worker pool.
int maximum_startup_concurrency;
/// The commands used to start the worker process, grouped by language.
std::unordered_map<Language, std::vector<std::string>> worker_commands;
WorkerCommandMap worker_commands;
/// The time between heartbeats in milliseconds.
uint64_t heartbeat_period_ms;
/// The time between debug dumps in milliseconds, or -1 to disable.
+3 -3
View File
@@ -228,7 +228,7 @@ ray::Status RayletClient::SubmitTask(const std::vector<ObjectID> &execution_depe
flatbuffers::FlatBufferBuilder fbb;
auto execution_dependencies_message = to_flatbuf(fbb, execution_dependencies);
auto message = ray::protocol::CreateSubmitTaskRequest(
fbb, execution_dependencies_message, task_spec.ToFlatbuffer(fbb));
fbb, execution_dependencies_message, fbb.CreateString(task_spec.Serialize()));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::SubmitTask, &fbb);
}
@@ -335,9 +335,9 @@ ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string
return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb);
}
ray::Status RayletClient::PushProfileEvents(const ProfileTableDataT &profile_events) {
ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) {
flatbuffers::FlatBufferBuilder fbb;
auto message = CreateProfileTableData(fbb, &profile_events);
auto message = fbb.CreateString(profile_events.SerializeAsString());
fbb.Finish(message);
auto status = conn_->WriteMessage(MessageType::PushProfileEventsRequest, &fbb);
+5 -1
View File
@@ -1,6 +1,7 @@
#ifndef RAYLET_CLIENT_H
#define RAYLET_CLIENT_H
#include <ray/protobuf/gcs.pb.h>
#include <unistd.h>
#include <mutex>
#include <unordered_map>
@@ -17,6 +18,9 @@ using ray::ObjectID;
using ray::TaskID;
using ray::UniqueID;
using ray::rpc::Language;
using ray::rpc::ProfileTableData;
using MessageType = ray::protocol::MessageType;
using ResourceMappingType =
std::unordered_map<std::string, std::vector<std::pair<int64_t, double>>>;
@@ -138,7 +142,7 @@ class RayletClient {
///
/// \param profile_events A batch of profiling event information.
/// \return ray::Status.
ray::Status PushProfileEvents(const ProfileTableDataT &profile_events);
ray::Status PushProfileEvents(const ProfileTableData &profile_events);
/// Free a list of objects from object stores.
///
+3 -3
View File
@@ -85,7 +85,7 @@ class MockGcs : public gcs::PubsubInterface<TaskID>,
}
void Add(const JobID &job_id, const TaskID &task_id,
std::shared_ptr<TaskLeaseData> &task_lease_data) {
const std::shared_ptr<TaskLeaseData> &task_lease_data) {
task_lease_table_[task_id] = task_lease_data;
if (subscribed_tasks_.count(task_id) == 1) {
notification_callback_(nullptr, task_id, *task_lease_data);
@@ -112,7 +112,7 @@ class MockGcs : public gcs::PubsubInterface<TaskID>,
Status AppendAt(
const JobID &job_id, const TaskID &task_id,
std::shared_ptr<TaskReconstructionData> &task_data,
const std::shared_ptr<TaskReconstructionData> &task_data,
const ray::gcs::LogInterface<TaskID, TaskReconstructionData>::WriteCallback
&success_callback,
const ray::gcs::LogInterface<TaskID, TaskReconstructionData>::WriteCallback
@@ -134,7 +134,7 @@ class MockGcs : public gcs::PubsubInterface<TaskID>,
MOCK_METHOD4(
Append,
ray::Status(
const JobID &, const TaskID &, std::shared_ptr<TaskReconstructionData> &,
const JobID &, const TaskID &, const std::shared_ptr<TaskReconstructionData> &,
const ray::gcs::LogInterface<TaskID, TaskReconstructionData>::WriteCallback &));
private:
+1 -27
View File
@@ -4,24 +4,12 @@ namespace ray {
namespace raylet {
flatbuffers::Offset<protocol::Task> Task::ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const {
auto task = CreateTask(fbb, task_spec_.ToFlatbuffer(fbb),
task_execution_spec_.ToFlatbuffer(fbb));
return task;
}
const TaskExecutionSpecification &Task::GetTaskExecutionSpec() const {
return task_execution_spec_;
}
const TaskSpecification &Task::GetTaskSpecification() const { return task_spec_; }
void Task::SetExecutionDependencies(const std::vector<ObjectID> &dependencies) {
task_execution_spec_.SetExecutionDependencies(dependencies);
ComputeDependencies();
}
void Task::IncrementNumForwards() { task_execution_spec_.IncrementNumForwards(); }
const std::vector<ObjectID> &Task::GetDependencies() const { return dependencies_; }
@@ -42,24 +30,10 @@ void Task::ComputeDependencies() {
}
void Task::CopyTaskExecutionSpec(const Task &task) {
task_execution_spec_ = task.GetTaskExecutionSpec();
task_execution_spec_ = task.task_execution_spec_;
ComputeDependencies();
}
const std::string Task::Serialize() const {
flatbuffers::FlatBufferBuilder fbb;
fbb.Finish(ToFlatbuffer(fbb));
return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize());
}
std::string SerializeTaskAsString(const std::vector<ObjectID> *dependencies,
const TaskSpecification *task_spec) {
std::vector<ObjectID> execution_dependencies(*dependencies);
TaskExecutionSpecification execution_spec(std::move(execution_dependencies));
Task task(execution_spec, *task_spec);
return task.Serialize();
}
} // namespace raylet
} // namespace ray
+17 -45
View File
@@ -3,9 +3,11 @@
#include <inttypes.h>
#include "ray/protobuf/common.pb.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/task_execution_spec.h"
#include "ray/raylet/task_spec.h"
#include "ray/rpc/message_wrapper.h"
namespace ray {
@@ -19,41 +21,22 @@ namespace raylet {
/// time.
class Task {
public:
/// Create a task.
/// Construct a `Task` object from a protobuf message.
///
/// \param execution_spec The execution specification for the task. These are
/// the mutable fields in the task specification that may change at task
/// execution time.
/// \param task_spec The immutable specification for the task. These fields
/// are determined at task submission time.
Task(const TaskExecutionSpecification &execution_spec,
const TaskSpecification &task_spec)
: task_execution_spec_(execution_spec), task_spec_(task_spec) {
/// \param message The protobuf message.
explicit Task(const rpc::Task &message)
: task_spec_(message.task_spec()),
task_execution_spec_(message.task_execution_spec()) {
ComputeDependencies();
}
/// Create a task from a serialized flatbuffer.
///
/// \param task_flatbuffer The serialized task.
Task(const protocol::Task &task_flatbuffer)
: Task(*task_flatbuffer.task_execution_spec(),
*task_flatbuffer.task_specification()) {}
/// Create a task from a flatbuffer object.
///
/// \param task_data The task flatbuffer object.
Task(const protocol::TaskT &task_data)
: Task(*task_data.task_execution_spec, task_data.task_specification) {}
/// Destroy the task.
virtual ~Task() {}
/// Serialize a task to a flatbuffer.
///
/// \param fbb The flatbuffer builder.
/// \return An offset to the serialized task.
flatbuffers::Offset<protocol::Task> ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const;
/// Construct a `Task` object from a `TaskSpecification` and a
/// `TaskExecutionSpecification`.
Task(TaskSpecification task_spec, TaskExecutionSpecification task_execution_spec)
: task_spec_(std::move(task_spec)),
task_execution_spec_(std::move(task_execution_spec)) {
ComputeDependencies();
}
/// Get the mutable specification for the task. This specification may be
/// updated at runtime.
@@ -66,11 +49,6 @@ class Task {
/// \return The immutable specification for the task.
const TaskSpecification &GetTaskSpecification() const;
/// Set the task's execution dependencies.
///
/// \param dependencies The value to set the execution dependencies to.
void SetExecutionDependencies(const std::vector<ObjectID> &dependencies);
/// Increment the number of times this task has been forwarded.
void IncrementNumForwards();
@@ -84,28 +62,22 @@ class Task {
/// \param task Task structure with updated dynamic information.
void CopyTaskExecutionSpec(const Task &task);
/// Serialize this task as a string.
const std::string Serialize() const;
private:
void ComputeDependencies();
/// Task execution specification, consisting of all dynamic/mutable
/// information about this task determined at execution time..
TaskExecutionSpecification task_execution_spec_;
/// Task specification object, consisting of immutable information about this
/// task determined at submission time. Includes resource demand, object
/// dependencies, etc.
TaskSpecification task_spec_;
/// Task execution specification, consisting of all dynamic/mutable
/// information about this task determined at execution time..
TaskExecutionSpecification task_execution_spec_;
/// A cached copy of the task's object dependencies, including arguments from
/// the TaskSpecification and execution dependencies from the
/// TaskExecutionSpecification.
std::vector<ObjectID> dependencies_;
};
std::string SerializeTaskAsString(const std::vector<ObjectID> *dependencies,
const TaskSpecification *task_spec);
} // namespace raylet
} // namespace ray
+11 -15
View File
@@ -6,6 +6,7 @@
#include <boost/asio.hpp>
#include "ray/raylet/task_dependency_manager.h"
#include "ray/raylet/task_util.h"
namespace ray {
@@ -30,7 +31,7 @@ class MockGcs : public gcs::TableInterface<TaskID, TaskLeaseData> {
MOCK_METHOD4(
Add,
ray::Status(const JobID &job_id, const TaskID &task_id,
std::shared_ptr<TaskLeaseData> &task_data,
const std::shared_ptr<TaskLeaseData> &task_data,
const gcs::TableInterface<TaskID, TaskLeaseData>::WriteCallback &done));
};
@@ -67,21 +68,16 @@ class TaskDependencyManagerTest : public ::testing::Test {
};
static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
int64_t num_returns) {
std::unordered_map<std::string, double> required_resources;
std::vector<std::shared_ptr<TaskArgument>> task_arguments;
for (auto &argument : arguments) {
std::vector<ObjectID> references = {argument};
task_arguments.emplace_back(std::make_shared<TaskArgumentByReference>(references));
uint64_t num_returns) {
TaskSpecBuilder builder;
builder.SetCommonTaskSpec(Language::PYTHON, {"", "", ""}, JobID::Nil(),
TaskID::FromRandom(), 0, num_returns, {}, {});
for (const auto &arg : arguments) {
builder.AddByRefArg(arg);
}
std::vector<std::string> function_descriptor(3);
auto spec = TaskSpecification(JobID::Nil(), TaskID::FromRandom(), 0, task_arguments,
num_returns, required_resources, Language::PYTHON,
function_descriptor);
auto execution_spec = TaskExecutionSpecification(std::vector<ObjectID>());
execution_spec.IncrementNumForwards();
Task task = Task(execution_spec, spec);
return task;
rpc::TaskExecutionSpec execution_spec_message;
execution_spec_message.set_num_forwards(1);
return Task(builder.Build(), TaskExecutionSpecification(execution_spec_message));
}
std::vector<Task> MakeTaskChain(int chain_size,
+6 -44
View File
@@ -4,54 +4,16 @@ namespace ray {
namespace raylet {
TaskExecutionSpecification::TaskExecutionSpecification(
const std::vector<ObjectID> &&dependencies) {
SetExecutionDependencies(dependencies);
using rpc::IdVectorFromProtobuf;
const std::vector<ObjectID> TaskExecutionSpecification::ExecutionDependencies() const {
return IdVectorFromProtobuf<ObjectID>(message_.dependencies());
}
TaskExecutionSpecification::TaskExecutionSpecification(
const std::vector<ObjectID> &&dependencies, int num_forwards) {
// TaskExecutionSpecification(std::move(dependencies));
SetExecutionDependencies(dependencies);
execution_spec_.num_forwards = num_forwards;
}
flatbuffers::Offset<protocol::TaskExecutionSpecification>
TaskExecutionSpecification::ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const {
fbb.ForceDefaults(true);
return protocol::TaskExecutionSpecification::Pack(fbb, &execution_spec_);
}
std::vector<ObjectID> TaskExecutionSpecification::ExecutionDependencies() const {
std::vector<ObjectID> dependencies;
for (const auto &dependency : execution_spec_.dependencies) {
dependencies.push_back(ObjectID::FromBinary(dependency));
}
return dependencies;
}
void TaskExecutionSpecification::SetExecutionDependencies(
const std::vector<ObjectID> &dependencies) {
execution_spec_.dependencies.clear();
for (const auto &dependency : dependencies) {
execution_spec_.dependencies.push_back(dependency.Binary());
}
}
int TaskExecutionSpecification::NumForwards() const {
return execution_spec_.num_forwards;
}
size_t TaskExecutionSpecification::NumForwards() const { return message_.num_forwards(); }
void TaskExecutionSpecification::IncrementNumForwards() {
execution_spec_.num_forwards += 1;
}
int64_t TaskExecutionSpecification::LastTimestamp() const {
return execution_spec_.last_timestamp;
}
void TaskExecutionSpecification::SetLastTimestamp(int64_t new_timestamp) {
execution_spec_.last_timestamp = new_timestamp;
message_.set_num_forwards(message_.num_forwards() + 1);
}
} // namespace raylet
+18 -57
View File
@@ -4,84 +4,45 @@
#include <vector>
#include "ray/common/id.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/protobuf/common.pb.h"
#include "ray/rpc/message_wrapper.h"
#include "ray/rpc/util.h"
namespace ray {
namespace raylet {
/// \class TaskExecutionSpecification
///
/// The task execution specification encapsulates all mutable information about
/// the task. These fields may change at execution time, converse to the
/// TaskSpecification that is determined at submission time.
class TaskExecutionSpecification {
using rpc::MessageWrapper;
/// Wrapper class of protobuf `TaskExecutionSpec`, see `common.proto` for details.
class TaskExecutionSpecification : public MessageWrapper<rpc::TaskExecutionSpec> {
public:
TaskExecutionSpecification(const protocol::TaskExecutionSpecificationT &execution_spec)
: execution_spec_(execution_spec) {}
/// Create a task execution specification.
/// Construct from a protobuf message object.
/// The input message will be **copied** into this object.
///
/// \param dependencies The task's dependencies, determined at execution
/// time.
TaskExecutionSpecification(const std::vector<ObjectID> &&dependencies);
/// \param message The protobuf message.
explicit TaskExecutionSpecification(rpc::TaskExecutionSpec message)
: MessageWrapper(std::move(message)) {}
/// Create a task execution specification.
/// Construct from protobuf-serialized binary.
///
/// \param dependencies The task's dependencies, determined at execution
/// time.
/// \param num_forwards The number of times this task has been forwarded by a
/// node manager.
TaskExecutionSpecification(const std::vector<ObjectID> &&dependencies,
int num_forwards);
/// Create a task execution specification from a serialized flatbuffer.
///
/// \param spec_flatbuffer The serialized specification.
TaskExecutionSpecification(
const protocol::TaskExecutionSpecification &spec_flatbuffer) {
spec_flatbuffer.UnPackTo(&execution_spec_);
}
/// Serialize a task execution specification to a flatbuffer.
///
/// \param fbb The flatbuffer builder.
/// \return An offset to the serialized task execution specification.
flatbuffers::Offset<protocol::TaskExecutionSpecification> ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const;
/// \param serialized_binary Protobuf-serialized binary.
explicit TaskExecutionSpecification(const std::string &serialized_binary)
: MessageWrapper(serialized_binary) {}
/// Get the task's execution dependencies.
///
/// \return A vector of object IDs representing this task's execution
/// dependencies.
std::vector<ObjectID> ExecutionDependencies() const;
/// Set the task's execution dependencies.
///
/// \param dependencies The value to set the execution dependencies to.
void SetExecutionDependencies(const std::vector<ObjectID> &dependencies);
const std::vector<ObjectID> ExecutionDependencies() const;
/// Get the number of times this task has been forwarded.
///
/// \return The number of times this task has been forwarded.
int NumForwards() const;
size_t NumForwards() const;
/// Increment the number of times this task has been forwarded.
void IncrementNumForwards();
/// Get the task's last timestamp.
///
/// \return The timestamp when this task was last received for scheduling.
int64_t LastTimestamp() const;
/// Set the task's last timestamp to the specified value.
///
/// \param new_timestamp The new timestamp in millisecond to set the task's
/// time stamp to. Tracks the last time this task entered a raylet.
void SetLastTimestamp(int64_t new_timestamp);
private:
protocol::TaskExecutionSpecificationT execution_spec_;
};
} // namespace raylet
+74 -170
View File
@@ -1,151 +1,50 @@
#include "task_spec.h"
#include <sstream>
#include "ray/common/common_protocol.h"
#include "ray/gcs/format/gcs_generated.h"
#include "ray/raylet/task_spec.h"
#include "ray/rpc/util.h"
#include "ray/util/logging.h"
namespace ray {
namespace raylet {
TaskArgument::~TaskArgument() {}
using rpc::MapFromProtobuf;
using rpc::VectorFromProtobuf;
TaskArgumentByReference::TaskArgumentByReference(const std::vector<ObjectID> &references)
: references_(references) {}
flatbuffers::Offset<Arg> TaskArgumentByReference::ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const {
return CreateArg(fbb, ids_to_flatbuf(fbb, references_));
}
TaskArgumentByValue::TaskArgumentByValue(const uint8_t *value, size_t length) {
value_.assign(value, value + length);
}
flatbuffers::Offset<Arg> TaskArgumentByValue::ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const {
auto arg =
fbb.CreateString(reinterpret_cast<const char *>(value_.data()), value_.size());
const auto &empty_ids = fbb.CreateString("");
return CreateArg(fbb, empty_ids, arg);
}
void TaskSpecification::AssignSpecification(const uint8_t *spec, size_t spec_size) {
spec_.assign(spec, spec + spec_size);
// Initialize required_resources_ and required_placement_resources_
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
auto required_resources = map_from_flatbuf(*message->required_resources());
void TaskSpecification::ComputeResources() {
auto required_resources = MapFromProtobuf(message_.required_resources());
auto required_placement_resources =
map_from_flatbuf(*message->required_placement_resources());
// If the required_placement_resources field is empty, then the placement
// resources default to the required resources.
if (required_placement_resources.size() == 0) {
MapFromProtobuf(message_.required_placement_resources());
if (required_placement_resources.empty()) {
required_placement_resources = required_resources;
}
required_resources_ = ResourceSet(required_resources);
required_placement_resources_ = ResourceSet(required_placement_resources);
}
TaskSpecification::TaskSpecification(const flatbuffers::String &string) {
AssignSpecification(reinterpret_cast<const uint8_t *>(string.data()), string.size());
}
TaskSpecification::TaskSpecification(const std::string &string) {
AssignSpecification(reinterpret_cast<const uint8_t *>(string.data()), string.size());
}
TaskSpecification::TaskSpecification(const uint8_t *spec, size_t spec_size) {
AssignSpecification(spec, spec_size);
}
TaskSpecification::TaskSpecification(
const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments, int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const Language &language, const std::vector<std::string> &function_descriptor)
: TaskSpecification(job_id, parent_task_id, parent_counter, ActorID::Nil(),
ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), -1, {},
task_arguments, num_returns, required_resources,
std::unordered_map<std::string, double>(), language,
function_descriptor) {}
TaskSpecification::TaskSpecification(
const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter,
const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id,
const int64_t max_actor_reconstructions, const ActorID &actor_id,
const ActorHandleID &actor_handle_id, int64_t actor_counter,
const std::vector<ActorHandleID> &new_actor_handles,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments, int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources,
const Language &language, const std::vector<std::string> &function_descriptor,
const std::vector<std::string> &dynamic_worker_options)
: spec_() {
flatbuffers::FlatBufferBuilder fbb;
TaskID task_id = GenerateTaskId(job_id, parent_task_id, parent_counter);
// Add argument object IDs.
std::vector<flatbuffers::Offset<Arg>> arguments;
for (auto &argument : task_arguments) {
arguments.push_back(argument->ToFlatbuffer(fbb));
}
// Serialize the TaskSpecification.
auto spec = CreateTaskInfo(
fbb, to_flatbuf(fbb, job_id), to_flatbuf(fbb, task_id),
to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id),
to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions,
to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter,
ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), num_returns,
map_to_flatbuf(fbb, required_resources),
map_to_flatbuf(fbb, required_placement_resources), language,
string_vec_to_flatbuf(fbb, function_descriptor),
string_vec_to_flatbuf(fbb, dynamic_worker_options));
fbb.Finish(spec);
AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize());
}
flatbuffers::Offset<flatbuffers::String> TaskSpecification::ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const {
return fbb.CreateString(reinterpret_cast<const char *>(data()), size());
}
// TODO(atumanov): copy/paste most TaskSpec_* methods from task.h and make them
// methods of this class.
const uint8_t *TaskSpecification::data() const { return spec_.data(); }
size_t TaskSpecification::size() const { return spec_.size(); }
// Task specification getter methods.
TaskID TaskSpecification::TaskId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf<TaskID>(*message->task_id());
}
JobID TaskSpecification::JobId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf<JobID>(*message->job_id());
return TaskID::FromBinary(message_.task_id());
}
JobID TaskSpecification::JobId() const { return JobID::FromBinary(message_.job_id()); }
TaskID TaskSpecification::ParentTaskId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf<TaskID>(*message->parent_task_id());
}
int64_t TaskSpecification::ParentCounter() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return message->parent_counter();
return TaskID::FromBinary(message_.parent_task_id());
}
size_t TaskSpecification::ParentCounter() const { return message_.parent_counter(); }
std::vector<std::string> TaskSpecification::FunctionDescriptor() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return string_vec_from_flatbuf(*message->function_descriptor());
return VectorFromProtobuf(message_.function_descriptor());
}
std::string TaskSpecification::FunctionDescriptorString() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
auto list = string_vec_from_flatbuf(*message->function_descriptor());
auto list = VectorFromProtobuf(message_.function_descriptor());
std::ostringstream stream;
// The 4th is the code hash which is binary bits. No need to output it.
int size = std::min(static_cast<size_t>(3), list.size());
size_t size = std::min(static_cast<size_t>(3), list.size());
for (int i = 0; i < size; ++i) {
if (i != 0) {
stream << ",";
@@ -155,46 +54,32 @@ std::string TaskSpecification::FunctionDescriptorString() const {
return stream.str();
}
int64_t TaskSpecification::NumArgs() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return message->args()->size();
}
size_t TaskSpecification::NumArgs() const { return message_.args_size(); }
int64_t TaskSpecification::NumReturns() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return message->num_returns();
}
size_t TaskSpecification::NumReturns() const { return message_.num_returns(); }
ObjectID TaskSpecification::ReturnId(int64_t return_index) const {
ObjectID TaskSpecification::ReturnId(size_t return_index) const {
return ObjectID::ForTaskReturn(TaskId(), return_index + 1);
}
bool TaskSpecification::ArgByRef(int64_t arg_index) const {
bool TaskSpecification::ArgByRef(size_t arg_index) const {
return (ArgIdCount(arg_index) != 0);
}
int TaskSpecification::ArgIdCount(int64_t arg_index) const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
auto ids = message->args()->Get(arg_index)->object_ids();
return (ids->size() / kUniqueIDSize);
size_t TaskSpecification::ArgIdCount(size_t arg_index) const {
return message_.args(arg_index).object_ids_size();
}
ObjectID TaskSpecification::ArgId(int64_t arg_index, int64_t id_index) const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
const auto &object_ids =
ids_from_flatbuf<ObjectID>(*message->args()->Get(arg_index)->object_ids());
return object_ids[id_index];
ObjectID TaskSpecification::ArgId(size_t arg_index, size_t id_index) const {
return ObjectID::FromBinary(message_.args(arg_index).object_ids(id_index));
}
const uint8_t *TaskSpecification::ArgVal(int64_t arg_index) const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return reinterpret_cast<const uint8_t *>(
message->args()->Get(arg_index)->data()->c_str());
const uint8_t *TaskSpecification::ArgVal(size_t arg_index) const {
return reinterpret_cast<const uint8_t *>(message_.args(arg_index).data().data());
}
size_t TaskSpecification::ArgValLength(int64_t arg_index) const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return message->args()->Get(arg_index)->data()->size();
size_t TaskSpecification::ArgValLength(size_t arg_index) const {
return message_.args(arg_index).data().size();
}
const ResourceSet TaskSpecification::GetRequiredResources() const {
@@ -210,43 +95,59 @@ bool TaskSpecification::IsDriverTask() const {
return FunctionDescriptor().empty();
}
Language TaskSpecification::GetLanguage() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return message->language();
rpc::Language TaskSpecification::GetLanguage() const { return message_.language(); }
bool TaskSpecification::IsActorCreationTask() const {
return message_.type() == rpc::TaskType::ACTOR_CREATION_TASK;
}
bool TaskSpecification::IsActorCreationTask() const { return !ActorCreationId().IsNil(); }
bool TaskSpecification::IsActorTask() const { return !ActorId().IsNil(); }
bool TaskSpecification::IsActorTask() const {
return message_.type() == rpc::TaskType::ACTOR_TASK;
}
ActorID TaskSpecification::ActorCreationId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf<ActorID>(*message->actor_creation_id());
// TODO(hchen) Add a check to make sure this function can only be called if
// task is an actor creation task.
if (!IsActorCreationTask()) {
return ActorID::Nil();
}
return ActorID::FromBinary(message_.actor_creation_task_spec().actor_id());
}
ObjectID TaskSpecification::ActorCreationDummyObjectId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf<ObjectID>(*message->actor_creation_dummy_object_id());
if (!IsActorTask()) {
return ObjectID::Nil();
}
return ObjectID::FromBinary(
message_.actor_task_spec().actor_creation_dummy_object_id());
}
int64_t TaskSpecification::MaxActorReconstructions() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return message->max_actor_reconstructions();
uint64_t TaskSpecification::MaxActorReconstructions() const {
if (!IsActorCreationTask()) {
return 0;
}
return message_.actor_creation_task_spec().max_actor_reconstructions();
}
ActorID TaskSpecification::ActorId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf<ActorID>(*message->actor_id());
if (!IsActorTask()) {
return ActorID::Nil();
}
return ActorID::FromBinary(message_.actor_task_spec().actor_id());
}
ActorHandleID TaskSpecification::ActorHandleId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf<ActorHandleID>(*message->actor_handle_id());
if (!IsActorTask()) {
return ActorHandleID::Nil();
}
return ActorHandleID::FromBinary(message_.actor_task_spec().actor_handle_id());
}
int64_t TaskSpecification::ActorCounter() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return message->actor_counter();
uint64_t TaskSpecification::ActorCounter() const {
if (!IsActorTask()) {
return 0;
}
return message_.actor_task_spec().actor_counter();
}
ObjectID TaskSpecification::ActorDummyObject() const {
@@ -255,13 +156,16 @@ ObjectID TaskSpecification::ActorDummyObject() const {
}
std::vector<ActorHandleID> TaskSpecification::NewActorHandles() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return ids_from_flatbuf<ActorHandleID>(*message->new_actor_handles());
if (!IsActorTask()) {
return {};
}
return rpc::IdVectorFromProtobuf<ActorHandleID>(
message_.actor_task_spec().new_actor_handles());
}
std::vector<std::string> TaskSpecification::DynamicWorkerOptions() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return string_vec_from_flatbuf(*message->dynamic_worker_options());
return rpc::VectorFromProtobuf(
message_.actor_creation_task_spec().dynamic_worker_options());
}
} // namespace raylet
+57 -165
View File
@@ -7,8 +7,9 @@
#include <vector>
#include "ray/common/id.h"
#include "ray/gcs/format/gcs_generated.h"
#include "ray/protobuf/common.pb.h"
#include "ray/raylet/scheduling_resources.h"
#include "ray/rpc/message_wrapper.h"
extern "C" {
#include "ray/thirdparty/sha256.h"
@@ -18,179 +19,66 @@ namespace ray {
namespace raylet {
/// \class TaskArgument
///
/// A virtual class that represents an argument to a task.
class TaskArgument {
using rpc::Language;
using rpc::MessageWrapper;
using rpc::TaskType;
/// Wrapper class of protobuf `TaskSpec`, see `common.proto` for details.
class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
public:
/// Serialize the task argument to a flatbuffer.
/// Construct from a protobuf message object.
/// The input message will be **copied** into this object.
///
/// \param fbb The flatbuffer builder to serialize with.
/// \return An offset to the serialized task argument.
virtual flatbuffers::Offset<Arg> ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const = 0;
/// \param message The protobuf message.
explicit TaskSpecification(rpc::TaskSpec message) : MessageWrapper(std::move(message)) {
ComputeResources();
}
virtual ~TaskArgument() = 0;
};
/// \class TaskArgumentByReference
///
/// A task argument consisting of a list of object ID references.
class TaskArgumentByReference : virtual public TaskArgument {
public:
/// Create a task argument by reference from a list of object IDs.
/// Construct from protobuf-serialized binary.
///
/// \param references A list of object ID references.
TaskArgumentByReference(const std::vector<ObjectID> &references);
~TaskArgumentByReference(){};
flatbuffers::Offset<Arg> ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const;
private:
/// The object IDs.
const std::vector<ObjectID> references_;
};
/// \class TaskArgumentByValue
///
/// A task argument containing the raw value.
class TaskArgumentByValue : public TaskArgument {
public:
/// Create a task argument from a raw value.
///
/// \param value A pointer to the raw value.
/// \param length The size of the raw value.
TaskArgumentByValue(const uint8_t *value, size_t length);
flatbuffers::Offset<Arg> ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const;
private:
/// The raw value.
std::vector<uint8_t> value_;
};
/// \class TaskSpecification
///
/// The task specification encapsulates all immutable information about the
/// task. These fields are determined at submission time, converse to the
/// TaskExecutionSpecification that may change at execution time.
class TaskSpecification {
public:
/// Deserialize a task specification from a flatbuffer.
///
/// \param string A serialized task specification flatbuffer.
TaskSpecification(const flatbuffers::String &string);
// TODO(swang): Define an actor task constructor.
/// Create a task specification from the raw fields. This constructor omits
/// some values and sets them to sensible defaults.
///
/// \param job_id The driver ID, representing the job that this task is a
/// part of.
/// \param parent_task_id The task ID of the task that spawned this task.
/// \param parent_counter The number of tasks that this task's parent spawned
/// before this task.
/// \param function_descriptor The function descriptor.
/// \param task_arguments The list of task arguments.
/// \param num_returns The number of values returned by the task.
/// \param required_resources The task's resource demands.
/// \param language The language of the worker that must execute the function.
TaskSpecification(const JobID &job_id, const TaskID &parent_task_id,
int64_t parent_counter,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments,
int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const Language &language,
const std::vector<std::string> &function_descriptor);
// TODO(swang): Define an actor task constructor.
/// Create a task specification from the raw fields.
///
/// \param job_id The driver ID, representing the job that this task is a
/// part of.
/// \param parent_task_id The task ID of the task that spawned this task.
/// \param parent_counter The number of tasks that this task's parent spawned
/// before this task.
/// \param actor_creation_id If this is an actor task, then this is the ID of
/// the corresponding actor creation task. Otherwise, this is nil.
/// \param actor_id The ID of the actor for the task. If this is not an actor
/// task, then this is nil.
/// \param actor_handle_id The ID of the actor handle that submitted this
/// task. If this is not an actor task, then this is nil.
/// \param actor_counter The number of tasks submitted before this task from
/// the same actor handle. If this is not an actor task, then this is 0.
/// \param task_arguments The list of task arguments.
/// \param num_returns The number of values returned by the task.
/// \param required_resources The task's resource demands.
/// \param required_placement_resources The resources required to place this
/// task on a node. Typically, this should be an empty map in which case it
/// will default to be equal to the required_resources argument.
/// \param language The language of the worker that must execute the function.
/// \param function_descriptor The function descriptor.
/// \param dynamic_worker_options The dynamic options for starting an actor worker.
TaskSpecification(
const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter,
const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id,
int64_t max_actor_reconstructions, const ActorID &actor_id,
const ActorHandleID &actor_handle_id, int64_t actor_counter,
const std::vector<ActorHandleID> &new_actor_handles,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments,
int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources,
const Language &language, const std::vector<std::string> &function_descriptor,
const std::vector<std::string> &dynamic_worker_options = {});
/// Deserialize a task specification from a string.
///
/// \param string The string data for a serialized task specification flatbuffers.
TaskSpecification(const std::string &string);
/// Deserialize a task specification from raw byte array.
///
/// \param spec Raw byte array for a serialized task specification flatbuffer.
/// \param spec_size Size of the byte array.
TaskSpecification(const uint8_t *spec, size_t spec_size);
~TaskSpecification() {}
/// Serialize the TaskSpecification to a flatbuffer.
///
/// \param fbb The flatbuffer builder to serialize with.
/// \return An offset to the serialized task specification.
flatbuffers::Offset<flatbuffers::String> ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const;
std::string SerializeAsString() const {
flatbuffers::FlatBufferBuilder fbb;
auto string = ToFlatbuffer(fbb);
fbb.Finish(string);
return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize());
/// \param serialized_binary Protobuf-serialized binary.
explicit TaskSpecification(const std::string &serialized_binary)
: MessageWrapper(serialized_binary) {
ComputeResources();
}
// TODO(swang): Finalize and document these methods.
TaskID TaskId() const;
JobID JobId() const;
TaskID ParentTaskId() const;
int64_t ParentCounter() const;
size_t ParentCounter() const;
std::vector<std::string> FunctionDescriptor() const;
// Output the function descriptor as a string for log purpose.
std::string FunctionDescriptorString() const;
int64_t NumArgs() const;
int64_t NumReturns() const;
bool ArgByRef(int64_t arg_index) const;
int ArgIdCount(int64_t arg_index) const;
ObjectID ArgId(int64_t arg_index, int64_t id_index) const;
ObjectID ReturnId(int64_t return_index) const;
const uint8_t *ArgVal(int64_t arg_index) const;
size_t ArgValLength(int64_t arg_index) const;
size_t NumArgs() const;
size_t NumReturns() const;
bool ArgByRef(size_t arg_index) const;
size_t ArgIdCount(size_t arg_index) const;
ObjectID ArgId(size_t arg_index, size_t id_index) const;
ObjectID ReturnId(size_t return_index) const;
const uint8_t *ArgVal(size_t arg_index) const;
size_t ArgValLength(size_t arg_index) const;
/// Return the resources that are to be acquired during the execution of this
/// task.
///
/// \return The resources that will be acquired during the execution of this
/// task.
const ResourceSet GetRequiredResources() const;
/// Return the resources that are required for a task to be placed on a node.
/// This will typically be the same as the resources acquired during execution
/// and will always be a superset of those resources. However, they may
@@ -201,36 +89,40 @@ class TaskSpecification {
///
/// \return The resources that are required to place a task on a node.
const ResourceSet GetRequiredPlacementResources() const;
bool IsDriverTask() const;
Language GetLanguage() const;
// Methods specific to actor tasks.
bool IsActorCreationTask() const;
bool IsActorTask() const;
ActorID ActorCreationId() const;
ObjectID ActorCreationDummyObjectId() const;
int64_t MaxActorReconstructions() const;
uint64_t MaxActorReconstructions() const;
ActorID ActorId() const;
ActorHandleID ActorHandleId() const;
int64_t ActorCounter() const;
uint64_t ActorCounter() const;
ObjectID ActorDummyObject() const;
std::vector<ActorHandleID> NewActorHandles() const;
std::vector<std::string> DynamicWorkerOptions() const;
private:
/// Assign the specification data from a pointer.
void AssignSpecification(const uint8_t *spec, size_t spec_size);
/// Get a pointer to the byte data.
const uint8_t *data() const;
/// Get the size in bytes of the task specification.
size_t size() const;
void ComputeResources();
/// Field storing required resources. Initalized in constructor.
ResourceSet required_resources_;
/// Field storing required placement resources. Initalized in constructor.
ResourceSet required_placement_resources_;
/// The task specification data.
std::vector<uint8_t> spec_;
};
} // namespace raylet
-50
View File
@@ -48,56 +48,6 @@ TEST(IdPropertyTest, TestIdProperty) {
ASSERT_TRUE(ObjectID::Nil().IsNil());
}
TEST(TaskSpecTest, TaskInfoSize) {
std::vector<ObjectID> references = {ObjectID::FromRandom(), ObjectID::FromRandom()};
auto arguments_1 = std::make_shared<TaskArgumentByReference>(references);
std::string one_arg("This is an value argument.");
auto arguments_2 = std::make_shared<TaskArgumentByValue>(
reinterpret_cast<const uint8_t *>(one_arg.c_str()), one_arg.size());
std::vector<std::shared_ptr<TaskArgument>> task_arguments({arguments_1, arguments_2});
auto task_id = TaskID::FromRandom();
{
flatbuffers::FlatBufferBuilder fbb;
std::vector<flatbuffers::Offset<Arg>> arguments;
for (auto &argument : task_arguments) {
arguments.push_back(argument->ToFlatbuffer(fbb));
}
// General task.
auto spec = CreateTaskInfo(
fbb, to_flatbuf(fbb, JobID::FromRandom()), to_flatbuf(fbb, task_id),
to_flatbuf(fbb, TaskID::FromRandom()), 0, to_flatbuf(fbb, ActorID::Nil()),
to_flatbuf(fbb, ObjectID::Nil()), 0, to_flatbuf(fbb, ActorID::Nil()),
to_flatbuf(fbb, ActorHandleID::Nil()), 0,
ids_to_flatbuf(fbb, std::vector<ObjectID>()), fbb.CreateVector(arguments), 1,
map_to_flatbuf(fbb, {}), map_to_flatbuf(fbb, {}), Language::PYTHON,
string_vec_to_flatbuf(fbb, {"PackageName", "ClassName", "FunctionName"}));
fbb.Finish(spec);
RAY_LOG(ERROR) << "Ordinary task info size: " << fbb.GetSize();
}
{
flatbuffers::FlatBufferBuilder fbb;
std::vector<flatbuffers::Offset<Arg>> arguments;
for (auto &argument : task_arguments) {
arguments.push_back(argument->ToFlatbuffer(fbb));
}
// General task.
auto spec = CreateTaskInfo(
fbb, to_flatbuf(fbb, JobID::FromRandom()), to_flatbuf(fbb, task_id),
to_flatbuf(fbb, TaskID::FromRandom()), 10, to_flatbuf(fbb, ActorID::FromRandom()),
to_flatbuf(fbb, ObjectID::FromRandom()), 10000000,
to_flatbuf(fbb, ActorID::FromRandom()),
to_flatbuf(fbb, ActorHandleID::FromRandom()), 20,
ids_to_flatbuf(
fbb, std::vector<ObjectID>({ObjectID::FromRandom(), ObjectID::FromRandom()})),
fbb.CreateVector(arguments), 2, map_to_flatbuf(fbb, {}), map_to_flatbuf(fbb, {}),
Language::PYTHON,
string_vec_to_flatbuf(fbb, {"PackageName", "ClassName", "FunctionName"}));
fbb.Finish(spec);
RAY_LOG(ERROR) << "Actor task info size: " << fbb.GetSize();
}
}
} // namespace raylet
} // namespace ray
+120
View File
@@ -0,0 +1,120 @@
#ifndef RAY_RAYLET_TASK_UTIL_H
#define RAY_RAYLET_TASK_UTIL_H
#include "ray/protobuf/common.pb.h"
#include "ray/raylet/task_spec.h"
namespace ray {
namespace raylet {
/// Helper class for building a `TaskSpecification` object.
class TaskSpecBuilder {
public:
/// Build the `TaskSpecification` object.
TaskSpecification Build() { return TaskSpecification(message_); }
/// Get a reference to the internal protobuf message object.
const rpc::TaskSpec &GetMessage() const { return message_; }
/// Set the common attributes of the task spec.
/// See `common.proto` for meaning of the arguments.
///
/// \return Reference to the builder object itself.
TaskSpecBuilder &SetCommonTaskSpec(
const Language &language, const std::vector<std::string> &function_descriptor,
const JobID &job_id, const TaskID &parent_task_id, uint64_t parent_counter,
uint64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources) {
message_.set_type(rpc::TaskType::NORMAL_TASK);
message_.set_language(language);
for (const auto &fd : function_descriptor) {
message_.add_function_descriptor(fd);
}
message_.set_job_id(job_id.Binary());
message_.set_task_id(GenerateTaskId(job_id, parent_task_id, parent_counter).Binary());
message_.set_parent_task_id(parent_task_id.Binary());
message_.set_parent_counter(parent_counter);
message_.set_num_returns(num_returns);
message_.mutable_required_resources()->insert(required_resources.begin(),
required_resources.end());
message_.mutable_required_placement_resources()->insert(
required_placement_resources.begin(), required_placement_resources.end());
return *this;
}
/// Add a by-reference argument to the task.
///
/// \param arg_id Id of the argument.
/// \return Reference to the builder object itself.
TaskSpecBuilder &AddByRefArg(const ObjectID &arg_id) {
message_.add_args()->add_object_ids(arg_id.Binary());
return *this;
}
/// Add a by-value argument to the task.
///
/// \param data String object that contains the data.
/// \return Reference to the builder object itself.
TaskSpecBuilder &AddByValueArg(const std::string &data) {
message_.add_args()->set_data(data);
return *this;
}
/// Add a by-value argument to the task.
///
/// \param data Pointer to the data.
/// \param size Size of the data.
/// \return Reference to the builder object itself.
TaskSpecBuilder &AddByValueArg(const void *data, size_t size) {
message_.add_args()->set_data(data, size);
return *this;
}
/// Set the `ActorCreationTaskSpec` of the task spec.
/// See `common.proto` for meaning of the arguments.
///
/// \return Reference to the builder object itself.
TaskSpecBuilder &SetActorCreationTaskSpec(
const ActorID &actor_id, uint64_t max_reconstructions = 0,
const std::vector<std::string> &dynamic_worker_options = {}) {
message_.set_type(TaskType::ACTOR_CREATION_TASK);
auto actor_creation_spec = message_.mutable_actor_creation_task_spec();
actor_creation_spec->set_actor_id(actor_id.Binary());
actor_creation_spec->set_max_actor_reconstructions(max_reconstructions);
for (const auto &option : dynamic_worker_options) {
actor_creation_spec->add_dynamic_worker_options(option);
}
return *this;
}
/// Set the `ActorTaskSpec` of the task spec.
/// See `common.proto` for meaning of the arguments.
///
/// \return Reference to the builder object itself.
TaskSpecBuilder &SetActorTaskSpec(
const ActorID &actor_id, const ActorHandleID &actor_handle_id,
const ObjectID &actor_creation_dummy_object_id, uint64_t actor_counter,
const std::vector<ActorHandleID> &new_handle_ids = {}) {
message_.set_type(TaskType::ACTOR_TASK);
auto actor_spec = message_.mutable_actor_task_spec();
actor_spec->set_actor_id(actor_id.Binary());
actor_spec->set_actor_handle_id(actor_handle_id.Binary());
actor_spec->set_actor_creation_dummy_object_id(
actor_creation_dummy_object_id.Binary());
actor_spec->set_actor_counter(actor_counter);
for (const auto &id : new_handle_ids) {
actor_spec->add_new_actor_handles(id.Binary());
}
return *this;
}
private:
rpc::TaskSpec message_;
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_TASK_UTIL_H
+3
View File
@@ -5,12 +5,15 @@
#include "ray/common/client_connection.h"
#include "ray/common/id.h"
#include "ray/protobuf/common.pb.h"
#include "ray/raylet/scheduling_resources.h"
namespace ray {
namespace raylet {
using rpc::Language;
/// Worker class encapsulates the implementation details of a worker. A worker
/// is the execution container around a unit of Ray work, such as a task or an
/// actor. Ray units of work execute in the context of a Worker.
+8 -8
View File
@@ -41,10 +41,10 @@ namespace raylet {
/// A constructor that initializes a worker pool with
/// (num_worker_processes * num_workers_per_process) workers for each language.
WorkerPool::WorkerPool(
int num_worker_processes, int num_workers_per_process,
int maximum_startup_concurrency, std::shared_ptr<gcs::AsyncGcsClient> gcs_client,
const std::unordered_map<Language, std::vector<std::string>> &worker_commands)
WorkerPool::WorkerPool(int num_worker_processes, int num_workers_per_process,
int maximum_startup_concurrency,
std::shared_ptr<gcs::AsyncGcsClient> gcs_client,
const WorkerCommandMap &worker_commands)
: num_workers_per_process_(num_workers_per_process),
multiple_for_warning_(std::max(num_worker_processes, maximum_startup_concurrency)),
maximum_startup_concurrency_(maximum_startup_concurrency),
@@ -300,7 +300,7 @@ bool WorkerPool::DisconnectWorker(const std::shared_ptr<Worker> &worker) {
RAY_CHECK(RemoveWorker(state.registered_workers, worker));
stats::CurrentWorker().Record(
0, {{stats::LanguageKey, EnumNameLanguage(worker->GetLanguage())},
0, {{stats::LanguageKey, Language_Name(worker->GetLanguage())},
{stats::WorkerPidKey, std::to_string(worker->Pid())}});
return RemoveWorker(state.idle, worker);
@@ -310,7 +310,7 @@ void WorkerPool::DisconnectDriver(const std::shared_ptr<Worker> &driver) {
auto &state = GetStateForLanguage(driver->GetLanguage());
RAY_CHECK(RemoveWorker(state.registered_drivers, driver));
stats::CurrentDriver().Record(
0, {{stats::LanguageKey, EnumNameLanguage(driver->GetLanguage())},
0, {{stats::LanguageKey, Language_Name(driver->GetLanguage())},
{stats::WorkerPidKey, std::to_string(driver->Pid())}});
}
@@ -382,14 +382,14 @@ void WorkerPool::RecordMetrics() const {
// Record worker.
for (auto worker : entry.second.registered_workers) {
stats::CurrentWorker().Record(
worker->Pid(), {{stats::LanguageKey, EnumNameLanguage(worker->GetLanguage())},
worker->Pid(), {{stats::LanguageKey, Language_Name(worker->GetLanguage())},
{stats::WorkerPidKey, std::to_string(worker->Pid())}});
}
// Record driver.
for (auto driver : entry.second.registered_drivers) {
stats::CurrentDriver().Record(
driver->Pid(), {{stats::LanguageKey, EnumNameLanguage(driver->GetLanguage())},
driver->Pid(), {{stats::LanguageKey, Language_Name(driver->GetLanguage())},
{stats::WorkerPidKey, std::to_string(driver->Pid())}});
}
}
+11 -6
View File
@@ -8,7 +8,7 @@
#include "ray/common/client_connection.h"
#include "ray/gcs/client.h"
#include "ray/gcs/format/util.h"
#include "ray/protobuf/common.pb.h"
#include "ray/raylet/task.h"
#include "ray/raylet/worker.h"
@@ -16,6 +16,11 @@ namespace ray {
namespace raylet {
using rpc::Language;
using WorkerCommandMap =
std::unordered_map<Language, std::vector<std::string>, std::hash<int>>;
class Worker;
/// \class WorkerPool
@@ -36,10 +41,10 @@ class WorkerPool {
/// resources on the machine).
/// \param worker_commands The commands used to start the worker process, grouped by
/// language.
WorkerPool(
int num_worker_processes, int num_workers_per_process,
int maximum_startup_concurrency, std::shared_ptr<gcs::AsyncGcsClient> gcs_client,
const std::unordered_map<Language, std::vector<std::string>> &worker_commands);
WorkerPool(int num_worker_processes, int num_workers_per_process,
int maximum_startup_concurrency,
std::shared_ptr<gcs::AsyncGcsClient> gcs_client,
const WorkerCommandMap &worker_commands);
/// Destructor responsible for freeing a set of workers owned by this class.
virtual ~WorkerPool();
@@ -179,7 +184,7 @@ class WorkerPool {
/// The number of workers per process.
int num_workers_per_process_;
/// Pool states per language.
std::unordered_map<Language, State> states_by_lang_;
std::unordered_map<Language, State, std::hash<int>> states_by_lang_;
private:
/// A helper function that returns the reference of the pool state
+21 -13
View File
@@ -18,8 +18,7 @@ class WorkerPoolMock : public WorkerPool {
: WorkerPoolMock({{Language::PYTHON, {"dummy_py_worker_command"}},
{Language::JAVA, {"dummy_java_worker_command"}}}) {}
explicit WorkerPoolMock(
const std::unordered_map<Language, std::vector<std::string>> &worker_commands)
explicit WorkerPoolMock(const WorkerCommandMap &worker_commands)
: WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, nullptr,
worker_commands),
last_worker_pid_(0) {}
@@ -89,8 +88,7 @@ class WorkerPoolTest : public ::testing::Test {
return std::shared_ptr<Worker>(new Worker(pid, language, -1, client));
}
void SetWorkerCommands(
const std::unordered_map<Language, std::vector<std::string>> &worker_commands) {
void SetWorkerCommands(const WorkerCommandMap &worker_commands) {
WorkerPoolMock worker_pool(worker_commands);
this->worker_pool_ = std::move(worker_pool);
}
@@ -107,11 +105,23 @@ class WorkerPoolTest : public ::testing::Test {
static inline TaskSpecification ExampleTaskSpec(
const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON,
const ActorID actor_creation_id = ActorID::Nil()) {
std::vector<std::string> function_descriptor(3);
return TaskSpecification(JobID::Nil(), TaskID::Nil(), 0, actor_creation_id,
ObjectID::Nil(), 0, actor_id, ActorHandleID::Nil(), 0, {}, {},
0, {}, {}, language, function_descriptor);
const ActorID actor_creation_id = ActorID::Nil(),
const std::vector<std::string> &dynamic_worker_options = {}) {
rpc::TaskSpec message;
message.set_language(language);
if (!actor_id.IsNil()) {
message.set_type(rpc::TaskType::ACTOR_TASK);
message.mutable_actor_task_spec()->set_actor_id(actor_id.Binary());
} else if (!actor_creation_id.IsNil()) {
message.set_type(rpc::TaskType::ACTOR_CREATION_TASK);
message.mutable_actor_creation_task_spec()->set_actor_id(actor_creation_id.Binary());
for (const auto &option : dynamic_worker_options) {
message.mutable_actor_creation_task_spec()->add_dynamic_worker_options(option);
}
} else {
message.set_type(rpc::TaskType::NORMAL_TASK);
}
return TaskSpecification(std::move(message));
}
TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
@@ -226,10 +236,8 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}},
{Language::JAVA, java_worker_command}});
TaskSpecification task_spec(JobID::Nil(), TaskID::Nil(), 0, ActorID::FromRandom(),
ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), 0,
{}, {}, 0, {}, {}, Language::JAVA, {"", "", ""},
{"test_op_0", "test_op_1"});
TaskSpecification task_spec = ExampleTaskSpec(
ActorID::Nil(), Language::JAVA, ActorID::FromRandom(), {"test_op_0", "test_op_1"});
worker_pool_.StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions());
const auto real_command =
worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess());
+42
View File
@@ -0,0 +1,42 @@
#ifndef RAY_RPC_WRAPPER_H
#define RAY_RPC_WRAPPER_H
#include <memory>
namespace ray {
namespace rpc {
/// Wrap a protobuf message.
template <class Message>
class MessageWrapper {
public:
/// Construct from a protobuf message object.
/// The input message will be **copied** into this object.
///
/// \param message The protobuf message.
explicit MessageWrapper(const Message message) : message_(std::move(message)) {}
/// Construct from protobuf-serialized binary.
///
/// \param serialized_binary Protobuf-serialized binary.
explicit MessageWrapper(const std::string &serialized_binary) {
message_.ParseFromString(serialized_binary);
}
/// Get reference of the protobuf message.
const Message &GetMessage() const { return message_; }
/// Serialize the message to a string.
const std::string Serialize() const { return message_.SerializeAsString(); }
protected:
/// The wrapped message.
Message message_;
};
} // namespace rpc
} // namespace ray
#endif // RAY_RPC_WRAPPER_H
+20
View File
@@ -1,6 +1,7 @@
#ifndef RAY_RPC_UTIL_H
#define RAY_RPC_UTIL_H
#include <google/protobuf/map.h>
#include <google/protobuf/repeated_field.h>
#include <grpcpp/grpcpp.h>
@@ -28,18 +29,37 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) {
}
}
/// Converts a Protobuf `RepeatedPtrField` to a vector.
template <class T>
inline std::vector<T> VectorFromProtobuf(
const ::google::protobuf::RepeatedPtrField<T> &pb_repeated) {
return std::vector<T>(pb_repeated.begin(), pb_repeated.end());
}
/// Converts a Protobuf `RepeatedField` to a vector.
template <class T>
inline std::vector<T> VectorFromProtobuf(
const ::google::protobuf::RepeatedField<T> &pb_repeated) {
return std::vector<T>(pb_repeated.begin(), pb_repeated.end());
}
/// Converts a Protobuf `RepeatedField` to a vector of IDs.
template <class ID>
inline std::vector<ID> IdVectorFromProtobuf(
const ::google::protobuf::RepeatedPtrField<::std::string> &pb_repeated) {
auto str_vec = VectorFromProtobuf(pb_repeated);
std::vector<ID> ret;
std::transform(str_vec.begin(), str_vec.end(), std::back_inserter(ret),
&ID::FromBinary);
return ret;
}
/// Converts a Protobuf map to a `unordered_map`.
template <class K, class V>
inline std::unordered_map<K, V> MapFromProtobuf(::google::protobuf::Map<K, V> pb_map) {
return std::unordered_map<K, V>(pb_map.begin(), pb_map.end());
}
} // namespace rpc
} // namespace ray