diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index e44fd1014..b3a77f0dc 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -6,11 +6,11 @@ import java.util.ArrayList; import java.util.Deque; import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -47,7 +47,7 @@ public class MockRayletClient implements RayletClient { store.addObjectPutCallback(this::onObjectPut); // The thread pool that executes tasks in parallel. exec = Executors.newFixedThreadPool(numberThreads); - idleWorkers = new LinkedList<>(); + idleWorkers = new ConcurrentLinkedDeque<>(); actorWorkers = new HashMap<>(); currentWorker = new ThreadLocal<>(); } @@ -69,19 +69,19 @@ public class MockRayletClient implements RayletClient { /** * Get a worker from the worker pool to run the given task. */ - private Worker getWorker(TaskSpec task) { + private synchronized Worker getWorker(TaskSpec task) { Worker worker; if (task.isActorTask()) { worker = actorWorkers.get(task.actorId); } else { - if (idleWorkers.size() > 0) { + if (task.isActorCreationTask()) { + worker = new Worker(runtime); + actorWorkers.put(task.actorCreationId, worker); + } else if (idleWorkers.size() > 0) { worker = idleWorkers.pop(); } else { worker = new Worker(runtime); } - if (task.isActorCreationTask()) { - actorWorkers.put(task.actorCreationId, worker); - } } currentWorker.set(worker); return worker; diff --git a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java index 6289d1cd7..d90b20a7b 100644 --- a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java +++ b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java @@ -15,13 +15,17 @@ import org.ray.api.RayObject; import org.ray.api.TestUtils; import org.ray.api.WaitResult; import org.ray.api.annotation.RayRemote; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.testng.Assert; import org.testng.annotations.Test; public class MultiThreadingTest extends BaseTest { - private static final int LOOP_COUNTER = 1000; + private static final Logger LOGGER = LoggerFactory.getLogger(MultiThreadingTest.class); + + private static final int LOOP_COUNTER = 100; private static final int NUM_THREADS = 20; @RayRemote @@ -55,6 +59,21 @@ public class MultiThreadingTest extends BaseTest { Assert.assertEquals(arg, (int) obj.get()); }, LOOP_COUNTER); + // Test creating multi actors + runTestCaseInMultipleThreads(() -> { + int arg = random.nextInt(); + RayActor echoActor1 = Ray.createActor(Echo::new); + try { + // Sleep a while to test the case that another actor is created before submitting + // tasks to this actor. + TimeUnit.MILLISECONDS.sleep(10); + } catch (InterruptedException e) { + LOGGER.warn("Got exception while sleeping.", e); + } + RayObject obj = Ray.call(Echo::echo, echoActor1, arg); + Assert.assertEquals(arg, (int) obj.get()); + }, 1); + // Test put and get. runTestCaseInMultipleThreads(() -> { int arg = random.nextInt(); @@ -74,8 +93,6 @@ public class MultiThreadingTest extends BaseTest { @Test public void testInDriver() { - // TODO(hchen): Fix this test under single-process mode. - TestUtils.skipTestUnderSingleProcess(); testMultiThreading(); }