diff --git a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java index c0ecfdeba..5827e1379 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java @@ -14,8 +14,8 @@ import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import org.ray.api.RayActor; import org.ray.api.id.ActorId; @@ -50,7 +50,13 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { private final Object taskAndObjectLock = new Object(); private final RayDevRuntime runtime; private final LocalModeObjectStore objectStore; - private final ExecutorService exec; + + /// The thread pool to execute actor tasks. + private final Map actorTaskExecutorServices; + + /// The thread pool to execute normal tasks. + private final ExecutorService normalTaskExecutorService; + private final Deque idleTaskExecutors = new ArrayDeque<>(); private final Map actorTaskExecutors = new HashMap<>(); private final Object taskExecutorLock = new Object(); @@ -60,8 +66,10 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { int numberThreads) { this.runtime = runtime; this.objectStore = objectStore; - // The thread pool that executes tasks in parallel. - exec = Executors.newFixedThreadPool(numberThreads); + // The thread pool that executes normal tasks in parallel. + normalTaskExecutorService = Executors.newFixedThreadPool(numberThreads); + // The thread pool that executes actor tasks in parallel. + actorTaskExecutorServices = new HashMap<>(); } public void onObjectPut(ObjectId id) { @@ -211,7 +219,14 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { } public void shutdown() { - exec.shutdown(); + // Shutdown actor task executor service. + synchronized (actorTaskExecutorServices) { + for (Map.Entry item : actorTaskExecutorServices.entrySet()) { + item.getValue().shutdown(); + } + } + // Shutdown normal task executor service. + normalTaskExecutorService.shutdown(); } public static ActorId getActorId(TaskSpec taskSpec) { @@ -231,37 +246,54 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { LOGGER.debug("Submitting task: {}.", taskSpec); synchronized (taskAndObjectLock) { Set unreadyObjects = getUnreadyObjects(taskSpec); + + final Runnable runnable = () -> { + TaskExecutor taskExecutor = getTaskExecutor(taskSpec); + try { + List args = getFunctionArgs(taskSpec).stream() + .map(arg -> arg.id != null ? + objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0) + : arg.value) + .collect(Collectors.toList()); + ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec); + List returnObjects = taskExecutor + .execute(getJavaFunctionDescriptor(taskSpec).toList(), args); + ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null); + List returnIds = getReturnIds(taskSpec); + for (int i = 0; i < returnIds.size(); i++) { + NativeRayObject putObject; + if (i >= returnObjects.size()) { + // If the task is an actor task or an actor creation task, + // put the dummy object in object store, so those tasks which depends on it + // can be executed. + putObject = new NativeRayObject(new byte[]{1}, null); + } else { + putObject = returnObjects.get(i); + } + objectStore.putRaw(putObject, returnIds.get(i)); + } + } finally { + returnTaskExecutor(taskExecutor, taskSpec); + } + }; + if (unreadyObjects.isEmpty()) { // If all dependencies are ready, execute this task. - exec.submit(() -> { - TaskExecutor taskExecutor = getTaskExecutor(taskSpec); - try { - List args = getFunctionArgs(taskSpec).stream() - .map(arg -> arg.id != null ? - objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0) - : arg.value) - .collect(Collectors.toList()); - ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec); - List returnObjects = taskExecutor - .execute(getJavaFunctionDescriptor(taskSpec).toList(), args); - ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null); - List returnIds = getReturnIds(taskSpec); - for (int i = 0; i < returnIds.size(); i++) { - NativeRayObject putObject; - if (i >= returnObjects.size()) { - // If the task is an actor task or an actor creation task, - // put the dummy object in object store, so those tasks which depends on it - // can be executed. - putObject = new NativeRayObject(new byte[]{1}, null); - } else { - putObject = returnObjects.get(i); - } - objectStore.putRaw(putObject, returnIds.get(i)); - } - } finally { - returnTaskExecutor(taskExecutor, taskSpec); + if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) { + ExecutorService actorExecutorService = Executors.newSingleThreadExecutor(); + synchronized (actorTaskExecutorServices) { + actorTaskExecutorServices.put(getActorId(taskSpec), actorExecutorService); } - }); + actorExecutorService.submit(runnable); + } else if (taskSpec.getType() == TaskType.ACTOR_TASK) { + synchronized (actorTaskExecutorServices) { + ExecutorService actorExecutorService = actorTaskExecutorServices.get(getActorId(taskSpec)); + actorExecutorService.submit(runnable); + } + } else { + // Normal task. + normalTaskExecutorService.submit(runnable); + } } else { // If some dependencies aren't ready yet, put this task in waiting list. for (ObjectId id : unreadyObjects) { diff --git a/java/test/src/main/java/org/ray/api/TestUtils.java b/java/test/src/main/java/org/ray/api/TestUtils.java index 486e55fe0..94f3217fd 100644 --- a/java/test/src/main/java/org/ray/api/TestUtils.java +++ b/java/test/src/main/java/org/ray/api/TestUtils.java @@ -27,6 +27,12 @@ public class TestUtils { } } + public static void skipTestUnderClusterMode() { + if (getRuntime().getRayConfig().runMode == RunMode.CLUSTER) { + throw new SkipException("This test doesn't work under cluster mode."); + } + } + public static void skipTestIfDirectActorCallEnabled() { skipTestIfDirectActorCallEnabled(true); } diff --git a/java/test/src/main/java/org/ray/api/test/SingleProcessModeTest.java b/java/test/src/main/java/org/ray/api/test/SingleProcessModeTest.java new file mode 100644 index 000000000..4ccca362c --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/SingleProcessModeTest.java @@ -0,0 +1,65 @@ +package org.ray.api.test; + +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.RayObject; +import org.ray.api.TestUtils; +import org.ray.api.annotation.RayRemote; +import org.ray.api.id.ActorId; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class SingleProcessModeTest extends BaseTest { + + private final static int NUM_ACTOR_INSTANCE = 10; + + private final static int TIMES_TO_CALL_PER_ACTOR = 10; + + @RayRemote + static class MyActor { + public MyActor() { + } + + public long getThreadId() { + return Thread.currentThread().getId(); + } + } + + @Test + public void testActorTasksInOneThread() { + TestUtils.skipTestUnderClusterMode(); + + List> actors = new ArrayList<>(); + Map actorThreadIds = new HashMap<>(); + for (int i = 0; i < NUM_ACTOR_INSTANCE; ++i) { + RayActor actor = Ray.createActor(MyActor::new); + actors.add(actor); + actorThreadIds.put(actor.getId(), Ray.call(MyActor::getThreadId, actor).get()); + } + + Map>> allResults = new HashMap<>(); + for (int i = 0; i < NUM_ACTOR_INSTANCE; ++i) { + final RayActor actor = actors.get(i); + List> thisActorResult = new ArrayList<>(); + for (int j = 0; j < TIMES_TO_CALL_PER_ACTOR; ++j) { + thisActorResult.add(Ray.call(MyActor::getThreadId, actor)); + } + allResults.put(actor.getId(), thisActorResult); + } + + // check result. + for (int i = 0; i < NUM_ACTOR_INSTANCE; ++i) { + final RayActor actor = actors.get(i); + final List> thisActorResult = allResults.get(actor.getId()); + // assert + for (RayObject threadId : thisActorResult) { + Assert.assertEquals(threadId.get(), actorThreadIds.get(actor.getId())); + } + } + } +}