diff --git a/.gitignore b/.gitignore
index b26c1d6b4..9b6947880 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,8 +4,6 @@
/python/ray/pyarrow_files/
/python/build
/python/dist
-/python/flatbuffers-1.7.1/
-/flatbuffers-1.7.1/
/thirdparty/pkg/
# Files generated by flatc should be ignored
diff --git a/BUILD.bazel b/BUILD.bazel
index f2d458e31..c9ecaad87 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -11,10 +11,27 @@ COPTS = ["-DRAY_USE_GLOG"]
# === Begin of protobuf definitions ===
+proto_library(
+ name = "common_proto",
+ srcs = ["src/ray/protobuf/common.proto"],
+ visibility = ["//java:__subpackages__"],
+)
+
+cc_proto_library(
+ name = "common_cc_proto",
+ deps = [":common_proto"],
+)
+
+python_proto_compile(
+ name = "common_py_proto",
+ deps = [":common_proto"],
+)
+
proto_library(
name = "gcs_proto",
srcs = ["src/ray/protobuf/gcs.proto"],
visibility = ["//java:__subpackages__"],
+ deps = [":common_proto"],
)
cc_proto_library(
@@ -30,6 +47,7 @@ python_proto_compile(
proto_library(
name = "node_manager_proto",
srcs = ["src/ray/protobuf/node_manager.proto"],
+ deps = [":common_proto"],
)
cc_proto_library(
@@ -50,6 +68,7 @@ cc_proto_library(
proto_library(
name = "worker_proto",
srcs = ["src/ray/protobuf/worker.proto"],
+ deps = [":common_proto"],
)
cc_proto_library(
@@ -60,6 +79,7 @@ cc_proto_library(
proto_library(
name = "core_worker_proto",
srcs = ["src/ray/protobuf/core_worker.proto"],
+ deps = [":common_proto"],
)
cc_proto_library(
@@ -242,8 +262,8 @@ cc_library(
copts = COPTS,
linkopts = ["-pthread"],
deps = [
+ ":common_cc_proto",
":gcs",
- ":gcs_fbs",
":node_manager_fbs",
":node_manager_rpc",
":object_manager",
@@ -491,11 +511,7 @@ cc_library(
],
),
copts = COPTS,
- includes = [
- "src/ray/gcs/format",
- ],
deps = [
- ":gcs_fbs",
":ray_util",
"@boost//:asio",
"@plasma//:plasma_client",
@@ -550,15 +566,10 @@ cc_library(
),
hdrs = glob([
"src/ray/gcs/*.h",
- "src/ray/gcs/format/*.h",
]),
copts = COPTS,
- includes = [
- "src/ray/gcs/format",
- ],
deps = [
":gcs_cc_proto",
- ":gcs_fbs",
":hiredis",
":node_manager_fbs",
":node_manager_rpc",
@@ -598,13 +609,6 @@ FLATC_ARGS = [
"--scoped-enums",
]
-flatbuffer_cc_library(
- name = "gcs_fbs",
- srcs = ["src/ray/gcs/format/gcs.fbs"],
- flatc_args = FLATC_ARGS,
- out_prefix = "src/ray/gcs/format/",
-)
-
flatbuffer_cc_library(
name = "common_fbs",
srcs = ["@plasma//:cpp/src/plasma/format/common.fbs"],
@@ -616,8 +620,6 @@ flatbuffer_cc_library(
name = "node_manager_fbs",
srcs = ["src/ray/raylet/format/node_manager.fbs"],
flatc_args = FLATC_ARGS,
- include_paths = ["src/ray/gcs/format"],
- includes = [":gcs_fbs_includes"],
out_prefix = "src/ray/raylet/format/",
)
@@ -686,43 +688,6 @@ filegroup(
visibility = ["//java:__subpackages__"],
)
-filegroup(
- name = "gcs_fbs_file",
- srcs = ["src/ray/gcs/format/gcs.fbs"],
- visibility = ["//java:__subpackages__"],
-)
-
-flatbuffer_py_library(
- name = "python_node_manager_fbs",
- srcs = [
- "src/ray/raylet/format/node_manager.fbs",
- ],
- outs = [
- "ray/protocol/DisconnectClient.py",
- "ray/protocol/FetchOrReconstruct.py",
- "ray/protocol/ForwardTaskRequest.py",
- "ray/protocol/FreeObjectsRequest.py",
- "ray/protocol/GetTaskReply.py",
- "ray/protocol/MessageType.py",
- "ray/protocol/NotifyUnblocked.py",
- "ray/protocol/PushErrorRequest.py",
- "ray/protocol/RegisterClientReply.py",
- "ray/protocol/RegisterClientRequest.py",
- "ray/protocol/RegisterNodeManagerRequest.py",
- "ray/protocol/ResourceIdSetInfo.py",
- "ray/protocol/SubmitTaskRequest.py",
- "ray/protocol/Task.py",
- "ray/protocol/TaskExecutionSpecification.py",
- "ray/protocol/WaitReply.py",
- "ray/protocol/WaitRequest.py",
- ],
- include_paths = [
- "src/ray/gcs/format/",
- ],
- includes = ["src/ray/gcs/format/gcs.fbs"],
- out_prefix = "python/ray/core/generated/",
-)
-
filegroup(
name = "python_sources",
srcs = glob([
@@ -781,13 +746,20 @@ cc_binary(
],
)
+filegroup(
+ name = "all_py_proto",
+ srcs = [
+ "common_py_proto",
+ "gcs_py_proto",
+ ],
+)
+
genrule(
name = "ray_pkg",
srcs = [
"python/ray/_raylet.so",
"//:python_sources",
- "//:gcs_py_proto",
- "//:python_node_manager_fbs",
+ "//:all_py_proto",
"//:redis-server",
"//:redis-cli",
"//:libray_redis_module.so",
@@ -809,12 +781,12 @@ genrule(
cp -f $(location @plasma//:plasma_store_server) $$WORK_DIR/python/ray/core/src/plasma/ &&
cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ &&
mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ &&
- for f in $(locations //:python_node_manager_fbs); do
- cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/;
- done &&
- for f in $(locations //:gcs_py_proto); do
+ for f in $(locations //:all_py_proto); do
cp -f $$f $$WORK_DIR/python/ray/core/generated/;
done &&
+ # NOTE(hchen): Protobuf doesn't allow specifying Python package name. So we use this `sed`
+ # command to change the import path in the generated file.
+ sed -i -E 's/from src.ray.protobuf/from ./' $$WORK_DIR/python/ray/core/generated/gcs_pb2.py &&
echo $$WORK_DIR > $@
""",
local = 1,
diff --git a/bazel/ray.bzl b/bazel/ray.bzl
index 750b90a21..8f41e05aa 100644
--- a/bazel/ray.bzl
+++ b/bazel/ray.bzl
@@ -13,23 +13,22 @@ def flatbuffer_py_library(name, srcs, outs, out_prefix, includes = [], include_p
includes = includes,
)
-def flatbuffer_java_library(name, srcs, outs, out_prefix, includes = [], include_paths = []):
- flatbuffer_library_public(
- name = name,
- srcs = srcs,
- outs = outs,
- language_flag = "-j",
- out_prefix = out_prefix,
- include_paths = include_paths,
- includes = includes,
- )
-
-def define_java_module(name, additional_srcs = [], additional_resources = [], define_test_lib = False, test_deps = [], **kwargs):
+def define_java_module(
+ name,
+ additional_srcs = [],
+ exclude_srcs = [],
+ additional_resources = [],
+ define_test_lib = False,
+ test_deps = [],
+ **kwargs):
lib_name = "org_ray_ray_" + name
pom_file_targets = [lib_name]
native.java_library(
name = lib_name,
- srcs = additional_srcs + native.glob([name + "/src/main/java/**/*.java"]),
+ srcs = additional_srcs + native.glob(
+ [name + "/src/main/java/**/*.java"],
+ exclude = exclude_srcs,
+ ),
resources = native.glob([name + "/src/main/resources/**"]) + additional_resources,
**kwargs
)
diff --git a/java/BUILD.bazel b/java/BUILD.bazel
index 4960434af..b9c43424f 100644
--- a/java/BUILD.bazel
+++ b/java/BUILD.bazel
@@ -1,4 +1,4 @@
-load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module")
+load("//bazel:ray.bzl", "define_java_module")
load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile")
exports_files([
@@ -50,8 +50,10 @@ define_java_module(
define_java_module(
name = "runtime",
additional_srcs = [
- ":generate_java_gcs_fbs",
- ":gcs_java_proto",
+ ":all_java_proto",
+ ],
+ exclude_srcs = [
+ "runtime/src/main/java/org/ray/runtime/generated/*.java",
],
additional_resources = [
":java_native_deps",
@@ -68,7 +70,6 @@ define_java_module(
deps = [
":org_ray_ray_api",
"@plasma//:org_apache_arrow_arrow_plasma",
- "@maven//:com_github_davidmoten_flatbuffers_java",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_typesafe_config",
@@ -151,39 +152,22 @@ java_binary(
],
)
+java_proto_compile(
+ name = "common_java_proto",
+ deps = ["@//:common_proto"],
+)
+
java_proto_compile(
name = "gcs_java_proto",
deps = ["@//:gcs_proto"],
)
-flatbuffers_generated_files = [
- "Arg.java",
- "Language.java",
- "TaskInfo.java",
- "ResourcePair.java",
-]
-
-flatbuffer_java_library(
- name = "java_gcs_fbs",
- srcs = ["//:gcs_fbs_file"],
- outs = flatbuffers_generated_files,
- out_prefix = "",
-)
-
-genrule(
- name = "generate_java_gcs_fbs",
- srcs = [":java_gcs_fbs"],
- outs = [
- "runtime/src/main/java/org/ray/runtime/generated/" + file for file in flatbuffers_generated_files
+filegroup(
+ name = "all_java_proto",
+ srcs = [
+ ":common_java_proto",
+ ":gcs_java_proto",
],
- cmd = """
- for f in $(locations //java:java_gcs_fbs); do
- chmod +w $$f
- mv -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated
- done
- python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/..
- """,
- local = 1,
)
filegroup(
@@ -202,8 +186,7 @@ filegroup(
genrule(
name = "gen_maven_deps",
srcs = [
- ":gcs_java_proto",
- ":generate_java_gcs_fbs",
+ ":all_java_proto",
":java_native_deps",
":copy_pom_file",
"@plasma//:org_apache_arrow_arrow_plasma",
@@ -212,6 +195,11 @@ genrule(
cmd = """
set -x
WORK_DIR=$$(pwd)
+ # Copy protobuf-generated files.
+ rm -rf $$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated
+ for f in $(locations //java:all_java_proto); do
+ unzip $$f -x META-INF/MANIFEST.MF -d $$WORK_DIR/java/runtime/src/main/java
+ done
# Copy native dependecies.
NATIVE_DEPS_DIR=$$WORK_DIR/java/runtime/native_dependencies/
rm -rf $$NATIVE_DEPS_DIR
@@ -220,18 +208,6 @@ genrule(
chmod +w $$f
cp $$f $$NATIVE_DEPS_DIR
done
- # Copy protobuf-generated files.
- GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated
- rm -rf $$GENERATED_DIR
- mkdir -p $$GENERATED_DIR
- for f in $(locations //java:gcs_java_proto); do
- unzip $$f
- mv org/ray/runtime/generated/* $$GENERATED_DIR
- done
- # Copy flatbuffers-generated files
- for f in $(locations //java:generate_java_gcs_fbs); do
- cp $$f $$GENERATED_DIR
- done
# Install plasma jar to local maven repo.
mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \
-DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT
diff --git a/java/dependencies.bzl b/java/dependencies.bzl
index ef6671375..26e36dff5 100644
--- a/java/dependencies.bzl
+++ b/java/dependencies.bzl
@@ -4,7 +4,6 @@ def gen_java_deps():
maven_install(
artifacts = [
"com.beust:jcommander:1.72",
- "com.github.davidmoten:flatbuffers-java:1.9.0.1",
"com.google.guava:guava:27.0.1-jre",
"com.google.protobuf:protobuf-java:3.8.0",
"com.puppycrawl.tools:checkstyle:8.15",
diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml
index e13dd95f9..aba612b36 100644
--- a/java/runtime/pom.xml
+++ b/java/runtime/pom.xml
@@ -31,11 +31,6 @@
jcommander
1.72
-
- com.github.davidmoten
- flatbuffers-java
- 1.9.0.1
-
com.google.guava
guava
diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java
index 00b114460..8c00f718c 100644
--- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java
+++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java
@@ -1,13 +1,17 @@
package org.ray.runtime.raylet;
import com.google.common.base.Preconditions;
-import com.google.flatbuffers.FlatBufferBuilder;
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.nio.charset.Charset;
import java.util.ArrayList;
-import java.util.HashMap;
+import java.util.Arrays;
import java.util.List;
import java.util.Map;
+import java.util.stream.Collectors;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
@@ -15,10 +19,8 @@ import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
-import org.ray.runtime.generated.Arg;
-import org.ray.runtime.generated.Language;
-import org.ray.runtime.generated.ResourcePair;
-import org.ray.runtime.generated.TaskInfo;
+import org.ray.runtime.generated.Common;
+import org.ray.runtime.generated.Common.TaskType;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskLanguage;
import org.ray.runtime.task.TaskSpec;
@@ -30,15 +32,6 @@ public class RayletClientImpl implements RayletClient {
private static final Logger LOGGER = LoggerFactory.getLogger(RayletClientImpl.class);
- private static final int TASK_SPEC_BUFFER_SIZE = 2 * 1024 * 1024;
-
- /**
- * Direct buffers that are used to hold flatbuffer-serialized task specs.
- */
- private static ThreadLocal taskSpecBuffer = ThreadLocal.withInitial(() ->
- ByteBuffer.allocateDirect(TASK_SPEC_BUFFER_SIZE).order(ByteOrder.LITTLE_ENDIAN)
- );
-
/**
* The pointer to c++'s raylet client.
*/
@@ -86,21 +79,20 @@ public class RayletClientImpl implements RayletClient {
Preconditions.checkState(!spec.parentTaskId.isNil());
Preconditions.checkState(!spec.jobId.isNil());
- ByteBuffer info = convertTaskSpecToFlatbuffer(spec);
+ byte[] taskSpec = convertTaskSpecToProtobuf(spec);
byte[] cursorId = null;
if (!spec.getExecutionDependencies().isEmpty()) {
//TODO(hchen): handle more than one dependencies.
cursorId = spec.getExecutionDependencies().get(0).getBytes();
}
- nativeSubmitTask(client, cursorId, info, info.position(), info.remaining());
+ nativeSubmitTask(client, cursorId, taskSpec);
}
@Override
public TaskSpec getTask() {
byte[] bytes = nativeGetTask(client);
assert (null != bytes);
- ByteBuffer bb = ByteBuffer.wrap(bytes);
- return parseTaskSpecFromFlatbuffer(bb);
+ return parseTaskSpecFromProtobuf(bytes);
}
@Override
@@ -127,7 +119,7 @@ public class RayletClientImpl implements RayletClient {
@Override
public void freePlasmaObjects(List objectIds, boolean localOnly,
- boolean deleteCreatingTasks) {
+ boolean deleteCreatingTasks) {
byte[][] objectIdsArray = IdUtil.getIdBytes(objectIds);
nativeFreePlasmaObjects(client, objectIdsArray, localOnly, deleteCreatingTasks);
}
@@ -142,59 +134,76 @@ public class RayletClientImpl implements RayletClient {
nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes());
}
- private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) {
- bb.order(ByteOrder.LITTLE_ENDIAN);
- TaskInfo info = TaskInfo.getRootAsTaskInfo(bb);
- UniqueId jobId = UniqueId.fromByteBuffer(info.jobIdAsByteBuffer());
- TaskId taskId = TaskId.fromByteBuffer(info.taskIdAsByteBuffer());
- TaskId parentTaskId = TaskId.fromByteBuffer(info.parentTaskIdAsByteBuffer());
- int parentCounter = info.parentCounter();
- UniqueId actorCreationId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer());
- int maxActorReconstructions = info.maxActorReconstructions();
- UniqueId actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer());
- UniqueId actorHandleId = UniqueId.fromByteBuffer(info.actorHandleIdAsByteBuffer());
- int actorCounter = info.actorCounter();
- int numReturns = info.numReturns();
+ /**
+ * Parse `TaskSpec` protobuf bytes.
+ */
+ private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
+ Common.TaskSpec taskSpec;
+ try {
+ taskSpec = Common.TaskSpec.parseFrom(bytes);
+ } catch (InvalidProtocolBufferException e) {
+ throw new RuntimeException("Invalid protobuf data.");
+ }
- // Deserialize new actor handles
- UniqueId[] newActorHandles = IdUtil.getUniqueIdsFromByteBuffer(
- info.newActorHandlesAsByteBuffer());
+ // Parse common fields.
+ UniqueId jobId = UniqueId.fromByteBuffer(taskSpec.getJobId().asReadOnlyByteBuffer());
+ TaskId taskId = TaskId.fromByteBuffer(taskSpec.getTaskId().asReadOnlyByteBuffer());
+ TaskId parentTaskId = TaskId.fromByteBuffer(taskSpec.getParentTaskId().asReadOnlyByteBuffer());
+ int parentCounter = (int) taskSpec.getParentCounter();
+ int numReturns = (int) taskSpec.getNumReturns();
+ Map resources = taskSpec.getRequiredResourcesMap();
- // Deserialize args
- FunctionArg[] args = new FunctionArg[info.argsLength()];
- for (int i = 0; i < info.argsLength(); i++) {
- Arg arg = info.args(i);
-
- int objectIdsLength = arg.objectIdsAsByteBuffer().remaining() / UniqueId.LENGTH;
+ // Parse args.
+ FunctionArg[] args = new FunctionArg[taskSpec.getArgsCount()];
+ for (int i = 0; i < args.length; i++) {
+ Common.TaskArg arg = taskSpec.getArgs(i);
+ int objectIdsLength = arg.getObjectIdsCount();
if (objectIdsLength > 0) {
Preconditions.checkArgument(objectIdsLength == 1,
"This arg has more than one id: {}", objectIdsLength);
- args[i] = FunctionArg.passByReference(ObjectId.fromByteBuffer(arg.objectIdsAsByteBuffer()));
+ args[i] = FunctionArg
+ .passByReference(ObjectId.fromByteBuffer(arg.getObjectIds(0).asReadOnlyByteBuffer()));
} else {
- ByteBuffer lbb = arg.dataAsByteBuffer();
- Preconditions.checkState(lbb != null && lbb.remaining() > 0);
- byte[] data = new byte[lbb.remaining()];
- lbb.get(data);
- args[i] = FunctionArg.passByValue(data);
+ args[i] = FunctionArg.passByValue(arg.getData().toByteArray());
}
}
- // Deserialize required resources;
- Map resources = new HashMap<>();
- for (int i = 0; i < info.requiredResourcesLength(); i++) {
- resources.put(info.requiredResources(i).key(), info.requiredResources(i).value());
- }
- // Deserialize function descriptor
- Preconditions.checkArgument(info.language() == Language.JAVA);
- Preconditions.checkArgument(info.functionDescriptorLength() == 3);
+ // Parse function descriptor
+ Preconditions.checkArgument(taskSpec.getLanguage() == Common.Language.JAVA);
+ Preconditions.checkArgument(taskSpec.getFunctionDescriptorCount() == 3);
JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor(
- info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2)
+ taskSpec.getFunctionDescriptor(0).toString(Charset.defaultCharset()),
+ taskSpec.getFunctionDescriptor(1).toString(Charset.defaultCharset()),
+ taskSpec.getFunctionDescriptor(2).toString(Charset.defaultCharset())
);
- // Deserialize dynamic worker options.
+ // Parse ActorCreationTaskSpec.
+ UniqueId actorCreationId = UniqueId.NIL;
+ int maxActorReconstructions = 0;
+ UniqueId[] newActorHandles = new UniqueId[0];
List dynamicWorkerOptions = new ArrayList<>();
- for (int i = 0; i < info.dynamicWorkerOptionsLength(); ++i) {
- dynamicWorkerOptions.add(info.dynamicWorkerOptions(i));
+ if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) {
+ Common.ActorCreationTaskSpec actorCreationTaskSpec = taskSpec.getActorCreationTaskSpec();
+ actorCreationId = UniqueId
+ .fromByteBuffer(actorCreationTaskSpec.getActorId().asReadOnlyByteBuffer());
+ maxActorReconstructions = (int) actorCreationTaskSpec.getMaxActorReconstructions();
+ dynamicWorkerOptions = ImmutableList
+ .copyOf(actorCreationTaskSpec.getDynamicWorkerOptionsList());
+ }
+
+ // Parse ActorTaskSpec.
+ UniqueId actorId = UniqueId.NIL;
+ UniqueId actorHandleId = UniqueId.NIL;
+ int actorCounter = 0;
+ if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) {
+ Common.ActorTaskSpec actorTaskSpec = taskSpec.getActorTaskSpec();
+ actorId = UniqueId.fromByteBuffer(actorTaskSpec.getActorId().asReadOnlyByteBuffer());
+ actorHandleId = UniqueId
+ .fromByteBuffer(actorTaskSpec.getActorHandleId().asReadOnlyByteBuffer());
+ actorCounter = (int) actorTaskSpec.getActorCounter();
+ newActorHandles = actorTaskSpec.getNewActorHandlesList().stream()
+ .map(byteString -> UniqueId.fromByteBuffer(byteString.asReadOnlyByteBuffer()))
+ .toArray(UniqueId[]::new);
}
return new TaskSpec(jobId, taskId, parentTaskId, parentCounter, actorCreationId,
@@ -202,122 +211,78 @@ public class RayletClientImpl implements RayletClient {
args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions);
}
- private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) {
- ByteBuffer bb = taskSpecBuffer.get();
- bb.clear();
+ /**
+ * Convert a `TaskSpec` to protobuf-serialized bytes.
+ */
+ private static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
+ // Set common fields.
+ Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder()
+ .setJobId(ByteString.copyFrom(task.jobId.getBytes()))
+ .setTaskId(ByteString.copyFrom(task.taskId.getBytes()))
+ .setParentTaskId(ByteString.copyFrom(task.parentTaskId.getBytes()))
+ .setParentCounter(task.parentCounter)
+ .setNumReturns(task.numReturns)
+ .putAllRequiredResources(task.resources);
- FlatBufferBuilder fbb = new FlatBufferBuilder(bb);
- final int jobIdOffset = fbb.createString(task.jobId.toByteBuffer());
- final int taskIdOffset = fbb.createString(task.taskId.toByteBuffer());
- final int parentTaskIdOffset = fbb.createString(task.parentTaskId.toByteBuffer());
- final int parentCounter = task.parentCounter;
- final int actorCreateIdOffset = fbb.createString(task.actorCreationId.toByteBuffer());
- final int actorCreateDummyIdOffset = fbb.createString(task.actorId.toByteBuffer());
- final int maxActorReconstructions = task.maxActorReconstructions;
- final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer());
- final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer());
- final int actorCounter = task.actorCounter;
- final int numReturnsOffset = task.numReturns;
-
- // Serialize the new actor handles.
- int newActorHandlesOffset
- = fbb.createString(IdUtil.concatIds(task.newActorHandles));
-
- // Serialize args
- int[] argsOffsets = new int[task.args.length];
- for (int i = 0; i < argsOffsets.length; i++) {
- int objectIdOffset = 0;
- int dataOffset = 0;
- if (task.args[i].id != null) {
- objectIdOffset = fbb.createString(
- IdUtil.concatIds(new ObjectId[]{task.args[i].id}));
- } else {
- objectIdOffset = fbb.createString("");
- }
- if (task.args[i].data != null) {
- dataOffset = fbb.createString(ByteBuffer.wrap(task.args[i].data));
- }
- argsOffsets[i] = Arg.createArg(fbb, objectIdOffset, dataOffset);
- }
- int argsOffset = fbb.createVectorOfTables(argsOffsets);
-
- // Serialize required resources
- // The required_resources vector indicates the quantities of the different
- // resources required by this task. The index in this vector corresponds to
- // the resource type defined in the ResourceIndex enum. For example,
- int[] requiredResourcesOffsets = new int[task.resources.size()];
- int i = 0;
- for (Map.Entry entry : task.resources.entrySet()) {
- int keyOffset = fbb.createString(ByteBuffer.wrap(entry.getKey().getBytes()));
- requiredResourcesOffsets[i++] =
- ResourcePair.createResourcePair(fbb, keyOffset, entry.getValue());
- }
- int requiredResourcesOffset = fbb.createVectorOfTables(requiredResourcesOffsets);
-
- int[] requiredPlacementResourcesOffsets = new int[0];
- int requiredPlacementResourcesOffset =
- fbb.createVectorOfTables(requiredPlacementResourcesOffsets);
-
- int language;
- int functionDescriptorOffset;
+ // Set args
+ builder.addAllArgs(
+ Arrays.stream(task.args).map(arg -> {
+ Common.TaskArg.Builder argBuilder = Common.TaskArg.newBuilder();
+ if (arg.id != null) {
+ argBuilder.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build();
+ } else {
+ argBuilder.setData(ByteString.copyFrom(arg.data)).build();
+ }
+ return argBuilder.build();
+ }).collect(Collectors.toList())
+ );
+ // Set function descriptor and language.
if (task.language == TaskLanguage.JAVA) {
- // This is a Java task.
- language = Language.JAVA;
- int[] functionDescriptorOffsets = new int[]{
- fbb.createString(task.getJavaFunctionDescriptor().className),
- fbb.createString(task.getJavaFunctionDescriptor().name),
- fbb.createString(task.getJavaFunctionDescriptor().typeDescriptor)
- };
- functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
+ builder.setLanguage(Common.Language.JAVA);
+ builder.addAllFunctionDescriptor(ImmutableList.of(
+ ByteString.copyFrom(task.getJavaFunctionDescriptor().className.getBytes()),
+ ByteString.copyFrom(task.getJavaFunctionDescriptor().name.getBytes()),
+ ByteString.copyFrom(task.getJavaFunctionDescriptor().typeDescriptor.getBytes())
+ ));
} else {
- // This is a Python task.
- language = Language.PYTHON;
- int[] functionDescriptorOffsets = new int[]{
- fbb.createString(task.getPyFunctionDescriptor().moduleName),
- fbb.createString(task.getPyFunctionDescriptor().className),
- fbb.createString(task.getPyFunctionDescriptor().functionName),
- fbb.createString("")
- };
- functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
+ builder.setLanguage(Common.Language.PYTHON);
+ builder.addAllFunctionDescriptor(ImmutableList.of(
+ ByteString.copyFrom(task.getPyFunctionDescriptor().moduleName.getBytes()),
+ ByteString.copyFrom(task.getPyFunctionDescriptor().className.getBytes()),
+ ByteString.copyFrom(task.getPyFunctionDescriptor().functionName.getBytes()),
+ ByteString.EMPTY
+ ));
}
- int [] dynamicWorkerOptionsOffsets = new int[task.dynamicWorkerOptions.size()];
- for (int index = 0; index < task.dynamicWorkerOptions.size(); ++index) {
- dynamicWorkerOptionsOffsets[index] = fbb.createString(task.dynamicWorkerOptions.get(index));
+ if (!task.actorCreationId.isNil()) {
+ // Actor creation task.
+ builder.setType(TaskType.ACTOR_CREATION_TASK);
+ builder.setActorCreationTaskSpec(
+ Common.ActorCreationTaskSpec.newBuilder()
+ .setActorId(ByteString.copyFrom(task.actorCreationId.getBytes()))
+ .setMaxActorReconstructions(task.maxActorReconstructions)
+ .addAllDynamicWorkerOptions(task.dynamicWorkerOptions)
+ );
+ } else if (!task.actorId.isNil()) {
+ // Actor task.
+ builder.setType(TaskType.ACTOR_TASK);
+ List newHandles = Arrays.stream(task.newActorHandles)
+ .map(id -> ByteString.copyFrom(id.getBytes())).collect(Collectors.toList());
+ builder.setActorTaskSpec(
+ Common.ActorTaskSpec.newBuilder()
+ .setActorId(ByteString.copyFrom(task.actorId.getBytes()))
+ .setActorHandleId(ByteString.copyFrom(task.actorHandleId.getBytes()))
+ .setActorCreationDummyObjectId(ByteString.copyFrom(task.actorId.getBytes()))
+ .setActorCounter(task.actorCounter)
+ .addAllNewActorHandles(newHandles)
+ );
+ } else {
+ // Normal task.
+ builder.setType(TaskType.NORMAL_TASK);
}
- int dynamicWorkerOptionsOffset = fbb.createVectorOfTables(dynamicWorkerOptionsOffsets);
- int root = TaskInfo.createTaskInfo(
- fbb,
- jobIdOffset,
- taskIdOffset,
- parentTaskIdOffset,
- parentCounter,
- actorCreateIdOffset,
- actorCreateDummyIdOffset,
- maxActorReconstructions,
- actorIdOffset,
- actorHandleIdOffset,
- actorCounter,
- newActorHandlesOffset,
- argsOffset,
- numReturnsOffset,
- requiredResourcesOffset,
- requiredPlacementResourcesOffset,
- language,
- functionDescriptorOffset,
- dynamicWorkerOptionsOffset);
- fbb.finish(root);
- ByteBuffer buffer = fbb.dataBuffer();
-
- if (buffer.remaining() > TASK_SPEC_BUFFER_SIZE) {
- LOGGER.error(
- "Allocated buffer is not enough to transfer the task specification: {} vs {}",
- TASK_SPEC_BUFFER_SIZE, buffer.remaining());
- throw new RuntimeException("Allocated buffer is not enough to transfer to task.");
- }
- return buffer;
+ return builder.build().toByteArray();
}
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
@@ -344,10 +309,9 @@ public class RayletClientImpl implements RayletClient {
private static native long nativeInit(String localSchedulerSocket, byte[] workerId,
boolean isWorker, byte[] driverTaskId);
- private static native void nativeSubmitTask(long client, byte[] cursorId, ByteBuffer taskBuff,
- int pos, int taskSize) throws RayException;
+ private static native void nativeSubmitTask(long client, byte[] cursorId, byte[] taskSpec)
+ throws RayException;
- // return TaskInfo (in FlatBuffer)
private static native byte[] nativeGetTask(long client) throws RayException;
private static native void nativeDestroy(long client) throws RayException;
@@ -375,5 +339,5 @@ public class RayletClientImpl implements RayletClient {
byte[] checkpointId);
private static native void nativeSetResource(long conn, String resourceName, double capacity,
- byte[] nodeId) throws RayException;
+ byte[] nodeId) throws RayException;
}
diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx
index f3968577d..3e5586069 100644
--- a/python/ray/_raylet.pyx
+++ b/python/ray/_raylet.pyx
@@ -24,8 +24,8 @@ from ray.includes.common cimport (
)
from ray.includes.libraylet cimport (
CRayletClient,
- GCSProfileEventT,
- GCSProfileTableDataT,
+ GCSProfileEvent,
+ GCSProfileTableData,
ResourceMappingType,
WaitResultPair,
)
@@ -34,7 +34,7 @@ from ray.includes.unique_ids cimport (
CObjectID,
CClientID,
)
-from ray.includes.task cimport CTaskSpecification
+from ray.includes.task cimport CTaskSpec
from ray.includes.ray_config cimport RayConfig
from ray.utils import decode
@@ -232,18 +232,22 @@ cdef class RayletClient:
def disconnect(self):
check_status(self.client.get().Disconnect())
- def submit_task(self, Task task_spec):
+ def submit_task(self, TaskSpec task_spec, execution_dependencies):
+ cdef:
+ CObjectID c_id
+ c_vector[CObjectID] c_dependencies
+ for dep in execution_dependencies:
+ c_dependencies.push_back((dep).native())
check_status(self.client.get().SubmitTask(
- task_spec.execution_dependencies.get()[0],
- task_spec.task_spec.get()[0]))
+ c_dependencies, task_spec.task_spec.get()[0]))
def get_task(self):
cdef:
- unique_ptr[CTaskSpecification] task_spec
+ unique_ptr[CTaskSpec] task_spec
with nogil:
check_status(self.client.get().GetTask(&task_spec))
- return Task.make(task_spec)
+ return TaskSpec.make(task_spec)
def task_done(self):
check_status(self.client.get().TaskDone())
@@ -303,19 +307,19 @@ cdef class RayletClient:
def push_profile_events(self, component_type, UniqueID component_id,
node_ip_address, profile_data):
cdef:
- GCSProfileTableDataT profile_info
- GCSProfileEventT *profile_event
+ GCSProfileTableData profile_info
+ GCSProfileEvent *profile_event
c_string event_type
if len(profile_data) == 0:
return # Short circuit if there are no profile events.
- profile_info.component_type = component_type.encode("ascii")
- profile_info.component_id = component_id.binary()
- profile_info.node_ip_address = node_ip_address.encode("ascii")
+ profile_info.set_component_type(component_type.encode("ascii"))
+ profile_info.set_component_id(component_id.binary())
+ profile_info.set_node_ip_address(node_ip_address.encode("ascii"))
for py_profile_event in profile_data:
- profile_event = new GCSProfileEventT()
+ profile_event = profile_info.add_profile_events()
if not isinstance(py_profile_event, dict):
raise TypeError(
"Incorrect type for a profile event. Expected dict "
@@ -325,28 +329,22 @@ cdef class RayletClient:
# that will cause segfaults in the node manager.
for key_string, event_data in py_profile_event.items():
if key_string == "event_type":
- profile_event.event_type = event_data.encode("ascii")
- if profile_event.event_type.length() == 0:
+ if len(event_data) == 0:
raise ValueError(
"'event_type' should not be a null string.")
+ profile_event.set_event_type(event_data.encode("ascii"))
elif key_string == "start_time":
- profile_event.start_time = float(event_data)
+ profile_event.set_start_time(float(event_data))
elif key_string == "end_time":
- profile_event.end_time = float(event_data)
+ profile_event.set_end_time(float(event_data))
elif key_string == "extra_data":
- profile_event.extra_data = event_data.encode("ascii")
- if profile_event.extra_data.length() == 0:
+ if len(event_data) == 0:
raise ValueError(
"'extra_data' should not be a null string.")
+ profile_event.set_extra_data(event_data.encode("ascii"))
else:
raise ValueError(
"Unknown profile event key '%s'" % key_string)
- # Note that profile_info.profile_events is a vector of unique
- # pointers, so profile_event will be deallocated when profile_info
- # goes out of scope. "emplace_back" of vector has not been
- # supported by Cython
- profile_info.profile_events.push_back(
- unique_ptr[GCSProfileEventT](profile_event))
check_status(self.client.get().PushProfileEvents(profile_info))
diff --git a/python/ray/core/generated/ray/__init__.py b/python/ray/core/generated/ray/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/python/ray/core/generated/ray/protocol/__init__.py b/python/ray/core/generated/ray/protocol/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py
index 1e5441485..43ce42d91 100644
--- a/python/ray/gcs_utils.py
+++ b/python/ray/gcs_utils.py
@@ -2,8 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from ray.core.generated.ray.protocol.Task import Task
-
from ray.core.generated.gcs_pb2 import (
ActorCheckpointIdData,
ClientTableData,
@@ -33,7 +31,6 @@ __all__ = [
"ProfileTableData",
"TablePrefix",
"TablePubsub",
- "Task",
"TaskTableData",
"construct_error_message",
]
diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd
index 5c716b673..c9b95aaec 100644
--- a/python/ray/includes/common.pxd
+++ b/python/ray/includes/common.pxd
@@ -87,17 +87,14 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
int parent_task_counter)
-cdef extern from "ray/gcs/format/gcs_generated.h" nogil:
- cdef cppclass GCSArg "Arg":
- pass
-
+cdef extern from "ray/protobuf/common.pb.h" nogil:
cdef cppclass CLanguage "Language":
pass
# This is a workaround for C++ enum class since Cython has no corresponding
# representation.
-cdef extern from "ray/gcs/format/gcs_generated.h" namespace "Language" nogil:
+cdef extern from "ray/protobuf/common.pb.h" namespace "Language" nogil:
cdef CLanguage LANGUAGE_PYTHON "Language::PYTHON"
cdef CLanguage LANGUAGE_CPP "Language::CPP"
cdef CLanguage LANGUAGE_JAVA "Language::JAVA"
diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd
index 3bc6eddd0..7f6d03a5c 100644
--- a/python/ray/includes/libraylet.pxd
+++ b/python/ray/includes/libraylet.pxd
@@ -19,23 +19,23 @@ from ray.includes.unique_ids cimport (
CObjectID,
CTaskID,
)
-from ray.includes.task cimport CTaskSpecification
+from ray.includes.task cimport CTaskSpec
-cdef extern from "ray/gcs/format/gcs_generated.h" nogil:
- cdef cppclass GCSProfileEventT "ProfileEventT":
- c_string event_type
- double start_time
- double end_time
- c_string extra_data
- GCSProfileEventT()
+cdef extern from "ray/protobuf/gcs.pb.h" nogil:
+ cdef cppclass GCSProfileEvent "ProfileTableData::ProfileEvent":
+ void set_event_type(const c_string &value)
+ void set_start_time(double value)
+ void set_end_time(double value)
+ c_string set_extra_data(const c_string &value)
+ GCSProfileEvent()
- cdef cppclass GCSProfileTableDataT "ProfileTableDataT":
- c_string component_type
- c_string component_id
- c_string node_ip_address
- c_vector[unique_ptr[GCSProfileEventT]] profile_events
- GCSProfileTableDataT()
+ cdef cppclass GCSProfileTableData "ProfileTableData":
+ void set_component_type(const c_string &value)
+ void set_component_id(const c_string &value)
+ void set_node_ip_address(const c_string &value)
+ GCSProfileEvent *add_profile_events()
+ GCSProfileTableData()
ctypedef unordered_map[c_string, c_vector[pair[int64_t, double]]] \
@@ -52,8 +52,8 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
CRayStatus Disconnect()
CRayStatus SubmitTask(
const c_vector[CObjectID] &execution_dependencies,
- const CTaskSpecification &task_spec)
- CRayStatus GetTask(unique_ptr[CTaskSpecification] *task_spec)
+ const CTaskSpec &task_spec)
+ CRayStatus GetTask(unique_ptr[CTaskSpec] *task_spec)
CRayStatus TaskDone()
CRayStatus FetchOrReconstruct(c_vector[CObjectID] &object_ids,
c_bool fetch_only,
@@ -66,7 +66,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
CRayStatus PushError(const CJobID &job_id, const c_string &type,
const c_string &error_message, double timestamp)
CRayStatus PushProfileEvents(
- const GCSProfileTableDataT &profile_events)
+ const GCSProfileTableData &profile_events)
CRayStatus FreeObjects(const c_vector[CObjectID] &object_ids,
c_bool local_only, c_bool delete_creating_tasks)
CRayStatus PrepareActorCheckpoint(const CActorID &actor_id,
diff --git a/python/ray/includes/task.pxd b/python/ray/includes/task.pxd
index 5d6511c32..b648c325d 100644
--- a/python/ray/includes/task.pxd
+++ b/python/ray/includes/task.pxd
@@ -1,4 +1,4 @@
-from libc.stdint cimport int64_t, uint8_t
+from libc.stdint cimport uint8_t, uint64_t
from libcpp cimport bool as c_bool
from libcpp.memory cimport unique_ptr, shared_ptr
from libcpp.string cimport string as c_string
@@ -17,72 +17,45 @@ from ray.includes.unique_ids cimport (
CTaskID,
)
+cdef extern from "ray/protobuf/common.pb.h" namespace "ray::rpc" nogil:
+ cdef cppclass RpcTaskSpec "ray::rpc::TaskSpec":
+ void CopyFrom(const RpcTaskSpec &value)
-cdef extern from "ray/raylet/task_execution_spec.h" \
- namespace "ray::raylet" nogil:
- cdef cppclass CTaskExecutionSpecification \
- "ray::raylet::TaskExecutionSpecification":
- CTaskExecutionSpecification(const c_vector[CObjectID] &&dependencies)
- CTaskExecutionSpecification(
- const c_vector[CObjectID] &&dependencies, int num_forwards)
- c_vector[CObjectID] ExecutionDependencies() const
- void SetExecutionDependencies(const c_vector[CObjectID] &dependencies)
- int NumForwards() const
- void IncrementNumForwards()
- int64_t LastTimestamp() const
- void SetLastTimestamp(int64_t new_timestamp)
+ cdef cppclass RpcTaskExecutionSpec "ray::rpc::TaskExecutionSpec":
+ void CopyFrom(const RpcTaskExecutionSpec &value)
+ void add_dependencies(const c_string &value)
+
+ cdef cppclass RpcTask "ray::rpc::Task":
+ RpcTaskSpec *mutable_task_spec()
+
+
+cdef extern from "ray/protobuf/gcs.pb.h" namespace "ray::rpc" nogil:
+ cdef cppclass TaskTableData "ray::rpc::TaskTableData":
+ RpcTask *mutable_task()
+ const c_string &SerializeAsString()
cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
- cdef cppclass CTaskArgument "ray::raylet::TaskArgument":
- pass
-
- cdef cppclass CTaskArgumentByReference \
- "ray::raylet::TaskArgumentByReference":
- CTaskArgumentByReference(const c_vector[CObjectID] &references)
-
- cdef cppclass CTaskArgumentByValue "ray::raylet::TaskArgumentByValue":
- CTaskArgumentByValue(const uint8_t *value, size_t length)
-
- cdef cppclass CTaskSpecification "ray::raylet::TaskSpecification":
- CTaskSpecification(
- const CJobID &job_id, const CTaskID &parent_task_id,
- int64_t parent_counter,
- const c_vector[shared_ptr[CTaskArgument]] &task_arguments,
- int64_t num_returns,
- const unordered_map[c_string, double] &required_resources,
- const CLanguage &language,
- const c_vector[c_string] &function_descriptor)
- CTaskSpecification(
- const CJobID &job_id, const CTaskID &parent_task_id,
- int64_t parent_counter, const CActorID &actor_creation_id,
- const CObjectID &actor_creation_dummy_object_id,
- int64_t max_actor_reconstructions, const CActorID &actor_id,
- const CActorHandleID &actor_handle_id, int64_t actor_counter,
- const c_vector[CActorHandleID] &new_actor_handles,
- const c_vector[shared_ptr[CTaskArgument]] &task_arguments,
- int64_t num_returns,
- const unordered_map[c_string, double] &required_resources,
- const unordered_map[c_string, double] &required_placement_res,
- const CLanguage &language,
- const c_vector[c_string] &function_descriptor)
- CTaskSpecification(const c_string &string)
- c_string SerializeAsString() const
+ cdef cppclass CTaskSpec "ray::raylet::TaskSpecification":
+ CTaskSpec(const RpcTaskSpec message)
+ CTaskSpec(const c_string &serialized_binary)
+ const RpcTaskSpec &GetMessage()
+ c_string Serialize() const
CTaskID TaskId() const
CJobID JobId() const
CTaskID ParentTaskId() const
- int64_t ParentCounter() const
+ uint64_t ParentCounter() const
c_vector[c_string] FunctionDescriptor() const
c_string FunctionDescriptorString() const
- int64_t NumArgs() const
- int64_t NumReturns() const
- c_bool ArgByRef(int64_t arg_index) const
- int ArgIdCount(int64_t arg_index) const
- CObjectID ArgId(int64_t arg_index, int64_t id_index) const
- CObjectID ReturnId(int64_t return_index) const
- const uint8_t *ArgVal(int64_t arg_index) const
- size_t ArgValLength(int64_t arg_index) const
+ uint64_t NumArgs() const
+ uint64_t NumReturns() const
+ c_bool ArgByRef(uint64_t arg_index) const
+ int ArgIdCount(uint64_t arg_index) const
+ CObjectID ArgId(uint64_t arg_index, uint64_t id_index) const
+ CObjectID ReturnId(uint64_t return_index) const
+ const uint8_t *ArgVal(uint64_t arg_index) const
+ size_t ArgValLength(uint64_t arg_index) const
double GetRequiredResource(const c_string &resource_name) const
const ResourceSet GetRequiredResources() const
const ResourceSet GetRequiredPlacementResources() const
@@ -93,25 +66,46 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
c_bool IsActorTask() const
CActorID ActorCreationId() const
CObjectID ActorCreationDummyObjectId() const
- int64_t MaxActorReconstructions() const
+ uint64_t MaxActorReconstructions() const
CActorID ActorId() const
CActorHandleID ActorHandleId() const
- int64_t ActorCounter() const
+ uint64_t ActorCounter() const
CObjectID ActorDummyObject() const
c_vector[CActorHandleID] NewActorHandles() const
+cdef extern from "ray/raylet/task_util.h" namespace "ray::raylet" nogil:
+ cdef cppclass TaskSpecBuilder "ray::raylet::TaskSpecBuilder":
+ TaskSpecBuilder &SetCommonTaskSpec(
+ const CLanguage &language, const c_vector[c_string] &function_descriptor,
+ const CJobID &job_id, const CTaskID &parent_task_id, uint64_t parent_counter,
+ uint64_t num_returns, const unordered_map[c_string, double] &required_resources,
+ const unordered_map[c_string, double] &required_placement_resources);
+
+ TaskSpecBuilder &AddByRefArg(const CObjectID &arg_id);
+
+ TaskSpecBuilder &AddByValueArg(const c_string &data);
+
+ TaskSpecBuilder &SetActorCreationTaskSpec(
+ const CActorID &actor_id, uint64_t max_reconstructions,
+ const c_vector[c_string] &dynamic_worker_options);
+
+ TaskSpecBuilder &SetActorTaskSpec(
+ const CActorID &actor_id, const CActorHandleID &actor_handle_id,
+ const CObjectID &actor_creation_dummy_object_id, uint64_t actor_counter,
+ const c_vector[CActorHandleID] &new_handle_ids);
+
+ RpcTaskSpec GetMessage();
+
+
+cdef extern from "ray/raylet/task_execution_spec.h" namespace "ray::raylet" nogil:
+ cdef cppclass CTaskExecutionSpec "ray::raylet::TaskExecutionSpecification":
+ CTaskExecutionSpec(RpcTaskExecutionSpec message)
+ CTaskExecutionSpec(const c_string &serialized_binary)
+ const RpcTaskExecutionSpec &GetMessage()
+ c_vector[CObjectID] ExecutionDependencies()
+ uint64_t NumForwards()
+
cdef extern from "ray/raylet/task.h" namespace "ray::raylet" nogil:
cdef cppclass CTask "ray::raylet::Task":
- CTask(const CTaskExecutionSpecification &execution_spec,
- const CTaskSpecification &task_spec)
- const CTaskExecutionSpecification &GetTaskExecutionSpec() const
- const CTaskSpecification &GetTaskSpecification() const
- void SetExecutionDependencies(const c_vector[CObjectID] &dependencies)
- void IncrementNumForwards()
- const c_vector[CObjectID] &GetDependencies() const
- void CopyTaskExecutionSpec(const CTask &task)
-
- cdef c_string SerializeTaskAsString(
- const c_vector[CObjectID] *dependencies,
- const CTaskSpecification *task_spec)
+ CTask(CTaskSpec task_spec, CTaskExecutionSpec task_execution_spec)
diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi
index c9e64d83b..42838b08a 100644
--- a/python/ray/includes/task.pxi
+++ b/python/ray/includes/task.pxi
@@ -5,18 +5,19 @@ from libcpp.memory cimport (
static_pointer_cast,
)
from ray.includes.task cimport (
- CTaskArgument,
- CTaskArgumentByReference,
- CTaskArgumentByValue,
- CTaskSpecification,
- SerializeTaskAsString,
+ CTask,
+ CTaskExecutionSpec,
+ CTaskSpec,
+ RpcTaskExecutionSpec,
+ TaskSpecBuilder,
+ TaskTableData,
)
-cdef class Task:
+cdef class TaskSpec:
+ """Cython wrapper class of C++ `ray::raylet::TaskSpecification`."""
cdef:
- unique_ptr[CTaskSpecification] task_spec
- unique_ptr[c_vector[CObjectID]] execution_dependencies
+ unique_ptr[CTaskSpec] task_spec
def __init__(self, JobID job_id, function_descriptor, arguments,
int num_returns, TaskID parent_task_id, int parent_counter,
@@ -24,73 +25,78 @@ cdef class Task:
ObjectID actor_creation_dummy_object_id,
int32_t max_actor_reconstructions, ActorID actor_id,
ActorHandleID actor_handle_id, int actor_counter,
- new_actor_handles, execution_arguments, resource_map,
- placement_resource_map):
+ new_actor_handles, resource_map, placement_resource_map):
cdef:
+ TaskSpecBuilder builder
unordered_map[c_string, double] required_resources
unordered_map[c_string, double] required_placement_resources
- c_vector[shared_ptr[CTaskArgument]] task_args
- c_vector[CActorHandleID] task_new_actor_handles
c_vector[c_string] c_function_descriptor
c_string pickled_str
- c_vector[CObjectID] references
+ c_vector[CActorHandleID] c_new_actor_handles
+ # Convert function descriptor to C++ vector.
for item in function_descriptor:
if not isinstance(item, bytes):
raise TypeError(
"'function_descriptor' takes a list of byte strings.")
c_function_descriptor.push_back(item)
- # Parse the resource map.
+ # Convert resource map to C++ unordered_map.
if resource_map is not None:
required_resources = resource_map_from_dict(resource_map)
if placement_resource_map is not None:
required_placement_resources = (
resource_map_from_dict(placement_resource_map))
- # Parse the arguments from the list.
+ # Build common task spec.
+ builder.SetCommonTaskSpec(
+ LANGUAGE_PYTHON,
+ c_function_descriptor,
+ job_id.native(),
+ parent_task_id.native(),
+ parent_counter,
+ num_returns,
+ required_resources,
+ required_placement_resources,
+ )
+
+ # Build arguments.
for arg in arguments:
if isinstance(arg, ObjectID):
- references = c_vector[CObjectID]()
- references.push_back((arg).native())
- task_args.push_back(
- static_pointer_cast[CTaskArgument,
- CTaskArgumentByReference](
- make_shared[CTaskArgumentByReference](references)))
+ builder.AddByRefArg((arg).native())
else:
pickled_str = pickle.dumps(
arg, protocol=pickle.HIGHEST_PROTOCOL)
- task_args.push_back(
- static_pointer_cast[CTaskArgument,
- CTaskArgumentByValue](
- make_shared[CTaskArgumentByValue](
- pickled_str.c_str(),
- pickled_str.size())))
+ builder.AddByValueArg(pickled_str)
- for new_actor_handle in new_actor_handles:
- task_new_actor_handles.push_back(
- (new_actor_handle).native())
-
- self.task_spec.reset(new CTaskSpecification(
- job_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(),
- actor_creation_dummy_object_id.native(), max_actor_reconstructions, actor_id.native(),
- actor_handle_id.native(), actor_counter, task_new_actor_handles, task_args, num_returns,
- required_resources, required_placement_resources, LANGUAGE_PYTHON,
- c_function_descriptor))
-
- # Set the task's execution dependencies.
- self.execution_dependencies.reset(new c_vector[CObjectID]())
- if execution_arguments is not None:
- for execution_arg in execution_arguments:
- self.execution_dependencies.get().push_back(
- (execution_arg).native())
+ if not actor_creation_id.is_nil():
+ # Actor creation task.
+ builder.SetActorCreationTaskSpec(
+ actor_creation_id.native(),
+ max_actor_reconstructions,
+ [],
+ )
+ elif not actor_id.is_nil():
+ # Actor task.
+ for new_actor_handle in new_actor_handles:
+ c_new_actor_handles.push_back(
+ (new_actor_handle).native())
+ builder.SetActorTaskSpec(
+ actor_id.native(),
+ actor_handle_id.native(),
+ actor_creation_dummy_object_id.native(),
+ actor_counter,
+ c_new_actor_handles,
+ )
+ else:
+ # Normal task.
+ pass
+ self.task_spec.reset(new CTaskSpec(builder.GetMessage()))
@staticmethod
- cdef make(unique_ptr[CTaskSpecification]& task_spec):
- cdef Task self = Task.__new__(Task)
+ cdef make(unique_ptr[CTaskSpec]& task_spec):
+ cdef TaskSpec self = TaskSpec.__new__(TaskSpec)
self.task_spec.reset(task_spec.release())
- # The created task does not include any execution dependencies.
- self.execution_dependencies.reset(new c_vector[CObjectID]())
return self
@staticmethod
@@ -103,11 +109,8 @@ cdef class Task:
Returns:
Python task specification object.
"""
- cdef Task self = Task.__new__(Task)
- # TODO(pcm): Use flatbuffers validation here.
- self.task_spec.reset(new CTaskSpecification(task_spec_str))
- # The created task does not include any execution dependencies.
- self.execution_dependencies.reset(new c_vector[CObjectID]())
+ cdef TaskSpec self = TaskSpec.__new__(TaskSpec)
+ self.task_spec.reset(new CTaskSpec(task_spec_str))
return self
def to_string(self):
@@ -116,11 +119,7 @@ cdef class Task:
Returns:
String representing the task specification.
"""
- return self.task_spec.get().SerializeAsString()
-
- def _serialized_raylet_task(self):
- return SerializeTaskAsString(
- self.execution_dependencies.get(), self.task_spec.get())
+ return self.task_spec.get().Serialize()
def job_id(self):
"""Return the job ID for this task."""
@@ -150,7 +149,7 @@ cdef class Task:
def arguments(self):
"""Return the arguments for the task."""
cdef:
- CTaskSpecification *task_spec = self.task_spec.get()
+ CTaskSpec*task_spec = self.task_spec.get()
int64_t num_args = task_spec.NumArgs()
int32_t lang = task_spec.GetLanguage()
int count
@@ -175,7 +174,7 @@ cdef class Task:
def returns(self):
"""Return the object IDs for the return values of the task."""
- cdef CTaskSpecification *task_spec = self.task_spec.get()
+ cdef CTaskSpec *task_spec = self.task_spec.get()
return_id_list = []
for i in range(task_spec.NumReturns()):
return_id_list.append(ObjectID(task_spec.ReturnId(i).Binary()))
@@ -221,3 +220,59 @@ cdef class Task:
def actor_counter(self):
"""Return the actor counter for this task."""
return self.task_spec.get().ActorCounter()
+
+
+cdef class TaskExecutionSpec:
+ """Cython wrapper class of C++ `ray::raylet::TaskExecutionSpecification`."""
+ cdef:
+ unique_ptr[CTaskExecutionSpec] c_spec
+
+ def __init__(self, execution_dependencies):
+ cdef:
+ RpcTaskExecutionSpec message;
+
+ for dependency in execution_dependencies:
+ message.add_dependencies(
+ (dependency).binary())
+ self.c_spec.reset(new CTaskExecutionSpec(message))
+
+ @staticmethod
+ def from_string(const c_string& string):
+ """Convert a string to a Ray `TaskExecutionSpec` Python object.
+ """
+ cdef TaskExecutionSpec self = TaskExecutionSpec.__new__(TaskExecutionSpec)
+ self.c_spec.reset(new CTaskExecutionSpec(string))
+ return self
+
+ def dependencies(self):
+ cdef:
+ CObjectID c_id
+ c_vector[CObjectID] dependencies = (
+ self.c_spec.get().ExecutionDependencies())
+ ret = []
+ for c_id in dependencies:
+ ret.append(ObjectID(c_id.Binary()))
+ return ret
+
+ def num_forwards(self):
+ return self.c_spec.get().NumForwards()
+
+
+cdef class Task:
+ """Cython wrapper class of C++ `ray::raylet::Task`."""
+ cdef:
+ unique_ptr[CTask] c_task
+
+ def __init__(self, TaskSpec task_spec, TaskExecutionSpec task_execution_spec):
+ self.c_task.reset(new CTask(task_spec.task_spec.get()[0],
+ task_execution_spec.c_spec.get()[0]))
+
+
+def generate_gcs_task_table_data(TaskSpec task_spec):
+ """Converts a Python `TaskSpec` object to serialized GCS `TaskTableData`.
+ """
+ cdef:
+ TaskTableData task_table_data
+ task_table_data.mutable_task().mutable_task_spec().CopyFrom(
+ task_spec.task_spec.get().GetMessage())
+ return task_table_data.SerializeAsString()
diff --git a/python/ray/state.py b/python/ray/state.py
index 75fd8a014..c9f718fde 100644
--- a/python/ray/state.py
+++ b/python/ray/state.py
@@ -305,12 +305,9 @@ class GlobalState(object):
assert len(gcs_entries.entries) == 1
task_table_data = gcs_utils.TaskTableData.FromString(
gcs_entries.entries[0])
- task_table_message = gcs_utils.Task.GetRootAsTask(
- task_table_data.task, 0)
- execution_spec = task_table_message.TaskExecutionSpec()
- task_spec = task_table_message.TaskSpecification()
- task = ray._raylet.Task.from_string(task_spec)
+ task = ray._raylet.TaskSpec.from_string(
+ task_table_data.task.task_spec.SerializeToString())
function_descriptor_list = task.function_descriptor_list()
function_descriptor = FunctionDescriptor.from_bytes_list(
function_descriptor_list)
@@ -335,14 +332,12 @@ class GlobalState(object):
"FunctionName": function_descriptor.function_name,
}
+ execution_spec = ray._raylet.TaskExecutionSpec.from_string(
+ task_table_data.task.task_execution_spec.SerializeToString())
return {
"ExecutionSpec": {
- "Dependencies": [
- execution_spec.Dependencies(i)
- for i in range(execution_spec.DependenciesLength())
- ],
- "LastTimestamp": execution_spec.LastTimestamp(),
- "NumForwards": execution_spec.NumForwards()
+ "Dependencies": execution_spec.dependencies(),
+ "NumForwards": execution_spec.num_forwards(),
},
"TaskSpec": task_spec_info
}
diff --git a/python/ray/worker.py b/python/ray/worker.py
index 092061e84..5a3146255 100644
--- a/python/ray/worker.py
+++ b/python/ray/worker.py
@@ -688,7 +688,7 @@ class Worker(object):
function_descriptor_list = (
function_descriptor.get_function_descriptor_list())
assert isinstance(job_id, JobID)
- task = ray._raylet.Task(
+ task = ray._raylet.TaskSpec(
job_id,
function_descriptor_list,
args_for_raylet,
@@ -702,11 +702,10 @@ class Worker(object):
actor_handle_id,
actor_counter,
new_actor_handles,
- execution_dependencies,
resources,
placement_resources,
)
- self.raylet_client.submit_task(task)
+ self.raylet_client.submit_task(task, execution_dependencies)
return task.returns()
@@ -1864,7 +1863,7 @@ def connect(node,
nil_actor_counter = 0
function_descriptor = FunctionDescriptor.for_driver_task()
- driver_task = ray._raylet.Task(
+ driver_task_spec = ray._raylet.TaskSpec(
worker.current_job_id,
function_descriptor.get_function_descriptor_list(),
[], # arguments.
@@ -1878,24 +1877,25 @@ def connect(node,
ActorHandleID.nil(), # actor_handle_id.
nil_actor_counter, # actor_counter.
[], # new_actor_handles.
- [], # execution_dependencies.
{}, # resource_map.
{}, # placement_resource_map.
)
- task_table_data = ray.gcs_utils.TaskTableData()
- task_table_data.task = driver_task._serialized_raylet_task()
+ task_table_data = ray._raylet.generate_gcs_task_table_data(
+ driver_task_spec)
# Add the driver task to the task table.
ray.state.state._execute_command(
- driver_task.task_id(), "RAY.TABLE_ADD",
+ driver_task_spec.task_id(),
+ "RAY.TABLE_ADD",
ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"),
ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"),
- driver_task.task_id().binary(),
- task_table_data.SerializeToString())
+ driver_task_spec.task_id().binary(),
+ task_table_data,
+ )
# Set the driver's current task ID to the task ID assigned to the
# driver task.
- worker.task_context.current_task_id = driver_task.task_id()
+ worker.task_context.current_task_id = driver_task_spec.task_id()
worker.raylet_client = ray._raylet.RayletClient(
node.raylet_socket_name,
diff --git a/python/setup.py b/python/setup.py
index a0b5376ca..64e366157 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -28,11 +28,10 @@ ray_files = [
"ray/dashboard/res/main.css", "ray/dashboard/res/main.js"
]
-# These are the directories where automatically generated Python flatbuffer
+# These are the directories where automatically generated Python protobuf
# bindings are created.
generated_python_directories = [
- "ray/core/generated", "ray/core/generated/ray",
- "ray/core/generated/ray/protocol"
+ "ray/core/generated",
]
optional_ray_files = []
@@ -88,7 +87,7 @@ class build_ext(_build_ext.build_ext):
files_to_include = ray_files + pyarrow_files + modin_files
- # Copy over the autogenerated flatbuffer Python bindings.
+ # Copy over the autogenerated protobuf Python bindings.
for directory in generated_python_directories:
for filename in os.listdir(directory):
if filename[-3:] == ".py":
@@ -148,7 +147,6 @@ requires = [
# NOTE: Don't upgrade the version of six! Doing so causes installation
# problems. See https://github.com/ray-project/ray/issues/4169.
"six >= 1.0.0",
- "flatbuffers",
"faulthandler;python_version<'3.3'",
"protobuf >= 3.8.0",
]
diff --git a/src/ray/common/common_protocol.cc b/src/ray/common/common_protocol.cc
index adce684fc..a58808a36 100644
--- a/src/ray/common/common_protocol.cc
+++ b/src/ray/common/common_protocol.cc
@@ -6,28 +6,6 @@ std::string string_from_flatbuf(const flatbuffers::String &string) {
return std::string(string.data(), string.size());
}
-const std::unordered_map map_from_flatbuf(
- const flatbuffers::Vector> &resource_vector) {
- std::unordered_map required_resources;
- for (int64_t i = 0; i < resource_vector.size(); i++) {
- const ResourcePair *resource_pair = resource_vector.Get(i);
- required_resources[string_from_flatbuf(*resource_pair->key())] =
- resource_pair->value();
- }
- return required_resources;
-}
-
-flatbuffers::Offset>>
-map_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
- const std::unordered_map &resource_map) {
- std::vector> resource_vector;
- for (auto const &resource_pair : resource_map) {
- resource_vector.push_back(CreateResourcePair(
- fbb, fbb.CreateString(resource_pair.first), resource_pair.second));
- }
- return fbb.CreateVector(resource_vector);
-}
-
std::vector string_vec_from_flatbuf(
const flatbuffers::Vector> &flatbuf_vec) {
std::vector string_vector;
diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h
index 1dbd6bbc2..3eff4cd8e 100644
--- a/src/ray/common/common_protocol.h
+++ b/src/ray/common/common_protocol.h
@@ -1,8 +1,7 @@
#ifndef COMMON_PROTOCOL_H
#define COMMON_PROTOCOL_H
-#include "ray/gcs/format/gcs_generated.h"
-
+#include
#include
#include "ray/common/id.h"
@@ -76,24 +75,6 @@ to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids);
/// @return The std::string version of the flatbuffer string.
std::string string_from_flatbuf(const flatbuffers::String &string);
-/// Convert a std::unordered_map to a flatbuffer vector of pairs.
-///
-/// @param fbb Reference to the flatbuffer builder.
-/// @param resource_map A mapping from resource name to resource quantity.
-/// @return A flatbuffer vector of ResourcePair objects.
-flatbuffers::Offset>>
-map_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
- const std::unordered_map &resource_map);
-
-/// Convert a flatbuffer vector of ResourcePair objects to a std::unordered map
-/// from resource name to resource quantity.
-///
-/// @param fbb Reference to the flatbuffer builder.
-/// @param resource_vector The flatbuffer object.
-/// @return A map from resource name to resource quantity.
-const std::unordered_map map_from_flatbuf(
- const flatbuffers::Vector> &resource_vector);
-
std::vector string_vec_from_flatbuf(
const flatbuffers::Vector> &flatbuf_vec);
diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h
index ee20e88f4..d15e7c8d8 100644
--- a/src/ray/core_worker/common.h
+++ b/src/ray/core_worker/common.h
@@ -5,12 +5,14 @@
#include "ray/common/buffer.h"
#include "ray/common/id.h"
-#include "ray/gcs/format/gcs_generated.h"
#include "ray/raylet/raylet_client.h"
#include "ray/raylet/task_spec.h"
namespace ray {
+using rpc::Language;
+using rpc::TaskType;
+
/// Type of this worker.
enum class WorkerType { WORKER, DRIVER };
@@ -66,8 +68,6 @@ class TaskArg {
const std::shared_ptr data_;
};
-enum class TaskType { NORMAL_TASK, ACTOR_CREATION_TASK, ACTOR_TASK };
-
/// Information of a task
struct TaskInfo {
/// The ID of task.
diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc
index fe6ed290d..661996630 100644
--- a/src/ray/core_worker/core_worker.cc
+++ b/src/ray/core_worker/core_worker.cc
@@ -3,7 +3,7 @@
namespace ray {
-CoreWorker::CoreWorker(const enum WorkerType worker_type, const ::Language language,
+CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id)
: worker_type_(worker_type),
diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h
index 8a3347b2a..c61b50c78 100644
--- a/src/ray/core_worker/core_worker.h
+++ b/src/ray/core_worker/core_worker.h
@@ -7,7 +7,6 @@
#include "ray/core_worker/object_interface.h"
#include "ray/core_worker/task_execution.h"
#include "ray/core_worker/task_interface.h"
-#include "ray/gcs/format/gcs_generated.h"
#include "ray/raylet/raylet_client.h"
namespace ray {
@@ -23,7 +22,7 @@ class CoreWorker {
/// \param[in] langauge Language of this worker.
///
/// NOTE(zhijunfu): the constructor would throw if a failure happens.
- CoreWorker(const WorkerType worker_type, const ::Language language,
+ CoreWorker(const WorkerType worker_type, const Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id = JobID::Nil());
@@ -31,7 +30,7 @@ class CoreWorker {
enum WorkerType WorkerType() const { return worker_type_; }
/// Language of this worker.
- ::Language Language() const { return language_; }
+ enum Language Language() const { return language_; }
/// Return the `CoreWorkerTaskInterface` that contains the methods related to task
/// submisson.
@@ -53,7 +52,7 @@ class CoreWorker {
const enum WorkerType worker_type_;
/// Language of this worker.
- const ::Language language_;
+ const enum Language language_;
/// raylet socket name.
const std::string raylet_socket_;
diff --git a/src/ray/core_worker/core_worker_test.cc b/src/ray/core_worker/core_worker_test.cc
index cf40e175f..3eb58595a 100644
--- a/src/ray/core_worker/core_worker_test.cc
+++ b/src/ray/core_worker/core_worker_test.cc
@@ -301,8 +301,7 @@ TEST_F(ZeroNodeTest, TestWorkerContext) {
}
TEST_F(ZeroNodeTest, TestActorHandle) {
- ActorHandle handle1(ActorID::FromRandom(), ActorHandleID::FromRandom(),
- ::Language::JAVA,
+ ActorHandle handle1(ActorID::FromRandom(), ActorHandleID::FromRandom(), Language::JAVA,
{"org.ray.exampleClass", "exampleMethod", "exampleSignature"});
auto forkedHandle1 = handle1.Fork();
diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc
index e85ce28c9..da6bf25fa 100644
--- a/src/ray/core_worker/task_interface.cc
+++ b/src/ray/core_worker/task_interface.cc
@@ -8,11 +8,11 @@ namespace ray {
ActorHandle::ActorHandle(
const class ActorID &actor_id, const class ActorHandleID &actor_handle_id,
- const ::Language actor_language,
+ const Language actor_language,
const std::vector &actor_creation_task_function_descriptor) {
inner_.set_actor_id(actor_id.Data(), actor_id.Size());
inner_.set_actor_handle_id(actor_handle_id.Data(), actor_handle_id.Size());
- inner_.set_actor_language(static_cast(actor_language));
+ inner_.set_actor_language(actor_language);
*inner_.mutable_actor_creation_task_function_descriptor() = {
actor_creation_task_function_descriptor.begin(),
actor_creation_task_function_descriptor.end()};
@@ -30,9 +30,7 @@ ray::ActorHandleID ActorHandle::ActorHandleID() const {
return ActorHandleID::FromBinary(inner_.actor_handle_id());
};
-::Language ActorHandle::ActorLanguage() const {
- return (::Language)inner_.actor_language();
-};
+Language ActorHandle::ActorLanguage() const { return inner_.actor_language(); };
std::vector ActorHandle::ActorCreationTaskFunctionDescriptor() const {
return ray::rpc::VectorFromProtobuf(inner_.actor_creation_task_function_descriptor());
@@ -100,30 +98,43 @@ CoreWorkerTaskInterface::CoreWorkerTaskInterface(
new CoreWorkerRayletTaskSubmitter(raylet_client)));
}
-Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function,
- const std::vector &args,
- const TaskOptions &task_options,
- std::vector *return_ids) {
- auto &context = worker_context_;
- auto next_task_index = context.GetNextTaskIndex();
- const auto task_id = GenerateTaskId(context.GetCurrentJobID(),
- context.GetCurrentTaskID(), next_task_index);
+raylet::TaskSpecBuilder CoreWorkerTaskInterface::BuildCommonTaskSpec(
+ const RayFunction &function, const std::vector &args, uint64_t num_returns,
+ const std::unordered_map &required_resources,
+ const std::unordered_map &required_placement_resources,
+ std::vector *return_ids) {
+ raylet::TaskSpecBuilder builder;
+ auto next_task_index = worker_context_.GetNextTaskIndex();
+ // Build common task spec.
+ builder.SetCommonTaskSpec(
+ function.language, function.function_descriptor, worker_context_.GetCurrentJobID(),
+ worker_context_.GetCurrentTaskID(), next_task_index, num_returns,
+ required_resources, required_placement_resources);
+ // Set task arguments.
+ for (const auto &arg : args) {
+ if (arg.IsPassedByReference()) {
+ builder.AddByRefArg(arg.GetReference());
+ } else {
+ builder.AddByValueArg(arg.GetValue()->Data(), arg.GetValue()->Size());
+ }
+ }
- auto num_returns = task_options.num_returns;
+ // Compute return IDs.
+ const auto task_id = TaskID::FromBinary(builder.GetMessage().task_id());
(*return_ids).resize(num_returns);
for (int i = 0; i < num_returns; i++) {
(*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1);
}
+ return builder;
+}
- auto task_arguments = BuildTaskArguments(args);
-
- ray::raylet::TaskSpecification spec(context.GetCurrentJobID(),
- context.GetCurrentTaskID(), next_task_index,
- task_arguments, num_returns, task_options.resources,
- function.language, function.function_descriptor);
-
- std::vector execution_dependencies;
- TaskSpec task(std::move(spec), execution_dependencies);
+Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function,
+ const std::vector &args,
+ const TaskOptions &task_options,
+ std::vector *return_ids) {
+ auto builder = BuildCommonTaskSpec(function, args, task_options.num_returns,
+ task_options.resources, {}, return_ids);
+ TaskSpec task(builder.Build(), {});
return task_submitters_[static_cast(TaskTransportType::RAYLET)]->SubmitTask(task);
}
@@ -131,33 +142,20 @@ Status CoreWorkerTaskInterface::CreateActor(
const RayFunction &function, const std::vector &args,
const ActorCreationOptions &actor_creation_options,
std::unique_ptr *actor_handle) {
- auto &context = worker_context_;
- auto next_task_index = context.GetNextTaskIndex();
- const auto task_id = GenerateTaskId(context.GetCurrentJobID(),
- context.GetCurrentTaskID(), next_task_index);
-
std::vector return_ids;
- return_ids.push_back(ObjectID::ForTaskReturn(task_id, 1));
- ActorID actor_creation_id = ActorID::FromBinary(return_ids[0].Binary());
- *actor_handle = std::unique_ptr(
- new ActorHandle(actor_creation_id, ActorHandleID::Nil(), function.language,
- function.function_descriptor));
+ auto builder = BuildCommonTaskSpec(function, args, 1, actor_creation_options.resources,
+ actor_creation_options.resources, &return_ids);
+
+ const ActorID actor_id = ActorID::FromBinary(return_ids[0].Binary());
+ builder.SetActorCreationTaskSpec(actor_id, actor_creation_options.max_reconstructions,
+ {});
+
+ *actor_handle = std::unique_ptr(new ActorHandle(
+ actor_id, ActorHandleID::Nil(), function.language, function.function_descriptor));
(*actor_handle)->IncreaseTaskCounter();
(*actor_handle)->SetActorCursor(return_ids[0]);
- auto task_arguments = BuildTaskArguments(args);
-
- // Note that the caller is supposed to specify required placement resources
- // correctly via actor_creation_options.resources.
- ray::raylet::TaskSpecification spec(
- context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index,
- actor_creation_id, ObjectID::Nil(), actor_creation_options.max_reconstructions,
- ActorID::Nil(), ActorHandleID::Nil(), 0, {}, task_arguments, 1,
- actor_creation_options.resources, actor_creation_options.resources,
- function.language, function.function_descriptor);
-
- std::vector execution_dependencies;
- TaskSpec task(std::move(spec), execution_dependencies);
+ const TaskSpec task(builder.Build(), {});
return task_submitters_[static_cast(TaskTransportType::RAYLET)]->SubmitTask(task);
}
@@ -166,65 +164,37 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle,
const std::vector &args,
const TaskOptions &task_options,
std::vector *return_ids) {
- auto &context = worker_context_;
- auto next_task_index = context.GetNextTaskIndex();
- const auto task_id = GenerateTaskId(context.GetCurrentJobID(),
- context.GetCurrentTaskID(), next_task_index);
+ // Add one for actor cursor object id.
+ const auto num_returns = task_options.num_returns + 1;
- // add one for actor cursor object id.
- auto num_returns = task_options.num_returns + 1;
- (*return_ids).resize(num_returns);
- for (int i = 0; i < num_returns; i++) {
- (*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1);
- }
-
- auto actor_creation_dummy_object_id =
- ObjectID::FromBinary(actor_handle.ActorID().Binary());
-
- auto task_arguments = BuildTaskArguments(args);
+ // Build common task spec.
+ auto builder = BuildCommonTaskSpec(function, args, num_returns, task_options.resources,
+ {}, return_ids);
std::unique_lock guard(actor_handle.mutex_);
+ // Build actor task spec.
+ const auto actor_creation_dummy_object_id =
+ ObjectID::FromBinary(actor_handle.ActorID().Binary());
+ builder.SetActorTaskSpec(actor_handle.ActorID(), actor_handle.ActorHandleID(),
+ actor_creation_dummy_object_id,
+ actor_handle.IncreaseTaskCounter(),
+ actor_handle.NewActorHandles());
- ray::raylet::TaskSpecification spec(
- context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index,
- ActorID::Nil(), actor_creation_dummy_object_id, 0, actor_handle.ActorID(),
- actor_handle.ActorHandleID(), actor_handle.IncreaseTaskCounter(),
- actor_handle.NewActorHandles(), task_arguments, num_returns, task_options.resources,
- task_options.resources, function.language, function.function_descriptor);
-
- std::vector execution_dependencies;
- execution_dependencies.push_back(actor_handle.ActorCursor());
+ const TaskSpec task(builder.Build(), {actor_handle.ActorCursor()});
+ // Manipulate actor handle state.
auto actor_cursor = (*return_ids).back();
actor_handle.SetActorCursor(actor_cursor);
actor_handle.ClearNewActorHandles();
-
guard.unlock();
- TaskSpec task(std::move(spec), execution_dependencies);
+ // Submit task.
auto status =
task_submitters_[static_cast(TaskTransportType::RAYLET)]->SubmitTask(task);
- // remove cursor from return ids.
+ // Remove cursor from return ids.
(*return_ids).pop_back();
return status;
}
-std::vector>
-CoreWorkerTaskInterface::BuildTaskArguments(const std::vector &args) {
- std::vector> task_arguments;
- for (const auto &arg : args) {
- if (arg.IsPassedByReference()) {
- std::vector references{arg.GetReference()};
- task_arguments.push_back(
- std::make_shared(references));
- } else {
- auto data = arg.GetValue();
- task_arguments.push_back(
- std::make_shared(data->Data(), data->Size()));
- }
- }
- return task_arguments;
-}
-
} // namespace ray
diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h
index aa5876c00..85c6d8ddb 100644
--- a/src/ray/core_worker/task_interface.h
+++ b/src/ray/core_worker/task_interface.h
@@ -9,10 +9,14 @@
#include "ray/core_worker/transport/transport.h"
#include "ray/protobuf/core_worker.pb.h"
#include "ray/raylet/task.h"
+#include "ray/raylet/task_spec.h"
+#include "ray/raylet/task_util.h"
#include "ray/rpc/util.h"
namespace ray {
+using rpc::Language;
+
class CoreWorker;
/// Options of a non-actor-creation task.
@@ -45,7 +49,7 @@ struct ActorCreationOptions {
class ActorHandle {
public:
ActorHandle(const ActorID &actor_id, const ActorHandleID &actor_handle_id,
- const ::Language actor_language,
+ const Language actor_language,
const std::vector &actor_creation_task_function_descriptor);
ActorHandle(const ActorHandle &other);
@@ -57,7 +61,7 @@ class ActorHandle {
ray::ActorHandleID ActorHandleID() const;
/// Language of the actor.
- ::Language ActorLanguage() const;
+ Language ActorLanguage() const;
// Function descriptor of actor creation task.
std::vector ActorCreationTaskFunctionDescriptor() const;
@@ -149,12 +153,21 @@ class CoreWorkerTaskInterface {
std::vector *return_ids);
private:
- /// Build the arguments for a task spec.
+ /// Build common attributes of the task spec, and compute return ids.
///
- /// \param[in] args Arguments of a task.
- /// \return Arguments as required by task spec.
- std::vector> BuildTaskArguments(
- const std::vector &args);
+ /// \param[in] function The remote function to execute.
+ /// \param[in] args Arguments of this task.
+ /// \param[in] num_returns Number of returns.
+ /// \param[in] required_resources Resources required by this task.
+ /// \param[in] required_placement_resources Resources required by placing this task on a
+ /// node.
+ /// \param[out] return_ids Return IDs.
+ /// \return A `TaskSpecBuilder`.
+ raylet::TaskSpecBuilder BuildCommonTaskSpec(
+ const RayFunction &function, const std::vector &args, uint64_t num_returns,
+ const std::unordered_map &required_resources,
+ const std::unordered_map &required_placement_resources,
+ std::vector *return_ids);
/// Reference to the parent CoreWorker's context.
WorkerContext &worker_context_;
diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc
index b63a5ef89..df9d2fbe9 100644
--- a/src/ray/core_worker/transport/raylet_transport.cc
+++ b/src/ray/core_worker/transport/raylet_transport.cc
@@ -38,11 +38,8 @@ CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(
void CoreWorkerRayletTaskReceiver::HandleAssignTask(
const rpc::AssignTaskRequest &request, rpc::AssignTaskReply *reply,
rpc::RequestDoneCallback done_callback) {
- const std::string &task_message = request.task_spec();
- const raylet::Task task(*flatbuffers::GetRoot(
- reinterpret_cast(task_message.data())));
+ const raylet::Task task(request.task());
const auto &spec = task.GetTaskSpecification();
-
auto status = task_handler_(spec);
done_callback(status);
}
diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc
index fb4fdedb3..bf4bd5550 100644
--- a/src/ray/gcs/client_test.cc
+++ b/src/ray/gcs/client_test.cc
@@ -13,10 +13,6 @@ namespace ray {
namespace gcs {
-namespace {
-constexpr char kRandomId[] = "abcdefghijklmnopqrst";
-} // namespace
-
/* Flush redis. */
static inline void flushall_redis(void) {
redisContext *context = redisConnect("127.0.0.1", 6379);
@@ -82,23 +78,40 @@ class TestGcsWithChainAsio : public TestGcsWithAsio {
TestGcsWithChainAsio() : TestGcsWithAsio(gcs::CommandType::kChain){};
};
-void TestTableLookup(const JobID &job_id, std::shared_ptr client) {
- TaskID task_id = TaskID::FromRandom();
+/// A helper function that creates a GCS `TaskTableData` object.
+std::shared_ptr CreateTaskTableData(const TaskID &task_id,
+ uint64_t num_returns = 0) {
auto data = std::make_shared();
- data->set_task("123");
+ data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary());
+ data->mutable_task()->mutable_task_spec()->set_num_returns(num_returns);
+ return data;
+}
+
+/// A helper function that compare wether 2 `TaskTableData` objects are equal.
+/// Note, this function only compares fields set by `CreateTaskTableData`.
+bool TaskTableDataEqual(const TaskTableData &data1, const TaskTableData &data2) {
+ const auto &spec1 = data1.task().task_spec();
+ const auto &spec2 = data2.task().task_spec();
+ return (spec1.task_id() == spec2.task_id() &&
+ spec1.num_returns() == spec2.num_returns());
+}
+
+void TestTableLookup(const JobID &job_id, std::shared_ptr client) {
+ const auto task_id = TaskID::FromRandom();
+ const auto data = CreateTaskTableData(task_id);
// Check that we added the correct task.
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id,
const TaskTableData &d) {
ASSERT_EQ(id, task_id);
- ASSERT_EQ(data->task(), d.task());
+ ASSERT_TRUE(TaskTableDataEqual(*data, d));
};
// Check that the lookup returns the added task.
auto lookup_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id,
const TaskTableData &d) {
ASSERT_EQ(id, task_id);
- ASSERT_EQ(data->task(), d.task());
+ ASSERT_TRUE(TaskTableDataEqual(*data, d));
test->Stop();
};
@@ -386,7 +399,7 @@ void TestDeleteKeysFromTable(const JobID &job_id,
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id,
const TaskTableData &d) {
ASSERT_EQ(id, task_id);
- ASSERT_EQ(data->task(), d.task());
+ ASSERT_TRUE(TaskTableDataEqual(*data, d));
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback));
@@ -501,9 +514,7 @@ void TestDeleteKeys(const JobID &job_id, std::shared_ptr cl
std::vector> task_vector;
auto AppendTaskData = [&task_vector](size_t add_count) {
for (size_t i = 0; i < add_count; ++i) {
- auto task_data = std::make_shared();
- task_data->set_task(ObjectID::FromRandom().Hex());
- task_vector.push_back(task_data);
+ task_vector.push_back(CreateTaskTableData(TaskID::FromRandom()));
}
};
AppendTaskData(1);
@@ -682,25 +693,26 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeAll) {
void TestTableSubscribeId(const JobID &job_id,
std::shared_ptr client) {
+ int num_modifications = 3;
+
// Add a table entry.
TaskID task_id1 = TaskID::FromRandom();
- std::vector task_specs1 = {"abc", "def", "ghi"};
// Add a table entry at a second key.
TaskID task_id2 = TaskID::FromRandom();
- std::vector task_specs2 = {"jkl", "mno", "pqr"};
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
- auto notification_callback = [task_id2, task_specs2](gcs::AsyncGcsClient *client,
- const TaskID &id,
- const TaskTableData &data) {
+ auto notification_callback = [task_id2, num_modifications](gcs::AsyncGcsClient *client,
+ const TaskID &id,
+ const TaskTableData &data) {
// Check that we only get notifications for the requested key.
ASSERT_EQ(id, task_id2);
// Check that we get notifications in the same order as the writes.
- ASSERT_EQ(data.task(), task_specs2[test->NumCallbacks()]);
+ ASSERT_TRUE(
+ TaskTableDataEqual(data, *CreateTaskTableData(task_id2, test->NumCallbacks())));
test->IncrementNumCallbacks();
- if (test->NumCallbacks() == task_specs2.size()) {
+ if (test->NumCallbacks() == num_modifications) {
test->Stop();
}
};
@@ -717,21 +729,19 @@ void TestTableSubscribeId(const JobID &job_id,
// The callback for subscription success. Once we've subscribed, request
// notifications for only one of the keys, then write to both keys.
- auto subscribe_callback = [job_id, task_id1, task_id2, task_specs1,
- task_specs2](gcs::AsyncGcsClient *client) {
+ auto subscribe_callback = [job_id, task_id1, task_id2,
+ num_modifications](gcs::AsyncGcsClient *client) {
// Request notifications for one of the keys.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id2, client->client_table().GetLocalClientId()));
// Write both keys. We should only receive notifications for the key that
// we requested them for.
- for (const auto &task_spec : task_specs1) {
- auto data = std::make_shared();
- data->set_task(task_spec);
+ for (uint64_t i = 0; i < num_modifications; i++) {
+ auto data = CreateTaskTableData(task_id1, i);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id1, data, nullptr));
}
- for (const auto &task_spec : task_specs2) {
- auto data = std::make_shared();
- data->set_task(task_spec);
+ for (uint64_t i = 0; i < num_modifications; i++) {
+ auto data = CreateTaskTableData(task_id2, i);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id2, data, nullptr));
}
};
@@ -749,7 +759,7 @@ void TestTableSubscribeId(const JobID &job_id,
ASSERT_TRUE(failure_notification_received);
// Check that we received one notification callback for each write to the
// requested key.
- ASSERT_EQ(test->NumCallbacks(), task_specs2.size());
+ ASSERT_EQ(test->NumCallbacks(), num_modifications);
}
TEST_MACRO(TestGcsWithAsio, TestTableSubscribeId);
@@ -910,10 +920,9 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeId) {
void TestTableSubscribeCancel(const JobID &job_id,
std::shared_ptr client) {
// Add a table entry.
- TaskID task_id = TaskID::FromRandom();
- std::vector task_specs = {"jkl", "mno", "pqr"};
- auto data = std::make_shared();
- data->set_task(task_specs[0]);
+ const auto task_id = TaskID::FromRandom();
+ const int num_modifications = 3;
+ const auto data = CreateTaskTableData(task_id, 0);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr));
// The failure callback should not be called since all keys are non-empty
@@ -924,26 +933,26 @@ void TestTableSubscribeCancel(const JobID &job_id,
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
- auto notification_callback = [task_id, task_specs](gcs::AsyncGcsClient *client,
- const TaskID &id,
- const TaskTableData &data) {
+ auto notification_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id,
+ const TaskTableData &data) {
ASSERT_EQ(id, task_id);
// Check that we only get notifications for the first and last writes,
// since notifications are canceled in between.
if (test->NumCallbacks() == 0) {
- ASSERT_EQ(data.task(), task_specs.front());
+ ASSERT_TRUE(TaskTableDataEqual(data, *CreateTaskTableData(task_id, 0)));
} else {
- ASSERT_EQ(data.task(), task_specs.back());
+ ASSERT_TRUE(
+ TaskTableDataEqual(data, *CreateTaskTableData(task_id, num_modifications - 1)));
}
test->IncrementNumCallbacks();
- if (test->NumCallbacks() == 2) {
+ if (test->NumCallbacks() == num_modifications - 1) {
test->Stop();
}
};
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
- auto subscribe_callback = [job_id, task_id, task_specs](gcs::AsyncGcsClient *client) {
+ auto subscribe_callback = [job_id, task_id](gcs::AsyncGcsClient *client) {
// Request notifications, then cancel immediately. We should receive a
// notification for the current value at the key.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
@@ -952,10 +961,8 @@ void TestTableSubscribeCancel(const JobID &job_id,
job_id, task_id, client->client_table().GetLocalClientId()));
// Write to the key. Since we canceled notifications, we should not receive
// a notification for these writes.
- auto remaining = std::vector(++task_specs.begin(), task_specs.end());
- for (const auto &task_spec : remaining) {
- auto data = std::make_shared();
- data->set_task(task_spec);
+ for (uint64_t i = 1; i < num_modifications; i++) {
+ auto data = CreateTaskTableData(task_id, i);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr));
}
// Request notifications again. We should receive a notification for the
diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs
deleted file mode 100644
index 4ab9f385a..000000000
--- a/src/ray/gcs/format/gcs.fbs
+++ /dev/null
@@ -1,105 +0,0 @@
-// TODO(hchen): Migrate data structures in this file to protobuf (`gcs.proto`).
-
-enum Language:int {
- PYTHON=0,
- JAVA=1,
- CPP=2,
-}
-
-table Arg {
- // Object ID for pass-by-reference arguments. Normally there is only one
- // object ID in this list which represents the object that is being passed.
- // However to support reducers in a MapReduce workload, we also support
- // passing multiple object IDs for each argument.
- // Note that this is a long string that concatenate all of the object IDs.
- object_ids: string;
- // Data for pass-by-value arguments.
- data: string;
-}
-
-table TaskInfo {
- // ID of the job that created this task.
- job_id: string;
- // Task ID of the task.
- task_id: string;
- // Task ID of the parent task.
- parent_task_id: string;
- // A count of the number of tasks submitted by the parent task before this one.
- parent_counter: int;
- // The ID of the actor to create if this is an actor creation task.
- actor_creation_id: string;
- // The dummy object ID of the actor creation task if this is an actor method.
- actor_creation_dummy_object_id: string;
- // The max number of times this actor should be recontructed.
- // If this number of 0 or negative, the actor won't be reconstructed on failure.
- max_actor_reconstructions: int;
- // Actor ID of the task. This is the actor that this task is executed on
- // or NIL_ACTOR_ID if the task is just a normal task.
- actor_id: string;
- // The ID of the handle that was used to submit the task. This should be
- // unique across handles with the same actor_id.
- actor_handle_id: string;
- // Number of tasks that have been submitted to this actor so far.
- actor_counter: int;
- // If this is an actor task, then this will be populated with all of the new
- // actor handles that were forked from this handle since the last task on
- // this handle was submitted.
- // Note that this is a long string that concatenate all of the new_actor_handle IDs.
- new_actor_handles: string;
- // Task arguments.
- args: [Arg];
- // Number of return objects.
- num_returns: int;
- // The required_resources vector indicates the quantities of the different
- // resources required by this task.
- required_resources: [ResourcePair];
- // The resources required for placing this task on a node. If this is empty,
- // then the placement resources are equal to the required_resources.
- required_placement_resources: [ResourcePair];
- // The language that this task belongs to.
- language: Language;
- // Function descriptor, which is a list of strings that can
- // uniquely describe a function.
- // For a Python function, it should be: [module_name, class_name, function_name]
- // For a Java function, it should be: [class_name, method_name, type_descriptor]
- function_descriptor: [string];
- // The dynamic options used in the worker command when starting the worker process for
- // an actor creation task. If the list isn't empty, the options will be used to replace
- // the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the
- // worker command.
- dynamic_worker_options: [string];
-}
-
-table ResourcePair {
- // The name of the resource.
- key: string;
- // The quantity of the resource.
- value: double;
-}
-
-table ProfileEvent {
- // The type of the event.
- event_type: string;
- // The start time of the event.
- start_time: double;
- // The end time of the event. If the event is a point event, then this should
- // be the same as the start time.
- end_time: double;
- // Additional data associated with the event. This data must be serialized
- // using JSON.
- extra_data: string;
-}
-
-table ProfileTableData {
- // The type of the component that generated the event, e.g., worker or
- // object_manager, or node_manager.
- component_type: string;
- // An identifier for the component that generated the event.
- component_id: string;
- // An identifier for the node that generated the event.
- node_ip_address: string;
- // This is a batch of profiling events. We batch these together for
- // performance reasons because a single task may generate many events, and
- // we don't want each event to require a GCS command.
- profile_events: [ProfileEvent];
-}
diff --git a/src/ray/gcs/format/util.h b/src/ray/gcs/format/util.h
deleted file mode 100644
index 85bc0511d..000000000
--- a/src/ray/gcs/format/util.h
+++ /dev/null
@@ -1,24 +0,0 @@
-#ifndef RAY_RAYLET_GCS_FORMAT_UTIL_H
-#define RAY_RAYLET_GCS_FORMAT_UTIL_H
-
-#include "ray/gcs/format/gcs_generated.h"
-
-namespace std {
-
-template <>
-struct hash {
- size_t operator()(const Language &language) const {
- return std::hash()(static_cast(language));
- }
-};
-
-template <>
-struct hash {
- size_t operator()(const Language &language) const {
- return std::hash()(static_cast(language));
- }
-};
-
-} // namespace std
-
-#endif // RAY_RAYLET_GCS_FORMAT_UTIL_H
diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc
index c3a82c320..197f36179 100644
--- a/src/ray/gcs/redis_module/ray_redis_module.cc
+++ b/src/ray/gcs/redis_module/ray_redis_module.cc
@@ -4,7 +4,6 @@
#include "ray/common/common_protocol.h"
#include "ray/common/id.h"
#include "ray/common/status.h"
-#include "ray/gcs/format/gcs_generated.h"
#include "ray/protobuf/gcs.pb.h"
#include "ray/util/logging.h"
#include "redis_string.h"
diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc
index d3135288e..587ab937e 100644
--- a/src/ray/gcs/tables.cc
+++ b/src/ray/gcs/tables.cc
@@ -40,7 +40,8 @@ namespace gcs {
template
Status Log::Append(const JobID &job_id, const ID &id,
- std::shared_ptr &data, const WriteCallback &done) {
+ const std::shared_ptr &data,
+ const WriteCallback &done) {
num_appends_++;
auto callback = [this, id, data, done](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
@@ -59,8 +60,9 @@ Status Log::Append(const JobID &job_id, const ID &id,
template
Status Log::AppendAt(const JobID &job_id, const ID &id,
- std::shared_ptr &data, const WriteCallback &done,
- const WriteCallback &failure, int log_length) {
+ const std::shared_ptr &data,
+ const WriteCallback &done, const WriteCallback &failure,
+ int log_length) {
num_appends_++;
auto callback = [this, id, data, done, failure](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
@@ -226,7 +228,8 @@ std::string Log::DebugString() const {
template
Status Table::Add(const JobID &job_id, const ID &id,
- std::shared_ptr &data, const WriteCallback &done) {
+ const std::shared_ptr &data,
+ const WriteCallback &done) {
num_adds_++;
auto callback = [this, id, data, done](const CallbackReply &reply) {
if (done != nullptr) {
@@ -288,8 +291,8 @@ std::string Table::DebugString() const {
}
template
-Status Set::Add(const JobID &job_id, const ID &id, std::shared_ptr &data,
- const WriteCallback &done) {
+Status Set::Add(const JobID &job_id, const ID &id,
+ const std::shared_ptr &data, const WriteCallback &done) {
num_adds_++;
auto callback = [this, id, data, done](const CallbackReply &reply) {
if (done != nullptr) {
@@ -303,7 +306,8 @@ Status Set::Add(const JobID &job_id, const ID &id, std::shared_ptr
Status Set::Remove(const JobID &job_id, const ID &id,
- std::shared_ptr &data, const WriteCallback &done) {
+ const std::shared_ptr &data,
+ const WriteCallback &done) {
num_removes_++;
auto callback = [this, id, data, done](const CallbackReply &reply) {
if (done != nullptr) {
diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h
index f8f5ab3dc..364db09dc 100644
--- a/src/ray/gcs/tables.h
+++ b/src/ray/gcs/tables.h
@@ -67,10 +67,10 @@ class LogInterface {
public:
using WriteCallback =
std::function;
- virtual Status Append(const JobID &job_id, const ID &id, std::shared_ptr &data,
- const WriteCallback &done) = 0;
+ virtual Status Append(const JobID &job_id, const ID &id,
+ const std::shared_ptr &data, const WriteCallback &done) = 0;
virtual Status AppendAt(const JobID &job_id, const ID &task_id,
- std::shared_ptr &data, const WriteCallback &done,
+ const std::shared_ptr &data, const WriteCallback &done,
const WriteCallback &failure, int log_length) = 0;
virtual ~LogInterface(){};
};
@@ -126,7 +126,7 @@ class Log : public LogInterface, virtual public PubsubInterface {
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
- Status Append(const JobID &job_id, const ID &id, std::shared_ptr &data,
+ Status Append(const JobID &job_id, const ID &id, const std::shared_ptr &data,
const WriteCallback &done);
/// Append a log entry to a key if and only if the log has the given number
@@ -141,7 +141,7 @@ class Log : public LogInterface, virtual public PubsubInterface {
/// \param log_length The number of entries that the log must have for the
/// append to succeed.
/// \return Status
- Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr &data,
+ Status AppendAt(const JobID &job_id, const ID &id, const std::shared_ptr &data,
const WriteCallback &done, const WriteCallback &failure,
int log_length);
@@ -272,8 +272,8 @@ template
class TableInterface {
public:
using WriteCallback = typename Log::WriteCallback;
- virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr &data,
- const WriteCallback &done) = 0;
+ virtual Status Add(const JobID &job_id, const ID &task_id,
+ const std::shared_ptr &data, const WriteCallback &done) = 0;
virtual ~TableInterface(){};
};
@@ -315,7 +315,7 @@ class Table : private Log,
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
- Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data,
+ Status Add(const JobID &job_id, const ID &id, const std::shared_ptr &data,
const WriteCallback &done);
/// Lookup an entry asynchronously.
@@ -378,10 +378,10 @@ template
class SetInterface {
public:
using WriteCallback = typename Log::WriteCallback;
- virtual Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data,
+ virtual Status Add(const JobID &job_id, const ID &id, const std::shared_ptr &data,
const WriteCallback &done) = 0;
- virtual Status Remove(const JobID &job_id, const ID &id, std::shared_ptr &data,
- const WriteCallback &done) = 0;
+ virtual Status Remove(const JobID &job_id, const ID &id,
+ const std::shared_ptr &data, const WriteCallback &done) = 0;
virtual ~SetInterface(){};
};
@@ -420,7 +420,7 @@ class Set : private Log,
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
- Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data,
+ Status Add(const JobID &job_id, const ID &id, const std::shared_ptr &data,
const WriteCallback &done);
/// Remove an entry from the set.
@@ -431,7 +431,7 @@ class Set : private Log,
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
- Status Remove(const JobID &job_id, const ID &id, std::shared_ptr &data,
+ Status Remove(const JobID &job_id, const ID &id, const std::shared_ptr &data,
const WriteCallback &done);
Status Subscribe(const JobID &job_id, const ClientID &client_id,
@@ -695,7 +695,8 @@ class TaskLeaseTable : public Table {
prefix_ = TablePrefix::TASK_LEASE;
}
- Status Add(const JobID &job_id, const TaskID &id, std::shared_ptr &data,
+ Status Add(const JobID &job_id, const TaskID &id,
+ const std::shared_ptr &data,
const WriteCallback &done) override {
RAY_RETURN_NOT_OK((Table::Add(job_id, id, data, done)));
// Mark the entry for expiration in Redis. It's okay if this command fails
diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto
new file mode 100644
index 000000000..4bee921c0
--- /dev/null
+++ b/src/ray/protobuf/common.proto
@@ -0,0 +1,123 @@
+syntax = "proto3";
+
+package ray.rpc;
+
+option java_package = "org.ray.runtime.generated";
+
+// Language of a task or worker.
+enum Language {
+ PYTHON = 0;
+ JAVA = 1;
+ CPP = 2;
+}
+
+// Type of a task.
+enum TaskType {
+ // Normal task.
+ NORMAL_TASK = 0;
+ // Actor creation task.
+ ACTOR_CREATION_TASK = 1;
+ // Actor task.
+ ACTOR_TASK = 2;
+}
+
+/// The task specification encapsulates all immutable information about the
+/// task. These fields are determined at submission time, converse to the
+/// `TaskExecutionSpec` may change at execution time.
+message TaskSpec {
+ // Type of this task.
+ TaskType type = 1;
+ // Language of this task.
+ Language language = 2;
+ // Function descriptor of this task, which is a list of strings that can
+ // uniquely describe the function to execute.
+ // For a Python function, it should be: [module_name, class_name, function_name]
+ // For a Java function, it should be: [class_name, method_name, type_descriptor]
+ repeated bytes function_descriptor = 3;
+ // ID of the job that this task belongs to.
+ bytes job_id = 4;
+ // Task ID of the task.
+ bytes task_id = 5;
+ // Task ID of the parent task.
+ bytes parent_task_id = 6;
+ // A count of the number of tasks submitted by the parent task before this one.
+ uint64 parent_counter = 7;
+ // Task arguments.
+ repeated TaskArg args = 8;
+ // Number of return objects.
+ uint64 num_returns = 9;
+ // Quantities of the different resources required by this task.
+ map required_resources = 10;
+ // The resources required for placing this task on a node. If this is empty,
+ // then the placement resources are equal to the required_resources.
+ map required_placement_resources = 11;
+ // Task specification for an actor creation task.
+ // This field is only valid when `type == ACTOR_CREATION_TASK`.
+ ActorCreationTaskSpec actor_creation_task_spec = 14;
+ // Task specification for an actor task.
+ // This field is only valid when `type == ACTOR_TASK`.
+ ActorTaskSpec actor_task_spec = 15;
+}
+
+// Argument in the task.
+message TaskArg {
+ // Object IDs for pass-by-reference arguments. Normally there is only one
+ // object ID in this list which represents the object that is being passed.
+ // However to support reducers in a MapReduce workload, we also support
+ // passing multiple object IDs for each argument.
+ repeated bytes object_ids = 1;
+ // Data for pass-by-value arguments.
+ bytes data = 2;
+}
+
+// Task spec of an actor creation task.
+message ActorCreationTaskSpec {
+ // ID of the actor that will be created by this task.
+ bytes actor_id = 2;
+ // The max number of times this actor should be recontructed.
+ // If this number of 0 or negative, the actor won't be reconstructed on failure.
+ uint64 max_actor_reconstructions = 3;
+ // The dynamic options used in the worker command when starting the worker process for
+ // an actor creation task. If the list isn't empty, the options will be used to replace
+ // the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the
+ // worker command.
+ repeated string dynamic_worker_options = 4;
+}
+
+// Task spec of an actor task.
+message ActorTaskSpec {
+ // Actor ID of the task. This is the actor that this task is executed on
+ // or NIL_ACTOR_ID if the task is just a normal task.
+ bytes actor_id = 2;
+ // The ID of the handle that was used to submit the task. This should be
+ // unique across handles with the same actor_id.
+ bytes actor_handle_id = 3;
+ // The dummy object ID of the actor creation task if this is an actor method.
+ bytes actor_creation_dummy_object_id = 4;
+ // Number of tasks that have been submitted to this actor so far.
+ uint64 actor_counter = 5;
+ // If this is an actor task, then this will be populated with all of the new
+ // actor handles that were forked from this handle since the last task on
+ // this handle was submitted.
+ // Note that this is a long string that concatenate all of the new_actor_handle IDs.
+ repeated bytes new_actor_handles = 6;
+}
+
+// The task execution specification encapsulates all mutable information about
+// the task. These fields may change at execution time, converse to the
+// `TaskSpec` is determined at submission time.
+message TaskExecutionSpec {
+ // A list of object IDs representing the dependencies of this task that may
+ // change at execution time.
+ repeated bytes dependencies = 1;
+ // The last time this task was received for scheduling.
+ double last_timestamp = 2;
+ // The number of times this task was spilled back by raylets.
+ uint64 num_forwards = 3;
+}
+
+// Represents a task, including task spec, and task execution spec.
+message Task {
+ TaskSpec task_spec = 1;
+ TaskExecutionSpec task_execution_spec = 2;
+}
diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto
index e698b4602..0b238f456 100644
--- a/src/ray/protobuf/core_worker.proto
+++ b/src/ray/protobuf/core_worker.proto
@@ -2,6 +2,8 @@ syntax = "proto3";
package ray.rpc;
+import "src/ray/protobuf/common.proto";
+
message ActorHandle {
// ID of the actor.
bytes actor_id = 1;
@@ -10,7 +12,7 @@ message ActorHandle {
bytes actor_handle_id = 2;
// Language of the actor.
- int32 actor_language = 3;
+ Language actor_language = 3;
// Function descriptor of actor creation task.
repeated string actor_creation_task_function_descriptor = 4;
diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto
index b33e4afb8..1dfb1a39a 100644
--- a/src/ray/protobuf/gcs.proto
+++ b/src/ray/protobuf/gcs.proto
@@ -2,14 +2,9 @@ syntax = "proto3";
package ray.rpc;
-option java_package = "org.ray.runtime.generated";
+import "src/ray/protobuf/common.proto";
-// Language of a worker or task.
-enum Language {
- PYTHON = 0;
- CPP = 1;
- JAVA = 2;
-}
+option java_package = "org.ray.runtime.generated";
// These indexes are mapped to strings in ray_redis_module.cc.
enum TablePrefix {
@@ -77,12 +72,8 @@ message TaskReconstructionData {
bytes node_manager_id = 2;
}
-// TODO(hchen): Task table currently still uses flatbuffers-defined data structure
-// (`Task` in `node_manager.fbs`), because a lot of code depends on that. This should
-// be migrated to protobuf very soon.
message TaskTableData {
- // Flatbuffers-serialized content of the task, see `src/ray/raylet/task.h`.
- bytes task = 1;
+ Task task = 1;
}
message ActorTableData {
diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto
index 8a82da1c7..59d2ce18f 100644
--- a/src/ray/protobuf/node_manager.proto
+++ b/src/ray/protobuf/node_manager.proto
@@ -2,16 +2,14 @@ syntax = "proto3";
package ray.rpc;
+import "src/ray/protobuf/common.proto";
+
message ForwardTaskRequest {
// The ID of the task to be forwarded.
bytes task_id = 1;
// The tasks in the uncommitted lineage of the forwarded task. This
// should include task_id.
- // TODO(hchen): Currently, `uncommitted_tasks` are represented as
- // flatbutters-serialized bytes. This is because the flatbuffers-defined Task data
- // structure is being used in many places. We should move Task and all related data
- // strucutres to protobuf.
- repeated bytes uncommitted_tasks = 2;
+ repeated Task uncommitted_tasks = 2;
}
message ForwardTaskReply {
diff --git a/src/ray/protobuf/worker.proto b/src/ray/protobuf/worker.proto
index b63782925..c2affa1ac 100644
--- a/src/ray/protobuf/worker.proto
+++ b/src/ray/protobuf/worker.proto
@@ -2,15 +2,11 @@ syntax = "proto3";
package ray.rpc;
+import "src/ray/protobuf/common.proto";
+
message AssignTaskRequest {
- // The ID of the task to be pushed.
- bytes task_id = 1;
- // The task to be pushed. This should include task_id.
- // TODO(hchen): Currently, `task_spec` are represented as
- // flatbutters-serialized bytes. This is because the flatbuffers-defined Task data
- // structure is being used in many places. We should move Task and all related data
- // structures to protobuf.
- bytes task_spec = 2;
+ // The task to be pushed.
+ Task task = 1;
}
message AssignTaskReply {
diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs
index be2a6ead5..dad8c8248 100644
--- a/src/ray/raylet/format/node_manager.fbs
+++ b/src/ray/raylet/format/node_manager.fbs
@@ -1,8 +1,5 @@
// raylet protocol specification
-include "gcs.fbs";
-
-
// TODO(swang): We put the flatbuffer types in a separate namespace for now to
// avoid conflicts with legacy Ray types.
namespace ray.protocol;
@@ -137,7 +134,8 @@ table RegisterClientRequest {
// The job ID if the client is a driver, otherwise it should be NIL.
job_id: string;
// Language of this worker.
- language: Language;
+ // TODO(hchen): Use `Language` in `common.proto`.
+ language: int;
// Port that this worker is listening on.
// If port > 0, then worker will listen to this port and wait for
// raylet to push tasks, instead of invoking GetTask().
diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc
index 8e7750aae..a37254222 100644
--- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc
+++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc
@@ -59,8 +59,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
* Signature: (J[BLjava/nio/ByteBuffer;II)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask(
- JNIEnv *env, jclass, jlong client, jbyteArray cursorId, jobject taskBuff, jint pos,
- jint taskSize) {
+ JNIEnv *env, jclass, jlong client, jbyteArray cursorId, jbyteArray taskSpec) {
auto raylet_client = reinterpret_cast(client);
std::vector execution_dependencies;
@@ -69,8 +68,13 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit
execution_dependencies.push_back(cursor_id.GetId());
}
- auto data = reinterpret_cast(env->GetDirectBufferAddress(taskBuff)) + pos;
- ray::raylet::TaskSpecification task_spec(data, taskSize);
+ jbyte *data = env->GetByteArrayElements(taskSpec, NULL);
+ jsize size = env->GetArrayLength(taskSpec);
+ ray::rpc::TaskSpec task_spec_message;
+ task_spec_message.ParseFromArray(data, size);
+ env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT);
+
+ ray::raylet::TaskSpecification task_spec(task_spec_message);
auto status = raylet_client->SubmitTask(execution_dependencies, task_spec);
ThrowRayExceptionIfNotOK(env, status);
}
@@ -90,24 +94,16 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_native
return nullptr;
}
- // We serialize the task specification using flatbuffers and then parse the
- // resulting string. This awkwardness is due to the fact that the Java
- // implementation does not use the underlying C++ TaskSpecification class.
- flatbuffers::FlatBufferBuilder fbb;
- auto message = spec->ToFlatbuffer(fbb);
- fbb.Finish(message);
- auto task_message = flatbuffers::GetRoot(fbb.GetBufferPointer());
+ // Serialize the task spec and copy to Java byte array.
+ auto task_data = spec->Serialize();
- jbyteArray result;
- result = env->NewByteArray(task_message->size());
+ jbyteArray result = env->NewByteArray(task_data.size());
if (result == nullptr) {
return nullptr; /* out of memory error thrown */
}
- // move from task spec structure to the java structure
- env->SetByteArrayRegion(
- result, 0, task_message->size(),
- reinterpret_cast(const_cast(task_message->data())));
+ env->SetByteArrayRegion(result, 0, task_data.size(),
+ reinterpret_cast(task_data.data()));
return result;
}
diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h
index 91338a12e..414f916d1 100644
--- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h
+++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h
@@ -20,10 +20,10 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeSubmitTask
- * Signature: (J[BLjava/nio/ByteBuffer;II)V
+ * Signature: (J[B[B)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask(
- JNIEnv *, jclass, jlong, jbyteArray, jobject, jint, jint);
+ JNIEnv *, jclass, jlong, jbyteArray, jbyteArray);
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc
index 6f6d2cb69..986f735f7 100644
--- a/src/ray/raylet/lineage_cache.cc
+++ b/src/ray/raylet/lineage_cache.cc
@@ -272,9 +272,11 @@ void LineageCache::FlushTask(const TaskID &task_id) {
[this](ray::gcs::AsyncGcsClient *client, const TaskID &id,
const TaskTableData &data) { HandleEntryCommitted(id); };
auto task = lineage_.GetEntry(task_id);
- // TODO(swang): Make this better...
auto task_data = std::make_shared();
- task_data->set_task(task->TaskData().Serialize());
+ task_data->mutable_task()->mutable_task_spec()->CopyFrom(
+ task->TaskData().GetTaskSpecification().GetMessage());
+ task_data->mutable_task()->mutable_task_execution_spec()->CopyFrom(
+ task->TaskData().GetTaskExecutionSpec().GetMessage());
RAY_CHECK_OK(task_storage_.Add(JobID(task->TaskData().GetTaskSpecification().JobId()),
task_id, task_data, task_callback));
diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc
index 1ecc13599..ad3f1d7a2 100644
--- a/src/ray/raylet/lineage_cache_test.cc
+++ b/src/ray/raylet/lineage_cache_test.cc
@@ -8,6 +8,7 @@
#include "ray/raylet/task.h"
#include "ray/raylet/task_execution_spec.h"
#include "ray/raylet/task_spec.h"
+#include "ray/raylet/task_util.h"
namespace ray {
@@ -23,7 +24,7 @@ class MockGcs : public gcs::TableInterface,
}
Status Add(const JobID &job_id, const TaskID &task_id,
- std::shared_ptr &task_data,
+ const std::shared_ptr &task_data,
const gcs::TableInterface::WriteCallback &done) {
task_table_[task_id] = task_data;
auto callback = done;
@@ -125,21 +126,16 @@ class LineageCacheTest : public ::testing::Test {
};
static inline Task ExampleTask(const std::vector &arguments,
- int64_t num_returns) {
- std::unordered_map required_resources;
- std::vector> task_arguments;
- for (auto &argument : arguments) {
- std::vector references = {argument};
- task_arguments.emplace_back(std::make_shared(references));
+ uint64_t num_returns) {
+ TaskSpecBuilder builder;
+ builder.SetCommonTaskSpec(Language::PYTHON, {"", "", ""}, JobID::Nil(),
+ TaskID::FromRandom(), 0, num_returns, {}, {});
+ for (const auto &arg : arguments) {
+ builder.AddByRefArg(arg);
}
- std::vector function_descriptor(3);
- auto spec = TaskSpecification(JobID::Nil(), TaskID::FromRandom(), 0, task_arguments,
- num_returns, required_resources, Language::PYTHON,
- function_descriptor);
- auto execution_spec = TaskExecutionSpecification(std::vector());
- execution_spec.IncrementNumForwards();
- Task task = Task(execution_spec, spec);
- return task;
+ rpc::TaskExecutionSpec execution_spec_message;
+ execution_spec_message.set_num_forwards(1);
+ return Task(builder.Build(), TaskExecutionSpecification(execution_spec_message));
}
/// Helper method to create a Lineage object with a single task.
diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc
index 8da446d7a..664b2a513 100644
--- a/src/ray/raylet/main.cc
+++ b/src/ray/raylet/main.cc
@@ -2,11 +2,14 @@
#include "ray/common/ray_config.h"
#include "ray/common/status.h"
+#include "ray/protobuf/common.pb.h"
#include "ray/raylet/raylet.h"
#include "ray/stats/stats.h"
#include "gflags/gflags.h"
+using ray::rpc::Language;
+
DEFINE_string(raylet_socket_name, "", "The socket name of raylet.");
DEFINE_string(store_socket_name, "", "The socket name of object store.");
DEFINE_int32(object_manager_port, -1, "The port of object manager.");
diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc
index ebc07717e..f9c53d99a 100644
--- a/src/ray/raylet/node_manager.cc
+++ b/src/ray/raylet/node_manager.cc
@@ -797,19 +797,10 @@ void NodeManager::ProcessClientMessage(
ProcessPushErrorRequestMessage(message_data);
} break;
case protocol::MessageType::PushProfileEventsRequest: {
- ProfileTableDataT fbs_message;
- flatbuffers::GetRoot(message_data)->UnPackTo(&fbs_message);
+ auto fbs_message = flatbuffers::GetRoot(message_data);
rpc::ProfileTableData profile_table_data;
- profile_table_data.set_component_type(fbs_message.component_type);
- profile_table_data.set_component_id(fbs_message.component_id);
- for (const auto &fbs_event : fbs_message.profile_events) {
- rpc::ProfileTableData::ProfileEvent *event =
- profile_table_data.add_profile_events();
- event->set_event_type(fbs_event->event_type);
- event->set_start_time(fbs_event->start_time);
- event->set_end_time(fbs_event->end_time);
- event->set_extra_data(fbs_event->extra_data);
- }
+ RAY_CHECK(
+ profile_table_data.ParseFromArray(fbs_message->data(), fbs_message->size()));
RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data));
} break;
case protocol::MessageType::FreeObjectsInObjectStoreRequest: {
@@ -845,8 +836,9 @@ void NodeManager::ProcessRegisterClientRequestMessage(
const std::shared_ptr &client, const uint8_t *message_data) {
auto message = flatbuffers::GetRoot(message_data);
client->SetClientID(from_flatbuf(*message->worker_id()));
- auto worker = std::make_shared(message->worker_pid(), message->language(),
- message->port(), client);
+ Language language = static_cast(message->language());
+ auto worker =
+ std::make_shared(message->worker_pid(), language, message->port(), client);
if (message->is_worker()) {
// Register the new worker.
worker_pool_.RegisterWorker(std::move(worker));
@@ -1050,14 +1042,18 @@ void NodeManager::ProcessDisconnectClientMessage(
void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) {
// Read the task submitted by the client.
- auto message = flatbuffers::GetRoot(message_data);
- TaskExecutionSpecification task_execution_spec(
- from_flatbuf(*message->execution_dependencies()));
- TaskSpecification task_spec(*message->task_spec());
- Task task(task_execution_spec, task_spec);
+ auto fbs_message = flatbuffers::GetRoot(message_data);
+ rpc::Task task_message;
+ RAY_CHECK(task_message.mutable_task_spec()->ParseFromArray(
+ fbs_message->task_spec()->data(), fbs_message->task_spec()->size()));
+ for (const auto &dependency :
+ string_vec_from_flatbuf(*fbs_message->execution_dependencies())) {
+ task_message.mutable_task_execution_spec()->add_dependencies(dependency);
+ }
+
// Submit the task to the raylet. Since the task was submitted
// locally, there is no uncommitted lineage.
- SubmitTask(task, Lineage());
+ SubmitTask(Task(task_message), Lineage());
}
void NodeManager::ProcessFetchOrReconstructMessage(
@@ -1224,10 +1220,8 @@ void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request,
TaskID task_id = TaskID::FromBinary(request.task_id());
Lineage uncommitted_lineage;
for (int i = 0; i < request.uncommitted_tasks_size(); i++) {
- const std::string &task_message = request.uncommitted_tasks(i);
- const Task task(*flatbuffers::GetRoot(
- reinterpret_cast(task_message.data())));
- RAY_CHECK(uncommitted_lineage.SetEntry(std::move(task), GcsStatus::UNCOMMITTED));
+ Task task(request.uncommitted_tasks(i));
+ RAY_CHECK(uncommitted_lineage.SetEntry(task, GcsStatus::UNCOMMITTED));
}
const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData();
RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId()
@@ -1769,7 +1763,7 @@ bool NodeManager::AssignTask(const Task &task) {
worker->GetTaskResourceIds().Plus(worker->GetLifetimeResourceIds());
auto resource_id_set_flatbuf = resource_id_set.ToFlatbuf(fbb);
- auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb),
+ auto message = protocol::CreateGetTaskReply(fbb, fbb.CreateString(spec.Serialize()),
fbb.CreateVector(resource_id_set_flatbuf));
fbb.Finish(message);
const auto &task_id = spec.TaskId();
@@ -2025,9 +2019,7 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) {
const TaskTableData &task_data) {
// The task was in the GCS task table. Use the stored task spec to
// re-execute the task.
- auto message = flatbuffers::GetRoot(task_data.task().data());
- const Task task(*message);
- ResubmitTask(task);
+ ResubmitTask(Task(task_data.task()));
},
/*failure_callback=*/
[this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id) {
@@ -2238,8 +2230,12 @@ void NodeManager::ForwardTask(
// Prepare the request message.
rpc::ForwardTaskRequest request;
request.set_task_id(task_id.Binary());
- for (auto &entry : uncommitted_lineage.GetEntries()) {
- request.add_uncommitted_tasks(entry.second.TaskData().Serialize());
+ for (auto &task_entry : uncommitted_lineage.GetEntries()) {
+ auto task = request.add_uncommitted_tasks();
+ task->mutable_task_spec()->CopyFrom(
+ task_entry.second.TaskData().GetTaskSpecification().GetMessage());
+ task->mutable_task_execution_spec()->CopyFrom(
+ task_entry.second.TaskData().GetTaskExecutionSpec().GetMessage());
}
// Move the FORWARDING task to the SWAP queue so that we remember that we
diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h
index d485fff8a..582f6bd77 100644
--- a/src/ray/raylet/node_manager.h
+++ b/src/ray/raylet/node_manager.h
@@ -10,6 +10,7 @@
#include "ray/raylet/task.h"
#include "ray/object_manager/object_manager.h"
#include "ray/common/client_connection.h"
+#include "ray/protobuf/common.pb.h"
#include "ray/raylet/actor_registration.h"
#include "ray/raylet/lineage_cache.h"
#include "ray/raylet/scheduling_policy.h"
@@ -31,6 +32,7 @@ using rpc::ErrorType;
using rpc::HeartbeatBatchTableData;
using rpc::HeartbeatTableData;
using rpc::JobTableData;
+using rpc::Language;
struct NodeManagerConfig {
/// The node's resource configuration.
@@ -48,7 +50,7 @@ struct NodeManagerConfig {
/// worker pool.
int maximum_startup_concurrency;
/// The commands used to start the worker process, grouped by language.
- std::unordered_map> worker_commands;
+ WorkerCommandMap worker_commands;
/// The time between heartbeats in milliseconds.
uint64_t heartbeat_period_ms;
/// The time between debug dumps in milliseconds, or -1 to disable.
diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc
index da1b09e27..e6bb6b740 100644
--- a/src/ray/raylet/raylet_client.cc
+++ b/src/ray/raylet/raylet_client.cc
@@ -228,7 +228,7 @@ ray::Status RayletClient::SubmitTask(const std::vector &execution_depe
flatbuffers::FlatBufferBuilder fbb;
auto execution_dependencies_message = to_flatbuf(fbb, execution_dependencies);
auto message = ray::protocol::CreateSubmitTaskRequest(
- fbb, execution_dependencies_message, task_spec.ToFlatbuffer(fbb));
+ fbb, execution_dependencies_message, fbb.CreateString(task_spec.Serialize()));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::SubmitTask, &fbb);
}
@@ -335,9 +335,9 @@ ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string
return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb);
}
-ray::Status RayletClient::PushProfileEvents(const ProfileTableDataT &profile_events) {
+ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) {
flatbuffers::FlatBufferBuilder fbb;
- auto message = CreateProfileTableData(fbb, &profile_events);
+ auto message = fbb.CreateString(profile_events.SerializeAsString());
fbb.Finish(message);
auto status = conn_->WriteMessage(MessageType::PushProfileEventsRequest, &fbb);
diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h
index 5f887cc93..3620c9d9c 100644
--- a/src/ray/raylet/raylet_client.h
+++ b/src/ray/raylet/raylet_client.h
@@ -1,6 +1,7 @@
#ifndef RAYLET_CLIENT_H
#define RAYLET_CLIENT_H
+#include
#include
#include
#include
@@ -17,6 +18,9 @@ using ray::ObjectID;
using ray::TaskID;
using ray::UniqueID;
+using ray::rpc::Language;
+using ray::rpc::ProfileTableData;
+
using MessageType = ray::protocol::MessageType;
using ResourceMappingType =
std::unordered_map>>;
@@ -138,7 +142,7 @@ class RayletClient {
///
/// \param profile_events A batch of profiling event information.
/// \return ray::Status.
- ray::Status PushProfileEvents(const ProfileTableDataT &profile_events);
+ ray::Status PushProfileEvents(const ProfileTableData &profile_events);
/// Free a list of objects from object stores.
///
diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc
index 033da0561..94155b442 100644
--- a/src/ray/raylet/reconstruction_policy_test.cc
+++ b/src/ray/raylet/reconstruction_policy_test.cc
@@ -85,7 +85,7 @@ class MockGcs : public gcs::PubsubInterface,
}
void Add(const JobID &job_id, const TaskID &task_id,
- std::shared_ptr &task_lease_data) {
+ const std::shared_ptr &task_lease_data) {
task_lease_table_[task_id] = task_lease_data;
if (subscribed_tasks_.count(task_id) == 1) {
notification_callback_(nullptr, task_id, *task_lease_data);
@@ -112,7 +112,7 @@ class MockGcs : public gcs::PubsubInterface,
Status AppendAt(
const JobID &job_id, const TaskID &task_id,
- std::shared_ptr &task_data,
+ const std::shared_ptr &task_data,
const ray::gcs::LogInterface::WriteCallback
&success_callback,
const ray::gcs::LogInterface::WriteCallback
@@ -134,7 +134,7 @@ class MockGcs : public gcs::PubsubInterface,
MOCK_METHOD4(
Append,
ray::Status(
- const JobID &, const TaskID &, std::shared_ptr &,
+ const JobID &, const TaskID &, const std::shared_ptr &,
const ray::gcs::LogInterface::WriteCallback &));
private:
diff --git a/src/ray/raylet/task.cc b/src/ray/raylet/task.cc
index 9d8036411..59f45d97e 100644
--- a/src/ray/raylet/task.cc
+++ b/src/ray/raylet/task.cc
@@ -4,24 +4,12 @@ namespace ray {
namespace raylet {
-flatbuffers::Offset Task::ToFlatbuffer(
- flatbuffers::FlatBufferBuilder &fbb) const {
- auto task = CreateTask(fbb, task_spec_.ToFlatbuffer(fbb),
- task_execution_spec_.ToFlatbuffer(fbb));
- return task;
-}
-
const TaskExecutionSpecification &Task::GetTaskExecutionSpec() const {
return task_execution_spec_;
}
const TaskSpecification &Task::GetTaskSpecification() const { return task_spec_; }
-void Task::SetExecutionDependencies(const std::vector &dependencies) {
- task_execution_spec_.SetExecutionDependencies(dependencies);
- ComputeDependencies();
-}
-
void Task::IncrementNumForwards() { task_execution_spec_.IncrementNumForwards(); }
const std::vector &Task::GetDependencies() const { return dependencies_; }
@@ -42,24 +30,10 @@ void Task::ComputeDependencies() {
}
void Task::CopyTaskExecutionSpec(const Task &task) {
- task_execution_spec_ = task.GetTaskExecutionSpec();
+ task_execution_spec_ = task.task_execution_spec_;
ComputeDependencies();
}
-const std::string Task::Serialize() const {
- flatbuffers::FlatBufferBuilder fbb;
- fbb.Finish(ToFlatbuffer(fbb));
- return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize());
-}
-
-std::string SerializeTaskAsString(const std::vector *dependencies,
- const TaskSpecification *task_spec) {
- std::vector execution_dependencies(*dependencies);
- TaskExecutionSpecification execution_spec(std::move(execution_dependencies));
- Task task(execution_spec, *task_spec);
- return task.Serialize();
-}
-
} // namespace raylet
} // namespace ray
diff --git a/src/ray/raylet/task.h b/src/ray/raylet/task.h
index 10cdfe511..e1a8a878b 100644
--- a/src/ray/raylet/task.h
+++ b/src/ray/raylet/task.h
@@ -3,9 +3,11 @@
#include
+#include "ray/protobuf/common.pb.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/task_execution_spec.h"
#include "ray/raylet/task_spec.h"
+#include "ray/rpc/message_wrapper.h"
namespace ray {
@@ -19,41 +21,22 @@ namespace raylet {
/// time.
class Task {
public:
- /// Create a task.
+ /// Construct a `Task` object from a protobuf message.
///
- /// \param execution_spec The execution specification for the task. These are
- /// the mutable fields in the task specification that may change at task
- /// execution time.
- /// \param task_spec The immutable specification for the task. These fields
- /// are determined at task submission time.
- Task(const TaskExecutionSpecification &execution_spec,
- const TaskSpecification &task_spec)
- : task_execution_spec_(execution_spec), task_spec_(task_spec) {
+ /// \param message The protobuf message.
+ explicit Task(const rpc::Task &message)
+ : task_spec_(message.task_spec()),
+ task_execution_spec_(message.task_execution_spec()) {
ComputeDependencies();
}
- /// Create a task from a serialized flatbuffer.
- ///
- /// \param task_flatbuffer The serialized task.
- Task(const protocol::Task &task_flatbuffer)
- : Task(*task_flatbuffer.task_execution_spec(),
- *task_flatbuffer.task_specification()) {}
-
- /// Create a task from a flatbuffer object.
- ///
- /// \param task_data The task flatbuffer object.
- Task(const protocol::TaskT &task_data)
- : Task(*task_data.task_execution_spec, task_data.task_specification) {}
-
- /// Destroy the task.
- virtual ~Task() {}
-
- /// Serialize a task to a flatbuffer.
- ///
- /// \param fbb The flatbuffer builder.
- /// \return An offset to the serialized task.
- flatbuffers::Offset ToFlatbuffer(
- flatbuffers::FlatBufferBuilder &fbb) const;
+ /// Construct a `Task` object from a `TaskSpecification` and a
+ /// `TaskExecutionSpecification`.
+ Task(TaskSpecification task_spec, TaskExecutionSpecification task_execution_spec)
+ : task_spec_(std::move(task_spec)),
+ task_execution_spec_(std::move(task_execution_spec)) {
+ ComputeDependencies();
+ }
/// Get the mutable specification for the task. This specification may be
/// updated at runtime.
@@ -66,11 +49,6 @@ class Task {
/// \return The immutable specification for the task.
const TaskSpecification &GetTaskSpecification() const;
- /// Set the task's execution dependencies.
- ///
- /// \param dependencies The value to set the execution dependencies to.
- void SetExecutionDependencies(const std::vector &dependencies);
-
/// Increment the number of times this task has been forwarded.
void IncrementNumForwards();
@@ -84,28 +62,22 @@ class Task {
/// \param task Task structure with updated dynamic information.
void CopyTaskExecutionSpec(const Task &task);
- /// Serialize this task as a string.
- const std::string Serialize() const;
-
private:
void ComputeDependencies();
- /// Task execution specification, consisting of all dynamic/mutable
- /// information about this task determined at execution time..
- TaskExecutionSpecification task_execution_spec_;
/// Task specification object, consisting of immutable information about this
/// task determined at submission time. Includes resource demand, object
/// dependencies, etc.
TaskSpecification task_spec_;
+ /// Task execution specification, consisting of all dynamic/mutable
+ /// information about this task determined at execution time..
+ TaskExecutionSpecification task_execution_spec_;
/// A cached copy of the task's object dependencies, including arguments from
/// the TaskSpecification and execution dependencies from the
/// TaskExecutionSpecification.
std::vector dependencies_;
};
-std::string SerializeTaskAsString(const std::vector *dependencies,
- const TaskSpecification *task_spec);
-
} // namespace raylet
} // namespace ray
diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc
index f16c17b6b..1a469c596 100644
--- a/src/ray/raylet/task_dependency_manager_test.cc
+++ b/src/ray/raylet/task_dependency_manager_test.cc
@@ -6,6 +6,7 @@
#include
#include "ray/raylet/task_dependency_manager.h"
+#include "ray/raylet/task_util.h"
namespace ray {
@@ -30,7 +31,7 @@ class MockGcs : public gcs::TableInterface {
MOCK_METHOD4(
Add,
ray::Status(const JobID &job_id, const TaskID &task_id,
- std::shared_ptr &task_data,
+ const std::shared_ptr &task_data,
const gcs::TableInterface::WriteCallback &done));
};
@@ -67,21 +68,16 @@ class TaskDependencyManagerTest : public ::testing::Test {
};
static inline Task ExampleTask(const std::vector &arguments,
- int64_t num_returns) {
- std::unordered_map required_resources;
- std::vector> task_arguments;
- for (auto &argument : arguments) {
- std::vector references = {argument};
- task_arguments.emplace_back(std::make_shared(references));
+ uint64_t num_returns) {
+ TaskSpecBuilder builder;
+ builder.SetCommonTaskSpec(Language::PYTHON, {"", "", ""}, JobID::Nil(),
+ TaskID::FromRandom(), 0, num_returns, {}, {});
+ for (const auto &arg : arguments) {
+ builder.AddByRefArg(arg);
}
- std::vector function_descriptor(3);
- auto spec = TaskSpecification(JobID::Nil(), TaskID::FromRandom(), 0, task_arguments,
- num_returns, required_resources, Language::PYTHON,
- function_descriptor);
- auto execution_spec = TaskExecutionSpecification(std::vector());
- execution_spec.IncrementNumForwards();
- Task task = Task(execution_spec, spec);
- return task;
+ rpc::TaskExecutionSpec execution_spec_message;
+ execution_spec_message.set_num_forwards(1);
+ return Task(builder.Build(), TaskExecutionSpecification(execution_spec_message));
}
std::vector MakeTaskChain(int chain_size,
diff --git a/src/ray/raylet/task_execution_spec.cc b/src/ray/raylet/task_execution_spec.cc
index dc7bf30b8..ed7e60e2c 100644
--- a/src/ray/raylet/task_execution_spec.cc
+++ b/src/ray/raylet/task_execution_spec.cc
@@ -4,54 +4,16 @@ namespace ray {
namespace raylet {
-TaskExecutionSpecification::TaskExecutionSpecification(
- const std::vector &&dependencies) {
- SetExecutionDependencies(dependencies);
+using rpc::IdVectorFromProtobuf;
+
+const std::vector TaskExecutionSpecification::ExecutionDependencies() const {
+ return IdVectorFromProtobuf(message_.dependencies());
}
-TaskExecutionSpecification::TaskExecutionSpecification(
- const std::vector &&dependencies, int num_forwards) {
- // TaskExecutionSpecification(std::move(dependencies));
- SetExecutionDependencies(dependencies);
- execution_spec_.num_forwards = num_forwards;
-}
-
-flatbuffers::Offset
-TaskExecutionSpecification::ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const {
- fbb.ForceDefaults(true);
- return protocol::TaskExecutionSpecification::Pack(fbb, &execution_spec_);
-}
-
-std::vector TaskExecutionSpecification::ExecutionDependencies() const {
- std::vector dependencies;
- for (const auto &dependency : execution_spec_.dependencies) {
- dependencies.push_back(ObjectID::FromBinary(dependency));
- }
- return dependencies;
-}
-
-void TaskExecutionSpecification::SetExecutionDependencies(
- const std::vector &dependencies) {
- execution_spec_.dependencies.clear();
- for (const auto &dependency : dependencies) {
- execution_spec_.dependencies.push_back(dependency.Binary());
- }
-}
-
-int TaskExecutionSpecification::NumForwards() const {
- return execution_spec_.num_forwards;
-}
+size_t TaskExecutionSpecification::NumForwards() const { return message_.num_forwards(); }
void TaskExecutionSpecification::IncrementNumForwards() {
- execution_spec_.num_forwards += 1;
-}
-
-int64_t TaskExecutionSpecification::LastTimestamp() const {
- return execution_spec_.last_timestamp;
-}
-
-void TaskExecutionSpecification::SetLastTimestamp(int64_t new_timestamp) {
- execution_spec_.last_timestamp = new_timestamp;
+ message_.set_num_forwards(message_.num_forwards() + 1);
}
} // namespace raylet
diff --git a/src/ray/raylet/task_execution_spec.h b/src/ray/raylet/task_execution_spec.h
index 6fc3b833a..bdc16c8b4 100644
--- a/src/ray/raylet/task_execution_spec.h
+++ b/src/ray/raylet/task_execution_spec.h
@@ -4,84 +4,45 @@
#include
#include "ray/common/id.h"
-#include "ray/raylet/format/node_manager_generated.h"
+#include "ray/protobuf/common.pb.h"
+#include "ray/rpc/message_wrapper.h"
+#include "ray/rpc/util.h"
namespace ray {
namespace raylet {
-/// \class TaskExecutionSpecification
-///
-/// The task execution specification encapsulates all mutable information about
-/// the task. These fields may change at execution time, converse to the
-/// TaskSpecification that is determined at submission time.
-class TaskExecutionSpecification {
+using rpc::MessageWrapper;
+
+/// Wrapper class of protobuf `TaskExecutionSpec`, see `common.proto` for details.
+class TaskExecutionSpecification : public MessageWrapper {
public:
- TaskExecutionSpecification(const protocol::TaskExecutionSpecificationT &execution_spec)
- : execution_spec_(execution_spec) {}
-
- /// Create a task execution specification.
+ /// Construct from a protobuf message object.
+ /// The input message will be **copied** into this object.
///
- /// \param dependencies The task's dependencies, determined at execution
- /// time.
- TaskExecutionSpecification(const std::vector &&dependencies);
+ /// \param message The protobuf message.
+ explicit TaskExecutionSpecification(rpc::TaskExecutionSpec message)
+ : MessageWrapper(std::move(message)) {}
- /// Create a task execution specification.
+ /// Construct from protobuf-serialized binary.
///
- /// \param dependencies The task's dependencies, determined at execution
- /// time.
- /// \param num_forwards The number of times this task has been forwarded by a
- /// node manager.
- TaskExecutionSpecification(const std::vector &&dependencies,
- int num_forwards);
-
- /// Create a task execution specification from a serialized flatbuffer.
- ///
- /// \param spec_flatbuffer The serialized specification.
- TaskExecutionSpecification(
- const protocol::TaskExecutionSpecification &spec_flatbuffer) {
- spec_flatbuffer.UnPackTo(&execution_spec_);
- }
-
- /// Serialize a task execution specification to a flatbuffer.
- ///
- /// \param fbb The flatbuffer builder.
- /// \return An offset to the serialized task execution specification.
- flatbuffers::Offset ToFlatbuffer(
- flatbuffers::FlatBufferBuilder &fbb) const;
+ /// \param serialized_binary Protobuf-serialized binary.
+ explicit TaskExecutionSpecification(const std::string &serialized_binary)
+ : MessageWrapper(serialized_binary) {}
/// Get the task's execution dependencies.
///
/// \return A vector of object IDs representing this task's execution
/// dependencies.
- std::vector ExecutionDependencies() const;
-
- /// Set the task's execution dependencies.
- ///
- /// \param dependencies The value to set the execution dependencies to.
- void SetExecutionDependencies(const std::vector &dependencies);
+ const std::vector