diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 4f12cd477..5fa9ced27 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -50,6 +50,11 @@ public class RayletClientImpl implements RayletClient { @Override public WaitResult wait(List> waitFor, int numReturns, int timeoutMs, UniqueId currentTaskId) { + Preconditions.checkNotNull(waitFor); + if (waitFor.isEmpty()) { + return new WaitResult<>(new ArrayList<>(), new ArrayList<>()); + } + List ids = new ArrayList<>(); for (RayObject element : waitFor) { ids.add(element.getId()); diff --git a/java/test/src/main/java/org/ray/api/test/WaitTest.java b/java/test/src/main/java/org/ray/api/test/WaitTest.java index e3ab0e476..49b5f8365 100644 --- a/java/test/src/main/java/org/ray/api/test/WaitTest.java +++ b/java/test/src/main/java/org/ray/api/test/WaitTest.java @@ -1,6 +1,7 @@ package org.ray.api.test; import com.google.common.collect.ImmutableList; +import java.util.ArrayList; import java.util.List; import org.junit.Assert; import org.junit.Test; @@ -56,4 +57,18 @@ public class WaitTest extends BaseTest { RayObject res = Ray.call(WaitTest::waitInWorker); res.get(); } + + @Test + public void testWaitForEmpty() { + WaitResult result = Ray.wait(new ArrayList<>()); + Assert.assertTrue(result.getReady().isEmpty()); + Assert.assertTrue(result.getUnready().isEmpty()); + + try { + Ray.wait(null); + Assert.fail(); + } catch (NullPointerException e) { + Assert.assertTrue(true); + } + } }