Define common data structures with protobuf. (#5121)

This commit is contained in:
Hao Chen
2019-07-08 22:41:37 +08:00
committed by GitHub
parent b4e51c8aa1
commit 8a30b93e42
64 changed files with 1233 additions and 1561 deletions
+21 -45
View File
@@ -1,4 +1,4 @@
load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module")
load("//bazel:ray.bzl", "define_java_module")
load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile")
exports_files([
@@ -50,8 +50,10 @@ define_java_module(
define_java_module(
name = "runtime",
additional_srcs = [
":generate_java_gcs_fbs",
":gcs_java_proto",
":all_java_proto",
],
exclude_srcs = [
"runtime/src/main/java/org/ray/runtime/generated/*.java",
],
additional_resources = [
":java_native_deps",
@@ -68,7 +70,6 @@ define_java_module(
deps = [
":org_ray_ray_api",
"@plasma//:org_apache_arrow_arrow_plasma",
"@maven//:com_github_davidmoten_flatbuffers_java",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_typesafe_config",
@@ -151,39 +152,22 @@ java_binary(
],
)
java_proto_compile(
name = "common_java_proto",
deps = ["@//:common_proto"],
)
java_proto_compile(
name = "gcs_java_proto",
deps = ["@//:gcs_proto"],
)
flatbuffers_generated_files = [
"Arg.java",
"Language.java",
"TaskInfo.java",
"ResourcePair.java",
]
flatbuffer_java_library(
name = "java_gcs_fbs",
srcs = ["//:gcs_fbs_file"],
outs = flatbuffers_generated_files,
out_prefix = "",
)
genrule(
name = "generate_java_gcs_fbs",
srcs = [":java_gcs_fbs"],
outs = [
"runtime/src/main/java/org/ray/runtime/generated/" + file for file in flatbuffers_generated_files
filegroup(
name = "all_java_proto",
srcs = [
":common_java_proto",
":gcs_java_proto",
],
cmd = """
for f in $(locations //java:java_gcs_fbs); do
chmod +w $$f
mv -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated
done
python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/..
""",
local = 1,
)
filegroup(
@@ -202,8 +186,7 @@ filegroup(
genrule(
name = "gen_maven_deps",
srcs = [
":gcs_java_proto",
":generate_java_gcs_fbs",
":all_java_proto",
":java_native_deps",
":copy_pom_file",
"@plasma//:org_apache_arrow_arrow_plasma",
@@ -212,6 +195,11 @@ genrule(
cmd = """
set -x
WORK_DIR=$$(pwd)
# Copy protobuf-generated files.
rm -rf $$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated
for f in $(locations //java:all_java_proto); do
unzip $$f -x META-INF/MANIFEST.MF -d $$WORK_DIR/java/runtime/src/main/java
done
# Copy native dependecies.
NATIVE_DEPS_DIR=$$WORK_DIR/java/runtime/native_dependencies/
rm -rf $$NATIVE_DEPS_DIR
@@ -220,18 +208,6 @@ genrule(
chmod +w $$f
cp $$f $$NATIVE_DEPS_DIR
done
# Copy protobuf-generated files.
GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated
rm -rf $$GENERATED_DIR
mkdir -p $$GENERATED_DIR
for f in $(locations //java:gcs_java_proto); do
unzip $$f
mv org/ray/runtime/generated/* $$GENERATED_DIR
done
# Copy flatbuffers-generated files
for f in $(locations //java:generate_java_gcs_fbs); do
cp $$f $$GENERATED_DIR
done
# Install plasma jar to local maven repo.
mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \
-DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT
-1
View File
@@ -4,7 +4,6 @@ def gen_java_deps():
maven_install(
artifacts = [
"com.beust:jcommander:1.72",
"com.github.davidmoten:flatbuffers-java:1.9.0.1",
"com.google.guava:guava:27.0.1-jre",
"com.google.protobuf:protobuf-java:3.8.0",
"com.puppycrawl.tools:checkstyle:8.15",
-5
View File
@@ -31,11 +31,6 @@
<artifactId>jcommander</artifactId>
<version>1.72</version>
</dependency>
<dependency>
<groupId>com.github.davidmoten</groupId>
<artifactId>flatbuffers-java</artifactId>
<version>1.9.0.1</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
@@ -1,13 +1,17 @@
package org.ray.runtime.raylet;
import com.google.common.base.Preconditions;
import com.google.flatbuffers.FlatBufferBuilder;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
@@ -15,10 +19,8 @@ import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.generated.Arg;
import org.ray.runtime.generated.Language;
import org.ray.runtime.generated.ResourcePair;
import org.ray.runtime.generated.TaskInfo;
import org.ray.runtime.generated.Common;
import org.ray.runtime.generated.Common.TaskType;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskLanguage;
import org.ray.runtime.task.TaskSpec;
@@ -30,15 +32,6 @@ public class RayletClientImpl implements RayletClient {
private static final Logger LOGGER = LoggerFactory.getLogger(RayletClientImpl.class);
private static final int TASK_SPEC_BUFFER_SIZE = 2 * 1024 * 1024;
/**
* Direct buffers that are used to hold flatbuffer-serialized task specs.
*/
private static ThreadLocal<ByteBuffer> taskSpecBuffer = ThreadLocal.withInitial(() ->
ByteBuffer.allocateDirect(TASK_SPEC_BUFFER_SIZE).order(ByteOrder.LITTLE_ENDIAN)
);
/**
* The pointer to c++'s raylet client.
*/
@@ -86,21 +79,20 @@ public class RayletClientImpl implements RayletClient {
Preconditions.checkState(!spec.parentTaskId.isNil());
Preconditions.checkState(!spec.jobId.isNil());
ByteBuffer info = convertTaskSpecToFlatbuffer(spec);
byte[] taskSpec = convertTaskSpecToProtobuf(spec);
byte[] cursorId = null;
if (!spec.getExecutionDependencies().isEmpty()) {
//TODO(hchen): handle more than one dependencies.
cursorId = spec.getExecutionDependencies().get(0).getBytes();
}
nativeSubmitTask(client, cursorId, info, info.position(), info.remaining());
nativeSubmitTask(client, cursorId, taskSpec);
}
@Override
public TaskSpec getTask() {
byte[] bytes = nativeGetTask(client);
assert (null != bytes);
ByteBuffer bb = ByteBuffer.wrap(bytes);
return parseTaskSpecFromFlatbuffer(bb);
return parseTaskSpecFromProtobuf(bytes);
}
@Override
@@ -127,7 +119,7 @@ public class RayletClientImpl implements RayletClient {
@Override
public void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly,
boolean deleteCreatingTasks) {
boolean deleteCreatingTasks) {
byte[][] objectIdsArray = IdUtil.getIdBytes(objectIds);
nativeFreePlasmaObjects(client, objectIdsArray, localOnly, deleteCreatingTasks);
}
@@ -142,59 +134,76 @@ public class RayletClientImpl implements RayletClient {
nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes());
}
private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) {
bb.order(ByteOrder.LITTLE_ENDIAN);
TaskInfo info = TaskInfo.getRootAsTaskInfo(bb);
UniqueId jobId = UniqueId.fromByteBuffer(info.jobIdAsByteBuffer());
TaskId taskId = TaskId.fromByteBuffer(info.taskIdAsByteBuffer());
TaskId parentTaskId = TaskId.fromByteBuffer(info.parentTaskIdAsByteBuffer());
int parentCounter = info.parentCounter();
UniqueId actorCreationId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer());
int maxActorReconstructions = info.maxActorReconstructions();
UniqueId actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer());
UniqueId actorHandleId = UniqueId.fromByteBuffer(info.actorHandleIdAsByteBuffer());
int actorCounter = info.actorCounter();
int numReturns = info.numReturns();
/**
* Parse `TaskSpec` protobuf bytes.
*/
private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
Common.TaskSpec taskSpec;
try {
taskSpec = Common.TaskSpec.parseFrom(bytes);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException("Invalid protobuf data.");
}
// Deserialize new actor handles
UniqueId[] newActorHandles = IdUtil.getUniqueIdsFromByteBuffer(
info.newActorHandlesAsByteBuffer());
// Parse common fields.
UniqueId jobId = UniqueId.fromByteBuffer(taskSpec.getJobId().asReadOnlyByteBuffer());
TaskId taskId = TaskId.fromByteBuffer(taskSpec.getTaskId().asReadOnlyByteBuffer());
TaskId parentTaskId = TaskId.fromByteBuffer(taskSpec.getParentTaskId().asReadOnlyByteBuffer());
int parentCounter = (int) taskSpec.getParentCounter();
int numReturns = (int) taskSpec.getNumReturns();
Map<String, Double> resources = taskSpec.getRequiredResourcesMap();
// Deserialize args
FunctionArg[] args = new FunctionArg[info.argsLength()];
for (int i = 0; i < info.argsLength(); i++) {
Arg arg = info.args(i);
int objectIdsLength = arg.objectIdsAsByteBuffer().remaining() / UniqueId.LENGTH;
// Parse args.
FunctionArg[] args = new FunctionArg[taskSpec.getArgsCount()];
for (int i = 0; i < args.length; i++) {
Common.TaskArg arg = taskSpec.getArgs(i);
int objectIdsLength = arg.getObjectIdsCount();
if (objectIdsLength > 0) {
Preconditions.checkArgument(objectIdsLength == 1,
"This arg has more than one id: {}", objectIdsLength);
args[i] = FunctionArg.passByReference(ObjectId.fromByteBuffer(arg.objectIdsAsByteBuffer()));
args[i] = FunctionArg
.passByReference(ObjectId.fromByteBuffer(arg.getObjectIds(0).asReadOnlyByteBuffer()));
} else {
ByteBuffer lbb = arg.dataAsByteBuffer();
Preconditions.checkState(lbb != null && lbb.remaining() > 0);
byte[] data = new byte[lbb.remaining()];
lbb.get(data);
args[i] = FunctionArg.passByValue(data);
args[i] = FunctionArg.passByValue(arg.getData().toByteArray());
}
}
// Deserialize required resources;
Map<String, Double> resources = new HashMap<>();
for (int i = 0; i < info.requiredResourcesLength(); i++) {
resources.put(info.requiredResources(i).key(), info.requiredResources(i).value());
}
// Deserialize function descriptor
Preconditions.checkArgument(info.language() == Language.JAVA);
Preconditions.checkArgument(info.functionDescriptorLength() == 3);
// Parse function descriptor
Preconditions.checkArgument(taskSpec.getLanguage() == Common.Language.JAVA);
Preconditions.checkArgument(taskSpec.getFunctionDescriptorCount() == 3);
JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor(
info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2)
taskSpec.getFunctionDescriptor(0).toString(Charset.defaultCharset()),
taskSpec.getFunctionDescriptor(1).toString(Charset.defaultCharset()),
taskSpec.getFunctionDescriptor(2).toString(Charset.defaultCharset())
);
// Deserialize dynamic worker options.
// Parse ActorCreationTaskSpec.
UniqueId actorCreationId = UniqueId.NIL;
int maxActorReconstructions = 0;
UniqueId[] newActorHandles = new UniqueId[0];
List<String> dynamicWorkerOptions = new ArrayList<>();
for (int i = 0; i < info.dynamicWorkerOptionsLength(); ++i) {
dynamicWorkerOptions.add(info.dynamicWorkerOptions(i));
if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) {
Common.ActorCreationTaskSpec actorCreationTaskSpec = taskSpec.getActorCreationTaskSpec();
actorCreationId = UniqueId
.fromByteBuffer(actorCreationTaskSpec.getActorId().asReadOnlyByteBuffer());
maxActorReconstructions = (int) actorCreationTaskSpec.getMaxActorReconstructions();
dynamicWorkerOptions = ImmutableList
.copyOf(actorCreationTaskSpec.getDynamicWorkerOptionsList());
}
// Parse ActorTaskSpec.
UniqueId actorId = UniqueId.NIL;
UniqueId actorHandleId = UniqueId.NIL;
int actorCounter = 0;
if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) {
Common.ActorTaskSpec actorTaskSpec = taskSpec.getActorTaskSpec();
actorId = UniqueId.fromByteBuffer(actorTaskSpec.getActorId().asReadOnlyByteBuffer());
actorHandleId = UniqueId
.fromByteBuffer(actorTaskSpec.getActorHandleId().asReadOnlyByteBuffer());
actorCounter = (int) actorTaskSpec.getActorCounter();
newActorHandles = actorTaskSpec.getNewActorHandlesList().stream()
.map(byteString -> UniqueId.fromByteBuffer(byteString.asReadOnlyByteBuffer()))
.toArray(UniqueId[]::new);
}
return new TaskSpec(jobId, taskId, parentTaskId, parentCounter, actorCreationId,
@@ -202,122 +211,78 @@ public class RayletClientImpl implements RayletClient {
args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions);
}
private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) {
ByteBuffer bb = taskSpecBuffer.get();
bb.clear();
/**
* Convert a `TaskSpec` to protobuf-serialized bytes.
*/
private static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
// Set common fields.
Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder()
.setJobId(ByteString.copyFrom(task.jobId.getBytes()))
.setTaskId(ByteString.copyFrom(task.taskId.getBytes()))
.setParentTaskId(ByteString.copyFrom(task.parentTaskId.getBytes()))
.setParentCounter(task.parentCounter)
.setNumReturns(task.numReturns)
.putAllRequiredResources(task.resources);
FlatBufferBuilder fbb = new FlatBufferBuilder(bb);
final int jobIdOffset = fbb.createString(task.jobId.toByteBuffer());
final int taskIdOffset = fbb.createString(task.taskId.toByteBuffer());
final int parentTaskIdOffset = fbb.createString(task.parentTaskId.toByteBuffer());
final int parentCounter = task.parentCounter;
final int actorCreateIdOffset = fbb.createString(task.actorCreationId.toByteBuffer());
final int actorCreateDummyIdOffset = fbb.createString(task.actorId.toByteBuffer());
final int maxActorReconstructions = task.maxActorReconstructions;
final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer());
final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer());
final int actorCounter = task.actorCounter;
final int numReturnsOffset = task.numReturns;
// Serialize the new actor handles.
int newActorHandlesOffset
= fbb.createString(IdUtil.concatIds(task.newActorHandles));
// Serialize args
int[] argsOffsets = new int[task.args.length];
for (int i = 0; i < argsOffsets.length; i++) {
int objectIdOffset = 0;
int dataOffset = 0;
if (task.args[i].id != null) {
objectIdOffset = fbb.createString(
IdUtil.concatIds(new ObjectId[]{task.args[i].id}));
} else {
objectIdOffset = fbb.createString("");
}
if (task.args[i].data != null) {
dataOffset = fbb.createString(ByteBuffer.wrap(task.args[i].data));
}
argsOffsets[i] = Arg.createArg(fbb, objectIdOffset, dataOffset);
}
int argsOffset = fbb.createVectorOfTables(argsOffsets);
// Serialize required resources
// The required_resources vector indicates the quantities of the different
// resources required by this task. The index in this vector corresponds to
// the resource type defined in the ResourceIndex enum. For example,
int[] requiredResourcesOffsets = new int[task.resources.size()];
int i = 0;
for (Map.Entry<String, Double> entry : task.resources.entrySet()) {
int keyOffset = fbb.createString(ByteBuffer.wrap(entry.getKey().getBytes()));
requiredResourcesOffsets[i++] =
ResourcePair.createResourcePair(fbb, keyOffset, entry.getValue());
}
int requiredResourcesOffset = fbb.createVectorOfTables(requiredResourcesOffsets);
int[] requiredPlacementResourcesOffsets = new int[0];
int requiredPlacementResourcesOffset =
fbb.createVectorOfTables(requiredPlacementResourcesOffsets);
int language;
int functionDescriptorOffset;
// Set args
builder.addAllArgs(
Arrays.stream(task.args).map(arg -> {
Common.TaskArg.Builder argBuilder = Common.TaskArg.newBuilder();
if (arg.id != null) {
argBuilder.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build();
} else {
argBuilder.setData(ByteString.copyFrom(arg.data)).build();
}
return argBuilder.build();
}).collect(Collectors.toList())
);
// Set function descriptor and language.
if (task.language == TaskLanguage.JAVA) {
// This is a Java task.
language = Language.JAVA;
int[] functionDescriptorOffsets = new int[]{
fbb.createString(task.getJavaFunctionDescriptor().className),
fbb.createString(task.getJavaFunctionDescriptor().name),
fbb.createString(task.getJavaFunctionDescriptor().typeDescriptor)
};
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
builder.setLanguage(Common.Language.JAVA);
builder.addAllFunctionDescriptor(ImmutableList.of(
ByteString.copyFrom(task.getJavaFunctionDescriptor().className.getBytes()),
ByteString.copyFrom(task.getJavaFunctionDescriptor().name.getBytes()),
ByteString.copyFrom(task.getJavaFunctionDescriptor().typeDescriptor.getBytes())
));
} else {
// This is a Python task.
language = Language.PYTHON;
int[] functionDescriptorOffsets = new int[]{
fbb.createString(task.getPyFunctionDescriptor().moduleName),
fbb.createString(task.getPyFunctionDescriptor().className),
fbb.createString(task.getPyFunctionDescriptor().functionName),
fbb.createString("")
};
functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets);
builder.setLanguage(Common.Language.PYTHON);
builder.addAllFunctionDescriptor(ImmutableList.of(
ByteString.copyFrom(task.getPyFunctionDescriptor().moduleName.getBytes()),
ByteString.copyFrom(task.getPyFunctionDescriptor().className.getBytes()),
ByteString.copyFrom(task.getPyFunctionDescriptor().functionName.getBytes()),
ByteString.EMPTY
));
}
int [] dynamicWorkerOptionsOffsets = new int[task.dynamicWorkerOptions.size()];
for (int index = 0; index < task.dynamicWorkerOptions.size(); ++index) {
dynamicWorkerOptionsOffsets[index] = fbb.createString(task.dynamicWorkerOptions.get(index));
if (!task.actorCreationId.isNil()) {
// Actor creation task.
builder.setType(TaskType.ACTOR_CREATION_TASK);
builder.setActorCreationTaskSpec(
Common.ActorCreationTaskSpec.newBuilder()
.setActorId(ByteString.copyFrom(task.actorCreationId.getBytes()))
.setMaxActorReconstructions(task.maxActorReconstructions)
.addAllDynamicWorkerOptions(task.dynamicWorkerOptions)
);
} else if (!task.actorId.isNil()) {
// Actor task.
builder.setType(TaskType.ACTOR_TASK);
List<ByteString> newHandles = Arrays.stream(task.newActorHandles)
.map(id -> ByteString.copyFrom(id.getBytes())).collect(Collectors.toList());
builder.setActorTaskSpec(
Common.ActorTaskSpec.newBuilder()
.setActorId(ByteString.copyFrom(task.actorId.getBytes()))
.setActorHandleId(ByteString.copyFrom(task.actorHandleId.getBytes()))
.setActorCreationDummyObjectId(ByteString.copyFrom(task.actorId.getBytes()))
.setActorCounter(task.actorCounter)
.addAllNewActorHandles(newHandles)
);
} else {
// Normal task.
builder.setType(TaskType.NORMAL_TASK);
}
int dynamicWorkerOptionsOffset = fbb.createVectorOfTables(dynamicWorkerOptionsOffsets);
int root = TaskInfo.createTaskInfo(
fbb,
jobIdOffset,
taskIdOffset,
parentTaskIdOffset,
parentCounter,
actorCreateIdOffset,
actorCreateDummyIdOffset,
maxActorReconstructions,
actorIdOffset,
actorHandleIdOffset,
actorCounter,
newActorHandlesOffset,
argsOffset,
numReturnsOffset,
requiredResourcesOffset,
requiredPlacementResourcesOffset,
language,
functionDescriptorOffset,
dynamicWorkerOptionsOffset);
fbb.finish(root);
ByteBuffer buffer = fbb.dataBuffer();
if (buffer.remaining() > TASK_SPEC_BUFFER_SIZE) {
LOGGER.error(
"Allocated buffer is not enough to transfer the task specification: {} vs {}",
TASK_SPEC_BUFFER_SIZE, buffer.remaining());
throw new RuntimeException("Allocated buffer is not enough to transfer to task.");
}
return buffer;
return builder.build().toByteArray();
}
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
@@ -344,10 +309,9 @@ public class RayletClientImpl implements RayletClient {
private static native long nativeInit(String localSchedulerSocket, byte[] workerId,
boolean isWorker, byte[] driverTaskId);
private static native void nativeSubmitTask(long client, byte[] cursorId, ByteBuffer taskBuff,
int pos, int taskSize) throws RayException;
private static native void nativeSubmitTask(long client, byte[] cursorId, byte[] taskSpec)
throws RayException;
// return TaskInfo (in FlatBuffer)
private static native byte[] nativeGetTask(long client) throws RayException;
private static native void nativeDestroy(long client) throws RayException;
@@ -375,5 +339,5 @@ public class RayletClientImpl implements RayletClient {
byte[] checkpointId);
private static native void nativeSetResource(long conn, String resourceName, double capacity,
byte[] nodeId) throws RayException;
byte[] nodeId) throws RayException;
}