mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 12:20:16 +08:00
[Java] Add getAsyncContext and setAsyncContext API (#6439)
* Add getAsyncContext and setAsyncContext API * address comment * fix bug * Add test case
This commit is contained in:
@@ -119,7 +119,25 @@ public final class Ray extends RayCall {
|
||||
}
|
||||
|
||||
/**
|
||||
* If users want to use Ray API in there own threads, they should wrap their {@link Runnable}
|
||||
* If users want to use Ray API in their own threads, call this method to get the async context
|
||||
* and then call {@link #setAsyncContext} at the beginning of the new thread.
|
||||
*
|
||||
* @return The async context.
|
||||
*/
|
||||
public static Object getAsyncContext() {
|
||||
return runtime.getAsyncContext();
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the async context for the current thread.
|
||||
* @param asyncContext The async context to set.
|
||||
*/
|
||||
public static void setAsyncContext(Object asyncContext) {
|
||||
runtime.setAsyncContext(asyncContext);
|
||||
}
|
||||
|
||||
/**
|
||||
* If users want to use Ray API in their own threads, they should wrap their {@link Runnable}
|
||||
* objects with this method.
|
||||
*
|
||||
* @param runnable The runnable to wrap.
|
||||
@@ -130,7 +148,7 @@ public final class Ray extends RayCall {
|
||||
}
|
||||
|
||||
/**
|
||||
* If users want to use Ray API in there own threads, they should wrap their {@link Callable}
|
||||
* If users want to use Ray API in their own threads, they should wrap their {@link Callable}
|
||||
* objects with this method.
|
||||
*
|
||||
* @param callable The callable to wrap.
|
||||
|
||||
@@ -143,6 +143,10 @@ public interface RayRuntime {
|
||||
RayPyActor createPyActor(String moduleName, String className, Object[] args,
|
||||
ActorCreationOptions options);
|
||||
|
||||
Object getAsyncContext();
|
||||
|
||||
void setAsyncContext(Object asyncContext);
|
||||
|
||||
/**
|
||||
* Wrap a {@link Runnable} with necessary context capture.
|
||||
* @param runnable The runnable to wrap.
|
||||
|
||||
@@ -47,6 +47,15 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
LOGGER.error("Not implemented under SINGLE_PROCESS mode.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getAsyncContext() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
}
|
||||
|
||||
private JobId nextJobId() {
|
||||
return JobId.fromInt(jobCounter.getAndIncrement());
|
||||
}
|
||||
|
||||
@@ -178,20 +178,30 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime {
|
||||
return getCurrentRuntime().createPyActor(moduleName, className, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getAsyncContext() {
|
||||
return getCurrentRuntime();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
currentThreadRuntime.set((RayNativeRuntime)asyncContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Runnable wrapRunnable(Runnable runnable) {
|
||||
RayNativeRuntime runtime = getCurrentRuntime();
|
||||
Object asyncContext = getAsyncContext();
|
||||
return () -> {
|
||||
currentThreadRuntime.set(runtime);
|
||||
setAsyncContext(asyncContext);
|
||||
runnable.run();
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public Callable wrapCallable(Callable callable) {
|
||||
RayNativeRuntime runtime = getCurrentRuntime();
|
||||
Object asyncContext = getAsyncContext();
|
||||
return () -> {
|
||||
currentThreadRuntime.set(runtime);
|
||||
setAsyncContext(asyncContext);
|
||||
return callable.call();
|
||||
};
|
||||
}
|
||||
|
||||
@@ -150,6 +150,15 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getAsyncContext() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
}
|
||||
|
||||
public void run() {
|
||||
nativeRunTaskExecutor(nativeCoreWorkerPointer, taskExecutor);
|
||||
}
|
||||
|
||||
@@ -168,4 +168,32 @@ public class MultiThreadingTest extends BaseTest {
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean testGetAsyncContextAndSetAsyncContext() throws Exception {
|
||||
final Object asyncContext = Ray.getAsyncContext();
|
||||
final Object[] result = new Object[1];
|
||||
Thread thread = new Thread(() -> {
|
||||
try {
|
||||
Ray.setAsyncContext(asyncContext);
|
||||
Ray.put(0);
|
||||
} catch (Exception e) {
|
||||
result[0] = e;
|
||||
}
|
||||
});
|
||||
thread.start();
|
||||
thread.join();
|
||||
if (result[0] instanceof Exception) {
|
||||
throw (Exception) result[0];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
public void testGetAsyncContextAndSetAsyncContextInDriver() throws Exception {
|
||||
Assert.assertTrue(testGetAsyncContextAndSetAsyncContext());
|
||||
}
|
||||
|
||||
public void testGetAsyncContextAndSetAsyncContextInWorker() {
|
||||
RayObject<Boolean> obj = Ray.call(MultiThreadingTest::testGetAsyncContextAndSetAsyncContext);
|
||||
Assert.assertTrue(obj.get());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user