diff --git a/java/runtime/src/main/java/io/ray/runtime/context/LocalModeWorkerContext.java b/java/runtime/src/main/java/io/ray/runtime/context/LocalModeWorkerContext.java index 28825c55a..406c3d651 100644 --- a/java/runtime/src/main/java/io/ray/runtime/context/LocalModeWorkerContext.java +++ b/java/runtime/src/main/java/io/ray/runtime/context/LocalModeWorkerContext.java @@ -5,6 +5,7 @@ import io.ray.api.id.ActorId; import io.ray.api.id.JobId; import io.ray.api.id.TaskId; import io.ray.api.id.UniqueId; +import io.ray.runtime.generated.Common.Address; import io.ray.runtime.generated.Common.TaskSpec; import io.ray.runtime.generated.Common.TaskType; import io.ray.runtime.task.LocalModeTaskSubmitter; @@ -68,6 +69,11 @@ public class LocalModeWorkerContext implements WorkerContext { return TaskId.fromBytes(taskSpec.getTaskId().toByteArray()); } + @Override + public Address getRpcAddress() { + return Address.getDefaultInstance(); + } + public void setCurrentTask(TaskSpec taskSpec) { currentTask.set(taskSpec); } diff --git a/java/runtime/src/main/java/io/ray/runtime/context/NativeWorkerContext.java b/java/runtime/src/main/java/io/ray/runtime/context/NativeWorkerContext.java index defb42173..76978d0f5 100644 --- a/java/runtime/src/main/java/io/ray/runtime/context/NativeWorkerContext.java +++ b/java/runtime/src/main/java/io/ray/runtime/context/NativeWorkerContext.java @@ -1,9 +1,11 @@ package io.ray.runtime.context; +import com.google.protobuf.InvalidProtocolBufferException; import io.ray.api.id.ActorId; import io.ray.api.id.JobId; import io.ray.api.id.TaskId; import io.ray.api.id.UniqueId; +import io.ray.runtime.generated.Common.Address; import io.ray.runtime.generated.Common.TaskType; import java.nio.ByteBuffer; @@ -51,6 +53,15 @@ public class NativeWorkerContext implements WorkerContext { return TaskId.fromByteBuffer(nativeGetCurrentTaskId()); } + @Override + public Address getRpcAddress() { + try { + return Address.parseFrom(nativeGetRpcAddress()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + } + private static native int nativeGetCurrentTaskType(); private static native ByteBuffer nativeGetCurrentTaskId(); @@ -60,4 +71,6 @@ public class NativeWorkerContext implements WorkerContext { private static native ByteBuffer nativeGetCurrentWorkerId(); private static native ByteBuffer nativeGetCurrentActorId(); + + private static native byte[] nativeGetRpcAddress(); } diff --git a/java/runtime/src/main/java/io/ray/runtime/context/WorkerContext.java b/java/runtime/src/main/java/io/ray/runtime/context/WorkerContext.java index 40b647f69..50a5c5ce5 100644 --- a/java/runtime/src/main/java/io/ray/runtime/context/WorkerContext.java +++ b/java/runtime/src/main/java/io/ray/runtime/context/WorkerContext.java @@ -4,6 +4,7 @@ import io.ray.api.id.ActorId; import io.ray.api.id.JobId; import io.ray.api.id.TaskId; import io.ray.api.id.UniqueId; +import io.ray.runtime.generated.Common.Address; import io.ray.runtime.generated.Common.TaskType; /** @@ -46,4 +47,6 @@ public interface WorkerContext { * ID of the current task. */ TaskId getCurrentTaskId(); + + Address getRpcAddress(); } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java index 5349f590a..87f0adc00 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java @@ -4,6 +4,7 @@ import com.google.common.base.Preconditions; import io.ray.api.id.ObjectId; import io.ray.api.id.UniqueId; import io.ray.runtime.context.WorkerContext; +import io.ray.runtime.generated.Common.Address; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -106,6 +107,11 @@ public class LocalModeObjectStore extends ObjectStore { public void removeLocalReference(UniqueId workerId, ObjectId objectId) { } + @Override + public Address getOwnerAddress(ObjectId id) { + return Address.getDefaultInstance(); + } + @Override public byte[] promoteAndGetOwnershipInfo(ObjectId objectId) { return new byte[0]; diff --git a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java index fc198ed73..ef85cf62c 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java @@ -1,9 +1,11 @@ package io.ray.runtime.object; +import com.google.protobuf.InvalidProtocolBufferException; import io.ray.api.id.BaseId; import io.ray.api.id.ObjectId; import io.ray.api.id.UniqueId; import io.ray.runtime.context.WorkerContext; +import io.ray.runtime.generated.Common.Address; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -92,6 +94,15 @@ public class NativeObjectStore extends ObjectStore { return referenceCounts; } + @Override + public Address getOwnerAddress(ObjectId id) { + try { + return Address.parseFrom(nativeGetOwnerAddress(id.getBytes())); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + } + private static List toBinaryList(List ids) { return ids.stream().map(BaseId::getBytes).collect(Collectors.toList()); } @@ -114,6 +125,8 @@ public class NativeObjectStore extends ObjectStore { private static native Map nativeGetAllReferenceCounts(); + private static native byte[] nativeGetOwnerAddress(byte[] objectId); + private static native byte[] nativePromoteAndGetOwnershipInfo(byte[] objectId); private static native void nativeRegisterOwnershipInfoAndResolveFuture(byte[] objectId, diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java index a69147fd8..e72bed802 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java @@ -7,6 +7,7 @@ import io.ray.api.id.ObjectId; import io.ray.api.id.UniqueId; import io.ray.runtime.context.WorkerContext; import io.ray.runtime.exception.RayException; +import io.ray.runtime.generated.Common.Address; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -187,6 +188,8 @@ public abstract class ObjectStore { */ public abstract void removeLocalReference(UniqueId workerId, ObjectId objectId); + public abstract Address getOwnerAddress(ObjectId id); + /** * Promote the given object to the underlying object store, and get the ownership info. * diff --git a/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java index 53af11da9..6c90b552e 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/ArgumentsBuilder.java @@ -5,6 +5,7 @@ import io.ray.api.ObjectRef; import io.ray.api.Ray; import io.ray.api.id.ObjectId; import io.ray.runtime.RayRuntimeInternal; +import io.ray.runtime.generated.Common.Address; import io.ray.runtime.generated.Common.Language; import io.ray.runtime.object.NativeRayObject; import io.ray.runtime.object.ObjectRefImpl; @@ -39,10 +40,12 @@ public class ArgumentsBuilder { List ret = new ArrayList<>(); for (Object arg : args) { ObjectId id = null; + Address address = null; NativeRayObject value = null; if (arg instanceof ObjectRef) { Preconditions.checkState(arg instanceof ObjectRefImpl); id = ((ObjectRefImpl) arg).getId(); + address = ((RayRuntimeInternal) Ray.internal()).getObjectStore().getOwnerAddress(id); } else { value = ObjectSerializer.serialize(arg); if (language != Language.JAVA) { @@ -58,6 +61,7 @@ public class ArgumentsBuilder { if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) { id = ((RayRuntimeInternal) Ray.internal()).getObjectStore() .putRaw(value); + address = ((RayRuntimeInternal) Ray.internal()).getWorkerContext().getRpcAddress(); value = null; } } @@ -65,7 +69,7 @@ public class ArgumentsBuilder { ret.add(FunctionArg.passByValue(PYTHON_DUMMY_TYPE)); } if (id != null) { - ret.add(FunctionArg.passByReference(id)); + ret.add(FunctionArg.passByReference(id, address)); } else { ret.add(FunctionArg.passByValue(value)); } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/FunctionArg.java b/java/runtime/src/main/java/io/ray/runtime/task/FunctionArg.java index 941863082..d61c0c525 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/FunctionArg.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/FunctionArg.java @@ -2,6 +2,7 @@ package io.ray.runtime.task; import com.google.common.base.Preconditions; import io.ray.api.id.ObjectId; +import io.ray.runtime.generated.Common.Address; import io.ray.runtime.object.NativeRayObject; /** @@ -15,29 +16,44 @@ public class FunctionArg { * The id of this argument (passed by reference). */ public final ObjectId id; + + /** + * The owner address of this argument (passed by reference). + */ + public final Address ownerAddress; + /** * Serialized data of this argument (passed by value). */ public final NativeRayObject value; - private FunctionArg(ObjectId id, NativeRayObject value) { - Preconditions.checkState((id == null) != (value == null)); + private FunctionArg(ObjectId id, Address ownerAddress) { + Preconditions.checkNotNull(id); + Preconditions.checkNotNull(ownerAddress); this.id = id; - this.value = value; + this.ownerAddress = ownerAddress; + this.value = null; + } + + private FunctionArg(NativeRayObject nativeRayObject) { + Preconditions.checkNotNull(nativeRayObject); + this.id = null; + this.ownerAddress = null; + this.value = nativeRayObject; } /** * Create a FunctionArg that will be passed by reference. */ - public static FunctionArg passByReference(ObjectId id) { - return new FunctionArg(id, null); + public static FunctionArg passByReference(ObjectId id, Address ownerAddress) { + return new FunctionArg(id, ownerAddress); } /** * Create a FunctionArg that will be passed by value. */ public static FunctionArg passByValue(NativeRayObject value) { - return new FunctionArg(null, value); + return new FunctionArg(value); } @Override diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java index c9cab4dce..276cb48b8 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java @@ -22,6 +22,7 @@ import io.ray.runtime.functionmanager.JavaFunctionDescriptor; import io.ray.runtime.generated.Common; import io.ray.runtime.generated.Common.ActorCreationTaskSpec; import io.ray.runtime.generated.Common.ActorTaskSpec; +import io.ray.runtime.generated.Common.Address; import io.ray.runtime.generated.Common.Language; import io.ray.runtime.generated.Common.ObjectReference; import io.ray.runtime.generated.Common.TaskArg; @@ -381,7 +382,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { TaskArg arg = taskSpec.getArgs(i); if (arg.getObjectRef().getObjectId() != ByteString.EMPTY) { functionArgs.add(FunctionArg - .passByReference(new ObjectId(arg.getObjectRef().getObjectId().toByteArray()))); + .passByReference(new ObjectId(arg.getObjectRef().getObjectId().toByteArray()), + Address.getDefaultInstance())); } else { functionArgs.add(FunctionArg.passByValue( new NativeRayObject(arg.getData().toByteArray(), arg.getMetadata().toByteArray()))); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index bf5050c50..035809662 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -276,10 +276,7 @@ Java_io_ray_runtime_RayNativeRuntime_nativeGetActorIdOfNamedActor(JNIEnv *env, j } else { actor_id = ray::ActorID::Nil(); } - jbyteArray bytes = env->NewByteArray(actor_id.Size()); - env->SetByteArrayRegion(bytes, 0, actor_id.Size(), - reinterpret_cast(actor_id.Data())); - return bytes; + return IdToJavaByteArray(env, actor_id); } JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeKillActor( diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.cc b/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.cc index d149e4313..191df0b0b 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_actor_NativeActorHandle.cc @@ -51,10 +51,8 @@ JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_actor_NativeActorHandle_nativeS ObjectID actor_handle_id; ray::Status status = ray::CoreWorkerProcess::GetCoreWorker().SerializeActorHandle( actor_id, &output, &actor_handle_id); - jbyteArray bytes = env->NewByteArray(output.size()); - env->SetByteArrayRegion(bytes, 0, output.size(), - reinterpret_cast(output.c_str())); - return bytes; + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + return NativeStringToJavaByteArray(env, output); } JNIEXPORT jbyteArray JNICALL diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_context_NativeWorkerContext.cc b/src/ray/core_worker/lib/java/io_ray_runtime_context_NativeWorkerContext.cc index c0833817c..cd094d29a 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_context_NativeWorkerContext.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_context_NativeWorkerContext.cc @@ -66,6 +66,12 @@ Java_io_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv * return IdToJavaByteBuffer(env, actor_id); } +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_context_NativeWorkerContext_nativeGetRpcAddress(JNIEnv *env, jclass) { + const auto &rpc_address = ray::CoreWorkerProcess::GetCoreWorker().GetRpcAddress(); + return NativeStringToJavaByteArray(env, rpc_address.SerializeAsString()); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_context_NativeWorkerContext.h b/src/ray/core_worker/lib/java/io_ray_runtime_context_NativeWorkerContext.h index cb6005533..cd836c40a 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_context_NativeWorkerContext.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_context_NativeWorkerContext.h @@ -63,6 +63,14 @@ Java_io_ray_runtime_context_NativeWorkerContext_nativeGetCurrentWorkerId(JNIEnv JNIEXPORT jobject JNICALL Java_io_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *, jclass); +/* + * Class: io_ray_runtime_context_NativeWorkerContext + * Method: nativeGetRpcAddress + * Signature: ()[B + */ +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_context_NativeWorkerContext_nativeGetRpcAddress(JNIEnv *, jclass); + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_metric_NativeMetric.h b/src/ray/core_worker/lib/java/io_ray_runtime_metric_NativeMetric.h index 54159f6fa..29ff46b39 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_metric_NativeMetric.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_metric_NativeMetric.h @@ -1,3 +1,17 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /* DO NOT EDIT THIS FILE - it is machine generated */ #include /* Header for class io_ray_runtime_metric_NativeMetric */ diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc index c0974f543..f14853002 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc @@ -174,6 +174,15 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeGetAllReferenceCounts(JNIEnv }); } +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *env, jclass, + jbyteArray objectId) { + auto object_id = JavaByteArrayToId(env, objectId); + const auto &rpc_address = + ray::CoreWorkerProcess::GetCoreWorker().GetOwnerAddress(object_id); + return NativeStringToJavaByteArray(env, rpc_address.SerializeAsString()); +} + JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativePromoteAndGetOwnershipInfo( JNIEnv *env, jclass, jbyteArray objectId) { diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h index 162c62b4e..0da1aba92 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h @@ -94,6 +94,15 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeGetAllReferenceCounts(JNIEnv *, jclass); +/* + * Class: io_ray_runtime_object_NativeObjectStore + * Method: nativeGetOwnerAddress + * Signature: ([B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_object_NativeObjectStore_nativeGetOwnerAddress(JNIEnv *, jclass, + jbyteArray); + /* * Class: io_ray_runtime_object_NativeObjectStore * Method: nativePromoteAndGetOwnershipInfo diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskExecutor.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskExecutor.cc index b1db4f451..2fc4e33ba 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskExecutor.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskExecutor.cc @@ -37,10 +37,7 @@ Java_io_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *env, ActorCheckpointID checkpoint_id; auto status = core_worker.PrepareActorCheckpoint(actor_id, &checkpoint_id); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - jbyteArray result = env->NewByteArray(checkpoint_id.Size()); - env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), - reinterpret_cast(checkpoint_id.Data())); - return result; + return IdToJavaByteArray(env, checkpoint_id); } JNIEXPORT void JNICALL diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index 118340e39..2a280a871 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -62,8 +62,14 @@ inline std::vector> ToTaskArgs(JNIEnv *env, jobjec env->CallObjectMethod(java_id, java_base_id_get_bytes)); RAY_CHECK_JAVA_EXCEPTION(env); auto id = JavaByteArrayToId(env, java_id_bytes); - return std::unique_ptr(new ray::TaskArgByReference( - id, ray::CoreWorkerProcess::GetCoreWorker().GetOwnerAddress(id))); + auto java_owner_address = + env->GetObjectField(arg, java_function_arg_owner_address); + RAY_CHECK(java_owner_address); + auto owner_address = + JavaProtobufObjectToNativeProtobufObject( + env, java_owner_address); + return std::unique_ptr( + new ray::TaskArgByReference(id, owner_address)); } auto java_value = static_cast(env->GetObjectField(arg, java_function_arg_value)); diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 596973020..4f2b05c6b 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -62,6 +62,9 @@ jmethodID java_jni_exception_util_get_stack_trace; jclass java_base_id_class; jmethodID java_base_id_get_bytes; +jclass java_abstract_message_lite_class; +jmethodID java_abstract_message_lite_to_byte_array; + jclass java_function_descriptor_class; jmethodID java_function_descriptor_get_language; jmethodID java_function_descriptor_to_list; @@ -71,6 +74,7 @@ jmethodID java_language_get_number; jclass java_function_arg_class; jfieldID java_function_arg_id; +jfieldID java_function_arg_owner_address; jfieldID java_function_arg_value; jclass java_base_task_options_class; @@ -183,6 +187,11 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_base_id_class = LoadClass(env, "io/ray/api/id/BaseId"); java_base_id_get_bytes = env->GetMethodID(java_base_id_class, "getBytes", "()[B"); + java_abstract_message_lite_class = + LoadClass(env, "com/google/protobuf/AbstractMessage"); + java_abstract_message_lite_to_byte_array = + env->GetMethodID(java_abstract_message_lite_class, "toByteArray", "()[B"); + java_function_descriptor_class = LoadClass(env, "io/ray/runtime/functionmanager/FunctionDescriptor"); java_function_descriptor_get_language = @@ -197,6 +206,9 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_function_arg_class = LoadClass(env, "io/ray/runtime/task/FunctionArg"); java_function_arg_id = env->GetFieldID(java_function_arg_class, "id", "Lio/ray/api/id/ObjectId;"); + java_function_arg_owner_address = + env->GetFieldID(java_function_arg_class, "ownerAddress", + "Lio/ray/runtime/generated/Common$Address;"); java_function_arg_value = env->GetFieldID(java_function_arg_class, "value", "Lio/ray/runtime/object/NativeRayObject;"); @@ -278,6 +290,7 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { env->DeleteGlobalRef(java_ray_intentional_system_exit_exception_class); env->DeleteGlobalRef(java_jni_exception_util_class); env->DeleteGlobalRef(java_base_id_class); + env->DeleteGlobalRef(java_abstract_message_lite_class); env->DeleteGlobalRef(java_function_descriptor_class); env->DeleteGlobalRef(java_language_class); env->DeleteGlobalRef(java_function_arg_class); diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 406cd71d4..5b849d885 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -107,6 +107,11 @@ extern jclass java_base_id_class; /// getBytes method of BaseId class extern jmethodID java_base_id_get_bytes; +/// AbstractMessageLite class +extern jclass java_abstract_message_lite_class; +/// toByteArray method of AbstractMessageLite class +extern jmethodID java_abstract_message_lite_to_byte_array; + /// FunctionDescriptor interface extern jclass java_function_descriptor_class; /// getLanguage method of FunctionDescriptor interface @@ -123,6 +128,8 @@ extern jmethodID java_language_get_number; extern jclass java_function_arg_class; /// id field of FunctionArg class extern jfieldID java_function_arg_id; +/// ownerAddress field of FunctionArg class +extern jfieldID java_function_arg_owner_address; /// value field of FunctionArg class extern jfieldID java_function_arg_value; @@ -528,6 +535,27 @@ inline jobject NativeRayFunctionDescriptorToJavaStringList( return NativeStringVectorToJavaStringList(env, std::vector()); } +/// Convert a Java protobuf object to a C++ protobuf object +template +inline NativeT JavaProtobufObjectToNativeProtobufObject(JNIEnv *env, jobject java_obj) { + NativeT native_obj; + if (java_obj) { + jbyteArray bytes = static_cast( + env->CallObjectMethod(java_obj, java_abstract_message_lite_to_byte_array)); + RAY_CHECK_JAVA_EXCEPTION(env); + RAY_CHECK(bytes != nullptr); + auto buffer = JavaByteArrayToNativeBuffer(env, bytes); + RAY_CHECK(buffer); + native_obj.ParseFromArray(buffer->Data(), buffer->Size()); + // Destroy the buffer before deleting the local ref of `bytes`. We need to make sure + // that `bytes` is still available when invoking the destructor of + // `JavaByteArrayBuffer`. + buffer.reset(); + env->DeleteLocalRef(bytes); + } + return native_obj; +} + // Return an actor fullname with job id prepended if this tis a global actor. inline std::string GetActorFullName(bool global, std::string name) { if (name.empty()) {