From d8f580469096b02fc1ac66a2a90cd2b0ef9dafba Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sun, 8 Sep 2019 11:07:48 +0800 Subject: [PATCH] Support metadata for passing by value task arguments (#5527) --- .../org/ray/runtime/AbstractRayRuntime.java | 13 +-- .../ray/runtime/object/NativeRayObject.java | 17 +++- .../ray/runtime/object/ObjectSerializer.java | 83 ++++++++++++++++++ .../org/ray/runtime/object/ObjectStore.java | 87 +++---------------- .../ray/runtime/task/ArgumentsBuilder.java | 31 +++---- .../org/ray/runtime/task/FunctionArg.java | 15 ++-- .../runtime/task/LocalModeTaskSubmitter.java | 11 ++- .../org/ray/runtime/task/TaskExecutor.java | 9 +- .../org/ray/api/test/RaySerializerTest.java | 12 +-- python/ray/includes/task.pxd | 8 +- python/ray/includes/task.pxi | 14 ++- src/ray/common/buffer.h | 5 ++ src/ray/common/ray_object.h | 77 ++++++++++++++++ src/ray/common/task/task_spec.cc | 12 ++- src/ray/common/task/task_spec.h | 8 +- src/ray/common/task/task_util.h | 24 +++-- src/ray/core_worker/common.h | 27 +++--- src/ray/core_worker/lib/java/jni_init.cc | 5 +- src/ray/core_worker/lib/java/jni_utils.h | 17 ++-- ...rg_ray_runtime_task_NativeTaskSubmitter.cc | 9 +- .../store_provider/plasma_store_provider.cc | 15 +++- .../store_provider/store_provider.h | 50 ----------- src/ray/core_worker/task_execution.cc | 15 +++- src/ray/core_worker/task_interface.cc | 2 +- src/ray/core_worker/test/core_worker_test.cc | 36 ++++---- .../transport/direct_actor_transport.cc | 4 +- src/ray/protobuf/common.proto | 2 + 27 files changed, 364 insertions(+), 244 deletions(-) create mode 100644 java/runtime/src/main/java/org/ray/runtime/object/ObjectSerializer.java create mode 100644 src/ray/common/ray_object.h diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index b3ea54e3f..96f55384e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -176,8 +176,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { private RayObject callNormalFunction(FunctionDescriptor functionDescriptor, Object[] args, int numReturns, CallOptions options) { - List functionArgs = ArgumentsBuilder - .wrap(args, functionDescriptor.getLanguage() != Language.JAVA); + List functionArgs = ArgumentsBuilder.wrap(args); List returnIds = taskSubmitter.submitTask(functionDescriptor, functionArgs, numReturns, options); Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1); @@ -190,8 +189,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { private RayObject callActorFunction(RayActor rayActor, FunctionDescriptor functionDescriptor, Object[] args, int numReturns) { - List functionArgs = ArgumentsBuilder - .wrap(args, functionDescriptor.getLanguage() != Language.JAVA); + List functionArgs = ArgumentsBuilder.wrap(args); List returnIds = taskSubmitter.submitActorTask(rayActor, functionDescriptor, functionArgs, numReturns, null); Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1); @@ -204,14 +202,11 @@ public abstract class AbstractRayRuntime implements RayRuntime { private RayActor createActorImpl(FunctionDescriptor functionDescriptor, Object[] args, ActorCreationOptions options) { - List functionArgs = ArgumentsBuilder - .wrap(args, functionDescriptor.getLanguage() != Language.JAVA); + List functionArgs = ArgumentsBuilder.wrap(args); if (functionDescriptor.getLanguage() != Language.JAVA && options != null) { Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions)); } - RayActor actor = taskSubmitter - .createActor(functionDescriptor, functionArgs, - options); + RayActor actor = taskSubmitter.createActor(functionDescriptor, functionArgs, options); return actor; } diff --git a/java/runtime/src/main/java/org/ray/runtime/object/NativeRayObject.java b/java/runtime/src/main/java/org/ray/runtime/object/NativeRayObject.java index 20111b7a6..d1773c30d 100644 --- a/java/runtime/src/main/java/org/ray/runtime/object/NativeRayObject.java +++ b/java/runtime/src/main/java/org/ray/runtime/object/NativeRayObject.java @@ -1,7 +1,9 @@ package org.ray.runtime.object; +import com.google.common.base.Preconditions; + /** - * Binary representation of ray object. + * Binary representation of a ray object. See `RayObject` class in C++ for details. */ public class NativeRayObject { @@ -9,8 +11,21 @@ public class NativeRayObject { public byte[] metadata; public NativeRayObject(byte[] data, byte[] metadata) { + Preconditions.checkState(bufferLength(data) > 0 || bufferLength(metadata) > 0); this.data = data; this.metadata = metadata; } + + private static int bufferLength(byte[] buffer) { + if (buffer == null) { + return 0; + } + return buffer.length; + } + + @Override + public String toString() { + return ": " + bufferLength(data) + ", : " + bufferLength(metadata); + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/object/ObjectSerializer.java b/java/runtime/src/main/java/org/ray/runtime/object/ObjectSerializer.java new file mode 100644 index 000000000..f386ae5c0 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/object/ObjectSerializer.java @@ -0,0 +1,83 @@ +package org.ray.runtime.object; + +import java.util.Arrays; +import org.ray.api.exception.RayActorException; +import org.ray.api.exception.RayTaskException; +import org.ray.api.exception.RayWorkerException; +import org.ray.api.exception.UnreconstructableException; +import org.ray.api.id.ObjectId; +import org.ray.runtime.generated.Gcs.ErrorType; +import org.ray.runtime.util.Serializer; + +/** + * Serialize to and deserialize from {@link NativeRayObject}. Metadata is generated during + * serialization and respected during deserialization. + */ +public class ObjectSerializer { + + private static final byte[] WORKER_EXCEPTION_META = String + .valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes(); + private static final byte[] ACTOR_EXCEPTION_META = String + .valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes(); + private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String + .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes(); + + private static final byte[] TASK_EXECUTION_EXCEPTION_META = String + .valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes(); + + private static final byte[] RAW_TYPE_META = "RAW".getBytes(); + + /** + * Deserialize an object from an {@link NativeRayObject} instance. + * + * @param nativeRayObject The object to deserialize. + * @param objectId The associated object ID of the object. + * @param classLoader The classLoader of the object. + * @return The deserialized object. + */ + public static Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId, + ClassLoader classLoader) { + byte[] meta = nativeRayObject.metadata; + byte[] data = nativeRayObject.data; + + if (meta != null && meta.length > 0) { + // If meta is not null, deserialize the object from meta. + if (Arrays.equals(meta, RAW_TYPE_META)) { + return data; + } else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { + return RayWorkerException.INSTANCE; + } else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) { + return RayActorException.INSTANCE; + } else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) { + return new UnreconstructableException(objectId); + } else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) { + return Serializer.decode(data, classLoader); + } + throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta)); + } else { + // If data is not null, deserialize the Java object. + return Serializer.decode(data, classLoader); + } + } + + /** + * Serialize an Java object to an {@link NativeRayObject} instance. + * + * @param object The object to serialize. + * @return The serialized object. + */ + public static NativeRayObject serialize(Object object) { + if (object instanceof NativeRayObject) { + return (NativeRayObject) object; + } else if (object instanceof byte[]) { + // If the object is a byte array, skip serializing it and use a special metadata to + // indicate it's raw binary. So that this object can also be read by Python. + return new NativeRayObject((byte[]) object, RAW_TYPE_META); + } else if (object instanceof RayTaskException) { + return new NativeRayObject(Serializer.encode(object), + TASK_EXECUTION_EXCEPTION_META); + } else { + return new NativeRayObject(Serializer.encode(object), null); + } + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/object/ObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/object/ObjectStore.java index e2cefbbb7..223c49b27 100644 --- a/java/runtime/src/main/java/org/ray/runtime/object/ObjectStore.java +++ b/java/runtime/src/main/java/org/ray/runtime/object/ObjectStore.java @@ -2,40 +2,21 @@ package org.ray.runtime.object; import com.google.common.base.Preconditions; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; import org.ray.api.RayObject; import org.ray.api.WaitResult; -import org.ray.api.exception.RayActorException; import org.ray.api.exception.RayException; -import org.ray.api.exception.RayTaskException; -import org.ray.api.exception.RayWorkerException; -import org.ray.api.exception.UnreconstructableException; import org.ray.api.id.ObjectId; import org.ray.runtime.context.WorkerContext; -import org.ray.runtime.generated.Gcs.ErrorType; -import org.ray.runtime.util.Serializer; /** * A class that is used to put/get objects to/from the object store. */ public abstract class ObjectStore { - private static final byte[] WORKER_EXCEPTION_META = String - .valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes(); - private static final byte[] ACTOR_EXCEPTION_META = String - .valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes(); - private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String - .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes(); - - private static final byte[] TASK_EXECUTION_EXCEPTION_META = String - .valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes(); - - private static final byte[] RAW_TYPE_META = "RAW".getBytes(); - private final WorkerContext workerContext; public ObjectStore(WorkerContext workerContext) { @@ -65,7 +46,11 @@ public abstract class ObjectStore { * @return Id of the object. */ public ObjectId put(Object object) { - return putRaw(serialize(object)); + if (object instanceof NativeRayObject) { + throw new IllegalArgumentException( + "Trying to put a NativeRayObject. Please use putRaw instead."); + } + return putRaw(ObjectSerializer.serialize(object)); } /** @@ -77,7 +62,11 @@ public abstract class ObjectStore { * @param objectId Object id. */ public void put(Object object, ObjectId objectId) { - putRaw(serialize(object), objectId); + if (object instanceof NativeRayObject) { + throw new IllegalArgumentException( + "Trying to put a NativeRayObject. Please use putRaw instead."); + } + putRaw(ObjectSerializer.serialize(object), objectId); } /** @@ -106,7 +95,8 @@ public abstract class ObjectStore { NativeRayObject dataAndMeta = dataAndMetaList.get(i); Object object = null; if (dataAndMeta != null) { - object = deserialize(dataAndMeta, ids.get(i)); + object = ObjectSerializer + .deserialize(dataAndMeta, ids.get(i), workerContext.getCurrentClassLoader()); } if (object instanceof RayException) { // If the object is a `RayException`, it means that an error occurred during task @@ -174,57 +164,4 @@ public abstract class ObjectStore { */ public abstract void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks); - - /** - * Deserialize an object. - * - * @param nativeRayObject The object to deserialize. - * @param objectId The associated object ID of the object. - * @return The deserialized object. - */ - public Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId) { - byte[] meta = nativeRayObject.metadata; - byte[] data = nativeRayObject.data; - - // If meta is not null, deserialize the object from meta. - if (meta != null && meta.length > 0) { - // If meta is not null, deserialize the object from meta. - if (Arrays.equals(meta, RAW_TYPE_META)) { - return data; - } else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { - return RayWorkerException.INSTANCE; - } else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) { - return RayActorException.INSTANCE; - } else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) { - return new UnreconstructableException(objectId); - } else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) { - return Serializer.decode(data, workerContext.getCurrentClassLoader()); - } - throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta)); - } else { - // If data is not null, deserialize the Java object. - return Serializer.decode(data, workerContext.getCurrentClassLoader()); - } - } - - /** - * Serialize an object. - * - * @param object The object to serialize. - * @return The serialized object. - */ - public NativeRayObject serialize(Object object) { - if (object instanceof NativeRayObject) { - return (NativeRayObject) object; - } else if (object instanceof byte[]) { - // If the object is a byte array, skip serializing it and use a special metadata to - // indicate it's raw binary. So that this object can also be read by Python. - return new NativeRayObject((byte[]) object, RAW_TYPE_META); - } else if (object instanceof RayTaskException) { - return new NativeRayObject(Serializer.encode(object), - TASK_EXECUTION_EXCEPTION_META); - } else { - return new NativeRayObject(Serializer.encode(object), null); - } - } } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index c74932c8f..11e524619 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -9,8 +9,7 @@ import org.ray.api.runtime.RayRuntime; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayMultiWorkerNativeRuntime; import org.ray.runtime.object.NativeRayObject; -import org.ray.runtime.object.ObjectStore; -import org.ray.runtime.util.Serializer; +import org.ray.runtime.object.ObjectSerializer; /** * Helper methods to convert arguments from/to objects. @@ -26,37 +25,29 @@ public class ArgumentsBuilder { /** * Convert real function arguments to task spec arguments. */ - public static List wrap(Object[] args, boolean crossLanguage) { + public static List wrap(Object[] args) { List ret = new ArrayList<>(); for (Object arg : args) { ObjectId id = null; - byte[] data = null; - if (arg == null) { - data = Serializer.encode(null); - } else if (arg instanceof RayObject) { + NativeRayObject value = null; + if (arg instanceof RayObject) { id = ((RayObject) arg).getId(); - } else if (arg instanceof byte[] && crossLanguage) { - // If the argument is a byte array and will be used by a different language, - // do not inline this argument. Because the other language doesn't know how - // to deserialize it. - id = Ray.put(arg).getId(); } else { - byte[] serialized = Serializer.encode(arg); - if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) { + value = ObjectSerializer.serialize(arg); + if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) { RayRuntime runtime = Ray.internal(); if (runtime instanceof RayMultiWorkerNativeRuntime) { runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime(); } id = ((AbstractRayRuntime) runtime).getObjectStore() - .put(new NativeRayObject(serialized, null)); - } else { - data = serialized; + .putRaw(value); + value = null; } } if (id != null) { ret.add(FunctionArg.passByReference(id)); } else { - ret.add(FunctionArg.passByValue(data)); + ret.add(FunctionArg.passByValue(value)); } } return ret; @@ -65,10 +56,10 @@ public class ArgumentsBuilder { /** * Convert list of NativeRayObject to real function arguments. */ - public static Object[] unwrap(ObjectStore objectStore, List args) { + public static Object[] unwrap(List args, ClassLoader classLoader) { Object[] realArgs = new Object[args.size()]; for (int i = 0; i < args.size(); i++) { - realArgs[i] = objectStore.deserialize(args.get(i), null); + realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, classLoader); } return realArgs; } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java b/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java index 95bdcb0da..b397d9362 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java @@ -1,6 +1,8 @@ package org.ray.runtime.task; +import com.google.common.base.Preconditions; import org.ray.api.id.ObjectId; +import org.ray.runtime.object.NativeRayObject; /** * Represents a function argument in task spec. @@ -16,11 +18,12 @@ public class FunctionArg { /** * Serialized data of this argument (passed by value). */ - public final byte[] data; + public final NativeRayObject value; - private FunctionArg(ObjectId id, byte[] data) { + private FunctionArg(ObjectId id, NativeRayObject value) { + Preconditions.checkState((id == null) != (value == null)); this.id = id; - this.data = data; + this.value = value; } /** @@ -33,8 +36,8 @@ public class FunctionArg { /** * Create a FunctionArg that will be passed by value. */ - public static FunctionArg passByValue(byte[] data) { - return new FunctionArg(null, data); + public static FunctionArg passByValue(NativeRayObject value) { + return new FunctionArg(null, value); } @Override @@ -42,7 +45,7 @@ public class FunctionArg { if (id != null) { return ": " + id.toString(); } else { - return ": " + data.length; + return value.toString(); } } } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java index b5f22e8d3..0cb23fcb6 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java @@ -154,7 +154,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { .collect(Collectors.toList())) .addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder() .addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build() - : TaskArg.newBuilder().setData(ByteString.copyFrom(arg.data)).build()) + : TaskArg.newBuilder().setData(ByteString.copyFrom(arg.value.data)) + .setMetadata(arg.value.metadata != null ? ByteString + .copyFrom(arg.value.metadata) : ByteString.EMPTY).build()) .collect(Collectors.toList())); } @@ -233,7 +235,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { List args = getFunctionArgs(taskSpec).stream() .map(arg -> arg.id != null ? objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0) - : new NativeRayObject(arg.data, null)) + : arg.value) .collect(Collectors.toList()); ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec); List returnObjects = taskExecutor @@ -246,7 +248,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { // If the task is an actor task or an actor creation task, // put the dummy object in object store, so those tasks which depends on it // can be executed. - putObject = new NativeRayObject(new byte[]{}, new byte[]{}); + putObject = new NativeRayObject(new byte[]{1}, null); } else { putObject = returnObjects.get(i); } @@ -279,7 +281,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { functionArgs.add(FunctionArg .passByReference(new ObjectId(arg.getObjectIds(0).toByteArray()))); } else { - functionArgs.add(FunctionArg.passByValue(arg.getData().toByteArray())); + functionArgs.add(FunctionArg.passByValue( + new NativeRayObject(arg.getData().toByteArray(), arg.getMetadata().toByteArray()))); } } return functionArgs; diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java index 76e1116fe..95ff86c67 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java @@ -17,6 +17,7 @@ import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.functionmanager.RayFunction; import org.ray.runtime.generated.Common.TaskType; import org.ray.runtime.object.NativeRayObject; +import org.ray.runtime.object.ObjectSerializer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -87,7 +88,7 @@ public final class TaskExecutor { actor = currentActor; } - Object[] args = ArgumentsBuilder.unwrap(runtime.getObjectStore(), argsBytes); + Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader); // Execute the task. Object result; if (!rayFunction.isConstructor()) { @@ -102,7 +103,7 @@ public final class TaskExecutor { maybeSaveCheckpoint(actor, runtime.getWorkerContext().getCurrentActorId()); } if (rayFunction.hasReturn()) { - returnObjects.add(runtime.getObjectStore().serialize(result)); + returnObjects.add(ObjectSerializer.serialize(result)); } } else { // TODO (kfstorm): handle checkpoint in core worker. @@ -113,8 +114,8 @@ public final class TaskExecutor { } catch (Exception e) { LOGGER.error("Error executing task " + taskId, e); if (taskType != TaskType.ACTOR_CREATION_TASK) { - if(rayFunction.hasReturn()) { - returnObjects.add(runtime.getObjectStore() + if (rayFunction.hasReturn()) { + returnObjects.add(ObjectSerializer .serialize(new RayTaskException("Error executing task " + taskId, e))); } } else { diff --git a/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java index 0cb8c6b23..2080a250c 100644 --- a/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java +++ b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java @@ -3,9 +3,9 @@ package org.ray.api.test; import org.ray.api.Ray; import org.ray.api.RayPyActor; import org.ray.api.TestUtils; -import org.ray.api.id.ObjectId; +import org.ray.runtime.context.WorkerContext; import org.ray.runtime.object.NativeRayObject; -import org.ray.runtime.object.ObjectStore; +import org.ray.runtime.object.ObjectSerializer; import org.testng.Assert; import org.testng.annotations.Test; @@ -14,10 +14,10 @@ public class RaySerializerTest extends BaseMultiLanguageTest { @Test public void testSerializePyActor() { RayPyActor pyActor = Ray.createPyActor("test", "RaySerializerTest"); - ObjectStore objectStore = TestUtils.getRuntime().getObjectStore(); - NativeRayObject nativeRayObject = objectStore.serialize(pyActor); - RayPyActor result = (RayPyActor) objectStore - .deserialize(nativeRayObject, ObjectId.fromRandom()); + WorkerContext workerContext = TestUtils.getRuntime().getWorkerContext(); + NativeRayObject nativeRayObject = ObjectSerializer.serialize(pyActor); + RayPyActor result = (RayPyActor) ObjectSerializer + .deserialize(nativeRayObject, null, workerContext.getCurrentClassLoader()); Assert.assertEquals(result.getId(), pyActor.getId()); Assert.assertEquals(result.getModuleName(), "test"); Assert.assertEquals(result.getClassName(), "RaySerializerTest"); diff --git a/python/ray/includes/task.pxd b/python/ray/includes/task.pxd index 00b45d02b..1645ebf85 100644 --- a/python/ray/includes/task.pxd +++ b/python/ray/includes/task.pxd @@ -54,8 +54,10 @@ cdef extern from "ray/common/task/task_spec.h" namespace "ray" nogil: int ArgIdCount(uint64_t arg_index) const CObjectID ArgId(uint64_t arg_index, uint64_t id_index) const CObjectID ReturnId(uint64_t return_index) const - const uint8_t *ArgVal(uint64_t arg_index) const - size_t ArgValLength(uint64_t arg_index) const + const uint8_t *ArgData(uint64_t arg_index) const + size_t ArgDataSize(uint64_t arg_index) const + const uint8_t *ArgMetadata(uint64_t arg_index) const + size_t ArgMetadataSize(uint64_t arg_index) const double GetRequiredResource(const c_string &resource_name) const const ResourceSet GetRequiredResources() const const ResourceSet GetRequiredPlacementResources() const @@ -86,7 +88,7 @@ cdef extern from "ray/common/task/task_util.h" namespace "ray" nogil: TaskSpecBuilder &AddByRefArg(const CObjectID &arg_id) - TaskSpecBuilder &AddByValueArg(const c_string &data) + TaskSpecBuilder &AddByValueArg(const c_string &data, const c_string &metadata) TaskSpecBuilder &SetActorCreationTaskSpec( const CActorID &actor_id, uint64_t max_reconstructions, diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi index f1290ac7d..943a123ef 100644 --- a/python/ray/includes/task.pxi +++ b/python/ray/includes/task.pxi @@ -12,6 +12,7 @@ from ray.includes.task cimport ( TaskSpecBuilder, TaskTableData, ) +from ray.ray_constants import RAW_BUFFER_METADATA from ray.utils import decode @@ -68,10 +69,12 @@ cdef class TaskSpec: for arg in arguments: if isinstance(arg, ObjectID): builder.AddByRefArg((arg).native()) + elif isinstance(arg, bytes): + builder.AddByValueArg(arg, RAW_BUFFER_METADATA) else: pickled_str = pickle.dumps( arg, protocol=pickle.HIGHEST_PROTOCOL) - builder.AddByValueArg(pickled_str) + builder.AddByValueArg(pickled_str, b'') if not actor_creation_id.is_nil(): # Actor creation task. @@ -180,9 +183,12 @@ cdef class TaskSpec: arg_list.append( ObjectID(task_spec.ArgId(i, 0).Binary())) else: - serialized_str = ( - task_spec.ArgVal(i)[:task_spec.ArgValLength(i)]) - obj = pickle.loads(serialized_str) + data = (task_spec.ArgData(i)[:task_spec.ArgDataSize(i)]) + metadata = (task_spec.ArgMetadata(i)[:task_spec.ArgMetadataSize(i)]) + if metadata == RAW_BUFFER_METADATA: + obj = data + else: + obj = pickle.loads(data) arg_list.append(obj) elif lang == LANGUAGE_JAVA: arg_list = num_args * [""] diff --git a/src/ray/common/buffer.h b/src/ray/common/buffer.h index 10c49cf8e..15e928958 100644 --- a/src/ray/common/buffer.h +++ b/src/ray/common/buffer.h @@ -39,6 +39,11 @@ class LocalMemoryBuffer : public Buffer { public: /// Constructor. /// + /// By default when initializing a LocalMemoryBuffer with a data pointer and a length, + /// it just assigns the pointer and length without coping the data content. This is + /// for performance reasons. In this case the buffer cannot ensure data validity. It + /// instead relies on the lifetime passed in data pointer. + /// /// \param data The data pointer to the passed-in buffer. /// \param size The size of the passed in buffer. /// \param copy_data If true, data will be copied and owned by this buffer, diff --git a/src/ray/common/ray_object.h b/src/ray/common/ray_object.h new file mode 100644 index 000000000..a7b3c9b21 --- /dev/null +++ b/src/ray/common/ray_object.h @@ -0,0 +1,77 @@ +#ifndef RAY_COMMON_RAY_OBJECT_H +#define RAY_COMMON_RAY_OBJECT_H + +#include "ray/common/buffer.h" +#include "ray/util/logging.h" + +namespace ray { + +/// Binary representation of a ray object, consisting of buffer pointers to data and +/// metadata. A ray object may have both data and metadata, or only one of them. +class RayObject { + public: + /// Create a ray object instance. + /// + /// Set `copy_data` to `false` is fine for most cases - for example when putting + /// an object into store with a temporary RayObject, and we don't want to do an extra + /// copy. But in some cases we do want to always hold a valid data - for example, memory + /// store uses RayObject to represent objects, in this case we actually want the object + /// data to remain valid after user puts it into store. + /// + /// \param[in] data Data of the ray object. + /// \param[in] metadata Metadata of the ray object. + /// \param[in] copy_data Whether this class should hold a copy of data. + RayObject(const std::shared_ptr &data, const std::shared_ptr &metadata, + bool copy_data = false) + : data_(data), metadata_(metadata), has_data_copy_(copy_data) { + RAY_CHECK(!data || data_->Size()) + << "Zero-length buffers are not allowed when constructing a RayObject."; + RAY_CHECK(!metadata || metadata->Size()) + << "Zero-length buffers are not allowed when constructing a RayObject."; + + if (has_data_copy_) { + // If this object is required to hold a copy of the data, + // make a copy if the passed in buffers don't already have a copy. + if (data_ && !data_->OwnsData()) { + data_ = std::make_shared(data_->Data(), data_->Size(), + /*copy_data=*/true); + } + + if (metadata_ && !metadata_->OwnsData()) { + metadata_ = std::make_shared( + metadata_->Data(), metadata_->Size(), /*copy_data=*/true); + } + } + + RAY_CHECK(data_ || metadata_) << "Data and metadata cannot both be empty."; + } + + /// Return the data of the ray object. + const std::shared_ptr &GetData() const { return data_; }; + + /// Return the metadata of the ray object. + const std::shared_ptr &GetMetadata() const { return metadata_; }; + + uint64_t GetSize() const { + uint64_t size = 0; + size += (data_ != nullptr) ? data_->Size() : 0; + size += (metadata_ != nullptr) ? metadata_->Size() : 0; + return size; + } + + /// Whether this object has data. + bool HasData() const { return data_ != nullptr; } + + /// Whether this object has metadata. + bool HasMetadata() const { return metadata_ != nullptr; } + + private: + std::shared_ptr data_; + std::shared_ptr metadata_; + /// Whether this class holds a data copy. + bool has_data_copy_; +}; + +} // namespace ray + +#endif // RAY_COMMON_BUFFER_H \ No newline at end of file diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 1f04d3943..cca5bf048 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -53,14 +53,22 @@ ObjectID TaskSpecification::ArgId(size_t arg_index, size_t id_index) const { return ObjectID::FromBinary(message_->args(arg_index).object_ids(id_index)); } -const uint8_t *TaskSpecification::ArgVal(size_t arg_index) const { +const uint8_t *TaskSpecification::ArgData(size_t arg_index) const { return reinterpret_cast(message_->args(arg_index).data().data()); } -size_t TaskSpecification::ArgValLength(size_t arg_index) const { +size_t TaskSpecification::ArgDataSize(size_t arg_index) const { return message_->args(arg_index).data().size(); } +const uint8_t *TaskSpecification::ArgMetadata(size_t arg_index) const { + return reinterpret_cast(message_->args(arg_index).metadata().data()); +} + +size_t TaskSpecification::ArgMetadataSize(size_t arg_index) const { + return message_->args(arg_index).metadata().size(); +} + const ResourceSet TaskSpecification::GetRequiredResources() const { return required_resources_; } diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index d1fae90ac..d481a4cbb 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -70,9 +70,13 @@ class TaskSpecification : public MessageWrapper { ObjectID ReturnId(size_t return_index) const; - const uint8_t *ArgVal(size_t arg_index) const; + const uint8_t *ArgData(size_t arg_index) const; - size_t ArgValLength(size_t arg_index) const; + size_t ArgDataSize(size_t arg_index) const; + + const uint8_t *ArgMetadata(size_t arg_index) const; + + size_t ArgMetadataSize(size_t arg_index) const; /// Return the resources that are to be acquired during the execution of this /// task. diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index 2bc635cc0..a5967d5c4 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -1,6 +1,8 @@ #ifndef RAY_COMMON_TASK_TASK_UTIL_H #define RAY_COMMON_TASK_TASK_UTIL_H +#include "ray/common/buffer.h" +#include "ray/common/ray_object.h" #include "ray/common/task/task_spec.h" #include "ray/protobuf/common.pb.h" @@ -56,19 +58,29 @@ class TaskSpecBuilder { /// Add a by-value argument to the task. /// /// \param data String object that contains the data. + /// \param metadata String object that contains the metadata. /// \return Reference to the builder object itself. - TaskSpecBuilder &AddByValueArg(const std::string &data) { - message_->add_args()->set_data(data); + TaskSpecBuilder &AddByValueArg(const std::string &data, const std::string &metadata) { + auto arg = message_->add_args(); + arg->set_data(data); + arg->set_metadata(metadata); return *this; } /// Add a by-value argument to the task. /// - /// \param data Pointer to the data. - /// \param size Size of the data. + /// \param value the RayObject instance that contains the data and the metadata. /// \return Reference to the builder object itself. - TaskSpecBuilder &AddByValueArg(const void *data, size_t size) { - message_->add_args()->set_data(data, size); + TaskSpecBuilder &AddByValueArg(const RayObject &value) { + auto arg = message_->add_args(); + if (value.HasData()) { + const auto &data = value.GetData(); + arg->set_data(data->Data(), data->Size()); + } + if (value.HasMetadata()) { + const auto &metadata = value.GetMetadata(); + arg->set_metadata(metadata->Data(), metadata->Size()); + } return *this; } diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 631ada2ba..ad540e083 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -3,8 +3,8 @@ #include -#include "ray/common/buffer.h" #include "ray/common/id.h" +#include "ray/common/ray_object.h" #include "ray/common/task/task_spec.h" #include "ray/raylet/raylet_client.h" #include "ray/util/util.h" @@ -31,12 +31,13 @@ class TaskArg { return TaskArg(std::make_shared(object_id), nullptr); } - /// Create a pass-by-reference task argument. + /// Create a pass-by-value task argument. /// - /// \param[in] object_id Id of the argument. + /// \param[in] value Value of the argument. /// \return The task argument. - static TaskArg PassByValue(const std::shared_ptr &data) { - return TaskArg(nullptr, data); + static TaskArg PassByValue(const std::shared_ptr &value) { + RAY_CHECK(value) << "Value can't be null."; + return TaskArg(nullptr, value); } /// Return true if this argument is passed by reference, false if passed by value. @@ -49,19 +50,19 @@ class TaskArg { } /// Get the value. - std::shared_ptr GetValue() const { - RAY_CHECK(data_ != nullptr) << "This argument isn't passed by value."; - return data_; + const RayObject &GetValue() const { + RAY_CHECK(value_ != nullptr) << "This argument isn't passed by value."; + return *value_; } private: - TaskArg(const std::shared_ptr id, const std::shared_ptr data) - : id_(id), data_(data) {} + TaskArg(const std::shared_ptr id, const std::shared_ptr value) + : id_(id), value_(value) {} - /// Id of the argument, if passed by reference, otherwise nullptr. + /// Id of the argument if passed by reference, otherwise nullptr. const std::shared_ptr id_; - /// Data of the argument, if passed by value, otherwise nullptr. - const std::shared_ptr data_; + /// Value of the argument if passed by value, otherwise nullptr. + const std::shared_ptr value_; }; enum class StoreProviderType { PLASMA, MEMORY }; diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index a7bd918ac..7339c54c9 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -43,7 +43,7 @@ jmethodID java_language_get_number; jclass java_function_arg_class; jfieldID java_function_arg_id; -jfieldID java_function_arg_data; +jfieldID java_function_arg_value; jclass java_base_task_options_class; jfieldID java_base_task_options_resources; @@ -137,7 +137,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_function_arg_class = LoadClass(env, "org/ray/runtime/task/FunctionArg"); java_function_arg_id = env->GetFieldID(java_function_arg_class, "id", "Lorg/ray/api/id/ObjectId;"); - java_function_arg_data = env->GetFieldID(java_function_arg_class, "data", "[B"); + java_function_arg_value = env->GetFieldID(java_function_arg_class, "value", + "Lorg/ray/runtime/object/NativeRayObject;"); java_base_task_options_class = LoadClass(env, "org/ray/api/options/BaseTaskOptions"); java_base_task_options_resources = diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 396b5a841..d2f5c71df 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -4,6 +4,7 @@ #include #include "ray/common/buffer.h" #include "ray/common/id.h" +#include "ray/common/ray_object.h" #include "ray/common/status.h" #include "ray/core_worker/store_provider/store_provider.h" @@ -77,12 +78,12 @@ extern jclass java_language_class; /// getNumber of Language class extern jmethodID java_language_get_number; -/// NativeTaskArg class +/// FunctionArg class extern jclass java_function_arg_class; -/// id field of NativeTaskArg class +/// id field of FunctionArg class extern jfieldID java_function_arg_id; -/// data field of NativeTaskArg class -extern jfieldID java_function_arg_data; +/// value field of FunctionArg class +extern jfieldID java_function_arg_value; /// BaseTaskOptions class extern jclass java_base_task_options_class; @@ -279,11 +280,11 @@ inline std::shared_ptr JavaNativeRayObjectToNativeRayObject( std::shared_ptr data_buffer = JavaByteArrayToNativeBuffer(env, java_data); std::shared_ptr metadata_buffer = JavaByteArrayToNativeBuffer(env, java_metadata); - if (!data_buffer) { - data_buffer = std::make_shared(nullptr, 0); + if (data_buffer && data_buffer->Size() == 0) { + data_buffer = nullptr; } - if (!metadata_buffer) { - metadata_buffer = std::make_shared(nullptr, 0); + if (metadata_buffer && metadata_buffer->Size() == 0) { + metadata_buffer = nullptr; } return std::make_shared(data_buffer, metadata_buffer); } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc index beaa000b2..b626cf122 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc @@ -34,10 +34,11 @@ inline std::vector ToTaskArgs(JNIEnv *env, jobject args) { return ray::TaskArg::PassByReference( JavaByteArrayToId(env, java_id_bytes)); } - auto java_data = - static_cast(env->GetObjectField(arg, java_function_arg_data)); - RAY_CHECK(java_data) << "Both id and data of FunctionArg are null."; - return ray::TaskArg::PassByValue(JavaByteArrayToNativeBuffer(env, java_data)); + auto java_value = + static_cast(env->GetObjectField(arg, java_function_arg_value)); + RAY_CHECK(java_value) << "Both id and value of FunctionArg are null."; + auto value = JavaNativeRayObjectToNativeRayObject(env, java_value); + return ray::TaskArg::PassByValue(value); }); return task_args; } diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index ab728cde4..807c9755d 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -69,9 +69,15 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore( for (size_t i = 0; i < plasma_results.size(); i++) { if (plasma_results[i].data != nullptr || plasma_results[i].metadata != nullptr) { const auto &object_id = batch_ids[i]; - const auto result_object = std::make_shared( - std::make_shared(plasma_results[i].data), - std::make_shared(plasma_results[i].metadata)); + std::shared_ptr data = nullptr; + std::shared_ptr metadata = nullptr; + if (plasma_results[i].data && plasma_results[i].data->size()) { + data = std::make_shared(plasma_results[i].data); + } + if (plasma_results[i].metadata && plasma_results[i].metadata->size()) { + metadata = std::make_shared(plasma_results[i].metadata); + } + const auto result_object = std::make_shared(data, metadata); (*results)[object_id] = result_object; remaining.erase(object_id); if (IsException(*result_object)) { @@ -174,6 +180,9 @@ Status CoreWorkerPlasmaStoreProvider::Delete(const std::vector &object bool CoreWorkerPlasmaStoreProvider::IsException(const RayObject &object) { // TODO (kfstorm): metadata should be structured. + if (!object.HasMetadata()) { + return false; + } const std::string metadata(reinterpret_cast(object.GetMetadata()->Data()), object.GetMetadata()->Size()); const auto error_type_descriptor = ray::rpc::ErrorType_descriptor(); diff --git a/src/ray/core_worker/store_provider/store_provider.h b/src/ray/core_worker/store_provider/store_provider.h index dc7757427..0e3fffc23 100644 --- a/src/ray/core_worker/store_provider/store_provider.h +++ b/src/ray/core_worker/store_provider/store_provider.h @@ -8,56 +8,6 @@ namespace ray { -/// Binary representation of a ray object. -class RayObject { - public: - /// Create a ray object instance. - /// - /// \param[in] data Data of the ray object. - /// \param[in] metadata Metadata of the ray object. - /// \param[in] copy_data Whether this class should hold a copy of data. - RayObject(const std::shared_ptr &data, const std::shared_ptr &metadata, - bool copy_data = false) - : data_(data), metadata_(metadata), has_data_copy_(copy_data) { - if (has_data_copy_) { - // If this object is required to hold a copy of the data, - // make a copy if the passed in buffers don't already have a copy. - if (data_ && !data_->OwnsData()) { - data_ = std::make_shared(data_->Data(), data_->Size(), true); - } - - if (metadata_ && !metadata_->OwnsData()) { - metadata_ = std::make_shared(metadata_->Data(), - metadata_->Size(), true); - } - } - } - - /// Return the data of the ray object. - const std::shared_ptr &GetData() const { return data_; }; - - /// Return the metadata of the ray object. - const std::shared_ptr &GetMetadata() const { return metadata_; }; - - uint64_t GetSize() const { - uint64_t size = 0; - size += (data_ != nullptr) ? data_->Size() : 0; - size += (metadata_ != nullptr) ? metadata_->Size() : 0; - return size; - } - - /// Whether this object has metadata. - bool HasMetadata() const { return metadata_ != nullptr && metadata_->Size() > 0; } - - private: - /// Data of the ray object. - std::shared_ptr data_; - /// Metadata of the ray object. - std::shared_ptr metadata_; - /// Whether this class holds a data copy. - bool has_data_copy_; -}; - /// Provider interface for store access. Store provider should inherit from this class and /// provide implementions for the methods. The actual store provider may use a plasma /// store or local memory store in worker process, or possibly other types of storage. diff --git a/src/ray/core_worker/task_execution.cc b/src/ray/core_worker/task_execution.cc index f397ab314..5eea02638 100644 --- a/src/ray/core_worker/task_execution.cc +++ b/src/ray/core_worker/task_execution.cc @@ -89,10 +89,17 @@ Status CoreWorkerTaskExecutionInterface::BuildArgsForExecutor( indices.push_back(i); } else { // pass by value. - (*args)[i] = std::make_shared( - std::make_shared(const_cast(task.ArgVal(i)), - task.ArgValLength(i)), - nullptr); + std::shared_ptr data = nullptr; + if (task.ArgDataSize(i)) { + data = std::make_shared(const_cast(task.ArgData(i)), + task.ArgDataSize(i)); + } + std::shared_ptr metadata = nullptr; + if (task.ArgMetadataSize(i)) { + metadata = std::make_shared( + const_cast(task.ArgMetadata(i)), task.ArgMetadataSize(i)); + } + (*args)[i] = std::make_shared(data, metadata); } } diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index de880f83e..10a792069 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -130,7 +130,7 @@ void CoreWorkerTaskInterface::BuildCommonTaskSpec( if (arg.IsPassedByReference()) { builder.AddByRefArg(arg.GetReference()); } else { - builder.AddByValueArg(arg.GetValue()->Data(), arg.GetValue()->Size()); + builder.AddByValueArg(arg.GetValue()); } } diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index b973e7c77..649a61642 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "ray/common/buffer.h" +#include "ray/common/ray_object.h" #include "ray/core_worker/context.h" #include "ray/core_worker/core_worker.h" #include "ray/core_worker/transport/direct_actor_transport.h" @@ -59,7 +60,7 @@ std::unique_ptr CreateActorHelper( RayFunction func{ray::Language::PYTHON, {"actor creation task"}}; std::vector args; - args.emplace_back(TaskArg::PassByValue(buffer)); + args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr))); ActorCreationOptions actor_options{max_reconstructions, is_direct_call, resources, {}}; @@ -232,7 +233,8 @@ void CoreWorkerTest::TestNormalTask( RAY_CHECK_OK(driver.Objects().Put(RayObject(buffer2, nullptr), &object_id)); std::vector args; - args.emplace_back(TaskArg::PassByValue(buffer1)); + args.emplace_back( + TaskArg::PassByValue(std::make_shared(buffer1, nullptr))); args.emplace_back(TaskArg::PassByReference(object_id)); RayFunction func{ray::Language::PYTHON, {}}; @@ -273,8 +275,10 @@ void CoreWorkerTest::TestActorTask( // Create arguments with PassByRef and PassByValue. std::vector args; - args.emplace_back(TaskArg::PassByValue(buffer1)); - args.emplace_back(TaskArg::PassByValue(buffer2)); + args.emplace_back( + TaskArg::PassByValue(std::make_shared(buffer1, nullptr))); + args.emplace_back( + TaskArg::PassByValue(std::make_shared(buffer2, nullptr))); TaskOptions options{1, resources}; std::vector return_ids; @@ -315,7 +319,8 @@ void CoreWorkerTest::TestActorTask( // Create arguments with PassByRef and PassByValue. std::vector args; args.emplace_back(TaskArg::PassByReference(object_id)); - args.emplace_back(TaskArg::PassByValue(buffer2)); + args.emplace_back( + TaskArg::PassByValue(std::make_shared(buffer2, nullptr))); TaskOptions options{1, resources}; std::vector return_ids; @@ -380,7 +385,8 @@ void CoreWorkerTest::TestActorReconstruction( // Create arguments with PassByValue. std::vector args; - args.emplace_back(TaskArg::PassByValue(buffer1)); + args.emplace_back( + TaskArg::PassByValue(std::make_shared(buffer1, nullptr))); TaskOptions options{1, resources}; std::vector return_ids; @@ -425,7 +431,8 @@ void CoreWorkerTest::TestActorFailure( // Create arguments with PassByRef and PassByValue. std::vector args; - args.emplace_back(TaskArg::PassByValue(buffer1)); + args.emplace_back( + TaskArg::PassByValue(std::make_shared(buffer1, nullptr))); TaskOptions options{1, resources}; std::vector return_ids; @@ -618,11 +625,10 @@ TEST_F(ZeroNodeTest, TestTaskArg) { ASSERT_TRUE(by_ref.IsPassedByReference()); ASSERT_EQ(by_ref.GetReference(), id); // Test by-value argument. - std::shared_ptr buffer = - std::make_shared(static_cast(0), 0); - TaskArg by_value = TaskArg::PassByValue(buffer); + auto buffer = GenerateRandomBuffer(); + TaskArg by_value = TaskArg::PassByValue(std::make_shared(buffer, nullptr)); ASSERT_FALSE(by_value.IsPassedByReference()); - auto data = by_value.GetValue(); + auto data = by_value.GetValue().GetData(); ASSERT_TRUE(data != nullptr); ASSERT_EQ(*data, *buffer); } @@ -635,7 +641,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { auto buffer = std::make_shared(array, sizeof(array)); RayFunction function{ray::Language::PYTHON, {}}; std::vector args; - args.emplace_back(TaskArg::PassByValue(buffer)); + args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr))); std::unordered_map resources; ActorCreationOptions actor_options{0, /*is_direct_call*/ true, resources, {}}; @@ -664,7 +670,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { if (arg.IsPassedByReference()) { builder.AddByRefArg(arg.GetReference()); } else { - builder.AddByValueArg(arg.GetValue()->Data(), arg.GetValue()->Size()); + builder.AddByValueArg(arg.GetValue()); } } @@ -696,7 +702,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { auto buffer = std::make_shared(array, sizeof(array)); RayFunction func{ray::Language::PYTHON, {}}; std::vector args; - args.emplace_back(TaskArg::PassByValue(buffer)); + args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr))); std::unordered_map resources; ActorCreationOptions actor_options{0, /*is_direct_call*/ true, resources, {}}; @@ -712,7 +718,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { for (int i = 0; i < num_tasks; i++) { // Create arguments with PassByValue. std::vector args; - args.emplace_back(TaskArg::PassByValue(buffer)); + args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr))); TaskOptions options{1, resources}; std::vector return_ids; diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 087560df4..513df2b3e 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -224,10 +224,10 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask( /*transport_type=*/static_cast(TaskTransportType::DIRECT_ACTOR)); return_object->set_object_id(id.Binary()); const auto &result = results[i]; - if (result->GetData() != nullptr) { + if (result->HasData()) { return_object->set_data(result->GetData()->Data(), result->GetData()->Size()); } - if (result->GetMetadata() != nullptr) { + if (result->HasMetadata()) { return_object->set_metadata(result->GetMetadata()->Data(), result->GetMetadata()->Size()); } diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 15f17c2f5..000670af8 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -74,6 +74,8 @@ message TaskArg { repeated bytes object_ids = 1; // Data for pass-by-value arguments. bytes data = 2; + // Metadata for pass-by-value arguments. + bytes metadata = 3; } // Task spec of an actor creation task.