mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:53:14 +08:00
[Java] Avoid data copy from C++ to Java for ByteBuffer type (#9033)
This commit is contained in:
@@ -7,6 +7,7 @@ import io.ray.api.exception.UnreconstructableException;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.runtime.generated.Gcs.ErrorType;
|
||||
import io.ray.runtime.serializer.Serializer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
@@ -45,6 +46,9 @@ public class ObjectSerializer {
|
||||
if (meta != null && meta.length > 0) {
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
if (Arrays.equals(meta, OBJECT_METADATA_TYPE_RAW)) {
|
||||
if (objectType == ByteBuffer.class) {
|
||||
return ByteBuffer.wrap(data);
|
||||
}
|
||||
return data;
|
||||
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_CROSS_LANGUAGE) ||
|
||||
Arrays.equals(meta, OBJECT_METADATA_TYPE_JAVA)) {
|
||||
@@ -81,6 +85,17 @@ public class ObjectSerializer {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW);
|
||||
} else if (object instanceof ByteBuffer) {
|
||||
// Serialize ByteBuffer to raw bytes.
|
||||
ByteBuffer buffer = (ByteBuffer) object;
|
||||
byte[] bytes;
|
||||
if (buffer.hasArray()) {
|
||||
bytes = buffer.array();
|
||||
} else {
|
||||
bytes = new byte[buffer.remaining()];
|
||||
buffer.get(bytes);
|
||||
}
|
||||
return new NativeRayObject(bytes, OBJECT_METADATA_TYPE_RAW);
|
||||
} else if (object instanceof RayTaskException) {
|
||||
byte[] serializedBytes = Serializer.encode(object).getLeft();
|
||||
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package io.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ObjectId;
|
||||
@@ -7,6 +8,7 @@ import io.ray.runtime.RayRuntimeInternal;
|
||||
import io.ray.runtime.generated.Common.Language;
|
||||
import io.ray.runtime.object.NativeRayObject;
|
||||
import io.ray.runtime.object.ObjectSerializer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
@@ -68,12 +70,19 @@ public class ArgumentsBuilder {
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert list of NativeRayObject to real function arguments.
|
||||
* Convert list of NativeRayObject/ByteBuffer to real function arguments.
|
||||
*/
|
||||
public static Object[] unwrap(List<NativeRayObject> args, Class<?>[] types) {
|
||||
public static Object[] unwrap(List<Object> args, Class<?>[] types) {
|
||||
Object[] realArgs = new Object[args.size()];
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, types[i]);
|
||||
Object arg = args.get(i);
|
||||
Preconditions.checkState(arg instanceof ByteBuffer || arg instanceof NativeRayObject);
|
||||
if (arg instanceof ByteBuffer) {
|
||||
Preconditions.checkState(types[i] == ByteBuffer.class);
|
||||
realArgs[i] = arg;
|
||||
} else {
|
||||
realArgs[i] = ObjectSerializer.deserialize((NativeRayObject) arg, null, types[i]);
|
||||
}
|
||||
}
|
||||
return realArgs;
|
||||
}
|
||||
|
||||
@@ -311,8 +311,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
|
||||
: UniqueId.randomId();
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
|
||||
List<NativeRayObject> returnObjects = taskExecutor
|
||||
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
|
||||
List<String> rayFunctionInfo = getJavaFunctionDescriptor(taskSpec).toList();
|
||||
taskExecutor.checkByteBufferArguments(rayFunctionInfo);
|
||||
List<NativeRayObject> returnObjects = taskExecutor.execute(rayFunctionInfo, args);
|
||||
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
|
||||
// Update actor context map ASAP in case objectStore.putRaw triggered the next actor task
|
||||
// on this actor.
|
||||
|
||||
@@ -13,6 +13,7 @@ import io.ray.runtime.generated.Common.TaskType;
|
||||
import io.ray.runtime.object.NativeRayObject;
|
||||
import io.ray.runtime.object.ObjectSerializer;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
@@ -30,6 +31,8 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
||||
|
||||
private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();
|
||||
|
||||
private final ThreadLocal<RayFunction> localRayFunction = new ThreadLocal<>();
|
||||
|
||||
static class ActorContext {
|
||||
|
||||
/**
|
||||
@@ -61,10 +64,34 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
||||
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
|
||||
}
|
||||
|
||||
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
|
||||
List<NativeRayObject> argsBytes) {
|
||||
runtime.setIsContextSet(true);
|
||||
private RayFunction getRayFunction(List<String> rayFunctionInfo) {
|
||||
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
|
||||
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
|
||||
return runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
|
||||
}
|
||||
|
||||
/**
|
||||
* The return value indicates which parameters are ByteBuffer.
|
||||
*/
|
||||
protected boolean[] checkByteBufferArguments(List<String> rayFunctionInfo) {
|
||||
localRayFunction.set(null);
|
||||
try {
|
||||
localRayFunction.set(getRayFunction(rayFunctionInfo));
|
||||
} catch (Throwable e) {
|
||||
// Ignore the exception.
|
||||
return null;
|
||||
}
|
||||
Class<?>[] types = localRayFunction.get().executable.getParameterTypes();
|
||||
boolean[] results = new boolean[types.length];
|
||||
for (int i = 0; i < types.length; i++) {
|
||||
results[i] = types[i] == ByteBuffer.class;
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
|
||||
List<Object> argsBytes) {
|
||||
runtime.setIsContextSet(true);
|
||||
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
|
||||
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
|
||||
LOGGER.debug("Executing task {}", taskId);
|
||||
@@ -80,11 +107,14 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
||||
|
||||
List<NativeRayObject> returnObjects = new ArrayList<>();
|
||||
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
|
||||
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
|
||||
RayFunction rayFunction = null;
|
||||
RayFunction rayFunction = localRayFunction.get();
|
||||
try {
|
||||
// Find the executable object.
|
||||
rayFunction = runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
|
||||
if (rayFunction == null) {
|
||||
// Failed to get RayFunction in checkByteBufferArguments. Redo here to throw
|
||||
// the exception again.
|
||||
rayFunction = getRayFunction(rayFunctionInfo);
|
||||
}
|
||||
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
|
||||
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
|
||||
|
||||
@@ -132,7 +162,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
||||
LOGGER.error("Error executing task " + taskId, e);
|
||||
if (taskType != TaskType.ACTOR_CREATION_TASK) {
|
||||
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
|
||||
boolean isCrossLanguage = functionDescriptor.signature.equals("");
|
||||
boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals("");
|
||||
if (hasReturn || isCrossLanguage) {
|
||||
returnObjects.add(ObjectSerializer
|
||||
.serialize(new RayTaskException("Error executing task " + taskId, e)));
|
||||
|
||||
@@ -4,6 +4,8 @@ import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.testng.Assert;
|
||||
@@ -63,6 +65,10 @@ public class RayCallTest extends BaseTest {
|
||||
TestUtils.getRuntime().getObjectStore().put(1, objectId);
|
||||
}
|
||||
|
||||
private static ByteBuffer testByteBuffer(ByteBuffer buffer) {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Test calling and returning different types.
|
||||
*/
|
||||
@@ -82,6 +88,11 @@ public class RayCallTest extends BaseTest {
|
||||
Assert.assertEquals(map, Ray.task(RayCallTest::testMap, map).remote().get());
|
||||
TestUtils.LargeObject largeObject = new TestUtils.LargeObject();
|
||||
Assert.assertNotNull(Ray.task(RayCallTest::testLargeObject, largeObject).remote().get());
|
||||
ByteBuffer buffer1 = ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8));
|
||||
ByteBuffer buffer2 = Ray.task(RayCallTest::testByteBuffer, buffer1).remote().get();
|
||||
byte[] bytes = new byte[buffer2.remaining()];
|
||||
buffer2.get(bytes);
|
||||
Assert.assertEquals("foo", new String(bytes, StandardCharsets.UTF_8));
|
||||
|
||||
// TODO(edoakes): this test doesn't work now that we've switched to direct call
|
||||
// mode. To make it work, we need to implement the same protocol for resolving
|
||||
|
||||
Reference in New Issue
Block a user