diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 50888d48c..7cebc02d4 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -76,6 +76,7 @@ define_java_module( "@maven//:de_ruedigermoeller_fst", "@maven//:net_java_dev_jna_jna", "@maven//:org_apache_commons_commons_lang3", + "@maven//:org_msgpack_msgpack_core", "@maven//:org_ow2_asm_asm", "@maven//:org_slf4j_slf4j_api", "@maven//:org_slf4j_slf4j_log4j12", diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java index cf673b13a..0f758f7b0 100644 --- a/java/api/src/main/java/org/ray/api/Ray.java +++ b/java/api/src/main/java/org/ray/api/Ray.java @@ -1,5 +1,6 @@ package org.ray.api; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; import org.ray.api.id.ObjectId; @@ -62,23 +63,41 @@ public final class Ray extends RayCall { } /** - * Get an object from the object store. + * Get an object by id from the object store. * * @param objectId The ID of the object to get. + * @param objectType The type of the object to get. * @return The Java object. */ - public static T get(ObjectId objectId) { - return runtime.get(objectId); + public static T get(ObjectId objectId, Class objectType) { + return runtime.get(objectId, objectType); } /** - * Get a list of objects from the object store. + * Get a list of objects by ids from the object store. * * @param objectIds The list of object IDs. + * @param objectType The type of object. * @return A list of Java objects. */ - public static List get(List objectIds) { - return runtime.get(objectIds); + public static List get(List objectIds, Class objectType) { + return runtime.get(objectIds, objectType); + } + + /** + * Get a list of objects by RayObjects from the object store. + * + * @param objectList A list of RayObject to get. + * @return A list of Java objects. + */ + public static List get(List> objectList) { + List objectIds = new ArrayList<>(); + Class objectType = null; + for (RayObject o : objectList) { + objectIds.add(o.getId()); + objectType = o.getType(); + } + return runtime.get(objectIds, objectType); } /** diff --git a/java/api/src/main/java/org/ray/api/RayObject.java b/java/api/src/main/java/org/ray/api/RayObject.java index faf42f826..e5d67e063 100644 --- a/java/api/src/main/java/org/ray/api/RayObject.java +++ b/java/api/src/main/java/org/ray/api/RayObject.java @@ -19,5 +19,10 @@ public interface RayObject { */ ObjectId getId(); + /** + * Get the Object type. + */ + Class getType(); + } diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index 9f8552e74..3c94c7d2c 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -39,17 +39,19 @@ public interface RayRuntime { * Get an object from the object store. * * @param objectId The ID of the object to get. + * @param objectType The type of the object to get. * @return The Java object. */ - T get(ObjectId objectId); + T get(ObjectId objectId, Class objectType); /** * Get a list of objects from the object store. * * @param objectIds The list of object IDs. + * @param objectType The type of object. * @return A list of Java objects. */ - List get(List objectIds); + List get(List objectIds, Class objectType); /** * Wait for a list of RayObjects to be locally available, until specified number of objects are diff --git a/java/dependencies.bzl b/java/dependencies.bzl index 791b24b12..d1af092ef 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -15,11 +15,12 @@ def gen_java_deps(): "de.ruedigermoeller:fst:2.57", "javax.xml.bind:jaxb-api:2.3.0", "org.apache.commons:commons-lang3:3.4", + "org.msgpack:msgpack-core:0.8.20", "org.ow2.asm:asm:6.0", "org.slf4j:slf4j-log4j12:1.7.25", "org.testng:testng:6.9.10", "redis.clients:jedis:2.8.0", - "net.java.dev.jna:jna:5.5.0" + "net.java.dev.jna:jna:5.5.0", ], repositories = [ "https://repo1.maven.org/maven2/", diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index b4aa54508..513f304d1 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -62,6 +62,11 @@ commons-lang3 3.4 + + org.msgpack + msgpack-core + 0.8.20 + org.ow2.asm asm diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 44c8bd6dd..b51f85cb3 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -4,6 +4,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.Optional; import java.util.concurrent.Callable; import org.ray.api.BaseActor; import org.ray.api.RayActor; @@ -15,7 +16,6 @@ import org.ray.api.function.PyActorClass; import org.ray.api.function.PyActorMethod; import org.ray.api.function.PyRemoteFunction; import org.ray.api.function.RayFunc; -import org.ray.api.function.RayFuncVoid; import org.ray.api.id.ObjectId; import org.ray.api.options.ActorCreationOptions; import org.ray.api.options.CallOptions; @@ -26,6 +26,7 @@ import org.ray.runtime.context.WorkerContext; import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.functionmanager.FunctionManager; import org.ray.runtime.functionmanager.PyFunctionDescriptor; +import org.ray.runtime.functionmanager.RayFunction; import org.ray.runtime.gcs.GcsClient; import org.ray.runtime.generated.Common.Language; import org.ray.runtime.generated.Common.WorkerType; @@ -73,18 +74,18 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { @Override public RayObject put(T obj) { ObjectId objectId = objectStore.put(obj); - return new RayObjectImpl<>(objectId); + return new RayObjectImpl(objectId, (Class)(obj == null ? Object.class : obj.getClass())); } @Override - public T get(ObjectId objectId) throws RayException { - List ret = get(ImmutableList.of(objectId)); + public T get(ObjectId objectId, Class objectType) throws RayException { + List ret = get(ImmutableList.of(objectId), objectType); return ret.get(0); } @Override - public List get(List objectIds) { - return objectStore.get(objectIds); + public List get(List objectIds, Class objectType) { + return objectStore.get(objectIds, objectType); } @Override @@ -99,41 +100,39 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { @Override public RayObject call(RayFunc func, Object[] args, CallOptions options) { - FunctionDescriptor functionDescriptor = - functionManager.getFunction(workerContext.getCurrentJobId(), func) - .functionDescriptor; - int numReturns = func instanceof RayFuncVoid ? 0 : 1; - return callNormalFunction(functionDescriptor, args, numReturns, options); + RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func); + FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor; + Optional> returnType = rayFunction.getReturnType(); + return callNormalFunction(functionDescriptor, args, returnType, options); } @Override public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args, - CallOptions options) { - checkPyArguments(args); + CallOptions options) { PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor( pyRemoteFunction.moduleName, "", pyRemoteFunction.functionName); // Python functions always have a return value, even if it's `None`. - return callNormalFunction(functionDescriptor, args, /*numReturns=*/1, options); + return callNormalFunction(functionDescriptor, args, + /*returnType=*/Optional.of(pyRemoteFunction.returnType), options); } @Override public RayObject callActor(RayActor actor, RayFunc func, Object[] args) { - FunctionDescriptor functionDescriptor = - functionManager.getFunction(workerContext.getCurrentJobId(), func) - .functionDescriptor; - int numReturns = func instanceof RayFuncVoid ? 0 : 1; - return callActorFunction(actor, functionDescriptor, args, numReturns); + RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentJobId(), func); + FunctionDescriptor functionDescriptor = rayFunction.functionDescriptor; + Optional> returnType = rayFunction.getReturnType(); + return callActorFunction(actor, functionDescriptor, args, returnType); } @Override public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object... args) { - checkPyArguments(args); PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(), pyActor.getClassName(), pyActorMethod.methodName); // Python functions always have a return value, even if it's `None`. - return callActorFunction(pyActor, functionDescriptor, args, /*numReturns=*/1); + return callActorFunction(pyActor, functionDescriptor, args, + /*returnType=*/Optional.of(pyActorMethod.returnType)); } @Override @@ -148,8 +147,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { @Override public RayPyActor createActor(PyActorClass pyActorClass, Object[] args, - ActorCreationOptions options) { - checkPyArguments(args); + ActorCreationOptions options) { PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor( pyActorClass.moduleName, pyActorClass.className, @@ -157,14 +155,6 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return (RayPyActor) createActorImpl(functionDescriptor, args, options); } - private void checkPyArguments(Object[] args) { - for (Object arg : args) { - Preconditions.checkArgument( - (arg instanceof RayPyActor) || (arg instanceof byte[]), - "Python argument can only be a RayPyActor or a byte array, not {}.", - arg.getClass().getName()); - } - } @Override public void setAsyncContext(Object asyncContext) { @@ -218,30 +208,32 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { } private RayObject callNormalFunction(FunctionDescriptor functionDescriptor, - Object[] args, int numReturns, CallOptions options) { + Object[] args, Optional> returnType, CallOptions options) { + int numReturns = returnType.isPresent() ? 1 : 0; List functionArgs = ArgumentsBuilder .wrap(args, functionDescriptor.getLanguage()); List returnIds = taskSubmitter.submitTask(functionDescriptor, functionArgs, numReturns, options); - Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1); + Preconditions.checkState(returnIds.size() == numReturns); if (returnIds.isEmpty()) { return null; } else { - return new RayObjectImpl(returnIds.get(0)); + return new RayObjectImpl(returnIds.get(0), returnType.get()); } } private RayObject callActorFunction(BaseActor rayActor, - FunctionDescriptor functionDescriptor, Object[] args, int numReturns) { + FunctionDescriptor functionDescriptor, Object[] args, Optional> returnType) { + int numReturns = returnType.isPresent() ? 1 : 0; List functionArgs = ArgumentsBuilder .wrap(args, functionDescriptor.getLanguage()); List returnIds = taskSubmitter.submitActorTask(rayActor, functionDescriptor, functionArgs, numReturns, null); - Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1); + Preconditions.checkState(returnIds.size() == numReturns); if (returnIds.isEmpty()) { return null; } else { - return new RayObjectImpl(returnIds.get(0)); + return new RayObjectImpl(returnIds.get(0), returnType.get()); } } diff --git a/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java index 4a526c85e..253030502 100644 --- a/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/context/WorkerContext.java @@ -5,6 +5,7 @@ import org.ray.api.id.JobId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.generated.Common.TaskType; +import org.ray.runtime.serializer.Serializer; /** * The context of worker. @@ -28,7 +29,7 @@ public interface WorkerContext { /** * The class loader that is associated with the current job. It's used for locating classes when - * dealing with serialization and deserialization in {@link org.ray.runtime.util.Serializer}. + * dealing with serialization and deserialization in {@link Serializer}. */ ClassLoader getCurrentClassLoader(); diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java index 9ebe6cf0a..01741e5e8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java @@ -3,6 +3,7 @@ package org.ray.runtime.functionmanager; import java.lang.reflect.Constructor; import java.lang.reflect.Executable; import java.lang.reflect.Method; +import java.util.Optional; /** * Represents a Ray function (either a Method or a Constructor in Java) and its metadata. @@ -67,6 +68,17 @@ public class RayFunction { } } + /** + * @return Return type. + */ + public Optional> getReturnType() { + if (hasReturn()) { + return Optional.of(((Method) executable).getReturnType()); + } else { + return Optional.empty(); + } + } + @Override public String toString() { return executable.toString(); diff --git a/java/runtime/src/main/java/org/ray/runtime/object/ObjectSerializer.java b/java/runtime/src/main/java/org/ray/runtime/object/ObjectSerializer.java index fe89ba428..528bf1099 100644 --- a/java/runtime/src/main/java/org/ray/runtime/object/ObjectSerializer.java +++ b/java/runtime/src/main/java/org/ray/runtime/object/ObjectSerializer.java @@ -1,13 +1,15 @@ package org.ray.runtime.object; import java.util.Arrays; + +import org.apache.commons.lang3.tuple.Pair; import org.ray.api.exception.RayActorException; import org.ray.api.exception.RayTaskException; import org.ray.api.exception.RayWorkerException; import org.ray.api.exception.UnreconstructableException; import org.ray.api.id.ObjectId; import org.ray.runtime.generated.Gcs.ErrorType; -import org.ray.runtime.util.Serializer; +import org.ray.runtime.serializer.Serializer; /** * Serialize to and deserialize from {@link NativeRayObject}. Metadata is generated during @@ -21,29 +23,33 @@ public class ObjectSerializer { .valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes(); private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes(); - private static final byte[] TASK_EXECUTION_EXCEPTION_META = String .valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes(); - private static final byte[] RAW_TYPE_META = "RAW".getBytes(); + public static final byte[] OBJECT_METADATA_TYPE_CROSS_LANGUAGE = "XLANG".getBytes(); + public static final byte[] OBJECT_METADATA_TYPE_JAVA = "JAVA".getBytes(); + public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes(); + public static final byte[] OBJECT_METADATA_TYPE_RAW = "RAW".getBytes(); /** * Deserialize an object from an {@link NativeRayObject} instance. * * @param nativeRayObject The object to deserialize. * @param objectId The associated object ID of the object. - * @param classLoader The classLoader of the object. * @return The deserialized object. */ public static Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId, - ClassLoader classLoader) { + Class objectType) { byte[] meta = nativeRayObject.metadata; byte[] data = nativeRayObject.data; if (meta != null && meta.length > 0) { // If meta is not null, deserialize the object from meta. - if (Arrays.equals(meta, RAW_TYPE_META)) { + if (Arrays.equals(meta, OBJECT_METADATA_TYPE_RAW)) { return data; + } else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_CROSS_LANGUAGE) || + Arrays.equals(meta, OBJECT_METADATA_TYPE_JAVA)) { + return Serializer.decode(data, objectType); } else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { return new RayWorkerException(); } else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) { @@ -51,12 +57,15 @@ public class ObjectSerializer { } else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) { return new UnreconstructableException(objectId); } else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) { - return Serializer.decode(data, classLoader); + return Serializer.decode(data, objectType); + } else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_PYTHON)) { + throw new IllegalArgumentException("Can't deserialize Python object: " + objectId + .toString()); } throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta)); } else { // If data is not null, deserialize the Java object. - return Serializer.decode(data, classLoader); + return Serializer.decode(data, objectType); } } @@ -72,12 +81,14 @@ public class ObjectSerializer { } else if (object instanceof byte[]) { // If the object is a byte array, skip serializing it and use a special metadata to // indicate it's raw binary. So that this object can also be read by Python. - return new NativeRayObject((byte[]) object, RAW_TYPE_META); + return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW); } else if (object instanceof RayTaskException) { - return new NativeRayObject(Serializer.encode(object), - TASK_EXECUTION_EXCEPTION_META); + byte[] serializedBytes = Serializer.encode(object).getLeft(); + return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META); } else { - return new NativeRayObject(Serializer.encode(object), null); + Pair serialized = Serializer.encode(object); + return new NativeRayObject(serialized.getLeft(), serialized.getRight() ? + OBJECT_METADATA_TYPE_CROSS_LANGUAGE : OBJECT_METADATA_TYPE_JAVA); } } } diff --git a/java/runtime/src/main/java/org/ray/runtime/object/ObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/object/ObjectStore.java index 223c49b27..d44cac969 100644 --- a/java/runtime/src/main/java/org/ray/runtime/object/ObjectStore.java +++ b/java/runtime/src/main/java/org/ray/runtime/object/ObjectStore.java @@ -86,7 +86,7 @@ public abstract class ObjectStore { * @return A list of GetResult objects. */ @SuppressWarnings("unchecked") - public List get(List ids) { + public List get(List ids, Class elementType) { // Pass -1 as timeout to wait until all objects are available in object store. List dataAndMetaList = getRaw(ids, -1); @@ -96,7 +96,7 @@ public abstract class ObjectStore { Object object = null; if (dataAndMeta != null) { object = ObjectSerializer - .deserialize(dataAndMeta, ids.get(i), workerContext.getCurrentClassLoader()); + .deserialize(dataAndMeta, ids.get(i), elementType); } if (object instanceof RayException) { // If the object is a `RayException`, it means that an error occurred during task diff --git a/java/runtime/src/main/java/org/ray/runtime/object/RayObjectImpl.java b/java/runtime/src/main/java/org/ray/runtime/object/RayObjectImpl.java index c34407691..b9d08ece9 100644 --- a/java/runtime/src/main/java/org/ray/runtime/object/RayObjectImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/object/RayObjectImpl.java @@ -1,6 +1,7 @@ package org.ray.runtime.object; import java.io.Serializable; + import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.id.ObjectId; @@ -20,13 +21,16 @@ public final class RayObjectImpl implements RayObject, Serializable { */ private transient T object; + private Class type; + /** * Whether the object is already gotten from the object store. */ private transient boolean objectGotten; - public RayObjectImpl(ObjectId id) { + public RayObjectImpl(ObjectId id, Class type) { this.id = id; + this.type = type; object = null; objectGotten = false; } @@ -34,7 +38,7 @@ public final class RayObjectImpl implements RayObject, Serializable { @Override public synchronized T get() { if (!objectGotten) { - object = Ray.get(id); + object = Ray.get(id, type); objectGotten = true; } return object; @@ -45,4 +49,9 @@ public final class RayObjectImpl implements RayObject, Serializable { return id; } + @Override + public Class getType() { + return type; + } + } diff --git a/java/runtime/src/main/java/org/ray/runtime/serializer/FstSerializer.java b/java/runtime/src/main/java/org/ray/runtime/serializer/FstSerializer.java new file mode 100644 index 000000000..3a93c1f95 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/serializer/FstSerializer.java @@ -0,0 +1,32 @@ +package org.ray.runtime.serializer; + +import org.nustaq.serialization.FSTConfiguration; +import org.ray.runtime.actor.NativeRayActor; +import org.ray.runtime.actor.NativeRayActorSerializer; + +/** + * Java object serialization TODO: use others (e.g. Arrow) for higher performance + */ +public class FstSerializer { + + private static final ThreadLocal conf = ThreadLocal.withInitial(() -> { + FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration(); + conf.registerSerializer(NativeRayActor.class, new NativeRayActorSerializer(), true); + return conf; + }); + + + public static byte[] encode(Object obj) { + FSTConfiguration current = conf.get(); + current.setClassLoader(Thread.currentThread().getContextClassLoader()); + return current.asByteArray(obj); + } + + + @SuppressWarnings("unchecked") + public static T decode(byte[] bs) { + FSTConfiguration current = conf.get(); + current.setClassLoader(Thread.currentThread().getContextClassLoader()); + return (T) current.asObject(bs); + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/serializer/MessagePackSerializer.java b/java/runtime/src/main/java/org/ray/runtime/serializer/MessagePackSerializer.java new file mode 100644 index 000000000..a0856a5a4 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/serializer/MessagePackSerializer.java @@ -0,0 +1,270 @@ +package org.ray.runtime.serializer; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import java.io.IOException; +import java.lang.reflect.Array; +import java.math.BigInteger; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.msgpack.core.MessageBufferPacker; +import org.msgpack.core.MessagePack; +import org.msgpack.core.MessagePacker; +import org.msgpack.core.MessageUnpacker; +import org.msgpack.value.ArrayValue; +import org.msgpack.value.ExtensionValue; +import org.msgpack.value.IntegerValue; +import org.msgpack.value.Value; +import org.msgpack.value.ValueType; + +// We can't pack List / Map by MessagePack, because we don't know the type class when unpacking. +public class MessagePackSerializer { + + private static final byte LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID = 101; + // MessagePack length is an int takes up to 9 bytes. + // https://github.com/msgpack/msgpack/blob/master/spec.md#int-format-family + private static final int MESSAGE_PACK_OFFSET = 9; + + // Pakcers indexed by its corresponding Java class object. + private static Map, TypePacker> packers = new HashMap<>(); + // Unpackers indexed by its corresponding MessagePack ValueType. + private static Map unpackers = new HashMap<>(); + // Null and array don't have a corresponding class, so define them separately. + private static final TypePacker NULL_PACKER; + private static final TypePacker ARRAY_PACKER; + private static final TypePacker EXTENSION_PACKER; + + static { + // ===== Initialize packers ===== + // Null packer. + NULL_PACKER = (object, packer, javaSerializer) -> packer.packNil(); + + // Array packer. + ARRAY_PACKER = ((object, packer, javaSerializer) -> { + int length = Array.getLength(object); + packer.packArrayHeader(length); + for (int i = 0; i < length; ++i) { + pack(Array.get(object, i), packer, javaSerializer); + } + }); + + // Extension packer. + EXTENSION_PACKER = ((object, packer, javaSerializer) -> { + javaSerializer.serialize(object, packer); + }); + + packers.put(Boolean.class, + ((object, packer, javaSerializer) -> packer.packBoolean((Boolean) object))); + packers.put(Byte.class, + ((object, packer, javaSerializer) -> packer.packByte((Byte) object))); + packers.put(Short.class, + ((object, packer, javaSerializer) -> packer.packShort((Short) object))); + packers.put(Integer.class, + ((object, packer, javaSerializer) -> packer.packInt((Integer) object))); + packers.put(Long.class, + ((object, packer, javaSerializer) -> packer.packLong((Long) object))); + packers.put(BigInteger.class, + ((object, packer, javaSerializer) -> packer.packBigInteger((BigInteger) object))); + packers.put(Float.class, + ((object, packer, javaSerializer) -> packer.packFloat((Float) object))); + packers.put(Double.class, + ((object, packer, javaSerializer) -> packer.packDouble((Double) object))); + packers.put(String.class, + ((object, packer, javaSerializer) -> packer.packString((String) object))); + packers.put(byte[].class, + ((object, packer, javaSerializer) -> { + byte[] bytes = (byte[]) object; + packer.packBinaryHeader(bytes.length); + packer.writePayload(bytes); + })); + + // ===== Initialize unpackers ===== + List> booleanClasses = ImmutableList.of(Boolean.class, boolean.class); + List> byteClasses = ImmutableList.of(Byte.class, byte.class); + List> shortClasses = ImmutableList.of(Short.class, short.class); + List> intClasses = ImmutableList.of(Integer.class, int.class); + List> longClasses = ImmutableList.of(Long.class, long.class); + List> bigIntClasses = ImmutableList.of(BigInteger.class); + List> floatClasses = ImmutableList.of(Float.class, float.class); + List> doubleClasses = ImmutableList.of(Double.class, double.class); + List> stringClasses = ImmutableList.of(String.class); + List> binaryClasses = ImmutableList.of(byte[].class); + + // Null unpacker. + unpackers.put(ValueType.NIL, (value, targetClass, javaDeserializer) -> null); + // Boolean unpacker. + unpackers.put(ValueType.BOOLEAN, (value, targetClass, javaDeserializer) -> { + Preconditions.checkArgument(checkTypeCompatible(booleanClasses, targetClass), + "Boolean can't be deserialized as {}.", targetClass); + return value.asBooleanValue().getBoolean(); + }); + // Integer unpacker. + unpackers.put(ValueType.INTEGER, ((value, targetClass, javaDeserializer) -> { + IntegerValue iv = value.asIntegerValue(); + if (iv.isInByteRange() && checkTypeCompatible(byteClasses, targetClass)) { + return iv.asByte(); + } else if (iv.isInShortRange() && checkTypeCompatible(shortClasses, targetClass)) { + return iv.asShort(); + } else if (iv.isInIntRange() && checkTypeCompatible(intClasses, targetClass)) { + return iv.asInt(); + } else if (iv.isInLongRange() && checkTypeCompatible(longClasses, targetClass)) { + return iv.asLong(); + } else if (checkTypeCompatible(bigIntClasses, targetClass)) { + return iv.asBigInteger(); + } + throw new IllegalArgumentException("Integer can't be deserialized as " + targetClass + "."); + })); + // Float unpacker. + unpackers.put(ValueType.FLOAT, ((value, targetClass, javaDeserializer) -> { + if (checkTypeCompatible(doubleClasses, targetClass)) { + return value.asFloatValue().toDouble(); + } else if (checkTypeCompatible(floatClasses, targetClass)) { + return value.asFloatValue().toFloat(); + } + throw new IllegalArgumentException("Float can't be deserialized as " + targetClass + "."); + })); + // String unpacker. + unpackers.put(ValueType.STRING, ((value, targetClass, javaDeserializer) -> { + Preconditions.checkArgument(checkTypeCompatible(stringClasses, targetClass), + "String can't be deserialized as {}.", targetClass); + return value.asStringValue().asString(); + })); + // Binary unpacker. + unpackers.put(ValueType.BINARY, ((value, targetClass, javaDeserializer) -> { + Preconditions.checkArgument(checkTypeCompatible(binaryClasses, targetClass), + "Binary can't be deserialized as {}.", targetClass); + return value.asBinaryValue().asByteArray(); + })); + // Array unpacker. + unpackers.put(ValueType.ARRAY, ((value, targetClass, javaDeserializer) -> { + ArrayValue av = value.asArrayValue(); + Class componentType = + targetClass.isArray() ? targetClass.getComponentType() : Object.class; + Object array = Array.newInstance(componentType, av.size()); + for (int i = 0; i < av.size(); ++i) { + Array.set(array, i, unpack(av.get(i), componentType, javaDeserializer)); + } + return array; + })); + // Extension unpacker. + unpackers.put(ValueType.EXTENSION, ((value, targetClass, javaDeserializer) -> { + ExtensionValue ev = value.asExtensionValue(); + byte extType = ev.getType(); + if (extType == LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID) { + return javaDeserializer.deserialize(ev); + } + throw new IllegalArgumentException("Unknown extension type id " + ev.getType() + "."); + })); + } + + interface JavaSerializer { + + void serialize(Object object, MessagePacker packer) throws IOException; + } + + interface JavaDeserializer { + + Object deserialize(ExtensionValue v); + } + + interface TypePacker { + + void pack(Object object, MessagePacker packer, + JavaSerializer javaSerializer) throws IOException; + } + + interface TypeUnpacker { + + Object unpack(Value value, Class targetClass, + JavaDeserializer javaDeserializer); + } + + private static boolean checkTypeCompatible(List> expected, Class actual) { + for (Class expectedClass : expected) { + if (actual.isAssignableFrom(expectedClass)) { + return true; + } + } + return false; + } + + private static void pack(Object object, MessagePacker packer, JavaSerializer javaSerializer) + throws IOException { + TypePacker typePacker; + if (object == null) { + typePacker = NULL_PACKER; + } else { + Class type = object.getClass(); + typePacker = packers.get(type); + if (typePacker == null) { + if (type.isArray()) { + typePacker = ARRAY_PACKER; + } else { + typePacker = EXTENSION_PACKER; + } + } + } + typePacker.pack(object, packer, javaSerializer); + } + + private static Object unpack(Value v, Class type, JavaDeserializer javaDeserializer) { + return unpackers.get(v.getValueType()).unpack(v, type, javaDeserializer); + } + + public static Pair encode(Object obj) { + MessageBufferPacker packer = MessagePack.newDefaultBufferPacker(); + try { + // Reserve MESSAGE_PACK_OFFSET bytes for MessagePack bytes length. + packer.writePayload(new byte[MESSAGE_PACK_OFFSET]); + // Serialize input object by MessagePack. + MutableBoolean isCrossLanguage = new MutableBoolean(true); + pack(obj, packer, ((object, packer1) -> { + byte[] payload = FstSerializer.encode(object); + packer1.packExtensionTypeHeader(LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID, payload.length); + packer1.addPayload(payload); + isCrossLanguage.setFalse(); + })); + byte[] msgpackBytes = packer.toByteArray(); + // Serialize MessagePack bytes length. + MessageBufferPacker headerPacker = MessagePack.newDefaultBufferPacker(); + Preconditions.checkState(msgpackBytes.length >= MESSAGE_PACK_OFFSET); + headerPacker.packLong(msgpackBytes.length - MESSAGE_PACK_OFFSET); + byte[] msgpackBytesLength = headerPacker.toByteArray(); + // Check serialized MessagePack bytes length is valid. + Preconditions.checkState(msgpackBytesLength.length <= MESSAGE_PACK_OFFSET); + // Write MessagePack bytes length to reserved buffer. + System.arraycopy(msgpackBytesLength, 0, msgpackBytes, 0, msgpackBytesLength.length); + return ImmutablePair.of(msgpackBytes, isCrossLanguage.getValue()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + + @SuppressWarnings("unchecked") + public static T decode(byte[] bs, Class type) { + try { + // Read MessagePack bytes length. + MessageUnpacker headerUnpacker = MessagePack.newDefaultUnpacker(bs, 0, MESSAGE_PACK_OFFSET); + long msgpackBytesLength = headerUnpacker.unpackLong(); + headerUnpacker.close(); + // Check MessagePack bytes length is valid. + Preconditions.checkState(MESSAGE_PACK_OFFSET + msgpackBytesLength <= bs.length); + // Deserialize MessagePack bytes from MESSAGE_PACK_OFFSET. + MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(bs, MESSAGE_PACK_OFFSET, + (int) msgpackBytesLength); + Value v = unpacker.unpackValue(); + if (type == null) { + type = Object.class; + } + return (T) unpack(v, type, + ((ExtensionValue ev) -> FstSerializer.decode(ev.getData()))); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/serializer/Serializer.java b/java/runtime/src/main/java/org/ray/runtime/serializer/Serializer.java new file mode 100644 index 000000000..e42d9d4d1 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/serializer/Serializer.java @@ -0,0 +1,15 @@ +package org.ray.runtime.serializer; + +import org.apache.commons.lang3.tuple.Pair; + +public class Serializer { + + public static Pair encode(Object obj) { + return MessagePackSerializer.encode(obj); + } + + @SuppressWarnings("unchecked") + public static T decode(byte[] bs, Class type) { + return MessagePackSerializer.decode(bs, type); + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 3dae7f5f2..4760fb1c2 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -1,8 +1,8 @@ package org.ray.runtime.task; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; - import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.id.ObjectId; @@ -40,9 +40,18 @@ public class ArgumentsBuilder { id = ((RayObject) arg).getId(); } else { value = ObjectSerializer.serialize(arg); + if (language != Language.JAVA) { + boolean isCrossData = + Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_CROSS_LANGUAGE) || + Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_RAW); + if (!isCrossData) { + throw new IllegalArgumentException(String.format("Can't transfer %s data to %s", + Arrays.toString(value.metadata), language.getValueDescriptor().getName())); + } + } if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) { id = ((RayRuntimeInternal) Ray.internal()).getObjectStore() - .putRaw(value); + .putRaw(value); value = null; } } @@ -61,10 +70,10 @@ public class ArgumentsBuilder { /** * Convert list of NativeRayObject to real function arguments. */ - public static Object[] unwrap(List args, ClassLoader classLoader) { + public static Object[] unwrap(List args, Class[] types) { Object[] realArgs = new Object[args.size()]; for (int i = 0; i < args.size(); i++) { - realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, classLoader); + realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, types[i]); } return realArgs; } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java index e044bf677..f89417869 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java @@ -5,6 +5,7 @@ import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentHashMap; + import org.ray.api.exception.RayTaskException; import org.ray.api.id.ActorId; import org.ray.api.id.JobId; @@ -97,7 +98,8 @@ public abstract class TaskExecutor { } actor = actorContext.currentActor; } - Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader); + Object[] args = ArgumentsBuilder + .unwrap(argsBytes, rayFunction.executable.getParameterTypes()); // Execute the task. Object result; try { diff --git a/java/runtime/src/main/java/org/ray/runtime/util/Serializer.java b/java/runtime/src/main/java/org/ray/runtime/util/Serializer.java deleted file mode 100644 index e29140411..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/util/Serializer.java +++ /dev/null @@ -1,60 +0,0 @@ -package org.ray.runtime.util; - -import org.nustaq.serialization.FSTConfiguration; -import org.ray.runtime.actor.NativeRayActor; -import org.ray.runtime.actor.NativeRayActorSerializer; - -/** - * Java object serialization TODO: use others (e.g. Arrow) for higher performance - */ -public class Serializer { - - private static final ThreadLocal conf = ThreadLocal.withInitial(() -> { - FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration(); - conf.registerSerializer(NativeRayActor.class, new NativeRayActorSerializer(), true); - return conf; - }); - - public static byte[] encode(Object obj) { - return conf.get().asByteArray(obj); - } - - public static byte[] encode(Object obj, ClassLoader classLoader) { - byte[] result; - FSTConfiguration current = conf.get(); - if (classLoader != null && classLoader != current.getClassLoader()) { - ClassLoader old = current.getClassLoader(); - current.setClassLoader(classLoader); - result = current.asByteArray(obj); - current.setClassLoader(old); - } else { - result = current.asByteArray(obj); - } - - return result; - } - - @SuppressWarnings("unchecked") - public static T decode(byte[] bs) { - return (T) conf.get().asObject(bs); - } - - @SuppressWarnings("unchecked") - public static T decode(byte[] bs, ClassLoader classLoader) { - Object object; - FSTConfiguration current = conf.get(); - if (classLoader != null && classLoader != current.getClassLoader()) { - ClassLoader old = current.getClassLoader(); - current.setClassLoader(classLoader); - object = current.asObject(bs); - current.setClassLoader(old); - } else { - object = current.asObject(bs); - } - return (T) object; - } - - public static void setClassloader(ClassLoader classLoader) { - conf.get().setClassLoader(classLoader); - } -} diff --git a/java/runtime/src/test/java/org/ray/runtime/util/SerializerTest.java b/java/runtime/src/test/java/org/ray/runtime/util/SerializerTest.java new file mode 100644 index 000000000..5621e424e --- /dev/null +++ b/java/runtime/src/test/java/org/ray/runtime/util/SerializerTest.java @@ -0,0 +1,52 @@ +package org.ray.runtime.util; + +import org.apache.commons.lang3.mutable.MutableBoolean; +import org.apache.commons.lang3.tuple.Pair; +import org.ray.runtime.serializer.Serializer; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.ArrayList; + +public class SerializerTest { + + @Test + public void testBasicSerialization() { + // Test serialize / deserialize primitive types with type conversion. + { + Object[] foo = new Object[]{"hello", (byte) 1, 2.0, (short) 3, 4, 5L, + new String[]{"hello", "world"}}; + Pair serialized = Serializer.encode(foo); + Object[] bar = Serializer.decode(serialized.getLeft(), Object[].class); + Assert.assertTrue(serialized.getRight()); + Assert.assertEquals(foo[0], bar[0]); + Assert.assertEquals(((Number) foo[1]).byteValue(), ((Number) bar[1]).byteValue()); + Assert.assertEquals(foo[2], bar[2]); + Assert.assertEquals(((Number) foo[3]).intValue(), ((Number) bar[3]).intValue()); + Assert.assertEquals(((Number) foo[4]).intValue(), ((Number) bar[4]).intValue()); + Assert.assertEquals(((Number) foo[5]).intValue(), ((Number) bar[5]).intValue()); + } + // Test multidimensional array. + { + Object[][] foo = new Object[][]{{1, 2}, {"3", 4}}; + Assert.expectThrows(RuntimeException.class, () -> { + Object[][] bar = Serializer.decode(Serializer.encode(foo).getLeft(), Integer[][].class); + }); + Pair serialized = Serializer.encode(foo); + Object[][] bar = Serializer.decode(serialized.getLeft(), Object[][].class); + Assert.assertTrue(serialized.getRight()); + Assert.assertEquals(((Number) foo[0][1]).intValue(), ((Number) bar[0][1]).intValue()); + Assert.assertEquals(foo[1][0], bar[1][0]); + } + // Test List. + { + ArrayList foo = new ArrayList<>(); + foo.add("1"); + foo.add("2"); + Pair serialized = Serializer.encode(foo); + ArrayList bar = Serializer.decode(serialized.getLeft(), String[].class); + Assert.assertFalse(serialized.getRight()); + Assert.assertEquals(foo.get(0), bar.get(0)); + } + } +} diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index baaffdf01..f7be4daa9 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -143,7 +143,7 @@ public class ActorTest extends BaseTest { try { // Try getting the object again, this should throw an UnreconstructableException. // Use `Ray.get()` to bypass the cache in `RayObjectImpl`. - Ray.get(value.getId()); + Ray.get(value.getId(), value.getType()); Assert.fail("This line should not be reachable."); } catch (UnreconstructableException e) { Assert.assertEquals(value.getId(), e.objectId); diff --git a/java/test/src/main/java/org/ray/api/test/ClassLoaderTest.java b/java/test/src/main/java/org/ray/api/test/ClassLoaderTest.java index 73bd9965d..86060814d 100644 --- a/java/test/src/main/java/org/ray/api/test/ClassLoaderTest.java +++ b/java/test/src/main/java/org/ray/api/test/ClassLoaderTest.java @@ -4,6 +4,7 @@ import java.io.File; import java.lang.reflect.Method; import java.nio.file.Files; import java.nio.file.Paths; +import java.util.Optional; import javax.tools.JavaCompiler; import javax.tools.ToolProvider; import org.apache.commons.io.FileUtils; @@ -101,12 +102,14 @@ public class ClassLoaderTest extends BaseTest { "()V"); RayActor actor1 = createActor(constructor); FunctionDescriptor getPid = new JavaFunctionDescriptor("ClassLoaderTester", "getPid", "()I"); - int pid = this.callActorFunction(actor1, getPid, new Object[0], 1).get(); + int pid = this.callActorFunction(actor1, getPid, new Object[0], + Optional.of(Integer.class)).get(); RayActor actor2; while (true) { // Create another actor which share the same process of actor 1. actor2 = createActor(constructor); - int actor2Pid = this.callActorFunction(actor2, getPid, new Object[0], 1).get(); + int actor2Pid = this.callActorFunction(actor2, getPid, new Object[0], + Optional.of(Integer.class)).get(); if (actor2Pid == pid) { break; } @@ -116,15 +119,17 @@ public class ClassLoaderTest extends BaseTest { "getClassLoaderHashCode", "()I"); RayObject hashCode1 = callActorFunction(actor1, getClassLoaderHashCode, new Object[0], - 1); + Optional.of(Integer.class)); RayObject hashCode2 = callActorFunction(actor2, getClassLoaderHashCode, new Object[0], - 1); + Optional.of(Integer.class)); Assert.assertEquals(hashCode1.get(), hashCode2.get()); FunctionDescriptor increase = new JavaFunctionDescriptor("ClassLoaderTester", "increase", "()I"); - RayObject value1 = callActorFunction(actor1, increase, new Object[0], 1); - RayObject value2 = callActorFunction(actor2, increase, new Object[0], 1); + RayObject value1 = callActorFunction(actor1, increase, new Object[0], + Optional.of(Integer.class)); + RayObject value2 = callActorFunction(actor2, increase, new Object[0], + Optional.of(Integer.class)); Assert.assertNotEquals(value1.get(), value2.get()); } @@ -138,11 +143,12 @@ public class ClassLoaderTest extends BaseTest { } private RayObject callActorFunction(RayActor rayActor, - FunctionDescriptor functionDescriptor, Object[] args, int numReturns) throws Exception { + FunctionDescriptor functionDescriptor, Object[] args, Optional> returnType) + throws Exception { Method callActorFunctionMethod = AbstractRayRuntime.class.getDeclaredMethod("callActorFunction", - BaseActor.class, FunctionDescriptor.class, Object[].class, int.class); + BaseActor.class, FunctionDescriptor.class, Object[].class, Optional.class); callActorFunctionMethod.setAccessible(true); return (RayObject) callActorFunctionMethod - .invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, numReturns); + .invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, returnType); } } diff --git a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java index 4058b7e52..6ce153ed5 100644 --- a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java @@ -23,7 +23,7 @@ public class ClientExceptionTest extends BaseTest { public void testWaitAndCrash() { TestUtils.skipTestUnderSingleProcess(); ObjectId randomId = ObjectId.fromRandom(); - RayObject notExisting = new RayObjectImpl(randomId); + RayObject notExisting = new RayObjectImpl(randomId, String.class); Thread thread = new Thread(() -> { try { diff --git a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java index 8715e581c..d200f0572 100644 --- a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java +++ b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java @@ -5,6 +5,9 @@ import com.google.common.collect.ImmutableMap; import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.math.BigInteger; +import java.util.Arrays; +import java.util.List; import java.util.Map; import org.apache.commons.io.FileUtils; import org.ray.api.Ray; @@ -51,18 +54,85 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { @Test public void testCallingPythonFunction() { - RayObject res = Ray.call( - new PyRemoteFunction<>(PYTHON_MODULE, "py_func", byte[].class), - "hello".getBytes()); - Assert.assertEquals(res.get(), "Response from Python: hello".getBytes()); + Object[] inputs = new Object[]{ + true, // Boolean + Byte.MAX_VALUE, // Byte + Short.MAX_VALUE, // Short + Integer.MAX_VALUE, // Integer + Long.MAX_VALUE, // Long + // BigInteger can support max value of 2^64-1, please refer to: + // https://github.com/msgpack/msgpack/blob/master/spec.md#int-format-family + // If BigInteger larger than 2^64-1, the value can only be transferred among Java workers. + BigInteger.valueOf(Long.MAX_VALUE), // BigInteger + "Hello World!", // String + 1.234f, // Float + 1.234, // Double + "example binary".getBytes()}; // byte[] + for (Object o : inputs) { + RayObject res = Ray.call( + new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", o.getClass()), + o); + Assert.assertEquals(res.get(), o); + } + // null + { + Object input = null; + RayObject res = Ray.call( + new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object.class), input); + Object r = res.get(); + Assert.assertEquals(r, input); + } + // array + { + int[] input = new int[]{1, 2}; + RayObject res = Ray.call( + new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", int[].class), input); + int[] r = res.get(); + Assert.assertEquals(r, input); + } + // array of Object + { + Object[] input = new Object[]{1, 2.3f, 4.56, "789", "10".getBytes(), null, true, + new int[]{1, 2}}; + RayObject res = Ray.call( + new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", Object[].class), input); + Object[] r = res.get(); + // If we tell the value type is Object, then all numbers will be Number type. + Assert.assertEquals(((Number) r[0]).intValue(), input[0]); + Assert.assertEquals(((Number) r[1]).floatValue(), input[1]); + Assert.assertEquals(((Number) r[2]).doubleValue(), input[2]); + // String cast + Assert.assertEquals((String) r[3], input[3]); + // binary cast + Assert.assertEquals((byte[]) r[4], input[4]); + // null + Assert.assertEquals(r[5], input[5]); + // Boolean cast + Assert.assertEquals((Boolean) r[6], input[6]); + // array cast + Object[] r7array = (Object[]) r[7]; + int[] input7array = (int[]) input[7]; + Assert.assertEquals(((Number) r7array[0]).intValue(), input7array[0]); + Assert.assertEquals(((Number) r7array[1]).intValue(), input7array[1]); + } + // Unsupported types, all Java specific types, e.g. List / Map... + { + Assert.expectThrows(Exception.class, () -> { + List input = Arrays.asList(1, 2); + RayObject> res = Ray.call( + new PyRemoteFunction<>(PYTHON_MODULE, "py_return_input", + (Class>) input.getClass()), input); + List r = res.get(); + Assert.assertEquals(r, input); + }); + } } @Test public void testPythonCallJavaFunction() { - RayObject res = Ray.call( - new PyRemoteFunction<>(PYTHON_MODULE, "py_func_call_java_function", byte[].class), - "hello".getBytes()); - Assert.assertEquals(res.get(), "[Python]py_func -> [Java]bytesEcho -> hello".getBytes()); + RayObject res = Ray.call( + new PyRemoteFunction<>(PYTHON_MODULE, "py_func_call_java_function", String.class)); + Assert.assertEquals(res.get(), "success"); } @Test @@ -117,11 +187,33 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { Assert.assertEquals(res.get(), "3".getBytes()); } - public static byte[] bytesEcho(byte[] value) { + public static Object[] pack(int i, String s, double f, Object[] o) { // This function will be called from test_cross_language_invocation.py - String valueStr = new String(value); - LOGGER.debug(String.format("bytesEcho called with: %s", valueStr)); - return ("[Java]bytesEcho -> " + valueStr).getBytes(); + return new Object[]{i, s, f, o}; + } + + public static Object returnInput(Object o) { + return o; + } + + public static boolean returnInputBoolean(boolean b) { + return b; + } + + public static int returnInputInt(int i) { + return i; + } + + public static double returnInputDouble(double d) { + return d; + } + + public static String returnInputString(String s) { + return s; + } + + public static int[] returnInputIntArray(int[] l) { + return l; } public static byte[] callPythonActorHandle(byte[] value) { @@ -135,6 +227,7 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { } public static class TestActor { + public TestActor(byte[] v) { value = v; } diff --git a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java index 4388dcf4e..269501fab 100644 --- a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java +++ b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java @@ -45,7 +45,7 @@ public class DynamicResourceTest extends BaseTest { // Assert ray call result. result = Ray.wait(ImmutableList.of(obj), 1, 1000); Assert.assertEquals(result.getReady().size(), 1); - Assert.assertEquals(Ray.get(obj.getId()), "hi"); + Assert.assertEquals(obj.get(), "hi"); } diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java index 5ec9c2e7c..a7a202eb0 100644 --- a/java/test/src/main/java/org/ray/api/test/FailureTest.java +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -148,7 +148,7 @@ public class FailureTest extends BaseTest { RayObject obj2 = Ray.call(FailureTest::slowFunc); Instant start = Instant.now(); try { - Ray.get(Arrays.asList(obj1.getId(), obj2.getId())); + Ray.get(Arrays.asList(obj1, obj2)); Assert.fail("Should throw RayException."); } catch (RayException e) { Instant end = Instant.now(); diff --git a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java index f1ed2a65a..505d79c85 100644 --- a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java +++ b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java @@ -104,7 +104,7 @@ public class MultiThreadingTest extends BaseTest { runTestCaseInMultipleThreads(() -> { int arg = random.nextInt(); RayObject obj = Ray.put(arg); - Assert.assertEquals(arg, (int) Ray.get(obj.getId())); + Assert.assertEquals(arg, (int) obj.get()); }, LOOP_COUNTER); TestUtils.warmUpCluster(); @@ -141,7 +141,7 @@ public class MultiThreadingTest extends BaseTest { final RayActor fooActor = Ray.createActor(Echo::new); final Runnable[] runnables = new Runnable[]{ () -> Ray.put(1), - () -> Ray.get(fooObject.getId()), + () -> Ray.get(fooObject.getId(), fooObject.getType()), fooObject::get, () -> Ray.wait(ImmutableList.of(fooObject)), Ray::getRuntimeContext, diff --git a/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java b/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java index be584ba6d..883ce8b7d 100644 --- a/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java +++ b/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java @@ -16,8 +16,22 @@ public class ObjectStoreTest extends BaseTest { @Test public void testPutAndGet() { - RayObject obj = Ray.put(1); - Assert.assertEquals(1, (int) obj.get()); + { + RayObject obj = Ray.put(1); + Assert.assertEquals(1, (int) obj.get()); + } + + { + String s = null; + RayObject obj = Ray.put(s); + Assert.assertNull(obj.get()); + } + + { + List> l = ImmutableList.of(ImmutableList.of("abc")); + RayObject>> obj = Ray.put(l); + Assert.assertEquals(obj.get(), l); + } } @Test @@ -25,6 +39,6 @@ public class ObjectStoreTest extends BaseTest { List ints = ImmutableList.of(1, 2, 3, 4, 5); List ids = ints.stream().map(obj -> Ray.put(obj).getId()) .collect(Collectors.toList()); - Assert.assertEquals(ints, Ray.get(ids)); + Assert.assertEquals(ints, Ray.get(ids, Integer.class)); } } diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java index e0655e052..99a2f4b71 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java @@ -15,9 +15,9 @@ public class PlasmaStoreTest extends BaseTest { ObjectId objectId = ObjectId.fromRandom(); ObjectStore objectStore = TestUtils.getRuntime().getObjectStore(); objectStore.put("1", objectId); - Assert.assertEquals(Ray.get(objectId), "1"); + Assert.assertEquals(Ray.get(objectId, String.class), "1"); objectStore.put("2", objectId); // Putting the second object with duplicate ID should fail but ignored. - Assert.assertEquals(Ray.get(objectId), "1"); + Assert.assertEquals(Ray.get(objectId, String.class), "1"); } } diff --git a/java/test/src/main/java/org/ray/api/test/RayCallTest.java b/java/test/src/main/java/org/ray/api/test/RayCallTest.java index fb1e94ef2..cbe918021 100644 --- a/java/test/src/main/java/org/ray/api/test/RayCallTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayCallTest.java @@ -87,7 +87,7 @@ public class RayCallTest extends BaseTest { ObjectId randomObjectId = ObjectId.fromRandom(); Ray.call(RayCallTest::testNoReturn, randomObjectId); - Assert.assertEquals(((int) Ray.get(randomObjectId)), 1); + Assert.assertEquals(((int) Ray.get(randomObjectId, Integer.class)), 1); } private static int testNoParam() { diff --git a/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java index 398d1277c..0d79c1f6c 100644 --- a/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java +++ b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java @@ -2,9 +2,7 @@ package org.ray.api.test; import org.ray.api.Ray; import org.ray.api.RayPyActor; -import org.ray.api.TestUtils; import org.ray.api.function.PyActorClass; -import org.ray.runtime.context.WorkerContext; import org.ray.runtime.object.NativeRayObject; import org.ray.runtime.object.ObjectSerializer; import org.testng.Assert; @@ -15,10 +13,9 @@ public class RaySerializerTest extends BaseMultiLanguageTest { @Test public void testSerializePyActor() { RayPyActor pyActor = Ray.createActor(new PyActorClass("test", "RaySerializerTest")); - WorkerContext workerContext = TestUtils.getRuntime().getWorkerContext(); NativeRayObject nativeRayObject = ObjectSerializer.serialize(pyActor); RayPyActor result = (RayPyActor) ObjectSerializer - .deserialize(nativeRayObject, null, workerContext.getCurrentClassLoader()); + .deserialize(nativeRayObject, null, Object.class); Assert.assertEquals(result.getId(), pyActor.getId()); Assert.assertEquals(result.getModuleName(), "test"); Assert.assertEquals(result.getClassName(), "RaySerializerTest"); diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java index 2a2a7f06d..c6d338963 100644 --- a/java/test/src/main/java/org/ray/api/test/StressTest.java +++ b/java/test/src/main/java/org/ray/api/test/StressTest.java @@ -28,7 +28,7 @@ public class StressTest extends BaseTest { resultIds.add(Ray.call(StressTest::echo, 1).getId()); } - for (Integer result : Ray.get(resultIds)) { + for (Integer result : Ray.get(resultIds, Integer.class)) { Assert.assertEquals(result, Integer.valueOf(1)); } } @@ -67,7 +67,7 @@ public class StressTest extends BaseTest { objectIds.add(actor.call(Actor::ping).getId()); } int sum = 0; - for (Integer result : Ray.get(objectIds)) { + for (Integer result : Ray.get(objectIds, Integer.class)) { sum += result; } return sum; @@ -84,7 +84,7 @@ public class StressTest extends BaseTest { objectIds.add(worker.call(Worker::ping, 100).getId()); } - for (Integer result : Ray.get(objectIds)) { + for (Integer result : Ray.get(objectIds, Integer.class)) { Assert.assertEquals(result, Integer.valueOf(100)); } } diff --git a/java/test/src/main/resources/test_cross_language_invocation.py b/java/test/src/main/resources/test_cross_language_invocation.py index 8b06aa661..753d87480 100644 --- a/java/test/src/main/resources/test_cross_language_invocation.py +++ b/java/test/src/main/resources/test_cross_language_invocation.py @@ -5,18 +5,47 @@ import ray @ray.remote -def py_func(value): - assert isinstance(value, bytes) - return b"Response from Python: " + value +def py_return_input(v): + return v @ray.remote -def py_func_call_java_function(value): - assert isinstance(value, bytes) - f = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest", - "bytesEcho") - r = f.remote(value) - return b"[Python]py_func -> " + ray.get(r) +def py_func_call_java_function(): + try: + # None + r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest", + "returnInput").remote(None) + assert ray.get(r) is None + # bool + r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest", + "returnInputBoolean").remote(True) + assert ray.get(r) is True + # int + r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest", + "returnInputInt").remote(100) + assert ray.get(r) == 100 + # double + r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest", + "returnInputDouble").remote(1.23) + assert ray.get(r) == 1.23 + # string + r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest", + "returnInputString").remote("Hello World!") + assert ray.get(r) == "Hello World!" + # list (tuple will be packed by pickle, + # so only list can be transferred across language) + r = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest", + "returnInputIntArray").remote([1, 2, 3]) + assert ray.get(r) == [1, 2, 3] + # pack + f = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest", + "pack") + input = [100, "hello", 1.23, [1, "2", 3.0]] + r = f.remote(*input) + assert ray.get(r) == input + return "success" + except Exception as ex: + return str(ex) @ray.remote diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 76cf05c4e..45b7a42c2 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -92,6 +92,8 @@ from ray.exceptions import ( RayTimeoutError, ) from ray.utils import decode +import gc +import msgpack cimport cpython @@ -106,8 +108,6 @@ include "includes/libcoreworker.pxi" logger = logging.getLogger(__name__) -MEMCOPY_THREADS = 6 - def set_internal_config(dict options): cdef: @@ -257,8 +257,9 @@ cdef int prepare_resources( return 0 -cdef void prepare_args( - CoreWorker core_worker, args, c_vector[CTaskArg] *args_vector): +cdef prepare_args( + CoreWorker core_worker, + Language language, args, c_vector[CTaskArg] *args_vector): cdef: size_t size int64_t put_threshold @@ -274,6 +275,13 @@ cdef void prepare_args( else: serialized_arg = worker.get_serialization_context().serialize(arg) + metadata = serialized_arg.metadata + if language != Language.PYTHON: + if metadata not in [ + ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE, + ray_constants.OBJECT_METADATA_TYPE_RAW]: + raise Exception("Can't transfer {} data to {}".format( + metadata, language)) size = serialized_arg.total_bytes # TODO(edoakes): any objects containing ObjectIDs are spilled to @@ -283,12 +291,14 @@ cdef void prepare_args( if size <= put_threshold: arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer]( make_shared[LocalMemoryBuffer](size)) - write_serialized_object(serialized_arg, arg_data) + if size > 0: + (serialized_arg).write_to( + Buffer.make(arg_data)) for object_id in serialized_arg.contained_object_ids: inlined_ids.push_back((object_id).native()) args_vector.push_back( CTaskArg.PassByValue(make_shared[CRayObject]( - arg_data, string_to_buffer(serialized_arg.metadata), + arg_data, string_to_buffer(metadata), inlined_ids))) inlined_ids.clear() else: @@ -616,29 +626,6 @@ cdef shared_ptr[CBuffer] string_to_buffer(c_string& c_str): (c_str.data()), c_str.size(), True)) -cdef write_serialized_object( - serialized_object, const shared_ptr[CBuffer]& buf): - from ray.serialization import Pickle5SerializedObject, RawSerializedObject - - if isinstance(serialized_object, RawSerializedObject): - if buf.get() != NULL and buf.get().Size() > 0: - size = serialized_object.total_bytes - if MEMCOPY_THREADS > 1 and size > kMemcopyDefaultThreshold: - parallel_memcopy(buf.get().Data(), - serialized_object.value, - size, kMemcopyDefaultBlocksize, - MEMCOPY_THREADS) - else: - memcpy(buf.get().Data(), - serialized_object.value, size) - - elif isinstance(serialized_object, Pickle5SerializedObject): - (serialized_object.writer).write_to( - serialized_object.inband, buf, MEMCOPY_THREADS) - else: - raise TypeError("Unsupported serialization type.") - - cdef class CoreWorker: def __cinit__(self, is_driver, store_socket, raylet_socket, @@ -780,7 +767,9 @@ cdef class CoreWorker: &c_object_id, &data) if not object_already_exists: - write_serialized_object(serialized_object, data) + if total_bytes > 0: + (serialized_object).write_to( + Buffer.make(data)) if self.is_local_mode: c_object_id_vector.push_back(c_object_id) check_status(CCoreWorkerProcess.GetCoreWorker().Put( @@ -875,7 +864,7 @@ cdef class CoreWorker: num_return_vals, c_resources) ray_function = CRayFunction( language.lang, function_descriptor.descriptor) - prepare_args(self, args, &args_vector) + prepare_args(self, language, args, &args_vector) with nogil: check_status(CCoreWorkerProcess.GetCoreWorker().SubmitTask( @@ -908,7 +897,7 @@ cdef class CoreWorker: prepare_resources(placement_resources, &c_placement_resources) ray_function = CRayFunction( language.lang, function_descriptor.descriptor) - prepare_args(self, args, &args_vector) + prepare_args(self, language, args, &args_vector) with nogil: check_status(CCoreWorkerProcess.GetCoreWorker().CreateActor( @@ -944,7 +933,7 @@ cdef class CoreWorker: task_options = CTaskOptions(num_return_vals, c_resources) ray_function = CRayFunction( language.lang, function_descriptor.descriptor) - prepare_args(self, args, &args_vector) + prepare_args(self, language, args, &args_vector) with nogil: check_status( @@ -1133,8 +1122,9 @@ cdef class CoreWorker: for i, serialized_object in enumerate(serialized_objects): # A nullptr is returned if the object already exists. if returns[0][i].get() != NULL: - write_serialized_object( - serialized_object, returns[0][i].get().GetData()) + if returns[0][i].get().HasData(): + (serialized_object).write_to( + Buffer.make(returns[0][i].get().GetData())) if self.is_local_mode: return_ids_vector.push_back(return_ids[i]) check_status( diff --git a/python/ray/includes/buffer.pxi b/python/ray/includes/buffer.pxi index 34e07412e..467e6f066 100644 --- a/python/ray/includes/buffer.pxi +++ b/python/ray/includes/buffer.pxi @@ -44,7 +44,7 @@ cdef class Buffer: def __getbuffer__(self, Py_buffer* buffer, int flags): buffer.readonly = 0 buffer.buf = self.buffer.get().Data() - buffer.format = 'b' + buffer.format = 'B' buffer.internal = NULL buffer.itemsize = 1 buffer.len = self.size diff --git a/python/ray/includes/serialization.pxi b/python/ray/includes/serialization.pxi index a5e08d8f7..90a634ac9 100644 --- a/python/ray/includes/serialization.pxi +++ b/python/ray/includes/serialization.pxi @@ -1,5 +1,9 @@ from libc.string cimport memcpy from libc.stdint cimport uintptr_t, uint64_t, INT32_MAX +from libcpp cimport nullptr +import cython + +DEF MEMCOPY_THREADS = 6 # This is the default alignment value for len(buffer) < 2048. DEF kMinorBufferAlign = 8 @@ -9,6 +13,8 @@ DEF kMajorBufferAlign = 64 DEF kMajorBufferSize = 2048 DEF kMemcopyDefaultBlocksize = 64 DEF kMemcopyDefaultThreshold = 1024 * 1024 +DEF kLanguageSpecificTypeExtensionId = 101 +DEF kMessagePackOffset = 9 cdef extern from "ray/util/memory.h" namespace "ray" nogil: void parallel_memcopy(uint8_t* dst, const uint8_t* src, int64_t nbytes, @@ -82,7 +88,7 @@ cdef class SubBuffer: void *internal object buffer - def __cinit__(self, Buffer buffer): + def __cinit__(self, object buffer): # Increase ref count. self.buffer = buffer self.suboffsets = NULL @@ -142,15 +148,68 @@ cdef class SubBuffer: return self.size -# See 'serialization.proto' for the memory layout in the Plasma buffer. -def unpack_pickle5_buffers(Buffer buf): +cdef class MessagePackSerializer(object): + @staticmethod + def dumps(o, python_serializer=None): + def _default(obj): + if python_serializer is not None: + return msgpack.ExtType(kLanguageSpecificTypeExtensionId, + msgpack.dumps(python_serializer(obj))) + return obj + try: + # If we let strict_types is False, then whether list or tuple will + # be packed to a message pack array. So, they can't be + # distinguished when unpacking. + return msgpack.dumps(o, default=_default, + use_bin_type=True, strict_types=True) + except ValueError as ex: + # msgpack can't handle recursive objects, so we serialize them by + # python serializer, e.g. pickle. + return msgpack.dumps(_default(o), default=_default, + use_bin_type=True, strict_types=True) + + @classmethod + def loads(cls, s, python_deserializer=None): + def _ext_hook(code, data): + if code == kLanguageSpecificTypeExtensionId: + if python_deserializer is not None: + return python_deserializer(msgpack.loads(data)) + raise Exception('Unrecognized ext type id: {}'.format(code)) + try: + gc.disable() # Performance optimization for msgpack. + return msgpack.loads(s, ext_hook=_ext_hook, raw=False) + finally: + gc.enable() + + +@cython.boundscheck(False) +@cython.wraparound(False) +def split_buffer(Buffer buf): cdef: - shared_ptr[CBuffer] _buffer = buf.buffer const uint8_t *data = buf.buffer.get().Data() - size_t size = _buffer.get().Size() + size_t size = buf.buffer.get().Size() + uint8_t[:] bufferview = buf + int64_t msgpack_bytes_length + + assert kMessagePackOffset <= size + header_unpacker = msgpack.Unpacker() + header_unpacker.feed(bufferview[:kMessagePackOffset]) + msgpack_bytes_length = header_unpacker.unpack() + assert kMessagePackOffset + msgpack_bytes_length <= size + return (bufferview[kMessagePackOffset: + kMessagePackOffset + msgpack_bytes_length], + bufferview[kMessagePackOffset + msgpack_bytes_length:]) + + +# See 'serialization.proto' for the memory layout in the Plasma buffer. +@cython.boundscheck(False) +@cython.wraparound(False) +def unpack_pickle5_buffers(uint8_t[:] bufferview): + cdef: + const uint8_t *data = &bufferview[0] + size_t size = len(bufferview) CPythonObject python_object CPythonBuffer *buffer_meta - c_string inband_data int64_t protobuf_offset int64_t protobuf_size int32_t i @@ -167,14 +226,16 @@ def unpack_pickle5_buffers(Buffer buf): if not python_object.ParseFromArray( data + protobuf_offset, protobuf_size): raise ValueError("Protobuf object is corrupted.") - inband_data.append((data + python_object.inband_data_offset()), - python_object.inband_data_size()) + inband_data_offset = python_object.inband_data_offset() + inband_data = bufferview[ + inband_data_offset: + inband_data_offset + python_object.inband_data_size()] buffers_segment = data + python_object.raw_buffers_offset() pickled_buffers = [] # Now read buffer meta for i in range(python_object.buffer_size()): buffer_meta = &python_object.buffer(i) - buffer = SubBuffer(buf) + buffer = SubBuffer(bufferview) buffer.buf = (buffers_segment + buffer_meta.address()) buffer.len = buffer_meta.length() buffer.itemsize = buffer_meta.itemsize() @@ -207,6 +268,11 @@ cdef class Pickle5Writer: self._curr_buffer_addr = 0 self._total_bytes = -1 + def __dealloc__(self): + # We must release the buffer, or we could experience memory leaks. + for i in range(self.buffers.size()): + cpython.PyBuffer_Release(&self.buffers[i]) + def buffer_callback(self, pickle_buffer): cdef: Py_buffer view @@ -240,14 +306,14 @@ cdef class Pickle5Writer: self._curr_buffer_addr += view.len self.buffers.push_back(view) - def get_total_bytes(self, const c_string &inband): + def get_total_bytes(self, const uint8_t[:] inband): cdef: size_t protobuf_bytes = 0 uint64_t inband_data_offset = sizeof(int64_t) * 2 uint64_t raw_buffers_offset = padded_length_u64( - inband_data_offset + inband.length(), kMajorBufferAlign) + inband_data_offset + len(inband), kMajorBufferAlign) self.python_object.set_inband_data_offset(inband_data_offset) - self.python_object.set_inband_data_size(inband.length()) + self.python_object.set_inband_data_size(len(inband)) self.python_object.set_raw_buffers_offset(raw_buffers_offset) self.python_object.set_raw_buffers_size(self._curr_buffer_addr) # Since calculating the output size is expensive, we will @@ -265,9 +331,11 @@ cdef class Pickle5Writer: self._total_bytes = self._protobuf_offset + protobuf_bytes return self._total_bytes - cdef void write_to(self, const c_string &inband, shared_ptr[CBuffer] data, - int memcopy_threads): - cdef uint8_t *ptr = data.get().Data() + @cython.boundscheck(False) + @cython.wraparound(False) + cdef void write_to(self, const uint8_t[:] inband, uint8_t[:] data, + int memcopy_threads) nogil: + cdef uint8_t *ptr = &data[0] cdef int32_t protobuf_size cdef uint64_t buffer_addr cdef uint64_t buffer_len @@ -284,7 +352,7 @@ cdef class Pickle5Writer: ptr + self._protobuf_offset) # Write inband data. memcpy(ptr + self.python_object.inband_data_offset(), - inband.data(), inband.length()) + &inband[0], len(inband)) # Write buffer data. ptr += self.python_object.raw_buffers_offset() for i in range(self.python_object.buffer_size()): @@ -298,5 +366,141 @@ cdef class Pickle5Writer: kMemcopyDefaultBlocksize, memcopy_threads) else: memcpy(ptr + buffer_addr, self.buffers[i].buf, buffer_len) - # We must release the buffer, or we could experience memory leaks. - cpython.PyBuffer_Release(&self.buffers[i]) + + +cdef class SerializedObject(object): + cdef: + object _metadata + object _contained_object_ids + + def __init__(self, metadata, contained_object_ids=None): + self._metadata = metadata + self._contained_object_ids = contained_object_ids or [] + + @property + def total_bytes(self): + raise NotImplementedError("{}.total_bytes not implemented.".format( + type(self).__name__)) + + @property + def metadata(self): + return self._metadata + + @property + def contained_object_ids(self): + return self._contained_object_ids + + @cython.boundscheck(False) + @cython.wraparound(False) + cdef void write_to(self, uint8_t[:] buffer) nogil: + raise NotImplementedError("{}.write_to not implemented.".format( + type(self).__name__)) + + +cdef class Pickle5SerializedObject(SerializedObject): + cdef: + const uint8_t[:] inband + Pickle5Writer writer + object _total_bytes + + def __init__(self, metadata, inband, Pickle5Writer writer, + contained_object_ids): + super(Pickle5SerializedObject, self).__init__(metadata, + contained_object_ids) + self.inband = inband + self.writer = writer + # cached total bytes + self._total_bytes = None + + @property + def total_bytes(self): + if self._total_bytes is None: + self._total_bytes = self.writer.get_total_bytes(self.inband) + return self._total_bytes + + @cython.boundscheck(False) + @cython.wraparound(False) + cdef void write_to(self, uint8_t[:] buffer) nogil: + self.writer.write_to(self.inband, buffer, MEMCOPY_THREADS) + + +cdef class MessagePackSerializedObject(SerializedObject): + cdef: + SerializedObject nest_serialized_object + object msgpack_header + object msgpack_data + int64_t _msgpack_header_bytes + int64_t _msgpack_data_bytes + int64_t _total_bytes + const uint8_t *msgpack_header_ptr + const uint8_t *msgpack_data_ptr + + def __init__(self, metadata, msgpack_data, + SerializedObject nest_serialized_object=None): + if nest_serialized_object: + contained_object_ids = nest_serialized_object.contained_object_ids + total_bytes = nest_serialized_object.total_bytes + else: + contained_object_ids = [] + total_bytes = 0 + super(MessagePackSerializedObject, self).__init__(metadata, + contained_object_ids) + self.nest_serialized_object = nest_serialized_object + self.msgpack_header = msgpack_header = msgpack.dumps(len(msgpack_data)) + self.msgpack_data = msgpack_data + self._msgpack_header_bytes = len(msgpack_header) + self._msgpack_data_bytes = len(msgpack_data) + self._total_bytes = (kMessagePackOffset + + self._msgpack_data_bytes + + total_bytes) + self.msgpack_header_ptr = msgpack_header + self.msgpack_data_ptr = msgpack_data + assert self._msgpack_header_bytes <= kMessagePackOffset + + @property + def total_bytes(self): + return self._total_bytes + + @cython.boundscheck(False) + @cython.wraparound(False) + cdef void write_to(self, uint8_t[:] buffer) nogil: + cdef uint8_t *ptr = &buffer[0] + + # Write msgpack data first. + memcpy(ptr, self.msgpack_header_ptr, self._msgpack_header_bytes) + memcpy(ptr + kMessagePackOffset, + self.msgpack_data_ptr, self._msgpack_data_bytes) + + if self.nest_serialized_object is not None: + self.nest_serialized_object.write_to( + buffer[kMessagePackOffset + self._msgpack_data_bytes:]) + + +cdef class RawSerializedObject(SerializedObject): + cdef: + object value + const uint8_t *value_ptr + int64_t _total_bytes + + def __init__(self, value): + super(RawSerializedObject, + self).__init__(ray_constants.OBJECT_METADATA_TYPE_RAW) + self.value = value + self.value_ptr = value + self._total_bytes = len(value) + + @property + def total_bytes(self): + return self._total_bytes + + @cython.boundscheck(False) + @cython.wraparound(False) + cdef void write_to(self, uint8_t[:] buffer) nogil: + if (MEMCOPY_THREADS > 1 and + self._total_bytes > kMemcopyDefaultThreshold): + parallel_memcopy(&buffer[0], + self.value_ptr, + self._total_bytes, kMemcopyDefaultBlocksize, + MEMCOPY_THREADS) + else: + memcpy(&buffer[0], self.value_ptr, self._total_bytes) diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index 953a568fd..d5569b6b3 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -180,13 +180,12 @@ PROCESS_TYPE_GCS_SERVER = "gcs_server" LOG_MONITOR_MAX_OPEN_FILES = 200 -# A constant used as object metadata to indicate the object is raw binary. -RAW_BUFFER_METADATA = b"RAW" -# A constant used as object metadata to indicate the object is pickled. This -# format is only ever used for Python inline task argument values. -PICKLE_BUFFER_METADATA = b"PICKLE" -# A constant used as object metadata to indicate the object is pickle5 format. -PICKLE5_BUFFER_METADATA = b"PICKLE5" +# A constant used as object metadata to indicate the object is cross language. +OBJECT_METADATA_TYPE_CROSS_LANGUAGE = b"XLANG" +# A constant used as object metadata to indicate the object is python specific. +OBJECT_METADATA_TYPE_PYTHON = b"PYTHON" +# A constant used as object metadata to indicate the object is raw bytes. +OBJECT_METADATA_TYPE_RAW = b"RAW" AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request" diff --git a/python/ray/serialization.py b/python/ray/serialization.py index 03ded8540..9b6589910 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -15,7 +15,15 @@ from ray.exceptions import ( RayWorkerError, UnreconstructableError, ) -from ray._raylet import Pickle5Writer, unpack_pickle5_buffers +from ray._raylet import ( + split_buffer, + unpack_pickle5_buffers, + Pickle5Writer, + Pickle5SerializedObject, + MessagePackSerializer, + MessagePackSerializedObject, + RawSerializedObject, +) logger = logging.getLogger(__name__) @@ -34,51 +42,6 @@ class DeserializationError(Exception): pass -class SerializedObject: - def __init__(self, metadata, contained_object_ids=None): - self._metadata = metadata - self._contained_object_ids = contained_object_ids or [] - - @property - def total_bytes(self): - raise NotImplementedError - - @property - def metadata(self): - return self._metadata - - @property - def contained_object_ids(self): - return self._contained_object_ids - - -class Pickle5SerializedObject(SerializedObject): - def __init__(self, metadata, inband, writer, contained_object_ids): - super(Pickle5SerializedObject, self).__init__(metadata, - contained_object_ids) - self.inband = inband - self.writer = writer - # cached total bytes - self._total_bytes = None - - @property - def total_bytes(self): - if self._total_bytes is None: - self._total_bytes = self.writer.get_total_bytes(self.inband) - return self._total_bytes - - -class RawSerializedObject(SerializedObject): - def __init__(self, value): - super(RawSerializedObject, - self).__init__(ray_constants.RAW_BUFFER_METADATA) - self.value = value - - @property - def total_bytes(self): - return len(self.value) - - def _try_to_compute_deterministic_class_id(cls, depth=5): """Attempt to produce a deterministic class ID for a given class. @@ -265,23 +228,51 @@ class SerializationContext: raise DeserializationError() return obj + def _deserialize_msgpack_data(self, data, metadata): + msgpack_data, pickle5_data = split_buffer(data) + + if metadata == ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE: + python_objects = [] + else: + python_objects = self._deserialize_pickle5_data(pickle5_data) + + try: + + def _python_deserializer(index): + return python_objects[index] + + obj = MessagePackSerializer.loads(msgpack_data, + _python_deserializer) + except Exception: + raise DeserializationError() + return obj + def _deserialize_object(self, data, metadata, object_id): if metadata: - if metadata == ray_constants.PICKLE5_BUFFER_METADATA: - return self._deserialize_pickle5_data(data) + if metadata in [ + ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE, + ray_constants.OBJECT_METADATA_TYPE_PYTHON + ]: + return self._deserialize_msgpack_data(data, metadata) # Check if the object should be returned as raw bytes. - if metadata == ray_constants.RAW_BUFFER_METADATA: + if metadata == ray_constants.OBJECT_METADATA_TYPE_RAW: if data is None: return b"" return data.to_pybytes() # Otherwise, return an exception object based on # the error type. - error_type = int(metadata) + try: + error_type = int(metadata) + except Exception: + raise Exception( + "Can't deserialize object: {}, metadata: {}".format( + object_id, metadata)) + # RayTaskError is serialized with pickle5 in the data field. # TODO (kfstorm): exception serialization should be language # independent. if error_type == ErrorType.Value("TASK_EXECUTION_EXCEPTION"): - obj = self._deserialize_pickle5_data(data) + obj = self._deserialize_msgpack_data(data, metadata) assert isinstance(obj, RayTaskError) return obj elif error_type == ErrorType.Value("WORKER_DIED"): @@ -347,6 +338,43 @@ class SerializationContext: return results + def _serialize_to_pickle5(self, metadata, value): + writer = Pickle5Writer() + # TODO(swang): Check that contained_object_ids is empty. + try: + self.set_in_band_serialization() + inband = pickle.dumps( + value, protocol=5, buffer_callback=writer.buffer_callback) + except Exception as e: + self.get_and_clear_contained_object_ids() + raise e + finally: + self.set_out_of_band_serialization() + + return Pickle5SerializedObject( + metadata, inband, writer, + self.get_and_clear_contained_object_ids()) + + def _serialize_to_msgpack(self, metadata, value): + python_objects = [] + + def _python_serializer(o): + index = len(python_objects) + python_objects.append(o) + return index + + msgpack_data = MessagePackSerializer.dumps(value, _python_serializer) + + if python_objects: + pickle5_serialized_object = \ + self._serialize_to_pickle5(metadata, python_objects) + else: + metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE + pickle5_serialized_object = None + + return MessagePackSerializedObject(metadata, msgpack_data, + pickle5_serialized_object) + def serialize(self, value): """Serialize an object. @@ -365,23 +393,9 @@ class SerializationContext: metadata = str(ErrorType.Value( "TASK_EXECUTION_EXCEPTION")).encode("ascii") else: - metadata = ray_constants.PICKLE5_BUFFER_METADATA + metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON - writer = Pickle5Writer() - # TODO(swang): Check that contained_object_ids is empty. - try: - self.set_in_band_serialization() - inband = pickle.dumps( - value, protocol=5, buffer_callback=writer.buffer_callback) - except Exception as e: - self.get_and_clear_contained_object_ids() - raise e - finally: - self.set_out_of_band_serialization() - - return Pickle5SerializedObject( - metadata, inband, writer, - self.get_and_clear_contained_object_ids()) + return self._serialize_to_msgpack(metadata, value) def register_custom_serializer(self, cls, diff --git a/python/ray/tests/test_cross_language.py b/python/ray/tests/test_cross_language.py index 13323c9bd..cf655eae7 100644 --- a/python/ray/tests/test_cross_language.py +++ b/python/ray/tests/test_cross_language.py @@ -13,3 +13,13 @@ def test_cross_language_raise_kwargs(shutdown_only): with pytest.raises(Exception, match="kwargs"): ray.java_actor_class("a").remote(x="arg1") + + +def test_cross_language_raise_exception(shutdown_only): + ray.init(load_code_from_local=True, include_java=True) + + class PythonObject(object): + pass + + with pytest.raises(Exception, match="transfer"): + ray.java_function("a", "b").remote(PythonObject()) diff --git a/python/setup.py b/python/setup.py index 80408cee6..02b28517e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -172,9 +172,19 @@ def find_version(*filepath): requires = [ - "numpy >= 1.16", "filelock", "jsonschema", "click", "colorama", "pyyaml", - "redis >= 3.3.2", "protobuf >= 3.8.0", "py-spy >= 0.2.0", "aiohttp", - "google", "grpcio" + "aiohttp", + "click", + "colorama", + "filelock", + "google", + "grpcio", + "jsonschema", + "msgpack >= 0.6.0, < 1.0.0", + "numpy >= 1.16", + "protobuf >= 3.8.0", + "py-spy >= 0.2.0", + "pyyaml", + "redis >= 3.3.2", ] setup( diff --git a/streaming/java/dependencies.bzl b/streaming/java/dependencies.bzl index cfcdab4c4..def6d4d5e 100644 --- a/streaming/java/dependencies.bzl +++ b/streaming/java/dependencies.bzl @@ -13,8 +13,8 @@ def gen_streaming_java_deps(): "org.slf4j:slf4j-api:1.7.12", "org.slf4j:slf4j-log4j12:1.7.25", "org.apache.logging.log4j:log4j-core:2.8.2", - "org.testng:testng:6.9.10", "org.msgpack:msgpack-core:0.8.20", + "org.testng:testng:6.9.10", ], repositories = [ "https://repo1.maven.org/maven2/", diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/collector/OutputCollector.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/collector/OutputCollector.java index 64f92cc13..900518362 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/collector/OutputCollector.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/collector/OutputCollector.java @@ -2,7 +2,7 @@ package org.ray.streaming.runtime.core.collector; import java.nio.ByteBuffer; import java.util.Collection; -import org.ray.runtime.util.Serializer; +import org.ray.runtime.serializer.Serializer; import org.ray.streaming.api.collector.Collector; import org.ray.streaming.api.partition.Partition; import org.ray.streaming.message.Record; @@ -31,7 +31,7 @@ public class OutputCollector implements Collector { @Override public void collect(Record record) { int[] partitions = this.partition.partition(record, outputQueues.length); - ByteBuffer msgBuffer = ByteBuffer.wrap(Serializer.encode(record)); + ByteBuffer msgBuffer = ByteBuffer.wrap(Serializer.encode(record).getLeft()); for (int partition : partitions) { writer.write(outputQueues[partition], msgBuffer); } diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/InputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/InputStreamTask.java index eed12f705..a4ba75eb4 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/InputStreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/InputStreamTask.java @@ -1,6 +1,6 @@ package org.ray.streaming.runtime.worker.tasks; -import org.ray.runtime.util.Serializer; +import org.ray.runtime.serializer.Serializer; import org.ray.streaming.runtime.core.processor.Processor; import org.ray.streaming.runtime.transfer.Message; import org.ray.streaming.runtime.worker.JobWorker; @@ -28,7 +28,7 @@ public abstract class InputStreamTask extends StreamTask { if (item != null) { byte[] bytes = new byte[item.body().remaining()]; item.body().get(bytes); - Object obj = Serializer.decode(bytes); + Object obj = Serializer.decode(bytes, Object.class); processor.process(obj); } }