diff --git a/.gitignore b/.gitignore index b26c1d6b4..9b6947880 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,6 @@ /python/ray/pyarrow_files/ /python/build /python/dist -/python/flatbuffers-1.7.1/ -/flatbuffers-1.7.1/ /thirdparty/pkg/ # Files generated by flatc should be ignored diff --git a/BUILD.bazel b/BUILD.bazel index f2d458e31..c9ecaad87 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -11,10 +11,27 @@ COPTS = ["-DRAY_USE_GLOG"] # === Begin of protobuf definitions === +proto_library( + name = "common_proto", + srcs = ["src/ray/protobuf/common.proto"], + visibility = ["//java:__subpackages__"], +) + +cc_proto_library( + name = "common_cc_proto", + deps = [":common_proto"], +) + +python_proto_compile( + name = "common_py_proto", + deps = [":common_proto"], +) + proto_library( name = "gcs_proto", srcs = ["src/ray/protobuf/gcs.proto"], visibility = ["//java:__subpackages__"], + deps = [":common_proto"], ) cc_proto_library( @@ -30,6 +47,7 @@ python_proto_compile( proto_library( name = "node_manager_proto", srcs = ["src/ray/protobuf/node_manager.proto"], + deps = [":common_proto"], ) cc_proto_library( @@ -50,6 +68,7 @@ cc_proto_library( proto_library( name = "worker_proto", srcs = ["src/ray/protobuf/worker.proto"], + deps = [":common_proto"], ) cc_proto_library( @@ -60,6 +79,7 @@ cc_proto_library( proto_library( name = "core_worker_proto", srcs = ["src/ray/protobuf/core_worker.proto"], + deps = [":common_proto"], ) cc_proto_library( @@ -242,8 +262,8 @@ cc_library( copts = COPTS, linkopts = ["-pthread"], deps = [ + ":common_cc_proto", ":gcs", - ":gcs_fbs", ":node_manager_fbs", ":node_manager_rpc", ":object_manager", @@ -491,11 +511,7 @@ cc_library( ], ), copts = COPTS, - includes = [ - "src/ray/gcs/format", - ], deps = [ - ":gcs_fbs", ":ray_util", "@boost//:asio", "@plasma//:plasma_client", @@ -550,15 +566,10 @@ cc_library( ), hdrs = glob([ "src/ray/gcs/*.h", - "src/ray/gcs/format/*.h", ]), copts = COPTS, - includes = [ - "src/ray/gcs/format", - ], deps = [ ":gcs_cc_proto", - ":gcs_fbs", ":hiredis", ":node_manager_fbs", ":node_manager_rpc", @@ -598,13 +609,6 @@ FLATC_ARGS = [ "--scoped-enums", ] -flatbuffer_cc_library( - name = "gcs_fbs", - srcs = ["src/ray/gcs/format/gcs.fbs"], - flatc_args = FLATC_ARGS, - out_prefix = "src/ray/gcs/format/", -) - flatbuffer_cc_library( name = "common_fbs", srcs = ["@plasma//:cpp/src/plasma/format/common.fbs"], @@ -616,8 +620,6 @@ flatbuffer_cc_library( name = "node_manager_fbs", srcs = ["src/ray/raylet/format/node_manager.fbs"], flatc_args = FLATC_ARGS, - include_paths = ["src/ray/gcs/format"], - includes = [":gcs_fbs_includes"], out_prefix = "src/ray/raylet/format/", ) @@ -686,43 +688,6 @@ filegroup( visibility = ["//java:__subpackages__"], ) -filegroup( - name = "gcs_fbs_file", - srcs = ["src/ray/gcs/format/gcs.fbs"], - visibility = ["//java:__subpackages__"], -) - -flatbuffer_py_library( - name = "python_node_manager_fbs", - srcs = [ - "src/ray/raylet/format/node_manager.fbs", - ], - outs = [ - "ray/protocol/DisconnectClient.py", - "ray/protocol/FetchOrReconstruct.py", - "ray/protocol/ForwardTaskRequest.py", - "ray/protocol/FreeObjectsRequest.py", - "ray/protocol/GetTaskReply.py", - "ray/protocol/MessageType.py", - "ray/protocol/NotifyUnblocked.py", - "ray/protocol/PushErrorRequest.py", - "ray/protocol/RegisterClientReply.py", - "ray/protocol/RegisterClientRequest.py", - "ray/protocol/RegisterNodeManagerRequest.py", - "ray/protocol/ResourceIdSetInfo.py", - "ray/protocol/SubmitTaskRequest.py", - "ray/protocol/Task.py", - "ray/protocol/TaskExecutionSpecification.py", - "ray/protocol/WaitReply.py", - "ray/protocol/WaitRequest.py", - ], - include_paths = [ - "src/ray/gcs/format/", - ], - includes = ["src/ray/gcs/format/gcs.fbs"], - out_prefix = "python/ray/core/generated/", -) - filegroup( name = "python_sources", srcs = glob([ @@ -781,13 +746,20 @@ cc_binary( ], ) +filegroup( + name = "all_py_proto", + srcs = [ + "common_py_proto", + "gcs_py_proto", + ], +) + genrule( name = "ray_pkg", srcs = [ "python/ray/_raylet.so", "//:python_sources", - "//:gcs_py_proto", - "//:python_node_manager_fbs", + "//:all_py_proto", "//:redis-server", "//:redis-cli", "//:libray_redis_module.so", @@ -809,12 +781,12 @@ genrule( cp -f $(location @plasma//:plasma_store_server) $$WORK_DIR/python/ray/core/src/plasma/ && cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ && mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ && - for f in $(locations //:python_node_manager_fbs); do - cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/; - done && - for f in $(locations //:gcs_py_proto); do + for f in $(locations //:all_py_proto); do cp -f $$f $$WORK_DIR/python/ray/core/generated/; done && + # NOTE(hchen): Protobuf doesn't allow specifying Python package name. So we use this `sed` + # command to change the import path in the generated file. + sed -i -E 's/from src.ray.protobuf/from ./' $$WORK_DIR/python/ray/core/generated/gcs_pb2.py && echo $$WORK_DIR > $@ """, local = 1, diff --git a/bazel/ray.bzl b/bazel/ray.bzl index 750b90a21..8f41e05aa 100644 --- a/bazel/ray.bzl +++ b/bazel/ray.bzl @@ -13,23 +13,22 @@ def flatbuffer_py_library(name, srcs, outs, out_prefix, includes = [], include_p includes = includes, ) -def flatbuffer_java_library(name, srcs, outs, out_prefix, includes = [], include_paths = []): - flatbuffer_library_public( - name = name, - srcs = srcs, - outs = outs, - language_flag = "-j", - out_prefix = out_prefix, - include_paths = include_paths, - includes = includes, - ) - -def define_java_module(name, additional_srcs = [], additional_resources = [], define_test_lib = False, test_deps = [], **kwargs): +def define_java_module( + name, + additional_srcs = [], + exclude_srcs = [], + additional_resources = [], + define_test_lib = False, + test_deps = [], + **kwargs): lib_name = "org_ray_ray_" + name pom_file_targets = [lib_name] native.java_library( name = lib_name, - srcs = additional_srcs + native.glob([name + "/src/main/java/**/*.java"]), + srcs = additional_srcs + native.glob( + [name + "/src/main/java/**/*.java"], + exclude = exclude_srcs, + ), resources = native.glob([name + "/src/main/resources/**"]) + additional_resources, **kwargs ) diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 4960434af..b9c43424f 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -1,4 +1,4 @@ -load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module") +load("//bazel:ray.bzl", "define_java_module") load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile") exports_files([ @@ -50,8 +50,10 @@ define_java_module( define_java_module( name = "runtime", additional_srcs = [ - ":generate_java_gcs_fbs", - ":gcs_java_proto", + ":all_java_proto", + ], + exclude_srcs = [ + "runtime/src/main/java/org/ray/runtime/generated/*.java", ], additional_resources = [ ":java_native_deps", @@ -68,7 +70,6 @@ define_java_module( deps = [ ":org_ray_ray_api", "@plasma//:org_apache_arrow_arrow_plasma", - "@maven//:com_github_davidmoten_flatbuffers_java", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_typesafe_config", @@ -151,39 +152,22 @@ java_binary( ], ) +java_proto_compile( + name = "common_java_proto", + deps = ["@//:common_proto"], +) + java_proto_compile( name = "gcs_java_proto", deps = ["@//:gcs_proto"], ) -flatbuffers_generated_files = [ - "Arg.java", - "Language.java", - "TaskInfo.java", - "ResourcePair.java", -] - -flatbuffer_java_library( - name = "java_gcs_fbs", - srcs = ["//:gcs_fbs_file"], - outs = flatbuffers_generated_files, - out_prefix = "", -) - -genrule( - name = "generate_java_gcs_fbs", - srcs = [":java_gcs_fbs"], - outs = [ - "runtime/src/main/java/org/ray/runtime/generated/" + file for file in flatbuffers_generated_files +filegroup( + name = "all_java_proto", + srcs = [ + ":common_java_proto", + ":gcs_java_proto", ], - cmd = """ - for f in $(locations //java:java_gcs_fbs); do - chmod +w $$f - mv -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated - done - python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/.. - """, - local = 1, ) filegroup( @@ -202,8 +186,7 @@ filegroup( genrule( name = "gen_maven_deps", srcs = [ - ":gcs_java_proto", - ":generate_java_gcs_fbs", + ":all_java_proto", ":java_native_deps", ":copy_pom_file", "@plasma//:org_apache_arrow_arrow_plasma", @@ -212,6 +195,11 @@ genrule( cmd = """ set -x WORK_DIR=$$(pwd) + # Copy protobuf-generated files. + rm -rf $$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated + for f in $(locations //java:all_java_proto); do + unzip $$f -x META-INF/MANIFEST.MF -d $$WORK_DIR/java/runtime/src/main/java + done # Copy native dependecies. NATIVE_DEPS_DIR=$$WORK_DIR/java/runtime/native_dependencies/ rm -rf $$NATIVE_DEPS_DIR @@ -220,18 +208,6 @@ genrule( chmod +w $$f cp $$f $$NATIVE_DEPS_DIR done - # Copy protobuf-generated files. - GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated - rm -rf $$GENERATED_DIR - mkdir -p $$GENERATED_DIR - for f in $(locations //java:gcs_java_proto); do - unzip $$f - mv org/ray/runtime/generated/* $$GENERATED_DIR - done - # Copy flatbuffers-generated files - for f in $(locations //java:generate_java_gcs_fbs); do - cp $$f $$GENERATED_DIR - done # Install plasma jar to local maven repo. mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \ -DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT diff --git a/java/dependencies.bzl b/java/dependencies.bzl index ef6671375..26e36dff5 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -4,7 +4,6 @@ def gen_java_deps(): maven_install( artifacts = [ "com.beust:jcommander:1.72", - "com.github.davidmoten:flatbuffers-java:1.9.0.1", "com.google.guava:guava:27.0.1-jre", "com.google.protobuf:protobuf-java:3.8.0", "com.puppycrawl.tools:checkstyle:8.15", diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index e13dd95f9..aba612b36 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -31,11 +31,6 @@ jcommander 1.72 - - com.github.davidmoten - flatbuffers-java - 1.9.0.1 - com.google.guava guava diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 00b114460..8c00f718c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -1,13 +1,17 @@ package org.ray.runtime.raylet; import com.google.common.base.Preconditions; -import com.google.flatbuffers.FlatBufferBuilder; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.charset.Charset; import java.util.ArrayList; -import java.util.HashMap; +import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.exception.RayException; @@ -15,10 +19,8 @@ import org.ray.api.id.ObjectId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; -import org.ray.runtime.generated.Arg; -import org.ray.runtime.generated.Language; -import org.ray.runtime.generated.ResourcePair; -import org.ray.runtime.generated.TaskInfo; +import org.ray.runtime.generated.Common; +import org.ray.runtime.generated.Common.TaskType; import org.ray.runtime.task.FunctionArg; import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; @@ -30,15 +32,6 @@ public class RayletClientImpl implements RayletClient { private static final Logger LOGGER = LoggerFactory.getLogger(RayletClientImpl.class); - private static final int TASK_SPEC_BUFFER_SIZE = 2 * 1024 * 1024; - - /** - * Direct buffers that are used to hold flatbuffer-serialized task specs. - */ - private static ThreadLocal taskSpecBuffer = ThreadLocal.withInitial(() -> - ByteBuffer.allocateDirect(TASK_SPEC_BUFFER_SIZE).order(ByteOrder.LITTLE_ENDIAN) - ); - /** * The pointer to c++'s raylet client. */ @@ -86,21 +79,20 @@ public class RayletClientImpl implements RayletClient { Preconditions.checkState(!spec.parentTaskId.isNil()); Preconditions.checkState(!spec.jobId.isNil()); - ByteBuffer info = convertTaskSpecToFlatbuffer(spec); + byte[] taskSpec = convertTaskSpecToProtobuf(spec); byte[] cursorId = null; if (!spec.getExecutionDependencies().isEmpty()) { //TODO(hchen): handle more than one dependencies. cursorId = spec.getExecutionDependencies().get(0).getBytes(); } - nativeSubmitTask(client, cursorId, info, info.position(), info.remaining()); + nativeSubmitTask(client, cursorId, taskSpec); } @Override public TaskSpec getTask() { byte[] bytes = nativeGetTask(client); assert (null != bytes); - ByteBuffer bb = ByteBuffer.wrap(bytes); - return parseTaskSpecFromFlatbuffer(bb); + return parseTaskSpecFromProtobuf(bytes); } @Override @@ -127,7 +119,7 @@ public class RayletClientImpl implements RayletClient { @Override public void freePlasmaObjects(List objectIds, boolean localOnly, - boolean deleteCreatingTasks) { + boolean deleteCreatingTasks) { byte[][] objectIdsArray = IdUtil.getIdBytes(objectIds); nativeFreePlasmaObjects(client, objectIdsArray, localOnly, deleteCreatingTasks); } @@ -142,59 +134,76 @@ public class RayletClientImpl implements RayletClient { nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes()); } - private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { - bb.order(ByteOrder.LITTLE_ENDIAN); - TaskInfo info = TaskInfo.getRootAsTaskInfo(bb); - UniqueId jobId = UniqueId.fromByteBuffer(info.jobIdAsByteBuffer()); - TaskId taskId = TaskId.fromByteBuffer(info.taskIdAsByteBuffer()); - TaskId parentTaskId = TaskId.fromByteBuffer(info.parentTaskIdAsByteBuffer()); - int parentCounter = info.parentCounter(); - UniqueId actorCreationId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer()); - int maxActorReconstructions = info.maxActorReconstructions(); - UniqueId actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer()); - UniqueId actorHandleId = UniqueId.fromByteBuffer(info.actorHandleIdAsByteBuffer()); - int actorCounter = info.actorCounter(); - int numReturns = info.numReturns(); + /** + * Parse `TaskSpec` protobuf bytes. + */ + private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) { + Common.TaskSpec taskSpec; + try { + taskSpec = Common.TaskSpec.parseFrom(bytes); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Invalid protobuf data."); + } - // Deserialize new actor handles - UniqueId[] newActorHandles = IdUtil.getUniqueIdsFromByteBuffer( - info.newActorHandlesAsByteBuffer()); + // Parse common fields. + UniqueId jobId = UniqueId.fromByteBuffer(taskSpec.getJobId().asReadOnlyByteBuffer()); + TaskId taskId = TaskId.fromByteBuffer(taskSpec.getTaskId().asReadOnlyByteBuffer()); + TaskId parentTaskId = TaskId.fromByteBuffer(taskSpec.getParentTaskId().asReadOnlyByteBuffer()); + int parentCounter = (int) taskSpec.getParentCounter(); + int numReturns = (int) taskSpec.getNumReturns(); + Map resources = taskSpec.getRequiredResourcesMap(); - // Deserialize args - FunctionArg[] args = new FunctionArg[info.argsLength()]; - for (int i = 0; i < info.argsLength(); i++) { - Arg arg = info.args(i); - - int objectIdsLength = arg.objectIdsAsByteBuffer().remaining() / UniqueId.LENGTH; + // Parse args. + FunctionArg[] args = new FunctionArg[taskSpec.getArgsCount()]; + for (int i = 0; i < args.length; i++) { + Common.TaskArg arg = taskSpec.getArgs(i); + int objectIdsLength = arg.getObjectIdsCount(); if (objectIdsLength > 0) { Preconditions.checkArgument(objectIdsLength == 1, "This arg has more than one id: {}", objectIdsLength); - args[i] = FunctionArg.passByReference(ObjectId.fromByteBuffer(arg.objectIdsAsByteBuffer())); + args[i] = FunctionArg + .passByReference(ObjectId.fromByteBuffer(arg.getObjectIds(0).asReadOnlyByteBuffer())); } else { - ByteBuffer lbb = arg.dataAsByteBuffer(); - Preconditions.checkState(lbb != null && lbb.remaining() > 0); - byte[] data = new byte[lbb.remaining()]; - lbb.get(data); - args[i] = FunctionArg.passByValue(data); + args[i] = FunctionArg.passByValue(arg.getData().toByteArray()); } } - // Deserialize required resources; - Map resources = new HashMap<>(); - for (int i = 0; i < info.requiredResourcesLength(); i++) { - resources.put(info.requiredResources(i).key(), info.requiredResources(i).value()); - } - // Deserialize function descriptor - Preconditions.checkArgument(info.language() == Language.JAVA); - Preconditions.checkArgument(info.functionDescriptorLength() == 3); + // Parse function descriptor + Preconditions.checkArgument(taskSpec.getLanguage() == Common.Language.JAVA); + Preconditions.checkArgument(taskSpec.getFunctionDescriptorCount() == 3); JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor( - info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2) + taskSpec.getFunctionDescriptor(0).toString(Charset.defaultCharset()), + taskSpec.getFunctionDescriptor(1).toString(Charset.defaultCharset()), + taskSpec.getFunctionDescriptor(2).toString(Charset.defaultCharset()) ); - // Deserialize dynamic worker options. + // Parse ActorCreationTaskSpec. + UniqueId actorCreationId = UniqueId.NIL; + int maxActorReconstructions = 0; + UniqueId[] newActorHandles = new UniqueId[0]; List dynamicWorkerOptions = new ArrayList<>(); - for (int i = 0; i < info.dynamicWorkerOptionsLength(); ++i) { - dynamicWorkerOptions.add(info.dynamicWorkerOptions(i)); + if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) { + Common.ActorCreationTaskSpec actorCreationTaskSpec = taskSpec.getActorCreationTaskSpec(); + actorCreationId = UniqueId + .fromByteBuffer(actorCreationTaskSpec.getActorId().asReadOnlyByteBuffer()); + maxActorReconstructions = (int) actorCreationTaskSpec.getMaxActorReconstructions(); + dynamicWorkerOptions = ImmutableList + .copyOf(actorCreationTaskSpec.getDynamicWorkerOptionsList()); + } + + // Parse ActorTaskSpec. + UniqueId actorId = UniqueId.NIL; + UniqueId actorHandleId = UniqueId.NIL; + int actorCounter = 0; + if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) { + Common.ActorTaskSpec actorTaskSpec = taskSpec.getActorTaskSpec(); + actorId = UniqueId.fromByteBuffer(actorTaskSpec.getActorId().asReadOnlyByteBuffer()); + actorHandleId = UniqueId + .fromByteBuffer(actorTaskSpec.getActorHandleId().asReadOnlyByteBuffer()); + actorCounter = (int) actorTaskSpec.getActorCounter(); + newActorHandles = actorTaskSpec.getNewActorHandlesList().stream() + .map(byteString -> UniqueId.fromByteBuffer(byteString.asReadOnlyByteBuffer())) + .toArray(UniqueId[]::new); } return new TaskSpec(jobId, taskId, parentTaskId, parentCounter, actorCreationId, @@ -202,122 +211,78 @@ public class RayletClientImpl implements RayletClient { args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions); } - private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { - ByteBuffer bb = taskSpecBuffer.get(); - bb.clear(); + /** + * Convert a `TaskSpec` to protobuf-serialized bytes. + */ + private static byte[] convertTaskSpecToProtobuf(TaskSpec task) { + // Set common fields. + Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder() + .setJobId(ByteString.copyFrom(task.jobId.getBytes())) + .setTaskId(ByteString.copyFrom(task.taskId.getBytes())) + .setParentTaskId(ByteString.copyFrom(task.parentTaskId.getBytes())) + .setParentCounter(task.parentCounter) + .setNumReturns(task.numReturns) + .putAllRequiredResources(task.resources); - FlatBufferBuilder fbb = new FlatBufferBuilder(bb); - final int jobIdOffset = fbb.createString(task.jobId.toByteBuffer()); - final int taskIdOffset = fbb.createString(task.taskId.toByteBuffer()); - final int parentTaskIdOffset = fbb.createString(task.parentTaskId.toByteBuffer()); - final int parentCounter = task.parentCounter; - final int actorCreateIdOffset = fbb.createString(task.actorCreationId.toByteBuffer()); - final int actorCreateDummyIdOffset = fbb.createString(task.actorId.toByteBuffer()); - final int maxActorReconstructions = task.maxActorReconstructions; - final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer()); - final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer()); - final int actorCounter = task.actorCounter; - final int numReturnsOffset = task.numReturns; - - // Serialize the new actor handles. - int newActorHandlesOffset - = fbb.createString(IdUtil.concatIds(task.newActorHandles)); - - // Serialize args - int[] argsOffsets = new int[task.args.length]; - for (int i = 0; i < argsOffsets.length; i++) { - int objectIdOffset = 0; - int dataOffset = 0; - if (task.args[i].id != null) { - objectIdOffset = fbb.createString( - IdUtil.concatIds(new ObjectId[]{task.args[i].id})); - } else { - objectIdOffset = fbb.createString(""); - } - if (task.args[i].data != null) { - dataOffset = fbb.createString(ByteBuffer.wrap(task.args[i].data)); - } - argsOffsets[i] = Arg.createArg(fbb, objectIdOffset, dataOffset); - } - int argsOffset = fbb.createVectorOfTables(argsOffsets); - - // Serialize required resources - // The required_resources vector indicates the quantities of the different - // resources required by this task. The index in this vector corresponds to - // the resource type defined in the ResourceIndex enum. For example, - int[] requiredResourcesOffsets = new int[task.resources.size()]; - int i = 0; - for (Map.Entry entry : task.resources.entrySet()) { - int keyOffset = fbb.createString(ByteBuffer.wrap(entry.getKey().getBytes())); - requiredResourcesOffsets[i++] = - ResourcePair.createResourcePair(fbb, keyOffset, entry.getValue()); - } - int requiredResourcesOffset = fbb.createVectorOfTables(requiredResourcesOffsets); - - int[] requiredPlacementResourcesOffsets = new int[0]; - int requiredPlacementResourcesOffset = - fbb.createVectorOfTables(requiredPlacementResourcesOffsets); - - int language; - int functionDescriptorOffset; + // Set args + builder.addAllArgs( + Arrays.stream(task.args).map(arg -> { + Common.TaskArg.Builder argBuilder = Common.TaskArg.newBuilder(); + if (arg.id != null) { + argBuilder.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build(); + } else { + argBuilder.setData(ByteString.copyFrom(arg.data)).build(); + } + return argBuilder.build(); + }).collect(Collectors.toList()) + ); + // Set function descriptor and language. if (task.language == TaskLanguage.JAVA) { - // This is a Java task. - language = Language.JAVA; - int[] functionDescriptorOffsets = new int[]{ - fbb.createString(task.getJavaFunctionDescriptor().className), - fbb.createString(task.getJavaFunctionDescriptor().name), - fbb.createString(task.getJavaFunctionDescriptor().typeDescriptor) - }; - functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets); + builder.setLanguage(Common.Language.JAVA); + builder.addAllFunctionDescriptor(ImmutableList.of( + ByteString.copyFrom(task.getJavaFunctionDescriptor().className.getBytes()), + ByteString.copyFrom(task.getJavaFunctionDescriptor().name.getBytes()), + ByteString.copyFrom(task.getJavaFunctionDescriptor().typeDescriptor.getBytes()) + )); } else { - // This is a Python task. - language = Language.PYTHON; - int[] functionDescriptorOffsets = new int[]{ - fbb.createString(task.getPyFunctionDescriptor().moduleName), - fbb.createString(task.getPyFunctionDescriptor().className), - fbb.createString(task.getPyFunctionDescriptor().functionName), - fbb.createString("") - }; - functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets); + builder.setLanguage(Common.Language.PYTHON); + builder.addAllFunctionDescriptor(ImmutableList.of( + ByteString.copyFrom(task.getPyFunctionDescriptor().moduleName.getBytes()), + ByteString.copyFrom(task.getPyFunctionDescriptor().className.getBytes()), + ByteString.copyFrom(task.getPyFunctionDescriptor().functionName.getBytes()), + ByteString.EMPTY + )); } - int [] dynamicWorkerOptionsOffsets = new int[task.dynamicWorkerOptions.size()]; - for (int index = 0; index < task.dynamicWorkerOptions.size(); ++index) { - dynamicWorkerOptionsOffsets[index] = fbb.createString(task.dynamicWorkerOptions.get(index)); + if (!task.actorCreationId.isNil()) { + // Actor creation task. + builder.setType(TaskType.ACTOR_CREATION_TASK); + builder.setActorCreationTaskSpec( + Common.ActorCreationTaskSpec.newBuilder() + .setActorId(ByteString.copyFrom(task.actorCreationId.getBytes())) + .setMaxActorReconstructions(task.maxActorReconstructions) + .addAllDynamicWorkerOptions(task.dynamicWorkerOptions) + ); + } else if (!task.actorId.isNil()) { + // Actor task. + builder.setType(TaskType.ACTOR_TASK); + List newHandles = Arrays.stream(task.newActorHandles) + .map(id -> ByteString.copyFrom(id.getBytes())).collect(Collectors.toList()); + builder.setActorTaskSpec( + Common.ActorTaskSpec.newBuilder() + .setActorId(ByteString.copyFrom(task.actorId.getBytes())) + .setActorHandleId(ByteString.copyFrom(task.actorHandleId.getBytes())) + .setActorCreationDummyObjectId(ByteString.copyFrom(task.actorId.getBytes())) + .setActorCounter(task.actorCounter) + .addAllNewActorHandles(newHandles) + ); + } else { + // Normal task. + builder.setType(TaskType.NORMAL_TASK); } - int dynamicWorkerOptionsOffset = fbb.createVectorOfTables(dynamicWorkerOptionsOffsets); - int root = TaskInfo.createTaskInfo( - fbb, - jobIdOffset, - taskIdOffset, - parentTaskIdOffset, - parentCounter, - actorCreateIdOffset, - actorCreateDummyIdOffset, - maxActorReconstructions, - actorIdOffset, - actorHandleIdOffset, - actorCounter, - newActorHandlesOffset, - argsOffset, - numReturnsOffset, - requiredResourcesOffset, - requiredPlacementResourcesOffset, - language, - functionDescriptorOffset, - dynamicWorkerOptionsOffset); - fbb.finish(root); - ByteBuffer buffer = fbb.dataBuffer(); - - if (buffer.remaining() > TASK_SPEC_BUFFER_SIZE) { - LOGGER.error( - "Allocated buffer is not enough to transfer the task specification: {} vs {}", - TASK_SPEC_BUFFER_SIZE, buffer.remaining()); - throw new RuntimeException("Allocated buffer is not enough to transfer to task."); - } - return buffer; + return builder.build().toByteArray(); } public void setResource(String resourceName, double capacity, UniqueId nodeId) { @@ -344,10 +309,9 @@ public class RayletClientImpl implements RayletClient { private static native long nativeInit(String localSchedulerSocket, byte[] workerId, boolean isWorker, byte[] driverTaskId); - private static native void nativeSubmitTask(long client, byte[] cursorId, ByteBuffer taskBuff, - int pos, int taskSize) throws RayException; + private static native void nativeSubmitTask(long client, byte[] cursorId, byte[] taskSpec) + throws RayException; - // return TaskInfo (in FlatBuffer) private static native byte[] nativeGetTask(long client) throws RayException; private static native void nativeDestroy(long client) throws RayException; @@ -375,5 +339,5 @@ public class RayletClientImpl implements RayletClient { byte[] checkpointId); private static native void nativeSetResource(long conn, String resourceName, double capacity, - byte[] nodeId) throws RayException; + byte[] nodeId) throws RayException; } diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index f3968577d..3e5586069 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -24,8 +24,8 @@ from ray.includes.common cimport ( ) from ray.includes.libraylet cimport ( CRayletClient, - GCSProfileEventT, - GCSProfileTableDataT, + GCSProfileEvent, + GCSProfileTableData, ResourceMappingType, WaitResultPair, ) @@ -34,7 +34,7 @@ from ray.includes.unique_ids cimport ( CObjectID, CClientID, ) -from ray.includes.task cimport CTaskSpecification +from ray.includes.task cimport CTaskSpec from ray.includes.ray_config cimport RayConfig from ray.utils import decode @@ -232,18 +232,22 @@ cdef class RayletClient: def disconnect(self): check_status(self.client.get().Disconnect()) - def submit_task(self, Task task_spec): + def submit_task(self, TaskSpec task_spec, execution_dependencies): + cdef: + CObjectID c_id + c_vector[CObjectID] c_dependencies + for dep in execution_dependencies: + c_dependencies.push_back((dep).native()) check_status(self.client.get().SubmitTask( - task_spec.execution_dependencies.get()[0], - task_spec.task_spec.get()[0])) + c_dependencies, task_spec.task_spec.get()[0])) def get_task(self): cdef: - unique_ptr[CTaskSpecification] task_spec + unique_ptr[CTaskSpec] task_spec with nogil: check_status(self.client.get().GetTask(&task_spec)) - return Task.make(task_spec) + return TaskSpec.make(task_spec) def task_done(self): check_status(self.client.get().TaskDone()) @@ -303,19 +307,19 @@ cdef class RayletClient: def push_profile_events(self, component_type, UniqueID component_id, node_ip_address, profile_data): cdef: - GCSProfileTableDataT profile_info - GCSProfileEventT *profile_event + GCSProfileTableData profile_info + GCSProfileEvent *profile_event c_string event_type if len(profile_data) == 0: return # Short circuit if there are no profile events. - profile_info.component_type = component_type.encode("ascii") - profile_info.component_id = component_id.binary() - profile_info.node_ip_address = node_ip_address.encode("ascii") + profile_info.set_component_type(component_type.encode("ascii")) + profile_info.set_component_id(component_id.binary()) + profile_info.set_node_ip_address(node_ip_address.encode("ascii")) for py_profile_event in profile_data: - profile_event = new GCSProfileEventT() + profile_event = profile_info.add_profile_events() if not isinstance(py_profile_event, dict): raise TypeError( "Incorrect type for a profile event. Expected dict " @@ -325,28 +329,22 @@ cdef class RayletClient: # that will cause segfaults in the node manager. for key_string, event_data in py_profile_event.items(): if key_string == "event_type": - profile_event.event_type = event_data.encode("ascii") - if profile_event.event_type.length() == 0: + if len(event_data) == 0: raise ValueError( "'event_type' should not be a null string.") + profile_event.set_event_type(event_data.encode("ascii")) elif key_string == "start_time": - profile_event.start_time = float(event_data) + profile_event.set_start_time(float(event_data)) elif key_string == "end_time": - profile_event.end_time = float(event_data) + profile_event.set_end_time(float(event_data)) elif key_string == "extra_data": - profile_event.extra_data = event_data.encode("ascii") - if profile_event.extra_data.length() == 0: + if len(event_data) == 0: raise ValueError( "'extra_data' should not be a null string.") + profile_event.set_extra_data(event_data.encode("ascii")) else: raise ValueError( "Unknown profile event key '%s'" % key_string) - # Note that profile_info.profile_events is a vector of unique - # pointers, so profile_event will be deallocated when profile_info - # goes out of scope. "emplace_back" of vector has not been - # supported by Cython - profile_info.profile_events.push_back( - unique_ptr[GCSProfileEventT](profile_event)) check_status(self.client.get().PushProfileEvents(profile_info)) diff --git a/python/ray/core/generated/ray/__init__.py b/python/ray/core/generated/ray/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/python/ray/core/generated/ray/protocol/__init__.py b/python/ray/core/generated/ray/protocol/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 1e5441485..43ce42d91 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -2,8 +2,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.core.generated.ray.protocol.Task import Task - from ray.core.generated.gcs_pb2 import ( ActorCheckpointIdData, ClientTableData, @@ -33,7 +31,6 @@ __all__ = [ "ProfileTableData", "TablePrefix", "TablePubsub", - "Task", "TaskTableData", "construct_error_message", ] diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 5c716b673..c9b95aaec 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -87,17 +87,14 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: int parent_task_counter) -cdef extern from "ray/gcs/format/gcs_generated.h" nogil: - cdef cppclass GCSArg "Arg": - pass - +cdef extern from "ray/protobuf/common.pb.h" nogil: cdef cppclass CLanguage "Language": pass # This is a workaround for C++ enum class since Cython has no corresponding # representation. -cdef extern from "ray/gcs/format/gcs_generated.h" namespace "Language" nogil: +cdef extern from "ray/protobuf/common.pb.h" namespace "Language" nogil: cdef CLanguage LANGUAGE_PYTHON "Language::PYTHON" cdef CLanguage LANGUAGE_CPP "Language::CPP" cdef CLanguage LANGUAGE_JAVA "Language::JAVA" diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index 3bc6eddd0..7f6d03a5c 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -19,23 +19,23 @@ from ray.includes.unique_ids cimport ( CObjectID, CTaskID, ) -from ray.includes.task cimport CTaskSpecification +from ray.includes.task cimport CTaskSpec -cdef extern from "ray/gcs/format/gcs_generated.h" nogil: - cdef cppclass GCSProfileEventT "ProfileEventT": - c_string event_type - double start_time - double end_time - c_string extra_data - GCSProfileEventT() +cdef extern from "ray/protobuf/gcs.pb.h" nogil: + cdef cppclass GCSProfileEvent "ProfileTableData::ProfileEvent": + void set_event_type(const c_string &value) + void set_start_time(double value) + void set_end_time(double value) + c_string set_extra_data(const c_string &value) + GCSProfileEvent() - cdef cppclass GCSProfileTableDataT "ProfileTableDataT": - c_string component_type - c_string component_id - c_string node_ip_address - c_vector[unique_ptr[GCSProfileEventT]] profile_events - GCSProfileTableDataT() + cdef cppclass GCSProfileTableData "ProfileTableData": + void set_component_type(const c_string &value) + void set_component_id(const c_string &value) + void set_node_ip_address(const c_string &value) + GCSProfileEvent *add_profile_events() + GCSProfileTableData() ctypedef unordered_map[c_string, c_vector[pair[int64_t, double]]] \ @@ -52,8 +52,8 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: CRayStatus Disconnect() CRayStatus SubmitTask( const c_vector[CObjectID] &execution_dependencies, - const CTaskSpecification &task_spec) - CRayStatus GetTask(unique_ptr[CTaskSpecification] *task_spec) + const CTaskSpec &task_spec) + CRayStatus GetTask(unique_ptr[CTaskSpec] *task_spec) CRayStatus TaskDone() CRayStatus FetchOrReconstruct(c_vector[CObjectID] &object_ids, c_bool fetch_only, @@ -66,7 +66,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: CRayStatus PushError(const CJobID &job_id, const c_string &type, const c_string &error_message, double timestamp) CRayStatus PushProfileEvents( - const GCSProfileTableDataT &profile_events) + const GCSProfileTableData &profile_events) CRayStatus FreeObjects(const c_vector[CObjectID] &object_ids, c_bool local_only, c_bool delete_creating_tasks) CRayStatus PrepareActorCheckpoint(const CActorID &actor_id, diff --git a/python/ray/includes/task.pxd b/python/ray/includes/task.pxd index 5d6511c32..b648c325d 100644 --- a/python/ray/includes/task.pxd +++ b/python/ray/includes/task.pxd @@ -1,4 +1,4 @@ -from libc.stdint cimport int64_t, uint8_t +from libc.stdint cimport uint8_t, uint64_t from libcpp cimport bool as c_bool from libcpp.memory cimport unique_ptr, shared_ptr from libcpp.string cimport string as c_string @@ -17,72 +17,45 @@ from ray.includes.unique_ids cimport ( CTaskID, ) +cdef extern from "ray/protobuf/common.pb.h" namespace "ray::rpc" nogil: + cdef cppclass RpcTaskSpec "ray::rpc::TaskSpec": + void CopyFrom(const RpcTaskSpec &value) -cdef extern from "ray/raylet/task_execution_spec.h" \ - namespace "ray::raylet" nogil: - cdef cppclass CTaskExecutionSpecification \ - "ray::raylet::TaskExecutionSpecification": - CTaskExecutionSpecification(const c_vector[CObjectID] &&dependencies) - CTaskExecutionSpecification( - const c_vector[CObjectID] &&dependencies, int num_forwards) - c_vector[CObjectID] ExecutionDependencies() const - void SetExecutionDependencies(const c_vector[CObjectID] &dependencies) - int NumForwards() const - void IncrementNumForwards() - int64_t LastTimestamp() const - void SetLastTimestamp(int64_t new_timestamp) + cdef cppclass RpcTaskExecutionSpec "ray::rpc::TaskExecutionSpec": + void CopyFrom(const RpcTaskExecutionSpec &value) + void add_dependencies(const c_string &value) + + cdef cppclass RpcTask "ray::rpc::Task": + RpcTaskSpec *mutable_task_spec() + + +cdef extern from "ray/protobuf/gcs.pb.h" namespace "ray::rpc" nogil: + cdef cppclass TaskTableData "ray::rpc::TaskTableData": + RpcTask *mutable_task() + const c_string &SerializeAsString() cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil: - cdef cppclass CTaskArgument "ray::raylet::TaskArgument": - pass - - cdef cppclass CTaskArgumentByReference \ - "ray::raylet::TaskArgumentByReference": - CTaskArgumentByReference(const c_vector[CObjectID] &references) - - cdef cppclass CTaskArgumentByValue "ray::raylet::TaskArgumentByValue": - CTaskArgumentByValue(const uint8_t *value, size_t length) - - cdef cppclass CTaskSpecification "ray::raylet::TaskSpecification": - CTaskSpecification( - const CJobID &job_id, const CTaskID &parent_task_id, - int64_t parent_counter, - const c_vector[shared_ptr[CTaskArgument]] &task_arguments, - int64_t num_returns, - const unordered_map[c_string, double] &required_resources, - const CLanguage &language, - const c_vector[c_string] &function_descriptor) - CTaskSpecification( - const CJobID &job_id, const CTaskID &parent_task_id, - int64_t parent_counter, const CActorID &actor_creation_id, - const CObjectID &actor_creation_dummy_object_id, - int64_t max_actor_reconstructions, const CActorID &actor_id, - const CActorHandleID &actor_handle_id, int64_t actor_counter, - const c_vector[CActorHandleID] &new_actor_handles, - const c_vector[shared_ptr[CTaskArgument]] &task_arguments, - int64_t num_returns, - const unordered_map[c_string, double] &required_resources, - const unordered_map[c_string, double] &required_placement_res, - const CLanguage &language, - const c_vector[c_string] &function_descriptor) - CTaskSpecification(const c_string &string) - c_string SerializeAsString() const + cdef cppclass CTaskSpec "ray::raylet::TaskSpecification": + CTaskSpec(const RpcTaskSpec message) + CTaskSpec(const c_string &serialized_binary) + const RpcTaskSpec &GetMessage() + c_string Serialize() const CTaskID TaskId() const CJobID JobId() const CTaskID ParentTaskId() const - int64_t ParentCounter() const + uint64_t ParentCounter() const c_vector[c_string] FunctionDescriptor() const c_string FunctionDescriptorString() const - int64_t NumArgs() const - int64_t NumReturns() const - c_bool ArgByRef(int64_t arg_index) const - int ArgIdCount(int64_t arg_index) const - CObjectID ArgId(int64_t arg_index, int64_t id_index) const - CObjectID ReturnId(int64_t return_index) const - const uint8_t *ArgVal(int64_t arg_index) const - size_t ArgValLength(int64_t arg_index) const + uint64_t NumArgs() const + uint64_t NumReturns() const + c_bool ArgByRef(uint64_t arg_index) const + int ArgIdCount(uint64_t arg_index) const + CObjectID ArgId(uint64_t arg_index, uint64_t id_index) const + CObjectID ReturnId(uint64_t return_index) const + const uint8_t *ArgVal(uint64_t arg_index) const + size_t ArgValLength(uint64_t arg_index) const double GetRequiredResource(const c_string &resource_name) const const ResourceSet GetRequiredResources() const const ResourceSet GetRequiredPlacementResources() const @@ -93,25 +66,46 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil: c_bool IsActorTask() const CActorID ActorCreationId() const CObjectID ActorCreationDummyObjectId() const - int64_t MaxActorReconstructions() const + uint64_t MaxActorReconstructions() const CActorID ActorId() const CActorHandleID ActorHandleId() const - int64_t ActorCounter() const + uint64_t ActorCounter() const CObjectID ActorDummyObject() const c_vector[CActorHandleID] NewActorHandles() const +cdef extern from "ray/raylet/task_util.h" namespace "ray::raylet" nogil: + cdef cppclass TaskSpecBuilder "ray::raylet::TaskSpecBuilder": + TaskSpecBuilder &SetCommonTaskSpec( + const CLanguage &language, const c_vector[c_string] &function_descriptor, + const CJobID &job_id, const CTaskID &parent_task_id, uint64_t parent_counter, + uint64_t num_returns, const unordered_map[c_string, double] &required_resources, + const unordered_map[c_string, double] &required_placement_resources); + + TaskSpecBuilder &AddByRefArg(const CObjectID &arg_id); + + TaskSpecBuilder &AddByValueArg(const c_string &data); + + TaskSpecBuilder &SetActorCreationTaskSpec( + const CActorID &actor_id, uint64_t max_reconstructions, + const c_vector[c_string] &dynamic_worker_options); + + TaskSpecBuilder &SetActorTaskSpec( + const CActorID &actor_id, const CActorHandleID &actor_handle_id, + const CObjectID &actor_creation_dummy_object_id, uint64_t actor_counter, + const c_vector[CActorHandleID] &new_handle_ids); + + RpcTaskSpec GetMessage(); + + +cdef extern from "ray/raylet/task_execution_spec.h" namespace "ray::raylet" nogil: + cdef cppclass CTaskExecutionSpec "ray::raylet::TaskExecutionSpecification": + CTaskExecutionSpec(RpcTaskExecutionSpec message) + CTaskExecutionSpec(const c_string &serialized_binary) + const RpcTaskExecutionSpec &GetMessage() + c_vector[CObjectID] ExecutionDependencies() + uint64_t NumForwards() + cdef extern from "ray/raylet/task.h" namespace "ray::raylet" nogil: cdef cppclass CTask "ray::raylet::Task": - CTask(const CTaskExecutionSpecification &execution_spec, - const CTaskSpecification &task_spec) - const CTaskExecutionSpecification &GetTaskExecutionSpec() const - const CTaskSpecification &GetTaskSpecification() const - void SetExecutionDependencies(const c_vector[CObjectID] &dependencies) - void IncrementNumForwards() - const c_vector[CObjectID] &GetDependencies() const - void CopyTaskExecutionSpec(const CTask &task) - - cdef c_string SerializeTaskAsString( - const c_vector[CObjectID] *dependencies, - const CTaskSpecification *task_spec) + CTask(CTaskSpec task_spec, CTaskExecutionSpec task_execution_spec) diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi index c9e64d83b..42838b08a 100644 --- a/python/ray/includes/task.pxi +++ b/python/ray/includes/task.pxi @@ -5,18 +5,19 @@ from libcpp.memory cimport ( static_pointer_cast, ) from ray.includes.task cimport ( - CTaskArgument, - CTaskArgumentByReference, - CTaskArgumentByValue, - CTaskSpecification, - SerializeTaskAsString, + CTask, + CTaskExecutionSpec, + CTaskSpec, + RpcTaskExecutionSpec, + TaskSpecBuilder, + TaskTableData, ) -cdef class Task: +cdef class TaskSpec: + """Cython wrapper class of C++ `ray::raylet::TaskSpecification`.""" cdef: - unique_ptr[CTaskSpecification] task_spec - unique_ptr[c_vector[CObjectID]] execution_dependencies + unique_ptr[CTaskSpec] task_spec def __init__(self, JobID job_id, function_descriptor, arguments, int num_returns, TaskID parent_task_id, int parent_counter, @@ -24,73 +25,78 @@ cdef class Task: ObjectID actor_creation_dummy_object_id, int32_t max_actor_reconstructions, ActorID actor_id, ActorHandleID actor_handle_id, int actor_counter, - new_actor_handles, execution_arguments, resource_map, - placement_resource_map): + new_actor_handles, resource_map, placement_resource_map): cdef: + TaskSpecBuilder builder unordered_map[c_string, double] required_resources unordered_map[c_string, double] required_placement_resources - c_vector[shared_ptr[CTaskArgument]] task_args - c_vector[CActorHandleID] task_new_actor_handles c_vector[c_string] c_function_descriptor c_string pickled_str - c_vector[CObjectID] references + c_vector[CActorHandleID] c_new_actor_handles + # Convert function descriptor to C++ vector. for item in function_descriptor: if not isinstance(item, bytes): raise TypeError( "'function_descriptor' takes a list of byte strings.") c_function_descriptor.push_back(item) - # Parse the resource map. + # Convert resource map to C++ unordered_map. if resource_map is not None: required_resources = resource_map_from_dict(resource_map) if placement_resource_map is not None: required_placement_resources = ( resource_map_from_dict(placement_resource_map)) - # Parse the arguments from the list. + # Build common task spec. + builder.SetCommonTaskSpec( + LANGUAGE_PYTHON, + c_function_descriptor, + job_id.native(), + parent_task_id.native(), + parent_counter, + num_returns, + required_resources, + required_placement_resources, + ) + + # Build arguments. for arg in arguments: if isinstance(arg, ObjectID): - references = c_vector[CObjectID]() - references.push_back((arg).native()) - task_args.push_back( - static_pointer_cast[CTaskArgument, - CTaskArgumentByReference]( - make_shared[CTaskArgumentByReference](references))) + builder.AddByRefArg((arg).native()) else: pickled_str = pickle.dumps( arg, protocol=pickle.HIGHEST_PROTOCOL) - task_args.push_back( - static_pointer_cast[CTaskArgument, - CTaskArgumentByValue]( - make_shared[CTaskArgumentByValue]( - pickled_str.c_str(), - pickled_str.size()))) + builder.AddByValueArg(pickled_str) - for new_actor_handle in new_actor_handles: - task_new_actor_handles.push_back( - (new_actor_handle).native()) - - self.task_spec.reset(new CTaskSpecification( - job_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(), - actor_creation_dummy_object_id.native(), max_actor_reconstructions, actor_id.native(), - actor_handle_id.native(), actor_counter, task_new_actor_handles, task_args, num_returns, - required_resources, required_placement_resources, LANGUAGE_PYTHON, - c_function_descriptor)) - - # Set the task's execution dependencies. - self.execution_dependencies.reset(new c_vector[CObjectID]()) - if execution_arguments is not None: - for execution_arg in execution_arguments: - self.execution_dependencies.get().push_back( - (execution_arg).native()) + if not actor_creation_id.is_nil(): + # Actor creation task. + builder.SetActorCreationTaskSpec( + actor_creation_id.native(), + max_actor_reconstructions, + [], + ) + elif not actor_id.is_nil(): + # Actor task. + for new_actor_handle in new_actor_handles: + c_new_actor_handles.push_back( + (new_actor_handle).native()) + builder.SetActorTaskSpec( + actor_id.native(), + actor_handle_id.native(), + actor_creation_dummy_object_id.native(), + actor_counter, + c_new_actor_handles, + ) + else: + # Normal task. + pass + self.task_spec.reset(new CTaskSpec(builder.GetMessage())) @staticmethod - cdef make(unique_ptr[CTaskSpecification]& task_spec): - cdef Task self = Task.__new__(Task) + cdef make(unique_ptr[CTaskSpec]& task_spec): + cdef TaskSpec self = TaskSpec.__new__(TaskSpec) self.task_spec.reset(task_spec.release()) - # The created task does not include any execution dependencies. - self.execution_dependencies.reset(new c_vector[CObjectID]()) return self @staticmethod @@ -103,11 +109,8 @@ cdef class Task: Returns: Python task specification object. """ - cdef Task self = Task.__new__(Task) - # TODO(pcm): Use flatbuffers validation here. - self.task_spec.reset(new CTaskSpecification(task_spec_str)) - # The created task does not include any execution dependencies. - self.execution_dependencies.reset(new c_vector[CObjectID]()) + cdef TaskSpec self = TaskSpec.__new__(TaskSpec) + self.task_spec.reset(new CTaskSpec(task_spec_str)) return self def to_string(self): @@ -116,11 +119,7 @@ cdef class Task: Returns: String representing the task specification. """ - return self.task_spec.get().SerializeAsString() - - def _serialized_raylet_task(self): - return SerializeTaskAsString( - self.execution_dependencies.get(), self.task_spec.get()) + return self.task_spec.get().Serialize() def job_id(self): """Return the job ID for this task.""" @@ -150,7 +149,7 @@ cdef class Task: def arguments(self): """Return the arguments for the task.""" cdef: - CTaskSpecification *task_spec = self.task_spec.get() + CTaskSpec*task_spec = self.task_spec.get() int64_t num_args = task_spec.NumArgs() int32_t lang = task_spec.GetLanguage() int count @@ -175,7 +174,7 @@ cdef class Task: def returns(self): """Return the object IDs for the return values of the task.""" - cdef CTaskSpecification *task_spec = self.task_spec.get() + cdef CTaskSpec *task_spec = self.task_spec.get() return_id_list = [] for i in range(task_spec.NumReturns()): return_id_list.append(ObjectID(task_spec.ReturnId(i).Binary())) @@ -221,3 +220,59 @@ cdef class Task: def actor_counter(self): """Return the actor counter for this task.""" return self.task_spec.get().ActorCounter() + + +cdef class TaskExecutionSpec: + """Cython wrapper class of C++ `ray::raylet::TaskExecutionSpecification`.""" + cdef: + unique_ptr[CTaskExecutionSpec] c_spec + + def __init__(self, execution_dependencies): + cdef: + RpcTaskExecutionSpec message; + + for dependency in execution_dependencies: + message.add_dependencies( + (dependency).binary()) + self.c_spec.reset(new CTaskExecutionSpec(message)) + + @staticmethod + def from_string(const c_string& string): + """Convert a string to a Ray `TaskExecutionSpec` Python object. + """ + cdef TaskExecutionSpec self = TaskExecutionSpec.__new__(TaskExecutionSpec) + self.c_spec.reset(new CTaskExecutionSpec(string)) + return self + + def dependencies(self): + cdef: + CObjectID c_id + c_vector[CObjectID] dependencies = ( + self.c_spec.get().ExecutionDependencies()) + ret = [] + for c_id in dependencies: + ret.append(ObjectID(c_id.Binary())) + return ret + + def num_forwards(self): + return self.c_spec.get().NumForwards() + + +cdef class Task: + """Cython wrapper class of C++ `ray::raylet::Task`.""" + cdef: + unique_ptr[CTask] c_task + + def __init__(self, TaskSpec task_spec, TaskExecutionSpec task_execution_spec): + self.c_task.reset(new CTask(task_spec.task_spec.get()[0], + task_execution_spec.c_spec.get()[0])) + + +def generate_gcs_task_table_data(TaskSpec task_spec): + """Converts a Python `TaskSpec` object to serialized GCS `TaskTableData`. + """ + cdef: + TaskTableData task_table_data + task_table_data.mutable_task().mutable_task_spec().CopyFrom( + task_spec.task_spec.get().GetMessage()) + return task_table_data.SerializeAsString() diff --git a/python/ray/state.py b/python/ray/state.py index 75fd8a014..c9f718fde 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -305,12 +305,9 @@ class GlobalState(object): assert len(gcs_entries.entries) == 1 task_table_data = gcs_utils.TaskTableData.FromString( gcs_entries.entries[0]) - task_table_message = gcs_utils.Task.GetRootAsTask( - task_table_data.task, 0) - execution_spec = task_table_message.TaskExecutionSpec() - task_spec = task_table_message.TaskSpecification() - task = ray._raylet.Task.from_string(task_spec) + task = ray._raylet.TaskSpec.from_string( + task_table_data.task.task_spec.SerializeToString()) function_descriptor_list = task.function_descriptor_list() function_descriptor = FunctionDescriptor.from_bytes_list( function_descriptor_list) @@ -335,14 +332,12 @@ class GlobalState(object): "FunctionName": function_descriptor.function_name, } + execution_spec = ray._raylet.TaskExecutionSpec.from_string( + task_table_data.task.task_execution_spec.SerializeToString()) return { "ExecutionSpec": { - "Dependencies": [ - execution_spec.Dependencies(i) - for i in range(execution_spec.DependenciesLength()) - ], - "LastTimestamp": execution_spec.LastTimestamp(), - "NumForwards": execution_spec.NumForwards() + "Dependencies": execution_spec.dependencies(), + "NumForwards": execution_spec.num_forwards(), }, "TaskSpec": task_spec_info } diff --git a/python/ray/worker.py b/python/ray/worker.py index 092061e84..5a3146255 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -688,7 +688,7 @@ class Worker(object): function_descriptor_list = ( function_descriptor.get_function_descriptor_list()) assert isinstance(job_id, JobID) - task = ray._raylet.Task( + task = ray._raylet.TaskSpec( job_id, function_descriptor_list, args_for_raylet, @@ -702,11 +702,10 @@ class Worker(object): actor_handle_id, actor_counter, new_actor_handles, - execution_dependencies, resources, placement_resources, ) - self.raylet_client.submit_task(task) + self.raylet_client.submit_task(task, execution_dependencies) return task.returns() @@ -1864,7 +1863,7 @@ def connect(node, nil_actor_counter = 0 function_descriptor = FunctionDescriptor.for_driver_task() - driver_task = ray._raylet.Task( + driver_task_spec = ray._raylet.TaskSpec( worker.current_job_id, function_descriptor.get_function_descriptor_list(), [], # arguments. @@ -1878,24 +1877,25 @@ def connect(node, ActorHandleID.nil(), # actor_handle_id. nil_actor_counter, # actor_counter. [], # new_actor_handles. - [], # execution_dependencies. {}, # resource_map. {}, # placement_resource_map. ) - task_table_data = ray.gcs_utils.TaskTableData() - task_table_data.task = driver_task._serialized_raylet_task() + task_table_data = ray._raylet.generate_gcs_task_table_data( + driver_task_spec) # Add the driver task to the task table. ray.state.state._execute_command( - driver_task.task_id(), "RAY.TABLE_ADD", + driver_task_spec.task_id(), + "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"), ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"), - driver_task.task_id().binary(), - task_table_data.SerializeToString()) + driver_task_spec.task_id().binary(), + task_table_data, + ) # Set the driver's current task ID to the task ID assigned to the # driver task. - worker.task_context.current_task_id = driver_task.task_id() + worker.task_context.current_task_id = driver_task_spec.task_id() worker.raylet_client = ray._raylet.RayletClient( node.raylet_socket_name, diff --git a/python/setup.py b/python/setup.py index a0b5376ca..64e366157 100644 --- a/python/setup.py +++ b/python/setup.py @@ -28,11 +28,10 @@ ray_files = [ "ray/dashboard/res/main.css", "ray/dashboard/res/main.js" ] -# These are the directories where automatically generated Python flatbuffer +# These are the directories where automatically generated Python protobuf # bindings are created. generated_python_directories = [ - "ray/core/generated", "ray/core/generated/ray", - "ray/core/generated/ray/protocol" + "ray/core/generated", ] optional_ray_files = [] @@ -88,7 +87,7 @@ class build_ext(_build_ext.build_ext): files_to_include = ray_files + pyarrow_files + modin_files - # Copy over the autogenerated flatbuffer Python bindings. + # Copy over the autogenerated protobuf Python bindings. for directory in generated_python_directories: for filename in os.listdir(directory): if filename[-3:] == ".py": @@ -148,7 +147,6 @@ requires = [ # NOTE: Don't upgrade the version of six! Doing so causes installation # problems. See https://github.com/ray-project/ray/issues/4169. "six >= 1.0.0", - "flatbuffers", "faulthandler;python_version<'3.3'", "protobuf >= 3.8.0", ] diff --git a/src/ray/common/common_protocol.cc b/src/ray/common/common_protocol.cc index adce684fc..a58808a36 100644 --- a/src/ray/common/common_protocol.cc +++ b/src/ray/common/common_protocol.cc @@ -6,28 +6,6 @@ std::string string_from_flatbuf(const flatbuffers::String &string) { return std::string(string.data(), string.size()); } -const std::unordered_map map_from_flatbuf( - const flatbuffers::Vector> &resource_vector) { - std::unordered_map required_resources; - for (int64_t i = 0; i < resource_vector.size(); i++) { - const ResourcePair *resource_pair = resource_vector.Get(i); - required_resources[string_from_flatbuf(*resource_pair->key())] = - resource_pair->value(); - } - return required_resources; -} - -flatbuffers::Offset>> -map_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, - const std::unordered_map &resource_map) { - std::vector> resource_vector; - for (auto const &resource_pair : resource_map) { - resource_vector.push_back(CreateResourcePair( - fbb, fbb.CreateString(resource_pair.first), resource_pair.second)); - } - return fbb.CreateVector(resource_vector); -} - std::vector string_vec_from_flatbuf( const flatbuffers::Vector> &flatbuf_vec) { std::vector string_vector; diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h index 1dbd6bbc2..3eff4cd8e 100644 --- a/src/ray/common/common_protocol.h +++ b/src/ray/common/common_protocol.h @@ -1,8 +1,7 @@ #ifndef COMMON_PROTOCOL_H #define COMMON_PROTOCOL_H -#include "ray/gcs/format/gcs_generated.h" - +#include #include #include "ray/common/id.h" @@ -76,24 +75,6 @@ to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids); /// @return The std::string version of the flatbuffer string. std::string string_from_flatbuf(const flatbuffers::String &string); -/// Convert a std::unordered_map to a flatbuffer vector of pairs. -/// -/// @param fbb Reference to the flatbuffer builder. -/// @param resource_map A mapping from resource name to resource quantity. -/// @return A flatbuffer vector of ResourcePair objects. -flatbuffers::Offset>> -map_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, - const std::unordered_map &resource_map); - -/// Convert a flatbuffer vector of ResourcePair objects to a std::unordered map -/// from resource name to resource quantity. -/// -/// @param fbb Reference to the flatbuffer builder. -/// @param resource_vector The flatbuffer object. -/// @return A map from resource name to resource quantity. -const std::unordered_map map_from_flatbuf( - const flatbuffers::Vector> &resource_vector); - std::vector string_vec_from_flatbuf( const flatbuffers::Vector> &flatbuf_vec); diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index ee20e88f4..d15e7c8d8 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -5,12 +5,14 @@ #include "ray/common/buffer.h" #include "ray/common/id.h" -#include "ray/gcs/format/gcs_generated.h" #include "ray/raylet/raylet_client.h" #include "ray/raylet/task_spec.h" namespace ray { +using rpc::Language; +using rpc::TaskType; + /// Type of this worker. enum class WorkerType { WORKER, DRIVER }; @@ -66,8 +68,6 @@ class TaskArg { const std::shared_ptr data_; }; -enum class TaskType { NORMAL_TASK, ACTOR_CREATION_TASK, ACTOR_TASK }; - /// Information of a task struct TaskInfo { /// The ID of task. diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index fe6ed290d..661996630 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -3,7 +3,7 @@ namespace ray { -CoreWorker::CoreWorker(const enum WorkerType worker_type, const ::Language language, +CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum Language language, const std::string &store_socket, const std::string &raylet_socket, const JobID &job_id) : worker_type_(worker_type), diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 8a3347b2a..c61b50c78 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -7,7 +7,6 @@ #include "ray/core_worker/object_interface.h" #include "ray/core_worker/task_execution.h" #include "ray/core_worker/task_interface.h" -#include "ray/gcs/format/gcs_generated.h" #include "ray/raylet/raylet_client.h" namespace ray { @@ -23,7 +22,7 @@ class CoreWorker { /// \param[in] langauge Language of this worker. /// /// NOTE(zhijunfu): the constructor would throw if a failure happens. - CoreWorker(const WorkerType worker_type, const ::Language language, + CoreWorker(const WorkerType worker_type, const Language language, const std::string &store_socket, const std::string &raylet_socket, const JobID &job_id = JobID::Nil()); @@ -31,7 +30,7 @@ class CoreWorker { enum WorkerType WorkerType() const { return worker_type_; } /// Language of this worker. - ::Language Language() const { return language_; } + enum Language Language() const { return language_; } /// Return the `CoreWorkerTaskInterface` that contains the methods related to task /// submisson. @@ -53,7 +52,7 @@ class CoreWorker { const enum WorkerType worker_type_; /// Language of this worker. - const ::Language language_; + const enum Language language_; /// raylet socket name. const std::string raylet_socket_; diff --git a/src/ray/core_worker/core_worker_test.cc b/src/ray/core_worker/core_worker_test.cc index cf40e175f..3eb58595a 100644 --- a/src/ray/core_worker/core_worker_test.cc +++ b/src/ray/core_worker/core_worker_test.cc @@ -301,8 +301,7 @@ TEST_F(ZeroNodeTest, TestWorkerContext) { } TEST_F(ZeroNodeTest, TestActorHandle) { - ActorHandle handle1(ActorID::FromRandom(), ActorHandleID::FromRandom(), - ::Language::JAVA, + ActorHandle handle1(ActorID::FromRandom(), ActorHandleID::FromRandom(), Language::JAVA, {"org.ray.exampleClass", "exampleMethod", "exampleSignature"}); auto forkedHandle1 = handle1.Fork(); diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index e85ce28c9..da6bf25fa 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -8,11 +8,11 @@ namespace ray { ActorHandle::ActorHandle( const class ActorID &actor_id, const class ActorHandleID &actor_handle_id, - const ::Language actor_language, + const Language actor_language, const std::vector &actor_creation_task_function_descriptor) { inner_.set_actor_id(actor_id.Data(), actor_id.Size()); inner_.set_actor_handle_id(actor_handle_id.Data(), actor_handle_id.Size()); - inner_.set_actor_language(static_cast(actor_language)); + inner_.set_actor_language(actor_language); *inner_.mutable_actor_creation_task_function_descriptor() = { actor_creation_task_function_descriptor.begin(), actor_creation_task_function_descriptor.end()}; @@ -30,9 +30,7 @@ ray::ActorHandleID ActorHandle::ActorHandleID() const { return ActorHandleID::FromBinary(inner_.actor_handle_id()); }; -::Language ActorHandle::ActorLanguage() const { - return (::Language)inner_.actor_language(); -}; +Language ActorHandle::ActorLanguage() const { return inner_.actor_language(); }; std::vector ActorHandle::ActorCreationTaskFunctionDescriptor() const { return ray::rpc::VectorFromProtobuf(inner_.actor_creation_task_function_descriptor()); @@ -100,30 +98,43 @@ CoreWorkerTaskInterface::CoreWorkerTaskInterface( new CoreWorkerRayletTaskSubmitter(raylet_client))); } -Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function, - const std::vector &args, - const TaskOptions &task_options, - std::vector *return_ids) { - auto &context = worker_context_; - auto next_task_index = context.GetNextTaskIndex(); - const auto task_id = GenerateTaskId(context.GetCurrentJobID(), - context.GetCurrentTaskID(), next_task_index); +raylet::TaskSpecBuilder CoreWorkerTaskInterface::BuildCommonTaskSpec( + const RayFunction &function, const std::vector &args, uint64_t num_returns, + const std::unordered_map &required_resources, + const std::unordered_map &required_placement_resources, + std::vector *return_ids) { + raylet::TaskSpecBuilder builder; + auto next_task_index = worker_context_.GetNextTaskIndex(); + // Build common task spec. + builder.SetCommonTaskSpec( + function.language, function.function_descriptor, worker_context_.GetCurrentJobID(), + worker_context_.GetCurrentTaskID(), next_task_index, num_returns, + required_resources, required_placement_resources); + // Set task arguments. + for (const auto &arg : args) { + if (arg.IsPassedByReference()) { + builder.AddByRefArg(arg.GetReference()); + } else { + builder.AddByValueArg(arg.GetValue()->Data(), arg.GetValue()->Size()); + } + } - auto num_returns = task_options.num_returns; + // Compute return IDs. + const auto task_id = TaskID::FromBinary(builder.GetMessage().task_id()); (*return_ids).resize(num_returns); for (int i = 0; i < num_returns; i++) { (*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1); } + return builder; +} - auto task_arguments = BuildTaskArguments(args); - - ray::raylet::TaskSpecification spec(context.GetCurrentJobID(), - context.GetCurrentTaskID(), next_task_index, - task_arguments, num_returns, task_options.resources, - function.language, function.function_descriptor); - - std::vector execution_dependencies; - TaskSpec task(std::move(spec), execution_dependencies); +Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function, + const std::vector &args, + const TaskOptions &task_options, + std::vector *return_ids) { + auto builder = BuildCommonTaskSpec(function, args, task_options.num_returns, + task_options.resources, {}, return_ids); + TaskSpec task(builder.Build(), {}); return task_submitters_[static_cast(TaskTransportType::RAYLET)]->SubmitTask(task); } @@ -131,33 +142,20 @@ Status CoreWorkerTaskInterface::CreateActor( const RayFunction &function, const std::vector &args, const ActorCreationOptions &actor_creation_options, std::unique_ptr *actor_handle) { - auto &context = worker_context_; - auto next_task_index = context.GetNextTaskIndex(); - const auto task_id = GenerateTaskId(context.GetCurrentJobID(), - context.GetCurrentTaskID(), next_task_index); - std::vector return_ids; - return_ids.push_back(ObjectID::ForTaskReturn(task_id, 1)); - ActorID actor_creation_id = ActorID::FromBinary(return_ids[0].Binary()); - *actor_handle = std::unique_ptr( - new ActorHandle(actor_creation_id, ActorHandleID::Nil(), function.language, - function.function_descriptor)); + auto builder = BuildCommonTaskSpec(function, args, 1, actor_creation_options.resources, + actor_creation_options.resources, &return_ids); + + const ActorID actor_id = ActorID::FromBinary(return_ids[0].Binary()); + builder.SetActorCreationTaskSpec(actor_id, actor_creation_options.max_reconstructions, + {}); + + *actor_handle = std::unique_ptr(new ActorHandle( + actor_id, ActorHandleID::Nil(), function.language, function.function_descriptor)); (*actor_handle)->IncreaseTaskCounter(); (*actor_handle)->SetActorCursor(return_ids[0]); - auto task_arguments = BuildTaskArguments(args); - - // Note that the caller is supposed to specify required placement resources - // correctly via actor_creation_options.resources. - ray::raylet::TaskSpecification spec( - context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index, - actor_creation_id, ObjectID::Nil(), actor_creation_options.max_reconstructions, - ActorID::Nil(), ActorHandleID::Nil(), 0, {}, task_arguments, 1, - actor_creation_options.resources, actor_creation_options.resources, - function.language, function.function_descriptor); - - std::vector execution_dependencies; - TaskSpec task(std::move(spec), execution_dependencies); + const TaskSpec task(builder.Build(), {}); return task_submitters_[static_cast(TaskTransportType::RAYLET)]->SubmitTask(task); } @@ -166,65 +164,37 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, const std::vector &args, const TaskOptions &task_options, std::vector *return_ids) { - auto &context = worker_context_; - auto next_task_index = context.GetNextTaskIndex(); - const auto task_id = GenerateTaskId(context.GetCurrentJobID(), - context.GetCurrentTaskID(), next_task_index); + // Add one for actor cursor object id. + const auto num_returns = task_options.num_returns + 1; - // add one for actor cursor object id. - auto num_returns = task_options.num_returns + 1; - (*return_ids).resize(num_returns); - for (int i = 0; i < num_returns; i++) { - (*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1); - } - - auto actor_creation_dummy_object_id = - ObjectID::FromBinary(actor_handle.ActorID().Binary()); - - auto task_arguments = BuildTaskArguments(args); + // Build common task spec. + auto builder = BuildCommonTaskSpec(function, args, num_returns, task_options.resources, + {}, return_ids); std::unique_lock guard(actor_handle.mutex_); + // Build actor task spec. + const auto actor_creation_dummy_object_id = + ObjectID::FromBinary(actor_handle.ActorID().Binary()); + builder.SetActorTaskSpec(actor_handle.ActorID(), actor_handle.ActorHandleID(), + actor_creation_dummy_object_id, + actor_handle.IncreaseTaskCounter(), + actor_handle.NewActorHandles()); - ray::raylet::TaskSpecification spec( - context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index, - ActorID::Nil(), actor_creation_dummy_object_id, 0, actor_handle.ActorID(), - actor_handle.ActorHandleID(), actor_handle.IncreaseTaskCounter(), - actor_handle.NewActorHandles(), task_arguments, num_returns, task_options.resources, - task_options.resources, function.language, function.function_descriptor); - - std::vector execution_dependencies; - execution_dependencies.push_back(actor_handle.ActorCursor()); + const TaskSpec task(builder.Build(), {actor_handle.ActorCursor()}); + // Manipulate actor handle state. auto actor_cursor = (*return_ids).back(); actor_handle.SetActorCursor(actor_cursor); actor_handle.ClearNewActorHandles(); - guard.unlock(); - TaskSpec task(std::move(spec), execution_dependencies); + // Submit task. auto status = task_submitters_[static_cast(TaskTransportType::RAYLET)]->SubmitTask(task); - // remove cursor from return ids. + // Remove cursor from return ids. (*return_ids).pop_back(); return status; } -std::vector> -CoreWorkerTaskInterface::BuildTaskArguments(const std::vector &args) { - std::vector> task_arguments; - for (const auto &arg : args) { - if (arg.IsPassedByReference()) { - std::vector references{arg.GetReference()}; - task_arguments.push_back( - std::make_shared(references)); - } else { - auto data = arg.GetValue(); - task_arguments.push_back( - std::make_shared(data->Data(), data->Size())); - } - } - return task_arguments; -} - } // namespace ray diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h index aa5876c00..85c6d8ddb 100644 --- a/src/ray/core_worker/task_interface.h +++ b/src/ray/core_worker/task_interface.h @@ -9,10 +9,14 @@ #include "ray/core_worker/transport/transport.h" #include "ray/protobuf/core_worker.pb.h" #include "ray/raylet/task.h" +#include "ray/raylet/task_spec.h" +#include "ray/raylet/task_util.h" #include "ray/rpc/util.h" namespace ray { +using rpc::Language; + class CoreWorker; /// Options of a non-actor-creation task. @@ -45,7 +49,7 @@ struct ActorCreationOptions { class ActorHandle { public: ActorHandle(const ActorID &actor_id, const ActorHandleID &actor_handle_id, - const ::Language actor_language, + const Language actor_language, const std::vector &actor_creation_task_function_descriptor); ActorHandle(const ActorHandle &other); @@ -57,7 +61,7 @@ class ActorHandle { ray::ActorHandleID ActorHandleID() const; /// Language of the actor. - ::Language ActorLanguage() const; + Language ActorLanguage() const; // Function descriptor of actor creation task. std::vector ActorCreationTaskFunctionDescriptor() const; @@ -149,12 +153,21 @@ class CoreWorkerTaskInterface { std::vector *return_ids); private: - /// Build the arguments for a task spec. + /// Build common attributes of the task spec, and compute return ids. /// - /// \param[in] args Arguments of a task. - /// \return Arguments as required by task spec. - std::vector> BuildTaskArguments( - const std::vector &args); + /// \param[in] function The remote function to execute. + /// \param[in] args Arguments of this task. + /// \param[in] num_returns Number of returns. + /// \param[in] required_resources Resources required by this task. + /// \param[in] required_placement_resources Resources required by placing this task on a + /// node. + /// \param[out] return_ids Return IDs. + /// \return A `TaskSpecBuilder`. + raylet::TaskSpecBuilder BuildCommonTaskSpec( + const RayFunction &function, const std::vector &args, uint64_t num_returns, + const std::unordered_map &required_resources, + const std::unordered_map &required_placement_resources, + std::vector *return_ids); /// Reference to the parent CoreWorker's context. WorkerContext &worker_context_; diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc index b63a5ef89..df9d2fbe9 100644 --- a/src/ray/core_worker/transport/raylet_transport.cc +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -38,11 +38,8 @@ CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver( void CoreWorkerRayletTaskReceiver::HandleAssignTask( const rpc::AssignTaskRequest &request, rpc::AssignTaskReply *reply, rpc::RequestDoneCallback done_callback) { - const std::string &task_message = request.task_spec(); - const raylet::Task task(*flatbuffers::GetRoot( - reinterpret_cast(task_message.data()))); + const raylet::Task task(request.task()); const auto &spec = task.GetTaskSpecification(); - auto status = task_handler_(spec); done_callback(status); } diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index fb4fdedb3..bf4bd5550 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -13,10 +13,6 @@ namespace ray { namespace gcs { -namespace { -constexpr char kRandomId[] = "abcdefghijklmnopqrst"; -} // namespace - /* Flush redis. */ static inline void flushall_redis(void) { redisContext *context = redisConnect("127.0.0.1", 6379); @@ -82,23 +78,40 @@ class TestGcsWithChainAsio : public TestGcsWithAsio { TestGcsWithChainAsio() : TestGcsWithAsio(gcs::CommandType::kChain){}; }; -void TestTableLookup(const JobID &job_id, std::shared_ptr client) { - TaskID task_id = TaskID::FromRandom(); +/// A helper function that creates a GCS `TaskTableData` object. +std::shared_ptr CreateTaskTableData(const TaskID &task_id, + uint64_t num_returns = 0) { auto data = std::make_shared(); - data->set_task("123"); + data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary()); + data->mutable_task()->mutable_task_spec()->set_num_returns(num_returns); + return data; +} + +/// A helper function that compare wether 2 `TaskTableData` objects are equal. +/// Note, this function only compares fields set by `CreateTaskTableData`. +bool TaskTableDataEqual(const TaskTableData &data1, const TaskTableData &data2) { + const auto &spec1 = data1.task().task_spec(); + const auto &spec2 = data2.task().task_spec(); + return (spec1.task_id() == spec2.task_id() && + spec1.num_returns() == spec2.num_returns()); +} + +void TestTableLookup(const JobID &job_id, std::shared_ptr client) { + const auto task_id = TaskID::FromRandom(); + const auto data = CreateTaskTableData(task_id); // Check that we added the correct task. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task(), d.task()); + ASSERT_TRUE(TaskTableDataEqual(*data, d)); }; // Check that the lookup returns the added task. auto lookup_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task(), d.task()); + ASSERT_TRUE(TaskTableDataEqual(*data, d)); test->Stop(); }; @@ -386,7 +399,7 @@ void TestDeleteKeysFromTable(const JobID &job_id, auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task(), d.task()); + ASSERT_TRUE(TaskTableDataEqual(*data, d)); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback)); @@ -501,9 +514,7 @@ void TestDeleteKeys(const JobID &job_id, std::shared_ptr cl std::vector> task_vector; auto AppendTaskData = [&task_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto task_data = std::make_shared(); - task_data->set_task(ObjectID::FromRandom().Hex()); - task_vector.push_back(task_data); + task_vector.push_back(CreateTaskTableData(TaskID::FromRandom())); } }; AppendTaskData(1); @@ -682,25 +693,26 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeAll) { void TestTableSubscribeId(const JobID &job_id, std::shared_ptr client) { + int num_modifications = 3; + // Add a table entry. TaskID task_id1 = TaskID::FromRandom(); - std::vector task_specs1 = {"abc", "def", "ghi"}; // Add a table entry at a second key. TaskID task_id2 = TaskID::FromRandom(); - std::vector task_specs2 = {"jkl", "mno", "pqr"}; // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto notification_callback = [task_id2, task_specs2](gcs::AsyncGcsClient *client, - const TaskID &id, - const TaskTableData &data) { + auto notification_callback = [task_id2, num_modifications](gcs::AsyncGcsClient *client, + const TaskID &id, + const TaskTableData &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, task_id2); // Check that we get notifications in the same order as the writes. - ASSERT_EQ(data.task(), task_specs2[test->NumCallbacks()]); + ASSERT_TRUE( + TaskTableDataEqual(data, *CreateTaskTableData(task_id2, test->NumCallbacks()))); test->IncrementNumCallbacks(); - if (test->NumCallbacks() == task_specs2.size()) { + if (test->NumCallbacks() == num_modifications) { test->Stop(); } }; @@ -717,21 +729,19 @@ void TestTableSubscribeId(const JobID &job_id, // The callback for subscription success. Once we've subscribed, request // notifications for only one of the keys, then write to both keys. - auto subscribe_callback = [job_id, task_id1, task_id2, task_specs1, - task_specs2](gcs::AsyncGcsClient *client) { + auto subscribe_callback = [job_id, task_id1, task_id2, + num_modifications](gcs::AsyncGcsClient *client) { // Request notifications for one of the keys. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( job_id, task_id2, client->client_table().GetLocalClientId())); // Write both keys. We should only receive notifications for the key that // we requested them for. - for (const auto &task_spec : task_specs1) { - auto data = std::make_shared(); - data->set_task(task_spec); + for (uint64_t i = 0; i < num_modifications; i++) { + auto data = CreateTaskTableData(task_id1, i); RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id1, data, nullptr)); } - for (const auto &task_spec : task_specs2) { - auto data = std::make_shared(); - data->set_task(task_spec); + for (uint64_t i = 0; i < num_modifications; i++) { + auto data = CreateTaskTableData(task_id2, i); RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id2, data, nullptr)); } }; @@ -749,7 +759,7 @@ void TestTableSubscribeId(const JobID &job_id, ASSERT_TRUE(failure_notification_received); // Check that we received one notification callback for each write to the // requested key. - ASSERT_EQ(test->NumCallbacks(), task_specs2.size()); + ASSERT_EQ(test->NumCallbacks(), num_modifications); } TEST_MACRO(TestGcsWithAsio, TestTableSubscribeId); @@ -910,10 +920,9 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeId) { void TestTableSubscribeCancel(const JobID &job_id, std::shared_ptr client) { // Add a table entry. - TaskID task_id = TaskID::FromRandom(); - std::vector task_specs = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->set_task(task_specs[0]); + const auto task_id = TaskID::FromRandom(); + const int num_modifications = 3; + const auto data = CreateTaskTableData(task_id, 0); RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr)); // The failure callback should not be called since all keys are non-empty @@ -924,26 +933,26 @@ void TestTableSubscribeCancel(const JobID &job_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto notification_callback = [task_id, task_specs](gcs::AsyncGcsClient *client, - const TaskID &id, - const TaskTableData &data) { + auto notification_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, + const TaskTableData &data) { ASSERT_EQ(id, task_id); // Check that we only get notifications for the first and last writes, // since notifications are canceled in between. if (test->NumCallbacks() == 0) { - ASSERT_EQ(data.task(), task_specs.front()); + ASSERT_TRUE(TaskTableDataEqual(data, *CreateTaskTableData(task_id, 0))); } else { - ASSERT_EQ(data.task(), task_specs.back()); + ASSERT_TRUE( + TaskTableDataEqual(data, *CreateTaskTableData(task_id, num_modifications - 1))); } test->IncrementNumCallbacks(); - if (test->NumCallbacks() == 2) { + if (test->NumCallbacks() == num_modifications - 1) { test->Stop(); } }; // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto subscribe_callback = [job_id, task_id, task_specs](gcs::AsyncGcsClient *client) { + auto subscribe_callback = [job_id, task_id](gcs::AsyncGcsClient *client) { // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( @@ -952,10 +961,8 @@ void TestTableSubscribeCancel(const JobID &job_id, job_id, task_id, client->client_table().GetLocalClientId())); // Write to the key. Since we canceled notifications, we should not receive // a notification for these writes. - auto remaining = std::vector(++task_specs.begin(), task_specs.end()); - for (const auto &task_spec : remaining) { - auto data = std::make_shared(); - data->set_task(task_spec); + for (uint64_t i = 1; i < num_modifications; i++) { + auto data = CreateTaskTableData(task_id, i); RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr)); } // Request notifications again. We should receive a notification for the diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs deleted file mode 100644 index 4ab9f385a..000000000 --- a/src/ray/gcs/format/gcs.fbs +++ /dev/null @@ -1,105 +0,0 @@ -// TODO(hchen): Migrate data structures in this file to protobuf (`gcs.proto`). - -enum Language:int { - PYTHON=0, - JAVA=1, - CPP=2, -} - -table Arg { - // Object ID for pass-by-reference arguments. Normally there is only one - // object ID in this list which represents the object that is being passed. - // However to support reducers in a MapReduce workload, we also support - // passing multiple object IDs for each argument. - // Note that this is a long string that concatenate all of the object IDs. - object_ids: string; - // Data for pass-by-value arguments. - data: string; -} - -table TaskInfo { - // ID of the job that created this task. - job_id: string; - // Task ID of the task. - task_id: string; - // Task ID of the parent task. - parent_task_id: string; - // A count of the number of tasks submitted by the parent task before this one. - parent_counter: int; - // The ID of the actor to create if this is an actor creation task. - actor_creation_id: string; - // The dummy object ID of the actor creation task if this is an actor method. - actor_creation_dummy_object_id: string; - // The max number of times this actor should be recontructed. - // If this number of 0 or negative, the actor won't be reconstructed on failure. - max_actor_reconstructions: int; - // Actor ID of the task. This is the actor that this task is executed on - // or NIL_ACTOR_ID if the task is just a normal task. - actor_id: string; - // The ID of the handle that was used to submit the task. This should be - // unique across handles with the same actor_id. - actor_handle_id: string; - // Number of tasks that have been submitted to this actor so far. - actor_counter: int; - // If this is an actor task, then this will be populated with all of the new - // actor handles that were forked from this handle since the last task on - // this handle was submitted. - // Note that this is a long string that concatenate all of the new_actor_handle IDs. - new_actor_handles: string; - // Task arguments. - args: [Arg]; - // Number of return objects. - num_returns: int; - // The required_resources vector indicates the quantities of the different - // resources required by this task. - required_resources: [ResourcePair]; - // The resources required for placing this task on a node. If this is empty, - // then the placement resources are equal to the required_resources. - required_placement_resources: [ResourcePair]; - // The language that this task belongs to. - language: Language; - // Function descriptor, which is a list of strings that can - // uniquely describe a function. - // For a Python function, it should be: [module_name, class_name, function_name] - // For a Java function, it should be: [class_name, method_name, type_descriptor] - function_descriptor: [string]; - // The dynamic options used in the worker command when starting the worker process for - // an actor creation task. If the list isn't empty, the options will be used to replace - // the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the - // worker command. - dynamic_worker_options: [string]; -} - -table ResourcePair { - // The name of the resource. - key: string; - // The quantity of the resource. - value: double; -} - -table ProfileEvent { - // The type of the event. - event_type: string; - // The start time of the event. - start_time: double; - // The end time of the event. If the event is a point event, then this should - // be the same as the start time. - end_time: double; - // Additional data associated with the event. This data must be serialized - // using JSON. - extra_data: string; -} - -table ProfileTableData { - // The type of the component that generated the event, e.g., worker or - // object_manager, or node_manager. - component_type: string; - // An identifier for the component that generated the event. - component_id: string; - // An identifier for the node that generated the event. - node_ip_address: string; - // This is a batch of profiling events. We batch these together for - // performance reasons because a single task may generate many events, and - // we don't want each event to require a GCS command. - profile_events: [ProfileEvent]; -} diff --git a/src/ray/gcs/format/util.h b/src/ray/gcs/format/util.h deleted file mode 100644 index 85bc0511d..000000000 --- a/src/ray/gcs/format/util.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef RAY_RAYLET_GCS_FORMAT_UTIL_H -#define RAY_RAYLET_GCS_FORMAT_UTIL_H - -#include "ray/gcs/format/gcs_generated.h" - -namespace std { - -template <> -struct hash { - size_t operator()(const Language &language) const { - return std::hash()(static_cast(language)); - } -}; - -template <> -struct hash { - size_t operator()(const Language &language) const { - return std::hash()(static_cast(language)); - } -}; - -} // namespace std - -#endif // RAY_RAYLET_GCS_FORMAT_UTIL_H diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index c3a82c320..197f36179 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -4,7 +4,6 @@ #include "ray/common/common_protocol.h" #include "ray/common/id.h" #include "ray/common/status.h" -#include "ray/gcs/format/gcs_generated.h" #include "ray/protobuf/gcs.pb.h" #include "ray/util/logging.h" #include "redis_string.h" diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index d3135288e..587ab937e 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -40,7 +40,8 @@ namespace gcs { template Status Log::Append(const JobID &job_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) { + const std::shared_ptr &data, + const WriteCallback &done) { num_appends_++; auto callback = [this, id, data, done](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); @@ -59,8 +60,9 @@ Status Log::Append(const JobID &job_id, const ID &id, template Status Log::AppendAt(const JobID &job_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done, - const WriteCallback &failure, int log_length) { + const std::shared_ptr &data, + const WriteCallback &done, const WriteCallback &failure, + int log_length) { num_appends_++; auto callback = [this, id, data, done, failure](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); @@ -226,7 +228,8 @@ std::string Log::DebugString() const { template Status Table::Add(const JobID &job_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) { + const std::shared_ptr &data, + const WriteCallback &done) { num_adds_++; auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { @@ -288,8 +291,8 @@ std::string Table::DebugString() const { } template -Status Set::Add(const JobID &job_id, const ID &id, std::shared_ptr &data, - const WriteCallback &done) { +Status Set::Add(const JobID &job_id, const ID &id, + const std::shared_ptr &data, const WriteCallback &done) { num_adds_++; auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { @@ -303,7 +306,8 @@ Status Set::Add(const JobID &job_id, const ID &id, std::shared_ptr Status Set::Remove(const JobID &job_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) { + const std::shared_ptr &data, + const WriteCallback &done) { num_removes_++; auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index f8f5ab3dc..364db09dc 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -67,10 +67,10 @@ class LogInterface { public: using WriteCallback = std::function; - virtual Status Append(const JobID &job_id, const ID &id, std::shared_ptr &data, - const WriteCallback &done) = 0; + virtual Status Append(const JobID &job_id, const ID &id, + const std::shared_ptr &data, const WriteCallback &done) = 0; virtual Status AppendAt(const JobID &job_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done, + const std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) = 0; virtual ~LogInterface(){}; }; @@ -126,7 +126,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Append(const JobID &job_id, const ID &id, std::shared_ptr &data, + Status Append(const JobID &job_id, const ID &id, const std::shared_ptr &data, const WriteCallback &done); /// Append a log entry to a key if and only if the log has the given number @@ -141,7 +141,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param log_length The number of entries that the log must have for the /// append to succeed. /// \return Status - Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr &data, + Status AppendAt(const JobID &job_id, const ID &id, const std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length); @@ -272,8 +272,8 @@ template class TableInterface { public: using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr &data, - const WriteCallback &done) = 0; + virtual Status Add(const JobID &job_id, const ID &task_id, + const std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~TableInterface(){}; }; @@ -315,7 +315,7 @@ class Table : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data, + Status Add(const JobID &job_id, const ID &id, const std::shared_ptr &data, const WriteCallback &done); /// Lookup an entry asynchronously. @@ -378,10 +378,10 @@ template class SetInterface { public: using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data, + virtual Status Add(const JobID &job_id, const ID &id, const std::shared_ptr &data, const WriteCallback &done) = 0; - virtual Status Remove(const JobID &job_id, const ID &id, std::shared_ptr &data, - const WriteCallback &done) = 0; + virtual Status Remove(const JobID &job_id, const ID &id, + const std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~SetInterface(){}; }; @@ -420,7 +420,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data, + Status Add(const JobID &job_id, const ID &id, const std::shared_ptr &data, const WriteCallback &done); /// Remove an entry from the set. @@ -431,7 +431,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Remove(const JobID &job_id, const ID &id, std::shared_ptr &data, + Status Remove(const JobID &job_id, const ID &id, const std::shared_ptr &data, const WriteCallback &done); Status Subscribe(const JobID &job_id, const ClientID &client_id, @@ -695,7 +695,8 @@ class TaskLeaseTable : public Table { prefix_ = TablePrefix::TASK_LEASE; } - Status Add(const JobID &job_id, const TaskID &id, std::shared_ptr &data, + Status Add(const JobID &job_id, const TaskID &id, + const std::shared_ptr &data, const WriteCallback &done) override { RAY_RETURN_NOT_OK((Table::Add(job_id, id, data, done))); // Mark the entry for expiration in Redis. It's okay if this command fails diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto new file mode 100644 index 000000000..4bee921c0 --- /dev/null +++ b/src/ray/protobuf/common.proto @@ -0,0 +1,123 @@ +syntax = "proto3"; + +package ray.rpc; + +option java_package = "org.ray.runtime.generated"; + +// Language of a task or worker. +enum Language { + PYTHON = 0; + JAVA = 1; + CPP = 2; +} + +// Type of a task. +enum TaskType { + // Normal task. + NORMAL_TASK = 0; + // Actor creation task. + ACTOR_CREATION_TASK = 1; + // Actor task. + ACTOR_TASK = 2; +} + +/// The task specification encapsulates all immutable information about the +/// task. These fields are determined at submission time, converse to the +/// `TaskExecutionSpec` may change at execution time. +message TaskSpec { + // Type of this task. + TaskType type = 1; + // Language of this task. + Language language = 2; + // Function descriptor of this task, which is a list of strings that can + // uniquely describe the function to execute. + // For a Python function, it should be: [module_name, class_name, function_name] + // For a Java function, it should be: [class_name, method_name, type_descriptor] + repeated bytes function_descriptor = 3; + // ID of the job that this task belongs to. + bytes job_id = 4; + // Task ID of the task. + bytes task_id = 5; + // Task ID of the parent task. + bytes parent_task_id = 6; + // A count of the number of tasks submitted by the parent task before this one. + uint64 parent_counter = 7; + // Task arguments. + repeated TaskArg args = 8; + // Number of return objects. + uint64 num_returns = 9; + // Quantities of the different resources required by this task. + map required_resources = 10; + // The resources required for placing this task on a node. If this is empty, + // then the placement resources are equal to the required_resources. + map required_placement_resources = 11; + // Task specification for an actor creation task. + // This field is only valid when `type == ACTOR_CREATION_TASK`. + ActorCreationTaskSpec actor_creation_task_spec = 14; + // Task specification for an actor task. + // This field is only valid when `type == ACTOR_TASK`. + ActorTaskSpec actor_task_spec = 15; +} + +// Argument in the task. +message TaskArg { + // Object IDs for pass-by-reference arguments. Normally there is only one + // object ID in this list which represents the object that is being passed. + // However to support reducers in a MapReduce workload, we also support + // passing multiple object IDs for each argument. + repeated bytes object_ids = 1; + // Data for pass-by-value arguments. + bytes data = 2; +} + +// Task spec of an actor creation task. +message ActorCreationTaskSpec { + // ID of the actor that will be created by this task. + bytes actor_id = 2; + // The max number of times this actor should be recontructed. + // If this number of 0 or negative, the actor won't be reconstructed on failure. + uint64 max_actor_reconstructions = 3; + // The dynamic options used in the worker command when starting the worker process for + // an actor creation task. If the list isn't empty, the options will be used to replace + // the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the + // worker command. + repeated string dynamic_worker_options = 4; +} + +// Task spec of an actor task. +message ActorTaskSpec { + // Actor ID of the task. This is the actor that this task is executed on + // or NIL_ACTOR_ID if the task is just a normal task. + bytes actor_id = 2; + // The ID of the handle that was used to submit the task. This should be + // unique across handles with the same actor_id. + bytes actor_handle_id = 3; + // The dummy object ID of the actor creation task if this is an actor method. + bytes actor_creation_dummy_object_id = 4; + // Number of tasks that have been submitted to this actor so far. + uint64 actor_counter = 5; + // If this is an actor task, then this will be populated with all of the new + // actor handles that were forked from this handle since the last task on + // this handle was submitted. + // Note that this is a long string that concatenate all of the new_actor_handle IDs. + repeated bytes new_actor_handles = 6; +} + +// The task execution specification encapsulates all mutable information about +// the task. These fields may change at execution time, converse to the +// `TaskSpec` is determined at submission time. +message TaskExecutionSpec { + // A list of object IDs representing the dependencies of this task that may + // change at execution time. + repeated bytes dependencies = 1; + // The last time this task was received for scheduling. + double last_timestamp = 2; + // The number of times this task was spilled back by raylets. + uint64 num_forwards = 3; +} + +// Represents a task, including task spec, and task execution spec. +message Task { + TaskSpec task_spec = 1; + TaskExecutionSpec task_execution_spec = 2; +} diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index e698b4602..0b238f456 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package ray.rpc; +import "src/ray/protobuf/common.proto"; + message ActorHandle { // ID of the actor. bytes actor_id = 1; @@ -10,7 +12,7 @@ message ActorHandle { bytes actor_handle_id = 2; // Language of the actor. - int32 actor_language = 3; + Language actor_language = 3; // Function descriptor of actor creation task. repeated string actor_creation_task_function_descriptor = 4; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index b33e4afb8..1dfb1a39a 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -2,14 +2,9 @@ syntax = "proto3"; package ray.rpc; -option java_package = "org.ray.runtime.generated"; +import "src/ray/protobuf/common.proto"; -// Language of a worker or task. -enum Language { - PYTHON = 0; - CPP = 1; - JAVA = 2; -} +option java_package = "org.ray.runtime.generated"; // These indexes are mapped to strings in ray_redis_module.cc. enum TablePrefix { @@ -77,12 +72,8 @@ message TaskReconstructionData { bytes node_manager_id = 2; } -// TODO(hchen): Task table currently still uses flatbuffers-defined data structure -// (`Task` in `node_manager.fbs`), because a lot of code depends on that. This should -// be migrated to protobuf very soon. message TaskTableData { - // Flatbuffers-serialized content of the task, see `src/ray/raylet/task.h`. - bytes task = 1; + Task task = 1; } message ActorTableData { diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index 8a82da1c7..59d2ce18f 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -2,16 +2,14 @@ syntax = "proto3"; package ray.rpc; +import "src/ray/protobuf/common.proto"; + message ForwardTaskRequest { // The ID of the task to be forwarded. bytes task_id = 1; // The tasks in the uncommitted lineage of the forwarded task. This // should include task_id. - // TODO(hchen): Currently, `uncommitted_tasks` are represented as - // flatbutters-serialized bytes. This is because the flatbuffers-defined Task data - // structure is being used in many places. We should move Task and all related data - // strucutres to protobuf. - repeated bytes uncommitted_tasks = 2; + repeated Task uncommitted_tasks = 2; } message ForwardTaskReply { diff --git a/src/ray/protobuf/worker.proto b/src/ray/protobuf/worker.proto index b63782925..c2affa1ac 100644 --- a/src/ray/protobuf/worker.proto +++ b/src/ray/protobuf/worker.proto @@ -2,15 +2,11 @@ syntax = "proto3"; package ray.rpc; +import "src/ray/protobuf/common.proto"; + message AssignTaskRequest { - // The ID of the task to be pushed. - bytes task_id = 1; - // The task to be pushed. This should include task_id. - // TODO(hchen): Currently, `task_spec` are represented as - // flatbutters-serialized bytes. This is because the flatbuffers-defined Task data - // structure is being used in many places. We should move Task and all related data - // structures to protobuf. - bytes task_spec = 2; + // The task to be pushed. + Task task = 1; } message AssignTaskReply { diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index be2a6ead5..dad8c8248 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -1,8 +1,5 @@ // raylet protocol specification -include "gcs.fbs"; - - // TODO(swang): We put the flatbuffer types in a separate namespace for now to // avoid conflicts with legacy Ray types. namespace ray.protocol; @@ -137,7 +134,8 @@ table RegisterClientRequest { // The job ID if the client is a driver, otherwise it should be NIL. job_id: string; // Language of this worker. - language: Language; + // TODO(hchen): Use `Language` in `common.proto`. + language: int; // Port that this worker is listening on. // If port > 0, then worker will listen to this port and wait for // raylet to push tasks, instead of invoking GetTask(). diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 8e7750aae..a37254222 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -59,8 +59,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( * Signature: (J[BLjava/nio/ByteBuffer;II)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask( - JNIEnv *env, jclass, jlong client, jbyteArray cursorId, jobject taskBuff, jint pos, - jint taskSize) { + JNIEnv *env, jclass, jlong client, jbyteArray cursorId, jbyteArray taskSpec) { auto raylet_client = reinterpret_cast(client); std::vector execution_dependencies; @@ -69,8 +68,13 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit execution_dependencies.push_back(cursor_id.GetId()); } - auto data = reinterpret_cast(env->GetDirectBufferAddress(taskBuff)) + pos; - ray::raylet::TaskSpecification task_spec(data, taskSize); + jbyte *data = env->GetByteArrayElements(taskSpec, NULL); + jsize size = env->GetArrayLength(taskSpec); + ray::rpc::TaskSpec task_spec_message; + task_spec_message.ParseFromArray(data, size); + env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT); + + ray::raylet::TaskSpecification task_spec(task_spec_message); auto status = raylet_client->SubmitTask(execution_dependencies, task_spec); ThrowRayExceptionIfNotOK(env, status); } @@ -90,24 +94,16 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_native return nullptr; } - // We serialize the task specification using flatbuffers and then parse the - // resulting string. This awkwardness is due to the fact that the Java - // implementation does not use the underlying C++ TaskSpecification class. - flatbuffers::FlatBufferBuilder fbb; - auto message = spec->ToFlatbuffer(fbb); - fbb.Finish(message); - auto task_message = flatbuffers::GetRoot(fbb.GetBufferPointer()); + // Serialize the task spec and copy to Java byte array. + auto task_data = spec->Serialize(); - jbyteArray result; - result = env->NewByteArray(task_message->size()); + jbyteArray result = env->NewByteArray(task_data.size()); if (result == nullptr) { return nullptr; /* out of memory error thrown */ } - // move from task spec structure to the java structure - env->SetByteArrayRegion( - result, 0, task_message->size(), - reinterpret_cast(const_cast(task_message->data()))); + env->SetByteArrayRegion(result, 0, task_data.size(), + reinterpret_cast(task_data.data())); return result; } diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h index 91338a12e..414f916d1 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h @@ -20,10 +20,10 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeSubmitTask - * Signature: (J[BLjava/nio/ByteBuffer;II)V + * Signature: (J[B[B)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask( - JNIEnv *, jclass, jlong, jbyteArray, jobject, jint, jint); + JNIEnv *, jclass, jlong, jbyteArray, jbyteArray); /* * Class: org_ray_runtime_raylet_RayletClientImpl diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 6f6d2cb69..986f735f7 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -272,9 +272,11 @@ void LineageCache::FlushTask(const TaskID &task_id) { [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, const TaskTableData &data) { HandleEntryCommitted(id); }; auto task = lineage_.GetEntry(task_id); - // TODO(swang): Make this better... auto task_data = std::make_shared(); - task_data->set_task(task->TaskData().Serialize()); + task_data->mutable_task()->mutable_task_spec()->CopyFrom( + task->TaskData().GetTaskSpecification().GetMessage()); + task_data->mutable_task()->mutable_task_execution_spec()->CopyFrom( + task->TaskData().GetTaskExecutionSpec().GetMessage()); RAY_CHECK_OK(task_storage_.Add(JobID(task->TaskData().GetTaskSpecification().JobId()), task_id, task_data, task_callback)); diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 1ecc13599..ad3f1d7a2 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -8,6 +8,7 @@ #include "ray/raylet/task.h" #include "ray/raylet/task_execution_spec.h" #include "ray/raylet/task_spec.h" +#include "ray/raylet/task_util.h" namespace ray { @@ -23,7 +24,7 @@ class MockGcs : public gcs::TableInterface, } Status Add(const JobID &job_id, const TaskID &task_id, - std::shared_ptr &task_data, + const std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done) { task_table_[task_id] = task_data; auto callback = done; @@ -125,21 +126,16 @@ class LineageCacheTest : public ::testing::Test { }; static inline Task ExampleTask(const std::vector &arguments, - int64_t num_returns) { - std::unordered_map required_resources; - std::vector> task_arguments; - for (auto &argument : arguments) { - std::vector references = {argument}; - task_arguments.emplace_back(std::make_shared(references)); + uint64_t num_returns) { + TaskSpecBuilder builder; + builder.SetCommonTaskSpec(Language::PYTHON, {"", "", ""}, JobID::Nil(), + TaskID::FromRandom(), 0, num_returns, {}, {}); + for (const auto &arg : arguments) { + builder.AddByRefArg(arg); } - std::vector function_descriptor(3); - auto spec = TaskSpecification(JobID::Nil(), TaskID::FromRandom(), 0, task_arguments, - num_returns, required_resources, Language::PYTHON, - function_descriptor); - auto execution_spec = TaskExecutionSpecification(std::vector()); - execution_spec.IncrementNumForwards(); - Task task = Task(execution_spec, spec); - return task; + rpc::TaskExecutionSpec execution_spec_message; + execution_spec_message.set_num_forwards(1); + return Task(builder.Build(), TaskExecutionSpecification(execution_spec_message)); } /// Helper method to create a Lineage object with a single task. diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 8da446d7a..664b2a513 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -2,11 +2,14 @@ #include "ray/common/ray_config.h" #include "ray/common/status.h" +#include "ray/protobuf/common.pb.h" #include "ray/raylet/raylet.h" #include "ray/stats/stats.h" #include "gflags/gflags.h" +using ray::rpc::Language; + DEFINE_string(raylet_socket_name, "", "The socket name of raylet."); DEFINE_string(store_socket_name, "", "The socket name of object store."); DEFINE_int32(object_manager_port, -1, "The port of object manager."); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index ebc07717e..f9c53d99a 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -797,19 +797,10 @@ void NodeManager::ProcessClientMessage( ProcessPushErrorRequestMessage(message_data); } break; case protocol::MessageType::PushProfileEventsRequest: { - ProfileTableDataT fbs_message; - flatbuffers::GetRoot(message_data)->UnPackTo(&fbs_message); + auto fbs_message = flatbuffers::GetRoot(message_data); rpc::ProfileTableData profile_table_data; - profile_table_data.set_component_type(fbs_message.component_type); - profile_table_data.set_component_id(fbs_message.component_id); - for (const auto &fbs_event : fbs_message.profile_events) { - rpc::ProfileTableData::ProfileEvent *event = - profile_table_data.add_profile_events(); - event->set_event_type(fbs_event->event_type); - event->set_start_time(fbs_event->start_time); - event->set_end_time(fbs_event->end_time); - event->set_extra_data(fbs_event->extra_data); - } + RAY_CHECK( + profile_table_data.ParseFromArray(fbs_message->data(), fbs_message->size())); RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data)); } break; case protocol::MessageType::FreeObjectsInObjectStoreRequest: { @@ -845,8 +836,9 @@ void NodeManager::ProcessRegisterClientRequestMessage( const std::shared_ptr &client, const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); client->SetClientID(from_flatbuf(*message->worker_id())); - auto worker = std::make_shared(message->worker_pid(), message->language(), - message->port(), client); + Language language = static_cast(message->language()); + auto worker = + std::make_shared(message->worker_pid(), language, message->port(), client); if (message->is_worker()) { // Register the new worker. worker_pool_.RegisterWorker(std::move(worker)); @@ -1050,14 +1042,18 @@ void NodeManager::ProcessDisconnectClientMessage( void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) { // Read the task submitted by the client. - auto message = flatbuffers::GetRoot(message_data); - TaskExecutionSpecification task_execution_spec( - from_flatbuf(*message->execution_dependencies())); - TaskSpecification task_spec(*message->task_spec()); - Task task(task_execution_spec, task_spec); + auto fbs_message = flatbuffers::GetRoot(message_data); + rpc::Task task_message; + RAY_CHECK(task_message.mutable_task_spec()->ParseFromArray( + fbs_message->task_spec()->data(), fbs_message->task_spec()->size())); + for (const auto &dependency : + string_vec_from_flatbuf(*fbs_message->execution_dependencies())) { + task_message.mutable_task_execution_spec()->add_dependencies(dependency); + } + // Submit the task to the raylet. Since the task was submitted // locally, there is no uncommitted lineage. - SubmitTask(task, Lineage()); + SubmitTask(Task(task_message), Lineage()); } void NodeManager::ProcessFetchOrReconstructMessage( @@ -1224,10 +1220,8 @@ void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, TaskID task_id = TaskID::FromBinary(request.task_id()); Lineage uncommitted_lineage; for (int i = 0; i < request.uncommitted_tasks_size(); i++) { - const std::string &task_message = request.uncommitted_tasks(i); - const Task task(*flatbuffers::GetRoot( - reinterpret_cast(task_message.data()))); - RAY_CHECK(uncommitted_lineage.SetEntry(std::move(task), GcsStatus::UNCOMMITTED)); + Task task(request.uncommitted_tasks(i)); + RAY_CHECK(uncommitted_lineage.SetEntry(task, GcsStatus::UNCOMMITTED)); } const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData(); RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId() @@ -1769,7 +1763,7 @@ bool NodeManager::AssignTask(const Task &task) { worker->GetTaskResourceIds().Plus(worker->GetLifetimeResourceIds()); auto resource_id_set_flatbuf = resource_id_set.ToFlatbuf(fbb); - auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb), + auto message = protocol::CreateGetTaskReply(fbb, fbb.CreateString(spec.Serialize()), fbb.CreateVector(resource_id_set_flatbuf)); fbb.Finish(message); const auto &task_id = spec.TaskId(); @@ -2025,9 +2019,7 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { const TaskTableData &task_data) { // The task was in the GCS task table. Use the stored task spec to // re-execute the task. - auto message = flatbuffers::GetRoot(task_data.task().data()); - const Task task(*message); - ResubmitTask(task); + ResubmitTask(Task(task_data.task())); }, /*failure_callback=*/ [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id) { @@ -2238,8 +2230,12 @@ void NodeManager::ForwardTask( // Prepare the request message. rpc::ForwardTaskRequest request; request.set_task_id(task_id.Binary()); - for (auto &entry : uncommitted_lineage.GetEntries()) { - request.add_uncommitted_tasks(entry.second.TaskData().Serialize()); + for (auto &task_entry : uncommitted_lineage.GetEntries()) { + auto task = request.add_uncommitted_tasks(); + task->mutable_task_spec()->CopyFrom( + task_entry.second.TaskData().GetTaskSpecification().GetMessage()); + task->mutable_task_execution_spec()->CopyFrom( + task_entry.second.TaskData().GetTaskExecutionSpec().GetMessage()); } // Move the FORWARDING task to the SWAP queue so that we remember that we diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index d485fff8a..582f6bd77 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -10,6 +10,7 @@ #include "ray/raylet/task.h" #include "ray/object_manager/object_manager.h" #include "ray/common/client_connection.h" +#include "ray/protobuf/common.pb.h" #include "ray/raylet/actor_registration.h" #include "ray/raylet/lineage_cache.h" #include "ray/raylet/scheduling_policy.h" @@ -31,6 +32,7 @@ using rpc::ErrorType; using rpc::HeartbeatBatchTableData; using rpc::HeartbeatTableData; using rpc::JobTableData; +using rpc::Language; struct NodeManagerConfig { /// The node's resource configuration. @@ -48,7 +50,7 @@ struct NodeManagerConfig { /// worker pool. int maximum_startup_concurrency; /// The commands used to start the worker process, grouped by language. - std::unordered_map> worker_commands; + WorkerCommandMap worker_commands; /// The time between heartbeats in milliseconds. uint64_t heartbeat_period_ms; /// The time between debug dumps in milliseconds, or -1 to disable. diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index da1b09e27..e6bb6b740 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -228,7 +228,7 @@ ray::Status RayletClient::SubmitTask(const std::vector &execution_depe flatbuffers::FlatBufferBuilder fbb; auto execution_dependencies_message = to_flatbuf(fbb, execution_dependencies); auto message = ray::protocol::CreateSubmitTaskRequest( - fbb, execution_dependencies_message, task_spec.ToFlatbuffer(fbb)); + fbb, execution_dependencies_message, fbb.CreateString(task_spec.Serialize())); fbb.Finish(message); return conn_->WriteMessage(MessageType::SubmitTask, &fbb); } @@ -335,9 +335,9 @@ ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb); } -ray::Status RayletClient::PushProfileEvents(const ProfileTableDataT &profile_events) { +ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) { flatbuffers::FlatBufferBuilder fbb; - auto message = CreateProfileTableData(fbb, &profile_events); + auto message = fbb.CreateString(profile_events.SerializeAsString()); fbb.Finish(message); auto status = conn_->WriteMessage(MessageType::PushProfileEventsRequest, &fbb); diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 5f887cc93..3620c9d9c 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -1,6 +1,7 @@ #ifndef RAYLET_CLIENT_H #define RAYLET_CLIENT_H +#include #include #include #include @@ -17,6 +18,9 @@ using ray::ObjectID; using ray::TaskID; using ray::UniqueID; +using ray::rpc::Language; +using ray::rpc::ProfileTableData; + using MessageType = ray::protocol::MessageType; using ResourceMappingType = std::unordered_map>>; @@ -138,7 +142,7 @@ class RayletClient { /// /// \param profile_events A batch of profiling event information. /// \return ray::Status. - ray::Status PushProfileEvents(const ProfileTableDataT &profile_events); + ray::Status PushProfileEvents(const ProfileTableData &profile_events); /// Free a list of objects from object stores. /// diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 033da0561..94155b442 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -85,7 +85,7 @@ class MockGcs : public gcs::PubsubInterface, } void Add(const JobID &job_id, const TaskID &task_id, - std::shared_ptr &task_lease_data) { + const std::shared_ptr &task_lease_data) { task_lease_table_[task_id] = task_lease_data; if (subscribed_tasks_.count(task_id) == 1) { notification_callback_(nullptr, task_id, *task_lease_data); @@ -112,7 +112,7 @@ class MockGcs : public gcs::PubsubInterface, Status AppendAt( const JobID &job_id, const TaskID &task_id, - std::shared_ptr &task_data, + const std::shared_ptr &task_data, const ray::gcs::LogInterface::WriteCallback &success_callback, const ray::gcs::LogInterface::WriteCallback @@ -134,7 +134,7 @@ class MockGcs : public gcs::PubsubInterface, MOCK_METHOD4( Append, ray::Status( - const JobID &, const TaskID &, std::shared_ptr &, + const JobID &, const TaskID &, const std::shared_ptr &, const ray::gcs::LogInterface::WriteCallback &)); private: diff --git a/src/ray/raylet/task.cc b/src/ray/raylet/task.cc index 9d8036411..59f45d97e 100644 --- a/src/ray/raylet/task.cc +++ b/src/ray/raylet/task.cc @@ -4,24 +4,12 @@ namespace ray { namespace raylet { -flatbuffers::Offset Task::ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb) const { - auto task = CreateTask(fbb, task_spec_.ToFlatbuffer(fbb), - task_execution_spec_.ToFlatbuffer(fbb)); - return task; -} - const TaskExecutionSpecification &Task::GetTaskExecutionSpec() const { return task_execution_spec_; } const TaskSpecification &Task::GetTaskSpecification() const { return task_spec_; } -void Task::SetExecutionDependencies(const std::vector &dependencies) { - task_execution_spec_.SetExecutionDependencies(dependencies); - ComputeDependencies(); -} - void Task::IncrementNumForwards() { task_execution_spec_.IncrementNumForwards(); } const std::vector &Task::GetDependencies() const { return dependencies_; } @@ -42,24 +30,10 @@ void Task::ComputeDependencies() { } void Task::CopyTaskExecutionSpec(const Task &task) { - task_execution_spec_ = task.GetTaskExecutionSpec(); + task_execution_spec_ = task.task_execution_spec_; ComputeDependencies(); } -const std::string Task::Serialize() const { - flatbuffers::FlatBufferBuilder fbb; - fbb.Finish(ToFlatbuffer(fbb)); - return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize()); -} - -std::string SerializeTaskAsString(const std::vector *dependencies, - const TaskSpecification *task_spec) { - std::vector execution_dependencies(*dependencies); - TaskExecutionSpecification execution_spec(std::move(execution_dependencies)); - Task task(execution_spec, *task_spec); - return task.Serialize(); -} - } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/task.h b/src/ray/raylet/task.h index 10cdfe511..e1a8a878b 100644 --- a/src/ray/raylet/task.h +++ b/src/ray/raylet/task.h @@ -3,9 +3,11 @@ #include +#include "ray/protobuf/common.pb.h" #include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/task_execution_spec.h" #include "ray/raylet/task_spec.h" +#include "ray/rpc/message_wrapper.h" namespace ray { @@ -19,41 +21,22 @@ namespace raylet { /// time. class Task { public: - /// Create a task. + /// Construct a `Task` object from a protobuf message. /// - /// \param execution_spec The execution specification for the task. These are - /// the mutable fields in the task specification that may change at task - /// execution time. - /// \param task_spec The immutable specification for the task. These fields - /// are determined at task submission time. - Task(const TaskExecutionSpecification &execution_spec, - const TaskSpecification &task_spec) - : task_execution_spec_(execution_spec), task_spec_(task_spec) { + /// \param message The protobuf message. + explicit Task(const rpc::Task &message) + : task_spec_(message.task_spec()), + task_execution_spec_(message.task_execution_spec()) { ComputeDependencies(); } - /// Create a task from a serialized flatbuffer. - /// - /// \param task_flatbuffer The serialized task. - Task(const protocol::Task &task_flatbuffer) - : Task(*task_flatbuffer.task_execution_spec(), - *task_flatbuffer.task_specification()) {} - - /// Create a task from a flatbuffer object. - /// - /// \param task_data The task flatbuffer object. - Task(const protocol::TaskT &task_data) - : Task(*task_data.task_execution_spec, task_data.task_specification) {} - - /// Destroy the task. - virtual ~Task() {} - - /// Serialize a task to a flatbuffer. - /// - /// \param fbb The flatbuffer builder. - /// \return An offset to the serialized task. - flatbuffers::Offset ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb) const; + /// Construct a `Task` object from a `TaskSpecification` and a + /// `TaskExecutionSpecification`. + Task(TaskSpecification task_spec, TaskExecutionSpecification task_execution_spec) + : task_spec_(std::move(task_spec)), + task_execution_spec_(std::move(task_execution_spec)) { + ComputeDependencies(); + } /// Get the mutable specification for the task. This specification may be /// updated at runtime. @@ -66,11 +49,6 @@ class Task { /// \return The immutable specification for the task. const TaskSpecification &GetTaskSpecification() const; - /// Set the task's execution dependencies. - /// - /// \param dependencies The value to set the execution dependencies to. - void SetExecutionDependencies(const std::vector &dependencies); - /// Increment the number of times this task has been forwarded. void IncrementNumForwards(); @@ -84,28 +62,22 @@ class Task { /// \param task Task structure with updated dynamic information. void CopyTaskExecutionSpec(const Task &task); - /// Serialize this task as a string. - const std::string Serialize() const; - private: void ComputeDependencies(); - /// Task execution specification, consisting of all dynamic/mutable - /// information about this task determined at execution time.. - TaskExecutionSpecification task_execution_spec_; /// Task specification object, consisting of immutable information about this /// task determined at submission time. Includes resource demand, object /// dependencies, etc. TaskSpecification task_spec_; + /// Task execution specification, consisting of all dynamic/mutable + /// information about this task determined at execution time.. + TaskExecutionSpecification task_execution_spec_; /// A cached copy of the task's object dependencies, including arguments from /// the TaskSpecification and execution dependencies from the /// TaskExecutionSpecification. std::vector dependencies_; }; -std::string SerializeTaskAsString(const std::vector *dependencies, - const TaskSpecification *task_spec); - } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index f16c17b6b..1a469c596 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -6,6 +6,7 @@ #include #include "ray/raylet/task_dependency_manager.h" +#include "ray/raylet/task_util.h" namespace ray { @@ -30,7 +31,7 @@ class MockGcs : public gcs::TableInterface { MOCK_METHOD4( Add, ray::Status(const JobID &job_id, const TaskID &task_id, - std::shared_ptr &task_data, + const std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done)); }; @@ -67,21 +68,16 @@ class TaskDependencyManagerTest : public ::testing::Test { }; static inline Task ExampleTask(const std::vector &arguments, - int64_t num_returns) { - std::unordered_map required_resources; - std::vector> task_arguments; - for (auto &argument : arguments) { - std::vector references = {argument}; - task_arguments.emplace_back(std::make_shared(references)); + uint64_t num_returns) { + TaskSpecBuilder builder; + builder.SetCommonTaskSpec(Language::PYTHON, {"", "", ""}, JobID::Nil(), + TaskID::FromRandom(), 0, num_returns, {}, {}); + for (const auto &arg : arguments) { + builder.AddByRefArg(arg); } - std::vector function_descriptor(3); - auto spec = TaskSpecification(JobID::Nil(), TaskID::FromRandom(), 0, task_arguments, - num_returns, required_resources, Language::PYTHON, - function_descriptor); - auto execution_spec = TaskExecutionSpecification(std::vector()); - execution_spec.IncrementNumForwards(); - Task task = Task(execution_spec, spec); - return task; + rpc::TaskExecutionSpec execution_spec_message; + execution_spec_message.set_num_forwards(1); + return Task(builder.Build(), TaskExecutionSpecification(execution_spec_message)); } std::vector MakeTaskChain(int chain_size, diff --git a/src/ray/raylet/task_execution_spec.cc b/src/ray/raylet/task_execution_spec.cc index dc7bf30b8..ed7e60e2c 100644 --- a/src/ray/raylet/task_execution_spec.cc +++ b/src/ray/raylet/task_execution_spec.cc @@ -4,54 +4,16 @@ namespace ray { namespace raylet { -TaskExecutionSpecification::TaskExecutionSpecification( - const std::vector &&dependencies) { - SetExecutionDependencies(dependencies); +using rpc::IdVectorFromProtobuf; + +const std::vector TaskExecutionSpecification::ExecutionDependencies() const { + return IdVectorFromProtobuf(message_.dependencies()); } -TaskExecutionSpecification::TaskExecutionSpecification( - const std::vector &&dependencies, int num_forwards) { - // TaskExecutionSpecification(std::move(dependencies)); - SetExecutionDependencies(dependencies); - execution_spec_.num_forwards = num_forwards; -} - -flatbuffers::Offset -TaskExecutionSpecification::ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const { - fbb.ForceDefaults(true); - return protocol::TaskExecutionSpecification::Pack(fbb, &execution_spec_); -} - -std::vector TaskExecutionSpecification::ExecutionDependencies() const { - std::vector dependencies; - for (const auto &dependency : execution_spec_.dependencies) { - dependencies.push_back(ObjectID::FromBinary(dependency)); - } - return dependencies; -} - -void TaskExecutionSpecification::SetExecutionDependencies( - const std::vector &dependencies) { - execution_spec_.dependencies.clear(); - for (const auto &dependency : dependencies) { - execution_spec_.dependencies.push_back(dependency.Binary()); - } -} - -int TaskExecutionSpecification::NumForwards() const { - return execution_spec_.num_forwards; -} +size_t TaskExecutionSpecification::NumForwards() const { return message_.num_forwards(); } void TaskExecutionSpecification::IncrementNumForwards() { - execution_spec_.num_forwards += 1; -} - -int64_t TaskExecutionSpecification::LastTimestamp() const { - return execution_spec_.last_timestamp; -} - -void TaskExecutionSpecification::SetLastTimestamp(int64_t new_timestamp) { - execution_spec_.last_timestamp = new_timestamp; + message_.set_num_forwards(message_.num_forwards() + 1); } } // namespace raylet diff --git a/src/ray/raylet/task_execution_spec.h b/src/ray/raylet/task_execution_spec.h index 6fc3b833a..bdc16c8b4 100644 --- a/src/ray/raylet/task_execution_spec.h +++ b/src/ray/raylet/task_execution_spec.h @@ -4,84 +4,45 @@ #include #include "ray/common/id.h" -#include "ray/raylet/format/node_manager_generated.h" +#include "ray/protobuf/common.pb.h" +#include "ray/rpc/message_wrapper.h" +#include "ray/rpc/util.h" namespace ray { namespace raylet { -/// \class TaskExecutionSpecification -/// -/// The task execution specification encapsulates all mutable information about -/// the task. These fields may change at execution time, converse to the -/// TaskSpecification that is determined at submission time. -class TaskExecutionSpecification { +using rpc::MessageWrapper; + +/// Wrapper class of protobuf `TaskExecutionSpec`, see `common.proto` for details. +class TaskExecutionSpecification : public MessageWrapper { public: - TaskExecutionSpecification(const protocol::TaskExecutionSpecificationT &execution_spec) - : execution_spec_(execution_spec) {} - - /// Create a task execution specification. + /// Construct from a protobuf message object. + /// The input message will be **copied** into this object. /// - /// \param dependencies The task's dependencies, determined at execution - /// time. - TaskExecutionSpecification(const std::vector &&dependencies); + /// \param message The protobuf message. + explicit TaskExecutionSpecification(rpc::TaskExecutionSpec message) + : MessageWrapper(std::move(message)) {} - /// Create a task execution specification. + /// Construct from protobuf-serialized binary. /// - /// \param dependencies The task's dependencies, determined at execution - /// time. - /// \param num_forwards The number of times this task has been forwarded by a - /// node manager. - TaskExecutionSpecification(const std::vector &&dependencies, - int num_forwards); - - /// Create a task execution specification from a serialized flatbuffer. - /// - /// \param spec_flatbuffer The serialized specification. - TaskExecutionSpecification( - const protocol::TaskExecutionSpecification &spec_flatbuffer) { - spec_flatbuffer.UnPackTo(&execution_spec_); - } - - /// Serialize a task execution specification to a flatbuffer. - /// - /// \param fbb The flatbuffer builder. - /// \return An offset to the serialized task execution specification. - flatbuffers::Offset ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb) const; + /// \param serialized_binary Protobuf-serialized binary. + explicit TaskExecutionSpecification(const std::string &serialized_binary) + : MessageWrapper(serialized_binary) {} /// Get the task's execution dependencies. /// /// \return A vector of object IDs representing this task's execution /// dependencies. - std::vector ExecutionDependencies() const; - - /// Set the task's execution dependencies. - /// - /// \param dependencies The value to set the execution dependencies to. - void SetExecutionDependencies(const std::vector &dependencies); + const std::vector ExecutionDependencies() const; /// Get the number of times this task has been forwarded. /// /// \return The number of times this task has been forwarded. - int NumForwards() const; + size_t NumForwards() const; /// Increment the number of times this task has been forwarded. void IncrementNumForwards(); - - /// Get the task's last timestamp. - /// - /// \return The timestamp when this task was last received for scheduling. - int64_t LastTimestamp() const; - - /// Set the task's last timestamp to the specified value. - /// - /// \param new_timestamp The new timestamp in millisecond to set the task's - /// time stamp to. Tracks the last time this task entered a raylet. - void SetLastTimestamp(int64_t new_timestamp); - - private: - protocol::TaskExecutionSpecificationT execution_spec_; }; } // namespace raylet diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index e401e5a2b..ccee1d122 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -1,151 +1,50 @@ -#include "task_spec.h" #include -#include "ray/common/common_protocol.h" -#include "ray/gcs/format/gcs_generated.h" +#include "ray/raylet/task_spec.h" +#include "ray/rpc/util.h" #include "ray/util/logging.h" namespace ray { namespace raylet { -TaskArgument::~TaskArgument() {} +using rpc::MapFromProtobuf; +using rpc::VectorFromProtobuf; -TaskArgumentByReference::TaskArgumentByReference(const std::vector &references) - : references_(references) {} - -flatbuffers::Offset TaskArgumentByReference::ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb) const { - return CreateArg(fbb, ids_to_flatbuf(fbb, references_)); -} - -TaskArgumentByValue::TaskArgumentByValue(const uint8_t *value, size_t length) { - value_.assign(value, value + length); -} - -flatbuffers::Offset TaskArgumentByValue::ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb) const { - auto arg = - fbb.CreateString(reinterpret_cast(value_.data()), value_.size()); - const auto &empty_ids = fbb.CreateString(""); - return CreateArg(fbb, empty_ids, arg); -} - -void TaskSpecification::AssignSpecification(const uint8_t *spec, size_t spec_size) { - spec_.assign(spec, spec + spec_size); - // Initialize required_resources_ and required_placement_resources_ - auto message = flatbuffers::GetRoot(spec_.data()); - auto required_resources = map_from_flatbuf(*message->required_resources()); +void TaskSpecification::ComputeResources() { + auto required_resources = MapFromProtobuf(message_.required_resources()); auto required_placement_resources = - map_from_flatbuf(*message->required_placement_resources()); - // If the required_placement_resources field is empty, then the placement - // resources default to the required resources. - if (required_placement_resources.size() == 0) { + MapFromProtobuf(message_.required_placement_resources()); + if (required_placement_resources.empty()) { required_placement_resources = required_resources; } required_resources_ = ResourceSet(required_resources); required_placement_resources_ = ResourceSet(required_placement_resources); } -TaskSpecification::TaskSpecification(const flatbuffers::String &string) { - AssignSpecification(reinterpret_cast(string.data()), string.size()); -} - -TaskSpecification::TaskSpecification(const std::string &string) { - AssignSpecification(reinterpret_cast(string.data()), string.size()); -} - -TaskSpecification::TaskSpecification(const uint8_t *spec, size_t spec_size) { - AssignSpecification(spec, spec_size); -} - -TaskSpecification::TaskSpecification( - const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter, - const std::vector> &task_arguments, int64_t num_returns, - const std::unordered_map &required_resources, - const Language &language, const std::vector &function_descriptor) - : TaskSpecification(job_id, parent_task_id, parent_counter, ActorID::Nil(), - ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), -1, {}, - task_arguments, num_returns, required_resources, - std::unordered_map(), language, - function_descriptor) {} - -TaskSpecification::TaskSpecification( - const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter, - const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, - const int64_t max_actor_reconstructions, const ActorID &actor_id, - const ActorHandleID &actor_handle_id, int64_t actor_counter, - const std::vector &new_actor_handles, - const std::vector> &task_arguments, int64_t num_returns, - const std::unordered_map &required_resources, - const std::unordered_map &required_placement_resources, - const Language &language, const std::vector &function_descriptor, - const std::vector &dynamic_worker_options) - : spec_() { - flatbuffers::FlatBufferBuilder fbb; - - TaskID task_id = GenerateTaskId(job_id, parent_task_id, parent_counter); - // Add argument object IDs. - std::vector> arguments; - for (auto &argument : task_arguments) { - arguments.push_back(argument->ToFlatbuffer(fbb)); - } - - // Serialize the TaskSpecification. - auto spec = CreateTaskInfo( - fbb, to_flatbuf(fbb, job_id), to_flatbuf(fbb, task_id), - to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id), - to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions, - to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter, - ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), num_returns, - map_to_flatbuf(fbb, required_resources), - map_to_flatbuf(fbb, required_placement_resources), language, - string_vec_to_flatbuf(fbb, function_descriptor), - string_vec_to_flatbuf(fbb, dynamic_worker_options)); - fbb.Finish(spec); - AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize()); -} - -flatbuffers::Offset TaskSpecification::ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb) const { - return fbb.CreateString(reinterpret_cast(data()), size()); -} - -// TODO(atumanov): copy/paste most TaskSpec_* methods from task.h and make them -// methods of this class. -const uint8_t *TaskSpecification::data() const { return spec_.data(); } - -size_t TaskSpecification::size() const { return spec_.size(); } - // Task specification getter methods. TaskID TaskSpecification::TaskId() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->task_id()); -} -JobID TaskSpecification::JobId() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->job_id()); + return TaskID::FromBinary(message_.task_id()); } + +JobID TaskSpecification::JobId() const { return JobID::FromBinary(message_.job_id()); } + TaskID TaskSpecification::ParentTaskId() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->parent_task_id()); -} -int64_t TaskSpecification::ParentCounter() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return message->parent_counter(); + return TaskID::FromBinary(message_.parent_task_id()); } + +size_t TaskSpecification::ParentCounter() const { return message_.parent_counter(); } + std::vector TaskSpecification::FunctionDescriptor() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return string_vec_from_flatbuf(*message->function_descriptor()); + return VectorFromProtobuf(message_.function_descriptor()); } std::string TaskSpecification::FunctionDescriptorString() const { - auto message = flatbuffers::GetRoot(spec_.data()); - auto list = string_vec_from_flatbuf(*message->function_descriptor()); + auto list = VectorFromProtobuf(message_.function_descriptor()); std::ostringstream stream; // The 4th is the code hash which is binary bits. No need to output it. - int size = std::min(static_cast(3), list.size()); + size_t size = std::min(static_cast(3), list.size()); for (int i = 0; i < size; ++i) { if (i != 0) { stream << ","; @@ -155,46 +54,32 @@ std::string TaskSpecification::FunctionDescriptorString() const { return stream.str(); } -int64_t TaskSpecification::NumArgs() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return message->args()->size(); -} +size_t TaskSpecification::NumArgs() const { return message_.args_size(); } -int64_t TaskSpecification::NumReturns() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return message->num_returns(); -} +size_t TaskSpecification::NumReturns() const { return message_.num_returns(); } -ObjectID TaskSpecification::ReturnId(int64_t return_index) const { +ObjectID TaskSpecification::ReturnId(size_t return_index) const { return ObjectID::ForTaskReturn(TaskId(), return_index + 1); } -bool TaskSpecification::ArgByRef(int64_t arg_index) const { +bool TaskSpecification::ArgByRef(size_t arg_index) const { return (ArgIdCount(arg_index) != 0); } -int TaskSpecification::ArgIdCount(int64_t arg_index) const { - auto message = flatbuffers::GetRoot(spec_.data()); - auto ids = message->args()->Get(arg_index)->object_ids(); - return (ids->size() / kUniqueIDSize); +size_t TaskSpecification::ArgIdCount(size_t arg_index) const { + return message_.args(arg_index).object_ids_size(); } -ObjectID TaskSpecification::ArgId(int64_t arg_index, int64_t id_index) const { - auto message = flatbuffers::GetRoot(spec_.data()); - const auto &object_ids = - ids_from_flatbuf(*message->args()->Get(arg_index)->object_ids()); - return object_ids[id_index]; +ObjectID TaskSpecification::ArgId(size_t arg_index, size_t id_index) const { + return ObjectID::FromBinary(message_.args(arg_index).object_ids(id_index)); } -const uint8_t *TaskSpecification::ArgVal(int64_t arg_index) const { - auto message = flatbuffers::GetRoot(spec_.data()); - return reinterpret_cast( - message->args()->Get(arg_index)->data()->c_str()); +const uint8_t *TaskSpecification::ArgVal(size_t arg_index) const { + return reinterpret_cast(message_.args(arg_index).data().data()); } -size_t TaskSpecification::ArgValLength(int64_t arg_index) const { - auto message = flatbuffers::GetRoot(spec_.data()); - return message->args()->Get(arg_index)->data()->size(); +size_t TaskSpecification::ArgValLength(size_t arg_index) const { + return message_.args(arg_index).data().size(); } const ResourceSet TaskSpecification::GetRequiredResources() const { @@ -210,43 +95,59 @@ bool TaskSpecification::IsDriverTask() const { return FunctionDescriptor().empty(); } -Language TaskSpecification::GetLanguage() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return message->language(); +rpc::Language TaskSpecification::GetLanguage() const { return message_.language(); } + +bool TaskSpecification::IsActorCreationTask() const { + return message_.type() == rpc::TaskType::ACTOR_CREATION_TASK; } -bool TaskSpecification::IsActorCreationTask() const { return !ActorCreationId().IsNil(); } - -bool TaskSpecification::IsActorTask() const { return !ActorId().IsNil(); } +bool TaskSpecification::IsActorTask() const { + return message_.type() == rpc::TaskType::ACTOR_TASK; +} ActorID TaskSpecification::ActorCreationId() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->actor_creation_id()); + // TODO(hchen) Add a check to make sure this function can only be called if + // task is an actor creation task. + if (!IsActorCreationTask()) { + return ActorID::Nil(); + } + return ActorID::FromBinary(message_.actor_creation_task_spec().actor_id()); } ObjectID TaskSpecification::ActorCreationDummyObjectId() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->actor_creation_dummy_object_id()); + if (!IsActorTask()) { + return ObjectID::Nil(); + } + return ObjectID::FromBinary( + message_.actor_task_spec().actor_creation_dummy_object_id()); } -int64_t TaskSpecification::MaxActorReconstructions() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return message->max_actor_reconstructions(); +uint64_t TaskSpecification::MaxActorReconstructions() const { + if (!IsActorCreationTask()) { + return 0; + } + return message_.actor_creation_task_spec().max_actor_reconstructions(); } ActorID TaskSpecification::ActorId() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->actor_id()); + if (!IsActorTask()) { + return ActorID::Nil(); + } + return ActorID::FromBinary(message_.actor_task_spec().actor_id()); } ActorHandleID TaskSpecification::ActorHandleId() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->actor_handle_id()); + if (!IsActorTask()) { + return ActorHandleID::Nil(); + } + return ActorHandleID::FromBinary(message_.actor_task_spec().actor_handle_id()); } -int64_t TaskSpecification::ActorCounter() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return message->actor_counter(); +uint64_t TaskSpecification::ActorCounter() const { + if (!IsActorTask()) { + return 0; + } + return message_.actor_task_spec().actor_counter(); } ObjectID TaskSpecification::ActorDummyObject() const { @@ -255,13 +156,16 @@ ObjectID TaskSpecification::ActorDummyObject() const { } std::vector TaskSpecification::NewActorHandles() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return ids_from_flatbuf(*message->new_actor_handles()); + if (!IsActorTask()) { + return {}; + } + return rpc::IdVectorFromProtobuf( + message_.actor_task_spec().new_actor_handles()); } std::vector TaskSpecification::DynamicWorkerOptions() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return string_vec_from_flatbuf(*message->dynamic_worker_options()); + return rpc::VectorFromProtobuf( + message_.actor_creation_task_spec().dynamic_worker_options()); } } // namespace raylet diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index 4339a1a4c..ce87d2580 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -7,8 +7,9 @@ #include #include "ray/common/id.h" -#include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/common.pb.h" #include "ray/raylet/scheduling_resources.h" +#include "ray/rpc/message_wrapper.h" extern "C" { #include "ray/thirdparty/sha256.h" @@ -18,179 +19,66 @@ namespace ray { namespace raylet { -/// \class TaskArgument -/// -/// A virtual class that represents an argument to a task. -class TaskArgument { +using rpc::Language; +using rpc::MessageWrapper; +using rpc::TaskType; + +/// Wrapper class of protobuf `TaskSpec`, see `common.proto` for details. +class TaskSpecification : public MessageWrapper { public: - /// Serialize the task argument to a flatbuffer. + /// Construct from a protobuf message object. + /// The input message will be **copied** into this object. /// - /// \param fbb The flatbuffer builder to serialize with. - /// \return An offset to the serialized task argument. - virtual flatbuffers::Offset ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb) const = 0; + /// \param message The protobuf message. + explicit TaskSpecification(rpc::TaskSpec message) : MessageWrapper(std::move(message)) { + ComputeResources(); + } - virtual ~TaskArgument() = 0; -}; - -/// \class TaskArgumentByReference -/// -/// A task argument consisting of a list of object ID references. -class TaskArgumentByReference : virtual public TaskArgument { - public: - /// Create a task argument by reference from a list of object IDs. + /// Construct from protobuf-serialized binary. /// - /// \param references A list of object ID references. - TaskArgumentByReference(const std::vector &references); - - ~TaskArgumentByReference(){}; - - flatbuffers::Offset ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const; - - private: - /// The object IDs. - const std::vector references_; -}; - -/// \class TaskArgumentByValue -/// -/// A task argument containing the raw value. -class TaskArgumentByValue : public TaskArgument { - public: - /// Create a task argument from a raw value. - /// - /// \param value A pointer to the raw value. - /// \param length The size of the raw value. - TaskArgumentByValue(const uint8_t *value, size_t length); - - flatbuffers::Offset ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const; - - private: - /// The raw value. - std::vector value_; -}; - -/// \class TaskSpecification -/// -/// The task specification encapsulates all immutable information about the -/// task. These fields are determined at submission time, converse to the -/// TaskExecutionSpecification that may change at execution time. -class TaskSpecification { - public: - /// Deserialize a task specification from a flatbuffer. - /// - /// \param string A serialized task specification flatbuffer. - TaskSpecification(const flatbuffers::String &string); - - // TODO(swang): Define an actor task constructor. - /// Create a task specification from the raw fields. This constructor omits - /// some values and sets them to sensible defaults. - /// - /// \param job_id The driver ID, representing the job that this task is a - /// part of. - /// \param parent_task_id The task ID of the task that spawned this task. - /// \param parent_counter The number of tasks that this task's parent spawned - /// before this task. - /// \param function_descriptor The function descriptor. - /// \param task_arguments The list of task arguments. - /// \param num_returns The number of values returned by the task. - /// \param required_resources The task's resource demands. - /// \param language The language of the worker that must execute the function. - TaskSpecification(const JobID &job_id, const TaskID &parent_task_id, - int64_t parent_counter, - const std::vector> &task_arguments, - int64_t num_returns, - const std::unordered_map &required_resources, - const Language &language, - const std::vector &function_descriptor); - - // TODO(swang): Define an actor task constructor. - /// Create a task specification from the raw fields. - /// - /// \param job_id The driver ID, representing the job that this task is a - /// part of. - /// \param parent_task_id The task ID of the task that spawned this task. - /// \param parent_counter The number of tasks that this task's parent spawned - /// before this task. - /// \param actor_creation_id If this is an actor task, then this is the ID of - /// the corresponding actor creation task. Otherwise, this is nil. - /// \param actor_id The ID of the actor for the task. If this is not an actor - /// task, then this is nil. - /// \param actor_handle_id The ID of the actor handle that submitted this - /// task. If this is not an actor task, then this is nil. - /// \param actor_counter The number of tasks submitted before this task from - /// the same actor handle. If this is not an actor task, then this is 0. - /// \param task_arguments The list of task arguments. - /// \param num_returns The number of values returned by the task. - /// \param required_resources The task's resource demands. - /// \param required_placement_resources The resources required to place this - /// task on a node. Typically, this should be an empty map in which case it - /// will default to be equal to the required_resources argument. - /// \param language The language of the worker that must execute the function. - /// \param function_descriptor The function descriptor. - /// \param dynamic_worker_options The dynamic options for starting an actor worker. - TaskSpecification( - const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter, - const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, - int64_t max_actor_reconstructions, const ActorID &actor_id, - const ActorHandleID &actor_handle_id, int64_t actor_counter, - const std::vector &new_actor_handles, - const std::vector> &task_arguments, - int64_t num_returns, - const std::unordered_map &required_resources, - const std::unordered_map &required_placement_resources, - const Language &language, const std::vector &function_descriptor, - const std::vector &dynamic_worker_options = {}); - - /// Deserialize a task specification from a string. - /// - /// \param string The string data for a serialized task specification flatbuffers. - TaskSpecification(const std::string &string); - - /// Deserialize a task specification from raw byte array. - /// - /// \param spec Raw byte array for a serialized task specification flatbuffer. - /// \param spec_size Size of the byte array. - TaskSpecification(const uint8_t *spec, size_t spec_size); - - ~TaskSpecification() {} - - /// Serialize the TaskSpecification to a flatbuffer. - /// - /// \param fbb The flatbuffer builder to serialize with. - /// \return An offset to the serialized task specification. - flatbuffers::Offset ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb) const; - - std::string SerializeAsString() const { - flatbuffers::FlatBufferBuilder fbb; - auto string = ToFlatbuffer(fbb); - fbb.Finish(string); - return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize()); + /// \param serialized_binary Protobuf-serialized binary. + explicit TaskSpecification(const std::string &serialized_binary) + : MessageWrapper(serialized_binary) { + ComputeResources(); } // TODO(swang): Finalize and document these methods. TaskID TaskId() const; + JobID JobId() const; + TaskID ParentTaskId() const; - int64_t ParentCounter() const; + + size_t ParentCounter() const; + std::vector FunctionDescriptor() const; + // Output the function descriptor as a string for log purpose. std::string FunctionDescriptorString() const; - int64_t NumArgs() const; - int64_t NumReturns() const; - bool ArgByRef(int64_t arg_index) const; - int ArgIdCount(int64_t arg_index) const; - ObjectID ArgId(int64_t arg_index, int64_t id_index) const; - ObjectID ReturnId(int64_t return_index) const; - const uint8_t *ArgVal(int64_t arg_index) const; - size_t ArgValLength(int64_t arg_index) const; + + size_t NumArgs() const; + + size_t NumReturns() const; + + bool ArgByRef(size_t arg_index) const; + + size_t ArgIdCount(size_t arg_index) const; + + ObjectID ArgId(size_t arg_index, size_t id_index) const; + + ObjectID ReturnId(size_t return_index) const; + + const uint8_t *ArgVal(size_t arg_index) const; + + size_t ArgValLength(size_t arg_index) const; + /// Return the resources that are to be acquired during the execution of this /// task. /// /// \return The resources that will be acquired during the execution of this /// task. const ResourceSet GetRequiredResources() const; + /// Return the resources that are required for a task to be placed on a node. /// This will typically be the same as the resources acquired during execution /// and will always be a superset of those resources. However, they may @@ -201,36 +89,40 @@ class TaskSpecification { /// /// \return The resources that are required to place a task on a node. const ResourceSet GetRequiredPlacementResources() const; + bool IsDriverTask() const; + Language GetLanguage() const; // Methods specific to actor tasks. bool IsActorCreationTask() const; + bool IsActorTask() const; + ActorID ActorCreationId() const; + ObjectID ActorCreationDummyObjectId() const; - int64_t MaxActorReconstructions() const; + + uint64_t MaxActorReconstructions() const; + ActorID ActorId() const; + ActorHandleID ActorHandleId() const; - int64_t ActorCounter() const; + + uint64_t ActorCounter() const; + ObjectID ActorDummyObject() const; + std::vector NewActorHandles() const; std::vector DynamicWorkerOptions() const; private: - /// Assign the specification data from a pointer. - void AssignSpecification(const uint8_t *spec, size_t spec_size); - /// Get a pointer to the byte data. - const uint8_t *data() const; - /// Get the size in bytes of the task specification. - size_t size() const; + void ComputeResources(); /// Field storing required resources. Initalized in constructor. ResourceSet required_resources_; /// Field storing required placement resources. Initalized in constructor. ResourceSet required_placement_resources_; - /// The task specification data. - std::vector spec_; }; } // namespace raylet diff --git a/src/ray/raylet/task_test.cc b/src/ray/raylet/task_test.cc index 72864785a..524c9aaca 100644 --- a/src/ray/raylet/task_test.cc +++ b/src/ray/raylet/task_test.cc @@ -48,56 +48,6 @@ TEST(IdPropertyTest, TestIdProperty) { ASSERT_TRUE(ObjectID::Nil().IsNil()); } -TEST(TaskSpecTest, TaskInfoSize) { - std::vector references = {ObjectID::FromRandom(), ObjectID::FromRandom()}; - auto arguments_1 = std::make_shared(references); - std::string one_arg("This is an value argument."); - auto arguments_2 = std::make_shared( - reinterpret_cast(one_arg.c_str()), one_arg.size()); - std::vector> task_arguments({arguments_1, arguments_2}); - auto task_id = TaskID::FromRandom(); - { - flatbuffers::FlatBufferBuilder fbb; - std::vector> arguments; - for (auto &argument : task_arguments) { - arguments.push_back(argument->ToFlatbuffer(fbb)); - } - // General task. - auto spec = CreateTaskInfo( - fbb, to_flatbuf(fbb, JobID::FromRandom()), to_flatbuf(fbb, task_id), - to_flatbuf(fbb, TaskID::FromRandom()), 0, to_flatbuf(fbb, ActorID::Nil()), - to_flatbuf(fbb, ObjectID::Nil()), 0, to_flatbuf(fbb, ActorID::Nil()), - to_flatbuf(fbb, ActorHandleID::Nil()), 0, - ids_to_flatbuf(fbb, std::vector()), fbb.CreateVector(arguments), 1, - map_to_flatbuf(fbb, {}), map_to_flatbuf(fbb, {}), Language::PYTHON, - string_vec_to_flatbuf(fbb, {"PackageName", "ClassName", "FunctionName"})); - fbb.Finish(spec); - RAY_LOG(ERROR) << "Ordinary task info size: " << fbb.GetSize(); - } - - { - flatbuffers::FlatBufferBuilder fbb; - std::vector> arguments; - for (auto &argument : task_arguments) { - arguments.push_back(argument->ToFlatbuffer(fbb)); - } - // General task. - auto spec = CreateTaskInfo( - fbb, to_flatbuf(fbb, JobID::FromRandom()), to_flatbuf(fbb, task_id), - to_flatbuf(fbb, TaskID::FromRandom()), 10, to_flatbuf(fbb, ActorID::FromRandom()), - to_flatbuf(fbb, ObjectID::FromRandom()), 10000000, - to_flatbuf(fbb, ActorID::FromRandom()), - to_flatbuf(fbb, ActorHandleID::FromRandom()), 20, - ids_to_flatbuf( - fbb, std::vector({ObjectID::FromRandom(), ObjectID::FromRandom()})), - fbb.CreateVector(arguments), 2, map_to_flatbuf(fbb, {}), map_to_flatbuf(fbb, {}), - Language::PYTHON, - string_vec_to_flatbuf(fbb, {"PackageName", "ClassName", "FunctionName"})); - fbb.Finish(spec); - RAY_LOG(ERROR) << "Actor task info size: " << fbb.GetSize(); - } -} - } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/task_util.h b/src/ray/raylet/task_util.h new file mode 100644 index 000000000..63c0ed890 --- /dev/null +++ b/src/ray/raylet/task_util.h @@ -0,0 +1,120 @@ +#ifndef RAY_RAYLET_TASK_UTIL_H +#define RAY_RAYLET_TASK_UTIL_H + +#include "ray/protobuf/common.pb.h" +#include "ray/raylet/task_spec.h" + +namespace ray { + +namespace raylet { + +/// Helper class for building a `TaskSpecification` object. +class TaskSpecBuilder { + public: + /// Build the `TaskSpecification` object. + TaskSpecification Build() { return TaskSpecification(message_); } + + /// Get a reference to the internal protobuf message object. + const rpc::TaskSpec &GetMessage() const { return message_; } + + /// Set the common attributes of the task spec. + /// See `common.proto` for meaning of the arguments. + /// + /// \return Reference to the builder object itself. + TaskSpecBuilder &SetCommonTaskSpec( + const Language &language, const std::vector &function_descriptor, + const JobID &job_id, const TaskID &parent_task_id, uint64_t parent_counter, + uint64_t num_returns, + const std::unordered_map &required_resources, + const std::unordered_map &required_placement_resources) { + message_.set_type(rpc::TaskType::NORMAL_TASK); + message_.set_language(language); + for (const auto &fd : function_descriptor) { + message_.add_function_descriptor(fd); + } + message_.set_job_id(job_id.Binary()); + message_.set_task_id(GenerateTaskId(job_id, parent_task_id, parent_counter).Binary()); + message_.set_parent_task_id(parent_task_id.Binary()); + message_.set_parent_counter(parent_counter); + message_.set_num_returns(num_returns); + message_.mutable_required_resources()->insert(required_resources.begin(), + required_resources.end()); + message_.mutable_required_placement_resources()->insert( + required_placement_resources.begin(), required_placement_resources.end()); + return *this; + } + + /// Add a by-reference argument to the task. + /// + /// \param arg_id Id of the argument. + /// \return Reference to the builder object itself. + TaskSpecBuilder &AddByRefArg(const ObjectID &arg_id) { + message_.add_args()->add_object_ids(arg_id.Binary()); + return *this; + } + + /// Add a by-value argument to the task. + /// + /// \param data String object that contains the data. + /// \return Reference to the builder object itself. + TaskSpecBuilder &AddByValueArg(const std::string &data) { + message_.add_args()->set_data(data); + return *this; + } + + /// Add a by-value argument to the task. + /// + /// \param data Pointer to the data. + /// \param size Size of the data. + /// \return Reference to the builder object itself. + TaskSpecBuilder &AddByValueArg(const void *data, size_t size) { + message_.add_args()->set_data(data, size); + return *this; + } + + /// Set the `ActorCreationTaskSpec` of the task spec. + /// See `common.proto` for meaning of the arguments. + /// + /// \return Reference to the builder object itself. + TaskSpecBuilder &SetActorCreationTaskSpec( + const ActorID &actor_id, uint64_t max_reconstructions = 0, + const std::vector &dynamic_worker_options = {}) { + message_.set_type(TaskType::ACTOR_CREATION_TASK); + auto actor_creation_spec = message_.mutable_actor_creation_task_spec(); + actor_creation_spec->set_actor_id(actor_id.Binary()); + actor_creation_spec->set_max_actor_reconstructions(max_reconstructions); + for (const auto &option : dynamic_worker_options) { + actor_creation_spec->add_dynamic_worker_options(option); + } + return *this; + } + + /// Set the `ActorTaskSpec` of the task spec. + /// See `common.proto` for meaning of the arguments. + /// + /// \return Reference to the builder object itself. + TaskSpecBuilder &SetActorTaskSpec( + const ActorID &actor_id, const ActorHandleID &actor_handle_id, + const ObjectID &actor_creation_dummy_object_id, uint64_t actor_counter, + const std::vector &new_handle_ids = {}) { + message_.set_type(TaskType::ACTOR_TASK); + auto actor_spec = message_.mutable_actor_task_spec(); + actor_spec->set_actor_id(actor_id.Binary()); + actor_spec->set_actor_handle_id(actor_handle_id.Binary()); + actor_spec->set_actor_creation_dummy_object_id( + actor_creation_dummy_object_id.Binary()); + actor_spec->set_actor_counter(actor_counter); + for (const auto &id : new_handle_ids) { + actor_spec->add_new_actor_handles(id.Binary()); + } + return *this; + } + + private: + rpc::TaskSpec message_; +}; + +} // namespace raylet +} // namespace ray + +#endif // RAY_RAYLET_TASK_UTIL_H diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index 6720c8ce3..67deb04d0 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -5,12 +5,15 @@ #include "ray/common/client_connection.h" #include "ray/common/id.h" +#include "ray/protobuf/common.pb.h" #include "ray/raylet/scheduling_resources.h" namespace ray { namespace raylet { +using rpc::Language; + /// Worker class encapsulates the implementation details of a worker. A worker /// is the execution container around a unit of Ray work, such as a task or an /// actor. Ray units of work execute in the context of a Worker. diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 3afe78b18..91552287a 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -41,10 +41,10 @@ namespace raylet { /// A constructor that initializes a worker pool with /// (num_worker_processes * num_workers_per_process) workers for each language. -WorkerPool::WorkerPool( - int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, std::shared_ptr gcs_client, - const std::unordered_map> &worker_commands) +WorkerPool::WorkerPool(int num_worker_processes, int num_workers_per_process, + int maximum_startup_concurrency, + std::shared_ptr gcs_client, + const WorkerCommandMap &worker_commands) : num_workers_per_process_(num_workers_per_process), multiple_for_warning_(std::max(num_worker_processes, maximum_startup_concurrency)), maximum_startup_concurrency_(maximum_startup_concurrency), @@ -300,7 +300,7 @@ bool WorkerPool::DisconnectWorker(const std::shared_ptr &worker) { RAY_CHECK(RemoveWorker(state.registered_workers, worker)); stats::CurrentWorker().Record( - 0, {{stats::LanguageKey, EnumNameLanguage(worker->GetLanguage())}, + 0, {{stats::LanguageKey, Language_Name(worker->GetLanguage())}, {stats::WorkerPidKey, std::to_string(worker->Pid())}}); return RemoveWorker(state.idle, worker); @@ -310,7 +310,7 @@ void WorkerPool::DisconnectDriver(const std::shared_ptr &driver) { auto &state = GetStateForLanguage(driver->GetLanguage()); RAY_CHECK(RemoveWorker(state.registered_drivers, driver)); stats::CurrentDriver().Record( - 0, {{stats::LanguageKey, EnumNameLanguage(driver->GetLanguage())}, + 0, {{stats::LanguageKey, Language_Name(driver->GetLanguage())}, {stats::WorkerPidKey, std::to_string(driver->Pid())}}); } @@ -382,14 +382,14 @@ void WorkerPool::RecordMetrics() const { // Record worker. for (auto worker : entry.second.registered_workers) { stats::CurrentWorker().Record( - worker->Pid(), {{stats::LanguageKey, EnumNameLanguage(worker->GetLanguage())}, + worker->Pid(), {{stats::LanguageKey, Language_Name(worker->GetLanguage())}, {stats::WorkerPidKey, std::to_string(worker->Pid())}}); } // Record driver. for (auto driver : entry.second.registered_drivers) { stats::CurrentDriver().Record( - driver->Pid(), {{stats::LanguageKey, EnumNameLanguage(driver->GetLanguage())}, + driver->Pid(), {{stats::LanguageKey, Language_Name(driver->GetLanguage())}, {stats::WorkerPidKey, std::to_string(driver->Pid())}}); } } diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 4ea2648d7..32f6cb042 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -8,7 +8,7 @@ #include "ray/common/client_connection.h" #include "ray/gcs/client.h" -#include "ray/gcs/format/util.h" +#include "ray/protobuf/common.pb.h" #include "ray/raylet/task.h" #include "ray/raylet/worker.h" @@ -16,6 +16,11 @@ namespace ray { namespace raylet { +using rpc::Language; + +using WorkerCommandMap = + std::unordered_map, std::hash>; + class Worker; /// \class WorkerPool @@ -36,10 +41,10 @@ class WorkerPool { /// resources on the machine). /// \param worker_commands The commands used to start the worker process, grouped by /// language. - WorkerPool( - int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, std::shared_ptr gcs_client, - const std::unordered_map> &worker_commands); + WorkerPool(int num_worker_processes, int num_workers_per_process, + int maximum_startup_concurrency, + std::shared_ptr gcs_client, + const WorkerCommandMap &worker_commands); /// Destructor responsible for freeing a set of workers owned by this class. virtual ~WorkerPool(); @@ -179,7 +184,7 @@ class WorkerPool { /// The number of workers per process. int num_workers_per_process_; /// Pool states per language. - std::unordered_map states_by_lang_; + std::unordered_map> states_by_lang_; private: /// A helper function that returns the reference of the pool state diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 715e89417..f80cf7376 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -18,8 +18,7 @@ class WorkerPoolMock : public WorkerPool { : WorkerPoolMock({{Language::PYTHON, {"dummy_py_worker_command"}}, {Language::JAVA, {"dummy_java_worker_command"}}}) {} - explicit WorkerPoolMock( - const std::unordered_map> &worker_commands) + explicit WorkerPoolMock(const WorkerCommandMap &worker_commands) : WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, nullptr, worker_commands), last_worker_pid_(0) {} @@ -89,8 +88,7 @@ class WorkerPoolTest : public ::testing::Test { return std::shared_ptr(new Worker(pid, language, -1, client)); } - void SetWorkerCommands( - const std::unordered_map> &worker_commands) { + void SetWorkerCommands(const WorkerCommandMap &worker_commands) { WorkerPoolMock worker_pool(worker_commands); this->worker_pool_ = std::move(worker_pool); } @@ -107,11 +105,23 @@ class WorkerPoolTest : public ::testing::Test { static inline TaskSpecification ExampleTaskSpec( const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON, - const ActorID actor_creation_id = ActorID::Nil()) { - std::vector function_descriptor(3); - return TaskSpecification(JobID::Nil(), TaskID::Nil(), 0, actor_creation_id, - ObjectID::Nil(), 0, actor_id, ActorHandleID::Nil(), 0, {}, {}, - 0, {}, {}, language, function_descriptor); + const ActorID actor_creation_id = ActorID::Nil(), + const std::vector &dynamic_worker_options = {}) { + rpc::TaskSpec message; + message.set_language(language); + if (!actor_id.IsNil()) { + message.set_type(rpc::TaskType::ACTOR_TASK); + message.mutable_actor_task_spec()->set_actor_id(actor_id.Binary()); + } else if (!actor_creation_id.IsNil()) { + message.set_type(rpc::TaskType::ACTOR_CREATION_TASK); + message.mutable_actor_creation_task_spec()->set_actor_id(actor_creation_id.Binary()); + for (const auto &option : dynamic_worker_options) { + message.mutable_actor_creation_task_spec()->add_dynamic_worker_options(option); + } + } else { + message.set_type(rpc::TaskType::NORMAL_TASK); + } + return TaskSpecification(std::move(message)); } TEST_F(WorkerPoolTest, HandleWorkerRegistration) { @@ -226,10 +236,8 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}}, {Language::JAVA, java_worker_command}}); - TaskSpecification task_spec(JobID::Nil(), TaskID::Nil(), 0, ActorID::FromRandom(), - ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), 0, - {}, {}, 0, {}, {}, Language::JAVA, {"", "", ""}, - {"test_op_0", "test_op_1"}); + TaskSpecification task_spec = ExampleTaskSpec( + ActorID::Nil(), Language::JAVA, ActorID::FromRandom(), {"test_op_0", "test_op_1"}); worker_pool_.StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions()); const auto real_command = worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess()); diff --git a/src/ray/rpc/message_wrapper.h b/src/ray/rpc/message_wrapper.h new file mode 100644 index 000000000..452ad24e5 --- /dev/null +++ b/src/ray/rpc/message_wrapper.h @@ -0,0 +1,42 @@ +#ifndef RAY_RPC_WRAPPER_H +#define RAY_RPC_WRAPPER_H + +#include + +namespace ray { + +namespace rpc { + +/// Wrap a protobuf message. +template +class MessageWrapper { + public: + /// Construct from a protobuf message object. + /// The input message will be **copied** into this object. + /// + /// \param message The protobuf message. + explicit MessageWrapper(const Message message) : message_(std::move(message)) {} + + /// Construct from protobuf-serialized binary. + /// + /// \param serialized_binary Protobuf-serialized binary. + explicit MessageWrapper(const std::string &serialized_binary) { + message_.ParseFromString(serialized_binary); + } + + /// Get reference of the protobuf message. + const Message &GetMessage() const { return message_; } + + /// Serialize the message to a string. + const std::string Serialize() const { return message_.SerializeAsString(); } + + protected: + /// The wrapped message. + Message message_; +}; + +} // namespace rpc + +} // namespace ray + +#endif // RAY_RPC_WRAPPER_H diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h index 59ae75ae3..97bfd0e9d 100644 --- a/src/ray/rpc/util.h +++ b/src/ray/rpc/util.h @@ -1,6 +1,7 @@ #ifndef RAY_RPC_UTIL_H #define RAY_RPC_UTIL_H +#include #include #include @@ -28,18 +29,37 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { } } +/// Converts a Protobuf `RepeatedPtrField` to a vector. template inline std::vector VectorFromProtobuf( const ::google::protobuf::RepeatedPtrField &pb_repeated) { return std::vector(pb_repeated.begin(), pb_repeated.end()); } +/// Converts a Protobuf `RepeatedField` to a vector. template inline std::vector VectorFromProtobuf( const ::google::protobuf::RepeatedField &pb_repeated) { return std::vector(pb_repeated.begin(), pb_repeated.end()); } +/// Converts a Protobuf `RepeatedField` to a vector of IDs. +template +inline std::vector IdVectorFromProtobuf( + const ::google::protobuf::RepeatedPtrField<::std::string> &pb_repeated) { + auto str_vec = VectorFromProtobuf(pb_repeated); + std::vector ret; + std::transform(str_vec.begin(), str_vec.end(), std::back_inserter(ret), + &ID::FromBinary); + return ret; +} + +/// Converts a Protobuf map to a `unordered_map`. +template +inline std::unordered_map MapFromProtobuf(::google::protobuf::Map pb_map) { + return std::unordered_map(pb_map.begin(), pb_map.end()); +} + } // namespace rpc } // namespace ray