[Java] Avoid data copy from C++ to Java for ByteBuffer type (#9033)

This commit is contained in:
Kai Yang
2020-07-22 16:25:32 +08:00
committed by GitHub
parent 6346c70792
commit bfa0605282
9 changed files with 125 additions and 16 deletions
@@ -7,6 +7,7 @@ import io.ray.api.exception.UnreconstructableException;
import io.ray.api.id.ObjectId;
import io.ray.runtime.generated.Gcs.ErrorType;
import io.ray.runtime.serializer.Serializer;
import java.nio.ByteBuffer;
import java.util.Arrays;
import org.apache.commons.lang3.tuple.Pair;
@@ -45,6 +46,9 @@ public class ObjectSerializer {
if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
if (Arrays.equals(meta, OBJECT_METADATA_TYPE_RAW)) {
if (objectType == ByteBuffer.class) {
return ByteBuffer.wrap(data);
}
return data;
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_CROSS_LANGUAGE) ||
Arrays.equals(meta, OBJECT_METADATA_TYPE_JAVA)) {
@@ -81,6 +85,17 @@ public class ObjectSerializer {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW);
} else if (object instanceof ByteBuffer) {
// Serialize ByteBuffer to raw bytes.
ByteBuffer buffer = (ByteBuffer) object;
byte[] bytes;
if (buffer.hasArray()) {
bytes = buffer.array();
} else {
bytes = new byte[buffer.remaining()];
buffer.get(bytes);
}
return new NativeRayObject(bytes, OBJECT_METADATA_TYPE_RAW);
} else if (object instanceof RayTaskException) {
byte[] serializedBytes = Serializer.encode(object).getLeft();
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);
@@ -1,5 +1,6 @@
package io.ray.runtime.task;
import com.google.common.base.Preconditions;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
@@ -7,6 +8,7 @@ import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.object.NativeRayObject;
import io.ray.runtime.object.ObjectSerializer;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -68,12 +70,19 @@ public class ArgumentsBuilder {
}
/**
* Convert list of NativeRayObject to real function arguments.
* Convert list of NativeRayObject/ByteBuffer to real function arguments.
*/
public static Object[] unwrap(List<NativeRayObject> args, Class<?>[] types) {
public static Object[] unwrap(List<Object> args, Class<?>[] types) {
Object[] realArgs = new Object[args.size()];
for (int i = 0; i < args.size(); i++) {
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, types[i]);
Object arg = args.get(i);
Preconditions.checkState(arg instanceof ByteBuffer || arg instanceof NativeRayObject);
if (arg instanceof ByteBuffer) {
Preconditions.checkState(types[i] == ByteBuffer.class);
realArgs[i] = arg;
} else {
realArgs[i] = ObjectSerializer.deserialize((NativeRayObject) arg, null, types[i]);
}
}
return realArgs;
}
@@ -311,8 +311,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
: UniqueId.randomId();
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
List<NativeRayObject> returnObjects = taskExecutor
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
List<String> rayFunctionInfo = getJavaFunctionDescriptor(taskSpec).toList();
taskExecutor.checkByteBufferArguments(rayFunctionInfo);
List<NativeRayObject> returnObjects = taskExecutor.execute(rayFunctionInfo, args);
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
// Update actor context map ASAP in case objectStore.putRaw triggered the next actor task
// on this actor.
@@ -13,6 +13,7 @@ import io.ray.runtime.generated.Common.TaskType;
import io.ray.runtime.object.NativeRayObject;
import io.ray.runtime.object.ObjectSerializer;
import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
@@ -30,6 +31,8 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();
private final ThreadLocal<RayFunction> localRayFunction = new ThreadLocal<>();
static class ActorContext {
/**
@@ -61,10 +64,34 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
}
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
List<NativeRayObject> argsBytes) {
runtime.setIsContextSet(true);
private RayFunction getRayFunction(List<String> rayFunctionInfo) {
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
return runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
}
/**
* The return value indicates which parameters are ByteBuffer.
*/
protected boolean[] checkByteBufferArguments(List<String> rayFunctionInfo) {
localRayFunction.set(null);
try {
localRayFunction.set(getRayFunction(rayFunctionInfo));
} catch (Throwable e) {
// Ignore the exception.
return null;
}
Class<?>[] types = localRayFunction.get().executable.getParameterTypes();
boolean[] results = new boolean[types.length];
for (int i = 0; i < types.length; i++) {
results[i] = types[i] == ByteBuffer.class;
}
return results;
}
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
List<Object> argsBytes) {
runtime.setIsContextSet(true);
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
LOGGER.debug("Executing task {}", taskId);
@@ -80,11 +107,14 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
List<NativeRayObject> returnObjects = new ArrayList<>();
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
RayFunction rayFunction = null;
RayFunction rayFunction = localRayFunction.get();
try {
// Find the executable object.
rayFunction = runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
if (rayFunction == null) {
// Failed to get RayFunction in checkByteBufferArguments. Redo here to throw
// the exception again.
rayFunction = getRayFunction(rayFunctionInfo);
}
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
@@ -132,7 +162,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
LOGGER.error("Error executing task " + taskId, e);
if (taskType != TaskType.ACTOR_CREATION_TASK) {
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
boolean isCrossLanguage = functionDescriptor.signature.equals("");
boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals("");
if (hasReturn || isCrossLanguage) {
returnObjects.add(ObjectSerializer
.serialize(new RayTaskException("Error executing task " + taskId, e)));
@@ -4,6 +4,8 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import org.testng.Assert;
@@ -63,6 +65,10 @@ public class RayCallTest extends BaseTest {
TestUtils.getRuntime().getObjectStore().put(1, objectId);
}
private static ByteBuffer testByteBuffer(ByteBuffer buffer) {
return buffer;
}
/**
* Test calling and returning different types.
*/
@@ -82,6 +88,11 @@ public class RayCallTest extends BaseTest {
Assert.assertEquals(map, Ray.task(RayCallTest::testMap, map).remote().get());
TestUtils.LargeObject largeObject = new TestUtils.LargeObject();
Assert.assertNotNull(Ray.task(RayCallTest::testLargeObject, largeObject).remote().get());
ByteBuffer buffer1 = ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8));
ByteBuffer buffer2 = Ray.task(RayCallTest::testByteBuffer, buffer1).remote().get();
byte[] bytes = new byte[buffer2.remaining()];
buffer2.get(bytes);
Assert.assertEquals("foo", new String(bytes, StandardCharsets.UTF_8));
// TODO(edoakes): this test doesn't work now that we've switched to direct call
// mode. To make it work, we need to implement the same protocol for resolving