mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 21:04:35 +08:00
[Java] Release actor instance reference when Ray.exitActor() is invoked (#11324)
This commit is contained in:
@@ -282,6 +282,11 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
return objectStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskExecutor getTaskExecutor() {
|
||||
return taskExecutor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FunctionManager getFunctionManager() {
|
||||
return functionManager;
|
||||
|
||||
@@ -6,6 +6,7 @@ import io.ray.runtime.context.WorkerContext;
|
||||
import io.ray.runtime.functionmanager.FunctionManager;
|
||||
import io.ray.runtime.gcs.GcsClient;
|
||||
import io.ray.runtime.object.ObjectStore;
|
||||
import io.ray.runtime.task.TaskExecutor;
|
||||
|
||||
/**
|
||||
* This interface is required to make {@link RayRuntimeProxy} work.
|
||||
@@ -21,6 +22,8 @@ public interface RayRuntimeInternal extends RayRuntime {
|
||||
|
||||
ObjectStore getObjectStore();
|
||||
|
||||
TaskExecutor getTaskExecutor();
|
||||
|
||||
FunctionManager getFunctionManager();
|
||||
|
||||
RayConfig getRayConfig();
|
||||
|
||||
@@ -45,6 +45,11 @@ public class NativeTaskExecutor extends TaskExecutor<NativeTaskExecutor.NativeAc
|
||||
return new NativeActorContext();
|
||||
}
|
||||
|
||||
public void onWorkerShutdown(byte[] workerIdBytes) {
|
||||
// This is to make sure no memory leak when `Ray.exitActor()` is called.
|
||||
removeActorContext(new UniqueId(workerIdBytes));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void maybeSaveCheckpoint(Object actor, ActorId actorId) {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
|
||||
@@ -65,6 +65,10 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
|
||||
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
|
||||
}
|
||||
|
||||
protected void removeActorContext(UniqueId workerId) {
|
||||
this.actorContextMap.remove(workerId);
|
||||
}
|
||||
|
||||
private RayFunction getRayFunction(List<String> rayFunctionInfo) {
|
||||
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
|
||||
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
|
||||
|
||||
@@ -9,9 +9,12 @@ import io.ray.api.Ray;
|
||||
import io.ray.api.id.ActorId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.exception.RayActorException;
|
||||
import io.ray.runtime.task.TaskExecutor;
|
||||
import io.ray.runtime.util.SystemUtil;
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
@@ -31,6 +34,17 @@ public class ExitActorTest extends BaseTest {
|
||||
return pid();
|
||||
}
|
||||
|
||||
public int getSizeOfActorContextMap() {
|
||||
TaskExecutor taskExecutor = TestUtils.getRuntime().getTaskExecutor();
|
||||
try {
|
||||
Field field = TaskExecutor.class.getDeclaredField("actorContextMap");
|
||||
field.setAccessible(true);
|
||||
return ((Map<?, ?>)field.get(taskExecutor)).size();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean shouldCheckpoint(CheckpointContext checkpointContext) {
|
||||
return true;
|
||||
@@ -77,6 +91,8 @@ public class ExitActorTest extends BaseTest {
|
||||
ActorHandle<ExitingActor> actor1 = Ray.actor(ExitingActor::new)
|
||||
.setMaxRestarts(10000).remote();
|
||||
int pid = actor1.task(ExitingActor::getPid).remote().get();
|
||||
Assert.assertEquals(
|
||||
1, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get());
|
||||
ActorHandle<ExitingActor> actor2;
|
||||
while (true) {
|
||||
// Create another actor which share the same process of actor 1.
|
||||
@@ -86,11 +102,17 @@ public class ExitActorTest extends BaseTest {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Assert.assertEquals(
|
||||
2, (int) actor1.task(ExitingActor::getSizeOfActorContextMap).remote().get());
|
||||
Assert.assertEquals(
|
||||
2, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get());
|
||||
ObjectRef<Boolean> obj1 = actor1.task(ExitingActor::exit).remote();
|
||||
Assert.assertThrows(RayActorException.class, obj1::get);
|
||||
Assert.assertTrue(SystemUtil.isProcessAlive(pid));
|
||||
// Actor 2 shouldn't exit or be reconstructed.
|
||||
Assert.assertEquals(1, (int) actor2.task(ExitingActor::incr).remote().get());
|
||||
Assert.assertEquals(
|
||||
1, (int) actor2.task(ExitingActor::getSizeOfActorContextMap).remote().get());
|
||||
Assert.assertEquals(pid, (int) actor2.task(ExitingActor::getPid).remote().get());
|
||||
Assert.assertTrue(SystemUtil.isProcessAlive(pid));
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ from ray.includes.unique_ids cimport (
|
||||
CTaskID,
|
||||
CObjectID,
|
||||
CPlacementGroupID,
|
||||
CWorkerID,
|
||||
)
|
||||
from ray.includes.common cimport (
|
||||
CAddress,
|
||||
@@ -227,6 +228,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
const c_vector[CObjectID] &return_ids,
|
||||
c_vector[shared_ptr[CRayObject]] *returns) nogil
|
||||
) task_execution_callback
|
||||
(void(const CWorkerID &) nogil) on_worker_shutdown
|
||||
(CRayStatus() nogil) check_signals
|
||||
(void() nogil) gc_collect
|
||||
(c_vector[c_string](const c_vector[CObjectID]&) nogil) spill_objects
|
||||
|
||||
@@ -541,6 +541,9 @@ void CoreWorker::Shutdown() {
|
||||
if (options_.worker_type == WorkerType::WORKER) {
|
||||
task_execution_service_.stop();
|
||||
}
|
||||
if (options_.on_worker_shutdown) {
|
||||
options_.on_worker_shutdown(GetWorkerID());
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorker::Disconnect() {
|
||||
|
||||
@@ -123,6 +123,8 @@ struct CoreWorkerOptions {
|
||||
std::string stderr_file;
|
||||
/// Language worker callback to execute tasks.
|
||||
TaskExecutionCallback task_execution_callback;
|
||||
/// The callback to be called when shutting down a `CoreWorker` instance.
|
||||
std::function<void(const WorkerID &)> on_worker_shutdown;
|
||||
/// Application-language callback to check for signals that have been received
|
||||
/// since calling into C++. This will be called periodically (at least every
|
||||
/// 1s) during long-running operations. If the function returns anything but StatusOK,
|
||||
|
||||
@@ -150,6 +150,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
|
||||
if (throwable &&
|
||||
env->IsInstanceOf(throwable,
|
||||
java_ray_intentional_system_exit_exception_class)) {
|
||||
env->ExceptionClear();
|
||||
return ray::Status::IntentionalSystemExit();
|
||||
}
|
||||
RAY_CHECK_JAVA_EXCEPTION(env);
|
||||
@@ -211,6 +212,16 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
|
||||
}
|
||||
};
|
||||
|
||||
auto on_worker_shutdown = [](const ray::WorkerID &worker_id) {
|
||||
JNIEnv *env = GetJNIEnv();
|
||||
auto worker_id_bytes = IdToJavaByteArray<ray::WorkerID>(env, worker_id);
|
||||
if (java_task_executor) {
|
||||
env->CallVoidMethod(java_task_executor,
|
||||
java_native_task_executor_on_worker_shutdown, worker_id_bytes);
|
||||
RAY_CHECK_JAVA_EXCEPTION(env);
|
||||
}
|
||||
};
|
||||
|
||||
std::string serialized_job_config =
|
||||
(jobConfig == nullptr ? "" : JavaByteArrayToNativeString(env, jobConfig));
|
||||
ray::CoreWorkerOptions options;
|
||||
@@ -229,6 +240,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
|
||||
options.raylet_ip_address = JavaStringToNativeString(env, nodeIpAddress);
|
||||
options.driver_name = JavaStringToNativeString(env, driverName);
|
||||
options.task_execution_callback = task_execution_callback;
|
||||
options.on_worker_shutdown = on_worker_shutdown;
|
||||
options.gc_collect = gc_collect;
|
||||
options.ref_counting_enabled = true;
|
||||
options.num_workers = static_cast<int>(numWorkersPerProcess);
|
||||
|
||||
@@ -107,6 +107,9 @@ jclass java_task_executor_class;
|
||||
jmethodID java_task_executor_parse_function_arguments;
|
||||
jmethodID java_task_executor_execute;
|
||||
|
||||
jclass java_native_task_executor_class;
|
||||
jmethodID java_native_task_executor_on_worker_shutdown;
|
||||
|
||||
jclass java_placement_group_class;
|
||||
jfieldID java_placement_group_id;
|
||||
|
||||
@@ -267,6 +270,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||
java_task_executor_execute =
|
||||
env->GetMethodID(java_task_executor_class, "execute",
|
||||
"(Ljava/util/List;Ljava/util/List;)Ljava/util/List;");
|
||||
java_native_task_executor_class =
|
||||
LoadClass(env, "io/ray/runtime/task/NativeTaskExecutor");
|
||||
java_native_task_executor_on_worker_shutdown =
|
||||
env->GetMethodID(java_native_task_executor_class, "onWorkerShutdown", "([B)V");
|
||||
return CURRENT_JNI_VERSION;
|
||||
}
|
||||
|
||||
@@ -298,4 +305,5 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) {
|
||||
env->DeleteGlobalRef(java_actor_creation_options_class);
|
||||
env->DeleteGlobalRef(java_native_ray_object_class);
|
||||
env->DeleteGlobalRef(java_task_executor_class);
|
||||
env->DeleteGlobalRef(java_native_task_executor_class);
|
||||
}
|
||||
|
||||
@@ -187,6 +187,11 @@ extern jmethodID java_task_executor_parse_function_arguments;
|
||||
/// execute method of TaskExecutor class
|
||||
extern jmethodID java_task_executor_execute;
|
||||
|
||||
/// NativeTaskExecutor class
|
||||
extern jclass java_native_task_executor_class;
|
||||
/// onWorkerShutdown method of NativeTaskExecutor class
|
||||
extern jmethodID java_native_task_executor_on_worker_shutdown;
|
||||
|
||||
/// PlacementGroup class
|
||||
extern jclass java_placement_group_class;
|
||||
/// id field of PlacementGroup class
|
||||
|
||||
Reference in New Issue
Block a user