diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java index 4cf4ffac6..375f43786 100644 --- a/java/api/src/main/java/org/ray/api/Ray.java +++ b/java/api/src/main/java/org/ray/api/Ray.java @@ -119,7 +119,25 @@ public final class Ray extends RayCall { } /** - * If users want to use Ray API in there own threads, they should wrap their {@link Runnable} + * If users want to use Ray API in their own threads, call this method to get the async context + * and then call {@link #setAsyncContext} at the beginning of the new thread. + * + * @return The async context. + */ + public static Object getAsyncContext() { + return runtime.getAsyncContext(); + } + + /** + * Set the async context for the current thread. + * @param asyncContext The async context to set. + */ + public static void setAsyncContext(Object asyncContext) { + runtime.setAsyncContext(asyncContext); + } + + /** + * If users want to use Ray API in their own threads, they should wrap their {@link Runnable} * objects with this method. * * @param runnable The runnable to wrap. @@ -130,7 +148,7 @@ public final class Ray extends RayCall { } /** - * If users want to use Ray API in there own threads, they should wrap their {@link Callable} + * If users want to use Ray API in their own threads, they should wrap their {@link Callable} * objects with this method. * * @param callable The callable to wrap. diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index 46d5ca842..059e85aea 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -143,6 +143,10 @@ public interface RayRuntime { RayPyActor createPyActor(String moduleName, String className, Object[] args, ActorCreationOptions options); + Object getAsyncContext(); + + void setAsyncContext(Object asyncContext); + /** * Wrap a {@link Runnable} with necessary context capture. * @param runnable The runnable to wrap. diff --git a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java index d34f6fc68..8582238f0 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java @@ -47,6 +47,15 @@ public class RayDevRuntime extends AbstractRayRuntime { LOGGER.error("Not implemented under SINGLE_PROCESS mode."); } + @Override + public Object getAsyncContext() { + return null; + } + + @Override + public void setAsyncContext(Object asyncContext) { + } + private JobId nextJobId() { return JobId.fromInt(jobCounter.getAndIncrement()); } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java index 47e69cb0e..335ed032a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java @@ -178,20 +178,30 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime { return getCurrentRuntime().createPyActor(moduleName, className, args, options); } + @Override + public Object getAsyncContext() { + return getCurrentRuntime(); + } + + @Override + public void setAsyncContext(Object asyncContext) { + currentThreadRuntime.set((RayNativeRuntime)asyncContext); + } + @Override public Runnable wrapRunnable(Runnable runnable) { - RayNativeRuntime runtime = getCurrentRuntime(); + Object asyncContext = getAsyncContext(); return () -> { - currentThreadRuntime.set(runtime); + setAsyncContext(asyncContext); runnable.run(); }; } @Override public Callable wrapCallable(Callable callable) { - RayNativeRuntime runtime = getCurrentRuntime(); + Object asyncContext = getAsyncContext(); return () -> { - currentThreadRuntime.set(runtime); + setAsyncContext(asyncContext); return callable.call(); }; } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index ebbcebded..fb05e987a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -150,6 +150,15 @@ public final class RayNativeRuntime extends AbstractRayRuntime { nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes()); } + @Override + public Object getAsyncContext() { + return null; + } + + @Override + public void setAsyncContext(Object asyncContext) { + } + public void run() { nativeRunTaskExecutor(nativeCoreWorkerPointer, taskExecutor); } diff --git a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java index bc5a5f3e3..1ba045f51 100644 --- a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java +++ b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java @@ -168,4 +168,32 @@ public class MultiThreadingTest extends BaseTest { } } + private static boolean testGetAsyncContextAndSetAsyncContext() throws Exception { + final Object asyncContext = Ray.getAsyncContext(); + final Object[] result = new Object[1]; + Thread thread = new Thread(() -> { + try { + Ray.setAsyncContext(asyncContext); + Ray.put(0); + } catch (Exception e) { + result[0] = e; + } + }); + thread.start(); + thread.join(); + if (result[0] instanceof Exception) { + throw (Exception) result[0]; + } + return true; + } + + public void testGetAsyncContextAndSetAsyncContextInDriver() throws Exception { + Assert.assertTrue(testGetAsyncContextAndSetAsyncContext()); + } + + public void testGetAsyncContextAndSetAsyncContextInWorker() { + RayObject obj = Ray.call(MultiThreadingTest::testGetAsyncContextAndSetAsyncContext); + Assert.assertTrue(obj.get()); + } + }