Support metadata for passing by value task arguments (#5527)

This commit is contained in:
Kai Yang
2019-09-08 11:07:48 +08:00
committed by Hao Chen
parent cb7102f31e
commit d8f5804690
27 changed files with 364 additions and 244 deletions
@@ -176,8 +176,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
Object[] args, int numReturns, CallOptions options) {
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
functionArgs, numReturns, options);
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
@@ -190,8 +189,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
private RayObject callActorFunction(RayActor rayActor,
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) {
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
functionDescriptor, functionArgs, numReturns, null);
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
@@ -204,14 +202,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
private RayActor createActorImpl(FunctionDescriptor functionDescriptor,
Object[] args, ActorCreationOptions options) {
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args);
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions));
}
RayActor actor = taskSubmitter
.createActor(functionDescriptor, functionArgs,
options);
RayActor actor = taskSubmitter.createActor(functionDescriptor, functionArgs, options);
return actor;
}
@@ -1,7 +1,9 @@
package org.ray.runtime.object;
import com.google.common.base.Preconditions;
/**
* Binary representation of ray object.
* Binary representation of a ray object. See `RayObject` class in C++ for details.
*/
public class NativeRayObject {
@@ -9,8 +11,21 @@ public class NativeRayObject {
public byte[] metadata;
public NativeRayObject(byte[] data, byte[] metadata) {
Preconditions.checkState(bufferLength(data) > 0 || bufferLength(metadata) > 0);
this.data = data;
this.metadata = metadata;
}
private static int bufferLength(byte[] buffer) {
if (buffer == null) {
return 0;
}
return buffer.length;
}
@Override
public String toString() {
return "<data>: " + bufferLength(data) + ", <metadata>: " + bufferLength(metadata);
}
}
@@ -0,0 +1,83 @@
package org.ray.runtime.object;
import java.util.Arrays;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.ObjectId;
import org.ray.runtime.generated.Gcs.ErrorType;
import org.ray.runtime.util.Serializer;
/**
* Serialize to and deserialize from {@link NativeRayObject}. Metadata is generated during
* serialization and respected during deserialization.
*/
public class ObjectSerializer {
private static final byte[] WORKER_EXCEPTION_META = String
.valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes();
private static final byte[] ACTOR_EXCEPTION_META = String
.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
/**
* Deserialize an object from an {@link NativeRayObject} instance.
*
* @param nativeRayObject The object to deserialize.
* @param objectId The associated object ID of the object.
* @param classLoader The classLoader of the object.
* @return The deserialized object.
*/
public static Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId,
ClassLoader classLoader) {
byte[] meta = nativeRayObject.metadata;
byte[] data = nativeRayObject.data;
if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
if (Arrays.equals(meta, RAW_TYPE_META)) {
return data;
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
return RayWorkerException.INSTANCE;
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
return RayActorException.INSTANCE;
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new UnreconstructableException(objectId);
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
return Serializer.decode(data, classLoader);
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
} else {
// If data is not null, deserialize the Java object.
return Serializer.decode(data, classLoader);
}
}
/**
* Serialize an Java object to an {@link NativeRayObject} instance.
*
* @param object The object to serialize.
* @return The serialized object.
*/
public static NativeRayObject serialize(Object object) {
if (object instanceof NativeRayObject) {
return (NativeRayObject) object;
} else if (object instanceof byte[]) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
return new NativeRayObject((byte[]) object, RAW_TYPE_META);
} else if (object instanceof RayTaskException) {
return new NativeRayObject(Serializer.encode(object),
TASK_EXECUTION_EXCEPTION_META);
} else {
return new NativeRayObject(Serializer.encode(object), null);
}
}
}
@@ -2,40 +2,21 @@ package org.ray.runtime.object;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.ObjectId;
import org.ray.runtime.context.WorkerContext;
import org.ray.runtime.generated.Gcs.ErrorType;
import org.ray.runtime.util.Serializer;
/**
* A class that is used to put/get objects to/from the object store.
*/
public abstract class ObjectStore {
private static final byte[] WORKER_EXCEPTION_META = String
.valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes();
private static final byte[] ACTOR_EXCEPTION_META = String
.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
private final WorkerContext workerContext;
public ObjectStore(WorkerContext workerContext) {
@@ -65,7 +46,11 @@ public abstract class ObjectStore {
* @return Id of the object.
*/
public ObjectId put(Object object) {
return putRaw(serialize(object));
if (object instanceof NativeRayObject) {
throw new IllegalArgumentException(
"Trying to put a NativeRayObject. Please use putRaw instead.");
}
return putRaw(ObjectSerializer.serialize(object));
}
/**
@@ -77,7 +62,11 @@ public abstract class ObjectStore {
* @param objectId Object id.
*/
public void put(Object object, ObjectId objectId) {
putRaw(serialize(object), objectId);
if (object instanceof NativeRayObject) {
throw new IllegalArgumentException(
"Trying to put a NativeRayObject. Please use putRaw instead.");
}
putRaw(ObjectSerializer.serialize(object), objectId);
}
/**
@@ -106,7 +95,8 @@ public abstract class ObjectStore {
NativeRayObject dataAndMeta = dataAndMetaList.get(i);
Object object = null;
if (dataAndMeta != null) {
object = deserialize(dataAndMeta, ids.get(i));
object = ObjectSerializer
.deserialize(dataAndMeta, ids.get(i), workerContext.getCurrentClassLoader());
}
if (object instanceof RayException) {
// If the object is a `RayException`, it means that an error occurred during task
@@ -174,57 +164,4 @@ public abstract class ObjectStore {
*/
public abstract void delete(List<ObjectId> objectIds, boolean localOnly,
boolean deleteCreatingTasks);
/**
* Deserialize an object.
*
* @param nativeRayObject The object to deserialize.
* @param objectId The associated object ID of the object.
* @return The deserialized object.
*/
public Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId) {
byte[] meta = nativeRayObject.metadata;
byte[] data = nativeRayObject.data;
// If meta is not null, deserialize the object from meta.
if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
if (Arrays.equals(meta, RAW_TYPE_META)) {
return data;
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
return RayWorkerException.INSTANCE;
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
return RayActorException.INSTANCE;
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new UnreconstructableException(objectId);
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
return Serializer.decode(data, workerContext.getCurrentClassLoader());
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
} else {
// If data is not null, deserialize the Java object.
return Serializer.decode(data, workerContext.getCurrentClassLoader());
}
}
/**
* Serialize an object.
*
* @param object The object to serialize.
* @return The serialized object.
*/
public NativeRayObject serialize(Object object) {
if (object instanceof NativeRayObject) {
return (NativeRayObject) object;
} else if (object instanceof byte[]) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
return new NativeRayObject((byte[]) object, RAW_TYPE_META);
} else if (object instanceof RayTaskException) {
return new NativeRayObject(Serializer.encode(object),
TASK_EXECUTION_EXCEPTION_META);
} else {
return new NativeRayObject(Serializer.encode(object), null);
}
}
}
@@ -9,8 +9,7 @@ import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectStore;
import org.ray.runtime.util.Serializer;
import org.ray.runtime.object.ObjectSerializer;
/**
* Helper methods to convert arguments from/to objects.
@@ -26,37 +25,29 @@ public class ArgumentsBuilder {
/**
* Convert real function arguments to task spec arguments.
*/
public static List<FunctionArg> wrap(Object[] args, boolean crossLanguage) {
public static List<FunctionArg> wrap(Object[] args) {
List<FunctionArg> ret = new ArrayList<>();
for (Object arg : args) {
ObjectId id = null;
byte[] data = null;
if (arg == null) {
data = Serializer.encode(null);
} else if (arg instanceof RayObject) {
NativeRayObject value = null;
if (arg instanceof RayObject) {
id = ((RayObject) arg).getId();
} else if (arg instanceof byte[] && crossLanguage) {
// If the argument is a byte array and will be used by a different language,
// do not inline this argument. Because the other language doesn't know how
// to deserialize it.
id = Ray.put(arg).getId();
} else {
byte[] serialized = Serializer.encode(arg);
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
value = ObjectSerializer.serialize(arg);
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayMultiWorkerNativeRuntime) {
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
}
id = ((AbstractRayRuntime) runtime).getObjectStore()
.put(new NativeRayObject(serialized, null));
} else {
data = serialized;
.putRaw(value);
value = null;
}
}
if (id != null) {
ret.add(FunctionArg.passByReference(id));
} else {
ret.add(FunctionArg.passByValue(data));
ret.add(FunctionArg.passByValue(value));
}
}
return ret;
@@ -65,10 +56,10 @@ public class ArgumentsBuilder {
/**
* Convert list of NativeRayObject to real function arguments.
*/
public static Object[] unwrap(ObjectStore objectStore, List<NativeRayObject> args) {
public static Object[] unwrap(List<NativeRayObject> args, ClassLoader classLoader) {
Object[] realArgs = new Object[args.size()];
for (int i = 0; i < args.size(); i++) {
realArgs[i] = objectStore.deserialize(args.get(i), null);
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, classLoader);
}
return realArgs;
}
@@ -1,6 +1,8 @@
package org.ray.runtime.task;
import com.google.common.base.Preconditions;
import org.ray.api.id.ObjectId;
import org.ray.runtime.object.NativeRayObject;
/**
* Represents a function argument in task spec.
@@ -16,11 +18,12 @@ public class FunctionArg {
/**
* Serialized data of this argument (passed by value).
*/
public final byte[] data;
public final NativeRayObject value;
private FunctionArg(ObjectId id, byte[] data) {
private FunctionArg(ObjectId id, NativeRayObject value) {
Preconditions.checkState((id == null) != (value == null));
this.id = id;
this.data = data;
this.value = value;
}
/**
@@ -33,8 +36,8 @@ public class FunctionArg {
/**
* Create a FunctionArg that will be passed by value.
*/
public static FunctionArg passByValue(byte[] data) {
return new FunctionArg(null, data);
public static FunctionArg passByValue(NativeRayObject value) {
return new FunctionArg(null, value);
}
@Override
@@ -42,7 +45,7 @@ public class FunctionArg {
if (id != null) {
return "<id>: " + id.toString();
} else {
return "<data>: " + data.length;
return value.toString();
}
}
}
@@ -154,7 +154,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
.collect(Collectors.toList()))
.addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder()
.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build()
: TaskArg.newBuilder().setData(ByteString.copyFrom(arg.data)).build())
: TaskArg.newBuilder().setData(ByteString.copyFrom(arg.value.data))
.setMetadata(arg.value.metadata != null ? ByteString
.copyFrom(arg.value.metadata) : ByteString.EMPTY).build())
.collect(Collectors.toList()));
}
@@ -233,7 +235,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
.map(arg -> arg.id != null ?
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
: new NativeRayObject(arg.data, null))
: arg.value)
.collect(Collectors.toList());
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
List<NativeRayObject> returnObjects = taskExecutor
@@ -246,7 +248,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
// If the task is an actor task or an actor creation task,
// put the dummy object in object store, so those tasks which depends on it
// can be executed.
putObject = new NativeRayObject(new byte[]{}, new byte[]{});
putObject = new NativeRayObject(new byte[]{1}, null);
} else {
putObject = returnObjects.get(i);
}
@@ -279,7 +281,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
functionArgs.add(FunctionArg
.passByReference(new ObjectId(arg.getObjectIds(0).toByteArray())));
} else {
functionArgs.add(FunctionArg.passByValue(arg.getData().toByteArray()));
functionArgs.add(FunctionArg.passByValue(
new NativeRayObject(arg.getData().toByteArray(), arg.getMetadata().toByteArray())));
}
}
return functionArgs;
@@ -17,6 +17,7 @@ import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.generated.Common.TaskType;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -87,7 +88,7 @@ public final class TaskExecutor {
actor = currentActor;
}
Object[] args = ArgumentsBuilder.unwrap(runtime.getObjectStore(), argsBytes);
Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader);
// Execute the task.
Object result;
if (!rayFunction.isConstructor()) {
@@ -102,7 +103,7 @@ public final class TaskExecutor {
maybeSaveCheckpoint(actor, runtime.getWorkerContext().getCurrentActorId());
}
if (rayFunction.hasReturn()) {
returnObjects.add(runtime.getObjectStore().serialize(result));
returnObjects.add(ObjectSerializer.serialize(result));
}
} else {
// TODO (kfstorm): handle checkpoint in core worker.
@@ -113,8 +114,8 @@ public final class TaskExecutor {
} catch (Exception e) {
LOGGER.error("Error executing task " + taskId, e);
if (taskType != TaskType.ACTOR_CREATION_TASK) {
if(rayFunction.hasReturn()) {
returnObjects.add(runtime.getObjectStore()
if (rayFunction.hasReturn()) {
returnObjects.add(ObjectSerializer
.serialize(new RayTaskException("Error executing task " + taskId, e)));
}
} else {
@@ -3,9 +3,9 @@ package org.ray.api.test;
import org.ray.api.Ray;
import org.ray.api.RayPyActor;
import org.ray.api.TestUtils;
import org.ray.api.id.ObjectId;
import org.ray.runtime.context.WorkerContext;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectStore;
import org.ray.runtime.object.ObjectSerializer;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -14,10 +14,10 @@ public class RaySerializerTest extends BaseMultiLanguageTest {
@Test
public void testSerializePyActor() {
RayPyActor pyActor = Ray.createPyActor("test", "RaySerializerTest");
ObjectStore objectStore = TestUtils.getRuntime().getObjectStore();
NativeRayObject nativeRayObject = objectStore.serialize(pyActor);
RayPyActor result = (RayPyActor) objectStore
.deserialize(nativeRayObject, ObjectId.fromRandom());
WorkerContext workerContext = TestUtils.getRuntime().getWorkerContext();
NativeRayObject nativeRayObject = ObjectSerializer.serialize(pyActor);
RayPyActor result = (RayPyActor) ObjectSerializer
.deserialize(nativeRayObject, null, workerContext.getCurrentClassLoader());
Assert.assertEquals(result.getId(), pyActor.getId());
Assert.assertEquals(result.getModuleName(), "test");
Assert.assertEquals(result.getClassName(), "RaySerializerTest");