diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java index b09a70789..fd229503e 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java @@ -2,6 +2,7 @@ package io.ray.runtime.object; import com.google.protobuf.InvalidProtocolBufferException; import io.ray.api.id.ObjectId; +import io.ray.runtime.actor.NativeActorHandle; import io.ray.runtime.exception.RayActorException; import io.ray.runtime.exception.RayTaskException; import io.ray.runtime.exception.RayWorkerException; @@ -35,6 +36,10 @@ public class ObjectSerializer { public static final byte[] OBJECT_METADATA_TYPE_JAVA = "JAVA".getBytes(); public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes(); public static final byte[] OBJECT_METADATA_TYPE_RAW = "RAW".getBytes(); + // A constant used as object metadata to indicate the object is an actor handle. + // This value should be synchronized with the Python definition in ray_constants.py + // TODO(fyrestone): Serialize the ActorHandle via the custom type feature of XLANG. + public static final byte[] OBJECT_METADATA_TYPE_ACTOR_HANDLE = "ACTOR_HANDLE".getBytes(); // When an outer object is being serialized, the nested ObjectRefs are all // serialized and the writeExternal method of the nested ObjectRefs are @@ -86,6 +91,9 @@ public class ObjectSerializer { "Can't deserialize RayTaskException object: " + objectId .toString()); } + } else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_ACTOR_HANDLE)) { + byte[] serialized = Serializer.decode(data, byte[].class); + return NativeActorHandle.fromBytes(serialized); } else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_PYTHON)) { throw new IllegalArgumentException("Can't deserialize Python object: " + objectId .toString()); @@ -129,6 +137,13 @@ public class ObjectSerializer { // Only OBJECT_METADATA_TYPE_RAW is raw bytes, // any other type should be the MessagePack serialized bytes. return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META); + } else if (object instanceof NativeActorHandle) { + NativeActorHandle actorHandle = (NativeActorHandle)object; + byte[] serializedBytes = Serializer.encode(actorHandle.toBytes()).getLeft(); + // serializedBytes is MessagePack serialized bytes + // Only OBJECT_METADATA_TYPE_RAW is raw bytes, + // any other type should be the MessagePack serialized bytes. + return new NativeRayObject(serializedBytes, OBJECT_METADATA_TYPE_ACTOR_HANDLE); } else { try { Pair serialized = Serializer.encode(object); diff --git a/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java index 9455fedb4..53af11da9 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java @@ -48,7 +48,8 @@ public class ArgumentsBuilder { if (language != Language.JAVA) { boolean isCrossData = Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_CROSS_LANGUAGE) || - Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_RAW); + Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_RAW) || + Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_ACTOR_HANDLE); if (!isCrossData) { throw new IllegalArgumentException(String.format("Can't transfer %s data to %s", Arrays.toString(value.metadata), language.getValueDescriptor().getName())); diff --git a/java/test/src/main/java/io/ray/test/CrossLanguageInvocationTest.java b/java/test/src/main/java/io/ray/test/CrossLanguageInvocationTest.java index e630494ee..876b00681 100644 --- a/java/test/src/main/java/io/ray/test/CrossLanguageInvocationTest.java +++ b/java/test/src/main/java/io/ray/test/CrossLanguageInvocationTest.java @@ -167,23 +167,21 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { // Create a java actor, and pass actor handle to python. ActorHandle javaActor = Ray.actor(TestActor::new, "1".getBytes()).remote(); Preconditions.checkState(javaActor instanceof NativeActorHandle); - byte[] actorHandleBytes = ((NativeActorHandle) javaActor).toBytes(); ObjectRef res = Ray.task( PyFunction.of(PYTHON_MODULE, "py_func_call_java_actor_from_handle", byte[].class), - actorHandleBytes).remote(); + javaActor).remote(); Assert.assertEquals(res.get(), "12".getBytes()); // Create a python actor, and pass actor handle to python. PyActorHandle pyActor = Ray.actor( PyActorClass.of(PYTHON_MODULE, "Counter"), "1".getBytes()).remote(); Preconditions.checkState(pyActor instanceof NativeActorHandle); - actorHandleBytes = ((NativeActorHandle) pyActor).toBytes(); res = Ray.task( PyFunction.of(PYTHON_MODULE, "py_func_call_python_actor_from_handle", byte[].class), - actorHandleBytes).remote(); + pyActor).remote(); Assert.assertEquals(res.get(), "3".getBytes()); } @@ -301,9 +299,8 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { return l; } - public static byte[] callPythonActorHandle(byte[] value) { + public static byte[] callPythonActorHandle(PyActorHandle actor) { // This function will be called from test_cross_language_invocation.py - NativePyActorHandle actor = (NativePyActorHandle) NativeActorHandle.fromBytes(value); ObjectRef res = actor.task( PyActorMethod.of("increase", byte[].class), "1".getBytes()).remote(); diff --git a/java/test/src/main/resources/test_cross_language_invocation.py b/java/test/src/main/resources/test_cross_language_invocation.py index 372c6f6e8..72a56f649 100644 --- a/java/test/src/main/resources/test_cross_language_invocation.py +++ b/java/test/src/main/resources/test_cross_language_invocation.py @@ -59,17 +59,13 @@ def py_func_call_java_actor(value): @ray.remote -def py_func_call_java_actor_from_handle(value): - assert isinstance(value, bytes) - actor_handle = ray.actor.ActorHandle._deserialization_helper(value) +def py_func_call_java_actor_from_handle(actor_handle): r = actor_handle.concat.remote(b"2") return ray.get(r) @ray.remote -def py_func_call_python_actor_from_handle(value): - assert isinstance(value, bytes) - actor_handle = ray.actor.ActorHandle._deserialization_helper(value) +def py_func_call_python_actor_from_handle(actor_handle): r = actor_handle.increase.remote(2) return ray.get(r) @@ -79,7 +75,7 @@ def py_func_pass_python_actor_handle(): counter = Counter.remote(2) f = ray.java_function("io.ray.test.CrossLanguageInvocationTest", "callPythonActorHandle") - r = f.remote(counter._serialization_helper()[0]) + r = f.remote(counter) return ray.get(r) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 654a7a538..b5c4722d2 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -297,7 +297,8 @@ cdef prepare_args( if language != Language.PYTHON: if metadata not in [ ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE, - ray_constants.OBJECT_METADATA_TYPE_RAW]: + ray_constants.OBJECT_METADATA_TYPE_RAW, + ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE]: raise Exception("Can't transfer {} data to {}".format( metadata, language)) size = serialized_arg.total_bytes diff --git a/python/ray/includes/serialization.pxi b/python/ray/includes/serialization.pxi index 31a5f1e04..da789ad8c 100644 --- a/python/ray/includes/serialization.pxi +++ b/python/ray/includes/serialization.pxi @@ -436,15 +436,14 @@ cdef class MessagePackSerializedObject(SerializedObject): const uint8_t *msgpack_header_ptr const uint8_t *msgpack_data_ptr - def __init__(self, metadata, msgpack_data, + def __init__(self, metadata, msgpack_data, contained_object_refs, SerializedObject nest_serialized_object=None): if nest_serialized_object: - contained_object_refs = ( + contained_object_refs.extend( nest_serialized_object.contained_object_refs ) total_bytes = nest_serialized_object.total_bytes else: - contained_object_refs = [] total_bytes = 0 super(MessagePackSerializedObject, self).__init__( metadata, diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index 92aeac05f..3fdf50b8e 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -193,6 +193,13 @@ OBJECT_METADATA_TYPE_PYTHON = b"PYTHON" # A constant used as object metadata to indicate the object is raw bytes. OBJECT_METADATA_TYPE_RAW = b"RAW" +# A constant used as object metadata to indicate the object is an actor handle. +# This value should be synchronized with the Java definition in +# ObjectSerializer.java +# TODO(fyrestone): Serialize the ActorHandle via the custom type feature +# of XLANG. +OBJECT_METADATA_TYPE_ACTOR_HANDLE = b"ACTOR_HANDLE" + AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request" # The default password to prevent redis port scanning attack. diff --git a/python/ray/serialization.py b/python/ray/serialization.py index ae51289c5..f85b07afa 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -250,6 +250,9 @@ class SerializationContext: if data is None: return b"" return data.to_pybytes() + elif metadata == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE: + obj = self._deserialize_msgpack_data(data, metadata) + return actor_handle_deserializer(obj) # Otherwise, return an exception object based on # the error type. try: @@ -349,10 +352,20 @@ class SerializationContext: def _serialize_to_msgpack(self, value): # Only RayTaskError is possible to be serialized here. We don't # need to deal with other exception types here. + contained_object_refs = [] + if isinstance(value, RayTaskError): metadata = str( ErrorType.Value("TASK_EXECUTION_EXCEPTION")).encode("ascii") value = value.to_bytes() + elif isinstance(value, ray.actor.ActorHandle): + # TODO(fyresone): ActorHandle should be serialized via the + # custom type feature of cross-language. + serialized, actor_handle_id = value._serialization_helper() + contained_object_refs.append(actor_handle_id) + # Update ref counting for the actor handle + metadata = ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE + value = serialized else: metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE @@ -373,6 +386,7 @@ class SerializationContext: pickle5_serialized_object = None return MessagePackSerializedObject(metadata, msgpack_data, + contained_object_refs, pickle5_serialized_object) def serialize(self, value):