diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index 0da3dbe80..dbe2cd3b6 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -6,6 +6,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.objectstore.MockObjectStore; @@ -67,7 +68,7 @@ public class MockRayletClient implements RayletClient { @Override public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - UniqueId currentTaskId) { + UniqueId currentTaskId) throws RayException { } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index b68fe0182..3e3f4f1e7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -3,6 +3,7 @@ package org.ray.runtime.raylet; import java.util.List; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.task.TaskSpec; @@ -15,7 +16,8 @@ public interface RayletClient { TaskSpec getTask(); - void fetchOrReconstruct(List objectIds, boolean fetchOnly, UniqueId currentTaskId); + void fetchOrReconstruct(List objectIds, boolean fetchOnly, UniqueId currentTaskId) + throws RayException; void notifyUnblocked(UniqueId currentTaskId); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 91937ba14..cd4f3fd31 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.Map; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.generated.Arg; @@ -89,13 +90,16 @@ public class RayletClientImpl implements RayletClient { @Override public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - UniqueId currentTaskId) { + UniqueId currentTaskId) throws RayException { if (RayLog.core.isDebugEnabled()) { RayLog.core.debug("Blocked on objects for task {}, object IDs are {}", UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds); } - nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds), + int ret = nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds), fetchOnly, currentTaskId.getBytes()); + if (ret != 0) { + throw new RayException("Connection closed by Raylet"); + } } @Override @@ -274,7 +278,7 @@ public class RayletClientImpl implements RayletClient { private static native void nativeDestroy(long client); - private static native void nativeFetchOrReconstruct(long client, byte[][] objectIds, + private static native int nativeFetchOrReconstruct(long client, byte[][] objectIds, boolean fetchOnly, byte[] currentTaskId); private static native void nativeNotifyUnblocked(long client, byte[] currentTaskId); diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 12388d181..c3b1b4475 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -119,9 +119,9 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestro /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeFetchOrReconstruct - * Signature: (J[[BZ)V + * Signature: (J[[BZ[B)I */ -JNIEXPORT void JNICALL +JNIEXPORT jint JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean fetchOnly, jbyteArray currentTaskId) { @@ -136,7 +136,8 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( } UniqueIdFromJByteArray current_task_id(env, currentTaskId); auto conn = reinterpret_cast(client); - local_scheduler_fetch_or_reconstruct(conn, object_ids, fetchOnly, *current_task_id.PID); + return local_scheduler_fetch_or_reconstruct(conn, object_ids, fetchOnly, + *current_task_id.PID); } /* diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h index 8940046ce..0d5d0b9cb 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h @@ -42,9 +42,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(JNIEnv *, jclass, jlo /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeFetchOrReconstruct - * Signature: (J[[BZ)V + * Signature: (J[[BZ[B)I */ -JNIEXPORT void JNICALL +JNIEXPORT jint JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(JNIEnv *, jclass, jlong, jobjectArray, jboolean, diff --git a/src/ray/raylet/lib/python/local_scheduler_extension.cc b/src/ray/raylet/lib/python/local_scheduler_extension.cc index 05d24cdb4..8d0480092 100644 --- a/src/ray/raylet/lib/python/local_scheduler_extension.cc +++ b/src/ray/raylet/lib/python/local_scheduler_extension.cc @@ -1,4 +1,5 @@ #include +#include #include "common_extension.h" #include "config_extension.h" @@ -92,10 +93,19 @@ static PyObject *PyLocalSchedulerClient_fetch_or_reconstruct(PyObject *self, } object_ids.push_back(object_id); } - local_scheduler_fetch_or_reconstruct( + int ret = local_scheduler_fetch_or_reconstruct( reinterpret_cast(self)->local_scheduler_connection, object_ids, fetch_only, current_task_id); - Py_RETURN_NONE; + if (ret == 0) { + Py_RETURN_NONE; + } else { + std::ostringstream stream; + stream << "local_scheduler_fetch_or_reconstruct failed: " + << "local scheduler connection may be closed, " + << "check raylet status. return value: " << ret; + PyErr_SetString(CommonError, stream.str().c_str()); + Py_RETURN_NONE; + } } static PyObject *PyLocalSchedulerClient_notify_unblocked(PyObject *self, PyObject *args) { diff --git a/src/ray/raylet/local_scheduler_client.cc b/src/ray/raylet/local_scheduler_client.cc index ec9434cbc..584379090 100644 --- a/src/ray/raylet/local_scheduler_client.cc +++ b/src/ray/raylet/local_scheduler_client.cc @@ -306,18 +306,16 @@ void local_scheduler_task_done(LocalSchedulerConnection *conn) { &conn->write_mutex); } -void local_scheduler_fetch_or_reconstruct(LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool fetch_only, - const TaskID ¤t_task_id) { +int local_scheduler_fetch_or_reconstruct(LocalSchedulerConnection *conn, + const std::vector &object_ids, + bool fetch_only, const TaskID ¤t_task_id) { flatbuffers::FlatBufferBuilder fbb; auto object_ids_message = to_flatbuf(fbb, object_ids); auto message = ray::protocol::CreateFetchOrReconstruct( fbb, object_ids_message, fetch_only, to_flatbuf(fbb, current_task_id)); fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::FetchOrReconstruct), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - /* TODO(swang): Propagate the error. */ + return write_message(conn->conn, static_cast(MessageType::FetchOrReconstruct), + fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); } void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn, diff --git a/src/ray/raylet/local_scheduler_client.h b/src/ray/raylet/local_scheduler_client.h index 66c76f37a..0f975de5c 100644 --- a/src/ray/raylet/local_scheduler_client.h +++ b/src/ray/raylet/local_scheduler_client.h @@ -97,11 +97,11 @@ void local_scheduler_task_done(LocalSchedulerConnection *conn); * @param object_ids The IDs of the objects to reconstruct. * @param fetch_only Only fetch objects, do not reconstruct them. * @param current_task_id The task that needs the objects. - * @return Void. + * @return int 0 means correct, other numbers mean error. */ -void local_scheduler_fetch_or_reconstruct(LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool fetch_only, const TaskID ¤t_task_id); +int local_scheduler_fetch_or_reconstruct(LocalSchedulerConnection *conn, + const std::vector &object_ids, + bool fetch_only, const TaskID ¤t_task_id); /** * Notify the local scheduler that this client (worker) is no longer blocked.