From ed761900f68c7c7aeff4e75c00e448c4fa3b9735 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Mon, 9 Sep 2019 14:29:20 +0800 Subject: [PATCH] [Java] Support direct actor call in Java worker (#5504) --- .../ray/api/options/ActorCreationOptions.java | 25 +- java/dependencies.bzl | 2 +- java/runtime/pom.xml | 2 +- .../org/ray/runtime/AbstractRayRuntime.java | 30 +- .../java/org/ray/runtime/RayDevRuntime.java | 16 +- .../org/ray/runtime/RayNativeRuntime.java | 18 +- .../org/ray/runtime/actor/NativeRayActor.java | 6 + .../runtime/raylet/LocalModeRayletClient.java | 29 -- .../runtime/raylet/NativeRayletClient.java | 57 ---- .../org/ray/runtime/raylet/RayletClient.java | 16 - .../ray/runtime/task/ArgumentsBuilder.java | 8 +- .../runtime/task/LocalModeTaskExecutor.java | 22 ++ .../runtime/task/LocalModeTaskSubmitter.java | 4 +- .../ray/runtime/task/NativeTaskExecutor.java | 102 ++++++ .../org/ray/runtime/task/TaskExecutor.java | 84 +---- java/streaming/pom.xml | 2 +- java/test.sh | 3 + java/test/pom.xml | 2 +- .../org/ray/api/RayAlterSuiteListener.java | 23 ++ .../src/main/java/org/ray/api/TestUtils.java | 13 + .../ray/api/test/ActorReconstructionTest.java | 5 +- .../main/java/org/ray/api/test/ActorTest.java | 30 +- .../ray/api/test/BaseMultiLanguageTest.java | 4 +- .../main/java/org/ray/api/test/BaseTest.java | 4 +- .../api/test/CrossLanguageInvocationTest.java | 5 +- .../java/org/ray/api/test/FailureTest.java | 9 +- .../org/ray/api/test/MultiThreadingTest.java | 9 +- .../java/org/ray/api/test/RayCallTest.java | 7 +- .../java/org/ray/api/test/StressTest.java | 2 +- java/testng.xml | 15 +- python/ray/includes/task.pxd | 3 +- python/ray/includes/task.pxi | 1 + src/ray/common/task/task_spec.cc | 8 +- src/ray/common/task/task_spec.h | 2 + src/ray/common/task/task_util.h | 4 +- src/ray/core_worker/context.cc | 5 + src/ray/core_worker/context.h | 5 + src/ray/core_worker/lib/java/jni_init.cc | 6 + src/ray/core_worker/lib/java/jni_utils.h | 4 + .../java/org_ray_runtime_RayNativeRuntime.cc | 19 ++ .../java/org_ray_runtime_RayNativeRuntime.h | 8 + .../org_ray_runtime_actor_NativeRayActor.cc | 10 + .../org_ray_runtime_actor_NativeRayActor.h | 8 + ...g_ray_runtime_raylet_NativeRayletClient.cc | 74 ----- ...rg_ray_runtime_raylet_NativeRayletClient.h | 39 --- ...org_ray_runtime_task_NativeTaskExecutor.cc | 55 ++++ .../org_ray_runtime_task_NativeTaskExecutor.h | 33 ++ ...rg_ray_runtime_task_NativeTaskSubmitter.cc | 9 +- src/ray/core_worker/task_execution.cc | 7 +- src/ray/core_worker/task_interface.cc | 3 +- .../transport/direct_actor_transport.cc | 76 ++++- .../transport/direct_actor_transport.h | 31 +- .../core_worker/transport/raylet_transport.cc | 11 +- .../core_worker/transport/raylet_transport.h | 5 +- src/ray/protobuf/common.proto | 2 + src/ray/protobuf/gcs.proto | 2 + src/ray/raylet/actor_registration.cc | 21 +- src/ray/raylet/actor_registration.h | 5 +- ...org_ray_runtime_raylet_RayletClientImpl.cc | 292 ------------------ src/ray/raylet/node_manager.cc | 19 +- src/ray/raylet/worker.cc | 15 +- 61 files changed, 608 insertions(+), 728 deletions(-) delete mode 100644 java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java delete mode 100644 java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java delete mode 100644 java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java create mode 100644 java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java create mode 100644 java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java create mode 100644 java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java delete mode 100644 src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc delete mode 100644 src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h create mode 100644 src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc create mode 100644 src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h delete mode 100644 src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index 928df3221..cc8a126eb 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -10,26 +10,33 @@ public class ActorCreationOptions extends BaseTaskOptions { public static final int NO_RECONSTRUCTION = 0; public static final int INFINITE_RECONSTRUCTIONS = (int) Math.pow(2, 30); + // DO NOT set this environment variable. It's only used for test purposes. + // Please use `setUseDirectCall` instead. + public static final boolean DEFAULT_USE_DIRECT_CALL = "1" + .equals(System.getenv("ACTOR_CREATION_OPTIONS_DEFAULT_USE_DIRECT_CALL")); public final int maxReconstructions; + public final boolean useDirectCall; + public final String jvmOptions; - private ActorCreationOptions(Map resources, - int maxReconstructions, - String jvmOptions) { + private ActorCreationOptions(Map resources, int maxReconstructions, + boolean useDirectCall, String jvmOptions) { super(resources); this.maxReconstructions = maxReconstructions; + this.useDirectCall = useDirectCall; this.jvmOptions = jvmOptions; } /** - * The inner class for building ActorCreationOptions. + * The inner class for building ActorCreationOptions. */ public static class Builder { private Map resources = new HashMap<>(); private int maxReconstructions = NO_RECONSTRUCTION; + private boolean useDirectCall = DEFAULT_USE_DIRECT_CALL; private String jvmOptions = null; public Builder setResources(Map resources) { @@ -42,13 +49,21 @@ public class ActorCreationOptions extends BaseTaskOptions { return this; } + // Since direct call is not fully supported yet (see issue #5559), + // users are not allowed to set the option to true. + // TODO (kfstorm): uncomment when direct call is ready. +// public Builder setUseDirectCall(boolean useDirectCall) { +// this.useDirectCall = useDirectCall; +// return this; +// } + public Builder setJvmOptions(String jvmOptions) { this.jvmOptions = jvmOptions; return this; } public ActorCreationOptions createActorCreationOptions() { - return new ActorCreationOptions(resources, maxReconstructions, jvmOptions); + return new ActorCreationOptions(resources, maxReconstructions, useDirectCall, jvmOptions); } } diff --git a/java/dependencies.bzl b/java/dependencies.bzl index 26e36dff5..c51a181ed 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -16,7 +16,7 @@ def gen_java_deps(): "org.apache.commons:commons-lang3:3.4", "org.ow2.asm:asm:6.0", "org.slf4j:slf4j-log4j12:1.7.25", - "org.testng:testng:6.9.9", + "org.testng:testng:6.9.10", "redis.clients:jedis:2.8.0", ], repositories = [ diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index 3c40f7ffc..eb6c268f8 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -75,7 +75,7 @@ org.testng testng - 6.9.9 + 6.9.10 redis.clients diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 96f55384e..be7c4e471 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -13,11 +13,11 @@ import org.ray.api.exception.RayException; import org.ray.api.function.RayFunc; import org.ray.api.function.RayFuncVoid; import org.ray.api.id.ObjectId; -import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; import org.ray.api.options.CallOptions; import org.ray.api.runtime.RayRuntime; import org.ray.api.runtimecontext.RuntimeContext; +import org.ray.runtime.actor.NativeRayActor; import org.ray.runtime.config.RayConfig; import org.ray.runtime.context.RuntimeContextImpl; import org.ray.runtime.context.WorkerContext; @@ -28,7 +28,6 @@ import org.ray.runtime.gcs.GcsClient; import org.ray.runtime.generated.Common.Language; import org.ray.runtime.object.ObjectStore; import org.ray.runtime.object.RayObjectImpl; -import org.ray.runtime.raylet.RayletClient; import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.FunctionArg; import org.ray.runtime.task.TaskExecutor; @@ -51,7 +50,6 @@ public abstract class AbstractRayRuntime implements RayRuntime { protected ObjectStore objectStore; protected TaskSubmitter taskSubmitter; - protected RayletClient rayletClient; protected WorkerContext workerContext; public AbstractRayRuntime(RayConfig rayConfig) { @@ -85,15 +83,6 @@ public abstract class AbstractRayRuntime implements RayRuntime { objectStore.delete(objectIds, localOnly, deleteCreatingTasks); } - @Override - public void setResource(String resourceName, double capacity, UniqueId nodeId) { - Preconditions.checkArgument(Double.compare(capacity, 0) >= 0); - if (nodeId == null) { - nodeId = UniqueId.NIL; - } - rayletClient.setResource(resourceName, capacity, nodeId); - } - @Override public WaitResult wait(List> waitList, int numReturns, int timeoutMs) { return objectStore.wait(waitList, numReturns, timeoutMs); @@ -176,7 +165,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { private RayObject callNormalFunction(FunctionDescriptor functionDescriptor, Object[] args, int numReturns, CallOptions options) { - List functionArgs = ArgumentsBuilder.wrap(args); + List functionArgs = ArgumentsBuilder.wrap(args, /*isDirectCall*/false); List returnIds = taskSubmitter.submitTask(functionDescriptor, functionArgs, numReturns, options); Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1); @@ -189,7 +178,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { private RayObject callActorFunction(RayActor rayActor, FunctionDescriptor functionDescriptor, Object[] args, int numReturns) { - List functionArgs = ArgumentsBuilder.wrap(args); + List functionArgs = ArgumentsBuilder.wrap(args, isDirectCall(rayActor)); List returnIds = taskSubmitter.submitActorTask(rayActor, functionDescriptor, functionArgs, numReturns, null); Preconditions.checkState(returnIds.size() == numReturns && returnIds.size() <= 1); @@ -202,7 +191,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { private RayActor createActorImpl(FunctionDescriptor functionDescriptor, Object[] args, ActorCreationOptions options) { - List functionArgs = ArgumentsBuilder.wrap(args); + List functionArgs = ArgumentsBuilder.wrap(args, /*isDirectCall*/false); if (functionDescriptor.getLanguage() != Language.JAVA && options != null) { Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions)); } @@ -210,6 +199,13 @@ public abstract class AbstractRayRuntime implements RayRuntime { return actor; } + private boolean isDirectCall(RayActor rayActor) { + if (rayActor instanceof NativeRayActor) { + return ((NativeRayActor) rayActor).isDirectCallActor(); + } + return false; + } + public WorkerContext getWorkerContext() { return workerContext; } @@ -218,10 +214,6 @@ public abstract class AbstractRayRuntime implements RayRuntime { return objectStore; } - public RayletClient getRayletClient() { - return rayletClient; - } - public FunctionManager getFunctionManager() { return functionManager; } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java index c96b811ce..e6891a50f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java @@ -2,15 +2,19 @@ package org.ray.runtime; import java.util.concurrent.atomic.AtomicInteger; import org.ray.api.id.JobId; +import org.ray.api.id.UniqueId; import org.ray.runtime.config.RayConfig; import org.ray.runtime.context.LocalModeWorkerContext; import org.ray.runtime.object.LocalModeObjectStore; -import org.ray.runtime.raylet.LocalModeRayletClient; +import org.ray.runtime.task.LocalModeTaskExecutor; import org.ray.runtime.task.LocalModeTaskSubmitter; -import org.ray.runtime.task.TaskExecutor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class RayDevRuntime extends AbstractRayRuntime { + private static final Logger LOGGER = LoggerFactory.getLogger(RayDevRuntime.class); + private AtomicInteger jobCounter = new AtomicInteger(0); public RayDevRuntime(RayConfig rayConfig) { @@ -18,14 +22,13 @@ public class RayDevRuntime extends AbstractRayRuntime { if (rayConfig.getJobId().isNil()) { rayConfig.setJobId(nextJobId()); } - taskExecutor = new TaskExecutor(this); + taskExecutor = new LocalModeTaskExecutor(this); workerContext = new LocalModeWorkerContext(rayConfig.getJobId()); objectStore = new LocalModeObjectStore(workerContext); taskSubmitter = new LocalModeTaskSubmitter(this, (LocalModeObjectStore) objectStore, rayConfig.numberExecThreadsForDevRuntime); ((LocalModeObjectStore) objectStore).addObjectPutCallback( objectId -> ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId)); - rayletClient = new LocalModeRayletClient(); } @Override @@ -33,6 +36,11 @@ public class RayDevRuntime extends AbstractRayRuntime { taskExecutor = null; } + @Override + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + LOGGER.error("Not implemented under SINGLE_PROCESS mode."); + } + private JobId nextJobId() { return JobId.fromInt(jobCounter.getAndIncrement()); } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 99908636d..f9e2364cc 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -9,6 +9,7 @@ import java.util.HashMap; import java.util.Map; import org.apache.commons.io.FileUtils; import org.ray.api.id.JobId; +import org.ray.api.id.UniqueId; import org.ray.runtime.config.RayConfig; import org.ray.runtime.context.NativeWorkerContext; import org.ray.runtime.gcs.GcsClient; @@ -16,8 +17,8 @@ import org.ray.runtime.gcs.GcsClientOptions; import org.ray.runtime.gcs.RedisClient; import org.ray.runtime.generated.Common.WorkerType; import org.ray.runtime.object.NativeObjectStore; -import org.ray.runtime.raylet.NativeRayletClient; import org.ray.runtime.runner.RunManager; +import org.ray.runtime.task.NativeTaskExecutor; import org.ray.runtime.task.NativeTaskSubmitter; import org.ray.runtime.task.TaskExecutor; import org.ray.runtime.util.FileUtil; @@ -112,11 +113,10 @@ public final class RayNativeRuntime extends AbstractRayRuntime { new GcsClientOptions(rayConfig)); Preconditions.checkState(nativeCoreWorkerPointer != 0); - taskExecutor = new TaskExecutor(this); + taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this); workerContext = new NativeWorkerContext(nativeCoreWorkerPointer); objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer); taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer); - rayletClient = new NativeRayletClient(nativeCoreWorkerPointer); // register registerWorker(); @@ -136,6 +136,15 @@ public final class RayNativeRuntime extends AbstractRayRuntime { } } + @Override + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + Preconditions.checkArgument(Double.compare(capacity, 0) >= 0); + if (nodeId == null) { + nodeId = UniqueId.NIL; + } + nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes()); + } + public void run() { nativeRunTaskExecutor(nativeCoreWorkerPointer, taskExecutor); } @@ -176,4 +185,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { private static native void nativeSetup(String logDir); private static native void nativeShutdownHook(); + + private static native void nativeSetResource(long conn, String resourceName, double capacity, + byte[] nodeId); } diff --git a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java index ecdf03053..8dd7ac8c3 100644 --- a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java +++ b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java @@ -51,6 +51,10 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable { return Language.forNumber(nativeGetLanguage(nativeActorHandle)); } + public boolean isDirectCallActor() { + return nativeIsDirectCallActor(nativeActorHandle); + } + @Override public String getModuleName() { Preconditions.checkState(getLanguage() == Language.PYTHON); @@ -90,6 +94,8 @@ public class NativeRayActor implements RayActor, RayPyActor, Externalizable { private static native int nativeGetLanguage(long nativeActorHandle); + private static native boolean nativeIsDirectCallActor(long nativeActorHandle); + private static native List nativeGetActorCreationTaskFunctionDescriptor( long nativeActorHandle); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java deleted file mode 100644 index 9d43244c3..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/LocalModeRayletClient.java +++ /dev/null @@ -1,29 +0,0 @@ -package org.ray.runtime.raylet; - -import org.apache.commons.lang3.NotImplementedException; -import org.ray.api.id.ActorId; -import org.ray.api.id.UniqueId; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Raylet client for local mode. - */ -public class LocalModeRayletClient implements RayletClient { - private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeRayletClient.class); - - @Override - public UniqueId prepareCheckpoint(ActorId actorId) { - throw new NotImplementedException("Not implemented."); - } - - @Override - public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) { - throw new NotImplementedException("Not implemented."); - } - - @Override - public void setResource(String resourceName, double capacity, UniqueId nodeId) { - LOGGER.error("Not implemented under SINGLE_PROCESS mode."); - } -} diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java deleted file mode 100644 index ed5f10f12..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/NativeRayletClient.java +++ /dev/null @@ -1,57 +0,0 @@ -package org.ray.runtime.raylet; - -import org.ray.api.exception.RayException; -import org.ray.api.id.ActorId; -import org.ray.api.id.UniqueId; - -/** - * Raylet client for cluster mode. This is a wrapper class for C++ RayletClient. - */ -public class NativeRayletClient implements RayletClient { - - /** - * The native pointer of core worker. - */ - private long nativeCoreWorkerPointer = 0; - - public NativeRayletClient(long nativeCoreWorkerPointer) { - this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; - } - - @Override - public UniqueId prepareCheckpoint(ActorId actorId) { - return new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer, actorId.getBytes())); - } - - @Override - public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) { - nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, actorId.getBytes(), - checkpointId.getBytes()); - } - - - public void setResource(String resourceName, double capacity, UniqueId nodeId) { - nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes()); - } - - /// Native method declarations. - /// - /// If you change the signature of any native methods, please re-generate - /// the C++ header file and update the C++ implementation accordingly: - /// - /// Suppose that $Dir is your ray root directory. - /// 1) pushd $Dir/java/runtime/target/classes - /// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.NativeRayletClient - /// 3) clang-format -i org_ray_runtime_raylet_NativeRayletClient.h - /// 4) cp org_ray_runtime_raylet_NativeRayletClient.h $Dir/src/ray/core_worker/lib/java/ - /// 5) vim $Dir/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc - /// 6) popd - - private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId); - - private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId, - byte[] checkpointId); - - private static native void nativeSetResource(long conn, String resourceName, double capacity, - byte[] nodeId) throws RayException; -} diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java deleted file mode 100644 index 144187b6b..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ /dev/null @@ -1,16 +0,0 @@ -package org.ray.runtime.raylet; - -import org.ray.api.id.ActorId; -import org.ray.api.id.UniqueId; - -/** - * Client to the Raylet backend. - */ -public interface RayletClient { - - UniqueId prepareCheckpoint(ActorId actorId); - - void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId); - - void setResource(String resourceName, double capacity, UniqueId nodeId); -} diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 11e524619..07ae3dfb1 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -25,16 +25,20 @@ public class ArgumentsBuilder { /** * Convert real function arguments to task spec arguments. */ - public static List wrap(Object[] args) { + public static List wrap(Object[] args, boolean isDirectCall) { List ret = new ArrayList<>(); for (Object arg : args) { ObjectId id = null; NativeRayObject value = null; if (arg instanceof RayObject) { + if (isDirectCall) { + throw new IllegalArgumentException( + "Passing RayObject to a direct call actor is not supported."); + } id = ((RayObject) arg).getId(); } else { value = ObjectSerializer.serialize(arg); - if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) { + if (!isDirectCall && value.data.length > LARGEST_SIZE_PASS_BY_VALUE) { RayRuntime runtime = Ray.internal(); if (runtime instanceof RayMultiWorkerNativeRuntime) { runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime(); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java new file mode 100644 index 000000000..24e6f15b9 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java @@ -0,0 +1,22 @@ +package org.ray.runtime.task; + +import org.ray.api.id.ActorId; +import org.ray.runtime.AbstractRayRuntime; + +/** + * Task executor for local mode. + */ +public class LocalModeTaskExecutor extends TaskExecutor { + + public LocalModeTaskExecutor(AbstractRayRuntime runtime) { + super(runtime); + } + + @Override + protected void maybeSaveCheckpoint(Object actor, ActorId actorId) { + } + + @Override + protected void maybeLoadCheckpoint(Object actor, ActorId actorId) { + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java index 0cb23fcb6..026cb9cc4 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java @@ -95,12 +95,12 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { if (task.getType() == TaskType.ACTOR_TASK) { taskExecutor = actorTaskExecutors.get(getActorId(task)); } else if (task.getType() == TaskType.ACTOR_CREATION_TASK) { - taskExecutor = new TaskExecutor(runtime); + taskExecutor = new LocalModeTaskExecutor(runtime); actorTaskExecutors.put(getActorId(task), taskExecutor); } else if (idleTaskExecutors.size() > 0) { taskExecutor = idleTaskExecutors.pop(); } else { - taskExecutor = new TaskExecutor(runtime); + taskExecutor = new LocalModeTaskExecutor(runtime); } } currentTaskExecutor.set(taskExecutor); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java new file mode 100644 index 000000000..36e7259a4 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java @@ -0,0 +1,102 @@ +package org.ray.runtime.task; + +import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.List; +import org.ray.api.Checkpointable; +import org.ray.api.Checkpointable.Checkpoint; +import org.ray.api.Checkpointable.CheckpointContext; +import org.ray.api.id.ActorId; +import org.ray.api.id.UniqueId; +import org.ray.runtime.AbstractRayRuntime; + +/** + * Task executor for cluster mode. + */ +public class NativeTaskExecutor extends TaskExecutor { + + // TODO(hchen): Use the C++ config. + private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20; + + /** + * The native pointer of core worker. + */ + private final long nativeCoreWorkerPointer; + + /** + * Number of tasks executed since last actor checkpoint. + */ + private int numTasksSinceLastCheckpoint = 0; + + /** + * IDs of this actor's previous checkpoints. + */ + private List checkpointIds; + + /** + * Timestamp of the last actor checkpoint. + */ + private long lastCheckpointTimestamp = 0; + + public NativeTaskExecutor(long nativeCoreWorkerPointer, AbstractRayRuntime runtime) { + super(runtime); + this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; + } + + @Override + protected void maybeSaveCheckpoint(Object actor, ActorId actorId) { + if (!(actor instanceof Checkpointable)) { + return; + } + CheckpointContext checkpointContext = new CheckpointContext(actorId, + ++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp); + Checkpointable checkpointable = (Checkpointable) actor; + if (!checkpointable.shouldCheckpoint(checkpointContext)) { + return; + } + numTasksSinceLastCheckpoint = 0; + lastCheckpointTimestamp = System.currentTimeMillis(); + UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer)); + checkpointIds.add(checkpointId); + if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) { + ((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0)); + checkpointIds.remove(0); + } + checkpointable.saveCheckpoint(actorId, checkpointId); + } + + @Override + protected void maybeLoadCheckpoint(Object actor, ActorId actorId) { + if (!(actor instanceof Checkpointable)) { + return; + } + numTasksSinceLastCheckpoint = 0; + lastCheckpointTimestamp = System.currentTimeMillis(); + checkpointIds = new ArrayList<>(); + List availableCheckpoints + = runtime.getGcsClient().getCheckpointsForActor(actorId); + if (availableCheckpoints.isEmpty()) { + return; + } + UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints); + if (checkpointId != null) { + boolean checkpointValid = false; + for (Checkpoint checkpoint : availableCheckpoints) { + if (checkpoint.checkpointId.equals(checkpointId)) { + checkpointValid = true; + break; + } + } + Preconditions.checkArgument(checkpointValid, + "'loadCheckpoint' must return a checkpoint ID that exists in the " + + "'availableCheckpoints' list, or null."); + + nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, checkpointId.getBytes()); + } + } + + private static native byte[] nativePrepareCheckpoint(long nativeCoreWorkerPointer); + + private static native void nativeNotifyActorResumedFromCheckpoint(long nativeCoreWorkerPointer, + byte[] checkpointId); +} diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java index 95ff86c67..dedf693ed 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java @@ -3,16 +3,11 @@ package org.ray.runtime.task; import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.List; -import org.ray.api.Checkpointable; -import org.ray.api.Checkpointable.Checkpoint; -import org.ray.api.Checkpointable.CheckpointContext; import org.ray.api.exception.RayTaskException; import org.ray.api.id.ActorId; import org.ray.api.id.JobId; import org.ray.api.id.TaskId; -import org.ray.api.id.UniqueId; import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.config.RunMode; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.functionmanager.RayFunction; import org.ray.runtime.generated.Common.TaskType; @@ -24,13 +19,10 @@ import org.slf4j.LoggerFactory; /** * The task executor, which executes tasks assigned by raylet continuously. */ -public final class TaskExecutor { +public abstract class TaskExecutor { private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class); - // TODO(hchen): Use the C++ config. - private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20; - protected final AbstractRayRuntime runtime; /** @@ -43,22 +35,7 @@ public final class TaskExecutor { */ private Exception actorCreationException = null; - /** - * Number of tasks executed since last actor checkpoint. - */ - private int numTasksSinceLastCheckpoint = 0; - - /** - * IDs of this actor's previous checkpoints. - */ - private List checkpointIds; - - /** - * Timestamp of the last actor checkpoint. - */ - private long lastCheckpointTimestamp = 0; - - public TaskExecutor(AbstractRayRuntime runtime) { + protected TaskExecutor(AbstractRayRuntime runtime) { this.runtime = runtime; } @@ -134,60 +111,7 @@ public final class TaskExecutor { rayFunctionInfo.get(2)); } - private void maybeSaveCheckpoint(Object actor, ActorId actorId) { - if (!(actor instanceof Checkpointable)) { - return; - } - if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) { - // Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet. - return; - } - CheckpointContext checkpointContext = new CheckpointContext(actorId, - ++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp); - Checkpointable checkpointable = (Checkpointable) actor; - if (!checkpointable.shouldCheckpoint(checkpointContext)) { - return; - } - numTasksSinceLastCheckpoint = 0; - lastCheckpointTimestamp = System.currentTimeMillis(); - UniqueId checkpointId = runtime.getRayletClient().prepareCheckpoint(actorId); - checkpointIds.add(checkpointId); - if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) { - ((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0)); - checkpointIds.remove(0); - } - checkpointable.saveCheckpoint(actorId, checkpointId); - } + protected abstract void maybeSaveCheckpoint(Object actor, ActorId actorId); - private void maybeLoadCheckpoint(Object actor, ActorId actorId) { - if (!(actor instanceof Checkpointable)) { - return; - } - if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) { - // Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet. - return; - } - numTasksSinceLastCheckpoint = 0; - lastCheckpointTimestamp = System.currentTimeMillis(); - checkpointIds = new ArrayList<>(); - List availableCheckpoints - = runtime.getGcsClient().getCheckpointsForActor(actorId); - if (availableCheckpoints.isEmpty()) { - return; - } - UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints); - if (checkpointId != null) { - boolean checkpointValid = false; - for (Checkpoint checkpoint : availableCheckpoints) { - if (checkpoint.checkpointId.equals(checkpointId)) { - checkpointValid = true; - break; - } - } - Preconditions.checkArgument(checkpointValid, - "'loadCheckpoint' must return a checkpoint ID that exists in the " - + "'availableCheckpoints' list, or null."); - runtime.getRayletClient().notifyActorResumedFromCheckpoint(actorId, checkpointId); - } - } + protected abstract void maybeLoadCheckpoint(Object actor, ActorId actorId); } diff --git a/java/streaming/pom.xml b/java/streaming/pom.xml index 382233fb0..e624bd6e5 100644 --- a/java/streaming/pom.xml +++ b/java/streaming/pom.xml @@ -50,7 +50,7 @@ org.testng testng - 6.9.9 + 6.9.10 diff --git a/java/test.sh b/java/test.sh index ba728f14b..4612bf7e3 100755 --- a/java/test.sh +++ b/java/test.sh @@ -27,6 +27,9 @@ echo "Running tests under cluster mode." # bazel test //java:all_tests --action_env=ENABLE_MULTI_LANGUAGE_TESTS=1 --test_output="errors" || cluster_exit_code=$? ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml +echo "Running tests under cluster mode with direct actor call turned on." +ENABLE_MULTI_LANGUAGE_TESTS=1 ACTOR_CREATION_OPTIONS_DEFAULT_USE_DIRECT_CALL=1 run_testng java -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml + echo "Running tests under single-process mode." # bazel test //java:all_tests --jvmopt="-Dray.run-mode=SINGLE_PROCESS" --test_output="errors" || single_exit_code=$? run_testng java -Dray.run-mode="SINGLE_PROCESS" -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml diff --git a/java/test/pom.xml b/java/test/pom.xml index 6a3a31d20..3dfbbbae8 100644 --- a/java/test/pom.xml +++ b/java/test/pom.xml @@ -65,7 +65,7 @@ org.testng testng - 6.9.9 + 6.9.10 diff --git a/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java b/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java new file mode 100644 index 000000000..d5d042c1d --- /dev/null +++ b/java/test/src/main/java/org/ray/api/RayAlterSuiteListener.java @@ -0,0 +1,23 @@ +package org.ray.api; + +import java.util.List; +import org.ray.api.options.ActorCreationOptions; +import org.testng.IAlterSuiteListener; +import org.testng.xml.XmlGroups; +import org.testng.xml.XmlRun; +import org.testng.xml.XmlSuite; + +public class RayAlterSuiteListener implements IAlterSuiteListener { + + @Override + public void alter(List suites) { + XmlSuite suite = suites.get(0); + if (ActorCreationOptions.DEFAULT_USE_DIRECT_CALL) { + XmlGroups groups = new XmlGroups(); + XmlRun run = new XmlRun(); + run.onInclude("directCall"); + groups.setRun(run); + suite.setGroups(groups); + } + } +} diff --git a/java/test/src/main/java/org/ray/api/TestUtils.java b/java/test/src/main/java/org/ray/api/TestUtils.java index a03bd627f..e036aa7d2 100644 --- a/java/test/src/main/java/org/ray/api/TestUtils.java +++ b/java/test/src/main/java/org/ray/api/TestUtils.java @@ -1,8 +1,10 @@ package org.ray.api; import com.google.common.base.Preconditions; +import java.io.Serializable; import java.util.function.Supplier; import org.ray.api.annotation.RayRemote; +import org.ray.api.options.ActorCreationOptions; import org.ray.api.runtime.RayRuntime; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayMultiWorkerNativeRuntime; @@ -12,6 +14,11 @@ import org.testng.SkipException; public class TestUtils { + public static class LargeObject implements Serializable { + + public byte[] data = new byte[1024 * 1024]; + } + private static final int WAIT_INTERVAL_MS = 5; public static void skipTestUnderSingleProcess() { @@ -20,6 +27,12 @@ public class TestUtils { } } + public static void skipTestIfDirectActorCallEnabled() { + if (ActorCreationOptions.DEFAULT_USE_DIRECT_CALL) { + throw new SkipException("This test doesn't work when direct actor call is enabled."); + } + } + /** * Wait until the given condition is met. * diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index 3e50b4d96..43ccfe0ff 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -17,6 +17,7 @@ import org.ray.api.options.ActorCreationOptions; import org.testng.Assert; import org.testng.annotations.Test; +@Test(groups = {"directCall"}) public class ActorReconstructionTest extends BaseTest { @RayRemote() @@ -44,7 +45,6 @@ public class ActorReconstructionTest extends BaseTest { } } - @Test public void testActorReconstruction() throws InterruptedException, IOException { TestUtils.skipTestUnderSingleProcess(); ActorCreationOptions options = @@ -65,7 +65,7 @@ public class ActorReconstructionTest extends BaseTest { // Try calling increase on this actor again and check the value is now 4. int value = Ray.call(Counter::increase, actor).get(); - Assert.assertEquals(value, 4); + Assert.assertEquals(value, options.useDirectCall ? 1 : 4); Assert.assertTrue(Ray.call(Counter::wasCurrentActorReconstructed, actor).get()); @@ -125,7 +125,6 @@ public class ActorReconstructionTest extends BaseTest { } } - @Test public void testActorCheckpointing() throws IOException, InterruptedException { TestUtils.skipTestUnderSingleProcess(); ActorCreationOptions options = diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index a8870558e..8b1b441e4 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -8,6 +8,7 @@ import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.TestUtils; +import org.ray.api.TestUtils.LargeObject; import org.ray.api.annotation.RayRemote; import org.ray.api.exception.UnreconstructableException; import org.ray.api.id.UniqueId; @@ -16,6 +17,7 @@ import org.ray.runtime.object.NativeRayObject; import org.testng.Assert; import org.testng.annotations.Test; +@Test(groups = {"directCall"}) public class ActorTest extends BaseTest { @RayRemote @@ -39,9 +41,13 @@ public class ActorTest extends BaseTest { value += delta; return value; } + + public int accessLargeObject(LargeObject largeObject) { + value += largeObject.data.length; + return value; + } } - @Test public void testCreateAndCallActor() { // Test creating an actor from a constructor RayActor actor = Ray.createActor(Counter::new, 1); @@ -52,12 +58,18 @@ public class ActorTest extends BaseTest { Assert.assertEquals(Integer.valueOf(3), Ray.call(Counter::increaseAndGet, actor, 1).get()); } + public void testCallActorWithLargeObject() { + RayActor actor = Ray.createActor(Counter::new, 1); + LargeObject largeObject = new LargeObject(); + Assert.assertEquals(Integer.valueOf(largeObject.data.length + 1), + Ray.call(Counter::accessLargeObject, actor, largeObject).get()); + } + @RayRemote - public static Counter factory(int initValue) { + static Counter factory(int initValue) { return new Counter(initValue); } - @Test public void testCreateActorFromFactory() { // Test creating an actor from a factory method RayActor actor = Ray.createActor(ActorTest::factory, 1); @@ -67,24 +79,23 @@ public class ActorTest extends BaseTest { } @RayRemote - public static int testActorAsFirstParameter(RayActor actor, int delta) { + static int testActorAsFirstParameter(RayActor actor, int delta) { RayObject res = Ray.call(Counter::increaseAndGet, actor, delta); return res.get(); } @RayRemote - public static int testActorAsSecondParameter(int delta, RayActor actor) { + static int testActorAsSecondParameter(int delta, RayActor actor) { RayObject res = Ray.call(Counter::increaseAndGet, actor, delta); return res.get(); } @RayRemote - public static int testActorAsFieldOfParameter(List> actor, int delta) { + static int testActorAsFieldOfParameter(List> actor, int delta) { RayObject res = Ray.call(Counter::increaseAndGet, actor.get(0), delta); return res.get(); } - @Test public void testPassActorAsParameter() { RayActor actor = Ray.createActor(Counter::new, 0); Assert.assertEquals(Integer.valueOf(1), @@ -96,7 +107,6 @@ public class ActorTest extends BaseTest { .get()); } - @Test public void testForkingActorHandle() { TestUtils.skipTestUnderSingleProcess(); RayActor counter = Ray.createActor(Counter::new, 100); @@ -105,9 +115,11 @@ public class ActorTest extends BaseTest { Assert.assertEquals(Integer.valueOf(103), Ray.call(Counter::increaseAndGet, counter2, 2).get()); } - @Test public void testUnreconstructableActorObject() throws InterruptedException { TestUtils.skipTestUnderSingleProcess(); + // The UnreconstructableException is created by raylet. + // TODO (kfstorm): This should be supported by direct actor call. + TestUtils.skipTestIfDirectActorCallEnabled(); RayActor counter = Ray.createActor(Counter::new, 100); // Call an actor method. RayObject value = Ray.call(Counter::getValue, counter); diff --git a/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java b/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java index 0603a917e..7095810d6 100644 --- a/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java +++ b/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java @@ -45,7 +45,7 @@ public abstract class BaseMultiLanguageTest { } } - @BeforeClass + @BeforeClass(alwaysRun = true) public void setUp() { if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) { LOGGER.info("Skip Multi-language tests because environment variable " @@ -100,7 +100,7 @@ public abstract class BaseMultiLanguageTest { return ImmutableMap.of(); } - @AfterClass + @AfterClass(alwaysRun = true) public void tearDown() { // Disconnect to the cluster. Ray.shutdown(); diff --git a/java/test/src/main/java/org/ray/api/test/BaseTest.java b/java/test/src/main/java/org/ray/api/test/BaseTest.java index 4c3973064..fa1d078de 100644 --- a/java/test/src/main/java/org/ray/api/test/BaseTest.java +++ b/java/test/src/main/java/org/ray/api/test/BaseTest.java @@ -16,7 +16,7 @@ public class BaseTest { private List filesToDelete; - @BeforeMethod + @BeforeMethod(alwaysRun = true) public void setUpBase(Method method) { LOGGER.info("===== Running test: " + method.getDeclaringClass().getName() + "." + method.getName()); @@ -34,7 +34,7 @@ public class BaseTest { filesToDelete.forEach(File::deleteOnExit); } - @AfterMethod + @AfterMethod(alwaysRun = true) public void tearDownBase() { Ray.shutdown(); diff --git a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java index 2f75c7b54..0cde562b0 100644 --- a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java +++ b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java @@ -9,6 +9,7 @@ import org.apache.commons.io.FileUtils; import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.RayPyActor; +import org.ray.api.TestUtils; import org.testng.Assert; import org.testng.annotations.Test; @@ -45,8 +46,10 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { Assert.assertEquals(res.get(), "Response from Python: hello".getBytes()); } - @Test + @Test(groups = {"directCall"}) public void testCallingPythonActor() { + // Python worker doesn't support direct call yet. + TestUtils.skipTestIfDirectActorCallEnabled(); RayPyActor actor = Ray.createPyActor(PYTHON_MODULE, "Counter", "1".getBytes()); RayObject res = Ray.callPy(actor, "increase", "1".getBytes()); Assert.assertEquals(res.get(), "2".getBytes()); diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java index 4dbcbcf4b..83548d90f 100644 --- a/java/test/src/main/java/org/ray/api/test/FailureTest.java +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -76,14 +76,14 @@ public class FailureTest extends BaseTest { assertTaskFailedWithRayTaskException(Ray.call(FailureTest::badFunc)); } - @Test + @Test(groups = {"directCall"}) public void testActorCreationFailure() { TestUtils.skipTestUnderSingleProcess(); RayActor actor = Ray.createActor(BadActor::new, true); assertTaskFailedWithRayTaskException(Ray.call(BadActor::badMethod, actor)); } - @Test + @Test(groups = {"directCall"}) public void testActorTaskFailure() { TestUtils.skipTestUnderSingleProcess(); RayActor actor = Ray.createActor(BadActor::new, false); @@ -102,9 +102,12 @@ public class FailureTest extends BaseTest { } } - @Test + @Test(groups = {"directCall"}) public void testActorProcessDying() { TestUtils.skipTestUnderSingleProcess(); + // This test case hangs if the worker to worker connection is implemented with grpc. + // TODO (kfstorm): Should be fixed. + TestUtils.skipTestIfDirectActorCallEnabled(); RayActor actor = Ray.createActor(BadActor::new, false); try { Ray.call(BadActor::badMethod2, actor).get(); diff --git a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java index ce2fb2452..bc5a5f3e3 100644 --- a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java +++ b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java @@ -21,7 +21,7 @@ import org.slf4j.LoggerFactory; import org.testng.Assert; import org.testng.annotations.Test; - +@Test(groups = {"directCall"}) public class MultiThreadingTest extends BaseTest { private static final Logger LOGGER = LoggerFactory.getLogger(MultiThreadingTest.class); @@ -30,7 +30,7 @@ public class MultiThreadingTest extends BaseTest { private static final int NUM_THREADS = 20; @RayRemote - public static Integer echo(int num) { + static Integer echo(int num) { return num; } @@ -73,7 +73,7 @@ public class MultiThreadingTest extends BaseTest { } } - public static String testMultiThreading() { + static String testMultiThreading() { Random random = new Random(); // Test calling normal functions. runTestCaseInMultipleThreads(() -> { @@ -123,12 +123,10 @@ public class MultiThreadingTest extends BaseTest { return "ok"; } - @Test public void testInDriver() { testMultiThreading(); } - @Test public void testInWorker() { // Single-process mode doesn't have real workers. TestUtils.skipTestUnderSingleProcess(); @@ -136,7 +134,6 @@ public class MultiThreadingTest extends BaseTest { Assert.assertEquals("ok", obj.get()); } - @Test public void testGetCurrentActorId() { TestUtils.skipTestUnderSingleProcess(); RayActor actorIdTester = Ray.createActor(ActorIdTester::new); diff --git a/java/test/src/main/java/org/ray/api/test/RayCallTest.java b/java/test/src/main/java/org/ray/api/test/RayCallTest.java index db496e2b8..65bc877e8 100644 --- a/java/test/src/main/java/org/ray/api/test/RayCallTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayCallTest.java @@ -2,11 +2,11 @@ package org.ray.api.test; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import java.io.Serializable; import java.util.List; import java.util.Map; import org.ray.api.Ray; import org.ray.api.TestUtils; +import org.ray.api.TestUtils.LargeObject; import org.ray.api.annotation.RayRemote; import org.ray.api.id.ObjectId; import org.testng.Assert; @@ -67,11 +67,6 @@ public class RayCallTest extends BaseTest { return val; } - public static class LargeObject implements Serializable { - - private byte[] data = new byte[1024 * 1024]; - } - @RayRemote private static LargeObject testLargeObject(LargeObject largeObject) { return largeObject; diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java index e2efecbf2..e9e261463 100644 --- a/java/test/src/main/java/org/ray/api/test/StressTest.java +++ b/java/test/src/main/java/org/ray/api/test/StressTest.java @@ -72,7 +72,7 @@ public class StressTest extends BaseTest { } } - @Test + @Test(groups = {"directCall"}) public void testSubmittingManyTasksToOneActor() { TestUtils.skipTestUnderSingleProcess(); RayActor actor = Ray.createActor(Actor::new); diff --git a/java/testng.xml b/java/testng.xml index c24545572..9448cb30f 100644 --- a/java/testng.xml +++ b/java/testng.xml @@ -1,10 +1,13 @@ - - - - - - + + + + + + + + + diff --git a/python/ray/includes/task.pxd b/python/ray/includes/task.pxd index 1645ebf85..ba306e09f 100644 --- a/python/ray/includes/task.pxd +++ b/python/ray/includes/task.pxd @@ -92,7 +92,8 @@ cdef extern from "ray/common/task/task_util.h" namespace "ray" nogil: TaskSpecBuilder &SetActorCreationTaskSpec( const CActorID &actor_id, uint64_t max_reconstructions, - const c_vector[c_string] &dynamic_worker_options) + const c_vector[c_string] &dynamic_worker_options, + c_bool is_direct_call) TaskSpecBuilder &SetActorTaskSpec( const CActorID &actor_id, const CActorHandleID &actor_handle_id, diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi index 943a123ef..5d0d1251c 100644 --- a/python/ray/includes/task.pxi +++ b/python/ray/includes/task.pxi @@ -82,6 +82,7 @@ cdef class TaskSpec: actor_creation_id.native(), max_actor_reconstructions, [], + False, ) elif not actor_id.is_nil(): # Actor task. diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index cca5bf048..0fdfb5e4e 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -154,6 +154,11 @@ std::vector TaskSpecification::NewActorHandles() const { message_->actor_task_spec().new_actor_handles()); } +bool TaskSpecification::IsDirectCall() const { + RAY_CHECK(IsActorCreationTask()); + return message_->actor_creation_task_spec().is_direct_call(); +} + std::string TaskSpecification::DebugString() const { std::ostringstream stream; stream << "Type=" << TaskType_Name(message_->type()) @@ -177,7 +182,8 @@ std::string TaskSpecification::DebugString() const { if (IsActorCreationTask()) { // Print actor creation task spec. stream << ", actor_creation_task_spec={actor_id=" << ActorCreationId() - << ", max_reconstructions=" << MaxActorReconstructions() << "}"; + << ", max_reconstructions=" << MaxActorReconstructions() + << ", is_direct_call=" << IsDirectCall() << "}"; } else if (IsActorTask()) { // Print actor task spec. stream << ", actor_task_spec={actor_id=" << ActorId() diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index d481a4cbb..100f28c41 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -131,6 +131,8 @@ class TaskSpecification : public MessageWrapper { std::vector NewActorHandles() const; + bool IsDirectCall() const; + ObjectID ActorDummyObject() const; std::string DebugString() const; diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index a5967d5c4..c51c38de2 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -90,7 +90,8 @@ class TaskSpecBuilder { /// \return Reference to the builder object itself. TaskSpecBuilder &SetActorCreationTaskSpec( const ActorID &actor_id, uint64_t max_reconstructions = 0, - const std::vector &dynamic_worker_options = {}) { + const std::vector &dynamic_worker_options = {}, + bool is_direct_call = false) { message_->set_type(TaskType::ACTOR_CREATION_TASK); auto actor_creation_spec = message_->mutable_actor_creation_task_spec(); actor_creation_spec->set_actor_id(actor_id.Binary()); @@ -98,6 +99,7 @@ class TaskSpecBuilder { for (const auto &option : dynamic_worker_options) { actor_creation_spec->add_dynamic_worker_options(option); } + actor_creation_spec->set_is_direct_call(is_direct_call); return *this; } diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 663305e4c..39d4ec3a5 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -80,6 +80,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { if (task_spec.IsActorCreationTask()) { RAY_CHECK(current_actor_id_.IsNil()); current_actor_id_ = task_spec.ActorCreationId(); + current_actor_use_direct_call_ = task_spec.IsDirectCall(); } if (task_spec.IsActorTask()) { RAY_CHECK(current_actor_id_ == task_spec.ActorId()); @@ -91,6 +92,10 @@ std::shared_ptr WorkerContext::GetCurrentTask() const { const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; } +bool WorkerContext::CurrentActorUseDirectCall() const { + return current_actor_use_direct_call_; +} + WorkerThreadContext &WorkerContext::GetThreadContext() { if (thread_context_ == nullptr) { thread_context_ = std::unique_ptr(new WorkerThreadContext()); diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 3a02c6915..3c53e415e 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -26,6 +26,8 @@ class WorkerContext { const ActorID &GetCurrentActorID() const; + bool CurrentActorUseDirectCall() const; + int GetNextTaskIndex(); int GetNextPutIndex(); @@ -43,6 +45,9 @@ class WorkerContext { /// ID of current actor. ActorID current_actor_id_; + /// Whether current actor accepts direct calls. + bool current_actor_use_direct_call_; + private: static WorkerThreadContext &GetThreadContext(); diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 7339c54c9..ec9511239 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -49,7 +49,9 @@ jclass java_base_task_options_class; jfieldID java_base_task_options_resources; jclass java_actor_creation_options_class; +jfieldID java_actor_creation_options_default_use_direct_call; jfieldID java_actor_creation_options_max_reconstructions; +jfieldID java_actor_creation_options_use_direct_call; jfieldID java_actor_creation_options_jvm_options; jclass java_gcs_client_options_class; @@ -146,8 +148,12 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_actor_creation_options_class = LoadClass(env, "org/ray/api/options/ActorCreationOptions"); + java_actor_creation_options_default_use_direct_call = env->GetStaticFieldID( + java_actor_creation_options_class, "DEFAULT_USE_DIRECT_CALL", "Z"); java_actor_creation_options_max_reconstructions = env->GetFieldID(java_actor_creation_options_class, "maxReconstructions", "I"); + java_actor_creation_options_use_direct_call = + env->GetFieldID(java_actor_creation_options_class, "useDirectCall", "Z"); java_actor_creation_options_jvm_options = env->GetFieldID( java_actor_creation_options_class, "jvmOptions", "Ljava/lang/String;"); diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index d2f5c71df..5edf9160d 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -92,8 +92,12 @@ extern jfieldID java_base_task_options_resources; /// ActorCreationOptions class extern jclass java_actor_creation_options_class; +/// DEFAULT_USE_DIRECT_CALL field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_default_use_direct_call; /// maxReconstructions field of ActorCreationOptions class extern jfieldID java_actor_creation_options_max_reconstructions; +/// useDirectCall field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_use_direct_call; /// jvmOptions field of ActorCreationOptions class extern jfieldID java_actor_creation_options_jvm_options; diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc index c1c545e8b..abd898f55 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc @@ -129,6 +129,25 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook( ray::RayLog::ShutDownRayLog(); } +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeSetResource + * Signature: (JLjava/lang/String;D[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jstring resourceName, + jdouble capacity, jbyteArray nodeId) { + const auto node_id = JavaByteArrayToId(env, nodeId); + const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); + + auto &raylet_client = + reinterpret_cast(nativeCoreWorkerPointer)->GetRayletClient(); + auto status = raylet_client.SetResource(native_resource_name, + static_cast(capacity), node_id); + env->ReleaseStringUTFChars(resourceName, native_resource_name); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h index c71fec982..480564640 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h @@ -48,6 +48,14 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(JNIEnv JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *, jclass); +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeSetResource + * Signature: (JLjava/lang/String;D[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource( + JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray); + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc index a63e7efa0..1a91e0fb0 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc @@ -57,6 +57,16 @@ JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLangua return (jint)GetActorHandle(nativeActorHandle).ActorLanguage(); } +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeIsDirectCallActor + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeIsDirectCallActor( + JNIEnv *env, jclass o, jlong nativeActorHandle) { + return GetActorHandle(nativeActorHandle).IsDirectCallActor(); +} + /* * Class: org_ray_runtime_actor_NativeRayActor * Method: nativeGetActorCreationTaskFunctionDescriptor diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h index 4de114c7a..245064fcf 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h @@ -40,6 +40,14 @@ Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorHandleId(JNIEnv *, jclas JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(JNIEnv *, jclass, jlong); +/* + * Class: org_ray_runtime_actor_NativeRayActor + * Method: nativeIsDirectCallActor + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeIsDirectCallActor(JNIEnv *, jclass, jlong); + /* * Class: org_ray_runtime_actor_NativeRayActor * Method: nativeGetActorCreationTaskFunctionDescriptor diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc deleted file mode 100644 index e84e4c51e..000000000 --- a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc +++ /dev/null @@ -1,74 +0,0 @@ -#include "ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h" -#include -#include "ray/common/id.h" -#include "ray/core_worker/common.h" -#include "ray/core_worker/core_worker.h" -#include "ray/core_worker/lib/java/jni_utils.h" -#include "ray/raylet/raylet_client.h" - -inline RayletClient &GetRayletClientFromPointer(jlong nativeCoreWorkerPointer) { - return reinterpret_cast(nativeCoreWorkerPointer)->GetRayletClient(); -} - -#ifdef __cplusplus -extern "C" { -#endif - -using ray::ClientID; - -/* - * Class: org_ray_runtime_raylet_NativeRayletClient - * Method: nativePrepareCheckpoint - * Signature: (J[B)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_NativeRayletClient_nativePrepareCheckpoint( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId) { - const auto actor_id = JavaByteArrayToId(env, actorId); - ActorCheckpointID checkpoint_id; - auto status = GetRayletClientFromPointer(nativeCoreWorkerPointer) - .PrepareActorCheckpoint(actor_id, checkpoint_id); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - jbyteArray result = env->NewByteArray(checkpoint_id.Size()); - env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), - reinterpret_cast(checkpoint_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_NativeRayletClient - * Method: nativeNotifyActorResumedFromCheckpoint - * Signature: (J[B[B)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_NativeRayletClient_nativeNotifyActorResumedFromCheckpoint( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId, - jbyteArray checkpointId) { - const auto actor_id = JavaByteArrayToId(env, actorId); - const auto checkpoint_id = JavaByteArrayToId(env, checkpointId); - auto status = GetRayletClientFromPointer(nativeCoreWorkerPointer) - .NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_raylet_NativeRayletClient - * Method: nativeSetResource - * Signature: (JLjava/lang/String;D[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_NativeRayletClient_nativeSetResource( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jstring resourceName, - jdouble capacity, jbyteArray nodeId) { - const auto node_id = JavaByteArrayToId(env, nodeId); - const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); - - auto status = - GetRayletClientFromPointer(nativeCoreWorkerPointer) - .SetResource(native_resource_name, static_cast(capacity), node_id); - env->ReleaseStringUTFChars(resourceName, native_resource_name); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -#ifdef __cplusplus -} -#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h deleted file mode 100644 index 0b54300de..000000000 --- a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.h +++ /dev/null @@ -1,39 +0,0 @@ -/* DO NOT EDIT THIS FILE - it is machine generated */ -#include -/* Header for class org_ray_runtime_raylet_NativeRayletClient */ - -#ifndef _Included_org_ray_runtime_raylet_NativeRayletClient -#define _Included_org_ray_runtime_raylet_NativeRayletClient -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: org_ray_runtime_raylet_NativeRayletClient - * Method: nativePrepareCheckpoint - * Signature: (J[B)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_NativeRayletClient_nativePrepareCheckpoint(JNIEnv *, jclass, - jlong, jbyteArray); - -/* - * Class: org_ray_runtime_raylet_NativeRayletClient - * Method: nativeNotifyActorResumedFromCheckpoint - * Signature: (J[B[B)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_NativeRayletClient_nativeNotifyActorResumedFromCheckpoint( - JNIEnv *, jclass, jlong, jbyteArray, jbyteArray); - -/* - * Class: org_ray_runtime_raylet_NativeRayletClient - * Method: nativeSetResource - * Signature: (JLjava/lang/String;D[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_NativeRayletClient_nativeSetResource( - JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc new file mode 100644 index 000000000..8658c2f8a --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc @@ -0,0 +1,55 @@ +#include "ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h" +#include +#include "ray/common/id.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/lib/java/jni_utils.h" +#include "ray/raylet/raylet_client.h" + +#ifdef __cplusplus +extern "C" { +#endif + +using ray::ClientID; + +/* + * Class: org_ray_runtime_task_NativeTaskExecutor + * Method: nativePrepareCheckpoint + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { + auto &core_worker = *reinterpret_cast(nativeCoreWorkerPointer); + const auto &actor_id = core_worker.GetWorkerContext().GetCurrentActorID(); + const auto &task_spec = core_worker.GetWorkerContext().GetCurrentTask(); + RAY_CHECK(task_spec->IsActorTask()); + ActorCheckpointID checkpoint_id; + auto status = core_worker.GetRayletClient().PrepareActorCheckpoint( + actor_id, checkpoint_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + jbyteArray result = env->NewByteArray(checkpoint_id.Size()); + env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), + reinterpret_cast(checkpoint_id.Data())); + return result; +} + +/* + * Class: org_ray_runtime_task_NativeTaskExecutor + * Method: nativeNotifyActorResumedFromCheckpoint + * Signature: (J[B)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_task_NativeTaskExecutor_nativeNotifyActorResumedFromCheckpoint( + JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray checkpointId) { + auto &core_worker = *reinterpret_cast(nativeCoreWorkerPointer); + const auto &actor_id = core_worker.GetWorkerContext().GetCurrentActorID(); + const auto checkpoint_id = JavaByteArrayToId(env, checkpointId); + auto status = core_worker.GetRayletClient().NotifyActorResumedFromCheckpoint( + actor_id, checkpoint_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + +#ifdef __cplusplus +} +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h new file mode 100644 index 000000000..c51bd22e1 --- /dev/null +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h @@ -0,0 +1,33 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_runtime_task_NativeTaskExecutor */ + +#ifndef _Included_org_ray_runtime_task_NativeTaskExecutor +#define _Included_org_ray_runtime_task_NativeTaskExecutor +#ifdef __cplusplus +extern "C" { +#endif +#undef org_ray_runtime_task_NativeTaskExecutor_NUM_ACTOR_CHECKPOINTS_TO_KEEP +#define org_ray_runtime_task_NativeTaskExecutor_NUM_ACTOR_CHECKPOINTS_TO_KEEP 20L +/* + * Class: org_ray_runtime_task_NativeTaskExecutor + * Method: nativePrepareCheckpoint + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *, jclass, + jlong); + +/* + * Class: org_ray_runtime_task_NativeTaskExecutor + * Method: nativeNotifyActorResumedFromCheckpoint + * Signature: (J[B)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_task_NativeTaskExecutor_nativeNotifyActorResumedFromCheckpoint( + JNIEnv *, jclass, jlong, jbyteArray); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc index b626cf122..a033bdfbd 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc @@ -77,11 +77,14 @@ inline ray::TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject call inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, jobject actorCreationOptions) { uint64_t max_reconstructions = 0; + bool use_direct_call; std::unordered_map resources; std::vector dynamic_worker_options; if (actorCreationOptions) { max_reconstructions = static_cast(env->GetIntField( actorCreationOptions, java_actor_creation_options_max_reconstructions)); + use_direct_call = env->GetBooleanField(actorCreationOptions, + java_actor_creation_options_use_direct_call); jobject java_resources = env->GetObjectField(actorCreationOptions, java_base_task_options_resources); resources = ToResources(env, java_resources); @@ -91,10 +94,14 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, std::string jvm_options = JavaStringToNativeString(env, java_jvm_options); dynamic_worker_options.emplace_back(jvm_options); } + } else { + use_direct_call = + env->GetStaticBooleanField(java_actor_creation_options_class, + java_actor_creation_options_default_use_direct_call); } ray::ActorCreationOptions action_creation_options{ - static_cast(max_reconstructions), false, resources, + static_cast(max_reconstructions), use_direct_call, resources, dynamic_worker_options}; return action_creation_options; } diff --git a/src/ray/core_worker/task_execution.cc b/src/ray/core_worker/task_execution.cc index 5eea02638..3d1688d6b 100644 --- a/src/ray/core_worker/task_execution.cc +++ b/src/ray/core_worker/task_execution.cc @@ -22,12 +22,13 @@ CoreWorkerTaskExecutionInterface::CoreWorkerTaskExecutionInterface( task_receivers_.emplace( TaskTransportType::RAYLET, std::unique_ptr(new CoreWorkerRayletTaskReceiver( - raylet_client, object_interface_, *main_service_, worker_server_, func))); + worker_context_, raylet_client, object_interface_, *main_service_, + worker_server_, func))); task_receivers_.emplace( TaskTransportType::DIRECT_ACTOR, std::unique_ptr( - new CoreWorkerDirectActorTaskReceiver(object_interface_, *main_service_, - worker_server_, func))); + new CoreWorkerDirectActorTaskReceiver(worker_context_, object_interface_, + *main_service_, worker_server_, func))); // Start RPC server after all the task receivers are properly initialized. worker_server_.Run(); diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index 10a792069..09774cd0e 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -173,7 +173,8 @@ Status CoreWorkerTaskInterface::CreateActor( actor_creation_options.resources, actor_creation_options.resources, TaskTransportType::RAYLET, &return_ids); builder.SetActorCreationTaskSpec(actor_id, actor_creation_options.max_reconstructions, - actor_creation_options.dynamic_worker_options); + actor_creation_options.dynamic_worker_options, + actor_creation_options.is_direct_call); *actor_handle = std::unique_ptr(new ActorHandle( actor_id, ActorHandleID::Nil(), function.language, diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 513df2b3e..617f6ff0b 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -21,12 +21,11 @@ CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter( : io_service_(io_service), gcs_client_(gcs_client), client_call_manager_(io_service), - store_provider_(std::move(store_provider)) { - RAY_CHECK_OK(SubscribeActorUpdates()); -} + store_provider_(std::move(store_provider)) {} Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( const TaskSpecification &task_spec) { + RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId(); if (HasByReferenceArgs(task_spec)) { return Status::Invalid("direct actor call only supports by-value arguments"); } @@ -41,6 +40,12 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); std::unique_lock guard(mutex_); + + if (subscribed_actors_.find(actor_id) == subscribed_actors_.end()) { + RAY_CHECK_OK(SubscribeActorUpdates(actor_id)); + subscribed_actors_.insert(actor_id); + } + auto iter = actor_states_.find(actor_id); if (iter == actor_states_.end() || iter->second.state_ == ActorTableData::RECONSTRUCTING) { @@ -51,6 +56,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( // to have a timeout to mark it as invalid if it doesn't show up in the // specified time. pending_requests_[actor_id].emplace_back(std::move(request)); + RAY_LOG(DEBUG) << "Actor " << actor_id << " is not yet created."; return Status::OK(); } else if (iter->second.state_ == ActorTableData::ALIVE) { // Actor is alive, submit the request. @@ -62,17 +68,19 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( // Submit request. auto &client = rpc_clients_[actor_id]; - PushTask(*client, *request, task_id, num_returns); + PushTask(*client, *request, actor_id, task_id, num_returns); return Status::OK(); } else { // Actor is dead, treat the task as failure. RAY_CHECK(iter->second.state_ == ActorTableData::DEAD); TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED); - return Status::IOError("Actor is dead."); + // Return OK here so that we can get the error from store with get operation. + return Status::OK(); } } -Status CoreWorkerDirectActorTaskSubmitter::SubscribeActorUpdates() { +Status CoreWorkerDirectActorTaskSubmitter::SubscribeActorUpdates( + const ActorID &actor_id) { // Register a callback to handle actor notifications. auto actor_notification_callback = [this](const ActorID &actor_id, const ActorTableData &actor_data) { @@ -92,6 +100,19 @@ Status CoreWorkerDirectActorTaskSubmitter::SubscribeActorUpdates() { } else { // Remove rpc client if it's dead or being reconstructed. rpc_clients_.erase(actor_id); + + // For tasks that have been sent and are waiting for replies, treat them + // as failed when the destination actor is dead or reconstructing. + auto iter = waiting_reply_tasks_.find(actor_id); + if (iter != waiting_reply_tasks_.end()) { + for (const auto &entry : iter->second) { + const auto &task_id = entry.first; + const auto num_returns = entry.second; + TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED); + } + waiting_reply_tasks_.erase(actor_id); + } + // If this actor is permanently dead and there are pending requests, treat // the pending tasks as failed. if (actor_data.state() == ActorTableData::DEAD && @@ -111,7 +132,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubscribeActorUpdates() { << ", port: " << actor_data.port(); }; - return gcs_client_.Actors().AsyncSubscribe(actor_notification_callback, nullptr); + return gcs_client_.Actors().AsyncSubscribe(actor_id, actor_notification_callback, + nullptr); } void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks( @@ -125,7 +147,8 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks( auto &requests = pending_requests_[actor_id]; while (!requests.empty()) { const auto &request = *requests.front(); - PushTask(*client, request, TaskID::FromBinary(request.task_spec().task_id()), + PushTask(*client, request, actor_id, + TaskID::FromBinary(request.task_spec().task_id()), request.task_spec().num_returns()); requests.pop_front(); } @@ -133,11 +156,18 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks( void CoreWorkerDirectActorTaskSubmitter::PushTask(rpc::DirectActorClient &client, const rpc::PushTaskRequest &request, + const ActorID &actor_id, const TaskID &task_id, int num_returns) { - auto status = client.PushTask( - request, - [this, task_id, num_returns](Status status, const rpc::PushTaskReply &reply) { + RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id; + waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns)); + auto status = + client.PushTask(request, [this, actor_id, task_id, num_returns]( + Status status, const rpc::PushTaskReply &reply) { + { + std::unique_lock guard(mutex_); + waiting_reply_tasks_[actor_id].erase(task_id); + } if (!status.ok()) { TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED); return; @@ -170,6 +200,8 @@ void CoreWorkerDirectActorTaskSubmitter::PushTask(rpc::DirectActorClient &client void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed( const TaskID &task_id, int num_returns, const rpc::ErrorType &error_type) { + RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id + << ", error_type: " << ErrorType_Name(error_type); for (int i = 0; i < num_returns; i++) { const auto object_id = ObjectID::ForTaskReturn( task_id, /*index=*/i + 1, @@ -181,16 +213,24 @@ void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed( } } -bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) const { +bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) { std::unique_lock guard(mutex_); + + if (subscribed_actors_.find(actor_id) == subscribed_actors_.end()) { + RAY_CHECK_OK(SubscribeActorUpdates(actor_id)); + subscribed_actors_.insert(actor_id); + } + auto iter = actor_states_.find(actor_id); return (iter != actor_states_.end() && iter->second.state_ == ActorTableData::ALIVE); } CoreWorkerDirectActorTaskReceiver::CoreWorkerDirectActorTaskReceiver( - CoreWorkerObjectInterface &object_interface, boost::asio::io_service &io_service, - rpc::GrpcServer &server, const TaskHandler &task_handler) - : object_interface_(object_interface), + WorkerContext &worker_context, CoreWorkerObjectInterface &object_interface, + boost::asio::io_service &io_service, rpc::GrpcServer &server, + const TaskHandler &task_handler) + : worker_context_(worker_context), + object_interface_(object_interface), task_service_(io_service, *this), task_handler_(task_handler) { server.RegisterService(task_service_); @@ -200,12 +240,18 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask( const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { const TaskSpecification task_spec(request.task_spec()); + RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId(); if (HasByReferenceArgs(task_spec)) { send_reply_callback( Status::Invalid("direct actor call only supports by value arguments"), nullptr, nullptr); return; } + if (task_spec.IsActorTask() && !worker_context_.CurrentActorUseDirectCall()) { + send_reply_callback(Status::Invalid("This actor doesn't accept direct calls."), + nullptr, nullptr); + return; + } auto num_returns = task_spec.NumReturns(); RAY_CHECK(task_spec.IsActorCreationTask() || task_spec.IsActorTask()); diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 056615ff0..effa7b01d 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -2,6 +2,7 @@ #define RAY_CORE_WORKER_DIRECT_ACTOR_TRANSPORT_H #include +#include #include "ray/core_worker/object_interface.h" #include "ray/core_worker/transport/transport.h" @@ -39,8 +40,8 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter { Status SubmitTask(const TaskSpecification &task_spec) override; private: - /// Subscribe to all actor updates. - Status SubscribeActorUpdates(); + /// Subscribe to updates of an actor. + Status SubscribeActorUpdates(const ActorID &actor_id); /// Push a task to a remote actor via the given client. /// Note, this function doesn't return any error status code. If an error occurs while @@ -48,11 +49,12 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter { /// /// \param[in] client The RPC client to send tasks to an actor. /// \param[in] request The request to send. + /// \param[in] actor_id Actor ID. /// \param[in] task_id The ID of a task. /// \param[in] num_returns Number of return objects. /// \return Void. void PushTask(rpc::DirectActorClient &client, const rpc::PushTaskRequest &request, - const TaskID &task_id, int num_returns); + const ActorID &actor_id, const TaskID &task_id, int num_returns); /// Treat a task as failed. /// @@ -78,7 +80,7 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter { /// /// \param[in] actor_id The actor ID. /// \return Whether this actor is alive. - bool IsActorAlive(const ActorID &actor_id) const; + bool IsActorAlive(const ActorID &actor_id); /// The IO event loop. boost::asio::io_service &io_service_; @@ -92,24 +94,22 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter { /// Mutex to proect the various maps below. mutable std::mutex mutex_; - /// Map from actor id to actor state. This currently includes all actors in the system. - /// - /// TODO(zhijunfu): this map currently keeps track of all the actors in the system, - /// like `actor_registry_` in raylet. Later after new GCS client interface supports - /// subscribing updates for a specific actor, this will be updated to only include - /// entries for actors that the transport submits tasks to. + /// Map from actor id to actor state. This only includes actors that we send tasks to. std::unordered_map actor_states_; /// Map from actor id to rpc client. This only includes actors that we send tasks to. - /// - /// TODO(zhijunfu): this will be moved into `actor_states_` later when we can - /// subscribe updates for a specific actor. std::unordered_map> rpc_clients_; /// Map from actor id to the actor's pending requests. std::unordered_map>> pending_requests_; + /// Map from actor id to the tasks that are waiting for reply. + std::unordered_map> waiting_reply_tasks_; + + /// The set of actors which are subscribed for further updates. + std::unordered_set subscribed_actors_; + /// The store provider. std::unique_ptr store_provider_; @@ -119,7 +119,8 @@ class CoreWorkerDirectActorTaskSubmitter : public CoreWorkerTaskSubmitter { class CoreWorkerDirectActorTaskReceiver : public CoreWorkerTaskReceiver, public rpc::DirectActorHandler { public: - CoreWorkerDirectActorTaskReceiver(CoreWorkerObjectInterface &object_interface, + CoreWorkerDirectActorTaskReceiver(WorkerContext &worker_context, + CoreWorkerObjectInterface &object_interface, boost::asio::io_service &io_service, rpc::GrpcServer &server, const TaskHandler &task_handler); @@ -135,6 +136,8 @@ class CoreWorkerDirectActorTaskReceiver : public CoreWorkerTaskReceiver, rpc::SendReplyCallback send_reply_callback) override; private: + // Worker context. + WorkerContext &worker_context_; // Object interface. CoreWorkerObjectInterface &object_interface_; /// The rpc service for `DirectActorService`. diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc index 004434f9d..f0f7a01cf 100644 --- a/src/ray/core_worker/transport/raylet_transport.cc +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -14,10 +14,11 @@ Status CoreWorkerRayletTaskSubmitter::SubmitTask(const TaskSpecification &task) } CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver( - std::unique_ptr &raylet_client, + WorkerContext &worker_context, std::unique_ptr &raylet_client, CoreWorkerObjectInterface &object_interface, boost::asio::io_service &io_service, rpc::GrpcServer &server, const TaskHandler &task_handler) - : raylet_client_(raylet_client), + : worker_context_(worker_context), + raylet_client_(raylet_client), object_interface_(object_interface), task_service_(io_service, *this), task_handler_(task_handler) { @@ -30,6 +31,12 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask( const Task task(request.task()); const auto &task_spec = task.GetTaskSpecification(); RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId(); + if (task_spec.IsActorTask() && worker_context_.CurrentActorUseDirectCall()) { + send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr, + nullptr); + return; + } + std::vector> results; auto status = task_handler_(task_spec, &results); diff --git a/src/ray/core_worker/transport/raylet_transport.h b/src/ray/core_worker/transport/raylet_transport.h index 0ba8feb5e..39a529cde 100644 --- a/src/ray/core_worker/transport/raylet_transport.h +++ b/src/ray/core_worker/transport/raylet_transport.h @@ -32,7 +32,8 @@ class CoreWorkerRayletTaskSubmitter : public CoreWorkerTaskSubmitter { class CoreWorkerRayletTaskReceiver : public CoreWorkerTaskReceiver, public rpc::WorkerTaskHandler { public: - CoreWorkerRayletTaskReceiver(std::unique_ptr &raylet_client, + CoreWorkerRayletTaskReceiver(WorkerContext &worker_context, + std::unique_ptr &raylet_client, CoreWorkerObjectInterface &object_interface, boost::asio::io_service &io_service, rpc::GrpcServer &server, const TaskHandler &task_handler); @@ -49,6 +50,8 @@ class CoreWorkerRayletTaskReceiver : public CoreWorkerTaskReceiver, rpc::SendReplyCallback send_reply_callback) override; private: + // Worker context. + WorkerContext &worker_context_; /// Raylet client. std::unique_ptr &raylet_client_; // Object interface. diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 000670af8..737e879a7 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -90,6 +90,8 @@ message ActorCreationTaskSpec { // the placeholder strings (`RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_0`, // `RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_1`, etc) in the worker command. repeated string dynamic_worker_options = 4; + // Whether direct actor call is used. + bool is_direct_call = 5; } // Task spec of an actor task. diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 218e1acd5..8041cc40f 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -109,6 +109,8 @@ message ActorTableData { string ip_address = 9; // The port that the actor is listening on. int32 port = 10; + // Whether direct actor call is used. + bool is_direct_call = 11; } message ErrorTableData { diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index 7381d8d13..7574a57db 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -98,16 +98,19 @@ void ActorRegistration::AddHandle(const ActorHandleID &handle_id, int ActorRegistration::NumHandles() const { return frontier_.size(); } std::shared_ptr ActorRegistration::GenerateCheckpointData( - const ActorID &actor_id, const Task &task) { - const auto actor_handle_id = task.GetTaskSpecification().ActorHandleId(); - const auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); - // Make a copy of the actor registration, and extend its frontier to include - // the most recent task. - // Note(hchen): this is needed because this method is called before - // `FinishAssignedTask`, which will be called when the worker tries to fetch - // the next task. + const ActorID &actor_id, const Task *task) { + // Make a copy of the actor registration ActorRegistration copy = *this; - copy.ExtendFrontier(actor_handle_id, dummy_object); + if (task) { + const auto actor_handle_id = task->GetTaskSpecification().ActorHandleId(); + const auto dummy_object = task->GetTaskSpecification().ActorDummyObject(); + // Extend its frontier to include the most recent task. + // NOTE(hchen): For non-direct-call actors, this is needed because this method is + // called before `FinishAssignedTask`, which will be called when the worker tries to + // fetch the next task. For direct-call actors, checkpoint data doesn't contain + // frontier info, so we don't need to do `ExtendFrontier` here. + copy.ExtendFrontier(actor_handle_id, dummy_object); + } // Use actor's current state to generate checkpoint data. auto checkpoint_data = std::make_shared(); diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index 67bd394e8..8aa40253b 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -133,10 +133,11 @@ class ActorRegistration { /// Generate checkpoint data based on actor's current state. /// /// \param actor_id ID of this actor. - /// \param task The task that just finished on the actor. + /// \param task The task that just finished on the actor. (nullptr when it's direct + /// call.) /// \return A shared pointer to the generated checkpoint data. std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, - const Task &task); + const Task *task); private: /// Information from the global actor table about this actor, including the diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc deleted file mode 100644 index a9ef670b9..000000000 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ /dev/null @@ -1,292 +0,0 @@ -#include "ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h" - -#include - -#include "ray/common/id.h" -#include "ray/core_worker/lib/java/jni_utils.h" -#include "ray/raylet/raylet_client.h" -#include "ray/util/logging.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeInit - * Signature: (Ljava/lang/String;[BZ[B)J - */ -JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( - JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker, - jbyteArray jobId) { - const auto worker_id = JavaByteArrayToId(env, workerId); - const auto job_id = JavaByteArrayToId(env, jobId); - const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE); - auto raylet_client = new std::unique_ptr( - new RayletClient(nativeString, worker_id, isWorker, job_id, Language::JAVA)); - env->ReleaseStringUTFChars(sockName, nativeString); - return reinterpret_cast(raylet_client); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeSubmitTask - * Signature: (J[BLjava/nio/ByteBuffer;II)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask( - JNIEnv *env, jclass, jlong client, jbyteArray taskSpec) { - auto &raylet_client = *reinterpret_cast *>(client); - - jbyte *data = env->GetByteArrayElements(taskSpec, NULL); - jsize size = env->GetArrayLength(taskSpec); - ray::rpc::TaskSpec task_spec_message; - task_spec_message.ParseFromArray(data, size); - env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT); - - ray::TaskSpecification task_spec(task_spec_message); - auto status = raylet_client->SubmitTask(task_spec); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGetTask - * Signature: (J)[B - */ -JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask( - JNIEnv *env, jclass, jlong client) { - auto &raylet_client = *reinterpret_cast *>(client); - - std::unique_ptr spec; - auto status = raylet_client->GetTask(&spec); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - - // Serialize the task spec and copy to Java byte array. - auto task_data = spec->Serialize(); - - jbyteArray result = env->NewByteArray(task_data.size()); - if (result == nullptr) { - return nullptr; /* out of memory error thrown */ - } - - env->SetByteArrayRegion(result, 0, task_data.size(), - reinterpret_cast(task_data.data())); - - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeDestroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy( - JNIEnv *env, jclass, jlong client) { - auto raylet_client = reinterpret_cast *>(client); - auto status = (*raylet_client)->Disconnect(); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); - delete raylet_client; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeWaitObject - * Signature: (J[[BIIZ[B)[Z - */ -JNIEXPORT jbooleanArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( - JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jint numReturns, - jint timeoutMillis, jboolean isWaitLocal, jbyteArray currentTaskId) { - std::vector object_ids; - auto len = env->GetArrayLength(objectIds); - for (int i = 0; i < len; i++) { - jbyteArray object_id_bytes = - static_cast(env->GetObjectArrayElement(objectIds, i)); - const auto object_id = JavaByteArrayToId(env, object_id_bytes); - object_ids.push_back(object_id); - env->DeleteLocalRef(object_id_bytes); - } - const auto current_task_id = JavaByteArrayToId(env, currentTaskId); - - auto &raylet_client = *reinterpret_cast *>(client); - - // Invoke wait. - WaitResultPair result; - auto status = - raylet_client->Wait(object_ids, numReturns, timeoutMillis, - static_cast(isWaitLocal), current_task_id, &result); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - - // Convert result to java object. - jboolean put_value = true; - jbooleanArray resultArray = env->NewBooleanArray(object_ids.size()); - for (uint i = 0; i < result.first.size(); ++i) { - for (uint j = 0; j < object_ids.size(); ++j) { - if (result.first[i] == object_ids[j]) { - env->SetBooleanArrayRegion(resultArray, j, 1, &put_value); - break; - } - } - } - - put_value = false; - for (uint i = 0; i < result.second.size(); ++i) { - for (uint j = 0; j < object_ids.size(); ++j) { - if (result.second[i] == object_ids[j]) { - env->SetBooleanArrayRegion(resultArray, j, 1, &put_value); - break; - } - } - } - return resultArray; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateActorCreationTaskId - * Signature: ([B[BI)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorCreationTaskId( - JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, - jint parent_task_counter) { - const auto job_id = JavaByteArrayToId(env, jobId); - const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); - - const ActorID actor_id = ray::ActorID::Of(job_id, parent_task_id, parent_task_counter); - const TaskID actor_creation_task_id = ray::TaskID::ForActorCreationTask(actor_id); - jbyteArray result = env->NewByteArray(actor_creation_task_id.Size()); - if (nullptr == result) { - return nullptr; - } - env->SetByteArrayRegion(result, 0, actor_creation_task_id.Size(), - reinterpret_cast(actor_creation_task_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateActorTaskId - * Signature: ([B[BI[B)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorTaskId( - JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, - jint parent_task_counter, jbyteArray actorId) { - const auto job_id = JavaByteArrayToId(env, jobId); - const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); - const auto actor_id = JavaByteArrayToId(env, actorId); - const TaskID actor_task_id = - ray::TaskID::ForActorTask(job_id, parent_task_id, parent_task_counter, actor_id); - - jbyteArray result = env->NewByteArray(actor_task_id.Size()); - if (nullptr == result) { - return nullptr; - } - env->SetByteArrayRegion(result, 0, actor_task_id.Size(), - reinterpret_cast(actor_task_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeGenerateNormalTaskId - * Signature: ([B[BI)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateNormalTaskId( - JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, - jint parent_task_counter) { - const auto job_id = JavaByteArrayToId(env, jobId); - const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); - const TaskID task_id = - ray::TaskID::ForNormalTask(job_id, parent_task_id, parent_task_counter); - - jbyteArray result = env->NewByteArray(task_id.Size()); - if (nullptr == result) { - return nullptr; - } - env->SetByteArrayRegion(result, 0, task_id.Size(), - reinterpret_cast(task_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeFreePlasmaObjects - * Signature: (J[[BZZ)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( - JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean localOnly, - jboolean deleteCreatingTasks) { - std::vector object_ids; - auto len = env->GetArrayLength(objectIds); - for (int i = 0; i < len; i++) { - jbyteArray object_id_bytes = - static_cast(env->GetObjectArrayElement(objectIds, i)); - const auto object_id = JavaByteArrayToId(env, object_id_bytes); - object_ids.push_back(object_id); - env->DeleteLocalRef(object_id_bytes); - } - auto &raylet_client = *reinterpret_cast *>(client); - auto status = raylet_client->FreeObjects(object_ids, localOnly, deleteCreatingTasks); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativePrepareCheckpoint - * Signature: (J[B)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env, jclass, - jlong client, - jbyteArray actorId) { - auto &raylet_client = *reinterpret_cast *>(client); - const auto actor_id = JavaByteArrayToId(env, actorId); - ActorCheckpointID checkpoint_id; - auto status = raylet_client->PrepareActorCheckpoint(actor_id, checkpoint_id); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); - jbyteArray result = env->NewByteArray(checkpoint_id.Size()); - env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), - reinterpret_cast(checkpoint_id.Data())); - return result; -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeNotifyActorResumedFromCheckpoint - * Signature: (J[B[B)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint( - JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) { - auto &raylet_client = *reinterpret_cast *>(client); - const auto actor_id = JavaByteArrayToId(env, actorId); - const auto checkpoint_id = JavaByteArrayToId(env, checkpointId); - auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -/* - * Class: org_ray_runtime_raylet_RayletClientImpl - * Method: nativeSetResource - * Signature: (JLjava/lang/String;D[B)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource( - JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity, - jbyteArray nodeId) { - auto &raylet_client = *reinterpret_cast *>(client); - const auto node_id = JavaByteArrayToId(env, nodeId); - const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); - - auto status = raylet_client->SetResource(native_resource_name, - static_cast(capacity), node_id); - env->ReleaseStringUTFChars(resourceName, native_resource_name); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); -} - -#ifdef __cplusplus -} -#endif diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index fa113cce8..7bab0b313 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1225,13 +1225,19 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); RAY_CHECK(worker && worker->GetActorId() == actor_id); - // Find the task that is running on this actor. - const auto task_id = worker->GetAssignedTaskId(); - const Task &task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING); - // Generate checkpoint id and data. ActorCheckpointID checkpoint_id = ActorCheckpointID::FromRandom(); - auto checkpoint_data = - actor_entry->second.GenerateCheckpointData(actor_entry->first, task); + std::shared_ptr checkpoint_data; + if (actor_entry->second.GetTableData().is_direct_call()) { + checkpoint_data = + actor_entry->second.GenerateCheckpointData(actor_entry->first, nullptr); + } else { + // Find the task that is running on this actor. + const auto task_id = worker->GetAssignedTaskId(); + const Task &task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING); + // Generate checkpoint data. + checkpoint_data = + actor_entry->second.GenerateCheckpointData(actor_entry->first, &task); + } // Write checkpoint data to GCS. RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add( @@ -1914,6 +1920,7 @@ std::shared_ptr NodeManager::CreateActorTableDataFromCreationTas // This is the first time that the actor has been created, so the number // of remaining reconstructions is the max. actor_info_ptr->set_remaining_reconstructions(task_spec.MaxActorReconstructions()); + actor_info_ptr->set_is_direct_call(task_spec.IsDirectCall()); } else { // If we've already seen this actor, it means that this actor was reconstructed. // Thus, its previous state must be RECONSTRUCTING. diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 52814bb70..947775d6d 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -128,12 +128,15 @@ void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set, task.GetTaskExecutionSpec().GetMessage()); request.set_resource_ids(resource_id_set.Serialize()); - auto status = rpc_client_->AssignTask( - request, [](Status status, const rpc::AssignTaskReply &reply) { - // Worker has finished this task. There's nothing to do here - // and assigning new task will be done when raylet receives - // `TaskDone` message. - }); + auto status = rpc_client_->AssignTask(request, [](Status status, + const rpc::AssignTaskReply &reply) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Worker failed to finish executing task: " << status.ToString(); + } + // Worker has finished this task. There's nothing to do here + // and assigning new task will be done when raylet receives + // `TaskDone` message. + }); finish_assign_callback(status); if (!status.ok()) { RAY_LOG(ERROR) << "Failed to assign task " << task.GetTaskSpecification().TaskId()