diff --git a/java/runtime/src/main/java/io/ray/runtime/functionmanager/PyFunctionDescriptor.java b/java/runtime/src/main/java/io/ray/runtime/functionmanager/PyFunctionDescriptor.java index f2a6bd315..bbf26c57f 100644 --- a/java/runtime/src/main/java/io/ray/runtime/functionmanager/PyFunctionDescriptor.java +++ b/java/runtime/src/main/java/io/ray/runtime/functionmanager/PyFunctionDescriptor.java @@ -1,5 +1,6 @@ package io.ray.runtime.functionmanager; +import com.google.common.base.Objects; import io.ray.runtime.generated.Common.Language; import java.util.Arrays; import java.util.List; @@ -26,6 +27,25 @@ public class PyFunctionDescriptor implements FunctionDescriptor { return moduleName + "." + className + "." + functionName; } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PyFunctionDescriptor that = (PyFunctionDescriptor) o; + return Objects.equal(moduleName, that.moduleName) && + Objects.equal(className, that.className) && + Objects.equal(functionName, that.functionName); + } + + @Override + public int hashCode() { + return Objects.hashCode(moduleName, className, functionName); + } + @Override public List toList() { return Arrays.asList(moduleName, className, functionName, "" /* function hash */); diff --git a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java index 753ef2588..5b0d79efe 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java @@ -1,6 +1,7 @@ package io.ray.runtime.task; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import io.ray.api.BaseActorHandle; import io.ray.api.id.ObjectId; import io.ray.api.options.ActorCreationOptions; @@ -18,14 +19,19 @@ public class NativeTaskSubmitter implements TaskSubmitter { @Override public List submitTask(FunctionDescriptor functionDescriptor, List args, int numReturns, CallOptions options) { - List returnIds = nativeSubmitTask(functionDescriptor, args, numReturns, options); + List returnIds = nativeSubmitTask(functionDescriptor, functionDescriptor.hashCode(), + args, numReturns, options); + if (returnIds == null) { + return ImmutableList.of(); + } return returnIds.stream().map(ObjectId::new).collect(Collectors.toList()); } @Override public BaseActorHandle createActor(FunctionDescriptor functionDescriptor, List args, ActorCreationOptions options) { - byte[] actorId = nativeCreateActor(functionDescriptor, args, options); + byte[] actorId = nativeCreateActor(functionDescriptor, functionDescriptor.hashCode(), args, + options); return NativeActorHandle.create(actorId, functionDescriptor.getLanguage()); } @@ -35,17 +41,21 @@ public class NativeTaskSubmitter implements TaskSubmitter { List args, int numReturns, CallOptions options) { Preconditions.checkState(actor instanceof NativeActorHandle); List returnIds = nativeSubmitActorTask(actor.getId().getBytes(), - functionDescriptor, args, numReturns, options); + functionDescriptor, functionDescriptor.hashCode(), args, numReturns, options); + if (returnIds == null) { + return ImmutableList.of(); + } return returnIds.stream().map(ObjectId::new).collect(Collectors.toList()); } private static native List nativeSubmitTask(FunctionDescriptor functionDescriptor, - List args, int numReturns, CallOptions callOptions); + int functionDescriptorHash, List args, int numReturns, CallOptions callOptions); private static native byte[] nativeCreateActor(FunctionDescriptor functionDescriptor, - List args, ActorCreationOptions actorCreationOptions); + int functionDescriptorHash, List args, + ActorCreationOptions actorCreationOptions); private static native List nativeSubmitActorTask(byte[] actorId, - FunctionDescriptor functionDescriptor, List args, int numReturns, - CallOptions callOptions); + FunctionDescriptor functionDescriptor, int functionDescriptorHash, List args, + int numReturns, CallOptions callOptions); } diff --git a/src/ray/common/function_descriptor.h b/src/ray/common/function_descriptor.h index b6c7e8f7b..4514e1863 100644 --- a/src/ray/common/function_descriptor.h +++ b/src/ray/common/function_descriptor.h @@ -41,6 +41,11 @@ class FunctionDescriptorInterface : public MessageWrapper()(ray::FunctionDescriptorType::FUNCTION_DESCRIPTOR_NOT_SET); } + inline bool operator==(const EmptyFunctionDescriptor &other) const { return true; } + + inline bool operator!=(const EmptyFunctionDescriptor &other) const { return false; } + virtual std::string ToString() const { return "{type=EmptyFunctionDescriptor}"; } }; @@ -90,17 +99,30 @@ class JavaFunctionDescriptor : public FunctionDescriptorInterface { std::hash()(typed_message_->signature()); } + inline bool operator==(const JavaFunctionDescriptor &other) const { + if (this == &other) { + return true; + } + return this->ClassName() == other.ClassName() && + this->FunctionName() == other.FunctionName() && + this->Signature() == other.Signature(); + } + + inline bool operator!=(const JavaFunctionDescriptor &other) const { + return !(*this == other); + } + virtual std::string ToString() const { return "{type=JavaFunctionDescriptor, class_name=" + typed_message_->class_name() + ", function_name=" + typed_message_->function_name() + ", signature=" + typed_message_->signature() + "}"; } - std::string ClassName() const { return typed_message_->class_name(); } + const std::string &ClassName() const { return typed_message_->class_name(); } - std::string FunctionName() const { return typed_message_->function_name(); } + const std::string &FunctionName() const { return typed_message_->function_name(); } - std::string Signature() const { return typed_message_->signature(); } + const std::string &Signature() const { return typed_message_->signature(); } private: const rpc::JavaFunctionDescriptor *typed_message_; @@ -127,6 +149,20 @@ class PythonFunctionDescriptor : public FunctionDescriptorInterface { std::hash()(typed_message_->function_hash()); } + inline bool operator==(const PythonFunctionDescriptor &other) const { + if (this == &other) { + return true; + } + return this->ModuleName() == other.ModuleName() && + this->ClassName() == other.ClassName() && + this->FunctionName() == other.FunctionName() && + this->FunctionHash() == other.FunctionHash(); + } + + inline bool operator!=(const PythonFunctionDescriptor &other) const { + return !(*this == other); + } + virtual std::string ToString() const { return "{type=PythonFunctionDescriptor, module_name=" + typed_message_->module_name() + @@ -140,13 +176,13 @@ class PythonFunctionDescriptor : public FunctionDescriptorInterface { typed_message_->function_name(); } - std::string ModuleName() const { return typed_message_->module_name(); } + const std::string &ModuleName() const { return typed_message_->module_name(); } - std::string ClassName() const { return typed_message_->class_name(); } + const std::string &ClassName() const { return typed_message_->class_name(); } - std::string FunctionName() const { return typed_message_->function_name(); } + const std::string &FunctionName() const { return typed_message_->function_name(); } - std::string FunctionHash() const { return typed_message_->function_hash(); } + const std::string &FunctionHash() const { return typed_message_->function_hash(); } private: const rpc::PythonFunctionDescriptor *typed_message_; @@ -172,17 +208,30 @@ class CppFunctionDescriptor : public FunctionDescriptorInterface { std::hash()(typed_message_->exec_function_offset()); } + inline bool operator==(const CppFunctionDescriptor &other) const { + if (this == &other) { + return true; + } + return this->LibName() == other.LibName() && + this->FunctionOffset() == other.FunctionOffset() && + this->ExecFunctionOffset() == other.ExecFunctionOffset(); + } + + inline bool operator!=(const CppFunctionDescriptor &other) const { + return !(*this == other); + } + virtual std::string ToString() const { return "{type=CppFunctionDescriptor, lib_name=" + typed_message_->lib_name() + ", function_offset=" + typed_message_->function_offset() + ", exec_function_offset=" + typed_message_->exec_function_offset() + "}"; } - std::string LibName() const { return typed_message_->lib_name(); } + const std::string &LibName() const { return typed_message_->lib_name(); } - std::string FunctionOffset() const { return typed_message_->function_offset(); } + const std::string &FunctionOffset() const { return typed_message_->function_offset(); } - std::string ExecFunctionOffset() const { + const std::string &ExecFunctionOffset() const { return typed_message_->exec_function_offset(); } @@ -193,11 +242,32 @@ class CppFunctionDescriptor : public FunctionDescriptorInterface { typedef std::shared_ptr FunctionDescriptor; inline bool operator==(const FunctionDescriptor &left, const FunctionDescriptor &right) { - if (left.get() != nullptr && right.get() != nullptr && left->Type() == right->Type() && - left->ToString() == right->ToString()) { + if (left.get() == right.get()) { return true; } - return left.get() == right.get(); + if (left.get() == nullptr || right.get() == nullptr) { + return false; + } + if (left->Type() != right->Type()) { + return false; + } + switch (left->Type()) { + case ray::FunctionDescriptorType::FUNCTION_DESCRIPTOR_NOT_SET: + return static_cast(*left) == + static_cast(*right); + case ray::FunctionDescriptorType::kJavaFunctionDescriptor: + return static_cast(*left) == + static_cast(*right); + case ray::FunctionDescriptorType::kPythonFunctionDescriptor: + return static_cast(*left) == + static_cast(*right); + case ray::FunctionDescriptorType::kCppFunctionDescriptor: + return static_cast(*left) == + static_cast(*right); + default: + RAY_LOG(FATAL) << "Unknown function descriptor type: " << left->Type(); + return false; + } } inline bool operator!=(const FunctionDescriptor &left, const FunctionDescriptor &right) { diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 6b4293795..6c8455411 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -25,6 +25,12 @@ thread_local JNIEnv *local_env = nullptr; jobject java_task_executor = nullptr; +/// Store Java instances of function descriptor in the cache to avoid unnessesary JNI +/// operations. +thread_local std::unordered_map>> + executor_function_descriptor_cache; + inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env, jobject gcs_client_options) { std::string ip = JavaStringToNativeString( @@ -73,9 +79,24 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( RAY_CHECK(env); RAY_CHECK(java_task_executor); + // convert RayFunction - jobject ray_function_array_list = NativeRayFunctionDescriptorToJavaStringList( - env, ray_function.GetFunctionDescriptor()); + auto function_descriptor = ray_function.GetFunctionDescriptor(); + size_t fd_hash = function_descriptor->Hash(); + auto &fd_vector = executor_function_descriptor_cache[fd_hash]; + jobject ray_function_array_list = nullptr; + for (auto &pair : fd_vector) { + if (pair.first == function_descriptor) { + ray_function_array_list = pair.second; + break; + } + } + if (!ray_function_array_list) { + ray_function_array_list = + NativeRayFunctionDescriptorToJavaStringList(env, function_descriptor); + fd_vector.emplace_back(function_descriptor, ray_function_array_list); + } + // convert args // TODO (kfstorm): Avoid copying binary data from Java to C++ jobject args_array_list = NativeVectorToJavaList>( @@ -86,19 +107,20 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( env->CallObjectMethod(java_task_executor, java_task_executor_execute, ray_function_array_list, args_array_list); RAY_CHECK_JAVA_EXCEPTION(env); - std::vector> return_objects; - JavaListToNativeVector>( - env, java_return_objects, &return_objects, - [](JNIEnv *env, jobject java_native_ray_object) { - return JavaNativeRayObjectToNativeRayObject(env, java_native_ray_object); - }); - for (auto &obj : return_objects) { - results->push_back(obj); + if (!return_ids.empty()) { + std::vector> return_objects; + JavaListToNativeVector>( + env, java_return_objects, &return_objects, + [](JNIEnv *env, jobject java_native_ray_object) { + return JavaNativeRayObjectToNativeRayObject(env, java_native_ray_object); + }); + for (auto &obj : return_objects) { + results->push_back(obj); + } } env->DeleteLocalRef(java_return_objects); env->DeleteLocalRef(args_array_list); - env->DeleteLocalRef(ray_function_array_list); return ray::Status::OK(); }; diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h index 1b9e7d006..338a7dfd8 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h @@ -25,7 +25,7 @@ extern "C" { * Class: io_ray_runtime_RayNativeRuntime * Method: nativeInitialize * Signature: - * (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;Ljava/util/Map;)V + * (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;Ljava/lang/String;ILjava/lang/String;Ljava/util/Map;[B)V */ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject, @@ -42,7 +42,7 @@ Java_io_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor(JNIEnv *, jclass, job /* * Class: io_ray_runtime_RayNativeRuntime * Method: nativeShutdown - * Signature: ()V + * Signature: (Z)V */ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEnv *, jclass); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index cefa121c5..af0e97d99 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -21,7 +21,19 @@ #include "ray/core_worker/core_worker.h" #include "ray/core_worker/lib/java/jni_utils.h" -inline ray::RayFunction ToRayFunction(JNIEnv *env, jobject functionDescriptor) { +/// Store C++ instances of ray function in the cache to avoid unnessesary JNI operations. +thread_local std::unordered_map>> + submitter_function_descriptor_cache; + +inline const ray::RayFunction &ToRayFunction(JNIEnv *env, jobject functionDescriptor, + jint hash) { + auto &fd_vector = submitter_function_descriptor_cache[hash]; + for (auto &pair : fd_vector) { + if (env->CallBooleanMethod(pair.first, java_object_equals, functionDescriptor)) { + return pair.second; + } + } + std::vector function_descriptor_list; jobject list = env->CallObjectMethod(functionDescriptor, java_function_descriptor_to_list); @@ -35,8 +47,9 @@ inline ray::RayFunction ToRayFunction(JNIEnv *env, jobject functionDescriptor) { RAY_CHECK_JAVA_EXCEPTION(env); ray::FunctionDescriptor function_descriptor = ray::FunctionDescriptorBuilder::FromVector(language, function_descriptor_list); - ray::RayFunction ray_function{language, function_descriptor}; - return ray_function; + fd_vector.emplace_back(env->NewGlobalRef(functionDescriptor), + ray::RayFunction(language, function_descriptor)); + return fd_vector.back().second; } inline std::vector> ToTaskArgs(JNIEnv *env, jobject args) { @@ -129,9 +142,10 @@ extern "C" { #endif JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask( - JNIEnv *env, jclass p, jobject functionDescriptor, jobject args, jint numReturns, - jobject callOptions) { - auto ray_function = ToRayFunction(env, functionDescriptor); + JNIEnv *env, jclass p, jobject functionDescriptor, jint functionDescriptorHash, + jobject args, jint numReturns, jobject callOptions) { + const auto &ray_function = + ToRayFunction(env, functionDescriptor, functionDescriptorHash); auto task_args = ToTaskArgs(env, args); auto task_options = ToTaskOptions(env, numReturns, callOptions); @@ -141,14 +155,20 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSub task_options, &return_ids, /*max_retries=*/0); + // This is to avoid creating an empty java list and boost performance. + if (return_ids.empty()) { + return nullptr; + } + return NativeIdVectorToJavaByteArrayList(env, return_ids); } JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor( - JNIEnv *env, jclass p, jobject functionDescriptor, jobject args, - jobject actorCreationOptions) { - auto ray_function = ToRayFunction(env, functionDescriptor); + JNIEnv *env, jclass p, jobject functionDescriptor, jint functionDescriptorHash, + jobject args, jobject actorCreationOptions) { + const auto &ray_function = + ToRayFunction(env, functionDescriptor, functionDescriptorHash); auto task_args = ToTaskArgs(env, args); auto actor_creation_options = ToActorCreationOptions(env, actorCreationOptions); @@ -163,10 +183,11 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor( JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask( - JNIEnv *env, jclass p, jbyteArray actorId, jobject functionDescriptor, jobject args, - jint numReturns, jobject callOptions) { + JNIEnv *env, jclass p, jbyteArray actorId, jobject functionDescriptor, + jint functionDescriptorHash, jobject args, jint numReturns, jobject callOptions) { auto actor_id = JavaByteArrayToId(env, actorId); - auto ray_function = ToRayFunction(env, functionDescriptor); + const auto &ray_function = + ToRayFunction(env, functionDescriptor, functionDescriptorHash); auto task_args = ToTaskArgs(env, args); auto task_options = ToTaskOptions(env, numReturns, callOptions); @@ -174,6 +195,12 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask( ray::CoreWorkerProcess::GetCoreWorker().SubmitActorTask( actor_id, ray_function, task_args, task_options, &return_ids); + + // This is to avoid creating an empty java list and boost performance. + if (return_ids.empty()) { + return nullptr; + } + return NativeIdVectorToJavaByteArrayList(env, return_ids); } diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h index c657ae323..1863fe311 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h @@ -25,31 +25,31 @@ extern "C" { * Class: io_ray_runtime_task_NativeTaskSubmitter * Method: nativeSubmitTask * Signature: - * (Lio/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILio/ray/api/options/CallOptions;)Ljava/util/List; + * (Lio/ray/runtime/functionmanager/FunctionDescriptor;ILjava/util/List;ILio/ray/api/options/CallOptions;)Ljava/util/List; */ JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask( - JNIEnv *, jclass, jobject, jobject, jint, jobject); + JNIEnv *, jclass, jobject, jint, jobject, jint, jobject); /* * Class: io_ray_runtime_task_NativeTaskSubmitter * Method: nativeCreateActor * Signature: - * (Lio/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;Lio/ray/api/options/ActorCreationOptions;)[B + * (Lio/ray/runtime/functionmanager/FunctionDescriptor;ILjava/util/List;Lio/ray/api/options/ActorCreationOptions;)[B */ JNIEXPORT jbyteArray JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(JNIEnv *, jclass, jobject, - jobject, jobject); + jint, jobject, jobject); /* * Class: io_ray_runtime_task_NativeTaskSubmitter * Method: nativeSubmitActorTask * Signature: - * ([BLio/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILio/ray/api/options/CallOptions;)Ljava/util/List; + * ([BLio/ray/runtime/functionmanager/FunctionDescriptor;ILjava/util/List;ILio/ray/api/options/CallOptions;)Ljava/util/List; */ JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(JNIEnv *, jclass, jbyteArray, jobject, - jobject, jint, + jint, jobject, jint, jobject); #ifdef __cplusplus diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 0b0a668db..1805112f6 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -20,6 +20,9 @@ jmethodID java_boolean_init; jclass java_double_class; jmethodID java_double_double_value; +jclass java_object_class; +jmethodID java_object_equals; + jclass java_list_class; jmethodID java_list_size; jmethodID java_list_get; @@ -108,6 +111,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_double_class = LoadClass(env, "java/lang/Double"); java_double_double_value = env->GetMethodID(java_double_class, "doubleValue", "()D"); + java_object_class = LoadClass(env, "java/lang/Object"); + java_object_equals = + env->GetMethodID(java_object_class, "equals", "(Ljava/lang/Object;)Z"); + java_list_class = LoadClass(env, "java/util/List"); java_list_size = env->GetMethodID(java_list_class, "size", "()I"); java_list_get = env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;"); @@ -205,6 +212,7 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { env->DeleteGlobalRef(java_boolean_class); env->DeleteGlobalRef(java_double_class); + env->DeleteGlobalRef(java_object_class); env->DeleteGlobalRef(java_list_class); env->DeleteGlobalRef(java_array_list_class); env->DeleteGlobalRef(java_map_class); diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 15280cc8a..f13d6f7b4 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -32,6 +32,11 @@ extern jclass java_double_class; /// doubleValue method of Double class extern jmethodID java_double_double_value; +/// Object class +extern jclass java_object_class; +/// equals method of Object class +extern jmethodID java_object_equals; + /// List class extern jclass java_list_class; /// size method of List class