diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index b3adaa11c..e667a3229 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -108,9 +108,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public List get(List objectIds) { boolean wasBlocked = false; - // TODO(swang): If we are not on the main thread, then we should generate a - // random task ID to pass to the backend. - UniqueId taskId = workerContext.getCurrentTask().taskId; + UniqueId taskId = workerContext.getCurrentThreadTaskId(); try { int numObjectIds = objectIds.size(); @@ -218,10 +216,8 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public WaitResult wait(List> waitList, int numReturns, int timeoutMs) { - // TODO(swang): If we are not on the main thread, then we should generate a - // random task ID to pass to the backend. - return rayletClient.wait(waitList, numReturns, timeoutMs, - workerContext.getCurrentTask().taskId); + return rayletClient.wait(waitList, numReturns, + timeoutMs, workerContext.getCurrentThreadTaskId()); } @Override @@ -237,9 +233,12 @@ public abstract class AbstractRayRuntime implements RayRuntime { throw new IllegalArgumentException("Unsupported actor type: " + actor.getClass().getName()); } RayActorImpl actorImpl = (RayActorImpl)actor; - TaskSpec spec = createTaskSpec(func, actorImpl, args, false, null); - spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor()); - actorImpl.setTaskCursor(spec.returnIds[1]); + TaskSpec spec; + synchronized (actor) { + spec = createTaskSpec(func, actorImpl, args, false, null); + spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor()); + actorImpl.setTaskCursor(spec.returnIds[1]); + } rayletClient.submitTask(spec); return new RayObjectImpl(spec.returnIds[0]); } @@ -342,4 +341,8 @@ public abstract class AbstractRayRuntime implements RayRuntime { public FunctionManager getFunctionManager() { return functionManager; } + + public RayConfig getRayConfig() { + return rayConfig; + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java index 4799baa94..2b93b17b2 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java @@ -11,10 +11,12 @@ public class RayDevRuntime extends AbstractRayRuntime { super(rayConfig); } + private MockObjectStore store; + @Override public void start() { - MockObjectStore store = new MockObjectStore(this); - objectStoreProxy = new ObjectStoreProxy(this, store); + store = new MockObjectStore(this); + objectStoreProxy = new ObjectStoreProxy(this, null); rayletClient = new MockRayletClient(this, store); } @@ -22,4 +24,8 @@ public class RayDevRuntime extends AbstractRayRuntime { public void shutdown() { // nothing to do } + + public MockObjectStore getObjectStore() { + return store; + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index fd88bde35..f47993246 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -74,8 +74,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { } kvStore = new RedisClient(rayConfig.getRedisAddress()); - ObjectStoreLink store = new PlasmaClient(rayConfig.objectStoreSocketName, "", 0); - objectStoreProxy = new ObjectStoreProxy(this, store); + objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName); rayletClient = new RayletClientImpl( rayConfig.rayletSocketName, diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index fdb507689..785581086 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -1,11 +1,17 @@ package org.ray.runtime; +import com.google.common.base.Preconditions; import java.util.HashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import org.ray.api.id.UniqueId; import org.ray.runtime.config.WorkerMode; import org.ray.runtime.task.TaskSpec; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class WorkerContext { + private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class); /** * Worker id. @@ -25,19 +31,53 @@ public class WorkerContext { /** * How many puts have been done by current task. */ - private int currentTaskPutCount; + private AtomicInteger currentTaskPutCount; /** * How many calls have been done by current task. */ - private int currentTaskCallCount; + private AtomicInteger currentTaskCallCount; + + /** + * The ID of main thread which created the worker context. + */ + private long mainThreadId; + /** + * If the multi-threading warning message has been logged. + */ + private AtomicBoolean multiThreadingWarned; public WorkerContext(WorkerMode workerMode, UniqueId driverId) { workerId = workerMode == WorkerMode.DRIVER ? driverId : UniqueId.randomId(); - currentTaskPutCount = 0; - currentTaskCallCount = 0; + currentTaskPutCount = new AtomicInteger(0); + currentTaskCallCount = new AtomicInteger(0); currentClassLoader = null; currentTask = createDummyTask(workerMode, driverId); + mainThreadId = Thread.currentThread().getId(); + multiThreadingWarned = new AtomicBoolean(false); + } + + /** + * Get the current thread's task ID. + * This returns the assigned task ID if called on the main thread, else a + * random task ID. + */ + public UniqueId getCurrentThreadTaskId() { + UniqueId taskId; + if (Thread.currentThread().getId() == mainThreadId) { + taskId = currentTask.taskId; + } else { + taskId = UniqueId.randomId(); + if (multiThreadingWarned.compareAndSet(false, true)) { + LOGGER.warn("Calling Ray.get or Ray.wait in a separate thread " + + "may lead to deadlock if the main thread blocks on this " + + "thread and there are not enough resources to execute " + + "more tasks"); + } + } + + Preconditions.checkState(!taskId.isNil()); + return taskId; } public void setWorkerId(UniqueId workerId) { @@ -49,11 +89,11 @@ public class WorkerContext { } public int nextPutIndex() { - return ++currentTaskPutCount; + return currentTaskPutCount.incrementAndGet(); } public int nextCallIndex() { - return ++currentTaskCallCount; + return currentTaskCallCount.incrementAndGet(); } public UniqueId getCurrentWorkerId() { @@ -66,6 +106,8 @@ public class WorkerContext { public void setCurrentTask(TaskSpec currentTask) { this.currentTask = currentTask; + currentTaskCallCount.set(0); + currentTaskPutCount.set(0); } public void setCurrentClassLoader(ClassLoader currentClassLoader) { diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index be33150c7..2bbb457dd 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -3,10 +3,13 @@ package org.ray.runtime.objectstore; import java.util.ArrayList; import java.util.List; import org.apache.arrow.plasma.ObjectStoreLink; +import org.apache.arrow.plasma.PlasmaClient; import org.apache.commons.lang3.tuple.Pair; import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.RayDevRuntime; +import org.ray.runtime.config.RunMode; import org.ray.runtime.util.Serializer; import org.ray.runtime.util.UniqueIdUtil; @@ -19,11 +22,18 @@ public class ObjectStoreProxy { private static final int GET_TIMEOUT_MS = 1000; private final AbstractRayRuntime runtime; - private final ObjectStoreLink store; - public ObjectStoreProxy(AbstractRayRuntime runtime, ObjectStoreLink store) { + private static ThreadLocal objectStore; + + public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) { this.runtime = runtime; - this.store = store; + objectStore = ThreadLocal.withInitial(() -> { + if (runtime.getRayConfig().runMode == RunMode.CLUSTER) { + return new PlasmaClient(storeSocketName, "", 0); + } else { + return ((RayDevRuntime) runtime).getObjectStore(); + } + }); } public Pair get(UniqueId objectId, boolean isMetadata) @@ -33,10 +43,10 @@ public class ObjectStoreProxy { public Pair get(UniqueId id, int timeoutMs, boolean isMetadata) throws RayException { - byte[] obj = store.get(id.getBytes(), timeoutMs, isMetadata); + byte[] obj = objectStore.get().get(id.getBytes(), timeoutMs, isMetadata); if (obj != null) { T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader()); - store.release(id.getBytes()); + objectStore.get().release(id.getBytes()); if (t instanceof RayException) { throw (RayException) t; } @@ -53,13 +63,13 @@ public class ObjectStoreProxy { public List> get(List ids, int timeoutMs, boolean isMetadata) throws RayException { - List objs = store.get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata); + List objs = objectStore.get().get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata); List> ret = new ArrayList<>(); for (int i = 0; i < objs.size(); i++) { byte[] obj = objs.get(i); if (obj != null) { T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader()); - store.release(ids.get(i).getBytes()); + objectStore.get().release(ids.get(i).getBytes()); if (t instanceof RayException) { throw (RayException) t; } @@ -72,11 +82,11 @@ public class ObjectStoreProxy { } public void put(UniqueId id, Object obj, Object metadata) { - store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata)); + objectStore.get().put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata)); } public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) { - store.put(id.getBytes(), obj, metadata); + objectStore.get().put(id.getBytes(), obj, metadata); } public enum GetStatus { diff --git a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java new file mode 100644 index 000000000..c95e7093c --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java @@ -0,0 +1,112 @@ +package org.ray.api.test; + +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import org.junit.Assert; +import org.junit.Test; +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.RayObject; +import org.ray.api.WaitResult; +import org.ray.api.annotation.RayRemote; + + +public class MultiThreadingTest extends BaseTest { + + private static final int LOOP_COUNTER = 1000; + private static final int NUM_THREADS = 20; + + @RayRemote + public static Integer echo(int num) { + return num; + } + + @RayRemote + public static class Echo { + + @RayRemote + public Integer echo(int num) { + return num; + } + } + + public static String testMultiThreading() { + Random random = new Random(); + // Test calling normal functions. + runTestCaseInMultipleThreads(() -> { + int arg = random.nextInt(); + RayObject obj = Ray.call(MultiThreadingTest::echo, arg); + Assert.assertEquals(arg, (int) obj.get()); + }, LOOP_COUNTER); + + // Test calling actors. + RayActor echoActor = Ray.createActor(Echo::new); + runTestCaseInMultipleThreads(() -> { + int arg = random.nextInt(); + RayObject obj = Ray.call(Echo::echo, echoActor, arg); + Assert.assertEquals(arg, (int) obj.get()); + }, LOOP_COUNTER); + + // Test put and get. + runTestCaseInMultipleThreads(() -> { + int arg = random.nextInt(); + RayObject obj = Ray.put(arg); + Assert.assertEquals(arg, (int) Ray.get(obj.getId())); + }, LOOP_COUNTER); + + // Test wait for one object in multi threads. + RayObject obj = Ray.call(MultiThreadingTest::echo, 100); + runTestCaseInMultipleThreads(() -> { + WaitResult result = Ray.wait(ImmutableList.of(obj), 1, 1000); + Assert.assertEquals(1, result.getReady().size()); + }, 1); + + return "ok"; + } + + @Test + public void testInDriver() { + testMultiThreading(); + } + + @Test + public void testInWorker() { + RayObject obj = Ray.call(MultiThreadingTest::testMultiThreading); + Assert.assertEquals("ok", obj.get()); + } + + private static void runTestCaseInMultipleThreads(Runnable testCase, int numRepeats) { + ExecutorService service = Executors.newFixedThreadPool(NUM_THREADS); + + try { + List> futures = new ArrayList<>(); + for (int i = 0; i < NUM_THREADS; i++) { + Callable task = () -> { + for (int j = 0; j < numRepeats; j++) { + TimeUnit.MILLISECONDS.sleep(1); + testCase.run(); + } + return "ok"; + }; + futures.add(service.submit(task)); + } + for (Future future : futures) { + try { + Assert.assertEquals(future.get(), "ok"); + } catch (Exception e) { + throw new RuntimeException("Test case failed.", e); + } + } + } finally { + service.shutdown(); + } + } + +}