mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 14:23:38 +08:00
[core] Ref counting for actor handles (#7434)
* tmp * Move Exit handler into CoreWorker, exit once owner's ref count goes to 0 * fix build * Remove __ray_terminate__ and add test case for distributed ref counting * lint * Remove unused * Fixes for detached actor, duplicate actor handles * Remove unused * Remove creation return ID * Remove ObjectIDs from python, set references in CoreWorker * Fix crash * Fix memory crash * Fix tests * fix * fixes * fix tests * fix java build * fix build * fix * check status * check status
This commit is contained in:
+20
-10
@@ -865,7 +865,7 @@ cdef class CoreWorker:
|
||||
|
||||
with nogil:
|
||||
check_status(self.core_worker.get().KillActor(
|
||||
c_actor_id))
|
||||
c_actor_id, True))
|
||||
|
||||
def resource_ids(self):
|
||||
cdef:
|
||||
@@ -894,15 +894,24 @@ cdef class CoreWorker:
|
||||
self.core_worker.get().CreateProfileEvent(event_type),
|
||||
extra_data)
|
||||
|
||||
def deserialize_and_register_actor_handle(self, const c_string &bytes):
|
||||
def remove_actor_handle_reference(self, ActorID actor_id):
|
||||
cdef:
|
||||
CActorID c_actor_id = actor_id.native()
|
||||
self.core_worker.get().RemoveActorHandleReference(c_actor_id)
|
||||
|
||||
def deserialize_and_register_actor_handle(self, const c_string &bytes,
|
||||
ObjectID
|
||||
outer_object_id):
|
||||
cdef:
|
||||
CActorHandle* c_actor_handle
|
||||
|
||||
CObjectID c_outer_object_id = (outer_object_id.native() if
|
||||
outer_object_id else
|
||||
CObjectID.Nil())
|
||||
worker = ray.worker.get_global_worker()
|
||||
worker.check_connected()
|
||||
manager = worker.function_actor_manager
|
||||
c_actor_id = self.core_worker.get().DeserializeAndRegisterActorHandle(
|
||||
bytes)
|
||||
bytes, c_outer_object_id)
|
||||
check_status(self.core_worker.get().GetActorHandle(
|
||||
c_actor_id, &c_actor_handle))
|
||||
actor_id = ActorID(c_actor_id.Binary())
|
||||
@@ -940,14 +949,13 @@ cdef class CoreWorker:
|
||||
actor_creation_function_descriptor,
|
||||
worker.current_session_and_job)
|
||||
|
||||
def serialize_actor_handle(self, actor_handle):
|
||||
assert isinstance(actor_handle, ray.actor.ActorHandle)
|
||||
def serialize_actor_handle(self, ActorID actor_id):
|
||||
cdef:
|
||||
ActorID actor_id = actor_handle._ray_actor_id
|
||||
c_string output
|
||||
CObjectID c_actor_handle_id
|
||||
check_status(self.core_worker.get().SerializeActorHandle(
|
||||
actor_id.native(), &output))
|
||||
return output
|
||||
actor_id.native(), &output, &c_actor_handle_id))
|
||||
return output, ObjectID(c_actor_handle_id.Binary())
|
||||
|
||||
def add_object_id_reference(self, ObjectID object_id):
|
||||
# Note: faster to not release GIL for short-running op.
|
||||
@@ -974,7 +982,9 @@ cdef class CoreWorker:
|
||||
const c_string &serialized_owner_address):
|
||||
cdef:
|
||||
CObjectID c_object_id = CObjectID.FromBinary(object_id_binary)
|
||||
CObjectID c_outer_object_id = outer_object_id.native()
|
||||
CObjectID c_outer_object_id = (outer_object_id.native() if
|
||||
outer_object_id else
|
||||
CObjectID.Nil())
|
||||
CTaskID c_owner_id = CTaskID.FromBinary(owner_id_binary)
|
||||
CAddress c_owner_address = CAddress()
|
||||
|
||||
|
||||
+21
-44
@@ -652,6 +652,14 @@ class ActorHandle:
|
||||
decorator=self._ray_method_decorators.get(method_name))
|
||||
setattr(self, method_name, method)
|
||||
|
||||
def __del__(self):
|
||||
# Mark that this actor handle has gone out of scope. Once all actor
|
||||
# handles are out of scope, the actor will exit.
|
||||
worker = ray.worker.get_global_worker()
|
||||
if worker.connected and hasattr(worker, "core_worker"):
|
||||
worker.core_worker.remove_actor_handle_reference(
|
||||
self._ray_actor_id)
|
||||
|
||||
def _actor_method_call(self,
|
||||
method_name,
|
||||
args=None,
|
||||
@@ -752,36 +760,6 @@ class ActorHandle:
|
||||
self._ray_actor_creation_function_descriptor.class_name,
|
||||
self._actor_id.hex())
|
||||
|
||||
def __del__(self):
|
||||
"""Terminate the worker that is running this actor."""
|
||||
# TODO(swang): Also clean up forked actor handles.
|
||||
# Kill the worker if this is the original actor handle, created
|
||||
# with Class.remote(). TODO(rkn): Even without passing handles around,
|
||||
# this is not the right policy. the actor should be alive as long as
|
||||
# there are ANY handles in scope in the process that created the actor,
|
||||
# not just the first one.
|
||||
worker = ray.worker.get_global_worker()
|
||||
exported_in_current_session_and_job = (
|
||||
self._ray_session_and_job == worker.current_session_and_job)
|
||||
if (worker.mode == ray.worker.SCRIPT_MODE
|
||||
and not exported_in_current_session_and_job):
|
||||
# If the worker is a driver and driver id has changed because
|
||||
# Ray was shut down re-initialized, the actor is already cleaned up
|
||||
# and we don't need to send `__ray_terminate__` again.
|
||||
logger.warning(
|
||||
"Actor is garbage collected in the wrong driver." +
|
||||
" Actor id = %s, class name = %s.", self._ray_actor_id,
|
||||
self._ray_actor_creation_function_descriptor.class_name)
|
||||
return
|
||||
if worker.connected and self._ray_original_handle:
|
||||
# Note: in py2 the weakref is destroyed prior to calling __del__
|
||||
# so we need to set the hardref here briefly
|
||||
try:
|
||||
self.__ray_terminate__._actor_hard_ref = self
|
||||
self.__ray_terminate__.remote()
|
||||
finally:
|
||||
self.__ray_terminate__._actor_hard_ref = None
|
||||
|
||||
def __ray_kill__(self):
|
||||
"""Deprecated - use ray.kill() instead."""
|
||||
logger.warning("actor.__ray_kill__() is deprecated and will be removed"
|
||||
@@ -792,13 +770,9 @@ class ActorHandle:
|
||||
def _actor_id(self):
|
||||
return self._ray_actor_id
|
||||
|
||||
def _serialization_helper(self, ray_forking):
|
||||
def _serialization_helper(self):
|
||||
"""This is defined in order to make pickling work.
|
||||
|
||||
Args:
|
||||
ray_forking: True if this is being called because Ray is forking
|
||||
the actor handle and false if it is being called by pickling.
|
||||
|
||||
Returns:
|
||||
A dictionary of the information needed to reconstruct the object.
|
||||
"""
|
||||
@@ -807,10 +781,11 @@ class ActorHandle:
|
||||
|
||||
if hasattr(worker, "core_worker"):
|
||||
# Non-local mode
|
||||
state = worker.core_worker.serialize_actor_handle(self)
|
||||
state = worker.core_worker.serialize_actor_handle(
|
||||
self._ray_actor_id)
|
||||
else:
|
||||
# Local mode
|
||||
state = {
|
||||
state = ({
|
||||
"actor_language": self._ray_actor_language,
|
||||
"actor_id": self._ray_actor_id,
|
||||
"method_decorators": self._ray_method_decorators,
|
||||
@@ -819,18 +794,20 @@ class ActorHandle:
|
||||
"actor_method_cpus": self._ray_actor_method_cpus,
|
||||
"actor_creation_function_descriptor": self.
|
||||
_ray_actor_creation_function_descriptor,
|
||||
}
|
||||
}, None)
|
||||
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def _deserialization_helper(cls, state, ray_forking):
|
||||
def _deserialization_helper(cls, state, outer_object_id=None):
|
||||
"""This is defined in order to make pickling work.
|
||||
|
||||
Args:
|
||||
state: The serialized state of the actor handle.
|
||||
ray_forking: True if this is being called because Ray is forking
|
||||
the actor handle and false if it is being called by pickling.
|
||||
outer_object_id: The ObjectID that the serialized actor handle was
|
||||
contained in, if any. This is used for counting references to
|
||||
the actor handle.
|
||||
|
||||
"""
|
||||
worker = ray.worker.get_global_worker()
|
||||
worker.check_connected()
|
||||
@@ -838,7 +815,7 @@ class ActorHandle:
|
||||
if hasattr(worker, "core_worker"):
|
||||
# Non-local mode
|
||||
return worker.core_worker.deserialize_and_register_actor_handle(
|
||||
state)
|
||||
state, outer_object_id)
|
||||
else:
|
||||
# Local mode
|
||||
return cls(
|
||||
@@ -855,8 +832,8 @@ class ActorHandle:
|
||||
|
||||
def __reduce__(self):
|
||||
"""This code path is used by pickling but not by Ray forking."""
|
||||
state = self._serialization_helper(False)
|
||||
return ActorHandle._deserialization_helper, (state, False)
|
||||
state = self._serialization_helper()
|
||||
return ActorHandle._deserialization_helper, (state)
|
||||
|
||||
|
||||
def modify_class(cls):
|
||||
|
||||
@@ -116,7 +116,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
const CActorID &actor_id, const CRayFunction &function,
|
||||
const c_vector[CTaskArg] &args, const CTaskOptions &options,
|
||||
c_vector[CObjectID] *return_ids)
|
||||
CRayStatus KillActor(const CActorID &actor_id)
|
||||
CRayStatus KillActor(const CActorID &actor_id, c_bool force_kill)
|
||||
|
||||
unique_ptr[CProfileEvent] CreateProfileEvent(
|
||||
const c_string &event_type)
|
||||
@@ -134,9 +134,12 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
void SetWebuiDisplay(const c_string &key, const c_string &message)
|
||||
CTaskID GetCallerId()
|
||||
const ResourceMappingType &GetResourceIDs() const
|
||||
CActorID DeserializeAndRegisterActorHandle(const c_string &bytes)
|
||||
void RemoveActorHandleReference(const CActorID &actor_id)
|
||||
CActorID DeserializeAndRegisterActorHandle(const c_string &bytes, const
|
||||
CObjectID &outer_object_id)
|
||||
CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string
|
||||
*bytes)
|
||||
*bytes,
|
||||
CObjectID *c_actor_handle_id)
|
||||
CRayStatus GetActorHandle(const CActorID &actor_id,
|
||||
CActorHandle **actor_handle) const
|
||||
void AddLocalReference(const CObjectID &object_id)
|
||||
|
||||
+24
-15
@@ -133,11 +133,18 @@ class SerializationContext:
|
||||
self._thread_local = threading.local()
|
||||
|
||||
def actor_handle_serializer(obj):
|
||||
return obj._serialization_helper(True)
|
||||
serialized, actor_handle_id = obj._serialization_helper()
|
||||
# Update ref counting for the actor handle
|
||||
self.add_contained_object_id(actor_handle_id)
|
||||
return serialized
|
||||
|
||||
def actor_handle_deserializer(serialized_obj):
|
||||
# If this actor handle was stored in another object, then tell the
|
||||
# core worker.
|
||||
context = ray.worker.global_worker.get_serialization_context()
|
||||
outer_id = context.get_outer_object_id()
|
||||
return ray.actor.ActorHandle._deserialization_helper(
|
||||
serialized_obj, True)
|
||||
serialized_obj, outer_id)
|
||||
|
||||
self._register_cloudpickle_serializer(
|
||||
ray.actor.ActorHandle,
|
||||
@@ -151,15 +158,7 @@ class SerializationContext:
|
||||
return serialized_obj[0](*serialized_obj[1])
|
||||
|
||||
def object_id_serializer(obj):
|
||||
if self.is_in_band_serialization():
|
||||
self.add_contained_object_id(obj)
|
||||
else:
|
||||
# If this serialization is out-of-band (e.g., from a call to
|
||||
# cloudpickle directly or captured in a remote function/actor),
|
||||
# then pin the object for the lifetime of this worker by adding
|
||||
# a local reference that won't ever be removed.
|
||||
ray.worker.get_global_worker(
|
||||
).core_worker.add_object_id_reference(obj)
|
||||
self.add_contained_object_id(obj)
|
||||
owner_id = ""
|
||||
owner_address = ""
|
||||
# TODO(swang): Remove this check. Otherwise, we will not be able to
|
||||
@@ -239,10 +238,20 @@ class SerializationContext:
|
||||
return object_ids
|
||||
|
||||
def add_contained_object_id(self, object_id):
|
||||
if not hasattr(self._thread_local, "object_ids"):
|
||||
self._thread_local.object_ids = set()
|
||||
|
||||
self._thread_local.object_ids.add(object_id)
|
||||
if self.is_in_band_serialization():
|
||||
# This object ID is being stored in an object. Add the ID to the
|
||||
# list of IDs contained in the object so that we keep the inner
|
||||
# object value alive as long as the outer object is in scope.
|
||||
if not hasattr(self._thread_local, "object_ids"):
|
||||
self._thread_local.object_ids = set()
|
||||
self._thread_local.object_ids.add(object_id)
|
||||
else:
|
||||
# If this serialization is out-of-band (e.g., from a call to
|
||||
# cloudpickle directly or captured in a remote function/actor),
|
||||
# then pin the object for the lifetime of this worker by adding
|
||||
# a local reference that won't ever be removed.
|
||||
ray.worker.get_global_worker().core_worker.add_object_id_reference(
|
||||
object_id)
|
||||
|
||||
def _deserialize_pickle5_data(self, data):
|
||||
try:
|
||||
|
||||
@@ -106,6 +106,7 @@ def test_actor_method_metadata_cache(ray_start_regular):
|
||||
|
||||
# The cache of ActorClassMethodMetadata.
|
||||
cache = ray.actor.ActorClassMethodMetadata._cache
|
||||
cache.clear()
|
||||
|
||||
# Check cache hit during ActorHandle deserialization.
|
||||
A1 = ray.remote(Actor)
|
||||
@@ -532,6 +533,34 @@ def test_actor_method_deletion(ray_start_regular):
|
||||
assert ray.get(Actor.remote().method.remote()) == 1
|
||||
|
||||
|
||||
def test_distributed_actor_handle_deletion(ray_start_regular):
|
||||
@ray.remote
|
||||
class Actor:
|
||||
def method(self):
|
||||
return 1
|
||||
|
||||
def getpid(self):
|
||||
return os.getpid()
|
||||
|
||||
@ray.remote
|
||||
def f(actor, signal):
|
||||
ray.get(signal.wait.remote())
|
||||
return ray.get(actor.method.remote())
|
||||
|
||||
signal = ray.test_utils.SignalActor.remote()
|
||||
a = Actor.remote()
|
||||
pid = ray.get(a.getpid.remote())
|
||||
# Pass the handle to another task that cannot run yet.
|
||||
x_id = f.remote(a, signal)
|
||||
# Delete the original handle. The actor should not get killed yet.
|
||||
del a
|
||||
|
||||
# Once the task finishes, the actor process should get killed.
|
||||
ray.get(signal.send.remote())
|
||||
assert ray.get(x_id) == 1
|
||||
ray.test_utils.wait_for_pid_to_exit(pid)
|
||||
|
||||
|
||||
def test_multiple_actors(ray_start_regular):
|
||||
@ray.remote
|
||||
class Counter:
|
||||
|
||||
@@ -202,7 +202,7 @@ def test_raylet_info_endpoint(shutdown_only):
|
||||
try:
|
||||
assert len(actor_info) == 1
|
||||
_, parent_actor_info = actor_info.popitem()
|
||||
assert parent_actor_info["numObjectIdsInScope"] == 11
|
||||
assert parent_actor_info["numObjectIdsInScope"] == 13
|
||||
assert parent_actor_info["numLocalObjects"] == 10
|
||||
children = parent_actor_info["children"]
|
||||
assert len(children) == 2
|
||||
|
||||
Reference in New Issue
Block a user