diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index dbe2cd3b6..0da3dbe80 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -6,7 +6,6 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.ray.api.RayObject; import org.ray.api.WaitResult; -import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.objectstore.MockObjectStore; @@ -68,7 +67,7 @@ public class MockRayletClient implements RayletClient { @Override public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - UniqueId currentTaskId) throws RayException { + UniqueId currentTaskId) { } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index 3e3f4f1e7..b68fe0182 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -3,7 +3,6 @@ package org.ray.runtime.raylet; import java.util.List; import org.ray.api.RayObject; import org.ray.api.WaitResult; -import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.task.TaskSpec; @@ -16,8 +15,7 @@ public interface RayletClient { TaskSpec getTask(); - void fetchOrReconstruct(List objectIds, boolean fetchOnly, UniqueId currentTaskId) - throws RayException; + void fetchOrReconstruct(List objectIds, boolean fetchOnly, UniqueId currentTaskId); void notifyUnblocked(UniqueId currentTaskId); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index b1932c08a..9757b4f07 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -90,7 +90,7 @@ public class RayletClientImpl implements RayletClient { @Override public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - UniqueId currentTaskId) throws RayException { + UniqueId currentTaskId) { if (RayLog.core.isDebugEnabled()) { RayLog.core.debug("Blocked on objects for task {}, object IDs are {}", UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds); diff --git a/src/ray/raylet/lib/python/raylet_extension.cc b/src/ray/raylet/lib/python/raylet_extension.cc index 1a6fc9b29..c5f7eafd5 100644 --- a/src/ray/raylet/lib/python/raylet_extension.cc +++ b/src/ray/raylet/lib/python/raylet_extension.cc @@ -98,7 +98,7 @@ static PyObject *PyRayletClient_FetchOrReconstruct(PyRayletClient *self, PyObjec << "raylet client may be closed, check raylet status. error message: " << status.ToString(); PyErr_SetString(CommonError, stream.str().c_str()); - Py_RETURN_NONE; + return NULL; } } diff --git a/test/failure_test.py b/test/failure_test.py index 2c4a92bd4..3efb9bc69 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -8,9 +8,11 @@ import os import ray import sys import tempfile +import threading import time import ray.ray_constants as ray_constants +from ray.utils import _random_string import pytest @@ -611,3 +613,18 @@ def test_warning_for_dead_node(ray_start_two_nodes): } assert client_ids == warning_client_ids + + +def test_raylet_crash_when_get(ray_start_regular): + nonexistent_id = ray.ObjectID(_random_string()) + + def sleep_to_kill_raylet(): + # Don't kill raylet before default workers get connected. + time.sleep(2) + ray.services.all_processes[ray.services.PROCESS_TYPE_RAYLET][0].kill() + + thread = threading.Thread(target=sleep_to_kill_raylet) + thread.start() + with pytest.raises(Exception, match=r".*raylet client may be closed.*"): + ray.get(nonexistent_id) + thread.join()