mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
Define common data structures with protobuf. (#5121)
This commit is contained in:
+21
-45
@@ -1,4 +1,4 @@
|
||||
load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module")
|
||||
load("//bazel:ray.bzl", "define_java_module")
|
||||
load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile")
|
||||
|
||||
exports_files([
|
||||
@@ -50,8 +50,10 @@ define_java_module(
|
||||
define_java_module(
|
||||
name = "runtime",
|
||||
additional_srcs = [
|
||||
":generate_java_gcs_fbs",
|
||||
":gcs_java_proto",
|
||||
":all_java_proto",
|
||||
],
|
||||
exclude_srcs = [
|
||||
"runtime/src/main/java/org/ray/runtime/generated/*.java",
|
||||
],
|
||||
additional_resources = [
|
||||
":java_native_deps",
|
||||
@@ -68,7 +70,6 @@ define_java_module(
|
||||
deps = [
|
||||
":org_ray_ray_api",
|
||||
"@plasma//:org_apache_arrow_arrow_plasma",
|
||||
"@maven//:com_github_davidmoten_flatbuffers_java",
|
||||
"@maven//:com_google_guava_guava",
|
||||
"@maven//:com_google_protobuf_protobuf_java",
|
||||
"@maven//:com_typesafe_config",
|
||||
@@ -151,39 +152,22 @@ java_binary(
|
||||
],
|
||||
)
|
||||
|
||||
java_proto_compile(
|
||||
name = "common_java_proto",
|
||||
deps = ["@//:common_proto"],
|
||||
)
|
||||
|
||||
java_proto_compile(
|
||||
name = "gcs_java_proto",
|
||||
deps = ["@//:gcs_proto"],
|
||||
)
|
||||
|
||||
flatbuffers_generated_files = [
|
||||
"Arg.java",
|
||||
"Language.java",
|
||||
"TaskInfo.java",
|
||||
"ResourcePair.java",
|
||||
]
|
||||
|
||||
flatbuffer_java_library(
|
||||
name = "java_gcs_fbs",
|
||||
srcs = ["//:gcs_fbs_file"],
|
||||
outs = flatbuffers_generated_files,
|
||||
out_prefix = "",
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "generate_java_gcs_fbs",
|
||||
srcs = [":java_gcs_fbs"],
|
||||
outs = [
|
||||
"runtime/src/main/java/org/ray/runtime/generated/" + file for file in flatbuffers_generated_files
|
||||
filegroup(
|
||||
name = "all_java_proto",
|
||||
srcs = [
|
||||
":common_java_proto",
|
||||
":gcs_java_proto",
|
||||
],
|
||||
cmd = """
|
||||
for f in $(locations //java:java_gcs_fbs); do
|
||||
chmod +w $$f
|
||||
mv -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated
|
||||
done
|
||||
python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/..
|
||||
""",
|
||||
local = 1,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
@@ -202,8 +186,7 @@ filegroup(
|
||||
genrule(
|
||||
name = "gen_maven_deps",
|
||||
srcs = [
|
||||
":gcs_java_proto",
|
||||
":generate_java_gcs_fbs",
|
||||
":all_java_proto",
|
||||
":java_native_deps",
|
||||
":copy_pom_file",
|
||||
"@plasma//:org_apache_arrow_arrow_plasma",
|
||||
@@ -212,6 +195,11 @@ genrule(
|
||||
cmd = """
|
||||
set -x
|
||||
WORK_DIR=$$(pwd)
|
||||
# Copy protobuf-generated files.
|
||||
rm -rf $$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated
|
||||
for f in $(locations //java:all_java_proto); do
|
||||
unzip $$f -x META-INF/MANIFEST.MF -d $$WORK_DIR/java/runtime/src/main/java
|
||||
done
|
||||
# Copy native dependecies.
|
||||
NATIVE_DEPS_DIR=$$WORK_DIR/java/runtime/native_dependencies/
|
||||
rm -rf $$NATIVE_DEPS_DIR
|
||||
@@ -220,18 +208,6 @@ genrule(
|
||||
chmod +w $$f
|
||||
cp $$f $$NATIVE_DEPS_DIR
|
||||
done
|
||||
# Copy protobuf-generated files.
|
||||
GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated
|
||||
rm -rf $$GENERATED_DIR
|
||||
mkdir -p $$GENERATED_DIR
|
||||
for f in $(locations //java:gcs_java_proto); do
|
||||
unzip $$f
|
||||
mv org/ray/runtime/generated/* $$GENERATED_DIR
|
||||
done
|
||||
# Copy flatbuffers-generated files
|
||||
for f in $(locations //java:generate_java_gcs_fbs); do
|
||||
cp $$f $$GENERATED_DIR
|
||||
done
|
||||
# Install plasma jar to local maven repo.
|
||||
mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \
|
||||
-DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT
|
||||
|
||||
@@ -4,7 +4,6 @@ def gen_java_deps():
|
||||
maven_install(
|
||||
artifacts = [
|
||||
"com.beust:jcommander:1.72",
|
||||
"com.github.davidmoten:flatbuffers-java:1.9.0.1",
|
||||
"com.google.guava:guava:27.0.1-jre",
|
||||
"com.google.protobuf:protobuf-java:3.8.0",
|
||||
"com.puppycrawl.tools:checkstyle:8.15",
|
||||
|
||||
@@ -31,11 +31,6 @@
|
||||
<artifactId>jcommander</artifactId>
|
||||
<version>1.72</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.github.davidmoten</groupId>
|
||||
<artifactId>flatbuffers-java</artifactId>
|
||||
<version>1.9.0.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayException;
|
||||
@@ -15,10 +19,8 @@ import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.generated.Arg;
|
||||
import org.ray.runtime.generated.Language;
|
||||
import org.ray.runtime.generated.ResourcePair;
|
||||
import org.ray.runtime.generated.TaskInfo;
|
||||
import org.ray.runtime.generated.Common;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
import org.ray.runtime.task.FunctionArg;
|
||||
import org.ray.runtime.task.TaskLanguage;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
@@ -30,15 +32,6 @@ public class RayletClientImpl implements RayletClient {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayletClientImpl.class);
|
||||
|
||||
private static final int TASK_SPEC_BUFFER_SIZE = 2 * 1024 * 1024;
|
||||
|
||||
/**
|
||||
* Direct buffers that are used to hold flatbuffer-serialized task specs.
|
||||
*/
|
||||
private static ThreadLocal<ByteBuffer> taskSpecBuffer = ThreadLocal.withInitial(() ->
|
||||
ByteBuffer.allocateDirect(TASK_SPEC_BUFFER_SIZE).order(ByteOrder.LITTLE_ENDIAN)
|
||||
);
|
||||
|
||||
/**
|
||||
* The pointer to c++'s raylet client.
|
||||
*/
|
||||
@@ -86,21 +79,20 @@ public class RayletClientImpl implements RayletClient {
|
||||
Preconditions.checkState(!spec.parentTaskId.isNil());
|
||||
Preconditions.checkState(!spec.jobId.isNil());
|
||||
|
||||
ByteBuffer info = convertTaskSpecToFlatbuffer(spec);
|
||||
byte[] taskSpec = convertTaskSpecToProtobuf(spec);
|
||||
byte[] cursorId = null;
|
||||
if (!spec.getExecutionDependencies().isEmpty()) {
|
||||
//TODO(hchen): handle more than one dependencies.
|
||||
cursorId = spec.getExecutionDependencies().get(0).getBytes();
|
||||
}
|
||||
nativeSubmitTask(client, cursorId, info, info.position(), info.remaining());
|
||||
nativeSubmitTask(client, cursorId, taskSpec);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskSpec getTask() {
|
||||
byte[] bytes = nativeGetTask(client);
|
||||
assert (null != bytes);
|
||||
ByteBuffer bb = ByteBuffer.wrap(bytes);
|
||||
return parseTaskSpecFromFlatbuffer(bb);
|
||||
return parseTaskSpecFromProtobuf(bytes);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -127,7 +119,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
|
||||
@Override
|
||||
public void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks) {
|
||||
boolean deleteCreatingTasks) {
|
||||
byte[][] objectIdsArray = IdUtil.getIdBytes(objectIds);
|
||||
nativeFreePlasmaObjects(client, objectIdsArray, localOnly, deleteCreatingTasks);
|
||||
}
|
||||
@@ -142,59 +134,76 @@ public class RayletClientImpl implements RayletClient {
|
||||
nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes());
|
||||
}
|
||||
|
||||
private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) {
|
||||
bb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
TaskInfo info = TaskInfo.getRootAsTaskInfo(bb);
|
||||
UniqueId jobId = UniqueId.fromByteBuffer(info.jobIdAsByteBuffer());
|
||||
TaskId taskId = TaskId.fromByteBuffer(info.taskIdAsByteBuffer());
|
||||
TaskId parentTaskId = TaskId.fromByteBuffer(info.parentTaskIdAsByteBuffer());
|
||||
int parentCounter = info.parentCounter();
|
||||
UniqueId actorCreationId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer());
|
||||
int maxActorReconstructions = info.maxActorReconstructions();
|
||||
UniqueId actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer());
|
||||
UniqueId actorHandleId = UniqueId.fromByteBuffer(info.actorHandleIdAsByteBuffer());
|
||||
int actorCounter = info.actorCounter();
|
||||
int numReturns = info.numReturns();
|
||||
/**
|
||||
* Parse `TaskSpec` protobuf bytes.
|
||||
*/
|
||||
private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
|
||||
Common.TaskSpec taskSpec;
|
||||
try {
|
||||
taskSpec = Common.TaskSpec.parseFrom(bytes);
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new RuntimeException("Invalid protobuf data.");
|
||||
}
|
||||
|
||||
// Deserialize new actor handles
|
||||
UniqueId[] newActorHandles = IdUtil.getUniqueIdsFromByteBuffer(
|
||||
info.newActorHandlesAsByteBuffer());
|
||||
// Parse common fields.
|
||||
UniqueId jobId = UniqueId.fromByteBuffer(taskSpec.getJobId().asReadOnlyByteBuffer());
|
||||
TaskId taskId = TaskId.fromByteBuffer(taskSpec.getTaskId().asReadOnlyByteBuffer());
|
||||
TaskId parentTaskId = TaskId.fromByteBuffer(taskSpec.getParentTaskId().asReadOnlyByteBuffer());
|
||||
int parentCounter = (int) taskSpec.getParentCounter();
|
||||
int numReturns = (int) taskSpec.getNumReturns();
|
||||
Map<String, Double> resources = taskSpec.getRequiredResourcesMap();
|
||||
|
||||
// Deserialize args
|
||||
FunctionArg[] args = new FunctionArg[info.argsLength()];
|
||||
for (int i = 0; i < info.argsLength(); i++) {
|
||||
Arg arg = info.args(i);
|
||||
|
||||
int objectIdsLength = arg.objectIdsAsByteBuffer().remaining() / UniqueId.LENGTH;
|
||||
// Parse args.
|
||||
FunctionArg[] args = new FunctionArg[taskSpec.getArgsCount()];
|
||||
for (int i = 0; i < args.length; i++) {
|
||||
Common.TaskArg arg = taskSpec.getArgs(i);
|
||||
int objectIdsLength = arg.getObjectIdsCount();
|
||||
if (objectIdsLength > 0) {
|
||||
Preconditions.checkArgument(objectIdsLength == 1,
|
||||
"This arg has more than one id: {}", objectIdsLength);
|
||||
args[i] = FunctionArg.passByReference(ObjectId.fromByteBuffer(arg.objectIdsAsByteBuffer()));
|
||||
args[i] = FunctionArg
|
||||
.passByReference(ObjectId.fromByteBuffer(arg.getObjectIds(0).asReadOnlyByteBuffer()));
|
||||
} else {
|
||||
ByteBuffer lbb = arg.dataAsByteBuffer();
|
||||
Preconditions.checkState(lbb != null && lbb.remaining() > 0);
|
||||
byte[] data = new byte[lbb.remaining()];
|
||||
lbb.get(data);
|
||||
args[i] = FunctionArg.passByValue(data);
|
||||
args[i] = FunctionArg.passByValue(arg.getData().toByteArray());
|
||||
}
|
||||
}
|
||||
|
||||
// Deserialize required resources;
|
||||
Map<String, Double> resources = new HashMap<>();
|
||||
for (int i = 0; i < info.requiredResourcesLength(); i++) {
|
||||
resources.put(info.requiredResources(i).key(), info.requiredResources(i).value());
|
||||
}
|
||||
// Deserialize function descriptor
|
||||
Preconditions.checkArgument(info.language() == Language.JAVA);
|
||||
Preconditions.checkArgument(info.functionDescriptorLength() == 3);
|
||||
// Parse function descriptor
|
||||
Preconditions.checkArgument(taskSpec.getLanguage() == Common.Language.JAVA);
|
||||
Preconditions.checkArgument(taskSpec.getFunctionDescriptorCount() == 3);
|
||||
JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor(
|
||||
info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2)
|
||||
taskSpec.getFunctionDescriptor(0).toString(Charset.defaultCharset()),
|
||||
taskSpec.getFunctionDescriptor(1).toString(Charset.defaultCharset()),
|
||||
taskSpec.getFunctionDescriptor(2).toString(Charset.defaultCharset())
|
||||
);
|
||||
|
||||
// Deserialize dynamic worker options.
|
||||
// Parse ActorCreationTaskSpec.
|
||||
UniqueId actorCreationId = UniqueId.NIL;
|
||||
int maxActorReconstructions = 0;
|
||||
UniqueId[] newActorHandles = new UniqueId[0];
|
||||
List<String> dynamicWorkerOptions = new ArrayList<>();
|
||||
for (int i = 0; i < info.dynamicWorkerOptionsLength(); ++i) {
|
||||
dynamicWorkerOptions.add(info.dynamicWorkerOptions(i));
|
||||
if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) {
|
||||
Common.ActorCreationTaskSpec actorCreationTaskSpec = taskSpec.getActorCreationTaskSpec();
|
||||
actorCreationId = UniqueId
|
||||
.fromByteBuffer(actorCreationTaskSpec.getActorId().asReadOnlyByteBuffer());
|
||||
maxActorReconstructions = (int) actorCreationTaskSpec.getMaxActorReconstructions();
|
||||
dynamicWorkerOptions = ImmutableList
|
||||
.copyOf(actorCreationTaskSpec.getDynamicWorkerOptionsList());
|
||||
}
|
||||
|
||||
// Parse ActorTaskSpec.
|
||||
UniqueId actorId = UniqueId.NIL;
|
||||
UniqueId actorHandleId = UniqueId.NIL;
|
||||
int actorCounter = 0;
|
||||
if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) {
|
||||
Common.ActorTaskSpec actorTaskSpec = taskSpec.getActorTaskSpec();
|
||||
actorId = UniqueId.fromByteBuffer(actorTaskSpec.getActorId().asReadOnlyByteBuffer());
|
||||
actorHandleId = UniqueId
|
||||
.fromByteBuffer(actorTaskSpec.getActorHandleId().asReadOnlyByteBuffer());
|
||||
actorCounter = (int) actorTaskSpec.getActorCounter();
|
||||
newActorHandles = actorTaskSpec.getNewActorHandlesList().stream()
|
||||
.map(byteString -> UniqueId.fromByteBuffer(byteString.asReadOnlyByteBuffer()))
|
||||
.toArray(UniqueId[]::new);
|
||||
}
|
||||
|
||||
return new TaskSpec(jobId, taskId, parentTaskId, parentCounter, actorCreationId,
|
||||
@@ -202,122 +211,78 @@ public class RayletClientImpl implements RayletClient {
|
||||
args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions);
|
||||
}
|
||||
|
||||
private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) {
|
||||
ByteBuffer bb = taskSpecBuffer.get();
|
||||
bb.clear();
|
||||
/**
|
||||
* Convert a `TaskSpec` to protobuf-serialized bytes.
|
||||
*/
|
||||
private static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
|
||||
// Set common fields.
|
||||
Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder()
|
||||
.setJobId(ByteString.copyFrom(task.jobId.getBytes()))
|
||||
.setTaskId(ByteString.copyFrom(task.taskId.getBytes()))
|
||||
.setParentTaskId(ByteString.copyFrom(task.parentTaskId.getBytes()))
|
||||
.setParentCounter(task.parentCounter)
|
||||
.setNumReturns(task.numReturns)
|
||||
.putAllRequiredResources(task.resources);
|
||||
|
||||
FlatBufferBuilder fbb = new FlatBufferBuilder(bb);
|
||||
final int jobIdOffset = fbb.createString(task.jobId.toByteBuffer());
|
||||
final int taskIdOffset = fbb.createString(task.taskId.toByteBuffer());
|
||||
final int parentTaskIdOffset = fbb.createString(task.parentTaskId.toByteBuffer());
|
||||
final int parentCounter = task.parentCounter;
|
||||
final int actorCreateIdOffset = fbb.createString(task.actorCreationId.toByteBuffer());
|
||||
final int actorCreateDummyIdOffset = fbb.createString(task.actorId.toByteBuffer());
|
||||
final int maxActorReconstructions = task.maxActorReconstructions;
|
||||
final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer());
|
||||
final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer());
|
||||
final int actorCounter = task.actorCounter;
|
||||
final int numReturnsOffset = task.numReturns;
|
||||
|
||||
// Serialize the new actor handles.
|
||||
int newActorHandlesOffset
|
||||
= fbb.createString(IdUtil.concatIds(task.newActorHandles));
|
||||
|
||||
// Serialize args
|
||||
int[] argsOffsets = new int[task.args.length];
|
||||
for (int i = 0; i < argsOffsets.length; i++) {
|
||||
int objectIdOffset = 0;
|
||||
int dataOffset = 0;
|
||||
if (task.args[i].id != null) {
|
||||
objectIdOffset = fbb.createString(
|
||||
IdUtil.concatIds(new ObjectId[]{task.args[i].id}));
|
||||
} else {
|
||||
objectIdOffset = fbb.createString("");
|
||||
}
|
||||
if (task.args[i].data != null) {
|
||||
dataOffset = fbb.createString(ByteBuffer.wrap(task.args[i].data));
|
||||
}
|
||||
argsOffsets[i] = Arg.createArg(fbb, objectIdOffset, dataOffset);
|
||||
}
|
||||
int argsOffset = fbb.createVectorOfTables(argsOffsets);
|
||||
|
||||
// Serialize required resources
|
||||
// The required_resources vector indicates the quantities of the different
|
||||
// resources required by this task. The index in this vector corresponds to
|
||||
// the resource type defined in the ResourceIndex enum. For example,
|
||||
int[] requiredResourcesOffsets = new int[task.resources.size()];
|
||||
int i = 0;
|
||||
for (Map.Entry<String, Double> entry : task.resources.entrySet()) {
|
||||
int keyOffset = fbb.createString(ByteBuffer.wrap(entry.getKey().getBytes()));
|
||||
requiredResourcesOffsets[i++] =
|
||||
ResourcePair.createResourcePair(fbb, keyOffset, entry.getValue());
|
||||
}
|
||||
int requiredResourcesOffset = fbb.createVectorOfTables(requiredResourcesOffsets);
|
||||
|
||||
int[] requiredPlacementResourcesOffsets = new int[0];
|
||||
int requiredPlacementResourcesOffset =
|
||||
fbb.createVectorOfTables(requiredPlacementResourcesOffsets);
|
||||
|
||||
int language;
|
||||
int functionDescriptorOffset;
|
||||
// Set args
|
||||
builder.addAllArgs(
|
||||
Arrays.stream(task.args).map(arg -> {
|
||||
Common.TaskArg.Builder argBuilder = Common.TaskArg.newBuilder();
|
||||
if (arg.id != null) {
|
||||
argBuilder.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build();
|
||||
} else {
|
||||
argBuilder.setData(ByteString.copyFrom(arg.data)).build();
|
||||
}
|
||||
return argBuilder.build();
|
||||
}).collect(Collectors.toList())
|
||||
);
|
||||
|
||||
// Set function descriptor and language.
|
||||
if (task.language == TaskLanguage.JAVA) {
|
||||
// This is a Java task.
|
||||
language = Language.JAVA;
|
||||
int[] functionDescriptorOffsets = new int[]{
|
||||
fbb.createString(task.getJavaFunctionDescriptor().className),
|
||||
fbb.createString(task.getJavaFunctionDescriptor().name),
|
||||
fbb.createString(task.getJavaFunctionDescriptor().typeDescriptor)
|
||||
};
|
||||
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
|
||||
builder.setLanguage(Common.Language.JAVA);
|
||||
builder.addAllFunctionDescriptor(ImmutableList.of(
|
||||
ByteString.copyFrom(task.getJavaFunctionDescriptor().className.getBytes()),
|
||||
ByteString.copyFrom(task.getJavaFunctionDescriptor().name.getBytes()),
|
||||
ByteString.copyFrom(task.getJavaFunctionDescriptor().typeDescriptor.getBytes())
|
||||
));
|
||||
} else {
|
||||
// This is a Python task.
|
||||
language = Language.PYTHON;
|
||||
int[] functionDescriptorOffsets = new int[]{
|
||||
fbb.createString(task.getPyFunctionDescriptor().moduleName),
|
||||
fbb.createString(task.getPyFunctionDescriptor().className),
|
||||
fbb.createString(task.getPyFunctionDescriptor().functionName),
|
||||
fbb.createString("")
|
||||
};
|
||||
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
|
||||
builder.setLanguage(Common.Language.PYTHON);
|
||||
builder.addAllFunctionDescriptor(ImmutableList.of(
|
||||
ByteString.copyFrom(task.getPyFunctionDescriptor().moduleName.getBytes()),
|
||||
ByteString.copyFrom(task.getPyFunctionDescriptor().className.getBytes()),
|
||||
ByteString.copyFrom(task.getPyFunctionDescriptor().functionName.getBytes()),
|
||||
ByteString.EMPTY
|
||||
));
|
||||
}
|
||||
|
||||
int [] dynamicWorkerOptionsOffsets = new int[task.dynamicWorkerOptions.size()];
|
||||
for (int index = 0; index < task.dynamicWorkerOptions.size(); ++index) {
|
||||
dynamicWorkerOptionsOffsets[index] = fbb.createString(task.dynamicWorkerOptions.get(index));
|
||||
if (!task.actorCreationId.isNil()) {
|
||||
// Actor creation task.
|
||||
builder.setType(TaskType.ACTOR_CREATION_TASK);
|
||||
builder.setActorCreationTaskSpec(
|
||||
Common.ActorCreationTaskSpec.newBuilder()
|
||||
.setActorId(ByteString.copyFrom(task.actorCreationId.getBytes()))
|
||||
.setMaxActorReconstructions(task.maxActorReconstructions)
|
||||
.addAllDynamicWorkerOptions(task.dynamicWorkerOptions)
|
||||
);
|
||||
} else if (!task.actorId.isNil()) {
|
||||
// Actor task.
|
||||
builder.setType(TaskType.ACTOR_TASK);
|
||||
List<ByteString> newHandles = Arrays.stream(task.newActorHandles)
|
||||
.map(id -> ByteString.copyFrom(id.getBytes())).collect(Collectors.toList());
|
||||
builder.setActorTaskSpec(
|
||||
Common.ActorTaskSpec.newBuilder()
|
||||
.setActorId(ByteString.copyFrom(task.actorId.getBytes()))
|
||||
.setActorHandleId(ByteString.copyFrom(task.actorHandleId.getBytes()))
|
||||
.setActorCreationDummyObjectId(ByteString.copyFrom(task.actorId.getBytes()))
|
||||
.setActorCounter(task.actorCounter)
|
||||
.addAllNewActorHandles(newHandles)
|
||||
);
|
||||
} else {
|
||||
// Normal task.
|
||||
builder.setType(TaskType.NORMAL_TASK);
|
||||
}
|
||||
int dynamicWorkerOptionsOffset = fbb.createVectorOfTables(dynamicWorkerOptionsOffsets);
|
||||
|
||||
int root = TaskInfo.createTaskInfo(
|
||||
fbb,
|
||||
jobIdOffset,
|
||||
taskIdOffset,
|
||||
parentTaskIdOffset,
|
||||
parentCounter,
|
||||
actorCreateIdOffset,
|
||||
actorCreateDummyIdOffset,
|
||||
maxActorReconstructions,
|
||||
actorIdOffset,
|
||||
actorHandleIdOffset,
|
||||
actorCounter,
|
||||
newActorHandlesOffset,
|
||||
argsOffset,
|
||||
numReturnsOffset,
|
||||
requiredResourcesOffset,
|
||||
requiredPlacementResourcesOffset,
|
||||
language,
|
||||
functionDescriptorOffset,
|
||||
dynamicWorkerOptionsOffset);
|
||||
fbb.finish(root);
|
||||
ByteBuffer buffer = fbb.dataBuffer();
|
||||
|
||||
if (buffer.remaining() > TASK_SPEC_BUFFER_SIZE) {
|
||||
LOGGER.error(
|
||||
"Allocated buffer is not enough to transfer the task specification: {} vs {}",
|
||||
TASK_SPEC_BUFFER_SIZE, buffer.remaining());
|
||||
throw new RuntimeException("Allocated buffer is not enough to transfer to task.");
|
||||
}
|
||||
return buffer;
|
||||
return builder.build().toByteArray();
|
||||
}
|
||||
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
@@ -344,10 +309,9 @@ public class RayletClientImpl implements RayletClient {
|
||||
private static native long nativeInit(String localSchedulerSocket, byte[] workerId,
|
||||
boolean isWorker, byte[] driverTaskId);
|
||||
|
||||
private static native void nativeSubmitTask(long client, byte[] cursorId, ByteBuffer taskBuff,
|
||||
int pos, int taskSize) throws RayException;
|
||||
private static native void nativeSubmitTask(long client, byte[] cursorId, byte[] taskSpec)
|
||||
throws RayException;
|
||||
|
||||
// return TaskInfo (in FlatBuffer)
|
||||
private static native byte[] nativeGetTask(long client) throws RayException;
|
||||
|
||||
private static native void nativeDestroy(long client) throws RayException;
|
||||
@@ -375,5 +339,5 @@ public class RayletClientImpl implements RayletClient {
|
||||
byte[] checkpointId);
|
||||
|
||||
private static native void nativeSetResource(long conn, String resourceName, double capacity,
|
||||
byte[] nodeId) throws RayException;
|
||||
byte[] nodeId) throws RayException;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user