[Java] Add getAsyncContext and setAsyncContext API (#6439)

* Add getAsyncContext and setAsyncContext API

* address comment

* fix bug

* Add test case
This commit is contained in:
Kai Yang
2019-12-19 18:08:58 +08:00
committed by Hao Chen
parent 7e2addb424
commit 3bb680a719
6 changed files with 84 additions and 6 deletions
+20 -2
View File
@@ -119,7 +119,25 @@ public final class Ray extends RayCall {
}
/**
* If users want to use Ray API in there own threads, they should wrap their {@link Runnable}
* If users want to use Ray API in their own threads, call this method to get the async context
* and then call {@link #setAsyncContext} at the beginning of the new thread.
*
* @return The async context.
*/
public static Object getAsyncContext() {
return runtime.getAsyncContext();
}
/**
* Set the async context for the current thread.
* @param asyncContext The async context to set.
*/
public static void setAsyncContext(Object asyncContext) {
runtime.setAsyncContext(asyncContext);
}
/**
* If users want to use Ray API in their own threads, they should wrap their {@link Runnable}
* objects with this method.
*
* @param runnable The runnable to wrap.
@@ -130,7 +148,7 @@ public final class Ray extends RayCall {
}
/**
* If users want to use Ray API in there own threads, they should wrap their {@link Callable}
* If users want to use Ray API in their own threads, they should wrap their {@link Callable}
* objects with this method.
*
* @param callable The callable to wrap.
@@ -143,6 +143,10 @@ public interface RayRuntime {
RayPyActor createPyActor(String moduleName, String className, Object[] args,
ActorCreationOptions options);
Object getAsyncContext();
void setAsyncContext(Object asyncContext);
/**
* Wrap a {@link Runnable} with necessary context capture.
* @param runnable The runnable to wrap.
@@ -47,6 +47,15 @@ public class RayDevRuntime extends AbstractRayRuntime {
LOGGER.error("Not implemented under SINGLE_PROCESS mode.");
}
@Override
public Object getAsyncContext() {
return null;
}
@Override
public void setAsyncContext(Object asyncContext) {
}
private JobId nextJobId() {
return JobId.fromInt(jobCounter.getAndIncrement());
}
@@ -178,20 +178,30 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime {
return getCurrentRuntime().createPyActor(moduleName, className, args, options);
}
@Override
public Object getAsyncContext() {
return getCurrentRuntime();
}
@Override
public void setAsyncContext(Object asyncContext) {
currentThreadRuntime.set((RayNativeRuntime)asyncContext);
}
@Override
public Runnable wrapRunnable(Runnable runnable) {
RayNativeRuntime runtime = getCurrentRuntime();
Object asyncContext = getAsyncContext();
return () -> {
currentThreadRuntime.set(runtime);
setAsyncContext(asyncContext);
runnable.run();
};
}
@Override
public Callable wrapCallable(Callable callable) {
RayNativeRuntime runtime = getCurrentRuntime();
Object asyncContext = getAsyncContext();
return () -> {
currentThreadRuntime.set(runtime);
setAsyncContext(asyncContext);
return callable.call();
};
}
@@ -150,6 +150,15 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
}
@Override
public Object getAsyncContext() {
return null;
}
@Override
public void setAsyncContext(Object asyncContext) {
}
public void run() {
nativeRunTaskExecutor(nativeCoreWorkerPointer, taskExecutor);
}
@@ -168,4 +168,32 @@ public class MultiThreadingTest extends BaseTest {
}
}
private static boolean testGetAsyncContextAndSetAsyncContext() throws Exception {
final Object asyncContext = Ray.getAsyncContext();
final Object[] result = new Object[1];
Thread thread = new Thread(() -> {
try {
Ray.setAsyncContext(asyncContext);
Ray.put(0);
} catch (Exception e) {
result[0] = e;
}
});
thread.start();
thread.join();
if (result[0] instanceof Exception) {
throw (Exception) result[0];
}
return true;
}
public void testGetAsyncContextAndSetAsyncContextInDriver() throws Exception {
Assert.assertTrue(testGetAsyncContextAndSetAsyncContext());
}
public void testGetAsyncContextAndSetAsyncContextInWorker() {
RayObject<Boolean> obj = Ray.call(MultiThreadingTest::testGetAsyncContextAndSetAsyncContext);
Assert.assertTrue(obj.get());
}
}