mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 22:36:53 +08:00
Support metadata for passing by value task arguments (#5527)
This commit is contained in:
@@ -176,8 +176,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
|
||||
Object[] args, int numReturns, CallOptions options) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
|
||||
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
|
||||
functionArgs, numReturns, options);
|
||||
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
|
||||
@@ -190,8 +189,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
private RayObject callActorFunction(RayActor rayActor,
|
||||
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
|
||||
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
|
||||
functionDescriptor, functionArgs, numReturns, null);
|
||||
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
|
||||
@@ -204,14 +202,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
private RayActor createActorImpl(FunctionDescriptor functionDescriptor,
|
||||
Object[] args, ActorCreationOptions options) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
|
||||
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
|
||||
Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions));
|
||||
}
|
||||
RayActor actor = taskSubmitter
|
||||
.createActor(functionDescriptor, functionArgs,
|
||||
options);
|
||||
RayActor actor = taskSubmitter.createActor(functionDescriptor, functionArgs, options);
|
||||
return actor;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package org.ray.runtime.object;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
/**
|
||||
* Binary representation of ray object.
|
||||
* Binary representation of a ray object. See `RayObject` class in C++ for details.
|
||||
*/
|
||||
public class NativeRayObject {
|
||||
|
||||
@@ -9,8 +11,21 @@ public class NativeRayObject {
|
||||
public byte[] metadata;
|
||||
|
||||
public NativeRayObject(byte[] data, byte[] metadata) {
|
||||
Preconditions.checkState(bufferLength(data) > 0 || bufferLength(metadata) > 0);
|
||||
this.data = data;
|
||||
this.metadata = metadata;
|
||||
}
|
||||
|
||||
private static int bufferLength(byte[] buffer) {
|
||||
if (buffer == null) {
|
||||
return 0;
|
||||
}
|
||||
return buffer.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "<data>: " + bufferLength(data) + ", <metadata>: " + bufferLength(metadata);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
package org.ray.runtime.object;
|
||||
|
||||
import java.util.Arrays;
|
||||
import org.ray.api.exception.RayActorException;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.exception.RayWorkerException;
|
||||
import org.ray.api.exception.UnreconstructableException;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.generated.Gcs.ErrorType;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
|
||||
/**
|
||||
* Serialize to and deserialize from {@link NativeRayObject}. Metadata is generated during
|
||||
* serialization and respected during deserialization.
|
||||
*/
|
||||
public class ObjectSerializer {
|
||||
|
||||
private static final byte[] WORKER_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes();
|
||||
private static final byte[] ACTOR_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
|
||||
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
|
||||
|
||||
/**
|
||||
* Deserialize an object from an {@link NativeRayObject} instance.
|
||||
*
|
||||
* @param nativeRayObject The object to deserialize.
|
||||
* @param objectId The associated object ID of the object.
|
||||
* @param classLoader The classLoader of the object.
|
||||
* @return The deserialized object.
|
||||
*/
|
||||
public static Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId,
|
||||
ClassLoader classLoader) {
|
||||
byte[] meta = nativeRayObject.metadata;
|
||||
byte[] data = nativeRayObject.data;
|
||||
|
||||
if (meta != null && meta.length > 0) {
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
if (Arrays.equals(meta, RAW_TYPE_META)) {
|
||||
return data;
|
||||
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
return RayWorkerException.INSTANCE;
|
||||
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
|
||||
return RayActorException.INSTANCE;
|
||||
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
|
||||
return new UnreconstructableException(objectId);
|
||||
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
|
||||
return Serializer.decode(data, classLoader);
|
||||
}
|
||||
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
|
||||
} else {
|
||||
// If data is not null, deserialize the Java object.
|
||||
return Serializer.decode(data, classLoader);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize an Java object to an {@link NativeRayObject} instance.
|
||||
*
|
||||
* @param object The object to serialize.
|
||||
* @return The serialized object.
|
||||
*/
|
||||
public static NativeRayObject serialize(Object object) {
|
||||
if (object instanceof NativeRayObject) {
|
||||
return (NativeRayObject) object;
|
||||
} else if (object instanceof byte[]) {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
return new NativeRayObject((byte[]) object, RAW_TYPE_META);
|
||||
} else if (object instanceof RayTaskException) {
|
||||
return new NativeRayObject(Serializer.encode(object),
|
||||
TASK_EXECUTION_EXCEPTION_META);
|
||||
} else {
|
||||
return new NativeRayObject(Serializer.encode(object), null);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,40 +2,21 @@ package org.ray.runtime.object;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayActorException;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.exception.RayWorkerException;
|
||||
import org.ray.api.exception.UnreconstructableException;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.context.WorkerContext;
|
||||
import org.ray.runtime.generated.Gcs.ErrorType;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
|
||||
/**
|
||||
* A class that is used to put/get objects to/from the object store.
|
||||
*/
|
||||
public abstract class ObjectStore {
|
||||
|
||||
private static final byte[] WORKER_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes();
|
||||
private static final byte[] ACTOR_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
|
||||
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
|
||||
|
||||
private final WorkerContext workerContext;
|
||||
|
||||
public ObjectStore(WorkerContext workerContext) {
|
||||
@@ -65,7 +46,11 @@ public abstract class ObjectStore {
|
||||
* @return Id of the object.
|
||||
*/
|
||||
public ObjectId put(Object object) {
|
||||
return putRaw(serialize(object));
|
||||
if (object instanceof NativeRayObject) {
|
||||
throw new IllegalArgumentException(
|
||||
"Trying to put a NativeRayObject. Please use putRaw instead.");
|
||||
}
|
||||
return putRaw(ObjectSerializer.serialize(object));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -77,7 +62,11 @@ public abstract class ObjectStore {
|
||||
* @param objectId Object id.
|
||||
*/
|
||||
public void put(Object object, ObjectId objectId) {
|
||||
putRaw(serialize(object), objectId);
|
||||
if (object instanceof NativeRayObject) {
|
||||
throw new IllegalArgumentException(
|
||||
"Trying to put a NativeRayObject. Please use putRaw instead.");
|
||||
}
|
||||
putRaw(ObjectSerializer.serialize(object), objectId);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -106,7 +95,8 @@ public abstract class ObjectStore {
|
||||
NativeRayObject dataAndMeta = dataAndMetaList.get(i);
|
||||
Object object = null;
|
||||
if (dataAndMeta != null) {
|
||||
object = deserialize(dataAndMeta, ids.get(i));
|
||||
object = ObjectSerializer
|
||||
.deserialize(dataAndMeta, ids.get(i), workerContext.getCurrentClassLoader());
|
||||
}
|
||||
if (object instanceof RayException) {
|
||||
// If the object is a `RayException`, it means that an error occurred during task
|
||||
@@ -174,57 +164,4 @@ public abstract class ObjectStore {
|
||||
*/
|
||||
public abstract void delete(List<ObjectId> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks);
|
||||
|
||||
/**
|
||||
* Deserialize an object.
|
||||
*
|
||||
* @param nativeRayObject The object to deserialize.
|
||||
* @param objectId The associated object ID of the object.
|
||||
* @return The deserialized object.
|
||||
*/
|
||||
public Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId) {
|
||||
byte[] meta = nativeRayObject.metadata;
|
||||
byte[] data = nativeRayObject.data;
|
||||
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
if (meta != null && meta.length > 0) {
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
if (Arrays.equals(meta, RAW_TYPE_META)) {
|
||||
return data;
|
||||
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
return RayWorkerException.INSTANCE;
|
||||
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
|
||||
return RayActorException.INSTANCE;
|
||||
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
|
||||
return new UnreconstructableException(objectId);
|
||||
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
|
||||
return Serializer.decode(data, workerContext.getCurrentClassLoader());
|
||||
}
|
||||
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
|
||||
} else {
|
||||
// If data is not null, deserialize the Java object.
|
||||
return Serializer.decode(data, workerContext.getCurrentClassLoader());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize an object.
|
||||
*
|
||||
* @param object The object to serialize.
|
||||
* @return The serialized object.
|
||||
*/
|
||||
public NativeRayObject serialize(Object object) {
|
||||
if (object instanceof NativeRayObject) {
|
||||
return (NativeRayObject) object;
|
||||
} else if (object instanceof byte[]) {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
return new NativeRayObject((byte[]) object, RAW_TYPE_META);
|
||||
} else if (object instanceof RayTaskException) {
|
||||
return new NativeRayObject(Serializer.encode(object),
|
||||
TASK_EXECUTION_EXCEPTION_META);
|
||||
} else {
|
||||
return new NativeRayObject(Serializer.encode(object), null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,7 @@ import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
import org.ray.runtime.object.ObjectSerializer;
|
||||
|
||||
/**
|
||||
* Helper methods to convert arguments from/to objects.
|
||||
@@ -26,37 +25,29 @@ public class ArgumentsBuilder {
|
||||
/**
|
||||
* Convert real function arguments to task spec arguments.
|
||||
*/
|
||||
public static List<FunctionArg> wrap(Object[] args, boolean crossLanguage) {
|
||||
public static List<FunctionArg> wrap(Object[] args) {
|
||||
List<FunctionArg> ret = new ArrayList<>();
|
||||
for (Object arg : args) {
|
||||
ObjectId id = null;
|
||||
byte[] data = null;
|
||||
if (arg == null) {
|
||||
data = Serializer.encode(null);
|
||||
} else if (arg instanceof RayObject) {
|
||||
NativeRayObject value = null;
|
||||
if (arg instanceof RayObject) {
|
||||
id = ((RayObject) arg).getId();
|
||||
} else if (arg instanceof byte[] && crossLanguage) {
|
||||
// If the argument is a byte array and will be used by a different language,
|
||||
// do not inline this argument. Because the other language doesn't know how
|
||||
// to deserialize it.
|
||||
id = Ray.put(arg).getId();
|
||||
} else {
|
||||
byte[] serialized = Serializer.encode(arg);
|
||||
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
value = ObjectSerializer.serialize(arg);
|
||||
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
|
||||
}
|
||||
id = ((AbstractRayRuntime) runtime).getObjectStore()
|
||||
.put(new NativeRayObject(serialized, null));
|
||||
} else {
|
||||
data = serialized;
|
||||
.putRaw(value);
|
||||
value = null;
|
||||
}
|
||||
}
|
||||
if (id != null) {
|
||||
ret.add(FunctionArg.passByReference(id));
|
||||
} else {
|
||||
ret.add(FunctionArg.passByValue(data));
|
||||
ret.add(FunctionArg.passByValue(value));
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
@@ -65,10 +56,10 @@ public class ArgumentsBuilder {
|
||||
/**
|
||||
* Convert list of NativeRayObject to real function arguments.
|
||||
*/
|
||||
public static Object[] unwrap(ObjectStore objectStore, List<NativeRayObject> args) {
|
||||
public static Object[] unwrap(List<NativeRayObject> args, ClassLoader classLoader) {
|
||||
Object[] realArgs = new Object[args.size()];
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
realArgs[i] = objectStore.deserialize(args.get(i), null);
|
||||
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, classLoader);
|
||||
}
|
||||
return realArgs;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
|
||||
/**
|
||||
* Represents a function argument in task spec.
|
||||
@@ -16,11 +18,12 @@ public class FunctionArg {
|
||||
/**
|
||||
* Serialized data of this argument (passed by value).
|
||||
*/
|
||||
public final byte[] data;
|
||||
public final NativeRayObject value;
|
||||
|
||||
private FunctionArg(ObjectId id, byte[] data) {
|
||||
private FunctionArg(ObjectId id, NativeRayObject value) {
|
||||
Preconditions.checkState((id == null) != (value == null));
|
||||
this.id = id;
|
||||
this.data = data;
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -33,8 +36,8 @@ public class FunctionArg {
|
||||
/**
|
||||
* Create a FunctionArg that will be passed by value.
|
||||
*/
|
||||
public static FunctionArg passByValue(byte[] data) {
|
||||
return new FunctionArg(null, data);
|
||||
public static FunctionArg passByValue(NativeRayObject value) {
|
||||
return new FunctionArg(null, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -42,7 +45,7 @@ public class FunctionArg {
|
||||
if (id != null) {
|
||||
return "<id>: " + id.toString();
|
||||
} else {
|
||||
return "<data>: " + data.length;
|
||||
return value.toString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,7 +154,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
.collect(Collectors.toList()))
|
||||
.addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder()
|
||||
.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build()
|
||||
: TaskArg.newBuilder().setData(ByteString.copyFrom(arg.data)).build())
|
||||
: TaskArg.newBuilder().setData(ByteString.copyFrom(arg.value.data))
|
||||
.setMetadata(arg.value.metadata != null ? ByteString
|
||||
.copyFrom(arg.value.metadata) : ByteString.EMPTY).build())
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
@@ -233,7 +235,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
|
||||
.map(arg -> arg.id != null ?
|
||||
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
|
||||
: new NativeRayObject(arg.data, null))
|
||||
: arg.value)
|
||||
.collect(Collectors.toList());
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
|
||||
List<NativeRayObject> returnObjects = taskExecutor
|
||||
@@ -246,7 +248,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
// If the task is an actor task or an actor creation task,
|
||||
// put the dummy object in object store, so those tasks which depends on it
|
||||
// can be executed.
|
||||
putObject = new NativeRayObject(new byte[]{}, new byte[]{});
|
||||
putObject = new NativeRayObject(new byte[]{1}, null);
|
||||
} else {
|
||||
putObject = returnObjects.get(i);
|
||||
}
|
||||
@@ -279,7 +281,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
functionArgs.add(FunctionArg
|
||||
.passByReference(new ObjectId(arg.getObjectIds(0).toByteArray())));
|
||||
} else {
|
||||
functionArgs.add(FunctionArg.passByValue(arg.getData().toByteArray()));
|
||||
functionArgs.add(FunctionArg.passByValue(
|
||||
new NativeRayObject(arg.getData().toByteArray(), arg.getMetadata().toByteArray())));
|
||||
}
|
||||
}
|
||||
return functionArgs;
|
||||
|
||||
@@ -17,6 +17,7 @@ import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectSerializer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -87,7 +88,7 @@ public final class TaskExecutor {
|
||||
actor = currentActor;
|
||||
|
||||
}
|
||||
Object[] args = ArgumentsBuilder.unwrap(runtime.getObjectStore(), argsBytes);
|
||||
Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader);
|
||||
// Execute the task.
|
||||
Object result;
|
||||
if (!rayFunction.isConstructor()) {
|
||||
@@ -102,7 +103,7 @@ public final class TaskExecutor {
|
||||
maybeSaveCheckpoint(actor, runtime.getWorkerContext().getCurrentActorId());
|
||||
}
|
||||
if (rayFunction.hasReturn()) {
|
||||
returnObjects.add(runtime.getObjectStore().serialize(result));
|
||||
returnObjects.add(ObjectSerializer.serialize(result));
|
||||
}
|
||||
} else {
|
||||
// TODO (kfstorm): handle checkpoint in core worker.
|
||||
@@ -113,8 +114,8 @@ public final class TaskExecutor {
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Error executing task " + taskId, e);
|
||||
if (taskType != TaskType.ACTOR_CREATION_TASK) {
|
||||
if(rayFunction.hasReturn()) {
|
||||
returnObjects.add(runtime.getObjectStore()
|
||||
if (rayFunction.hasReturn()) {
|
||||
returnObjects.add(ObjectSerializer
|
||||
.serialize(new RayTaskException("Error executing task " + taskId, e)));
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -3,9 +3,9 @@ package org.ray.api.test;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.context.WorkerContext;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.ray.runtime.object.ObjectSerializer;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@@ -14,10 +14,10 @@ public class RaySerializerTest extends BaseMultiLanguageTest {
|
||||
@Test
|
||||
public void testSerializePyActor() {
|
||||
RayPyActor pyActor = Ray.createPyActor("test", "RaySerializerTest");
|
||||
ObjectStore objectStore = TestUtils.getRuntime().getObjectStore();
|
||||
NativeRayObject nativeRayObject = objectStore.serialize(pyActor);
|
||||
RayPyActor result = (RayPyActor) objectStore
|
||||
.deserialize(nativeRayObject, ObjectId.fromRandom());
|
||||
WorkerContext workerContext = TestUtils.getRuntime().getWorkerContext();
|
||||
NativeRayObject nativeRayObject = ObjectSerializer.serialize(pyActor);
|
||||
RayPyActor result = (RayPyActor) ObjectSerializer
|
||||
.deserialize(nativeRayObject, null, workerContext.getCurrentClassLoader());
|
||||
Assert.assertEquals(result.getId(), pyActor.getId());
|
||||
Assert.assertEquals(result.getModuleName(), "test");
|
||||
Assert.assertEquals(result.getClassName(), "RaySerializerTest");
|
||||
|
||||
@@ -54,8 +54,10 @@ cdef extern from "ray/common/task/task_spec.h" namespace "ray" nogil:
|
||||
int ArgIdCount(uint64_t arg_index) const
|
||||
CObjectID ArgId(uint64_t arg_index, uint64_t id_index) const
|
||||
CObjectID ReturnId(uint64_t return_index) const
|
||||
const uint8_t *ArgVal(uint64_t arg_index) const
|
||||
size_t ArgValLength(uint64_t arg_index) const
|
||||
const uint8_t *ArgData(uint64_t arg_index) const
|
||||
size_t ArgDataSize(uint64_t arg_index) const
|
||||
const uint8_t *ArgMetadata(uint64_t arg_index) const
|
||||
size_t ArgMetadataSize(uint64_t arg_index) const
|
||||
double GetRequiredResource(const c_string &resource_name) const
|
||||
const ResourceSet GetRequiredResources() const
|
||||
const ResourceSet GetRequiredPlacementResources() const
|
||||
@@ -86,7 +88,7 @@ cdef extern from "ray/common/task/task_util.h" namespace "ray" nogil:
|
||||
|
||||
TaskSpecBuilder &AddByRefArg(const CObjectID &arg_id)
|
||||
|
||||
TaskSpecBuilder &AddByValueArg(const c_string &data)
|
||||
TaskSpecBuilder &AddByValueArg(const c_string &data, const c_string &metadata)
|
||||
|
||||
TaskSpecBuilder &SetActorCreationTaskSpec(
|
||||
const CActorID &actor_id, uint64_t max_reconstructions,
|
||||
|
||||
@@ -12,6 +12,7 @@ from ray.includes.task cimport (
|
||||
TaskSpecBuilder,
|
||||
TaskTableData,
|
||||
)
|
||||
from ray.ray_constants import RAW_BUFFER_METADATA
|
||||
from ray.utils import decode
|
||||
|
||||
|
||||
@@ -68,10 +69,12 @@ cdef class TaskSpec:
|
||||
for arg in arguments:
|
||||
if isinstance(arg, ObjectID):
|
||||
builder.AddByRefArg((<ObjectID>arg).native())
|
||||
elif isinstance(arg, bytes):
|
||||
builder.AddByValueArg(arg, RAW_BUFFER_METADATA)
|
||||
else:
|
||||
pickled_str = pickle.dumps(
|
||||
arg, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
builder.AddByValueArg(pickled_str)
|
||||
builder.AddByValueArg(pickled_str, b'')
|
||||
|
||||
if not actor_creation_id.is_nil():
|
||||
# Actor creation task.
|
||||
@@ -180,9 +183,12 @@ cdef class TaskSpec:
|
||||
arg_list.append(
|
||||
ObjectID(task_spec.ArgId(i, 0).Binary()))
|
||||
else:
|
||||
serialized_str = (
|
||||
task_spec.ArgVal(i)[:task_spec.ArgValLength(i)])
|
||||
obj = pickle.loads(serialized_str)
|
||||
data = (task_spec.ArgData(i)[:task_spec.ArgDataSize(i)])
|
||||
metadata = (task_spec.ArgMetadata(i)[:task_spec.ArgMetadataSize(i)])
|
||||
if metadata == RAW_BUFFER_METADATA:
|
||||
obj = data
|
||||
else:
|
||||
obj = pickle.loads(data)
|
||||
arg_list.append(obj)
|
||||
elif lang == <int32_t>LANGUAGE_JAVA:
|
||||
arg_list = num_args * ["<java-argument>"]
|
||||
|
||||
@@ -39,6 +39,11 @@ class LocalMemoryBuffer : public Buffer {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// By default when initializing a LocalMemoryBuffer with a data pointer and a length,
|
||||
/// it just assigns the pointer and length without coping the data content. This is
|
||||
/// for performance reasons. In this case the buffer cannot ensure data validity. It
|
||||
/// instead relies on the lifetime passed in data pointer.
|
||||
///
|
||||
/// \param data The data pointer to the passed-in buffer.
|
||||
/// \param size The size of the passed in buffer.
|
||||
/// \param copy_data If true, data will be copied and owned by this buffer,
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
#ifndef RAY_COMMON_RAY_OBJECT_H
|
||||
#define RAY_COMMON_RAY_OBJECT_H
|
||||
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/util/logging.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
/// Binary representation of a ray object, consisting of buffer pointers to data and
|
||||
/// metadata. A ray object may have both data and metadata, or only one of them.
|
||||
class RayObject {
|
||||
public:
|
||||
/// Create a ray object instance.
|
||||
///
|
||||
/// Set `copy_data` to `false` is fine for most cases - for example when putting
|
||||
/// an object into store with a temporary RayObject, and we don't want to do an extra
|
||||
/// copy. But in some cases we do want to always hold a valid data - for example, memory
|
||||
/// store uses RayObject to represent objects, in this case we actually want the object
|
||||
/// data to remain valid after user puts it into store.
|
||||
///
|
||||
/// \param[in] data Data of the ray object.
|
||||
/// \param[in] metadata Metadata of the ray object.
|
||||
/// \param[in] copy_data Whether this class should hold a copy of data.
|
||||
RayObject(const std::shared_ptr<Buffer> &data, const std::shared_ptr<Buffer> &metadata,
|
||||
bool copy_data = false)
|
||||
: data_(data), metadata_(metadata), has_data_copy_(copy_data) {
|
||||
RAY_CHECK(!data || data_->Size())
|
||||
<< "Zero-length buffers are not allowed when constructing a RayObject.";
|
||||
RAY_CHECK(!metadata || metadata->Size())
|
||||
<< "Zero-length buffers are not allowed when constructing a RayObject.";
|
||||
|
||||
if (has_data_copy_) {
|
||||
// If this object is required to hold a copy of the data,
|
||||
// make a copy if the passed in buffers don't already have a copy.
|
||||
if (data_ && !data_->OwnsData()) {
|
||||
data_ = std::make_shared<LocalMemoryBuffer>(data_->Data(), data_->Size(),
|
||||
/*copy_data=*/true);
|
||||
}
|
||||
|
||||
if (metadata_ && !metadata_->OwnsData()) {
|
||||
metadata_ = std::make_shared<LocalMemoryBuffer>(
|
||||
metadata_->Data(), metadata_->Size(), /*copy_data=*/true);
|
||||
}
|
||||
}
|
||||
|
||||
RAY_CHECK(data_ || metadata_) << "Data and metadata cannot both be empty.";
|
||||
}
|
||||
|
||||
/// Return the data of the ray object.
|
||||
const std::shared_ptr<Buffer> &GetData() const { return data_; };
|
||||
|
||||
/// Return the metadata of the ray object.
|
||||
const std::shared_ptr<Buffer> &GetMetadata() const { return metadata_; };
|
||||
|
||||
uint64_t GetSize() const {
|
||||
uint64_t size = 0;
|
||||
size += (data_ != nullptr) ? data_->Size() : 0;
|
||||
size += (metadata_ != nullptr) ? metadata_->Size() : 0;
|
||||
return size;
|
||||
}
|
||||
|
||||
/// Whether this object has data.
|
||||
bool HasData() const { return data_ != nullptr; }
|
||||
|
||||
/// Whether this object has metadata.
|
||||
bool HasMetadata() const { return metadata_ != nullptr; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Buffer> data_;
|
||||
std::shared_ptr<Buffer> metadata_;
|
||||
/// Whether this class holds a data copy.
|
||||
bool has_data_copy_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_COMMON_BUFFER_H
|
||||
@@ -53,14 +53,22 @@ ObjectID TaskSpecification::ArgId(size_t arg_index, size_t id_index) const {
|
||||
return ObjectID::FromBinary(message_->args(arg_index).object_ids(id_index));
|
||||
}
|
||||
|
||||
const uint8_t *TaskSpecification::ArgVal(size_t arg_index) const {
|
||||
const uint8_t *TaskSpecification::ArgData(size_t arg_index) const {
|
||||
return reinterpret_cast<const uint8_t *>(message_->args(arg_index).data().data());
|
||||
}
|
||||
|
||||
size_t TaskSpecification::ArgValLength(size_t arg_index) const {
|
||||
size_t TaskSpecification::ArgDataSize(size_t arg_index) const {
|
||||
return message_->args(arg_index).data().size();
|
||||
}
|
||||
|
||||
const uint8_t *TaskSpecification::ArgMetadata(size_t arg_index) const {
|
||||
return reinterpret_cast<const uint8_t *>(message_->args(arg_index).metadata().data());
|
||||
}
|
||||
|
||||
size_t TaskSpecification::ArgMetadataSize(size_t arg_index) const {
|
||||
return message_->args(arg_index).metadata().size();
|
||||
}
|
||||
|
||||
const ResourceSet TaskSpecification::GetRequiredResources() const {
|
||||
return required_resources_;
|
||||
}
|
||||
|
||||
@@ -70,9 +70,13 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
|
||||
|
||||
ObjectID ReturnId(size_t return_index) const;
|
||||
|
||||
const uint8_t *ArgVal(size_t arg_index) const;
|
||||
const uint8_t *ArgData(size_t arg_index) const;
|
||||
|
||||
size_t ArgValLength(size_t arg_index) const;
|
||||
size_t ArgDataSize(size_t arg_index) const;
|
||||
|
||||
const uint8_t *ArgMetadata(size_t arg_index) const;
|
||||
|
||||
size_t ArgMetadataSize(size_t arg_index) const;
|
||||
|
||||
/// Return the resources that are to be acquired during the execution of this
|
||||
/// task.
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#ifndef RAY_COMMON_TASK_TASK_UTIL_H
|
||||
#define RAY_COMMON_TASK_TASK_UTIL_H
|
||||
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/common/ray_object.h"
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/protobuf/common.pb.h"
|
||||
|
||||
@@ -56,19 +58,29 @@ class TaskSpecBuilder {
|
||||
/// Add a by-value argument to the task.
|
||||
///
|
||||
/// \param data String object that contains the data.
|
||||
/// \param metadata String object that contains the metadata.
|
||||
/// \return Reference to the builder object itself.
|
||||
TaskSpecBuilder &AddByValueArg(const std::string &data) {
|
||||
message_->add_args()->set_data(data);
|
||||
TaskSpecBuilder &AddByValueArg(const std::string &data, const std::string &metadata) {
|
||||
auto arg = message_->add_args();
|
||||
arg->set_data(data);
|
||||
arg->set_metadata(metadata);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Add a by-value argument to the task.
|
||||
///
|
||||
/// \param data Pointer to the data.
|
||||
/// \param size Size of the data.
|
||||
/// \param value the RayObject instance that contains the data and the metadata.
|
||||
/// \return Reference to the builder object itself.
|
||||
TaskSpecBuilder &AddByValueArg(const void *data, size_t size) {
|
||||
message_->add_args()->set_data(data, size);
|
||||
TaskSpecBuilder &AddByValueArg(const RayObject &value) {
|
||||
auto arg = message_->add_args();
|
||||
if (value.HasData()) {
|
||||
const auto &data = value.GetData();
|
||||
arg->set_data(data->Data(), data->Size());
|
||||
}
|
||||
if (value.HasMetadata()) {
|
||||
const auto &metadata = value.GetMetadata();
|
||||
arg->set_metadata(metadata->Data(), metadata->Size());
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/ray_object.h"
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/util/util.h"
|
||||
@@ -31,12 +31,13 @@ class TaskArg {
|
||||
return TaskArg(std::make_shared<ObjectID>(object_id), nullptr);
|
||||
}
|
||||
|
||||
/// Create a pass-by-reference task argument.
|
||||
/// Create a pass-by-value task argument.
|
||||
///
|
||||
/// \param[in] object_id Id of the argument.
|
||||
/// \param[in] value Value of the argument.
|
||||
/// \return The task argument.
|
||||
static TaskArg PassByValue(const std::shared_ptr<Buffer> &data) {
|
||||
return TaskArg(nullptr, data);
|
||||
static TaskArg PassByValue(const std::shared_ptr<RayObject> &value) {
|
||||
RAY_CHECK(value) << "Value can't be null.";
|
||||
return TaskArg(nullptr, value);
|
||||
}
|
||||
|
||||
/// Return true if this argument is passed by reference, false if passed by value.
|
||||
@@ -49,19 +50,19 @@ class TaskArg {
|
||||
}
|
||||
|
||||
/// Get the value.
|
||||
std::shared_ptr<Buffer> GetValue() const {
|
||||
RAY_CHECK(data_ != nullptr) << "This argument isn't passed by value.";
|
||||
return data_;
|
||||
const RayObject &GetValue() const {
|
||||
RAY_CHECK(value_ != nullptr) << "This argument isn't passed by value.";
|
||||
return *value_;
|
||||
}
|
||||
|
||||
private:
|
||||
TaskArg(const std::shared_ptr<ObjectID> id, const std::shared_ptr<Buffer> data)
|
||||
: id_(id), data_(data) {}
|
||||
TaskArg(const std::shared_ptr<ObjectID> id, const std::shared_ptr<RayObject> value)
|
||||
: id_(id), value_(value) {}
|
||||
|
||||
/// Id of the argument, if passed by reference, otherwise nullptr.
|
||||
/// Id of the argument if passed by reference, otherwise nullptr.
|
||||
const std::shared_ptr<ObjectID> id_;
|
||||
/// Data of the argument, if passed by value, otherwise nullptr.
|
||||
const std::shared_ptr<Buffer> data_;
|
||||
/// Value of the argument if passed by value, otherwise nullptr.
|
||||
const std::shared_ptr<RayObject> value_;
|
||||
};
|
||||
|
||||
enum class StoreProviderType { PLASMA, MEMORY };
|
||||
|
||||
@@ -43,7 +43,7 @@ jmethodID java_language_get_number;
|
||||
|
||||
jclass java_function_arg_class;
|
||||
jfieldID java_function_arg_id;
|
||||
jfieldID java_function_arg_data;
|
||||
jfieldID java_function_arg_value;
|
||||
|
||||
jclass java_base_task_options_class;
|
||||
jfieldID java_base_task_options_resources;
|
||||
@@ -137,7 +137,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||
java_function_arg_class = LoadClass(env, "org/ray/runtime/task/FunctionArg");
|
||||
java_function_arg_id =
|
||||
env->GetFieldID(java_function_arg_class, "id", "Lorg/ray/api/id/ObjectId;");
|
||||
java_function_arg_data = env->GetFieldID(java_function_arg_class, "data", "[B");
|
||||
java_function_arg_value = env->GetFieldID(java_function_arg_class, "value",
|
||||
"Lorg/ray/runtime/object/NativeRayObject;");
|
||||
|
||||
java_base_task_options_class = LoadClass(env, "org/ray/api/options/BaseTaskOptions");
|
||||
java_base_task_options_resources =
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <jni.h>
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/ray_object.h"
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/core_worker/store_provider/store_provider.h"
|
||||
|
||||
@@ -77,12 +78,12 @@ extern jclass java_language_class;
|
||||
/// getNumber of Language class
|
||||
extern jmethodID java_language_get_number;
|
||||
|
||||
/// NativeTaskArg class
|
||||
/// FunctionArg class
|
||||
extern jclass java_function_arg_class;
|
||||
/// id field of NativeTaskArg class
|
||||
/// id field of FunctionArg class
|
||||
extern jfieldID java_function_arg_id;
|
||||
/// data field of NativeTaskArg class
|
||||
extern jfieldID java_function_arg_data;
|
||||
/// value field of FunctionArg class
|
||||
extern jfieldID java_function_arg_value;
|
||||
|
||||
/// BaseTaskOptions class
|
||||
extern jclass java_base_task_options_class;
|
||||
@@ -279,11 +280,11 @@ inline std::shared_ptr<ray::RayObject> JavaNativeRayObjectToNativeRayObject(
|
||||
std::shared_ptr<ray::Buffer> data_buffer = JavaByteArrayToNativeBuffer(env, java_data);
|
||||
std::shared_ptr<ray::Buffer> metadata_buffer =
|
||||
JavaByteArrayToNativeBuffer(env, java_metadata);
|
||||
if (!data_buffer) {
|
||||
data_buffer = std::make_shared<ray::LocalMemoryBuffer>(nullptr, 0);
|
||||
if (data_buffer && data_buffer->Size() == 0) {
|
||||
data_buffer = nullptr;
|
||||
}
|
||||
if (!metadata_buffer) {
|
||||
metadata_buffer = std::make_shared<ray::LocalMemoryBuffer>(nullptr, 0);
|
||||
if (metadata_buffer && metadata_buffer->Size() == 0) {
|
||||
metadata_buffer = nullptr;
|
||||
}
|
||||
return std::make_shared<ray::RayObject>(data_buffer, metadata_buffer);
|
||||
}
|
||||
|
||||
@@ -34,10 +34,11 @@ inline std::vector<ray::TaskArg> ToTaskArgs(JNIEnv *env, jobject args) {
|
||||
return ray::TaskArg::PassByReference(
|
||||
JavaByteArrayToId<ray::ObjectID>(env, java_id_bytes));
|
||||
}
|
||||
auto java_data =
|
||||
static_cast<jbyteArray>(env->GetObjectField(arg, java_function_arg_data));
|
||||
RAY_CHECK(java_data) << "Both id and data of FunctionArg are null.";
|
||||
return ray::TaskArg::PassByValue(JavaByteArrayToNativeBuffer(env, java_data));
|
||||
auto java_value =
|
||||
static_cast<jbyteArray>(env->GetObjectField(arg, java_function_arg_value));
|
||||
RAY_CHECK(java_value) << "Both id and value of FunctionArg are null.";
|
||||
auto value = JavaNativeRayObjectToNativeRayObject(env, java_value);
|
||||
return ray::TaskArg::PassByValue(value);
|
||||
});
|
||||
return task_args;
|
||||
}
|
||||
|
||||
@@ -69,9 +69,15 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore(
|
||||
for (size_t i = 0; i < plasma_results.size(); i++) {
|
||||
if (plasma_results[i].data != nullptr || plasma_results[i].metadata != nullptr) {
|
||||
const auto &object_id = batch_ids[i];
|
||||
const auto result_object = std::make_shared<RayObject>(
|
||||
std::make_shared<PlasmaBuffer>(plasma_results[i].data),
|
||||
std::make_shared<PlasmaBuffer>(plasma_results[i].metadata));
|
||||
std::shared_ptr<PlasmaBuffer> data = nullptr;
|
||||
std::shared_ptr<PlasmaBuffer> metadata = nullptr;
|
||||
if (plasma_results[i].data && plasma_results[i].data->size()) {
|
||||
data = std::make_shared<PlasmaBuffer>(plasma_results[i].data);
|
||||
}
|
||||
if (plasma_results[i].metadata && plasma_results[i].metadata->size()) {
|
||||
metadata = std::make_shared<PlasmaBuffer>(plasma_results[i].metadata);
|
||||
}
|
||||
const auto result_object = std::make_shared<RayObject>(data, metadata);
|
||||
(*results)[object_id] = result_object;
|
||||
remaining.erase(object_id);
|
||||
if (IsException(*result_object)) {
|
||||
@@ -174,6 +180,9 @@ Status CoreWorkerPlasmaStoreProvider::Delete(const std::vector<ObjectID> &object
|
||||
|
||||
bool CoreWorkerPlasmaStoreProvider::IsException(const RayObject &object) {
|
||||
// TODO (kfstorm): metadata should be structured.
|
||||
if (!object.HasMetadata()) {
|
||||
return false;
|
||||
}
|
||||
const std::string metadata(reinterpret_cast<const char *>(object.GetMetadata()->Data()),
|
||||
object.GetMetadata()->Size());
|
||||
const auto error_type_descriptor = ray::rpc::ErrorType_descriptor();
|
||||
|
||||
@@ -8,56 +8,6 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
/// Binary representation of a ray object.
|
||||
class RayObject {
|
||||
public:
|
||||
/// Create a ray object instance.
|
||||
///
|
||||
/// \param[in] data Data of the ray object.
|
||||
/// \param[in] metadata Metadata of the ray object.
|
||||
/// \param[in] copy_data Whether this class should hold a copy of data.
|
||||
RayObject(const std::shared_ptr<Buffer> &data, const std::shared_ptr<Buffer> &metadata,
|
||||
bool copy_data = false)
|
||||
: data_(data), metadata_(metadata), has_data_copy_(copy_data) {
|
||||
if (has_data_copy_) {
|
||||
// If this object is required to hold a copy of the data,
|
||||
// make a copy if the passed in buffers don't already have a copy.
|
||||
if (data_ && !data_->OwnsData()) {
|
||||
data_ = std::make_shared<LocalMemoryBuffer>(data_->Data(), data_->Size(), true);
|
||||
}
|
||||
|
||||
if (metadata_ && !metadata_->OwnsData()) {
|
||||
metadata_ = std::make_shared<LocalMemoryBuffer>(metadata_->Data(),
|
||||
metadata_->Size(), true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the data of the ray object.
|
||||
const std::shared_ptr<Buffer> &GetData() const { return data_; };
|
||||
|
||||
/// Return the metadata of the ray object.
|
||||
const std::shared_ptr<Buffer> &GetMetadata() const { return metadata_; };
|
||||
|
||||
uint64_t GetSize() const {
|
||||
uint64_t size = 0;
|
||||
size += (data_ != nullptr) ? data_->Size() : 0;
|
||||
size += (metadata_ != nullptr) ? metadata_->Size() : 0;
|
||||
return size;
|
||||
}
|
||||
|
||||
/// Whether this object has metadata.
|
||||
bool HasMetadata() const { return metadata_ != nullptr && metadata_->Size() > 0; }
|
||||
|
||||
private:
|
||||
/// Data of the ray object.
|
||||
std::shared_ptr<Buffer> data_;
|
||||
/// Metadata of the ray object.
|
||||
std::shared_ptr<Buffer> metadata_;
|
||||
/// Whether this class holds a data copy.
|
||||
bool has_data_copy_;
|
||||
};
|
||||
|
||||
/// Provider interface for store access. Store provider should inherit from this class and
|
||||
/// provide implementions for the methods. The actual store provider may use a plasma
|
||||
/// store or local memory store in worker process, or possibly other types of storage.
|
||||
|
||||
@@ -89,10 +89,17 @@ Status CoreWorkerTaskExecutionInterface::BuildArgsForExecutor(
|
||||
indices.push_back(i);
|
||||
} else {
|
||||
// pass by value.
|
||||
(*args)[i] = std::make_shared<RayObject>(
|
||||
std::make_shared<LocalMemoryBuffer>(const_cast<uint8_t *>(task.ArgVal(i)),
|
||||
task.ArgValLength(i)),
|
||||
nullptr);
|
||||
std::shared_ptr<LocalMemoryBuffer> data = nullptr;
|
||||
if (task.ArgDataSize(i)) {
|
||||
data = std::make_shared<LocalMemoryBuffer>(const_cast<uint8_t *>(task.ArgData(i)),
|
||||
task.ArgDataSize(i));
|
||||
}
|
||||
std::shared_ptr<LocalMemoryBuffer> metadata = nullptr;
|
||||
if (task.ArgMetadataSize(i)) {
|
||||
metadata = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(task.ArgMetadata(i)), task.ArgMetadataSize(i));
|
||||
}
|
||||
(*args)[i] = std::make_shared<RayObject>(data, metadata);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ void CoreWorkerTaskInterface::BuildCommonTaskSpec(
|
||||
if (arg.IsPassedByReference()) {
|
||||
builder.AddByRefArg(arg.GetReference());
|
||||
} else {
|
||||
builder.AddByValueArg(arg.GetValue()->Data(), arg.GetValue()->Size());
|
||||
builder.AddByValueArg(arg.GetValue());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/common/ray_object.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
#include "ray/core_worker/transport/direct_actor_transport.h"
|
||||
@@ -59,7 +60,7 @@ std::unique_ptr<ActorHandle> CreateActorHelper(
|
||||
|
||||
RayFunction func{ray::Language::PYTHON, {"actor creation task"}};
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(TaskArg::PassByValue(buffer));
|
||||
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr)));
|
||||
|
||||
ActorCreationOptions actor_options{max_reconstructions, is_direct_call, resources, {}};
|
||||
|
||||
@@ -232,7 +233,8 @@ void CoreWorkerTest::TestNormalTask(
|
||||
RAY_CHECK_OK(driver.Objects().Put(RayObject(buffer2, nullptr), &object_id));
|
||||
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(TaskArg::PassByValue(buffer1));
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(buffer1, nullptr)));
|
||||
args.emplace_back(TaskArg::PassByReference(object_id));
|
||||
|
||||
RayFunction func{ray::Language::PYTHON, {}};
|
||||
@@ -273,8 +275,10 @@ void CoreWorkerTest::TestActorTask(
|
||||
|
||||
// Create arguments with PassByRef and PassByValue.
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(TaskArg::PassByValue(buffer1));
|
||||
args.emplace_back(TaskArg::PassByValue(buffer2));
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(buffer1, nullptr)));
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(buffer2, nullptr)));
|
||||
|
||||
TaskOptions options{1, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
@@ -315,7 +319,8 @@ void CoreWorkerTest::TestActorTask(
|
||||
// Create arguments with PassByRef and PassByValue.
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(TaskArg::PassByReference(object_id));
|
||||
args.emplace_back(TaskArg::PassByValue(buffer2));
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(buffer2, nullptr)));
|
||||
|
||||
TaskOptions options{1, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
@@ -380,7 +385,8 @@ void CoreWorkerTest::TestActorReconstruction(
|
||||
|
||||
// Create arguments with PassByValue.
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(TaskArg::PassByValue(buffer1));
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(buffer1, nullptr)));
|
||||
|
||||
TaskOptions options{1, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
@@ -425,7 +431,8 @@ void CoreWorkerTest::TestActorFailure(
|
||||
|
||||
// Create arguments with PassByRef and PassByValue.
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(TaskArg::PassByValue(buffer1));
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(buffer1, nullptr)));
|
||||
|
||||
TaskOptions options{1, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
@@ -618,11 +625,10 @@ TEST_F(ZeroNodeTest, TestTaskArg) {
|
||||
ASSERT_TRUE(by_ref.IsPassedByReference());
|
||||
ASSERT_EQ(by_ref.GetReference(), id);
|
||||
// Test by-value argument.
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer =
|
||||
std::make_shared<LocalMemoryBuffer>(static_cast<uint8_t *>(0), 0);
|
||||
TaskArg by_value = TaskArg::PassByValue(buffer);
|
||||
auto buffer = GenerateRandomBuffer();
|
||||
TaskArg by_value = TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr));
|
||||
ASSERT_FALSE(by_value.IsPassedByReference());
|
||||
auto data = by_value.GetValue();
|
||||
auto data = by_value.GetValue().GetData();
|
||||
ASSERT_TRUE(data != nullptr);
|
||||
ASSERT_EQ(*data, *buffer);
|
||||
}
|
||||
@@ -635,7 +641,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
|
||||
auto buffer = std::make_shared<LocalMemoryBuffer>(array, sizeof(array));
|
||||
RayFunction function{ray::Language::PYTHON, {}};
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(TaskArg::PassByValue(buffer));
|
||||
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr)));
|
||||
|
||||
std::unordered_map<std::string, double> resources;
|
||||
ActorCreationOptions actor_options{0, /*is_direct_call*/ true, resources, {}};
|
||||
@@ -664,7 +670,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
|
||||
if (arg.IsPassedByReference()) {
|
||||
builder.AddByRefArg(arg.GetReference());
|
||||
} else {
|
||||
builder.AddByValueArg(arg.GetValue()->Data(), arg.GetValue()->Size());
|
||||
builder.AddByValueArg(arg.GetValue());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -696,7 +702,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) {
|
||||
auto buffer = std::make_shared<LocalMemoryBuffer>(array, sizeof(array));
|
||||
RayFunction func{ray::Language::PYTHON, {}};
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(TaskArg::PassByValue(buffer));
|
||||
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr)));
|
||||
|
||||
std::unordered_map<std::string, double> resources;
|
||||
ActorCreationOptions actor_options{0, /*is_direct_call*/ true, resources, {}};
|
||||
@@ -712,7 +718,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) {
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
// Create arguments with PassByValue.
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(TaskArg::PassByValue(buffer));
|
||||
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr)));
|
||||
|
||||
TaskOptions options{1, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
|
||||
@@ -224,10 +224,10 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask(
|
||||
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT_ACTOR));
|
||||
return_object->set_object_id(id.Binary());
|
||||
const auto &result = results[i];
|
||||
if (result->GetData() != nullptr) {
|
||||
if (result->HasData()) {
|
||||
return_object->set_data(result->GetData()->Data(), result->GetData()->Size());
|
||||
}
|
||||
if (result->GetMetadata() != nullptr) {
|
||||
if (result->HasMetadata()) {
|
||||
return_object->set_metadata(result->GetMetadata()->Data(),
|
||||
result->GetMetadata()->Size());
|
||||
}
|
||||
|
||||
@@ -74,6 +74,8 @@ message TaskArg {
|
||||
repeated bytes object_ids = 1;
|
||||
// Data for pass-by-value arguments.
|
||||
bytes data = 2;
|
||||
// Metadata for pass-by-value arguments.
|
||||
bytes metadata = 3;
|
||||
}
|
||||
|
||||
// Task spec of an actor creation task.
|
||||
|
||||
Reference in New Issue
Block a user