diff --git a/java/runtime/src/main/java/org/ray/runtime/util/JniExceptionUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/JniExceptionUtil.java new file mode 100644 index 000000000..bd58af6cc --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/util/JniExceptionUtil.java @@ -0,0 +1,19 @@ +package org.ray.runtime.util; + +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +// Required by JNI macro RAY_CHECK_JAVA_EXCEPTION +public final class JniExceptionUtil { + + private static final Logger LOGGER = LoggerFactory.getLogger(JniExceptionUtil.class); + + public static String getStackTrace(String fileName, int lineNumber, String function, + Throwable throwable) { + LOGGER.error("An unexpected exception occurred while executing Java code from JNI ({}:{} {}).", + fileName, lineNumber, function, throwable); + // Return the exception in string form to JNI. + return ExceptionUtils.getStackTrace(throwable); + } +} diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index ec9511239..3ee763694 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -31,6 +31,9 @@ jmethodID java_map_entry_get_value; jclass java_ray_exception_class; +jclass java_jni_exception_util_class; +jmethodID java_jni_exception_util_get_stack_trace; + jclass java_base_id_class; jmethodID java_base_id_get_bytes; @@ -122,6 +125,11 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_ray_exception_class = LoadClass(env, "org/ray/api/exception/RayException"); + java_jni_exception_util_class = LoadClass(env, "org/ray/runtime/util/JniExceptionUtil"); + java_jni_exception_util_get_stack_trace = env->GetStaticMethodID( + java_jni_exception_util_class, "getStackTrace", + "(Ljava/lang/String;ILjava/lang/String;Ljava/lang/Throwable;)Ljava/lang/String;"); + java_base_id_class = LoadClass(env, "org/ray/api/id/BaseId"); java_base_id_get_bytes = env->GetMethodID(java_base_id_class, "getBytes", "()[B"); @@ -195,6 +203,7 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { env->DeleteGlobalRef(java_iterator_class); env->DeleteGlobalRef(java_map_entry_class); env->DeleteGlobalRef(java_ray_exception_class); + env->DeleteGlobalRef(java_jni_exception_util_class); env->DeleteGlobalRef(java_base_id_class); env->DeleteGlobalRef(java_function_descriptor_class); env->DeleteGlobalRef(java_language_class); diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 75e29a785..2024c9b3d 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -60,6 +60,11 @@ extern jmethodID java_map_entry_get_value; /// RayException class extern jclass java_ray_exception_class; +/// JniExceptionUtil class +extern jclass java_jni_exception_util_class; +/// getStackTrace method of JniExceptionUtil class +extern jmethodID java_jni_exception_util_get_stack_trace; + /// BaseId class extern jclass java_base_id_class; /// getBytes method of BaseId class @@ -136,6 +141,29 @@ extern JavaVM *jvm; } \ } +#define RAY_CHECK_JAVA_EXCEPTION(env) \ + { \ + jthrowable throwable = env->ExceptionOccurred(); \ + if (throwable) { \ + jstring java_file_name = env->NewStringUTF(__FILE__); \ + jstring java_function = env->NewStringUTF(__func__); \ + jobject java_error_message = env->CallStaticObjectMethod( \ + java_jni_exception_util_class, java_jni_exception_util_get_stack_trace, \ + java_file_name, __LINE__, java_function, throwable); \ + std::string error_message = \ + JavaStringToNativeString(env, static_cast(java_error_message)); \ + env->DeleteLocalRef(throwable); \ + env->DeleteLocalRef(java_file_name); \ + env->DeleteLocalRef(java_function); \ + env->DeleteLocalRef(java_error_message); \ + RAY_LOG(FATAL) << "An unexpected exception occurred while executing Java code " \ + "from JNI (" \ + << __FILE__ << ":" << __LINE__ << " " << __func__ << ")." \ + << "\n" \ + << error_message; \ + } \ + } + /// Represents a byte buffer of Java byte array. /// The destructor will automatically call ReleaseByteArrayElements. /// NOTE: Instances of this class cannot be used across threads. @@ -204,9 +232,11 @@ inline void JavaListToNativeVector( JNIEnv *env, jobject java_list, std::vector *native_vector, std::function element_converter) { int size = env->CallIntMethod(java_list, java_list_size); + RAY_CHECK_JAVA_EXCEPTION(env); native_vector->clear(); for (int i = 0; i < size; i++) { auto element = env->CallObjectMethod(java_list, java_list_get, (jint)i); + RAY_CHECK_JAVA_EXCEPTION(env); native_vector->emplace_back(element_converter(env, element)); env->DeleteLocalRef(element); } @@ -232,6 +262,7 @@ inline jobject NativeVectorToJavaList( for (const auto &item : native_vector) { auto element = element_converter(env, item); env->CallVoidMethod(java_list, java_list_add, element); + RAY_CHECK_JAVA_EXCEPTION(env); env->DeleteLocalRef(element); } return java_list; diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc index 52b0653f7..901c84660 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc @@ -55,6 +55,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork jobject java_return_objects = env->CallObjectMethod(local_java_task_executor, java_task_executor_execute, ray_function_array_list, args_array_list); + RAY_CHECK_JAVA_EXCEPTION(env); std::vector> return_objects; JavaListToNativeVector>( env, java_return_objects, &return_objects, diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc index 632f8db96..a4141306a 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc @@ -11,12 +11,15 @@ inline ray::CoreWorker &GetCoreWorker(jlong nativeCoreWorkerPointer) { inline ray::RayFunction ToRayFunction(JNIEnv *env, jobject functionDescriptor) { std::vector function_descriptor; - JavaStringListToNativeStringVector( - env, env->CallObjectMethod(functionDescriptor, java_function_descriptor_to_list), - &function_descriptor); + jobject list = + env->CallObjectMethod(functionDescriptor, java_function_descriptor_to_list); + RAY_CHECK_JAVA_EXCEPTION(env); + JavaStringListToNativeStringVector(env, list, &function_descriptor); jobject java_language = env->CallObjectMethod(functionDescriptor, java_function_descriptor_get_language); + RAY_CHECK_JAVA_EXCEPTION(env); int language = env->CallIntMethod(java_language, java_language_get_number); + RAY_CHECK_JAVA_EXCEPTION(env); ray::RayFunction ray_function{static_cast<::Language>(language), function_descriptor}; return ray_function; } @@ -29,6 +32,7 @@ inline std::vector ToTaskArgs(JNIEnv *env, jobject args) { if (java_id) { auto java_id_bytes = static_cast( env->CallObjectMethod(java_id, java_base_id_get_bytes)); + RAY_CHECK_JAVA_EXCEPTION(env); return ray::TaskArg::PassByReference( JavaByteArrayToId(env, java_id_bytes)); } @@ -46,16 +50,23 @@ inline std::unordered_map ToResources(JNIEnv *env, std::unordered_map resources; if (java_resources) { jobject entry_set = env->CallObjectMethod(java_resources, java_map_entry_set); + RAY_CHECK_JAVA_EXCEPTION(env); jobject iterator = env->CallObjectMethod(entry_set, java_set_iterator); + RAY_CHECK_JAVA_EXCEPTION(env); while (env->CallBooleanMethod(iterator, java_iterator_has_next)) { + RAY_CHECK_JAVA_EXCEPTION(env); jobject map_entry = env->CallObjectMethod(iterator, java_iterator_next); - std::string key = JavaStringToNativeString( - env, (jstring)env->CallObjectMethod(map_entry, java_map_entry_get_key)); - double value = env->CallDoubleMethod( - env->CallObjectMethod(map_entry, java_map_entry_get_value), - java_double_double_value); + RAY_CHECK_JAVA_EXCEPTION(env); + auto java_key = (jstring)env->CallObjectMethod(map_entry, java_map_entry_get_key); + RAY_CHECK_JAVA_EXCEPTION(env); + std::string key = JavaStringToNativeString(env, java_key); + auto java_value = env->CallObjectMethod(map_entry, java_map_entry_get_value); + RAY_CHECK_JAVA_EXCEPTION(env); + double value = env->CallDoubleMethod(java_value, java_double_double_value); + RAY_CHECK_JAVA_EXCEPTION(env); resources.emplace(key, value); } + RAY_CHECK_JAVA_EXCEPTION(env); } return resources; }