Support metadata for passing by value task arguments (#5527)

This commit is contained in:
Kai Yang
2019-09-08 11:07:48 +08:00
committed by Hao Chen
parent cb7102f31e
commit d8f5804690
27 changed files with 364 additions and 244 deletions
@@ -176,8 +176,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
Object[] args, int numReturns, CallOptions options) {
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
functionArgs, numReturns, options);
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
@@ -190,8 +189,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
private RayObject callActorFunction(RayActor rayActor,
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) {
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
functionDescriptor, functionArgs, numReturns, null);
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
@@ -204,14 +202,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
private RayActor createActorImpl(FunctionDescriptor functionDescriptor,
Object[] args, ActorCreationOptions options) {
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions));
}
RayActor actor = taskSubmitter
.createActor(functionDescriptor, functionArgs,
options);
RayActor actor = taskSubmitter.createActor(functionDescriptor, functionArgs, options);
return actor;
}
@@ -1,7 +1,9 @@
package org.ray.runtime.object;
import com.google.common.base.Preconditions;
/**
* Binary representation of ray object.
* Binary representation of a ray object. See `RayObject` class in C++ for details.
*/
public class NativeRayObject {
@@ -9,8 +11,21 @@ public class NativeRayObject {
public byte[] metadata;
public NativeRayObject(byte[] data, byte[] metadata) {
Preconditions.checkState(bufferLength(data) > 0 || bufferLength(metadata) > 0);
this.data = data;
this.metadata = metadata;
}
private static int bufferLength(byte[] buffer) {
if (buffer == null) {
return 0;
}
return buffer.length;
}
@Override
public String toString() {
return "<data>: " + bufferLength(data) + ", <metadata>: " + bufferLength(metadata);
}
}
@@ -0,0 +1,83 @@
package org.ray.runtime.object;
import java.util.Arrays;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.ObjectId;
import org.ray.runtime.generated.Gcs.ErrorType;
import org.ray.runtime.util.Serializer;
/**
* Serialize to and deserialize from {@link NativeRayObject}. Metadata is generated during
* serialization and respected during deserialization.
*/
public class ObjectSerializer {
private static final byte[] WORKER_EXCEPTION_META = String
.valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes();
private static final byte[] ACTOR_EXCEPTION_META = String
.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
/**
* Deserialize an object from an {@link NativeRayObject} instance.
*
* @param nativeRayObject The object to deserialize.
* @param objectId The associated object ID of the object.
* @param classLoader The classLoader of the object.
* @return The deserialized object.
*/
public static Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId,
ClassLoader classLoader) {
byte[] meta = nativeRayObject.metadata;
byte[] data = nativeRayObject.data;
if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
if (Arrays.equals(meta, RAW_TYPE_META)) {
return data;
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
return RayWorkerException.INSTANCE;
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
return RayActorException.INSTANCE;
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new UnreconstructableException(objectId);
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
return Serializer.decode(data, classLoader);
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
} else {
// If data is not null, deserialize the Java object.
return Serializer.decode(data, classLoader);
}
}
/**
* Serialize an Java object to an {@link NativeRayObject} instance.
*
* @param object The object to serialize.
* @return The serialized object.
*/
public static NativeRayObject serialize(Object object) {
if (object instanceof NativeRayObject) {
return (NativeRayObject) object;
} else if (object instanceof byte[]) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
return new NativeRayObject((byte[]) object, RAW_TYPE_META);
} else if (object instanceof RayTaskException) {
return new NativeRayObject(Serializer.encode(object),
TASK_EXECUTION_EXCEPTION_META);
} else {
return new NativeRayObject(Serializer.encode(object), null);
}
}
}
@@ -2,40 +2,21 @@ package org.ray.runtime.object;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.ObjectId;
import org.ray.runtime.context.WorkerContext;
import org.ray.runtime.generated.Gcs.ErrorType;
import org.ray.runtime.util.Serializer;
/**
* A class that is used to put/get objects to/from the object store.
*/
public abstract class ObjectStore {
private static final byte[] WORKER_EXCEPTION_META = String
.valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes();
private static final byte[] ACTOR_EXCEPTION_META = String
.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
private final WorkerContext workerContext;
public ObjectStore(WorkerContext workerContext) {
@@ -65,7 +46,11 @@ public abstract class ObjectStore {
* @return Id of the object.
*/
public ObjectId put(Object object) {
return putRaw(serialize(object));
if (object instanceof NativeRayObject) {
throw new IllegalArgumentException(
"Trying to put a NativeRayObject. Please use putRaw instead.");
}
return putRaw(ObjectSerializer.serialize(object));
}
/**
@@ -77,7 +62,11 @@ public abstract class ObjectStore {
* @param objectId Object id.
*/
public void put(Object object, ObjectId objectId) {
putRaw(serialize(object), objectId);
if (object instanceof NativeRayObject) {
throw new IllegalArgumentException(
"Trying to put a NativeRayObject. Please use putRaw instead.");
}
putRaw(ObjectSerializer.serialize(object), objectId);
}
/**
@@ -106,7 +95,8 @@ public abstract class ObjectStore {
NativeRayObject dataAndMeta = dataAndMetaList.get(i);
Object object = null;
if (dataAndMeta != null) {
object = deserialize(dataAndMeta, ids.get(i));
object = ObjectSerializer
.deserialize(dataAndMeta, ids.get(i), workerContext.getCurrentClassLoader());
}
if (object instanceof RayException) {
// If the object is a `RayException`, it means that an error occurred during task
@@ -174,57 +164,4 @@ public abstract class ObjectStore {
*/
public abstract void delete(List<ObjectId> objectIds, boolean localOnly,
boolean deleteCreatingTasks);
/**
* Deserialize an object.
*
* @param nativeRayObject The object to deserialize.
* @param objectId The associated object ID of the object.
* @return The deserialized object.
*/
public Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId) {
byte[] meta = nativeRayObject.metadata;
byte[] data = nativeRayObject.data;
// If meta is not null, deserialize the object from meta.
if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
if (Arrays.equals(meta, RAW_TYPE_META)) {
return data;
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
return RayWorkerException.INSTANCE;
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
return RayActorException.INSTANCE;
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new UnreconstructableException(objectId);
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
return Serializer.decode(data, workerContext.getCurrentClassLoader());
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
} else {
// If data is not null, deserialize the Java object.
return Serializer.decode(data, workerContext.getCurrentClassLoader());
}
}
/**
* Serialize an object.
*
* @param object The object to serialize.
* @return The serialized object.
*/
public NativeRayObject serialize(Object object) {
if (object instanceof NativeRayObject) {
return (NativeRayObject) object;
} else if (object instanceof byte[]) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
return new NativeRayObject((byte[]) object, RAW_TYPE_META);
} else if (object instanceof RayTaskException) {
return new NativeRayObject(Serializer.encode(object),
TASK_EXECUTION_EXCEPTION_META);
} else {
return new NativeRayObject(Serializer.encode(object), null);
}
}
}
@@ -9,8 +9,7 @@ import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectStore;
import org.ray.runtime.util.Serializer;
import org.ray.runtime.object.ObjectSerializer;
/**
* Helper methods to convert arguments from/to objects.
@@ -26,37 +25,29 @@ public class ArgumentsBuilder {
/**
* Convert real function arguments to task spec arguments.
*/
public static List<FunctionArg> wrap(Object[] args, boolean crossLanguage) {
public static List<FunctionArg> wrap(Object[] args) {
List<FunctionArg> ret = new ArrayList<>();
for (Object arg : args) {
ObjectId id = null;
byte[] data = null;
if (arg == null) {
data = Serializer.encode(null);
} else if (arg instanceof RayObject) {
NativeRayObject value = null;
if (arg instanceof RayObject) {
id = ((RayObject) arg).getId();
} else if (arg instanceof byte[] && crossLanguage) {
// If the argument is a byte array and will be used by a different language,
// do not inline this argument. Because the other language doesn't know how
// to deserialize it.
id = Ray.put(arg).getId();
} else {
byte[] serialized = Serializer.encode(arg);
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
value = ObjectSerializer.serialize(arg);
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayMultiWorkerNativeRuntime) {
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
}
id = ((AbstractRayRuntime) runtime).getObjectStore()
.put(new NativeRayObject(serialized, null));
} else {
data = serialized;
.putRaw(value);
value = null;
}
}
if (id != null) {
ret.add(FunctionArg.passByReference(id));
} else {
ret.add(FunctionArg.passByValue(data));
ret.add(FunctionArg.passByValue(value));
}
}
return ret;
@@ -65,10 +56,10 @@ public class ArgumentsBuilder {
/**
* Convert list of NativeRayObject to real function arguments.
*/
public static Object[] unwrap(ObjectStore objectStore, List<NativeRayObject> args) {
public static Object[] unwrap(List<NativeRayObject> args, ClassLoader classLoader) {
Object[] realArgs = new Object[args.size()];
for (int i = 0; i < args.size(); i++) {
realArgs[i] = objectStore.deserialize(args.get(i), null);
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, classLoader);
}
return realArgs;
}
@@ -1,6 +1,8 @@
package org.ray.runtime.task;
import com.google.common.base.Preconditions;
import org.ray.api.id.ObjectId;
import org.ray.runtime.object.NativeRayObject;
/**
* Represents a function argument in task spec.
@@ -16,11 +18,12 @@ public class FunctionArg {
/**
* Serialized data of this argument (passed by value).
*/
public final byte[] data;
public final NativeRayObject value;
private FunctionArg(ObjectId id, byte[] data) {
private FunctionArg(ObjectId id, NativeRayObject value) {
Preconditions.checkState((id == null) != (value == null));
this.id = id;
this.data = data;
this.value = value;
}
/**
@@ -33,8 +36,8 @@ public class FunctionArg {
/**
* Create a FunctionArg that will be passed by value.
*/
public static FunctionArg passByValue(byte[] data) {
return new FunctionArg(null, data);
public static FunctionArg passByValue(NativeRayObject value) {
return new FunctionArg(null, value);
}
@Override
@@ -42,7 +45,7 @@ public class FunctionArg {
if (id != null) {
return "<id>: " + id.toString();
} else {
return "<data>: " + data.length;
return value.toString();
}
}
}
@@ -154,7 +154,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
.collect(Collectors.toList()))
.addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder()
.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build()
: TaskArg.newBuilder().setData(ByteString.copyFrom(arg.data)).build())
: TaskArg.newBuilder().setData(ByteString.copyFrom(arg.value.data))
.setMetadata(arg.value.metadata != null ? ByteString
.copyFrom(arg.value.metadata) : ByteString.EMPTY).build())
.collect(Collectors.toList()));
}
@@ -233,7 +235,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
.map(arg -> arg.id != null ?
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
: new NativeRayObject(arg.data, null))
: arg.value)
.collect(Collectors.toList());
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
List<NativeRayObject> returnObjects = taskExecutor
@@ -246,7 +248,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
// If the task is an actor task or an actor creation task,
// put the dummy object in object store, so those tasks which depends on it
// can be executed.
putObject = new NativeRayObject(new byte[]{}, new byte[]{});
putObject = new NativeRayObject(new byte[]{1}, null);
} else {
putObject = returnObjects.get(i);
}
@@ -279,7 +281,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
functionArgs.add(FunctionArg
.passByReference(new ObjectId(arg.getObjectIds(0).toByteArray())));
} else {
functionArgs.add(FunctionArg.passByValue(arg.getData().toByteArray()));
functionArgs.add(FunctionArg.passByValue(
new NativeRayObject(arg.getData().toByteArray(), arg.getMetadata().toByteArray())));
}
}
return functionArgs;
@@ -17,6 +17,7 @@ import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.generated.Common.TaskType;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -87,7 +88,7 @@ public final class TaskExecutor {
actor = currentActor;
}
Object[] args = ArgumentsBuilder.unwrap(runtime.getObjectStore(), argsBytes);
Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader);
// Execute the task.
Object result;
if (!rayFunction.isConstructor()) {
@@ -102,7 +103,7 @@ public final class TaskExecutor {
maybeSaveCheckpoint(actor, runtime.getWorkerContext().getCurrentActorId());
}
if (rayFunction.hasReturn()) {
returnObjects.add(runtime.getObjectStore().serialize(result));
returnObjects.add(ObjectSerializer.serialize(result));
}
} else {
// TODO (kfstorm): handle checkpoint in core worker.
@@ -113,8 +114,8 @@ public final class TaskExecutor {
} catch (Exception e) {
LOGGER.error("Error executing task " + taskId, e);
if (taskType != TaskType.ACTOR_CREATION_TASK) {
if(rayFunction.hasReturn()) {
returnObjects.add(runtime.getObjectStore()
if (rayFunction.hasReturn()) {
returnObjects.add(ObjectSerializer
.serialize(new RayTaskException("Error executing task " + taskId, e)));
}
} else {
@@ -3,9 +3,9 @@ package org.ray.api.test;
import org.ray.api.Ray;
import org.ray.api.RayPyActor;
import org.ray.api.TestUtils;
import org.ray.api.id.ObjectId;
import org.ray.runtime.context.WorkerContext;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectStore;
import org.ray.runtime.object.ObjectSerializer;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -14,10 +14,10 @@ public class RaySerializerTest extends BaseMultiLanguageTest {
@Test
public void testSerializePyActor() {
RayPyActor pyActor = Ray.createPyActor("test", "RaySerializerTest");
ObjectStore objectStore = TestUtils.getRuntime().getObjectStore();
NativeRayObject nativeRayObject = objectStore.serialize(pyActor);
RayPyActor result = (RayPyActor) objectStore
.deserialize(nativeRayObject, ObjectId.fromRandom());
WorkerContext workerContext = TestUtils.getRuntime().getWorkerContext();
NativeRayObject nativeRayObject = ObjectSerializer.serialize(pyActor);
RayPyActor result = (RayPyActor) ObjectSerializer
.deserialize(nativeRayObject, null, workerContext.getCurrentClassLoader());
Assert.assertEquals(result.getId(), pyActor.getId());
Assert.assertEquals(result.getModuleName(), "test");
Assert.assertEquals(result.getClassName(), "RaySerializerTest");
+5 -3
View File
@@ -54,8 +54,10 @@ cdef extern from "ray/common/task/task_spec.h" namespace "ray" nogil:
int ArgIdCount(uint64_t arg_index) const
CObjectID ArgId(uint64_t arg_index, uint64_t id_index) const
CObjectID ReturnId(uint64_t return_index) const
const uint8_t *ArgVal(uint64_t arg_index) const
size_t ArgValLength(uint64_t arg_index) const
const uint8_t *ArgData(uint64_t arg_index) const
size_t ArgDataSize(uint64_t arg_index) const
const uint8_t *ArgMetadata(uint64_t arg_index) const
size_t ArgMetadataSize(uint64_t arg_index) const
double GetRequiredResource(const c_string &resource_name) const
const ResourceSet GetRequiredResources() const
const ResourceSet GetRequiredPlacementResources() const
@@ -86,7 +88,7 @@ cdef extern from "ray/common/task/task_util.h" namespace "ray" nogil:
TaskSpecBuilder &AddByRefArg(const CObjectID &arg_id)
TaskSpecBuilder &AddByValueArg(const c_string &data)
TaskSpecBuilder &AddByValueArg(const c_string &data, const c_string &metadata)
TaskSpecBuilder &SetActorCreationTaskSpec(
const CActorID &actor_id, uint64_t max_reconstructions,
+10 -4
View File
@@ -12,6 +12,7 @@ from ray.includes.task cimport (
TaskSpecBuilder,
TaskTableData,
)
from ray.ray_constants import RAW_BUFFER_METADATA
from ray.utils import decode
@@ -68,10 +69,12 @@ cdef class TaskSpec:
for arg in arguments:
if isinstance(arg, ObjectID):
builder.AddByRefArg((<ObjectID>arg).native())
elif isinstance(arg, bytes):
builder.AddByValueArg(arg, RAW_BUFFER_METADATA)
else:
pickled_str = pickle.dumps(
arg, protocol=pickle.HIGHEST_PROTOCOL)
builder.AddByValueArg(pickled_str)
builder.AddByValueArg(pickled_str, b'')
if not actor_creation_id.is_nil():
# Actor creation task.
@@ -180,9 +183,12 @@ cdef class TaskSpec:
arg_list.append(
ObjectID(task_spec.ArgId(i, 0).Binary()))
else:
serialized_str = (
task_spec.ArgVal(i)[:task_spec.ArgValLength(i)])
obj = pickle.loads(serialized_str)
data = (task_spec.ArgData(i)[:task_spec.ArgDataSize(i)])
metadata = (task_spec.ArgMetadata(i)[:task_spec.ArgMetadataSize(i)])
if metadata == RAW_BUFFER_METADATA:
obj = data
else:
obj = pickle.loads(data)
arg_list.append(obj)
elif lang == <int32_t>LANGUAGE_JAVA:
arg_list = num_args * ["<java-argument>"]
+5
View File
@@ -39,6 +39,11 @@ class LocalMemoryBuffer : public Buffer {
public:
/// Constructor.
///
/// By default when initializing a LocalMemoryBuffer with a data pointer and a length,
/// it just assigns the pointer and length without coping the data content. This is
/// for performance reasons. In this case the buffer cannot ensure data validity. It
/// instead relies on the lifetime passed in data pointer.
///
/// \param data The data pointer to the passed-in buffer.
/// \param size The size of the passed in buffer.
/// \param copy_data If true, data will be copied and owned by this buffer,
+77
View File
@@ -0,0 +1,77 @@
#ifndef RAY_COMMON_RAY_OBJECT_H
#define RAY_COMMON_RAY_OBJECT_H
#include "ray/common/buffer.h"
#include "ray/util/logging.h"
namespace ray {
/// Binary representation of a ray object, consisting of buffer pointers to data and
/// metadata. A ray object may have both data and metadata, or only one of them.
class RayObject {
public:
/// Create a ray object instance.
///
/// Set `copy_data` to `false` is fine for most cases - for example when putting
/// an object into store with a temporary RayObject, and we don't want to do an extra
/// copy. But in some cases we do want to always hold a valid data - for example, memory
/// store uses RayObject to represent objects, in this case we actually want the object
/// data to remain valid after user puts it into store.
///
/// \param[in] data Data of the ray object.
/// \param[in] metadata Metadata of the ray object.
/// \param[in] copy_data Whether this class should hold a copy of data.
RayObject(const std::shared_ptr<Buffer> &data, const std::shared_ptr<Buffer> &metadata,
bool copy_data = false)
: data_(data), metadata_(metadata), has_data_copy_(copy_data) {
RAY_CHECK(!data || data_->Size())
<< "Zero-length buffers are not allowed when constructing a RayObject.";
RAY_CHECK(!metadata || metadata->Size())
<< "Zero-length buffers are not allowed when constructing a RayObject.";
if (has_data_copy_) {
// If this object is required to hold a copy of the data,
// make a copy if the passed in buffers don't already have a copy.
if (data_ && !data_->OwnsData()) {
data_ = std::make_shared<LocalMemoryBuffer>(data_->Data(), data_->Size(),
/*copy_data=*/true);
}
if (metadata_ && !metadata_->OwnsData()) {
metadata_ = std::make_shared<LocalMemoryBuffer>(
metadata_->Data(), metadata_->Size(), /*copy_data=*/true);
}
}
RAY_CHECK(data_ || metadata_) << "Data and metadata cannot both be empty.";
}
/// Return the data of the ray object.
const std::shared_ptr<Buffer> &GetData() const { return data_; };
/// Return the metadata of the ray object.
const std::shared_ptr<Buffer> &GetMetadata() const { return metadata_; };
uint64_t GetSize() const {
uint64_t size = 0;
size += (data_ != nullptr) ? data_->Size() : 0;
size += (metadata_ != nullptr) ? metadata_->Size() : 0;
return size;
}
/// Whether this object has data.
bool HasData() const { return data_ != nullptr; }
/// Whether this object has metadata.
bool HasMetadata() const { return metadata_ != nullptr; }
private:
std::shared_ptr<Buffer> data_;
std::shared_ptr<Buffer> metadata_;
/// Whether this class holds a data copy.
bool has_data_copy_;
};
} // namespace ray
#endif // RAY_COMMON_BUFFER_H
+10 -2
View File
@@ -53,14 +53,22 @@ ObjectID TaskSpecification::ArgId(size_t arg_index, size_t id_index) const {
return ObjectID::FromBinary(message_->args(arg_index).object_ids(id_index));
}
const uint8_t *TaskSpecification::ArgVal(size_t arg_index) const {
const uint8_t *TaskSpecification::ArgData(size_t arg_index) const {
return reinterpret_cast<const uint8_t *>(message_->args(arg_index).data().data());
}
size_t TaskSpecification::ArgValLength(size_t arg_index) const {
size_t TaskSpecification::ArgDataSize(size_t arg_index) const {
return message_->args(arg_index).data().size();
}
const uint8_t *TaskSpecification::ArgMetadata(size_t arg_index) const {
return reinterpret_cast<const uint8_t *>(message_->args(arg_index).metadata().data());
}
size_t TaskSpecification::ArgMetadataSize(size_t arg_index) const {
return message_->args(arg_index).metadata().size();
}
const ResourceSet TaskSpecification::GetRequiredResources() const {
return required_resources_;
}
+6 -2
View File
@@ -70,9 +70,13 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
ObjectID ReturnId(size_t return_index) const;
const uint8_t *ArgVal(size_t arg_index) const;
const uint8_t *ArgData(size_t arg_index) const;
size_t ArgValLength(size_t arg_index) const;
size_t ArgDataSize(size_t arg_index) const;
const uint8_t *ArgMetadata(size_t arg_index) const;
size_t ArgMetadataSize(size_t arg_index) const;
/// Return the resources that are to be acquired during the execution of this
/// task.
+18 -6
View File
@@ -1,6 +1,8 @@
#ifndef RAY_COMMON_TASK_TASK_UTIL_H
#define RAY_COMMON_TASK_TASK_UTIL_H
#include "ray/common/buffer.h"
#include "ray/common/ray_object.h"
#include "ray/common/task/task_spec.h"
#include "ray/protobuf/common.pb.h"
@@ -56,19 +58,29 @@ class TaskSpecBuilder {
/// Add a by-value argument to the task.
///
/// \param data String object that contains the data.
/// \param metadata String object that contains the metadata.
/// \return Reference to the builder object itself.
TaskSpecBuilder &AddByValueArg(const std::string &data) {
message_->add_args()->set_data(data);
TaskSpecBuilder &AddByValueArg(const std::string &data, const std::string &metadata) {
auto arg = message_->add_args();
arg->set_data(data);
arg->set_metadata(metadata);
return *this;
}
/// Add a by-value argument to the task.
///
/// \param data Pointer to the data.
/// \param size Size of the data.
/// \param value the RayObject instance that contains the data and the metadata.
/// \return Reference to the builder object itself.
TaskSpecBuilder &AddByValueArg(const void *data, size_t size) {
message_->add_args()->set_data(data, size);
TaskSpecBuilder &AddByValueArg(const RayObject &value) {
auto arg = message_->add_args();
if (value.HasData()) {
const auto &data = value.GetData();
arg->set_data(data->Data(), data->Size());
}
if (value.HasMetadata()) {
const auto &metadata = value.GetMetadata();
arg->set_metadata(metadata->Data(), metadata->Size());
}
return *this;
}
+14 -13
View File
@@ -3,8 +3,8 @@
#include <string>
#include "ray/common/buffer.h"
#include "ray/common/id.h"
#include "ray/common/ray_object.h"
#include "ray/common/task/task_spec.h"
#include "ray/raylet/raylet_client.h"
#include "ray/util/util.h"
@@ -31,12 +31,13 @@ class TaskArg {
return TaskArg(std::make_shared<ObjectID>(object_id), nullptr);
}
/// Create a pass-by-reference task argument.
/// Create a pass-by-value task argument.
///
/// \param[in] object_id Id of the argument.
/// \param[in] value Value of the argument.
/// \return The task argument.
static TaskArg PassByValue(const std::shared_ptr<Buffer> &data) {
return TaskArg(nullptr, data);
static TaskArg PassByValue(const std::shared_ptr<RayObject> &value) {
RAY_CHECK(value) << "Value can't be null.";
return TaskArg(nullptr, value);
}
/// Return true if this argument is passed by reference, false if passed by value.
@@ -49,19 +50,19 @@ class TaskArg {
}
/// Get the value.
std::shared_ptr<Buffer> GetValue() const {
RAY_CHECK(data_ != nullptr) << "This argument isn't passed by value.";
return data_;
const RayObject &GetValue() const {
RAY_CHECK(value_ != nullptr) << "This argument isn't passed by value.";
return *value_;
}
private:
TaskArg(const std::shared_ptr<ObjectID> id, const std::shared_ptr<Buffer> data)
: id_(id), data_(data) {}
TaskArg(const std::shared_ptr<ObjectID> id, const std::shared_ptr<RayObject> value)
: id_(id), value_(value) {}
/// Id of the argument, if passed by reference, otherwise nullptr.
/// Id of the argument if passed by reference, otherwise nullptr.
const std::shared_ptr<ObjectID> id_;
/// Data of the argument, if passed by value, otherwise nullptr.
const std::shared_ptr<Buffer> data_;
/// Value of the argument if passed by value, otherwise nullptr.
const std::shared_ptr<RayObject> value_;
};
enum class StoreProviderType { PLASMA, MEMORY };
+3 -2
View File
@@ -43,7 +43,7 @@ jmethodID java_language_get_number;
jclass java_function_arg_class;
jfieldID java_function_arg_id;
jfieldID java_function_arg_data;
jfieldID java_function_arg_value;
jclass java_base_task_options_class;
jfieldID java_base_task_options_resources;
@@ -137,7 +137,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
java_function_arg_class = LoadClass(env, "org/ray/runtime/task/FunctionArg");
java_function_arg_id =
env->GetFieldID(java_function_arg_class, "id", "Lorg/ray/api/id/ObjectId;");
java_function_arg_data = env->GetFieldID(java_function_arg_class, "data", "[B");
java_function_arg_value = env->GetFieldID(java_function_arg_class, "value",
"Lorg/ray/runtime/object/NativeRayObject;");
java_base_task_options_class = LoadClass(env, "org/ray/api/options/BaseTaskOptions");
java_base_task_options_resources =
+9 -8
View File
@@ -4,6 +4,7 @@
#include <jni.h>
#include "ray/common/buffer.h"
#include "ray/common/id.h"
#include "ray/common/ray_object.h"
#include "ray/common/status.h"
#include "ray/core_worker/store_provider/store_provider.h"
@@ -77,12 +78,12 @@ extern jclass java_language_class;
/// getNumber of Language class
extern jmethodID java_language_get_number;
/// NativeTaskArg class
/// FunctionArg class
extern jclass java_function_arg_class;
/// id field of NativeTaskArg class
/// id field of FunctionArg class
extern jfieldID java_function_arg_id;
/// data field of NativeTaskArg class
extern jfieldID java_function_arg_data;
/// value field of FunctionArg class
extern jfieldID java_function_arg_value;
/// BaseTaskOptions class
extern jclass java_base_task_options_class;
@@ -279,11 +280,11 @@ inline std::shared_ptr<ray::RayObject> JavaNativeRayObjectToNativeRayObject(
std::shared_ptr<ray::Buffer> data_buffer = JavaByteArrayToNativeBuffer(env, java_data);
std::shared_ptr<ray::Buffer> metadata_buffer =
JavaByteArrayToNativeBuffer(env, java_metadata);
if (!data_buffer) {
data_buffer = std::make_shared<ray::LocalMemoryBuffer>(nullptr, 0);
if (data_buffer && data_buffer->Size() == 0) {
data_buffer = nullptr;
}
if (!metadata_buffer) {
metadata_buffer = std::make_shared<ray::LocalMemoryBuffer>(nullptr, 0);
if (metadata_buffer && metadata_buffer->Size() == 0) {
metadata_buffer = nullptr;
}
return std::make_shared<ray::RayObject>(data_buffer, metadata_buffer);
}
@@ -34,10 +34,11 @@ inline std::vector<ray::TaskArg> ToTaskArgs(JNIEnv *env, jobject args) {
return ray::TaskArg::PassByReference(
JavaByteArrayToId<ray::ObjectID>(env, java_id_bytes));
}
auto java_data =
static_cast<jbyteArray>(env->GetObjectField(arg, java_function_arg_data));
RAY_CHECK(java_data) << "Both id and data of FunctionArg are null.";
return ray::TaskArg::PassByValue(JavaByteArrayToNativeBuffer(env, java_data));
auto java_value =
static_cast<jbyteArray>(env->GetObjectField(arg, java_function_arg_value));
RAY_CHECK(java_value) << "Both id and value of FunctionArg are null.";
auto value = JavaNativeRayObjectToNativeRayObject(env, java_value);
return ray::TaskArg::PassByValue(value);
});
return task_args;
}
@@ -69,9 +69,15 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore(
for (size_t i = 0; i < plasma_results.size(); i++) {
if (plasma_results[i].data != nullptr || plasma_results[i].metadata != nullptr) {
const auto &object_id = batch_ids[i];
const auto result_object = std::make_shared<RayObject>(
std::make_shared<PlasmaBuffer>(plasma_results[i].data),
std::make_shared<PlasmaBuffer>(plasma_results[i].metadata));
std::shared_ptr<PlasmaBuffer> data = nullptr;
std::shared_ptr<PlasmaBuffer> metadata = nullptr;
if (plasma_results[i].data && plasma_results[i].data->size()) {
data = std::make_shared<PlasmaBuffer>(plasma_results[i].data);
}
if (plasma_results[i].metadata && plasma_results[i].metadata->size()) {
metadata = std::make_shared<PlasmaBuffer>(plasma_results[i].metadata);
}
const auto result_object = std::make_shared<RayObject>(data, metadata);
(*results)[object_id] = result_object;
remaining.erase(object_id);
if (IsException(*result_object)) {
@@ -174,6 +180,9 @@ Status CoreWorkerPlasmaStoreProvider::Delete(const std::vector<ObjectID> &object
bool CoreWorkerPlasmaStoreProvider::IsException(const RayObject &object) {
// TODO (kfstorm): metadata should be structured.
if (!object.HasMetadata()) {
return false;
}
const std::string metadata(reinterpret_cast<const char *>(object.GetMetadata()->Data()),
object.GetMetadata()->Size());
const auto error_type_descriptor = ray::rpc::ErrorType_descriptor();
@@ -8,56 +8,6 @@
namespace ray {
/// Binary representation of a ray object.
class RayObject {
public:
/// Create a ray object instance.
///
/// \param[in] data Data of the ray object.
/// \param[in] metadata Metadata of the ray object.
/// \param[in] copy_data Whether this class should hold a copy of data.
RayObject(const std::shared_ptr<Buffer> &data, const std::shared_ptr<Buffer> &metadata,
bool copy_data = false)
: data_(data), metadata_(metadata), has_data_copy_(copy_data) {
if (has_data_copy_) {
// If this object is required to hold a copy of the data,
// make a copy if the passed in buffers don't already have a copy.
if (data_ && !data_->OwnsData()) {
data_ = std::make_shared<LocalMemoryBuffer>(data_->Data(), data_->Size(), true);
}
if (metadata_ && !metadata_->OwnsData()) {
metadata_ = std::make_shared<LocalMemoryBuffer>(metadata_->Data(),
metadata_->Size(), true);
}
}
}
/// Return the data of the ray object.
const std::shared_ptr<Buffer> &GetData() const { return data_; };
/// Return the metadata of the ray object.
const std::shared_ptr<Buffer> &GetMetadata() const { return metadata_; };
uint64_t GetSize() const {
uint64_t size = 0;
size += (data_ != nullptr) ? data_->Size() : 0;
size += (metadata_ != nullptr) ? metadata_->Size() : 0;
return size;
}
/// Whether this object has metadata.
bool HasMetadata() const { return metadata_ != nullptr && metadata_->Size() > 0; }
private:
/// Data of the ray object.
std::shared_ptr<Buffer> data_;
/// Metadata of the ray object.
std::shared_ptr<Buffer> metadata_;
/// Whether this class holds a data copy.
bool has_data_copy_;
};
/// Provider interface for store access. Store provider should inherit from this class and
/// provide implementions for the methods. The actual store provider may use a plasma
/// store or local memory store in worker process, or possibly other types of storage.
+11 -4
View File
@@ -89,10 +89,17 @@ Status CoreWorkerTaskExecutionInterface::BuildArgsForExecutor(
indices.push_back(i);
} else {
// pass by value.
(*args)[i] = std::make_shared<RayObject>(
std::make_shared<LocalMemoryBuffer>(const_cast<uint8_t *>(task.ArgVal(i)),
task.ArgValLength(i)),
nullptr);
std::shared_ptr<LocalMemoryBuffer> data = nullptr;
if (task.ArgDataSize(i)) {
data = std::make_shared<LocalMemoryBuffer>(const_cast<uint8_t *>(task.ArgData(i)),
task.ArgDataSize(i));
}
std::shared_ptr<LocalMemoryBuffer> metadata = nullptr;
if (task.ArgMetadataSize(i)) {
metadata = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(task.ArgMetadata(i)), task.ArgMetadataSize(i));
}
(*args)[i] = std::make_shared<RayObject>(data, metadata);
}
}
+1 -1
View File
@@ -130,7 +130,7 @@ void CoreWorkerTaskInterface::BuildCommonTaskSpec(
if (arg.IsPassedByReference()) {
builder.AddByRefArg(arg.GetReference());
} else {
builder.AddByValueArg(arg.GetValue()->Data(), arg.GetValue()->Size());
builder.AddByValueArg(arg.GetValue());
}
}
+21 -15
View File
@@ -3,6 +3,7 @@
#include "gtest/gtest.h"
#include "ray/common/buffer.h"
#include "ray/common/ray_object.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/core_worker.h"
#include "ray/core_worker/transport/direct_actor_transport.h"
@@ -59,7 +60,7 @@ std::unique_ptr<ActorHandle> CreateActorHelper(
RayFunction func{ray::Language::PYTHON, {"actor creation task"}};
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByValue(buffer));
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr)));
ActorCreationOptions actor_options{max_reconstructions, is_direct_call, resources, {}};
@@ -232,7 +233,8 @@ void CoreWorkerTest::TestNormalTask(
RAY_CHECK_OK(driver.Objects().Put(RayObject(buffer2, nullptr), &object_id));
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByValue(buffer1));
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(buffer1, nullptr)));
args.emplace_back(TaskArg::PassByReference(object_id));
RayFunction func{ray::Language::PYTHON, {}};
@@ -273,8 +275,10 @@ void CoreWorkerTest::TestActorTask(
// Create arguments with PassByRef and PassByValue.
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByValue(buffer1));
args.emplace_back(TaskArg::PassByValue(buffer2));
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(buffer1, nullptr)));
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(buffer2, nullptr)));
TaskOptions options{1, resources};
std::vector<ObjectID> return_ids;
@@ -315,7 +319,8 @@ void CoreWorkerTest::TestActorTask(
// Create arguments with PassByRef and PassByValue.
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByReference(object_id));
args.emplace_back(TaskArg::PassByValue(buffer2));
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(buffer2, nullptr)));
TaskOptions options{1, resources};
std::vector<ObjectID> return_ids;
@@ -380,7 +385,8 @@ void CoreWorkerTest::TestActorReconstruction(
// Create arguments with PassByValue.
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByValue(buffer1));
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(buffer1, nullptr)));
TaskOptions options{1, resources};
std::vector<ObjectID> return_ids;
@@ -425,7 +431,8 @@ void CoreWorkerTest::TestActorFailure(
// Create arguments with PassByRef and PassByValue.
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByValue(buffer1));
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(buffer1, nullptr)));
TaskOptions options{1, resources};
std::vector<ObjectID> return_ids;
@@ -618,11 +625,10 @@ TEST_F(ZeroNodeTest, TestTaskArg) {
ASSERT_TRUE(by_ref.IsPassedByReference());
ASSERT_EQ(by_ref.GetReference(), id);
// Test by-value argument.
std::shared_ptr<LocalMemoryBuffer> buffer =
std::make_shared<LocalMemoryBuffer>(static_cast<uint8_t *>(0), 0);
TaskArg by_value = TaskArg::PassByValue(buffer);
auto buffer = GenerateRandomBuffer();
TaskArg by_value = TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr));
ASSERT_FALSE(by_value.IsPassedByReference());
auto data = by_value.GetValue();
auto data = by_value.GetValue().GetData();
ASSERT_TRUE(data != nullptr);
ASSERT_EQ(*data, *buffer);
}
@@ -635,7 +641,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
auto buffer = std::make_shared<LocalMemoryBuffer>(array, sizeof(array));
RayFunction function{ray::Language::PYTHON, {}};
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByValue(buffer));
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr)));
std::unordered_map<std::string, double> resources;
ActorCreationOptions actor_options{0, /*is_direct_call*/ true, resources, {}};
@@ -664,7 +670,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
if (arg.IsPassedByReference()) {
builder.AddByRefArg(arg.GetReference());
} else {
builder.AddByValueArg(arg.GetValue()->Data(), arg.GetValue()->Size());
builder.AddByValueArg(arg.GetValue());
}
}
@@ -696,7 +702,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) {
auto buffer = std::make_shared<LocalMemoryBuffer>(array, sizeof(array));
RayFunction func{ray::Language::PYTHON, {}};
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByValue(buffer));
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr)));
std::unordered_map<std::string, double> resources;
ActorCreationOptions actor_options{0, /*is_direct_call*/ true, resources, {}};
@@ -712,7 +718,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) {
for (int i = 0; i < num_tasks; i++) {
// Create arguments with PassByValue.
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByValue(buffer));
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr)));
TaskOptions options{1, resources};
std::vector<ObjectID> return_ids;
@@ -224,10 +224,10 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask(
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT_ACTOR));
return_object->set_object_id(id.Binary());
const auto &result = results[i];
if (result->GetData() != nullptr) {
if (result->HasData()) {
return_object->set_data(result->GetData()->Data(), result->GetData()->Size());
}
if (result->GetMetadata() != nullptr) {
if (result->HasMetadata()) {
return_object->set_metadata(result->GetMetadata()->Data(),
result->GetMetadata()->Size());
}
+2
View File
@@ -74,6 +74,8 @@ message TaskArg {
repeated bytes object_ids = 1;
// Data for pass-by-value arguments.
bytes data = 2;
// Metadata for pass-by-value arguments.
bytes metadata = 3;
}
// Task spec of an actor creation task.