[java] Fix getWorker and add create multi actors test (#4472)

This commit is contained in:
bibabolynn
2019-03-26 20:26:13 +08:00
committed by Hao Chen
parent 7d70cfba6e
commit 7a9d1546d4
2 changed files with 27 additions and 10 deletions
@@ -6,11 +6,11 @@ import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -47,7 +47,7 @@ public class MockRayletClient implements RayletClient {
store.addObjectPutCallback(this::onObjectPut);
// The thread pool that executes tasks in parallel.
exec = Executors.newFixedThreadPool(numberThreads);
idleWorkers = new LinkedList<>();
idleWorkers = new ConcurrentLinkedDeque<>();
actorWorkers = new HashMap<>();
currentWorker = new ThreadLocal<>();
}
@@ -69,19 +69,19 @@ public class MockRayletClient implements RayletClient {
/**
* Get a worker from the worker pool to run the given task.
*/
private Worker getWorker(TaskSpec task) {
private synchronized Worker getWorker(TaskSpec task) {
Worker worker;
if (task.isActorTask()) {
worker = actorWorkers.get(task.actorId);
} else {
if (idleWorkers.size() > 0) {
if (task.isActorCreationTask()) {
worker = new Worker(runtime);
actorWorkers.put(task.actorCreationId, worker);
} else if (idleWorkers.size() > 0) {
worker = idleWorkers.pop();
} else {
worker = new Worker(runtime);
}
if (task.isActorCreationTask()) {
actorWorkers.put(task.actorCreationId, worker);
}
}
currentWorker.set(worker);
return worker;
@@ -15,13 +15,17 @@ import org.ray.api.RayObject;
import org.ray.api.TestUtils;
import org.ray.api.WaitResult;
import org.ray.api.annotation.RayRemote;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.Test;
public class MultiThreadingTest extends BaseTest {
private static final int LOOP_COUNTER = 1000;
private static final Logger LOGGER = LoggerFactory.getLogger(MultiThreadingTest.class);
private static final int LOOP_COUNTER = 100;
private static final int NUM_THREADS = 20;
@RayRemote
@@ -55,6 +59,21 @@ public class MultiThreadingTest extends BaseTest {
Assert.assertEquals(arg, (int) obj.get());
}, LOOP_COUNTER);
// Test creating multi actors
runTestCaseInMultipleThreads(() -> {
int arg = random.nextInt();
RayActor<Echo> echoActor1 = Ray.createActor(Echo::new);
try {
// Sleep a while to test the case that another actor is created before submitting
// tasks to this actor.
TimeUnit.MILLISECONDS.sleep(10);
} catch (InterruptedException e) {
LOGGER.warn("Got exception while sleeping.", e);
}
RayObject<Integer> obj = Ray.call(Echo::echo, echoActor1, arg);
Assert.assertEquals(arg, (int) obj.get());
}, 1);
// Test put and get.
runTestCaseInMultipleThreads(() -> {
int arg = random.nextInt();
@@ -74,8 +93,6 @@ public class MultiThreadingTest extends BaseTest {
@Test
public void testInDriver() {
// TODO(hchen): Fix this test under single-process mode.
TestUtils.skipTestUnderSingleProcess();
testMultiThreading();
}