diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index db2e5e841..c7b0df21b 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -282,6 +282,11 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return objectStore; } + @Override + public TaskExecutor getTaskExecutor() { + return taskExecutor; + } + @Override public FunctionManager getFunctionManager() { return functionManager; diff --git a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java b/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java index f155fce4d..d70e4ed2d 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayRuntimeInternal.java @@ -6,6 +6,7 @@ import io.ray.runtime.context.WorkerContext; import io.ray.runtime.functionmanager.FunctionManager; import io.ray.runtime.gcs.GcsClient; import io.ray.runtime.object.ObjectStore; +import io.ray.runtime.task.TaskExecutor; /** * This interface is required to make {@link RayRuntimeProxy} work. @@ -21,6 +22,8 @@ public interface RayRuntimeInternal extends RayRuntime { ObjectStore getObjectStore(); + TaskExecutor getTaskExecutor(); + FunctionManager getFunctionManager(); RayConfig getRayConfig(); diff --git a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskExecutor.java b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskExecutor.java index 808465a80..a41a14b60 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskExecutor.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskExecutor.java @@ -45,6 +45,11 @@ public class NativeTaskExecutor extends TaskExecutor { this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext); } + protected void removeActorContext(UniqueId workerId) { + this.actorContextMap.remove(workerId); + } + private RayFunction getRayFunction(List rayFunctionInfo) { JobId jobId = runtime.getWorkerContext().getCurrentJobId(); JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo); diff --git a/java/test/src/main/java/io/ray/test/ExitActorTest.java b/java/test/src/main/java/io/ray/test/ExitActorTest.java index 9c95cb960..a0cdc8336 100644 --- a/java/test/src/main/java/io/ray/test/ExitActorTest.java +++ b/java/test/src/main/java/io/ray/test/ExitActorTest.java @@ -9,9 +9,12 @@ import io.ray.api.Ray; import io.ray.api.id.ActorId; import io.ray.api.id.UniqueId; import io.ray.runtime.exception.RayActorException; +import io.ray.runtime.task.TaskExecutor; import io.ray.runtime.util.SystemUtil; import java.io.IOException; +import java.lang.reflect.Field; import java.util.List; +import java.util.Map; import java.util.concurrent.TimeUnit; import org.testng.Assert; import org.testng.annotations.Test; @@ -31,6 +34,17 @@ public class ExitActorTest extends BaseTest { return pid(); } + public int getSizeOfActorContextMap() { + TaskExecutor taskExecutor = TestUtils.getRuntime().getTaskExecutor(); + try { + Field field = TaskExecutor.class.getDeclaredField("actorContextMap"); + field.setAccessible(true); + return ((Map)field.get(taskExecutor)).size(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + @Override public boolean shouldCheckpoint(CheckpointContext checkpointContext) { return true; @@ -77,6 +91,8 @@ public class ExitActorTest extends BaseTest { ActorHandle actor1 = Ray.actor(ExitingActor::new) .setMaxRestarts(10000).remote(); int pid = actor1.task(ExitingActor::getPid).remote().get(); + Assert.assertEquals( + 1, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get()); ActorHandle actor2; while (true) { // Create another actor which share the same process of actor 1. @@ -86,11 +102,17 @@ public class ExitActorTest extends BaseTest { break; } } + Assert.assertEquals( + 2, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get()); + Assert.assertEquals( + 2, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get()); ObjectRef obj1 = actor1.task(ExitingActor::exit).remote(); Assert.assertThrows(RayActorException.class, obj1::get); Assert.assertTrue(SystemUtil.isProcessAlive(pid)); // Actor 2 shouldn't exit or be reconstructed. Assert.assertEquals(1, (int) actor2.task(ExitingActor::incr).remote().get()); + Assert.assertEquals( + 1, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get()); Assert.assertEquals(pid, (int) actor2.task(ExitingActor::getPid).remote().get()); Assert.assertTrue(SystemUtil.isProcessAlive(pid)); } diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index cd5915733..351385c06 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -19,6 +19,7 @@ from ray.includes.unique_ids cimport ( CTaskID, CObjectID, CPlacementGroupID, + CWorkerID, ) from ray.includes.common cimport ( CAddress, @@ -227,6 +228,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const c_vector[CObjectID] &return_ids, c_vector[shared_ptr[CRayObject]] *returns) nogil ) task_execution_callback + (void(const CWorkerID &) nogil) on_worker_shutdown (CRayStatus() nogil) check_signals (void() nogil) gc_collect (c_vector[c_string](const c_vector[CObjectID]&) nogil) spill_objects diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 0b838e5a2..1150b5d77 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -541,6 +541,9 @@ void CoreWorker::Shutdown() { if (options_.worker_type == WorkerType::WORKER) { task_execution_service_.stop(); } + if (options_.on_worker_shutdown) { + options_.on_worker_shutdown(GetWorkerID()); + } } void CoreWorker::Disconnect() { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 64fd72d35..c7f419732 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -123,6 +123,8 @@ struct CoreWorkerOptions { std::string stderr_file; /// Language worker callback to execute tasks. TaskExecutionCallback task_execution_callback; + /// The callback to be called when shutting down a `CoreWorker` instance. + std::function on_worker_shutdown; /// Application-language callback to check for signals that have been received /// since calling into C++. This will be called periodically (at least every /// 1s) during long-running operations. If the function returns anything but StatusOK, diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index ad9d5b170..ee8c76a29 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -150,6 +150,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( if (throwable && env->IsInstanceOf(throwable, java_ray_intentional_system_exit_exception_class)) { + env->ExceptionClear(); return ray::Status::IntentionalSystemExit(); } RAY_CHECK_JAVA_EXCEPTION(env); @@ -211,6 +212,16 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( } }; + auto on_worker_shutdown = [](const ray::WorkerID &worker_id) { + JNIEnv *env = GetJNIEnv(); + auto worker_id_bytes = IdToJavaByteArray(env, worker_id); + if (java_task_executor) { + env->CallVoidMethod(java_task_executor, + java_native_task_executor_on_worker_shutdown, worker_id_bytes); + RAY_CHECK_JAVA_EXCEPTION(env); + } + }; + std::string serialized_job_config = (jobConfig == nullptr ? "" : JavaByteArrayToNativeString(env, jobConfig)); ray::CoreWorkerOptions options; @@ -229,6 +240,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( options.raylet_ip_address = JavaStringToNativeString(env, nodeIpAddress); options.driver_name = JavaStringToNativeString(env, driverName); options.task_execution_callback = task_execution_callback; + options.on_worker_shutdown = on_worker_shutdown; options.gc_collect = gc_collect; options.ref_counting_enabled = true; options.num_workers = static_cast(numWorkersPerProcess); diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 4f2b05c6b..522c12d3d 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -107,6 +107,9 @@ jclass java_task_executor_class; jmethodID java_task_executor_parse_function_arguments; jmethodID java_task_executor_execute; +jclass java_native_task_executor_class; +jmethodID java_native_task_executor_on_worker_shutdown; + jclass java_placement_group_class; jfieldID java_placement_group_id; @@ -267,6 +270,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_task_executor_execute = env->GetMethodID(java_task_executor_class, "execute", "(Ljava/util/List;Ljava/util/List;)Ljava/util/List;"); + java_native_task_executor_class = + LoadClass(env, "io/ray/runtime/task/NativeTaskExecutor"); + java_native_task_executor_on_worker_shutdown = + env->GetMethodID(java_native_task_executor_class, "onWorkerShutdown", "([B)V"); return CURRENT_JNI_VERSION; } @@ -298,4 +305,5 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { env->DeleteGlobalRef(java_actor_creation_options_class); env->DeleteGlobalRef(java_native_ray_object_class); env->DeleteGlobalRef(java_task_executor_class); + env->DeleteGlobalRef(java_native_task_executor_class); } diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 5b849d885..81aea813e 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -187,6 +187,11 @@ extern jmethodID java_task_executor_parse_function_arguments; /// execute method of TaskExecutor class extern jmethodID java_task_executor_execute; +/// NativeTaskExecutor class +extern jclass java_native_task_executor_class; +/// onWorkerShutdown method of NativeTaskExecutor class +extern jmethodID java_native_task_executor_on_worker_shutdown; + /// PlacementGroup class extern jclass java_placement_group_class; /// id field of PlacementGroup class