From 3a66d47a3a4b3fc31f82eb26fdff2ea282f85e1b Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Sat, 9 Feb 2019 18:10:22 +0800 Subject: [PATCH] Remove RAY_CHECK from JNI code (#3978) * Remove RAY_CHECK in JNI * Try to add mvn test to test the exception. * Refine * Address comments --- .../ray/runtime/raylet/RayletClientImpl.java | 22 ++++----- .../org/ray/api/test/ClientExceptionTest.java | 46 +++++++++++++++++++ ...org_ray_runtime_raylet_RayletClientImpl.cc | 36 +++++++++------ .../org_ray_runtime_raylet_RayletClientImpl.h | 8 ++-- 4 files changed, 83 insertions(+), 29 deletions(-) create mode 100644 java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index f4194bffb..96b7657db 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -106,11 +106,8 @@ public class RayletClientImpl implements RayletClient { LOGGER.debug("Blocked on objects for task {}, object IDs are {}", UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds); } - int ret = nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds), + nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds), fetchOnly, currentTaskId.getBytes()); - if (ret != 0) { - throw new RayException("Connection closed by Raylet"); - } } @Override @@ -302,27 +299,28 @@ public class RayletClientImpl implements RayletClient { boolean isWorker, byte[] driverTaskId); private static native void nativeSubmitTask(long client, byte[] cursorId, ByteBuffer taskBuff, - int pos, int taskSize); + int pos, int taskSize) throws RayException; // return TaskInfo (in FlatBuffer) - private static native byte[] nativeGetTask(long client); + private static native byte[] nativeGetTask(long client) throws RayException; - private static native void nativeDestroy(long client); + private static native void nativeDestroy(long client) throws RayException; - private static native int nativeFetchOrReconstruct(long client, byte[][] objectIds, - boolean fetchOnly, byte[] currentTaskId); + private static native void nativeFetchOrReconstruct(long client, byte[][] objectIds, + boolean fetchOnly, byte[] currentTaskId) throws RayException; - private static native void nativeNotifyUnblocked(long client, byte[] currentTaskId); + private static native void nativeNotifyUnblocked(long client, byte[] currentTaskId) + throws RayException; private static native void nativePutObject(long client, byte[] taskId, byte[] objectId); private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds, - int numReturns, int timeout, boolean waitLocal, byte[] currentTaskId); + int numReturns, int timeout, boolean waitLocal, byte[] currentTaskId) throws RayException; private static native byte[] nativeGenerateTaskId(byte[] driverId, byte[] parentTaskId, int taskIndex); private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds, - boolean localOnly); + boolean localOnly) throws RayException; } diff --git a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java new file mode 100644 index 000000000..57b176e9a --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java @@ -0,0 +1,46 @@ +package org.ray.api.test; + +import com.google.common.collect.ImmutableList; +import java.util.concurrent.TimeUnit; +import org.ray.api.Ray; +import org.ray.api.RayObject; +import org.ray.api.exception.RayException; +import org.ray.api.id.UniqueId; +import org.ray.runtime.RayObjectImpl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ClientExceptionTest extends BaseTest { + + private static final Logger LOGGER = LoggerFactory.getLogger(ClientExceptionTest.class); + + @Test + public void testWaitAndCrash() { + UniqueId randomId = UniqueId.randomId(); + RayObject notExisting = new RayObjectImpl(randomId); + + Thread thread = new Thread(() -> { + try { + TimeUnit.SECONDS.sleep(1); + Ray.shutdown(); + } catch (InterruptedException e) { + LOGGER.error("Got InterruptedException when sleeping, exit right now."); + throw new RuntimeException("Got InterruptedException when sleeping.", e); + } + }); + thread.start(); + try { + Ray.wait(ImmutableList.of(notExisting), 1, 2000); + Assert.fail("Should not reach here"); + } catch (RayException e) { + LOGGER.debug(String.format("Expected runtime exception: {}", e)); + } + try { + thread.join(); + } catch (Exception e) { + LOGGER.error(String.format("Excpetion caught: {}", e)); + } + } +} diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 212f91a84..634d04a9a 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -31,6 +31,14 @@ class UniqueIdFromJByteArray { } }; +inline void ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status, + const std::string &message) { + if (!status.ok()) { + jclass exception_class = env->FindClass("org/ray/api/exception/RayException"); + env->ThrowNew(exception_class, message.c_str()); + } +} + /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeInit @@ -67,7 +75,8 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit auto data = reinterpret_cast(env->GetDirectBufferAddress(taskBuff)) + pos; ray::raylet::TaskSpecification task_spec(std::string(data, taskSize)); auto status = raylet_client->SubmitTask(execution_dependencies, task_spec); - RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to submit a task to raylet."); + ThrowRayExceptionIfNotOK(env, status, + "[RayletClient] Failed to submit a task to raylet."); } /* @@ -82,7 +91,8 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_native // TODO: handle actor failure later std::unique_ptr spec; auto status = raylet_client->GetTask(&spec); - RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to get a task from raylet."); + ThrowRayExceptionIfNotOK(env, status, + "[RayletClient] Failed to get a task from raylet."); // We serialize the task specification using flatbuffers and then parse the // resulting string. This awkwardness is due to the fact that the Java @@ -112,19 +122,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_native * Signature: (J)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy( - JNIEnv *, jclass, jlong client) { + JNIEnv *env, jclass, jlong client) { auto raylet_client = reinterpret_cast(client); - RAY_CHECK_OK_PREPEND(raylet_client->Disconnect(), - "[RayletClient] Failed to disconnect."); + ThrowRayExceptionIfNotOK(env, raylet_client->Disconnect(), + "[RayletClient] Failed to disconnect."); delete raylet_client; } /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeFetchOrReconstruct - * Signature: (J[[BZ[B)I + * Signature: (J[[BZ[B)V */ -JNIEXPORT jint JNICALL +JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean fetchOnly, jbyteArray currentTaskId) { @@ -141,26 +151,26 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( auto raylet_client = reinterpret_cast(client); auto status = raylet_client->FetchOrReconstruct(object_ids, fetchOnly, *current_task_id.PID); - return static_cast(status.code()); + ThrowRayExceptionIfNotOK(env, status, "[RayletClient] Failed to fetch or reconstruct."); } /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeNotifyUnblocked - * Signature: (J)V + * Signature: (J[B)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked( JNIEnv *env, jclass, jlong client, jbyteArray currentTaskId) { UniqueIdFromJByteArray current_task_id(env, currentTaskId); auto raylet_client = reinterpret_cast(client); auto status = raylet_client->NotifyUnblocked(*current_task_id.PID); - RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to notify unblocked."); + ThrowRayExceptionIfNotOK(env, status, "[RayletClient] Failed to notify unblocked."); } /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeWaitObject - * Signature: (J[[BIIZ)[Z + * Signature: (J[[BIIZ[B)[Z */ JNIEXPORT jbooleanArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( @@ -184,7 +194,7 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( auto status = raylet_client->Wait(object_ids, numReturns, timeoutMillis, static_cast(isWaitLocal), *current_task_id.PID, &result); - RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to wait for objects."); + ThrowRayExceptionIfNotOK(env, status, "[RayletClient] Failed to wait for objects."); // Convert result to java object. jboolean put_value = true; @@ -255,7 +265,7 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( } auto raylet_client = reinterpret_cast(client); auto status = raylet_client->FreeObjects(object_ids, localOnly); - RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to free objects."); + ThrowRayExceptionIfNotOK(env, status, "[RayletClient] Failed to free objects."); } #ifdef __cplusplus diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h index 0d5d0b9cb..fff804a04 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h @@ -42,9 +42,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(JNIEnv *, jclass, jlo /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeFetchOrReconstruct - * Signature: (J[[BZ[B)I + * Signature: (J[[BZ[B)V */ -JNIEXPORT jint JNICALL +JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(JNIEnv *, jclass, jlong, jobjectArray, jboolean, @@ -53,7 +53,7 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(JNIEnv *, /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeNotifyUnblocked - * Signature: (J)V + * Signature: (J[B)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked( JNIEnv *, jclass, jlong, jbyteArray); @@ -61,7 +61,7 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotify /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeWaitObject - * Signature: (J[[BIIZ)[Z + * Signature: (J[[BIIZ[B)[Z */ JNIEXPORT jbooleanArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(JNIEnv *, jclass, jlong,