[Java] Improve JNI performance when submitting and executing tasks (#9032)

This commit is contained in:
Kai Yang
2020-07-10 17:51:07 +08:00
committed by GitHub
parent d49dadf891
commit a98cd0670e
9 changed files with 213 additions and 51 deletions
@@ -1,5 +1,6 @@
package io.ray.runtime.functionmanager;
import com.google.common.base.Objects;
import io.ray.runtime.generated.Common.Language;
import java.util.Arrays;
import java.util.List;
@@ -26,6 +27,25 @@ public class PyFunctionDescriptor implements FunctionDescriptor {
return moduleName + "." + className + "." + functionName;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
PyFunctionDescriptor that = (PyFunctionDescriptor) o;
return Objects.equal(moduleName, that.moduleName) &&
Objects.equal(className, that.className) &&
Objects.equal(functionName, that.functionName);
}
@Override
public int hashCode() {
return Objects.hashCode(moduleName, className, functionName);
}
@Override
public List<String> toList() {
return Arrays.asList(moduleName, className, functionName, "" /* function hash */);
@@ -1,6 +1,7 @@
package io.ray.runtime.task;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.ray.api.BaseActorHandle;
import io.ray.api.id.ObjectId;
import io.ray.api.options.ActorCreationOptions;
@@ -18,14 +19,19 @@ public class NativeTaskSubmitter implements TaskSubmitter {
@Override
public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
int numReturns, CallOptions options) {
List<byte[]> returnIds = nativeSubmitTask(functionDescriptor, args, numReturns, options);
List<byte[]> returnIds = nativeSubmitTask(functionDescriptor, functionDescriptor.hashCode(),
args, numReturns, options);
if (returnIds == null) {
return ImmutableList.of();
}
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
}
@Override
public BaseActorHandle createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options) {
byte[] actorId = nativeCreateActor(functionDescriptor, args, options);
byte[] actorId = nativeCreateActor(functionDescriptor, functionDescriptor.hashCode(), args,
options);
return NativeActorHandle.create(actorId, functionDescriptor.getLanguage());
}
@@ -35,17 +41,21 @@ public class NativeTaskSubmitter implements TaskSubmitter {
List<FunctionArg> args, int numReturns, CallOptions options) {
Preconditions.checkState(actor instanceof NativeActorHandle);
List<byte[]> returnIds = nativeSubmitActorTask(actor.getId().getBytes(),
functionDescriptor, args, numReturns, options);
functionDescriptor, functionDescriptor.hashCode(), args, numReturns, options);
if (returnIds == null) {
return ImmutableList.of();
}
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
}
private static native List<byte[]> nativeSubmitTask(FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions callOptions);
int functionDescriptorHash, List<FunctionArg> args, int numReturns, CallOptions callOptions);
private static native byte[] nativeCreateActor(FunctionDescriptor functionDescriptor,
List<FunctionArg> args, ActorCreationOptions actorCreationOptions);
int functionDescriptorHash, List<FunctionArg> args,
ActorCreationOptions actorCreationOptions);
private static native List<byte[]> nativeSubmitActorTask(byte[] actorId,
FunctionDescriptor functionDescriptor, List<FunctionArg> args, int numReturns,
CallOptions callOptions);
FunctionDescriptor functionDescriptor, int functionDescriptorHash, List<FunctionArg> args,
int numReturns, CallOptions callOptions);
}
+83 -13
View File
@@ -41,6 +41,11 @@ class FunctionDescriptorInterface : public MessageWrapper<rpc::FunctionDescripto
virtual size_t Hash() const = 0;
// DO NOT define operator==() or operator!=() in the base class.
// Let the derived classes define and implement.
// This is to avoid unexpected behaviors when comparing function descriptors of
// different declard types, as in this case, the base class version is invoked.
virtual std::string ToString() const = 0;
// A one-word summary of the function call site (e.g., __main__.foo).
@@ -67,6 +72,10 @@ class EmptyFunctionDescriptor : public FunctionDescriptorInterface {
return std::hash<int>()(ray::FunctionDescriptorType::FUNCTION_DESCRIPTOR_NOT_SET);
}
inline bool operator==(const EmptyFunctionDescriptor &other) const { return true; }
inline bool operator!=(const EmptyFunctionDescriptor &other) const { return false; }
virtual std::string ToString() const { return "{type=EmptyFunctionDescriptor}"; }
};
@@ -90,17 +99,30 @@ class JavaFunctionDescriptor : public FunctionDescriptorInterface {
std::hash<std::string>()(typed_message_->signature());
}
inline bool operator==(const JavaFunctionDescriptor &other) const {
if (this == &other) {
return true;
}
return this->ClassName() == other.ClassName() &&
this->FunctionName() == other.FunctionName() &&
this->Signature() == other.Signature();
}
inline bool operator!=(const JavaFunctionDescriptor &other) const {
return !(*this == other);
}
virtual std::string ToString() const {
return "{type=JavaFunctionDescriptor, class_name=" + typed_message_->class_name() +
", function_name=" + typed_message_->function_name() +
", signature=" + typed_message_->signature() + "}";
}
std::string ClassName() const { return typed_message_->class_name(); }
const std::string &ClassName() const { return typed_message_->class_name(); }
std::string FunctionName() const { return typed_message_->function_name(); }
const std::string &FunctionName() const { return typed_message_->function_name(); }
std::string Signature() const { return typed_message_->signature(); }
const std::string &Signature() const { return typed_message_->signature(); }
private:
const rpc::JavaFunctionDescriptor *typed_message_;
@@ -127,6 +149,20 @@ class PythonFunctionDescriptor : public FunctionDescriptorInterface {
std::hash<std::string>()(typed_message_->function_hash());
}
inline bool operator==(const PythonFunctionDescriptor &other) const {
if (this == &other) {
return true;
}
return this->ModuleName() == other.ModuleName() &&
this->ClassName() == other.ClassName() &&
this->FunctionName() == other.FunctionName() &&
this->FunctionHash() == other.FunctionHash();
}
inline bool operator!=(const PythonFunctionDescriptor &other) const {
return !(*this == other);
}
virtual std::string ToString() const {
return "{type=PythonFunctionDescriptor, module_name=" +
typed_message_->module_name() +
@@ -140,13 +176,13 @@ class PythonFunctionDescriptor : public FunctionDescriptorInterface {
typed_message_->function_name();
}
std::string ModuleName() const { return typed_message_->module_name(); }
const std::string &ModuleName() const { return typed_message_->module_name(); }
std::string ClassName() const { return typed_message_->class_name(); }
const std::string &ClassName() const { return typed_message_->class_name(); }
std::string FunctionName() const { return typed_message_->function_name(); }
const std::string &FunctionName() const { return typed_message_->function_name(); }
std::string FunctionHash() const { return typed_message_->function_hash(); }
const std::string &FunctionHash() const { return typed_message_->function_hash(); }
private:
const rpc::PythonFunctionDescriptor *typed_message_;
@@ -172,17 +208,30 @@ class CppFunctionDescriptor : public FunctionDescriptorInterface {
std::hash<std::string>()(typed_message_->exec_function_offset());
}
inline bool operator==(const CppFunctionDescriptor &other) const {
if (this == &other) {
return true;
}
return this->LibName() == other.LibName() &&
this->FunctionOffset() == other.FunctionOffset() &&
this->ExecFunctionOffset() == other.ExecFunctionOffset();
}
inline bool operator!=(const CppFunctionDescriptor &other) const {
return !(*this == other);
}
virtual std::string ToString() const {
return "{type=CppFunctionDescriptor, lib_name=" + typed_message_->lib_name() +
", function_offset=" + typed_message_->function_offset() +
", exec_function_offset=" + typed_message_->exec_function_offset() + "}";
}
std::string LibName() const { return typed_message_->lib_name(); }
const std::string &LibName() const { return typed_message_->lib_name(); }
std::string FunctionOffset() const { return typed_message_->function_offset(); }
const std::string &FunctionOffset() const { return typed_message_->function_offset(); }
std::string ExecFunctionOffset() const {
const std::string &ExecFunctionOffset() const {
return typed_message_->exec_function_offset();
}
@@ -193,11 +242,32 @@ class CppFunctionDescriptor : public FunctionDescriptorInterface {
typedef std::shared_ptr<FunctionDescriptorInterface> FunctionDescriptor;
inline bool operator==(const FunctionDescriptor &left, const FunctionDescriptor &right) {
if (left.get() != nullptr && right.get() != nullptr && left->Type() == right->Type() &&
left->ToString() == right->ToString()) {
if (left.get() == right.get()) {
return true;
}
return left.get() == right.get();
if (left.get() == nullptr || right.get() == nullptr) {
return false;
}
if (left->Type() != right->Type()) {
return false;
}
switch (left->Type()) {
case ray::FunctionDescriptorType::FUNCTION_DESCRIPTOR_NOT_SET:
return static_cast<const EmptyFunctionDescriptor &>(*left) ==
static_cast<const EmptyFunctionDescriptor &>(*right);
case ray::FunctionDescriptorType::kJavaFunctionDescriptor:
return static_cast<const JavaFunctionDescriptor &>(*left) ==
static_cast<const JavaFunctionDescriptor &>(*right);
case ray::FunctionDescriptorType::kPythonFunctionDescriptor:
return static_cast<const PythonFunctionDescriptor &>(*left) ==
static_cast<const PythonFunctionDescriptor &>(*right);
case ray::FunctionDescriptorType::kCppFunctionDescriptor:
return static_cast<const CppFunctionDescriptor &>(*left) ==
static_cast<const CppFunctionDescriptor &>(*right);
default:
RAY_LOG(FATAL) << "Unknown function descriptor type: " << left->Type();
return false;
}
}
inline bool operator!=(const FunctionDescriptor &left, const FunctionDescriptor &right) {
@@ -25,6 +25,12 @@
thread_local JNIEnv *local_env = nullptr;
jobject java_task_executor = nullptr;
/// Store Java instances of function descriptor in the cache to avoid unnessesary JNI
/// operations.
thread_local std::unordered_map<size_t,
std::vector<std::pair<ray::FunctionDescriptor, jobject>>>
executor_function_descriptor_cache;
inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env,
jobject gcs_client_options) {
std::string ip = JavaStringToNativeString(
@@ -73,9 +79,24 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
RAY_CHECK(env);
RAY_CHECK(java_task_executor);
// convert RayFunction
jobject ray_function_array_list = NativeRayFunctionDescriptorToJavaStringList(
env, ray_function.GetFunctionDescriptor());
auto function_descriptor = ray_function.GetFunctionDescriptor();
size_t fd_hash = function_descriptor->Hash();
auto &fd_vector = executor_function_descriptor_cache[fd_hash];
jobject ray_function_array_list = nullptr;
for (auto &pair : fd_vector) {
if (pair.first == function_descriptor) {
ray_function_array_list = pair.second;
break;
}
}
if (!ray_function_array_list) {
ray_function_array_list =
NativeRayFunctionDescriptorToJavaStringList(env, function_descriptor);
fd_vector.emplace_back(function_descriptor, ray_function_array_list);
}
// convert args
// TODO (kfstorm): Avoid copying binary data from Java to C++
jobject args_array_list = NativeVectorToJavaList<std::shared_ptr<ray::RayObject>>(
@@ -86,19 +107,20 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
env->CallObjectMethod(java_task_executor, java_task_executor_execute,
ray_function_array_list, args_array_list);
RAY_CHECK_JAVA_EXCEPTION(env);
std::vector<std::shared_ptr<ray::RayObject>> return_objects;
JavaListToNativeVector<std::shared_ptr<ray::RayObject>>(
env, java_return_objects, &return_objects,
[](JNIEnv *env, jobject java_native_ray_object) {
return JavaNativeRayObjectToNativeRayObject(env, java_native_ray_object);
});
for (auto &obj : return_objects) {
results->push_back(obj);
if (!return_ids.empty()) {
std::vector<std::shared_ptr<ray::RayObject>> return_objects;
JavaListToNativeVector<std::shared_ptr<ray::RayObject>>(
env, java_return_objects, &return_objects,
[](JNIEnv *env, jobject java_native_ray_object) {
return JavaNativeRayObjectToNativeRayObject(env, java_native_ray_object);
});
for (auto &obj : return_objects) {
results->push_back(obj);
}
}
env->DeleteLocalRef(java_return_objects);
env->DeleteLocalRef(args_array_list);
env->DeleteLocalRef(ray_function_array_list);
return ray::Status::OK();
};
@@ -25,7 +25,7 @@ extern "C" {
* Class: io_ray_runtime_RayNativeRuntime
* Method: nativeInitialize
* Signature:
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;Ljava/util/Map;)V
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;Ljava/lang/String;ILjava/lang/String;Ljava/util/Map;[B)V
*/
JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject,
@@ -42,7 +42,7 @@ Java_io_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor(JNIEnv *, jclass, job
/*
* Class: io_ray_runtime_RayNativeRuntime
* Method: nativeShutdown
* Signature: ()V
* Signature: (Z)V
*/
JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEnv *,
jclass);
@@ -21,7 +21,19 @@
#include "ray/core_worker/core_worker.h"
#include "ray/core_worker/lib/java/jni_utils.h"
inline ray::RayFunction ToRayFunction(JNIEnv *env, jobject functionDescriptor) {
/// Store C++ instances of ray function in the cache to avoid unnessesary JNI operations.
thread_local std::unordered_map<jint, std::vector<std::pair<jobject, ray::RayFunction>>>
submitter_function_descriptor_cache;
inline const ray::RayFunction &ToRayFunction(JNIEnv *env, jobject functionDescriptor,
jint hash) {
auto &fd_vector = submitter_function_descriptor_cache[hash];
for (auto &pair : fd_vector) {
if (env->CallBooleanMethod(pair.first, java_object_equals, functionDescriptor)) {
return pair.second;
}
}
std::vector<std::string> function_descriptor_list;
jobject list =
env->CallObjectMethod(functionDescriptor, java_function_descriptor_to_list);
@@ -35,8 +47,9 @@ inline ray::RayFunction ToRayFunction(JNIEnv *env, jobject functionDescriptor) {
RAY_CHECK_JAVA_EXCEPTION(env);
ray::FunctionDescriptor function_descriptor =
ray::FunctionDescriptorBuilder::FromVector(language, function_descriptor_list);
ray::RayFunction ray_function{language, function_descriptor};
return ray_function;
fd_vector.emplace_back(env->NewGlobalRef(functionDescriptor),
ray::RayFunction(language, function_descriptor));
return fd_vector.back().second;
}
inline std::vector<std::unique_ptr<ray::TaskArg>> ToTaskArgs(JNIEnv *env, jobject args) {
@@ -129,9 +142,10 @@ extern "C" {
#endif
JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask(
JNIEnv *env, jclass p, jobject functionDescriptor, jobject args, jint numReturns,
jobject callOptions) {
auto ray_function = ToRayFunction(env, functionDescriptor);
JNIEnv *env, jclass p, jobject functionDescriptor, jint functionDescriptorHash,
jobject args, jint numReturns, jobject callOptions) {
const auto &ray_function =
ToRayFunction(env, functionDescriptor, functionDescriptorHash);
auto task_args = ToTaskArgs(env, args);
auto task_options = ToTaskOptions(env, numReturns, callOptions);
@@ -141,14 +155,20 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSub
task_options, &return_ids,
/*max_retries=*/0);
// This is to avoid creating an empty java list and boost performance.
if (return_ids.empty()) {
return nullptr;
}
return NativeIdVectorToJavaByteArrayList(env, return_ids);
}
JNIEXPORT jbyteArray JNICALL
Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(
JNIEnv *env, jclass p, jobject functionDescriptor, jobject args,
jobject actorCreationOptions) {
auto ray_function = ToRayFunction(env, functionDescriptor);
JNIEnv *env, jclass p, jobject functionDescriptor, jint functionDescriptorHash,
jobject args, jobject actorCreationOptions) {
const auto &ray_function =
ToRayFunction(env, functionDescriptor, functionDescriptorHash);
auto task_args = ToTaskArgs(env, args);
auto actor_creation_options = ToActorCreationOptions(env, actorCreationOptions);
@@ -163,10 +183,11 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(
JNIEXPORT jobject JNICALL
Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(
JNIEnv *env, jclass p, jbyteArray actorId, jobject functionDescriptor, jobject args,
jint numReturns, jobject callOptions) {
JNIEnv *env, jclass p, jbyteArray actorId, jobject functionDescriptor,
jint functionDescriptorHash, jobject args, jint numReturns, jobject callOptions) {
auto actor_id = JavaByteArrayToId<ray::ActorID>(env, actorId);
auto ray_function = ToRayFunction(env, functionDescriptor);
const auto &ray_function =
ToRayFunction(env, functionDescriptor, functionDescriptorHash);
auto task_args = ToTaskArgs(env, args);
auto task_options = ToTaskOptions(env, numReturns, callOptions);
@@ -174,6 +195,12 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(
ray::CoreWorkerProcess::GetCoreWorker().SubmitActorTask(
actor_id, ray_function, task_args, task_options, &return_ids);
// This is to avoid creating an empty java list and boost performance.
if (return_ids.empty()) {
return nullptr;
}
return NativeIdVectorToJavaByteArrayList(env, return_ids);
}
@@ -25,31 +25,31 @@ extern "C" {
* Class: io_ray_runtime_task_NativeTaskSubmitter
* Method: nativeSubmitTask
* Signature:
* (Lio/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILio/ray/api/options/CallOptions;)Ljava/util/List;
* (Lio/ray/runtime/functionmanager/FunctionDescriptor;ILjava/util/List;ILio/ray/api/options/CallOptions;)Ljava/util/List;
*/
JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask(
JNIEnv *, jclass, jobject, jobject, jint, jobject);
JNIEnv *, jclass, jobject, jint, jobject, jint, jobject);
/*
* Class: io_ray_runtime_task_NativeTaskSubmitter
* Method: nativeCreateActor
* Signature:
* (Lio/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;Lio/ray/api/options/ActorCreationOptions;)[B
* (Lio/ray/runtime/functionmanager/FunctionDescriptor;ILjava/util/List;Lio/ray/api/options/ActorCreationOptions;)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(JNIEnv *, jclass, jobject,
jobject, jobject);
jint, jobject, jobject);
/*
* Class: io_ray_runtime_task_NativeTaskSubmitter
* Method: nativeSubmitActorTask
* Signature:
* ([BLio/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILio/ray/api/options/CallOptions;)Ljava/util/List;
* ([BLio/ray/runtime/functionmanager/FunctionDescriptor;ILjava/util/List;ILio/ray/api/options/CallOptions;)Ljava/util/List;
*/
JNIEXPORT jobject JNICALL
Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(JNIEnv *, jclass,
jbyteArray, jobject,
jobject, jint,
jint, jobject, jint,
jobject);
#ifdef __cplusplus
+8
View File
@@ -20,6 +20,9 @@ jmethodID java_boolean_init;
jclass java_double_class;
jmethodID java_double_double_value;
jclass java_object_class;
jmethodID java_object_equals;
jclass java_list_class;
jmethodID java_list_size;
jmethodID java_list_get;
@@ -108,6 +111,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
java_double_class = LoadClass(env, "java/lang/Double");
java_double_double_value = env->GetMethodID(java_double_class, "doubleValue", "()D");
java_object_class = LoadClass(env, "java/lang/Object");
java_object_equals =
env->GetMethodID(java_object_class, "equals", "(Ljava/lang/Object;)Z");
java_list_class = LoadClass(env, "java/util/List");
java_list_size = env->GetMethodID(java_list_class, "size", "()I");
java_list_get = env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;");
@@ -205,6 +212,7 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) {
env->DeleteGlobalRef(java_boolean_class);
env->DeleteGlobalRef(java_double_class);
env->DeleteGlobalRef(java_object_class);
env->DeleteGlobalRef(java_list_class);
env->DeleteGlobalRef(java_array_list_class);
env->DeleteGlobalRef(java_map_class);
+5
View File
@@ -32,6 +32,11 @@ extern jclass java_double_class;
/// doubleValue method of Double class
extern jmethodID java_double_double_value;
/// Object class
extern jclass java_object_class;
/// equals method of Object class
extern jmethodID java_object_equals;
/// List class
extern jclass java_list_class;
/// size method of List class