From f3703bafa3b16c80d03f0ce18860e8e5eb468bed Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Fri, 14 Feb 2020 13:02:39 +0800 Subject: [PATCH] [Java] Support concurrent actor calls API. (#7022) * WIP Temp change Attach native thread to jvm * Fix run mode * Address comments. --- .../ray/api/options/ActorCreationOptions.java | 23 ++++++++- .../org/ray/runtime/RayNativeRuntime.java | 11 ++-- .../org/ray/runtime/task/TaskExecutor.java | 15 ++++++ .../ray/api/test/ActorConcurrentCallTest.java | 51 +++++++++++++++++++ python/ray/_raylet.pyx | 3 +- python/ray/includes/libcoreworker.pxd | 4 +- src/ray/core_worker/core_worker.cc | 6 +-- src/ray/core_worker/core_worker.h | 2 +- src/ray/core_worker/lib/java/jni_init.cc | 10 +++- src/ray/core_worker/lib/java/jni_utils.h | 5 ++ .../java/org_ray_runtime_RayNativeRuntime.cc | 21 ++++++-- .../java/org_ray_runtime_RayNativeRuntime.h | 4 +- ...rg_ray_runtime_task_NativeTaskSubmitter.cc | 5 +- 13 files changed, 139 insertions(+), 21 deletions(-) create mode 100644 java/test/src/main/java/org/ray/api/test/ActorConcurrentCallTest.java diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index d4131eb1b..5577a35d4 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -21,12 +21,15 @@ public class ActorCreationOptions extends BaseTaskOptions { public final String jvmOptions; + public final int maxConcurrency; + private ActorCreationOptions(Map resources, int maxReconstructions, - boolean useDirectCall, String jvmOptions) { + boolean useDirectCall, String jvmOptions, int maxConcurrency) { super(resources); this.maxReconstructions = maxReconstructions; this.useDirectCall = useDirectCall; this.jvmOptions = jvmOptions; + this.maxConcurrency = maxConcurrency; } /** @@ -38,6 +41,7 @@ public class ActorCreationOptions extends BaseTaskOptions { private int maxReconstructions = NO_RECONSTRUCTION; private boolean useDirectCall = DEFAULT_USE_DIRECT_CALL; private String jvmOptions = null; + private int maxConcurrency = 1; public Builder setResources(Map resources) { this.resources = resources; @@ -62,8 +66,23 @@ public class ActorCreationOptions extends BaseTaskOptions { return this; } + // The max number of concurrent calls to allow for this actor. + // + // This only works with direct actor calls. The max concurrency defaults to 1 + // for threaded execution. Note that the execution order is not guaranteed + // when max_concurrency > 1. + public Builder setMaxConcurrency(int maxConcurrency) { + if (maxConcurrency <= 0) { + throw new IllegalArgumentException("maxConcurrency must be greater than 0."); + } + + this.maxConcurrency = maxConcurrency; + return this; + } + public ActorCreationOptions createActorCreationOptions() { - return new ActorCreationOptions(resources, maxReconstructions, useDirectCall, jvmOptions); + return new ActorCreationOptions( + resources, maxReconstructions, useDirectCall, jvmOptions, maxConcurrency); } } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 7e315c5ee..5e1be7bb6 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -95,8 +95,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime { new GcsClientOptions(rayConfig)); Preconditions.checkState(nativeCoreWorkerPointer != 0); - taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this); workerContext = new NativeWorkerContext(nativeCoreWorkerPointer); + taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this); objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer); taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer); @@ -153,13 +153,17 @@ public final class RayNativeRuntime extends AbstractRayRuntime { } public void run() { - nativeRunTaskExecutor(nativeCoreWorkerPointer, taskExecutor); + nativeRunTaskExecutor(nativeCoreWorkerPointer); } public long getNativeCoreWorkerPointer() { return nativeCoreWorkerPointer; } + public TaskExecutor getTaskExecutor() { + return taskExecutor; + } + /** * Register this worker or driver to GCS. */ @@ -189,8 +193,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { String rayletSocket, String nodeIpAddress, int nodeManagerPort, byte[] jobId, GcsClientOptions gcsClientOptions); - private static native void nativeRunTaskExecutor(long nativeCoreWorkerPointer, - TaskExecutor taskExecutor); + private static native void nativeRunTaskExecutor(long nativeCoreWorkerPointer); private static native void nativeDestroyCoreWorker(long nativeCoreWorkerPointer); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java index 230faf7e2..3ca133af7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java @@ -3,11 +3,15 @@ package org.ray.runtime.task; import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.ConcurrentHashMap; import org.ray.api.exception.RayTaskException; import org.ray.api.id.ActorId; import org.ray.api.id.JobId; import org.ray.api.id.TaskId; +import org.ray.api.id.UniqueId; import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.config.RayConfig; +import org.ray.runtime.config.RunMode; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.functionmanager.RayFunction; import org.ray.runtime.generated.Common.TaskType; @@ -23,6 +27,10 @@ public abstract class TaskExecutor { private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class); + // A helper map to help we get the corresponding executor for the given worker in JNI. + private static ConcurrentHashMap taskExecutors + = new ConcurrentHashMap<>(); + protected final AbstractRayRuntime runtime; /** @@ -37,6 +45,13 @@ public abstract class TaskExecutor { protected TaskExecutor(AbstractRayRuntime runtime) { this.runtime = runtime; + if (RayConfig.getInstance().runMode == RunMode.CLUSTER) { + taskExecutors.put(runtime.getWorkerContext().getCurrentWorkerId(), this); + } + } + + public static TaskExecutor get(byte[] workerId) { + return taskExecutors.get(new UniqueId(workerId)); } protected List execute(List rayFunctionInfo, diff --git a/java/test/src/main/java/org/ray/api/test/ActorConcurrentCallTest.java b/java/test/src/main/java/org/ray/api/test/ActorConcurrentCallTest.java new file mode 100644 index 000000000..850bae9dd --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/ActorConcurrentCallTest.java @@ -0,0 +1,51 @@ +package org.ray.api.test; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.RayObject; +import org.ray.api.TestUtils; +import org.ray.api.annotation.RayRemote; +import org.ray.api.options.ActorCreationOptions; +import org.testng.Assert; +import org.testng.annotations.Test; + + +@Test(groups = {"directCall"}) +public class ActorConcurrentCallTest extends BaseTest { + + @RayRemote + public static class ConcurrentActor { + private final CountDownLatch countDownLatch = new CountDownLatch(3); + + public String countDown() { + countDownLatch.countDown(); + try { + countDownLatch.await(); + return "ok"; + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + + public void testConcurrentCall() { + TestUtils.skipTestIfDirectActorCallDisabled(); + + ActorCreationOptions op = new ActorCreationOptions.Builder() + .setMaxConcurrency(3) + .createActorCreationOptions(); + RayActor actor = Ray.createActor(ConcurrentActor::new, op); + RayObject obj1 = Ray.call(ConcurrentActor::countDown, actor); + RayObject obj2 = Ray.call(ConcurrentActor::countDown, actor); + RayObject obj3 = Ray.call(ConcurrentActor::countDown, actor); + + List expectedResult = ImmutableList.of(1, 2, 3); + Assert.assertEquals(obj1.get(), "ok"); + Assert.assertEquals(obj2.get(), "ok"); + Assert.assertEquals(obj3.get(), "ok"); + } + +} diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 49bef9442..1339cdd93 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -515,7 +515,8 @@ cdef CRayStatus task_execution_handler( const c_vector[shared_ptr[CRayObject]] &c_args, const c_vector[CObjectID] &c_arg_reference_ids, const c_vector[CObjectID] &c_return_ids, - c_vector[shared_ptr[CRayObject]] *returns) nogil: + c_vector[shared_ptr[CRayObject]] *returns, + const CWorkerID &c_worker_id) nogil: with gil: try: diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 360552fb0..c24081423 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -17,6 +17,7 @@ from ray.includes.unique_ids cimport ( CJobID, CTaskID, CObjectID, + CWorkerID, ) from ray.includes.common cimport ( CAddress, @@ -79,7 +80,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const c_vector[shared_ptr[CRayObject]] &args, const c_vector[CObjectID] &arg_reference_ids, const c_vector[CObjectID] &return_ids, - c_vector[shared_ptr[CRayObject]] *returns) nogil, + c_vector[shared_ptr[CRayObject]] *returns, + const CWorkerID &worker_id) nogil, CRayStatus() nogil, c_bool ref_counting_enabled) CWorkerType &GetWorkerType() diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 554612ac2..bbc7debfc 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -961,9 +961,9 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, task_type = TaskType::ACTOR_TASK; } - status = task_execution_callback_(task_type, func, - task_spec.GetRequiredResources().GetResourceMap(), - args, arg_reference_ids, return_ids, return_objects); + status = task_execution_callback_( + task_type, func, task_spec.GetRequiredResources().GetResourceMap(), args, + arg_reference_ids, return_ids, return_objects, worker_context_.GetWorkerID()); for (size_t i = 0; i < return_objects->size(); i++) { // The object is nullptr if it already existed in the object store. diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index e9042f6c1..d34e9c9c3 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -47,7 +47,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { const std::vector> &args, const std::vector &arg_reference_ids, const std::vector &return_ids, - std::vector> *results)>; + std::vector> *results, const ray::WorkerID &worker_id)>; public: /// Construct a CoreWorker instance. diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 3ee763694..9b3fe5ee1 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -56,6 +56,7 @@ jfieldID java_actor_creation_options_default_use_direct_call; jfieldID java_actor_creation_options_max_reconstructions; jfieldID java_actor_creation_options_use_direct_call; jfieldID java_actor_creation_options_jvm_options; +jfieldID java_actor_creation_options_max_concurrency; jclass java_gcs_client_options_class; jfieldID java_gcs_client_options_ip; @@ -69,6 +70,7 @@ jfieldID java_native_ray_object_metadata; jclass java_task_executor_class; jmethodID java_task_executor_execute; +jmethodID java_task_executor_get; JavaVM *jvm; @@ -164,7 +166,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { env->GetFieldID(java_actor_creation_options_class, "useDirectCall", "Z"); java_actor_creation_options_jvm_options = env->GetFieldID( java_actor_creation_options_class, "jvmOptions", "Ljava/lang/String;"); - + java_actor_creation_options_max_concurrency = + env->GetFieldID(java_actor_creation_options_class, "maxConcurrency", "I"); java_gcs_client_options_class = LoadClass(env, "org/ray/runtime/gcs/GcsClientOptions"); java_gcs_client_options_ip = env->GetFieldID(java_gcs_client_options_class, "ip", "Ljava/lang/String;"); @@ -186,6 +189,11 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { env->GetMethodID(java_task_executor_class, "execute", "(Ljava/util/List;Ljava/util/List;)Ljava/util/List;"); + java_task_executor_get = env->GetStaticMethodID( + java_task_executor_class, + "get", + "([B)Lorg/ray/runtime/task/TaskExecutor;"); + return CURRENT_JNI_VERSION; } diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index abbad3236..f10dbf95f 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -105,6 +105,8 @@ extern jfieldID java_actor_creation_options_max_reconstructions; extern jfieldID java_actor_creation_options_use_direct_call; /// jvmOptions field of ActorCreationOptions class extern jfieldID java_actor_creation_options_jvm_options; +/// maxConcurrency field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_max_concurrency; /// GcsClientOptions class extern jclass java_gcs_client_options_class; @@ -129,6 +131,9 @@ extern jclass java_task_executor_class; /// execute method of TaskExecutor class extern jmethodID java_task_executor_execute; +/// The `get` method in TaskExecutor class +extern jmethodID java_task_executor_get; + #define CURRENT_JNI_VERSION JNI_VERSION_1_8 extern JavaVM *jvm; diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc index 8a98a9cd3..bbb179c4a 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc @@ -6,7 +6,6 @@ #include "ray/core_worker/lib/java/jni_utils.h" thread_local JNIEnv *local_env = nullptr; -thread_local jobject local_java_task_executor = nullptr; inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env, jobject gcs_client_options) { @@ -39,9 +38,23 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork const std::vector> &args, const std::vector &arg_reference_ids, const std::vector &return_ids, - std::vector> *results) { + std::vector> *results, + const ray::WorkerID &worker_id) { JNIEnv *env = local_env; + if (!env) { + // Attach the native thread to JVM. + auto status = + jvm->AttachCurrentThreadAsDaemon(reinterpret_cast(&env), nullptr); + RAY_CHECK(status == JNI_OK) << "Failed to get JNIEnv. Return code: " << status; + local_env = env; + } + RAY_CHECK(env); + + auto worker_id_bytes = IdToJavaByteArray(env, worker_id); + jobject local_java_task_executor = env->CallStaticObjectMethod( + java_task_executor_class, java_task_executor_get, worker_id_bytes); + RAY_CHECK(local_java_task_executor); // convert RayFunction jobject ray_function_array_list = NativeRayFunctionDescriptorToJavaStringList( @@ -87,13 +100,11 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork } JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor( - JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jobject javaTaskExecutor) { + JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer) { local_env = env; - local_java_task_executor = javaTaskExecutor; auto core_worker = reinterpret_cast(nativeCoreWorkerPointer); core_worker->StartExecutingTasks(); local_env = nullptr; - local_java_task_executor = nullptr; } JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWorker( diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h index 55f9ca298..faef1dd52 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h @@ -19,10 +19,10 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork /* * Class: org_ray_runtime_RayNativeRuntime * Method: nativeRunTaskExecutor - * Signature: (JLorg/ray/runtime/task/TaskExecutor;)V + * Signature: (J)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor( - JNIEnv *, jclass, jlong, jobject); + JNIEnv *, jclass, jlong); /* * Class: org_ray_runtime_RayNativeRuntime diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc index 87c801c6f..236a42f1f 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc @@ -92,6 +92,7 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, bool use_direct_call; std::unordered_map resources; std::vector dynamic_worker_options; + uint64_t max_concurrency = 1; if (actorCreationOptions) { max_reconstructions = static_cast(env->GetIntField( actorCreationOptions, java_actor_creation_options_max_reconstructions)); @@ -106,6 +107,8 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, std::string jvm_options = JavaStringToNativeString(env, java_jvm_options); dynamic_worker_options.emplace_back(jvm_options); } + max_concurrency = static_cast(env->GetIntField( + actorCreationOptions, java_actor_creation_options_max_concurrency)); } else { use_direct_call = env->GetStaticBooleanField(java_actor_creation_options_class, @@ -115,7 +118,7 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, ray::ActorCreationOptions actor_creation_options{ static_cast(max_reconstructions), use_direct_call, - /*max_concurrency=*/1, + static_cast(max_concurrency), resources, resources, dynamic_worker_options,