Cross language serialization for primitive types (#7711)

* Cross language serialization for Java and Python

* Use strict types when Python serializing

* Handle recursive objects in Python; Pin msgpack >= 0.6.0, < 1.0.0

* Disable gc for optimizing msgpack loads

* Fix merge bug

* Java call Python use returnType; Fix ClassLoaderTest

* Fix RayMethodsTest

* Fix checkstyle

* Fix lint

* prepare_args raises exception if try to transfer a non-deserializable object to another language

* Fix CrossLanguageInvocationTest.java, Python msgpack treat float as double

* Minor fixes

* Fix compile error on linux

* Fix lint in java/BUILD.bazel

* Fix test_failure

* Fix lint

* Class<?> to Class<T>; Refine metadata bytes.

* Rename FST to Fst; sort java dependencies

* Change Class<?>[] to Optional<Class<?>>; sort requirements in setup.py

* Improve CrossLanguageInvocationTest

* Refactor MessagePackSerializer.java

* Refactor MessagePackSerializer.java; Refine CrossLanguageInvocationTest.java

* Remove unnecessary dependencies for Java; Add getReturnType() for RayFunction in Java

* Fix bug

* Remove custom cross language type support

* Replace Serializer.Meta with MutableBoolean

* Remove @SuppressWarnings support from checkstyle.xml; Add null test in CrossLanguageInvocationTest.java

* Refine MessagePackSerializer.pack

* Ray.get support RayObject as input

* Improve comments and error info

* Remove classLoader argument from serializer

* Separate msgpack from pickle5 in Python

* Pair<byte[], MutableBoolean> to Pair<byte[], Boolean>

* Remove public static <T> T get(RayObject<T> object), use RayObject.get() instead

* Refine test

* small fixes

Co-authored-by: 刘宝 <po.lb@antfin.com>
Co-authored-by: Hao Chen <chenh1024@gmail.com>
This commit is contained in:
fyrestone
2020-04-08 21:10:57 +08:00
committed by GitHub
parent e8c19aba41
commit fc6259a656
42 changed files with 1057 additions and 313 deletions
+1
View File
@@ -76,6 +76,7 @@ define_java_module(
"@maven//:de_ruedigermoeller_fst",
"@maven//:net_java_dev_jna_jna",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_msgpack_msgpack_core",
"@maven//:org_ow2_asm_asm",
"@maven//:org_slf4j_slf4j_api",
"@maven//:org_slf4j_slf4j_log4j12",
+25 -6
View File
@@ -1,5 +1,6 @@
package org.ray.api;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import org.ray.api.id.ObjectId;
@@ -62,23 +63,41 @@ public final class Ray extends RayCall {
}
/**
* Get an object from the object store.
* Get an object by id from the object store.
*
* @param objectId The ID of the object to get.
* @param objectType The type of the object to get.
* @return The Java object.
*/
public static <T> T get(ObjectId objectId) {
return runtime.get(objectId);
public static <T> T get(ObjectId objectId, Class<T> objectType) {
return runtime.get(objectId, objectType);
}
/**
* Get a list of objects from the object store.
* Get a list of objects by ids from the object store.
*
* @param objectIds The list of object IDs.
* @param objectType The type of object.
* @return A list of Java objects.
*/
public static <T> List<T> get(List<ObjectId> objectIds) {
return runtime.get(objectIds);
public static <T> List<T> get(List<ObjectId> objectIds, Class<T> objectType) {
return runtime.get(objectIds, objectType);
}
/**
* Get a list of objects by RayObjects from the object store.
*
* @param objectList A list of RayObject to get.
* @return A list of Java objects.
*/
public static <T> List<T> get(List<RayObject<T>> objectList) {
List<ObjectId> objectIds = new ArrayList<>();
Class<T> objectType = null;
for (RayObject<T> o : objectList) {
objectIds.add(o.getId());
objectType = o.getType();
}
return runtime.get(objectIds, objectType);
}
/**
@@ -19,5 +19,10 @@ public interface RayObject<T> {
*/
ObjectId getId();
/**
* Get the Object type.
*/
Class<T> getType();
}
@@ -39,17 +39,19 @@ public interface RayRuntime {
* Get an object from the object store.
*
* @param objectId The ID of the object to get.
* @param objectType The type of the object to get.
* @return The Java object.
*/
<T> T get(ObjectId objectId);
<T> T get(ObjectId objectId, Class<T> objectType);
/**
* Get a list of objects from the object store.
*
* @param objectIds The list of object IDs.
* @param objectType The type of object.
* @return A list of Java objects.
*/
<T> List<T> get(List<ObjectId> objectIds);
<T> List<T> get(List<ObjectId> objectIds, Class<T> objectType);
/**
* Wait for a list of RayObjects to be locally available, until specified number of objects are
+2 -1
View File
@@ -15,11 +15,12 @@ def gen_java_deps():
"de.ruedigermoeller:fst:2.57",
"javax.xml.bind:jaxb-api:2.3.0",
"org.apache.commons:commons-lang3:3.4",
"org.msgpack:msgpack-core:0.8.20",
"org.ow2.asm:asm:6.0",
"org.slf4j:slf4j-log4j12:1.7.25",
"org.testng:testng:6.9.10",
"redis.clients:jedis:2.8.0",
"net.java.dev.jna:jna:5.5.0"
"net.java.dev.jna:jna:5.5.0",
],
repositories = [
"https://repo1.maven.org/maven2/",
+5
View File
@@ -62,6 +62,11 @@
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>org.msgpack</groupId>
<artifactId>msgpack-core</artifactId>
<version>0.8.20</version>
</dependency>
<dependency>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
@@ -4,6 +4,7 @@ import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Callable;
import org.ray.api.BaseActor;
import org.ray.api.RayActor;
@@ -15,7 +16,6 @@ import org.ray.api.function.PyActorClass;
import org.ray.api.function.PyActorMethod;
import org.ray.api.function.PyRemoteFunction;
import org.ray.api.function.RayFunc;
import org.ray.api.function.RayFuncVoid;
import org.ray.api.id.ObjectId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
@@ -26,6 +26,7 @@ import org.ray.runtime.context.WorkerContext;
import org.ray.runtime.functionmanager.FunctionDescriptor;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.generated.Common.Language;
import org.ray.runtime.generated.Common.WorkerType;
@@ -73,18 +74,18 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
@Override
public <T> RayObject<T> put(T obj) {
ObjectId objectId = objectStore.put(obj);
return new RayObjectImpl<>(objectId);
return new RayObjectImpl<T>(objectId, (Class<T>)(obj == null ? Object.class : obj.getClass()));
}
@Override
public <T> T get(ObjectId objectId) throws RayException {
List<T> ret = get(ImmutableList.of(objectId));
public <T> T get(ObjectId objectId, Class<T> objectType) throws RayException {
List<T> ret = get(ImmutableList.of(objectId), objectType);
return ret.get(0);
}
@Override
public <T> List<T> get(List<ObjectId> objectIds) {
return objectStore.get(objectIds);
public <T> List<T> get(List<ObjectId> objectIds, Class<T> objectType) {
return objectStore.get(objectIds, objectType);
}
@Override
@@ -99,41 +100,39 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
@Override
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
FunctionDescriptor functionDescriptor =
functionManager.getFunction(workerContext.getCurrentJobId(), func)
.functionDescriptor;
int numReturns = func instanceof RayFuncVoid ? 0 : 1;
return callNormalFunction(functionDescriptor, args, numReturns, options);
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func);
FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor;
Optional<Class<?>> returnType = rayFunction.getReturnType();
return callNormalFunction(functionDescriptor, args, returnType, options);
}
@Override
public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args,
CallOptions options) {
checkPyArguments(args);
CallOptions options) {
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
pyRemoteFunction.moduleName,
"",
pyRemoteFunction.functionName);
// Python functions always have a return value, even if it's `None`.
return callNormalFunction(functionDescriptor, args, /*numReturns=*/1, options);
return callNormalFunction(functionDescriptor, args,
/*returnType=*/Optional.of(pyRemoteFunction.returnType), options);
}
@Override
public RayObject callActor(RayActor<?> actor, RayFunc func, Object[] args) {
FunctionDescriptor functionDescriptor =
functionManager.getFunction(workerContext.getCurrentJobId(), func)
.functionDescriptor;
int numReturns = func instanceof RayFuncVoid ? 0 : 1;
return callActorFunction(actor, functionDescriptor, args, numReturns);
RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func);
FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor;
Optional<Class<?>> returnType = rayFunction.getReturnType();
return callActorFunction(actor, functionDescriptor, args, returnType);
}
@Override
public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object... args) {
checkPyArguments(args);
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
pyActor.getClassName(), pyActorMethod.methodName);
// Python functions always have a return value, even if it's `None`.
return callActorFunction(pyActor, functionDescriptor, args, /*numReturns=*/1);
return callActorFunction(pyActor, functionDescriptor, args,
/*returnType=*/Optional.of(pyActorMethod.returnType));
}
@Override
@@ -148,8 +147,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
@Override
public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,
ActorCreationOptions options) {
checkPyArguments(args);
ActorCreationOptions options) {
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
pyActorClass.moduleName,
pyActorClass.className,
@@ -157,14 +155,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
return (RayPyActor) createActorImpl(functionDescriptor, args, options);
}
private void checkPyArguments(Object[] args) {
for (Object arg : args) {
Preconditions.checkArgument(
(arg instanceof RayPyActor) || (arg instanceof byte[]),
"Python argument can only be a RayPyActor or a byte array, not {}.",
arg.getClass().getName());
}
}
@Override
public void setAsyncContext(Object asyncContext) {
@@ -218,30 +208,32 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
}
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
Object[] args, int numReturns, CallOptions options) {
Object[] args, Optional<Class<?>> returnType, CallOptions options) {
int numReturns = returnType.isPresent() ? 1 : 0;
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage());
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
functionArgs, numReturns, options);
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
Preconditions.checkState(returnIds.size() == numReturns);
if (returnIds.isEmpty()) {
return null;
} else {
return new RayObjectImpl(returnIds.get(0));
return new RayObjectImpl(returnIds.get(0), returnType.get());
}
}
private RayObject callActorFunction(BaseActor rayActor,
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) {
FunctionDescriptor functionDescriptor, Object[] args, Optional<Class<?>> returnType) {
int numReturns = returnType.isPresent() ? 1 : 0;
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage());
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
functionDescriptor, functionArgs, numReturns, null);
Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1);
Preconditions.checkState(returnIds.size() == numReturns);
if (returnIds.isEmpty()) {
return null;
} else {
return new RayObjectImpl(returnIds.get(0));
return new RayObjectImpl(returnIds.get(0), returnType.get());
}
}
@@ -5,6 +5,7 @@ import org.ray.api.id.JobId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.generated.Common.TaskType;
import org.ray.runtime.serializer.Serializer;
/**
* The context of worker.
@@ -28,7 +29,7 @@ public interface WorkerContext {
/**
* The class loader that is associated with the current job. It's used for locating classes when
* dealing with serialization and deserialization in {@link org.ray.runtime.util.Serializer}.
* dealing with serialization and deserialization in {@link Serializer}.
*/
ClassLoader getCurrentClassLoader();
@@ -3,6 +3,7 @@ package org.ray.runtime.functionmanager;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.util.Optional;
/**
* Represents a Ray function (either a Method or a Constructor in Java) and its metadata.
@@ -67,6 +68,17 @@ public class RayFunction {
}
}
/**
* @return Return type.
*/
public Optional<Class<?>> getReturnType() {
if (hasReturn()) {
return Optional.of(((Method) executable).getReturnType());
} else {
return Optional.empty();
}
}
@Override
public String toString() {
return executable.toString();
@@ -1,13 +1,15 @@
package org.ray.runtime.object;
import java.util.Arrays;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.ObjectId;
import org.ray.runtime.generated.Gcs.ErrorType;
import org.ray.runtime.util.Serializer;
import org.ray.runtime.serializer.Serializer;
/**
* Serialize to and deserialize from {@link NativeRayObject}. Metadata is generated during
@@ -21,29 +23,33 @@ public class ObjectSerializer {
.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_CROSS_LANGUAGE = "XLANG".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_JAVA = "JAVA".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_RAW = "RAW".getBytes();
/**
* Deserialize an object from an {@link NativeRayObject} instance.
*
* @param nativeRayObject The object to deserialize.
* @param objectId The associated object ID of the object.
* @param classLoader The classLoader of the object.
* @return The deserialized object.
*/
public static Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId,
ClassLoader classLoader) {
Class<?> objectType) {
byte[] meta = nativeRayObject.metadata;
byte[] data = nativeRayObject.data;
if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
if (Arrays.equals(meta, RAW_TYPE_META)) {
if (Arrays.equals(meta, OBJECT_METADATA_TYPE_RAW)) {
return data;
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_CROSS_LANGUAGE) ||
Arrays.equals(meta, OBJECT_METADATA_TYPE_JAVA)) {
return Serializer.decode(data, objectType);
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
return new RayWorkerException();
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
@@ -51,12 +57,15 @@ public class ObjectSerializer {
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
return new UnreconstructableException(objectId);
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
return Serializer.decode(data, classLoader);
return Serializer.decode(data, objectType);
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_PYTHON)) {
throw new IllegalArgumentException("Can't deserialize Python object: " + objectId
.toString());
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
} else {
// If data is not null, deserialize the Java object.
return Serializer.decode(data, classLoader);
return Serializer.decode(data, objectType);
}
}
@@ -72,12 +81,14 @@ public class ObjectSerializer {
} else if (object instanceof byte[]) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
return new NativeRayObject((byte[]) object, RAW_TYPE_META);
return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW);
} else if (object instanceof RayTaskException) {
return new NativeRayObject(Serializer.encode(object),
TASK_EXECUTION_EXCEPTION_META);
byte[] serializedBytes = Serializer.encode(object).getLeft();
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);
} else {
return new NativeRayObject(Serializer.encode(object), null);
Pair<byte[], Boolean> serialized = Serializer.encode(object);
return new NativeRayObject(serialized.getLeft(), serialized.getRight() ?
OBJECT_METADATA_TYPE_CROSS_LANGUAGE : OBJECT_METADATA_TYPE_JAVA);
}
}
}
@@ -86,7 +86,7 @@ public abstract class ObjectStore {
* @return A list of GetResult objects.
*/
@SuppressWarnings("unchecked")
public <T> List<T> get(List<ObjectId> ids) {
public <T> List<T> get(List<ObjectId> ids, Class<?> elementType) {
// Pass -1 as timeout to wait until all objects are available in object store.
List<NativeRayObject> dataAndMetaList = getRaw(ids, -1);
@@ -96,7 +96,7 @@ public abstract class ObjectStore {
Object object = null;
if (dataAndMeta != null) {
object = ObjectSerializer
.deserialize(dataAndMeta, ids.get(i), workerContext.getCurrentClassLoader());
.deserialize(dataAndMeta, ids.get(i), elementType);
}
if (object instanceof RayException) {
// If the object is a `RayException`, it means that an error occurred during task
@@ -1,6 +1,7 @@
package org.ray.runtime.object;
import java.io.Serializable;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.id.ObjectId;
@@ -20,13 +21,16 @@ public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
*/
private transient T object;
private Class<T> type;
/**
* Whether the object is already gotten from the object store.
*/
private transient boolean objectGotten;
public RayObjectImpl(ObjectId id) {
public RayObjectImpl(ObjectId id, Class<T> type) {
this.id = id;
this.type = type;
object = null;
objectGotten = false;
}
@@ -34,7 +38,7 @@ public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
@Override
public synchronized T get() {
if (!objectGotten) {
object = Ray.get(id);
object = Ray.get(id, type);
objectGotten = true;
}
return object;
@@ -45,4 +49,9 @@ public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
return id;
}
@Override
public Class<T> getType() {
return type;
}
}
@@ -0,0 +1,32 @@
package org.ray.runtime.serializer;
import org.nustaq.serialization.FSTConfiguration;
import org.ray.runtime.actor.NativeRayActor;
import org.ray.runtime.actor.NativeRayActorSerializer;
/**
* Java object serialization TODO: use others (e.g. Arrow) for higher performance
*/
public class FstSerializer {
private static final ThreadLocal<FSTConfiguration> conf = ThreadLocal.withInitial(() -> {
FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration();
conf.registerSerializer(NativeRayActor.class, new NativeRayActorSerializer(), true);
return conf;
});
public static byte[] encode(Object obj) {
FSTConfiguration current = conf.get();
current.setClassLoader(Thread.currentThread().getContextClassLoader());
return current.asByteArray(obj);
}
@SuppressWarnings("unchecked")
public static <T> T decode(byte[] bs) {
FSTConfiguration current = conf.get();
current.setClassLoader(Thread.currentThread().getContextClassLoader());
return (T) current.asObject(bs);
}
}
@@ -0,0 +1,270 @@
package org.ray.runtime.serializer;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.lang.reflect.Array;
import java.math.BigInteger;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.mutable.MutableBoolean;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.msgpack.core.MessageBufferPacker;
import org.msgpack.core.MessagePack;
import org.msgpack.core.MessagePacker;
import org.msgpack.core.MessageUnpacker;
import org.msgpack.value.ArrayValue;
import org.msgpack.value.ExtensionValue;
import org.msgpack.value.IntegerValue;
import org.msgpack.value.Value;
import org.msgpack.value.ValueType;
// We can't pack List / Map by MessagePack, because we don't know the type class when unpacking.
public class MessagePackSerializer {
private static final byte LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID = 101;
// MessagePack length is an int takes up to 9 bytes.
// https://github.com/msgpack/msgpack/blob/master/spec.md#int-format-family
private static final int MESSAGE_PACK_OFFSET = 9;
// Pakcers indexed by its corresponding Java class object.
private static Map<Class<?>, TypePacker> packers = new HashMap<>();
// Unpackers indexed by its corresponding MessagePack ValueType.
private static Map<ValueType, TypeUnpacker> unpackers = new HashMap<>();
// Null and array don't have a corresponding class, so define them separately.
private static final TypePacker NULL_PACKER;
private static final TypePacker ARRAY_PACKER;
private static final TypePacker EXTENSION_PACKER;
static {
// ===== Initialize packers =====
// Null packer.
NULL_PACKER = (object, packer, javaSerializer) -> packer.packNil();
// Array packer.
ARRAY_PACKER = ((object, packer, javaSerializer) -> {
int length = Array.getLength(object);
packer.packArrayHeader(length);
for (int i = 0; i < length; ++i) {
pack(Array.get(object, i), packer, javaSerializer);
}
});
// Extension packer.
EXTENSION_PACKER = ((object, packer, javaSerializer) -> {
javaSerializer.serialize(object, packer);
});
packers.put(Boolean.class,
((object, packer, javaSerializer) -> packer.packBoolean((Boolean) object)));
packers.put(Byte.class,
((object, packer, javaSerializer) -> packer.packByte((Byte) object)));
packers.put(Short.class,
((object, packer, javaSerializer) -> packer.packShort((Short) object)));
packers.put(Integer.class,
((object, packer, javaSerializer) -> packer.packInt((Integer) object)));
packers.put(Long.class,
((object, packer, javaSerializer) -> packer.packLong((Long) object)));
packers.put(BigInteger.class,
((object, packer, javaSerializer) -> packer.packBigInteger((BigInteger) object)));
packers.put(Float.class,
((object, packer, javaSerializer) -> packer.packFloat((Float) object)));
packers.put(Double.class,
((object, packer, javaSerializer) -> packer.packDouble((Double) object)));
packers.put(String.class,
((object, packer, javaSerializer) -> packer.packString((String) object)));
packers.put(byte[].class,
((object, packer, javaSerializer) -> {
byte[] bytes = (byte[]) object;
packer.packBinaryHeader(bytes.length);
packer.writePayload(bytes);
}));
// ===== Initialize unpackers =====
List<Class<?>> booleanClasses = ImmutableList.of(Boolean.class, boolean.class);
List<Class<?>> byteClasses = ImmutableList.of(Byte.class, byte.class);
List<Class<?>> shortClasses = ImmutableList.of(Short.class, short.class);
List<Class<?>> intClasses = ImmutableList.of(Integer.class, int.class);
List<Class<?>> longClasses = ImmutableList.of(Long.class, long.class);
List<Class<?>> bigIntClasses = ImmutableList.of(BigInteger.class);
List<Class<?>> floatClasses = ImmutableList.of(Float.class, float.class);
List<Class<?>> doubleClasses = ImmutableList.of(Double.class, double.class);
List<Class<?>> stringClasses = ImmutableList.of(String.class);
List<Class<?>> binaryClasses = ImmutableList.of(byte[].class);
// Null unpacker.
unpackers.put(ValueType.NIL, (value, targetClass, javaDeserializer) -> null);
// Boolean unpacker.
unpackers.put(ValueType.BOOLEAN, (value, targetClass, javaDeserializer) -> {
Preconditions.checkArgument(checkTypeCompatible(booleanClasses, targetClass),
"Boolean can't be deserialized as {}.", targetClass);
return value.asBooleanValue().getBoolean();
});
// Integer unpacker.
unpackers.put(ValueType.INTEGER, ((value, targetClass, javaDeserializer) -> {
IntegerValue iv = value.asIntegerValue();
if (iv.isInByteRange() && checkTypeCompatible(byteClasses, targetClass)) {
return iv.asByte();
} else if (iv.isInShortRange() && checkTypeCompatible(shortClasses, targetClass)) {
return iv.asShort();
} else if (iv.isInIntRange() && checkTypeCompatible(intClasses, targetClass)) {
return iv.asInt();
} else if (iv.isInLongRange() && checkTypeCompatible(longClasses, targetClass)) {
return iv.asLong();
} else if (checkTypeCompatible(bigIntClasses, targetClass)) {
return iv.asBigInteger();
}
throw new IllegalArgumentException("Integer can't be deserialized as " + targetClass + ".");
}));
// Float unpacker.
unpackers.put(ValueType.FLOAT, ((value, targetClass, javaDeserializer) -> {
if (checkTypeCompatible(doubleClasses, targetClass)) {
return value.asFloatValue().toDouble();
} else if (checkTypeCompatible(floatClasses, targetClass)) {
return value.asFloatValue().toFloat();
}
throw new IllegalArgumentException("Float can't be deserialized as " + targetClass + ".");
}));
// String unpacker.
unpackers.put(ValueType.STRING, ((value, targetClass, javaDeserializer) -> {
Preconditions.checkArgument(checkTypeCompatible(stringClasses, targetClass),
"String can't be deserialized as {}.", targetClass);
return value.asStringValue().asString();
}));
// Binary unpacker.
unpackers.put(ValueType.BINARY, ((value, targetClass, javaDeserializer) -> {
Preconditions.checkArgument(checkTypeCompatible(binaryClasses, targetClass),
"Binary can't be deserialized as {}.", targetClass);
return value.asBinaryValue().asByteArray();
}));
// Array unpacker.
unpackers.put(ValueType.ARRAY, ((value, targetClass, javaDeserializer) -> {
ArrayValue av = value.asArrayValue();
Class<?> componentType =
targetClass.isArray() ? targetClass.getComponentType() : Object.class;
Object array = Array.newInstance(componentType, av.size());
for (int i = 0; i < av.size(); ++i) {
Array.set(array, i, unpack(av.get(i), componentType, javaDeserializer));
}
return array;
}));
// Extension unpacker.
unpackers.put(ValueType.EXTENSION, ((value, targetClass, javaDeserializer) -> {
ExtensionValue ev = value.asExtensionValue();
byte extType = ev.getType();
if (extType == LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID) {
return javaDeserializer.deserialize(ev);
}
throw new IllegalArgumentException("Unknown extension type id " + ev.getType() + ".");
}));
}
interface JavaSerializer {
void serialize(Object object, MessagePacker packer) throws IOException;
}
interface JavaDeserializer {
Object deserialize(ExtensionValue v);
}
interface TypePacker {
void pack(Object object, MessagePacker packer,
JavaSerializer javaSerializer) throws IOException;
}
interface TypeUnpacker {
Object unpack(Value value, Class<?> targetClass,
JavaDeserializer javaDeserializer);
}
private static boolean checkTypeCompatible(List<Class<?>> expected, Class<?> actual) {
for (Class<?> expectedClass : expected) {
if (actual.isAssignableFrom(expectedClass)) {
return true;
}
}
return false;
}
private static void pack(Object object, MessagePacker packer, JavaSerializer javaSerializer)
throws IOException {
TypePacker typePacker;
if (object == null) {
typePacker = NULL_PACKER;
} else {
Class<?> type = object.getClass();
typePacker = packers.get(type);
if (typePacker == null) {
if (type.isArray()) {
typePacker = ARRAY_PACKER;
} else {
typePacker = EXTENSION_PACKER;
}
}
}
typePacker.pack(object, packer, javaSerializer);
}
private static Object unpack(Value v, Class<?> type, JavaDeserializer javaDeserializer) {
return unpackers.get(v.getValueType()).unpack(v, type, javaDeserializer);
}
public static Pair<byte[], Boolean> encode(Object obj) {
MessageBufferPacker packer = MessagePack.newDefaultBufferPacker();
try {
// Reserve MESSAGE_PACK_OFFSET bytes for MessagePack bytes length.
packer.writePayload(new byte[MESSAGE_PACK_OFFSET]);
// Serialize input object by MessagePack.
MutableBoolean isCrossLanguage = new MutableBoolean(true);
pack(obj, packer, ((object, packer1) -> {
byte[] payload = FstSerializer.encode(object);
packer1.packExtensionTypeHeader(LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID, payload.length);
packer1.addPayload(payload);
isCrossLanguage.setFalse();
}));
byte[] msgpackBytes = packer.toByteArray();
// Serialize MessagePack bytes length.
MessageBufferPacker headerPacker = MessagePack.newDefaultBufferPacker();
Preconditions.checkState(msgpackBytes.length >= MESSAGE_PACK_OFFSET);
headerPacker.packLong(msgpackBytes.length - MESSAGE_PACK_OFFSET);
byte[] msgpackBytesLength = headerPacker.toByteArray();
// Check serialized MessagePack bytes length is valid.
Preconditions.checkState(msgpackBytesLength.length <= MESSAGE_PACK_OFFSET);
// Write MessagePack bytes length to reserved buffer.
System.arraycopy(msgpackBytesLength, 0, msgpackBytes, 0, msgpackBytesLength.length);
return ImmutablePair.of(msgpackBytes, isCrossLanguage.getValue());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@SuppressWarnings("unchecked")
public static <T> T decode(byte[] bs, Class<?> type) {
try {
// Read MessagePack bytes length.
MessageUnpacker headerUnpacker = MessagePack.newDefaultUnpacker(bs, 0, MESSAGE_PACK_OFFSET);
long msgpackBytesLength = headerUnpacker.unpackLong();
headerUnpacker.close();
// Check MessagePack bytes length is valid.
Preconditions.checkState(MESSAGE_PACK_OFFSET + msgpackBytesLength <= bs.length);
// Deserialize MessagePack bytes from MESSAGE_PACK_OFFSET.
MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(bs, MESSAGE_PACK_OFFSET,
(int) msgpackBytesLength);
Value v = unpacker.unpackValue();
if (type == null) {
type = Object.class;
}
return (T) unpack(v, type,
((ExtensionValue ev) -> FstSerializer.decode(ev.getData())));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
@@ -0,0 +1,15 @@
package org.ray.runtime.serializer;
import org.apache.commons.lang3.tuple.Pair;
public class Serializer {
public static Pair<byte[], Boolean> encode(Object obj) {
return MessagePackSerializer.encode(obj);
}
@SuppressWarnings("unchecked")
public static <T> T decode(byte[] bs, Class<?> type) {
return MessagePackSerializer.decode(bs, type);
}
}
@@ -1,8 +1,8 @@
package org.ray.runtime.task;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.id.ObjectId;
@@ -40,9 +40,18 @@ public class ArgumentsBuilder {
id = ((RayObject) arg).getId();
} else {
value = ObjectSerializer.serialize(arg);
if (language != Language.JAVA) {
boolean isCrossData =
Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_CROSS_LANGUAGE) ||
Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_RAW);
if (!isCrossData) {
throw new IllegalArgumentException(String.format("Can't transfer %s data to %s",
Arrays.toString(value.metadata), language.getValueDescriptor().getName()));
}
}
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore()
.putRaw(value);
.putRaw(value);
value = null;
}
}
@@ -61,10 +70,10 @@ public class ArgumentsBuilder {
/**
* Convert list of NativeRayObject to real function arguments.
*/
public static Object[] unwrap(List<NativeRayObject> args, ClassLoader classLoader) {
public static Object[] unwrap(List<NativeRayObject> args, Class<?>[] types) {
Object[] realArgs = new Object[args.size()];
for (int i = 0; i < args.size(); i++) {
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, classLoader);
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, types[i]);
}
return realArgs;
}
@@ -5,6 +5,7 @@ import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import org.ray.api.exception.RayTaskException;
import org.ray.api.id.ActorId;
import org.ray.api.id.JobId;
@@ -97,7 +98,8 @@ public abstract class TaskExecutor<T extends ActorContext> {
}
actor = actorContext.currentActor;
}
Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader);
Object[] args = ArgumentsBuilder
.unwrap(argsBytes, rayFunction.executable.getParameterTypes());
// Execute the task.
Object result;
try {
@@ -1,60 +0,0 @@
package org.ray.runtime.util;
import org.nustaq.serialization.FSTConfiguration;
import org.ray.runtime.actor.NativeRayActor;
import org.ray.runtime.actor.NativeRayActorSerializer;
/**
* Java object serialization TODO: use others (e.g. Arrow) for higher performance
*/
public class Serializer {
private static final ThreadLocal<FSTConfiguration> conf = ThreadLocal.withInitial(() -> {
FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration();
conf.registerSerializer(NativeRayActor.class, new NativeRayActorSerializer(), true);
return conf;
});
public static byte[] encode(Object obj) {
return conf.get().asByteArray(obj);
}
public static byte[] encode(Object obj, ClassLoader classLoader) {
byte[] result;
FSTConfiguration current = conf.get();
if (classLoader != null && classLoader != current.getClassLoader()) {
ClassLoader old = current.getClassLoader();
current.setClassLoader(classLoader);
result = current.asByteArray(obj);
current.setClassLoader(old);
} else {
result = current.asByteArray(obj);
}
return result;
}
@SuppressWarnings("unchecked")
public static <T> T decode(byte[] bs) {
return (T) conf.get().asObject(bs);
}
@SuppressWarnings("unchecked")
public static <T> T decode(byte[] bs, ClassLoader classLoader) {
Object object;
FSTConfiguration current = conf.get();
if (classLoader != null && classLoader != current.getClassLoader()) {
ClassLoader old = current.getClassLoader();
current.setClassLoader(classLoader);
object = current.asObject(bs);
current.setClassLoader(old);
} else {
object = current.asObject(bs);
}
return (T) object;
}
public static void setClassloader(ClassLoader classLoader) {
conf.get().setClassLoader(classLoader);
}
}
@@ -0,0 +1,52 @@
package org.ray.runtime.util;
import org.apache.commons.lang3.mutable.MutableBoolean;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.runtime.serializer.Serializer;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.ArrayList;
public class SerializerTest {
@Test
public void testBasicSerialization() {
// Test serialize / deserialize primitive types with type conversion.
{
Object[] foo = new Object[]{"hello", (byte) 1, 2.0, (short) 3, 4, 5L,
new String[]{"hello", "world"}};
Pair<byte[], Boolean> serialized = Serializer.encode(foo);
Object[] bar = Serializer.decode(serialized.getLeft(), Object[].class);
Assert.assertTrue(serialized.getRight());
Assert.assertEquals(foo[0], bar[0]);
Assert.assertEquals(((Number) foo[1]).byteValue(), ((Number) bar[1]).byteValue());
Assert.assertEquals(foo[2], bar[2]);
Assert.assertEquals(((Number) foo[3]).intValue(), ((Number) bar[3]).intValue());
Assert.assertEquals(((Number) foo[4]).intValue(), ((Number) bar[4]).intValue());
Assert.assertEquals(((Number) foo[5]).intValue(), ((Number) bar[5]).intValue());
}
// Test multidimensional array.
{
Object[][] foo = new Object[][]{{1, 2}, {"3", 4}};
Assert.expectThrows(RuntimeException.class, () -> {
Object[][] bar = Serializer.decode(Serializer.encode(foo).getLeft(), Integer[][].class);
});
Pair<byte[], Boolean> serialized = Serializer.encode(foo);
Object[][] bar = Serializer.decode(serialized.getLeft(), Object[][].class);
Assert.assertTrue(serialized.getRight());
Assert.assertEquals(((Number) foo[0][1]).intValue(), ((Number) bar[0][1]).intValue());
Assert.assertEquals(foo[1][0], bar[1][0]);
}
// Test List.
{
ArrayList<String> foo = new ArrayList<>();
foo.add("1");
foo.add("2");
Pair<byte[], Boolean> serialized = Serializer.encode(foo);
ArrayList<String> bar = Serializer.decode(serialized.getLeft(), String[].class);
Assert.assertFalse(serialized.getRight());
Assert.assertEquals(foo.get(0), bar.get(0));
}
}
}
@@ -143,7 +143,7 @@ public class ActorTest extends BaseTest {
try {
// Try getting the object again, this should throw an UnreconstructableException.
// Use `Ray.get()` to bypass the cache in `RayObjectImpl`.
Ray.get(value.getId());
Ray.get(value.getId(), value.getType());
Assert.fail("This line should not be reachable.");
} catch (UnreconstructableException e) {
Assert.assertEquals(value.getId(), e.objectId);
@@ -4,6 +4,7 @@ import java.io.File;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Optional;
import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;
import org.apache.commons.io.FileUtils;
@@ -101,12 +102,14 @@ public class ClassLoaderTest extends BaseTest {
"()V");
RayActor<?> actor1 = createActor(constructor);
FunctionDescriptor getPid = new JavaFunctionDescriptor("ClassLoaderTester", "getPid", "()I");
int pid = this.<Integer>callActorFunction(actor1, getPid, new Object[0], 1).get();
int pid = this.<Integer>callActorFunction(actor1, getPid, new Object[0],
Optional.of(Integer.class)).get();
RayActor<?> actor2;
while (true) {
// Create another actor which share the same process of actor 1.
actor2 = createActor(constructor);
int actor2Pid = this.<Integer>callActorFunction(actor2, getPid, new Object[0], 1).get();
int actor2Pid = this.<Integer>callActorFunction(actor2, getPid, new Object[0],
Optional.of(Integer.class)).get();
if (actor2Pid == pid) {
break;
}
@@ -116,15 +119,17 @@ public class ClassLoaderTest extends BaseTest {
"getClassLoaderHashCode",
"()I");
RayObject<Integer> hashCode1 = callActorFunction(actor1, getClassLoaderHashCode, new Object[0],
1);
Optional.of(Integer.class));
RayObject<Integer> hashCode2 = callActorFunction(actor2, getClassLoaderHashCode, new Object[0],
1);
Optional.of(Integer.class));
Assert.assertEquals(hashCode1.get(), hashCode2.get());
FunctionDescriptor increase = new JavaFunctionDescriptor("ClassLoaderTester", "increase",
"()I");
RayObject<Integer> value1 = callActorFunction(actor1, increase, new Object[0], 1);
RayObject<Integer> value2 = callActorFunction(actor2, increase, new Object[0], 1);
RayObject<Integer> value1 = callActorFunction(actor1, increase, new Object[0],
Optional.of(Integer.class));
RayObject<Integer> value2 = callActorFunction(actor2, increase, new Object[0],
Optional.of(Integer.class));
Assert.assertNotEquals(value1.get(), value2.get());
}
@@ -138,11 +143,12 @@ public class ClassLoaderTest extends BaseTest {
}
private <T> RayObject<T> callActorFunction(RayActor<?> rayActor,
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) throws Exception {
FunctionDescriptor functionDescriptor, Object[] args, Optional<Class<?>> returnType)
throws Exception {
Method callActorFunctionMethod = AbstractRayRuntime.class.getDeclaredMethod("callActorFunction",
BaseActor.class, FunctionDescriptor.class, Object[].class, int.class);
BaseActor.class, FunctionDescriptor.class, Object[].class, Optional.class);
callActorFunctionMethod.setAccessible(true);
return (RayObject<T>) callActorFunctionMethod
.invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, numReturns);
.invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, returnType);
}
}
@@ -23,7 +23,7 @@ public class ClientExceptionTest extends BaseTest {
public void testWaitAndCrash() {
TestUtils.skipTestUnderSingleProcess();
ObjectId randomId = ObjectId.fromRandom();
RayObject<String> notExisting = new RayObjectImpl(randomId);
RayObject<String> notExisting = new RayObjectImpl(randomId, String.class);
Thread thread = new Thread(() -> {
try {
@@ -5,6 +5,9 @@ import com.google.common.collect.ImmutableMap;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.ray.api.Ray;
@@ -51,18 +54,85 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
@Test
public void testCallingPythonFunction() {
RayObject<byte[]> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_func", byte[].class),
"hello".getBytes());
Assert.assertEquals(res.get(), "Response from Python: hello".getBytes());
Object[] inputs = new Object[]{
true, // Boolean
Byte.MAX_VALUE, // Byte
Short.MAX_VALUE, // Short
Integer.MAX_VALUE, // Integer
Long.MAX_VALUE, // Long
// BigInteger can support max value of 2^64-1, please refer to:
// https://github.com/msgpack/msgpack/blob/master/spec.md#int-format-family
// If BigInteger larger than 2^64-1, the value can only be transferred among Java workers.
BigInteger.valueOf(Long.MAX_VALUE), // BigInteger
"Hello World!", // String
1.234f, // Float
1.234, // Double
"example binary".getBytes()}; // byte[]
for (Object o : inputs) {
RayObject res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", o.getClass()),
o);
Assert.assertEquals(res.get(), o);
}
// null
{
Object input = null;
RayObject<Object> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object.class), input);
Object r = res.get();
Assert.assertEquals(r, input);
}
// array
{
int[] input = new int[]{1, 2};
RayObject<int[]> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", int[].class), input);
int[] r = res.get();
Assert.assertEquals(r, input);
}
// array of Object
{
Object[] input = new Object[]{1, 2.3f, 4.56, "789", "10".getBytes(), null, true,
new int[]{1, 2}};
RayObject<Object[]> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object[].class), input);
Object[] r = res.get();
// If we tell the value type is Object, then all numbers will be Number type.
Assert.assertEquals(((Number) r[0]).intValue(), input[0]);
Assert.assertEquals(((Number) r[1]).floatValue(), input[1]);
Assert.assertEquals(((Number) r[2]).doubleValue(), input[2]);
// String cast
Assert.assertEquals((String) r[3], input[3]);
// binary cast
Assert.assertEquals((byte[]) r[4], input[4]);
// null
Assert.assertEquals(r[5], input[5]);
// Boolean cast
Assert.assertEquals((Boolean) r[6], input[6]);
// array cast
Object[] r7array = (Object[]) r[7];
int[] input7array = (int[]) input[7];
Assert.assertEquals(((Number) r7array[0]).intValue(), input7array[0]);
Assert.assertEquals(((Number) r7array[1]).intValue(), input7array[1]);
}
// Unsupported types, all Java specific types, e.g. List / Map...
{
Assert.expectThrows(Exception.class, () -> {
List<Integer> input = Arrays.asList(1, 2);
RayObject<List<Integer>> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input",
(Class<List<Integer>>) input.getClass()), input);
List<Integer> r = res.get();
Assert.assertEquals(r, input);
});
}
}
@Test
public void testPythonCallJavaFunction() {
RayObject<byte[]> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_func_call_java_function", byte[].class),
"hello".getBytes());
Assert.assertEquals(res.get(), "[Python]py_func -> [Java]bytesEcho -> hello".getBytes());
RayObject<String> res = Ray.call(
new PyRemoteFunction<>(PYTHON_MODULE, "py_func_call_java_function", String.class));
Assert.assertEquals(res.get(), "success");
}
@Test
@@ -117,11 +187,33 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
Assert.assertEquals(res.get(), "3".getBytes());
}
public static byte[] bytesEcho(byte[] value) {
public static Object[] pack(int i, String s, double f, Object[] o) {
// This function will be called from test_cross_language_invocation.py
String valueStr = new String(value);
LOGGER.debug(String.format("bytesEcho called with: %s", valueStr));
return ("[Java]bytesEcho -> " + valueStr).getBytes();
return new Object[]{i, s, f, o};
}
public static Object returnInput(Object o) {
return o;
}
public static boolean returnInputBoolean(boolean b) {
return b;
}
public static int returnInputInt(int i) {
return i;
}
public static double returnInputDouble(double d) {
return d;
}
public static String returnInputString(String s) {
return s;
}
public static int[] returnInputIntArray(int[] l) {
return l;
}
public static byte[] callPythonActorHandle(byte[] value) {
@@ -135,6 +227,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
}
public static class TestActor {
public TestActor(byte[] v) {
value = v;
}
@@ -45,7 +45,7 @@ public class DynamicResourceTest extends BaseTest {
// Assert ray call result.
result = Ray.wait(ImmutableList.of(obj), 1, 1000);
Assert.assertEquals(result.getReady().size(), 1);
Assert.assertEquals(Ray.get(obj.getId()), "hi");
Assert.assertEquals(obj.get(), "hi");
}
@@ -148,7 +148,7 @@ public class FailureTest extends BaseTest {
RayObject<Integer> obj2 = Ray.call(FailureTest::slowFunc);
Instant start = Instant.now();
try {
Ray.get(Arrays.asList(obj1.getId(), obj2.getId()));
Ray.get(Arrays.asList(obj1, obj2));
Assert.fail("Should throw RayException.");
} catch (RayException e) {
Instant end = Instant.now();
@@ -104,7 +104,7 @@ public class MultiThreadingTest extends BaseTest {
runTestCaseInMultipleThreads(() -> {
int arg = random.nextInt();
RayObject<Integer> obj = Ray.put(arg);
Assert.assertEquals(arg, (int) Ray.get(obj.getId()));
Assert.assertEquals(arg, (int) obj.get());
}, LOOP_COUNTER);
TestUtils.warmUpCluster();
@@ -141,7 +141,7 @@ public class MultiThreadingTest extends BaseTest {
final RayActor<Echo> fooActor = Ray.createActor(Echo::new);
final Runnable[] runnables = new Runnable[]{
() -> Ray.put(1),
() -> Ray.get(fooObject.getId()),
() -> Ray.get(fooObject.getId(), fooObject.getType()),
fooObject::get,
() -> Ray.wait(ImmutableList.of(fooObject)),
Ray::getRuntimeContext,
@@ -16,8 +16,22 @@ public class ObjectStoreTest extends BaseTest {
@Test
public void testPutAndGet() {
RayObject<Integer> obj = Ray.put(1);
Assert.assertEquals(1, (int) obj.get());
{
RayObject<Integer> obj = Ray.put(1);
Assert.assertEquals(1, (int) obj.get());
}
{
String s = null;
RayObject<String> obj = Ray.put(s);
Assert.assertNull(obj.get());
}
{
List<List<String>> l = ImmutableList.of(ImmutableList.of("abc"));
RayObject<List<List<String>>> obj = Ray.put(l);
Assert.assertEquals(obj.get(), l);
}
}
@Test
@@ -25,6 +39,6 @@ public class ObjectStoreTest extends BaseTest {
List<Integer> ints = ImmutableList.of(1, 2, 3, 4, 5);
List<ObjectId> ids = ints.stream().map(obj -> Ray.put(obj).getId())
.collect(Collectors.toList());
Assert.assertEquals(ints, Ray.get(ids));
Assert.assertEquals(ints, Ray.get(ids, Integer.class));
}
}
@@ -15,9 +15,9 @@ public class PlasmaStoreTest extends BaseTest {
ObjectId objectId = ObjectId.fromRandom();
ObjectStore objectStore = TestUtils.getRuntime().getObjectStore();
objectStore.put("1", objectId);
Assert.assertEquals(Ray.get(objectId), "1");
Assert.assertEquals(Ray.get(objectId, String.class), "1");
objectStore.put("2", objectId);
// Putting the second object with duplicate ID should fail but ignored.
Assert.assertEquals(Ray.get(objectId), "1");
Assert.assertEquals(Ray.get(objectId, String.class), "1");
}
}
@@ -87,7 +87,7 @@ public class RayCallTest extends BaseTest {
ObjectId randomObjectId = ObjectId.fromRandom();
Ray.call(RayCallTest::testNoReturn, randomObjectId);
Assert.assertEquals(((int) Ray.get(randomObjectId)), 1);
Assert.assertEquals(((int) Ray.get(randomObjectId, Integer.class)), 1);
}
private static int testNoParam() {
@@ -2,9 +2,7 @@ package org.ray.api.test;
import org.ray.api.Ray;
import org.ray.api.RayPyActor;
import org.ray.api.TestUtils;
import org.ray.api.function.PyActorClass;
import org.ray.runtime.context.WorkerContext;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectSerializer;
import org.testng.Assert;
@@ -15,10 +13,9 @@ public class RaySerializerTest extends BaseMultiLanguageTest {
@Test
public void testSerializePyActor() {
RayPyActor pyActor = Ray.createActor(new PyActorClass("test", "RaySerializerTest"));
WorkerContext workerContext = TestUtils.getRuntime().getWorkerContext();
NativeRayObject nativeRayObject = ObjectSerializer.serialize(pyActor);
RayPyActor result = (RayPyActor) ObjectSerializer
.deserialize(nativeRayObject, null, workerContext.getCurrentClassLoader());
.deserialize(nativeRayObject, null, Object.class);
Assert.assertEquals(result.getId(), pyActor.getId());
Assert.assertEquals(result.getModuleName(), "test");
Assert.assertEquals(result.getClassName(), "RaySerializerTest");
@@ -28,7 +28,7 @@ public class StressTest extends BaseTest {
resultIds.add(Ray.call(StressTest::echo, 1).getId());
}
for (Integer result : Ray.<Integer>get(resultIds)) {
for (Integer result : Ray.<Integer>get(resultIds, Integer.class)) {
Assert.assertEquals(result, Integer.valueOf(1));
}
}
@@ -67,7 +67,7 @@ public class StressTest extends BaseTest {
objectIds.add(actor.call(Actor::ping).getId());
}
int sum = 0;
for (Integer result : Ray.<Integer>get(objectIds)) {
for (Integer result : Ray.<Integer>get(objectIds, Integer.class)) {
sum += result;
}
return sum;
@@ -84,7 +84,7 @@ public class StressTest extends BaseTest {
objectIds.add(worker.call(Worker::ping, 100).getId());
}
for (Integer result : Ray.<Integer>get(objectIds)) {
for (Integer result : Ray.<Integer>get(objectIds, Integer.class)) {
Assert.assertEquals(result, Integer.valueOf(100));
}
}
@@ -5,18 +5,47 @@ import ray
@ray.remote
def py_func(value):
assert isinstance(value, bytes)
return b"Response from Python: " + value
def py_return_input(v):
return v
@ray.remote
def py_func_call_java_function(value):
assert isinstance(value, bytes)
f = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"bytesEcho")
r = f.remote(value)
return b"[Python]py_func -> " + ray.get(r)
def py_func_call_java_function():
try:
# None
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"returnInput").remote(None)
assert ray.get(r) is None
# bool
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"returnInputBoolean").remote(True)
assert ray.get(r) is True
# int
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"returnInputInt").remote(100)
assert ray.get(r) == 100
# double
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"returnInputDouble").remote(1.23)
assert ray.get(r) == 1.23
# string
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"returnInputString").remote("Hello World!")
assert ray.get(r) == "Hello World!"
# list (tuple will be packed by pickle,
# so only list can be transferred across language)
r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"returnInputIntArray").remote([1, 2, 3])
assert ray.get(r) == [1, 2, 3]
# pack
f = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest",
"pack")
input = [100, "hello", 1.23, [1, "2", 3.0]]
r = f.remote(*input)
assert ray.get(r) == input
return "success"
except Exception as ex:
return str(ex)
@ray.remote