mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 04:07:01 +08:00
Cross language serialization for primitive types (#7711)
* Cross language serialization for Java and Python * Use strict types when Python serializing * Handle recursive objects in Python; Pin msgpack >= 0.6.0, < 1.0.0 * Disable gc for optimizing msgpack loads * Fix merge bug * Java call Python use returnType; Fix ClassLoaderTest * Fix RayMethodsTest * Fix checkstyle * Fix lint * prepare_args raises exception if try to transfer a non-deserializable object to another language * Fix CrossLanguageInvocationTest.java, Python msgpack treat float as double * Minor fixes * Fix compile error on linux * Fix lint in java/BUILD.bazel * Fix test_failure * Fix lint * Class<?> to Class<T>; Refine metadata bytes. * Rename FST to Fst; sort java dependencies * Change Class<?>[] to Optional<Class<?>>; sort requirements in setup.py * Improve CrossLanguageInvocationTest * Refactor MessagePackSerializer.java * Refactor MessagePackSerializer.java; Refine CrossLanguageInvocationTest.java * Remove unnecessary dependencies for Java; Add getReturnType() for RayFunction in Java * Fix bug * Remove custom cross language type support * Replace Serializer.Meta with MutableBoolean * Remove @SuppressWarnings support from checkstyle.xml; Add null test in CrossLanguageInvocationTest.java * Refine MessagePackSerializer.pack * Ray.get support RayObject as input * Improve comments and error info * Remove classLoader argument from serializer * Separate msgpack from pickle5 in Python * Pair<byte[], MutableBoolean> to Pair<byte[], Boolean> * Remove public static <T> T get(RayObject<T> object), use RayObject.get() instead * Refine test * small fixes Co-authored-by: 刘宝 <po.lb@antfin.com> Co-authored-by: Hao Chen <chenh1024@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.Callable;
|
||||
import org.ray.api.BaseActor;
|
||||
import org.ray.api.RayActor;
|
||||
@@ -15,7 +16,6 @@ import org.ray.api.function.PyActorClass;
|
||||
import org.ray.api.function.PyActorMethod;
|
||||
import org.ray.api.function.PyRemoteFunction;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.function.RayFuncVoid;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
@@ -26,6 +26,7 @@ import org.ray.runtime.context.WorkerContext;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
@@ -73,18 +74,18 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
@Override
|
||||
public <T> RayObject<T> put(T obj) {
|
||||
ObjectId objectId = objectStore.put(obj);
|
||||
return new RayObjectImpl<>(objectId);
|
||||
return new RayObjectImpl<T>(objectId, (Class<T>)(obj == null ? Object.class : obj.getClass()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T get(ObjectId objectId) throws RayException {
|
||||
List<T> ret = get(ImmutableList.of(objectId));
|
||||
public <T> T get(ObjectId objectId, Class<T> objectType) throws RayException {
|
||||
List<T> ret = get(ImmutableList.of(objectId), objectType);
|
||||
return ret.get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> List<T> get(List<ObjectId> objectIds) {
|
||||
return objectStore.get(objectIds);
|
||||
public <T> List<T> get(List<ObjectId> objectIds, Class<T> objectType) {
|
||||
return objectStore.get(objectIds, objectType);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -99,41 +100,39 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
|
||||
@Override
|
||||
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
|
||||
FunctionDescriptor functionDescriptor =
|
||||
functionManager.getFunction(workerContext.getCurrentJobId(), func)
|
||||
.functionDescriptor;
|
||||
int numReturns = func instanceof RayFuncVoid ? 0 : 1;
|
||||
return callNormalFunction(functionDescriptor, args, numReturns, options);
|
||||
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func);
|
||||
FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor;
|
||||
Optional<Class<?>> returnType = rayFunction.getReturnType();
|
||||
return callNormalFunction(functionDescriptor, args, returnType, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args,
|
||||
CallOptions options) {
|
||||
checkPyArguments(args);
|
||||
CallOptions options) {
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
|
||||
pyRemoteFunction.moduleName,
|
||||
"",
|
||||
pyRemoteFunction.functionName);
|
||||
// Python functions always have a return value, even if it's `None`.
|
||||
return callNormalFunction(functionDescriptor, args, /*numReturns=*/1, options);
|
||||
return callNormalFunction(functionDescriptor, args,
|
||||
/*returnType=*/Optional.of(pyRemoteFunction.returnType), options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayActor<?> actor, RayFunc func, Object[] args) {
|
||||
FunctionDescriptor functionDescriptor =
|
||||
functionManager.getFunction(workerContext.getCurrentJobId(), func)
|
||||
.functionDescriptor;
|
||||
int numReturns = func instanceof RayFuncVoid ? 0 : 1;
|
||||
return callActorFunction(actor, functionDescriptor, args, numReturns);
|
||||
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func);
|
||||
FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor;
|
||||
Optional<Class<?>> returnType = rayFunction.getReturnType();
|
||||
return callActorFunction(actor, functionDescriptor, args, returnType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object... args) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
|
||||
pyActor.getClassName(), pyActorMethod.methodName);
|
||||
// Python functions always have a return value, even if it's `None`.
|
||||
return callActorFunction(pyActor, functionDescriptor, args, /*numReturns=*/1);
|
||||
return callActorFunction(pyActor, functionDescriptor, args,
|
||||
/*returnType=*/Optional.of(pyActorMethod.returnType));
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -148,8 +147,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
|
||||
@Override
|
||||
public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
checkPyArguments(args);
|
||||
ActorCreationOptions options) {
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
|
||||
pyActorClass.moduleName,
|
||||
pyActorClass.className,
|
||||
@@ -157,14 +155,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
return (RayPyActor) createActorImpl(functionDescriptor, args, options);
|
||||
}
|
||||
|
||||
private void checkPyArguments(Object[] args) {
|
||||
for (Object arg : args) {
|
||||
Preconditions.checkArgument(
|
||||
(arg instanceof RayPyActor) || (arg instanceof byte[]),
|
||||
"Python argument can only be a RayPyActor or a byte array, not {}.",
|
||||
arg.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
@@ -218,30 +208,32 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
}
|
||||
|
||||
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
|
||||
Object[] args, int numReturns, CallOptions options) {
|
||||
Object[] args, Optional<Class<?>> returnType, CallOptions options) {
|
||||
int numReturns = returnType.isPresent() ? 1 : 0;
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
.wrap(args, functionDescriptor.getLanguage());
|
||||
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
|
||||
functionArgs, numReturns, options);
|
||||
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
|
||||
Preconditions.checkState(returnIds.size() == numReturns);
|
||||
if (returnIds.isEmpty()) {
|
||||
return null;
|
||||
} else {
|
||||
return new RayObjectImpl(returnIds.get(0));
|
||||
return new RayObjectImpl(returnIds.get(0), returnType.get());
|
||||
}
|
||||
}
|
||||
|
||||
private RayObject callActorFunction(BaseActor rayActor,
|
||||
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) {
|
||||
FunctionDescriptor functionDescriptor, Object[] args, Optional<Class<?>> returnType) {
|
||||
int numReturns = returnType.isPresent() ? 1 : 0;
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
.wrap(args, functionDescriptor.getLanguage());
|
||||
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
|
||||
functionDescriptor, functionArgs, numReturns, null);
|
||||
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
|
||||
Preconditions.checkState(returnIds.size() == numReturns);
|
||||
if (returnIds.isEmpty()) {
|
||||
return null;
|
||||
} else {
|
||||
return new RayObjectImpl(returnIds.get(0));
|
||||
return new RayObjectImpl(returnIds.get(0), returnType.get());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
import org.ray.runtime.serializer.Serializer;
|
||||
|
||||
/**
|
||||
* The context of worker.
|
||||
@@ -28,7 +29,7 @@ public interface WorkerContext {
|
||||
|
||||
/**
|
||||
* The class loader that is associated with the current job. It's used for locating classes when
|
||||
* dealing with serialization and deserialization in {@link org.ray.runtime.util.Serializer}.
|
||||
* dealing with serialization and deserialization in {@link Serializer}.
|
||||
*/
|
||||
ClassLoader getCurrentClassLoader();
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package org.ray.runtime.functionmanager;
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Executable;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Represents a Ray function (either a Method or a Constructor in Java) and its metadata.
|
||||
@@ -67,6 +68,17 @@ public class RayFunction {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Return type.
|
||||
*/
|
||||
public Optional<Class<?>> getReturnType() {
|
||||
if (hasReturn()) {
|
||||
return Optional.of(((Method) executable).getReturnType());
|
||||
} else {
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return executable.toString();
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package org.ray.runtime.object;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.api.exception.RayActorException;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.exception.RayWorkerException;
|
||||
import org.ray.api.exception.UnreconstructableException;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.generated.Gcs.ErrorType;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
import org.ray.runtime.serializer.Serializer;
|
||||
|
||||
/**
|
||||
* Serialize to and deserialize from {@link NativeRayObject}. Metadata is generated during
|
||||
@@ -21,29 +23,33 @@ public class ObjectSerializer {
|
||||
.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
|
||||
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
|
||||
public static final byte[] OBJECT_METADATA_TYPE_CROSS_LANGUAGE = "XLANG".getBytes();
|
||||
public static final byte[] OBJECT_METADATA_TYPE_JAVA = "JAVA".getBytes();
|
||||
public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes();
|
||||
public static final byte[] OBJECT_METADATA_TYPE_RAW = "RAW".getBytes();
|
||||
|
||||
/**
|
||||
* Deserialize an object from an {@link NativeRayObject} instance.
|
||||
*
|
||||
* @param nativeRayObject The object to deserialize.
|
||||
* @param objectId The associated object ID of the object.
|
||||
* @param classLoader The classLoader of the object.
|
||||
* @return The deserialized object.
|
||||
*/
|
||||
public static Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId,
|
||||
ClassLoader classLoader) {
|
||||
Class<?> objectType) {
|
||||
byte[] meta = nativeRayObject.metadata;
|
||||
byte[] data = nativeRayObject.data;
|
||||
|
||||
if (meta != null && meta.length > 0) {
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
if (Arrays.equals(meta, RAW_TYPE_META)) {
|
||||
if (Arrays.equals(meta, OBJECT_METADATA_TYPE_RAW)) {
|
||||
return data;
|
||||
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_CROSS_LANGUAGE) ||
|
||||
Arrays.equals(meta, OBJECT_METADATA_TYPE_JAVA)) {
|
||||
return Serializer.decode(data, objectType);
|
||||
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
return new RayWorkerException();
|
||||
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
|
||||
@@ -51,12 +57,15 @@ public class ObjectSerializer {
|
||||
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
|
||||
return new UnreconstructableException(objectId);
|
||||
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
|
||||
return Serializer.decode(data, classLoader);
|
||||
return Serializer.decode(data, objectType);
|
||||
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_PYTHON)) {
|
||||
throw new IllegalArgumentException("Can't deserialize Python object: " + objectId
|
||||
.toString());
|
||||
}
|
||||
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
|
||||
} else {
|
||||
// If data is not null, deserialize the Java object.
|
||||
return Serializer.decode(data, classLoader);
|
||||
return Serializer.decode(data, objectType);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,12 +81,14 @@ public class ObjectSerializer {
|
||||
} else if (object instanceof byte[]) {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
return new NativeRayObject((byte[]) object, RAW_TYPE_META);
|
||||
return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW);
|
||||
} else if (object instanceof RayTaskException) {
|
||||
return new NativeRayObject(Serializer.encode(object),
|
||||
TASK_EXECUTION_EXCEPTION_META);
|
||||
byte[] serializedBytes = Serializer.encode(object).getLeft();
|
||||
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);
|
||||
} else {
|
||||
return new NativeRayObject(Serializer.encode(object), null);
|
||||
Pair<byte[], Boolean> serialized = Serializer.encode(object);
|
||||
return new NativeRayObject(serialized.getLeft(), serialized.getRight() ?
|
||||
OBJECT_METADATA_TYPE_CROSS_LANGUAGE : OBJECT_METADATA_TYPE_JAVA);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ public abstract class ObjectStore {
|
||||
* @return A list of GetResult objects.
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> List<T> get(List<ObjectId> ids) {
|
||||
public <T> List<T> get(List<ObjectId> ids, Class<?> elementType) {
|
||||
// Pass -1 as timeout to wait until all objects are available in object store.
|
||||
List<NativeRayObject> dataAndMetaList = getRaw(ids, -1);
|
||||
|
||||
@@ -96,7 +96,7 @@ public abstract class ObjectStore {
|
||||
Object object = null;
|
||||
if (dataAndMeta != null) {
|
||||
object = ObjectSerializer
|
||||
.deserialize(dataAndMeta, ids.get(i), workerContext.getCurrentClassLoader());
|
||||
.deserialize(dataAndMeta, ids.get(i), elementType);
|
||||
}
|
||||
if (object instanceof RayException) {
|
||||
// If the object is a `RayException`, it means that an error occurred during task
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.ray.runtime.object;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.id.ObjectId;
|
||||
@@ -20,13 +21,16 @@ public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
|
||||
*/
|
||||
private transient T object;
|
||||
|
||||
private Class<T> type;
|
||||
|
||||
/**
|
||||
* Whether the object is already gotten from the object store.
|
||||
*/
|
||||
private transient boolean objectGotten;
|
||||
|
||||
public RayObjectImpl(ObjectId id) {
|
||||
public RayObjectImpl(ObjectId id, Class<T> type) {
|
||||
this.id = id;
|
||||
this.type = type;
|
||||
object = null;
|
||||
objectGotten = false;
|
||||
}
|
||||
@@ -34,7 +38,7 @@ public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
|
||||
@Override
|
||||
public synchronized T get() {
|
||||
if (!objectGotten) {
|
||||
object = Ray.get(id);
|
||||
object = Ray.get(id, type);
|
||||
objectGotten = true;
|
||||
}
|
||||
return object;
|
||||
@@ -45,4 +49,9 @@ public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<T> getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
package org.ray.runtime.serializer;
|
||||
|
||||
import org.nustaq.serialization.FSTConfiguration;
|
||||
import org.ray.runtime.actor.NativeRayActor;
|
||||
import org.ray.runtime.actor.NativeRayActorSerializer;
|
||||
|
||||
/**
|
||||
* Java object serialization TODO: use others (e.g. Arrow) for higher performance
|
||||
*/
|
||||
public class FstSerializer {
|
||||
|
||||
private static final ThreadLocal<FSTConfiguration> conf = ThreadLocal.withInitial(() -> {
|
||||
FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration();
|
||||
conf.registerSerializer(NativeRayActor.class, new NativeRayActorSerializer(), true);
|
||||
return conf;
|
||||
});
|
||||
|
||||
|
||||
public static byte[] encode(Object obj) {
|
||||
FSTConfiguration current = conf.get();
|
||||
current.setClassLoader(Thread.currentThread().getContextClassLoader());
|
||||
return current.asByteArray(obj);
|
||||
}
|
||||
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T> T decode(byte[] bs) {
|
||||
FSTConfiguration current = conf.get();
|
||||
current.setClassLoader(Thread.currentThread().getContextClassLoader());
|
||||
return (T) current.asObject(bs);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,270 @@
|
||||
package org.ray.runtime.serializer;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Array;
|
||||
import java.math.BigInteger;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.lang3.mutable.MutableBoolean;
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.msgpack.core.MessageBufferPacker;
|
||||
import org.msgpack.core.MessagePack;
|
||||
import org.msgpack.core.MessagePacker;
|
||||
import org.msgpack.core.MessageUnpacker;
|
||||
import org.msgpack.value.ArrayValue;
|
||||
import org.msgpack.value.ExtensionValue;
|
||||
import org.msgpack.value.IntegerValue;
|
||||
import org.msgpack.value.Value;
|
||||
import org.msgpack.value.ValueType;
|
||||
|
||||
// We can't pack List / Map by MessagePack, because we don't know the type class when unpacking.
|
||||
public class MessagePackSerializer {
|
||||
|
||||
private static final byte LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID = 101;
|
||||
// MessagePack length is an int takes up to 9 bytes.
|
||||
// https://github.com/msgpack/msgpack/blob/master/spec.md#int-format-family
|
||||
private static final int MESSAGE_PACK_OFFSET = 9;
|
||||
|
||||
// Pakcers indexed by its corresponding Java class object.
|
||||
private static Map<Class<?>, TypePacker> packers = new HashMap<>();
|
||||
// Unpackers indexed by its corresponding MessagePack ValueType.
|
||||
private static Map<ValueType, TypeUnpacker> unpackers = new HashMap<>();
|
||||
// Null and array don't have a corresponding class, so define them separately.
|
||||
private static final TypePacker NULL_PACKER;
|
||||
private static final TypePacker ARRAY_PACKER;
|
||||
private static final TypePacker EXTENSION_PACKER;
|
||||
|
||||
static {
|
||||
// ===== Initialize packers =====
|
||||
// Null packer.
|
||||
NULL_PACKER = (object, packer, javaSerializer) -> packer.packNil();
|
||||
|
||||
// Array packer.
|
||||
ARRAY_PACKER = ((object, packer, javaSerializer) -> {
|
||||
int length = Array.getLength(object);
|
||||
packer.packArrayHeader(length);
|
||||
for (int i = 0; i < length; ++i) {
|
||||
pack(Array.get(object, i), packer, javaSerializer);
|
||||
}
|
||||
});
|
||||
|
||||
// Extension packer.
|
||||
EXTENSION_PACKER = ((object, packer, javaSerializer) -> {
|
||||
javaSerializer.serialize(object, packer);
|
||||
});
|
||||
|
||||
packers.put(Boolean.class,
|
||||
((object, packer, javaSerializer) -> packer.packBoolean((Boolean) object)));
|
||||
packers.put(Byte.class,
|
||||
((object, packer, javaSerializer) -> packer.packByte((Byte) object)));
|
||||
packers.put(Short.class,
|
||||
((object, packer, javaSerializer) -> packer.packShort((Short) object)));
|
||||
packers.put(Integer.class,
|
||||
((object, packer, javaSerializer) -> packer.packInt((Integer) object)));
|
||||
packers.put(Long.class,
|
||||
((object, packer, javaSerializer) -> packer.packLong((Long) object)));
|
||||
packers.put(BigInteger.class,
|
||||
((object, packer, javaSerializer) -> packer.packBigInteger((BigInteger) object)));
|
||||
packers.put(Float.class,
|
||||
((object, packer, javaSerializer) -> packer.packFloat((Float) object)));
|
||||
packers.put(Double.class,
|
||||
((object, packer, javaSerializer) -> packer.packDouble((Double) object)));
|
||||
packers.put(String.class,
|
||||
((object, packer, javaSerializer) -> packer.packString((String) object)));
|
||||
packers.put(byte[].class,
|
||||
((object, packer, javaSerializer) -> {
|
||||
byte[] bytes = (byte[]) object;
|
||||
packer.packBinaryHeader(bytes.length);
|
||||
packer.writePayload(bytes);
|
||||
}));
|
||||
|
||||
// ===== Initialize unpackers =====
|
||||
List<Class<?>> booleanClasses = ImmutableList.of(Boolean.class, boolean.class);
|
||||
List<Class<?>> byteClasses = ImmutableList.of(Byte.class, byte.class);
|
||||
List<Class<?>> shortClasses = ImmutableList.of(Short.class, short.class);
|
||||
List<Class<?>> intClasses = ImmutableList.of(Integer.class, int.class);
|
||||
List<Class<?>> longClasses = ImmutableList.of(Long.class, long.class);
|
||||
List<Class<?>> bigIntClasses = ImmutableList.of(BigInteger.class);
|
||||
List<Class<?>> floatClasses = ImmutableList.of(Float.class, float.class);
|
||||
List<Class<?>> doubleClasses = ImmutableList.of(Double.class, double.class);
|
||||
List<Class<?>> stringClasses = ImmutableList.of(String.class);
|
||||
List<Class<?>> binaryClasses = ImmutableList.of(byte[].class);
|
||||
|
||||
// Null unpacker.
|
||||
unpackers.put(ValueType.NIL, (value, targetClass, javaDeserializer) -> null);
|
||||
// Boolean unpacker.
|
||||
unpackers.put(ValueType.BOOLEAN, (value, targetClass, javaDeserializer) -> {
|
||||
Preconditions.checkArgument(checkTypeCompatible(booleanClasses, targetClass),
|
||||
"Boolean can't be deserialized as {}.", targetClass);
|
||||
return value.asBooleanValue().getBoolean();
|
||||
});
|
||||
// Integer unpacker.
|
||||
unpackers.put(ValueType.INTEGER, ((value, targetClass, javaDeserializer) -> {
|
||||
IntegerValue iv = value.asIntegerValue();
|
||||
if (iv.isInByteRange() && checkTypeCompatible(byteClasses, targetClass)) {
|
||||
return iv.asByte();
|
||||
} else if (iv.isInShortRange() && checkTypeCompatible(shortClasses, targetClass)) {
|
||||
return iv.asShort();
|
||||
} else if (iv.isInIntRange() && checkTypeCompatible(intClasses, targetClass)) {
|
||||
return iv.asInt();
|
||||
} else if (iv.isInLongRange() && checkTypeCompatible(longClasses, targetClass)) {
|
||||
return iv.asLong();
|
||||
} else if (checkTypeCompatible(bigIntClasses, targetClass)) {
|
||||
return iv.asBigInteger();
|
||||
}
|
||||
throw new IllegalArgumentException("Integer can't be deserialized as " + targetClass + ".");
|
||||
}));
|
||||
// Float unpacker.
|
||||
unpackers.put(ValueType.FLOAT, ((value, targetClass, javaDeserializer) -> {
|
||||
if (checkTypeCompatible(doubleClasses, targetClass)) {
|
||||
return value.asFloatValue().toDouble();
|
||||
} else if (checkTypeCompatible(floatClasses, targetClass)) {
|
||||
return value.asFloatValue().toFloat();
|
||||
}
|
||||
throw new IllegalArgumentException("Float can't be deserialized as " + targetClass + ".");
|
||||
}));
|
||||
// String unpacker.
|
||||
unpackers.put(ValueType.STRING, ((value, targetClass, javaDeserializer) -> {
|
||||
Preconditions.checkArgument(checkTypeCompatible(stringClasses, targetClass),
|
||||
"String can't be deserialized as {}.", targetClass);
|
||||
return value.asStringValue().asString();
|
||||
}));
|
||||
// Binary unpacker.
|
||||
unpackers.put(ValueType.BINARY, ((value, targetClass, javaDeserializer) -> {
|
||||
Preconditions.checkArgument(checkTypeCompatible(binaryClasses, targetClass),
|
||||
"Binary can't be deserialized as {}.", targetClass);
|
||||
return value.asBinaryValue().asByteArray();
|
||||
}));
|
||||
// Array unpacker.
|
||||
unpackers.put(ValueType.ARRAY, ((value, targetClass, javaDeserializer) -> {
|
||||
ArrayValue av = value.asArrayValue();
|
||||
Class<?> componentType =
|
||||
targetClass.isArray() ? targetClass.getComponentType() : Object.class;
|
||||
Object array = Array.newInstance(componentType, av.size());
|
||||
for (int i = 0; i < av.size(); ++i) {
|
||||
Array.set(array, i, unpack(av.get(i), componentType, javaDeserializer));
|
||||
}
|
||||
return array;
|
||||
}));
|
||||
// Extension unpacker.
|
||||
unpackers.put(ValueType.EXTENSION, ((value, targetClass, javaDeserializer) -> {
|
||||
ExtensionValue ev = value.asExtensionValue();
|
||||
byte extType = ev.getType();
|
||||
if (extType == LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID) {
|
||||
return javaDeserializer.deserialize(ev);
|
||||
}
|
||||
throw new IllegalArgumentException("Unknown extension type id " + ev.getType() + ".");
|
||||
}));
|
||||
}
|
||||
|
||||
interface JavaSerializer {
|
||||
|
||||
void serialize(Object object, MessagePacker packer) throws IOException;
|
||||
}
|
||||
|
||||
interface JavaDeserializer {
|
||||
|
||||
Object deserialize(ExtensionValue v);
|
||||
}
|
||||
|
||||
interface TypePacker {
|
||||
|
||||
void pack(Object object, MessagePacker packer,
|
||||
JavaSerializer javaSerializer) throws IOException;
|
||||
}
|
||||
|
||||
interface TypeUnpacker {
|
||||
|
||||
Object unpack(Value value, Class<?> targetClass,
|
||||
JavaDeserializer javaDeserializer);
|
||||
}
|
||||
|
||||
private static boolean checkTypeCompatible(List<Class<?>> expected, Class<?> actual) {
|
||||
for (Class<?> expectedClass : expected) {
|
||||
if (actual.isAssignableFrom(expectedClass)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static void pack(Object object, MessagePacker packer, JavaSerializer javaSerializer)
|
||||
throws IOException {
|
||||
TypePacker typePacker;
|
||||
if (object == null) {
|
||||
typePacker = NULL_PACKER;
|
||||
} else {
|
||||
Class<?> type = object.getClass();
|
||||
typePacker = packers.get(type);
|
||||
if (typePacker == null) {
|
||||
if (type.isArray()) {
|
||||
typePacker = ARRAY_PACKER;
|
||||
} else {
|
||||
typePacker = EXTENSION_PACKER;
|
||||
}
|
||||
}
|
||||
}
|
||||
typePacker.pack(object, packer, javaSerializer);
|
||||
}
|
||||
|
||||
private static Object unpack(Value v, Class<?> type, JavaDeserializer javaDeserializer) {
|
||||
return unpackers.get(v.getValueType()).unpack(v, type, javaDeserializer);
|
||||
}
|
||||
|
||||
public static Pair<byte[], Boolean> encode(Object obj) {
|
||||
MessageBufferPacker packer = MessagePack.newDefaultBufferPacker();
|
||||
try {
|
||||
// Reserve MESSAGE_PACK_OFFSET bytes for MessagePack bytes length.
|
||||
packer.writePayload(new byte[MESSAGE_PACK_OFFSET]);
|
||||
// Serialize input object by MessagePack.
|
||||
MutableBoolean isCrossLanguage = new MutableBoolean(true);
|
||||
pack(obj, packer, ((object, packer1) -> {
|
||||
byte[] payload = FstSerializer.encode(object);
|
||||
packer1.packExtensionTypeHeader(LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID, payload.length);
|
||||
packer1.addPayload(payload);
|
||||
isCrossLanguage.setFalse();
|
||||
}));
|
||||
byte[] msgpackBytes = packer.toByteArray();
|
||||
// Serialize MessagePack bytes length.
|
||||
MessageBufferPacker headerPacker = MessagePack.newDefaultBufferPacker();
|
||||
Preconditions.checkState(msgpackBytes.length >= MESSAGE_PACK_OFFSET);
|
||||
headerPacker.packLong(msgpackBytes.length - MESSAGE_PACK_OFFSET);
|
||||
byte[] msgpackBytesLength = headerPacker.toByteArray();
|
||||
// Check serialized MessagePack bytes length is valid.
|
||||
Preconditions.checkState(msgpackBytesLength.length <= MESSAGE_PACK_OFFSET);
|
||||
// Write MessagePack bytes length to reserved buffer.
|
||||
System.arraycopy(msgpackBytesLength, 0, msgpackBytes, 0, msgpackBytesLength.length);
|
||||
return ImmutablePair.of(msgpackBytes, isCrossLanguage.getValue());
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T> T decode(byte[] bs, Class<?> type) {
|
||||
try {
|
||||
// Read MessagePack bytes length.
|
||||
MessageUnpacker headerUnpacker = MessagePack.newDefaultUnpacker(bs, 0, MESSAGE_PACK_OFFSET);
|
||||
long msgpackBytesLength = headerUnpacker.unpackLong();
|
||||
headerUnpacker.close();
|
||||
// Check MessagePack bytes length is valid.
|
||||
Preconditions.checkState(MESSAGE_PACK_OFFSET + msgpackBytesLength <= bs.length);
|
||||
// Deserialize MessagePack bytes from MESSAGE_PACK_OFFSET.
|
||||
MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(bs, MESSAGE_PACK_OFFSET,
|
||||
(int) msgpackBytesLength);
|
||||
Value v = unpacker.unpackValue();
|
||||
if (type == null) {
|
||||
type = Object.class;
|
||||
}
|
||||
return (T) unpack(v, type,
|
||||
((ExtensionValue ev) -> FstSerializer.decode(ev.getData())));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package org.ray.runtime.serializer;
|
||||
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
public class Serializer {
|
||||
|
||||
public static Pair<byte[], Boolean> encode(Object obj) {
|
||||
return MessagePackSerializer.encode(obj);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T> T decode(byte[] bs, Class<?> type) {
|
||||
return MessagePackSerializer.decode(bs, type);
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.id.ObjectId;
|
||||
@@ -40,9 +40,18 @@ public class ArgumentsBuilder {
|
||||
id = ((RayObject) arg).getId();
|
||||
} else {
|
||||
value = ObjectSerializer.serialize(arg);
|
||||
if (language != Language.JAVA) {
|
||||
boolean isCrossData =
|
||||
Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_CROSS_LANGUAGE) ||
|
||||
Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_RAW);
|
||||
if (!isCrossData) {
|
||||
throw new IllegalArgumentException(String.format("Can't transfer %s data to %s",
|
||||
Arrays.toString(value.metadata), language.getValueDescriptor().getName()));
|
||||
}
|
||||
}
|
||||
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore()
|
||||
.putRaw(value);
|
||||
.putRaw(value);
|
||||
value = null;
|
||||
}
|
||||
}
|
||||
@@ -61,10 +70,10 @@ public class ArgumentsBuilder {
|
||||
/**
|
||||
* Convert list of NativeRayObject to real function arguments.
|
||||
*/
|
||||
public static Object[] unwrap(List<NativeRayObject> args, ClassLoader classLoader) {
|
||||
public static Object[] unwrap(List<NativeRayObject> args, Class<?>[] types) {
|
||||
Object[] realArgs = new Object[args.size()];
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, classLoader);
|
||||
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, types[i]);
|
||||
}
|
||||
return realArgs;
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
@@ -97,7 +98,8 @@ public abstract class TaskExecutor<T extends ActorContext> {
|
||||
}
|
||||
actor = actorContext.currentActor;
|
||||
}
|
||||
Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader);
|
||||
Object[] args = ArgumentsBuilder
|
||||
.unwrap(argsBytes, rayFunction.executable.getParameterTypes());
|
||||
// Execute the task.
|
||||
Object result;
|
||||
try {
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
package org.ray.runtime.util;
|
||||
|
||||
import org.nustaq.serialization.FSTConfiguration;
|
||||
import org.ray.runtime.actor.NativeRayActor;
|
||||
import org.ray.runtime.actor.NativeRayActorSerializer;
|
||||
|
||||
/**
|
||||
* Java object serialization TODO: use others (e.g. Arrow) for higher performance
|
||||
*/
|
||||
public class Serializer {
|
||||
|
||||
private static final ThreadLocal<FSTConfiguration> conf = ThreadLocal.withInitial(() -> {
|
||||
FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration();
|
||||
conf.registerSerializer(NativeRayActor.class, new NativeRayActorSerializer(), true);
|
||||
return conf;
|
||||
});
|
||||
|
||||
public static byte[] encode(Object obj) {
|
||||
return conf.get().asByteArray(obj);
|
||||
}
|
||||
|
||||
public static byte[] encode(Object obj, ClassLoader classLoader) {
|
||||
byte[] result;
|
||||
FSTConfiguration current = conf.get();
|
||||
if (classLoader != null && classLoader != current.getClassLoader()) {
|
||||
ClassLoader old = current.getClassLoader();
|
||||
current.setClassLoader(classLoader);
|
||||
result = current.asByteArray(obj);
|
||||
current.setClassLoader(old);
|
||||
} else {
|
||||
result = current.asByteArray(obj);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T> T decode(byte[] bs) {
|
||||
return (T) conf.get().asObject(bs);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T> T decode(byte[] bs, ClassLoader classLoader) {
|
||||
Object object;
|
||||
FSTConfiguration current = conf.get();
|
||||
if (classLoader != null && classLoader != current.getClassLoader()) {
|
||||
ClassLoader old = current.getClassLoader();
|
||||
current.setClassLoader(classLoader);
|
||||
object = current.asObject(bs);
|
||||
current.setClassLoader(old);
|
||||
} else {
|
||||
object = current.asObject(bs);
|
||||
}
|
||||
return (T) object;
|
||||
}
|
||||
|
||||
public static void setClassloader(ClassLoader classLoader) {
|
||||
conf.get().setClassLoader(classLoader);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package org.ray.runtime.util;
|
||||
|
||||
import org.apache.commons.lang3.mutable.MutableBoolean;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.runtime.serializer.Serializer;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
public class SerializerTest {
|
||||
|
||||
@Test
|
||||
public void testBasicSerialization() {
|
||||
// Test serialize / deserialize primitive types with type conversion.
|
||||
{
|
||||
Object[] foo = new Object[]{"hello", (byte) 1, 2.0, (short) 3, 4, 5L,
|
||||
new String[]{"hello", "world"}};
|
||||
Pair<byte[], Boolean> serialized = Serializer.encode(foo);
|
||||
Object[] bar = Serializer.decode(serialized.getLeft(), Object[].class);
|
||||
Assert.assertTrue(serialized.getRight());
|
||||
Assert.assertEquals(foo[0], bar[0]);
|
||||
Assert.assertEquals(((Number) foo[1]).byteValue(), ((Number) bar[1]).byteValue());
|
||||
Assert.assertEquals(foo[2], bar[2]);
|
||||
Assert.assertEquals(((Number) foo[3]).intValue(), ((Number) bar[3]).intValue());
|
||||
Assert.assertEquals(((Number) foo[4]).intValue(), ((Number) bar[4]).intValue());
|
||||
Assert.assertEquals(((Number) foo[5]).intValue(), ((Number) bar[5]).intValue());
|
||||
}
|
||||
// Test multidimensional array.
|
||||
{
|
||||
Object[][] foo = new Object[][]{{1, 2}, {"3", 4}};
|
||||
Assert.expectThrows(RuntimeException.class, () -> {
|
||||
Object[][] bar = Serializer.decode(Serializer.encode(foo).getLeft(), Integer[][].class);
|
||||
});
|
||||
Pair<byte[], Boolean> serialized = Serializer.encode(foo);
|
||||
Object[][] bar = Serializer.decode(serialized.getLeft(), Object[][].class);
|
||||
Assert.assertTrue(serialized.getRight());
|
||||
Assert.assertEquals(((Number) foo[0][1]).intValue(), ((Number) bar[0][1]).intValue());
|
||||
Assert.assertEquals(foo[1][0], bar[1][0]);
|
||||
}
|
||||
// Test List.
|
||||
{
|
||||
ArrayList<String> foo = new ArrayList<>();
|
||||
foo.add("1");
|
||||
foo.add("2");
|
||||
Pair<byte[], Boolean> serialized = Serializer.encode(foo);
|
||||
ArrayList<String> bar = Serializer.decode(serialized.getLeft(), String[].class);
|
||||
Assert.assertFalse(serialized.getRight());
|
||||
Assert.assertEquals(foo.get(0), bar.get(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user