[core] Ref counting for actor handles (#7434)

* tmp

* Move Exit handler into CoreWorker, exit once owner's ref count goes to 0

* fix build

* Remove __ray_terminate__ and add test case for distributed ref counting

* lint

* Remove unused

* Fixes for detached actor, duplicate actor handles

* Remove unused

* Remove creation return ID

* Remove ObjectIDs from python, set references in CoreWorker

* Fix crash

* Fix memory crash

* Fix tests

* fix

* fixes

* fix tests

* fix java build

* fix build

* fix

* check status

* check status
This commit is contained in:
Stephanie Wang
2020-03-10 17:45:07 -07:00
committed by GitHub
parent 119a303ea0
commit fdb528514b
23 changed files with 330 additions and 180 deletions
+20 -10
View File
@@ -865,7 +865,7 @@ cdef class CoreWorker:
with nogil:
check_status(self.core_worker.get().KillActor(
c_actor_id))
c_actor_id, True))
def resource_ids(self):
cdef:
@@ -894,15 +894,24 @@ cdef class CoreWorker:
self.core_worker.get().CreateProfileEvent(event_type),
extra_data)
def deserialize_and_register_actor_handle(self, const c_string &bytes):
def remove_actor_handle_reference(self, ActorID actor_id):
cdef:
CActorID c_actor_id = actor_id.native()
self.core_worker.get().RemoveActorHandleReference(c_actor_id)
def deserialize_and_register_actor_handle(self, const c_string &bytes,
ObjectID
outer_object_id):
cdef:
CActorHandle* c_actor_handle
CObjectID c_outer_object_id = (outer_object_id.native() if
outer_object_id else
CObjectID.Nil())
worker = ray.worker.get_global_worker()
worker.check_connected()
manager = worker.function_actor_manager
c_actor_id = self.core_worker.get().DeserializeAndRegisterActorHandle(
bytes)
bytes, c_outer_object_id)
check_status(self.core_worker.get().GetActorHandle(
c_actor_id, &c_actor_handle))
actor_id = ActorID(c_actor_id.Binary())
@@ -940,14 +949,13 @@ cdef class CoreWorker:
actor_creation_function_descriptor,
worker.current_session_and_job)
def serialize_actor_handle(self, actor_handle):
assert isinstance(actor_handle, ray.actor.ActorHandle)
def serialize_actor_handle(self, ActorID actor_id):
cdef:
ActorID actor_id = actor_handle._ray_actor_id
c_string output
CObjectID c_actor_handle_id
check_status(self.core_worker.get().SerializeActorHandle(
actor_id.native(), &output))
return output
actor_id.native(), &output, &c_actor_handle_id))
return output, ObjectID(c_actor_handle_id.Binary())
def add_object_id_reference(self, ObjectID object_id):
# Note: faster to not release GIL for short-running op.
@@ -974,7 +982,9 @@ cdef class CoreWorker:
const c_string &serialized_owner_address):
cdef:
CObjectID c_object_id = CObjectID.FromBinary(object_id_binary)
CObjectID c_outer_object_id = outer_object_id.native()
CObjectID c_outer_object_id = (outer_object_id.native() if
outer_object_id else
CObjectID.Nil())
CTaskID c_owner_id = CTaskID.FromBinary(owner_id_binary)
CAddress c_owner_address = CAddress()
+21 -44
View File
@@ -652,6 +652,14 @@ class ActorHandle:
decorator=self._ray_method_decorators.get(method_name))
setattr(self, method_name, method)
def __del__(self):
# Mark that this actor handle has gone out of scope. Once all actor
# handles are out of scope, the actor will exit.
worker = ray.worker.get_global_worker()
if worker.connected and hasattr(worker, "core_worker"):
worker.core_worker.remove_actor_handle_reference(
self._ray_actor_id)
def _actor_method_call(self,
method_name,
args=None,
@@ -752,36 +760,6 @@ class ActorHandle:
self._ray_actor_creation_function_descriptor.class_name,
self._actor_id.hex())
def __del__(self):
"""Terminate the worker that is running this actor."""
# TODO(swang): Also clean up forked actor handles.
# Kill the worker if this is the original actor handle, created
# with Class.remote(). TODO(rkn): Even without passing handles around,
# this is not the right policy. the actor should be alive as long as
# there are ANY handles in scope in the process that created the actor,
# not just the first one.
worker = ray.worker.get_global_worker()
exported_in_current_session_and_job = (
self._ray_session_and_job == worker.current_session_and_job)
if (worker.mode == ray.worker.SCRIPT_MODE
and not exported_in_current_session_and_job):
# If the worker is a driver and driver id has changed because
# Ray was shut down re-initialized, the actor is already cleaned up
# and we don't need to send `__ray_terminate__` again.
logger.warning(
"Actor is garbage collected in the wrong driver." +
" Actor id = %s, class name = %s.", self._ray_actor_id,
self._ray_actor_creation_function_descriptor.class_name)
return
if worker.connected and self._ray_original_handle:
# Note: in py2 the weakref is destroyed prior to calling __del__
# so we need to set the hardref here briefly
try:
self.__ray_terminate__._actor_hard_ref = self
self.__ray_terminate__.remote()
finally:
self.__ray_terminate__._actor_hard_ref = None
def __ray_kill__(self):
"""Deprecated - use ray.kill() instead."""
logger.warning("actor.__ray_kill__() is deprecated and will be removed"
@@ -792,13 +770,9 @@ class ActorHandle:
def _actor_id(self):
return self._ray_actor_id
def _serialization_helper(self, ray_forking):
def _serialization_helper(self):
"""This is defined in order to make pickling work.
Args:
ray_forking: True if this is being called because Ray is forking
the actor handle and false if it is being called by pickling.
Returns:
A dictionary of the information needed to reconstruct the object.
"""
@@ -807,10 +781,11 @@ class ActorHandle:
if hasattr(worker, "core_worker"):
# Non-local mode
state = worker.core_worker.serialize_actor_handle(self)
state = worker.core_worker.serialize_actor_handle(
self._ray_actor_id)
else:
# Local mode
state = {
state = ({
"actor_language": self._ray_actor_language,
"actor_id": self._ray_actor_id,
"method_decorators": self._ray_method_decorators,
@@ -819,18 +794,20 @@ class ActorHandle:
"actor_method_cpus": self._ray_actor_method_cpus,
"actor_creation_function_descriptor": self.
_ray_actor_creation_function_descriptor,
}
}, None)
return state
@classmethod
def _deserialization_helper(cls, state, ray_forking):
def _deserialization_helper(cls, state, outer_object_id=None):
"""This is defined in order to make pickling work.
Args:
state: The serialized state of the actor handle.
ray_forking: True if this is being called because Ray is forking
the actor handle and false if it is being called by pickling.
outer_object_id: The ObjectID that the serialized actor handle was
contained in, if any. This is used for counting references to
the actor handle.
"""
worker = ray.worker.get_global_worker()
worker.check_connected()
@@ -838,7 +815,7 @@ class ActorHandle:
if hasattr(worker, "core_worker"):
# Non-local mode
return worker.core_worker.deserialize_and_register_actor_handle(
state)
state, outer_object_id)
else:
# Local mode
return cls(
@@ -855,8 +832,8 @@ class ActorHandle:
def __reduce__(self):
"""This code path is used by pickling but not by Ray forking."""
state = self._serialization_helper(False)
return ActorHandle._deserialization_helper, (state, False)
state = self._serialization_helper()
return ActorHandle._deserialization_helper, (state)
def modify_class(cls):
+6 -3
View File
@@ -116,7 +116,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
const CActorID &actor_id, const CRayFunction &function,
const c_vector[CTaskArg] &args, const CTaskOptions &options,
c_vector[CObjectID] *return_ids)
CRayStatus KillActor(const CActorID &actor_id)
CRayStatus KillActor(const CActorID &actor_id, c_bool force_kill)
unique_ptr[CProfileEvent] CreateProfileEvent(
const c_string &event_type)
@@ -134,9 +134,12 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
void SetWebuiDisplay(const c_string &key, const c_string &message)
CTaskID GetCallerId()
const ResourceMappingType &GetResourceIDs() const
CActorID DeserializeAndRegisterActorHandle(const c_string &bytes)
void RemoveActorHandleReference(const CActorID &actor_id)
CActorID DeserializeAndRegisterActorHandle(const c_string &bytes, const
CObjectID &outer_object_id)
CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string
*bytes)
*bytes,
CObjectID *c_actor_handle_id)
CRayStatus GetActorHandle(const CActorID &actor_id,
CActorHandle **actor_handle) const
void AddLocalReference(const CObjectID &object_id)
+24 -15
View File
@@ -133,11 +133,18 @@ class SerializationContext:
self._thread_local = threading.local()
def actor_handle_serializer(obj):
return obj._serialization_helper(True)
serialized, actor_handle_id = obj._serialization_helper()
# Update ref counting for the actor handle
self.add_contained_object_id(actor_handle_id)
return serialized
def actor_handle_deserializer(serialized_obj):
# If this actor handle was stored in another object, then tell the
# core worker.
context = ray.worker.global_worker.get_serialization_context()
outer_id = context.get_outer_object_id()
return ray.actor.ActorHandle._deserialization_helper(
serialized_obj, True)
serialized_obj, outer_id)
self._register_cloudpickle_serializer(
ray.actor.ActorHandle,
@@ -151,15 +158,7 @@ class SerializationContext:
return serialized_obj[0](*serialized_obj[1])
def object_id_serializer(obj):
if self.is_in_band_serialization():
self.add_contained_object_id(obj)
else:
# If this serialization is out-of-band (e.g., from a call to
# cloudpickle directly or captured in a remote function/actor),
# then pin the object for the lifetime of this worker by adding
# a local reference that won't ever be removed.
ray.worker.get_global_worker(
).core_worker.add_object_id_reference(obj)
self.add_contained_object_id(obj)
owner_id = ""
owner_address = ""
# TODO(swang): Remove this check. Otherwise, we will not be able to
@@ -239,10 +238,20 @@ class SerializationContext:
return object_ids
def add_contained_object_id(self, object_id):
if not hasattr(self._thread_local, "object_ids"):
self._thread_local.object_ids = set()
self._thread_local.object_ids.add(object_id)
if self.is_in_band_serialization():
# This object ID is being stored in an object. Add the ID to the
# list of IDs contained in the object so that we keep the inner
# object value alive as long as the outer object is in scope.
if not hasattr(self._thread_local, "object_ids"):
self._thread_local.object_ids = set()
self._thread_local.object_ids.add(object_id)
else:
# If this serialization is out-of-band (e.g., from a call to
# cloudpickle directly or captured in a remote function/actor),
# then pin the object for the lifetime of this worker by adding
# a local reference that won't ever be removed.
ray.worker.get_global_worker().core_worker.add_object_id_reference(
object_id)
def _deserialize_pickle5_data(self, data):
try:
+29
View File
@@ -106,6 +106,7 @@ def test_actor_method_metadata_cache(ray_start_regular):
# The cache of ActorClassMethodMetadata.
cache = ray.actor.ActorClassMethodMetadata._cache
cache.clear()
# Check cache hit during ActorHandle deserialization.
A1 = ray.remote(Actor)
@@ -532,6 +533,34 @@ def test_actor_method_deletion(ray_start_regular):
assert ray.get(Actor.remote().method.remote()) == 1
def test_distributed_actor_handle_deletion(ray_start_regular):
@ray.remote
class Actor:
def method(self):
return 1
def getpid(self):
return os.getpid()
@ray.remote
def f(actor, signal):
ray.get(signal.wait.remote())
return ray.get(actor.method.remote())
signal = ray.test_utils.SignalActor.remote()
a = Actor.remote()
pid = ray.get(a.getpid.remote())
# Pass the handle to another task that cannot run yet.
x_id = f.remote(a, signal)
# Delete the original handle. The actor should not get killed yet.
del a
# Once the task finishes, the actor process should get killed.
ray.get(signal.send.remote())
assert ray.get(x_id) == 1
ray.test_utils.wait_for_pid_to_exit(pid)
def test_multiple_actors(ray_start_regular):
@ray.remote
class Counter:
+1 -1
View File
@@ -202,7 +202,7 @@ def test_raylet_info_endpoint(shutdown_only):
try:
assert len(actor_info) == 1
_, parent_actor_info = actor_info.popitem()
assert parent_actor_info["numObjectIdsInScope"] == 11
assert parent_actor_info["numObjectIdsInScope"] == 13
assert parent_actor_info["numLocalObjects"] == 10
children = parent_actor_info["children"]
assert len(children) == 2