diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index a5554d239..44ccfe8dd 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -278,6 +278,8 @@ cdef void prepare_args( size_t size int64_t put_threshold shared_ptr[CBuffer] arg_data + c_vector[CObjectID] inlined_ids + ObjectID obj_id worker = ray.worker.global_worker put_threshold = RayConfig.instance().max_direct_call_object_size() @@ -294,14 +296,17 @@ cdef void prepare_args( # plasma here. This is inefficient for small objects, but inlined # arguments aren't associated ObjectIDs right now so this is a # simple fix for reference counting purposes. - if (size <= put_threshold and - len(serialized_arg.contained_object_ids) == 0): + if size <= put_threshold: arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer]( make_shared[LocalMemoryBuffer](size)) write_serialized_object(serialized_arg, arg_data) + for obj_id in serialized_arg.contained_object_ids: + inlined_ids.push_back(obj_id.native()) args_vector.push_back( CTaskArg.PassByValue(make_shared[CRayObject]( - arg_data, string_to_buffer(serialized_arg.metadata)))) + arg_data, string_to_buffer(serialized_arg.metadata), + inlined_ids))) + inlined_ids.clear() else: args_vector.push_back( CTaskArg.PassByReference((CObjectID.FromBinary( @@ -664,7 +669,7 @@ cdef class CoreWorker: c_object_id[0] = object_id.native() with nogil: check_status(self.core_worker.get().Create( - metadata, data_size, contained_ids, + metadata, data_size, c_object_id[0], data)) break except ObjectStoreFullError as e: @@ -979,15 +984,18 @@ cdef class CoreWorker: c_owner_address.SerializeAsString()) def deserialize_and_register_object_id( - self, const c_string &object_id_binary, const c_string - &owner_id_binary, const c_string &serialized_owner_address): + self, const c_string &object_id_binary, ObjectID outer_object_id, + const c_string &owner_id_binary, + const c_string &serialized_owner_address): cdef: CObjectID c_object_id = CObjectID.FromBinary(object_id_binary) + CObjectID c_outer_object_id = outer_object_id.native() CTaskID c_owner_id = CTaskID.FromBinary(owner_id_binary) CAddress c_owner_address = CAddress() c_owner_address.ParseFromString(serialized_owner_address) self.core_worker.get().RegisterOwnershipInfoAndResolveFuture( c_object_id, + c_outer_object_id, c_owner_id, c_owner_address) diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 5cd5b134e..26ceb9396 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -142,8 +142,10 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CTaskID *owner_id, CAddress *owner_address) void RegisterOwnershipInfoAndResolveFuture( - const CObjectID &object_id, const CTaskID &owner_id, const - CAddress &owner_address) + const CObjectID &object_id, + const CObjectID &outer_object_id, + const CTaskID &owner_id, + const CAddress &owner_address) void AddContainedObjectIDs( const CObjectID &object_id, const c_vector[CObjectID] &contained_object_ids) @@ -161,7 +163,6 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CObjectID *object_id, shared_ptr[CBuffer] *data) CRayStatus Create(const shared_ptr[CBuffer] &metadata, const size_t data_size, - const c_vector[CObjectID] &contained_object_ids, const CObjectID &object_id, shared_ptr[CBuffer] *data) CRayStatus Seal(const CObjectID &object_id, c_bool pin_object) diff --git a/python/ray/serialization.py b/python/ray/serialization.py index b11b63b64..b6660f3da 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -156,6 +156,8 @@ class SerializationContext: self.add_contained_object_id(obj) owner_id = "" owner_address = "" + # TODO(swang): Remove this check. Otherwise, we will not be able to + # handle serialized plasma IDs correctly. if obj.is_direct_call_type(): worker = ray.worker.get_global_worker() worker.check_connected() @@ -176,14 +178,14 @@ class SerializationContext: # to 'self' here instead, but this function is itself pickled # somewhere, which causes an error. context = ray.worker.global_worker.get_serialization_context() - context.add_contained_object_id(deserialized_object_id) if owner_id: worker = ray.worker.get_global_worker() worker.check_connected() # UniqueIDs are serialized as # (class name, (unique bytes,)). + outer_id = context.get_outer_object_id() worker.core_worker.deserialize_and_register_object_id( - obj_id[1][0], owner_id[1][0], owner_address) + obj_id[1][0], outer_id, owner_id[1][0], owner_address) return deserialized_object_id for id_type in ray._raylet._ID_TYPES: @@ -204,6 +206,12 @@ class SerializationContext: # construct a reducer pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer + def set_outer_object_id(self, outer_object_id): + self._thread_local.outer_object_id = outer_object_id + + def get_outer_object_id(self): + return getattr(self._thread_local, "outer_object_id", None) + def get_and_clear_contained_object_ids(self): if not hasattr(self._thread_local, "object_ids"): self._thread_local.object_ids = set() @@ -235,18 +243,8 @@ class SerializationContext: # cloudpickle does not provide error types except pickle.pickle.PicklingError: raise DeserializationError() - - # Check that there are no ObjectIDs serialized in arguments - # that are inlined. - if object_id.is_nil(): - assert len(self.get_and_clear_contained_object_ids()) == 0 - else: - worker = ray.worker.global_worker - worker.core_worker.add_contained_object_ids( - object_id, - self.get_and_clear_contained_object_ids(), - ) return obj + # Check if the object should be returned as raw bytes. if metadata == ray_constants.RAW_BUFFER_METADATA: if data is None: @@ -287,6 +285,8 @@ class SerializationContext: while i < len(object_ids): object_id = object_ids[i] data, metadata = data_metadata_pairs[i] + assert self.get_outer_object_id() is None + self.set_outer_object_id(object_id) try: results.append( self._deserialize_object(data, metadata, object_id)) @@ -310,6 +310,9 @@ class SerializationContext: warning_message, job_id=self.worker.current_job_id) warning_sent = True + finally: + # Must clear ObjectID to not hold a reference. + self.set_outer_object_id(None) return results @@ -328,6 +331,7 @@ class SerializationContext: assert self.worker.use_pickle assert ray.cloudpickle.FAST_CLOUDPICKLE_USED writer = Pickle5Writer() + # TODO(swang): Check that contained_object_ids is empty. inband = pickle.dumps( value, protocol=5, buffer_callback=writer.buffer_callback) return Pickle5SerializedObject( diff --git a/python/ray/tests/test_reference_counting.py b/python/ray/tests/test_reference_counting.py index c815a9458..6331a00e8 100644 --- a/python/ray/tests/test_reference_counting.py +++ b/python/ray/tests/test_reference_counting.py @@ -8,6 +8,7 @@ import time import pytest import logging import uuid +import gc import ray import ray.cluster_utils @@ -18,7 +19,13 @@ logger = logging.getLogger(__name__) @pytest.fixture def one_worker_100MiB(request): - yield ray.init(num_cpus=1, object_store_memory=100 * 1024 * 1024) + config = json.dumps({ + "distributed_ref_counting_enabled": 1, + }) + yield ray.init( + num_cpus=1, + object_store_memory=100 * 1024 * 1024, + _internal_config=config) ray.shutdown() @@ -266,7 +273,6 @@ def test_feature_flag(shutdown_only): # Remote function takes serialized reference and doesn't hold onto it after # finishing. Referenced object shouldn't be evicted while the task is pending # and should be evicted after it returns. -@pytest.mark.skip("Serialized ObjectID reference counting not implemented.") def test_basic_serialized_reference(one_worker_100MiB): @ray.remote def pending(ref, dep): @@ -286,6 +292,9 @@ def test_basic_serialized_reference(one_worker_100MiB): # Remove the local reference. array_oid_bytes = array_oid.binary() del array_oid + # Needed due to Python GC issue in cloudpickle. + # https://github.com/cloudpipe/cloudpickle/issues/343 + gc.collect() # Check that the remote reference pins the object. _fill_object_store_and_get(array_oid_bytes) @@ -301,7 +310,8 @@ def test_basic_serialized_reference(one_worker_100MiB): # Call a recursive chain of tasks that pass a serialized reference to the end # of the chain. The reference should still exist while the final task in the # chain is running and should be removed once it finishes. -@pytest.mark.skip("Serialized ObjectID reference counting not implemented.") +@pytest.mark.skip("Memory not freed due to Python GC issue in cloudpickle " + "(https://github.com/cloudpipe/cloudpickle/issues/343).") def test_recursive_serialized_reference(one_worker_100MiB): @ray.remote def recursive(ref, dep, max_depth, depth=0): @@ -325,7 +335,7 @@ def test_recursive_serialized_reference(one_worker_100MiB): del array_oid tail_oid = head_oid - for _ in range(max_depth - 1): + for _ in range(max_depth): tail_oid = ray.get(tail_oid) # Check that the remote reference pins the object. @@ -333,7 +343,7 @@ def test_recursive_serialized_reference(one_worker_100MiB): # Fulfill the dependency, causing the tail task to finish. ray.worker.global_worker.put_object(None, object_id=random_oid) - ray.get(tail_oid) + assert ray.get(tail_oid) is None # Reference should be gone, check that array gets evicted. _fill_object_store_and_get(array_oid_bytes, succeed=False) @@ -342,7 +352,6 @@ def test_recursive_serialized_reference(one_worker_100MiB): # Test that a passed reference held by an actor after the method finishes # is kept until the reference is removed from the actor. Also tests giving # the actor a duplicate reference to the same object ID. -@pytest.mark.skip("Serialized ObjectID reference counting not implemented.") def test_actor_holding_serialized_reference(one_worker_100MiB): @ray.remote class GreedyActor(object): @@ -376,6 +385,9 @@ def test_actor_holding_serialized_reference(one_worker_100MiB): # Remove the local reference. array_oid_bytes = array_oid.binary() del array_oid + # Needed due to Python GC issue in cloudpickle. + # https://github.com/cloudpipe/cloudpickle/issues/343 + gc.collect() # Test that the remote references still pin the object. _fill_object_store_and_get(array_oid_bytes) @@ -392,7 +404,8 @@ def test_actor_holding_serialized_reference(one_worker_100MiB): # Test that a passed reference held by an actor after a task finishes # is kept until the reference is removed from the worker. Also tests giving # the worker a duplicate reference to the same object ID. -@pytest.mark.skip("Serialized ObjectID reference counting not implemented.") +@pytest.mark.skip("Memory not freed due to Python GC issue in cloudpickle " + "(https://github.com/cloudpipe/cloudpickle/issues/343).") def test_worker_holding_serialized_reference(one_worker_100MiB): @ray.remote def child(dep1, dep2): @@ -428,7 +441,6 @@ def test_worker_holding_serialized_reference(one_worker_100MiB): # Test that an object containing object IDs within it pins the inner IDs. -@pytest.mark.skip("Serialized ObjectID reference counting not implemented.") def test_basic_nested_ids(one_worker_100MiB): inner_oid = ray.put(np.zeros(40 * 1024 * 1024, dtype=np.uint8)) outer_oid = ray.put([inner_oid]) @@ -436,6 +448,9 @@ def test_basic_nested_ids(one_worker_100MiB): # Remove the local reference to the inner object. inner_oid_bytes = inner_oid.binary() del inner_oid + # Needed due to Python GC issue in cloudpickle. + # https://github.com/cloudpipe/cloudpickle/issues/343 + gc.collect() # Check that the outer reference pins the inner object. _fill_object_store_and_get(inner_oid_bytes) @@ -447,7 +462,8 @@ def test_basic_nested_ids(one_worker_100MiB): # Test that an object containing object IDs within it pins the inner IDs # recursively and for submitted tasks. -@pytest.mark.skip("Serialized ObjectID reference counting not implemented.") +@pytest.mark.skip("Memory not freed due to Python GC issue in cloudpickle " + "(https://github.com/cloudpipe/cloudpickle/issues/343).") def test_recursively_nest_ids(one_worker_100MiB): @ray.remote def recursive(ref, dep, max_depth, depth=0): @@ -474,7 +490,7 @@ def test_recursively_nest_ids(one_worker_100MiB): del array_oid, nested_oid tail_oid = head_oid - for _ in range(max_depth - 1): + for _ in range(max_depth): tail_oid = ray.get(tail_oid) # Check that the remote reference pins the object. @@ -490,7 +506,8 @@ def test_recursively_nest_ids(one_worker_100MiB): # Test that serialized objectIDs returned from remote tasks are pinned until # they go out of scope on the caller side. -@pytest.mark.skip("Serialized ObjectID reference counting not implemented.") +@pytest.mark.skip("Memory not freed due to Python GC issue in cloudpickle " + "(https://github.com/cloudpipe/cloudpickle/issues/343).") def test_return_object_id(one_worker_100MiB): @ray.remote def put(): @@ -519,7 +536,8 @@ def test_return_object_id(one_worker_100MiB): # Test that serialized objectIDs returned from remote tasks are pinned if # passed into another remote task by the caller. -@pytest.mark.skip("Serialized ObjectID reference counting not implemented.") +@pytest.mark.skip("Memory not freed due to Python GC issue in cloudpickle " + "(https://github.com/cloudpipe/cloudpickle/issues/343).") def test_pass_returned_object_id(one_worker_100MiB): @ray.remote def put(): @@ -555,7 +573,8 @@ def test_pass_returned_object_id(one_worker_100MiB): # returned by another task to the end of the chain. The reference should still # exist while the final task in the chain is running and should be removed once # it finishes. -@pytest.mark.skip("Serialized ObjectID reference counting not implemented.") +@pytest.mark.skip("Memory not freed due to Python GC issue in cloudpickle " + "(https://github.com/cloudpipe/cloudpickle/issues/343).") def test_recursively_pass_returned_object_id(one_worker_100MiB): @ray.remote def put(): @@ -583,7 +602,7 @@ def test_recursively_pass_returned_object_id(one_worker_100MiB): del outer_oid tail_oid = head_oid - for _ in range(max_depth - 1): + for _ in range(max_depth): tail_oid = ray.get(tail_oid) # Check that the remote reference pins the object. diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 26b3a05c9..eac5fcf9a 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -44,6 +44,22 @@ RAY_CONFIG(bool, fair_queueing_enabled, true) /// enabled, objects in scope in the cluster will not be LRU evicted. RAY_CONFIG(bool, object_pinning_enabled, true) +/// Whether to enable distributed reference counting for objects. When this is +/// enabled, an object's ref count will include any references held by other +/// processes, such as when an ObjectID is serialized and passed as an argument +/// to another task. It will also include any references due to nesting, i.e. +/// if the object ID is nested inside another object that is still in scope. +/// When this is disabled, an object's ref count will include only local +/// information: +/// 1. Local Python references to the ObjectID. +/// 2. Pending tasks submitted by the local process that depend on the object. +/// If both this flag and object_pinning_enabled are turned on, then an object +/// will not be LRU evicted until it is out of scope in ALL processes in the +/// cluster and all objects that contain it are also out of scope. If this flag +/// is off and object_pinning_enabled is turned on, then an object will not be +/// LRU evicted until it is out of scope on the CREATOR of the ObjectID. +RAY_CONFIG(bool, distributed_ref_counting_enabled, false) + /// Whether to enable the new scheduler. The new scheduler is designed /// only to work with direct calls. Once direct calls afre becoming /// the default, this scheduler will also become the default. diff --git a/src/ray/common/ray_object.cc b/src/ray/common/ray_object.cc index ecc652b8f..142c81710 100644 --- a/src/ray/common/ray_object.cc +++ b/src/ray/common/ray_object.cc @@ -11,7 +11,7 @@ std::shared_ptr MakeErrorMetadataBuffer(rpc::ErrorType error_ } RayObject::RayObject(rpc::ErrorType error_type) - : RayObject(nullptr, MakeErrorMetadataBuffer(error_type)) {} + : RayObject(nullptr, MakeErrorMetadataBuffer(error_type), {}) {} bool RayObject::IsException(rpc::ErrorType *error_type) const { if (metadata_ == nullptr) { diff --git a/src/ray/common/ray_object.h b/src/ray/common/ray_object.h index 4f8a5e003..395d369ed 100644 --- a/src/ray/common/ray_object.h +++ b/src/ray/common/ray_object.h @@ -2,6 +2,7 @@ #define RAY_COMMON_RAY_OBJECT_H #include "ray/common/buffer.h" +#include "ray/common/id.h" #include "ray/protobuf/gcs.pb.h" #include "ray/util/logging.h" @@ -21,10 +22,14 @@ class RayObject { /// /// \param[in] data Data of the ray object. /// \param[in] metadata Metadata of the ray object. + /// \param[in] inlined_ids ObjectIDs that were serialized in data. /// \param[in] copy_data Whether this class should hold a copy of data. RayObject(const std::shared_ptr &data, const std::shared_ptr &metadata, - bool copy_data = false) - : data_(data), metadata_(metadata), has_data_copy_(copy_data) { + const std::vector &inlined_ids, bool copy_data = false) + : data_(data), + metadata_(metadata), + inlined_ids_(inlined_ids), + has_data_copy_(copy_data) { if (has_data_copy_) { // If this object is required to hold a copy of the data, // make a copy if the passed in buffers don't already have a copy. @@ -45,10 +50,13 @@ class RayObject { RayObject(rpc::ErrorType error_type); /// Return the data of the ray object. - const std::shared_ptr &GetData() const { return data_; }; + const std::shared_ptr &GetData() const { return data_; } /// Return the metadata of the ray object. - const std::shared_ptr &GetMetadata() const { return metadata_; }; + const std::shared_ptr &GetMetadata() const { return metadata_; } + + /// Return the object IDs that were serialized in data. + const std::vector &GetInlinedIds() const { return inlined_ids_; } uint64_t GetSize() const { uint64_t size = 0; @@ -73,6 +81,7 @@ class RayObject { private: std::shared_ptr data_; std::shared_ptr metadata_; + const std::vector inlined_ids_; /// Whether this class holds a data copy. bool has_data_copy_; }; diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 8ad6f11ad..ab369b7ce 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -122,6 +122,10 @@ size_t TaskSpecification::ArgMetadataSize(size_t arg_index) const { return message_->args(arg_index).metadata().size(); } +const std::vector TaskSpecification::ArgInlinedIds(size_t arg_index) const { + return IdVectorFromProtobuf(message_->args(arg_index).nested_inlined_ids()); +} + const ResourceSet &TaskSpecification::GetRequiredResources() const { return *required_resources_; } diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index 9065255e4..23c229a4b 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -88,6 +88,9 @@ class TaskSpecification : public MessageWrapper { size_t ArgMetadataSize(size_t arg_index) const; + /// Return the ObjectIDs that were inlined in this task argument. + const std::vector ArgInlinedIds(size_t arg_index) const; + /// Return the scheduling class of the task. The scheduler makes a best effort /// attempt to fairly dispatch tasks of different classes, preventing /// starvation of any single class of task. diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index c7ab30952..b978de6b4 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -78,18 +78,6 @@ class TaskSpecBuilder { return *this; } - /// Add a by-value argument to the task. - /// - /// \param data String object that contains the data. - /// \param metadata String object that contains the metadata. - /// \return Reference to the builder object itself. - TaskSpecBuilder &AddByValueArg(const std::string &data, const std::string &metadata) { - auto arg = message_->add_args(); - arg->set_data(data); - arg->set_metadata(metadata); - return *this; - } - /// Add a by-value argument to the task. /// /// \param value the RayObject instance that contains the data and the metadata. @@ -104,6 +92,9 @@ class TaskSpecBuilder { const auto &metadata = value.GetMetadata(); arg->set_metadata(metadata->Data(), metadata->Size()); } + for (const auto &inlined_id : value.GetInlinedIds()) { + arg->add_nested_inlined_ids(inlined_id.Binary()); + } return *this; } diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index cc17249db..e2bf6d5e6 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -80,7 +80,13 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, death_check_timer_(io_service_), internal_timer_(io_service_), core_worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */), - reference_counter_(std::make_shared()), + reference_counter_(std::make_shared( + /*distributed_ref_counting_enabled=*/RayConfig::instance() + .distributed_ref_counting_enabled(), + [this](const rpc::Address &addr) { + return std::shared_ptr( + new rpc::CoreWorkerClient(addr, *client_call_manager_)); + })), task_queue_length_(0), num_executed_tasks_(0), task_execution_service_work_(task_execution_service_), @@ -111,8 +117,9 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, // Initialize task receivers. if (worker_type_ == WorkerType::WORKER) { RAY_CHECK(task_execution_callback_ != nullptr); - auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1, - std::placeholders::_2, std::placeholders::_3); + auto execute_task = + std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); auto exit = [this](bool intentional) { // Release the resources early in case draining takes a long time. RAY_CHECK_OK(local_raylet_client_->NotifyDirectCallTaskBlocked()); @@ -176,7 +183,12 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, store_socket, local_raylet_client_, check_signals_)); memory_store_.reset(new CoreWorkerMemoryStore( [this](const RayObject &obj, const ObjectID &obj_id) { - RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id)); + bool object_exists; + RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id, &object_exists)); + if (!object_exists) { + RAY_LOG(DEBUG) << "Pinning object promoted to plasma " << obj_id; + RAY_CHECK_OK(local_raylet_client_->PinObjectIDs(rpc_address_, {obj_id})); + } }, ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_, check_signals_)); @@ -211,9 +223,9 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, SetCurrentTaskId(task_id); } - auto client_factory = [this](const std::string ip_address, int port) { + auto client_factory = [this](const rpc::Address &addr) { return std::shared_ptr( - new rpc::CoreWorkerClient(ip_address, port, *client_call_manager_)); + new rpc::CoreWorkerClient(addr, *client_call_manager_)); }; direct_actor_submitter_ = std::unique_ptr( new CoreWorkerDirectActorTaskSubmitter(rpc_address_, client_factory, memory_store_, @@ -335,7 +347,14 @@ void CoreWorker::PromoteToPlasmaAndGetOwnershipInfo(const ObjectID &object_id, RAY_CHECK(object_id.IsDirectCallType()); auto value = memory_store_->GetOrPromoteToPlasma(object_id); if (value) { - RAY_CHECK_OK(plasma_store_provider_->Put(*value, object_id)); + RAY_LOG(DEBUG) << "Storing object promoted to plasma " << object_id; + bool object_exists; + RAY_CHECK_OK(plasma_store_provider_->Put(*value, object_id, &object_exists)); + if (!object_exists) { + RAY_LOG(DEBUG) << "PromoteToPlasma: Pinning object promoted to plasma " + << object_id; + RAY_CHECK_OK(local_raylet_client_->PinObjectIDs(rpc_address_, {object_id})); + } } auto has_owner = reference_counter_->GetOwner(object_id, owner_id, owner_address); @@ -345,14 +364,17 @@ void CoreWorker::PromoteToPlasmaAndGetOwnershipInfo(const ObjectID &object_id, "which task will create them. " "If this was not how your object ID was generated, please file an issue " "at https://github.com/ray-project/ray/issues/"; + RAY_LOG(DEBUG) << "Promoted object to plasma " << object_id << " owned by " + << *owner_id; } void CoreWorker::RegisterOwnershipInfoAndResolveFuture( - const ObjectID &object_id, const TaskID &owner_id, + const ObjectID &object_id, const ObjectID &outer_object_id, const TaskID &owner_id, const rpc::Address &owner_address) { // Add the object's owner to the local metadata in case it gets serialized // again. - reference_counter_->AddBorrowedObject(object_id, owner_id, owner_address); + reference_counter_->AddBorrowedObject(object_id, outer_object_id, owner_id, + owner_address); RAY_CHECK(!owner_id.IsNil()); // We will ask the owner about the object until the object is @@ -376,7 +398,8 @@ Status CoreWorker::Put(const RayObject &object, *object_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(), worker_context_.GetNextPutIndex(), static_cast(TaskTransportType::RAYLET)); - reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_); + reference_counter_->AddOwnedObject(*object_id, contained_object_ids, GetCallerId(), + rpc_address_); RAY_RETURN_NOT_OK(Put(object, contained_object_ids, *object_id)); // Tell the raylet to pin the object **after** it is created. RAY_CHECK_OK(local_raylet_client_->PinObjectIDs(rpc_address_, {*object_id})); @@ -390,7 +413,7 @@ Status CoreWorker::Put(const RayObject &object, static_cast(TaskTransportType::RAYLET)) << "Invalid transport type flag in object ID: " << object_id.GetTransportType(); // TODO(edoakes,swang): add contained object IDs to the reference counter. - return plasma_store_provider_->Put(object, object_id); + return plasma_store_provider_->Put(object, object_id, nullptr); } Status CoreWorker::Create(const std::shared_ptr &metadata, const size_t data_size, @@ -399,18 +422,18 @@ Status CoreWorker::Create(const std::shared_ptr &metadata, const size_t *object_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(), worker_context_.GetNextPutIndex(), static_cast(TaskTransportType::RAYLET)); - RAY_RETURN_NOT_OK(Create(metadata, data_size, contained_object_ids, *object_id, data)); + RAY_RETURN_NOT_OK( + plasma_store_provider_->Create(metadata, data_size, *object_id, data)); // Only add the object to the reference counter if it didn't already exist. if (data) { - reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_); + reference_counter_->AddOwnedObject(*object_id, contained_object_ids, GetCallerId(), + rpc_address_); } return Status::OK(); } Status CoreWorker::Create(const std::shared_ptr &metadata, const size_t data_size, - const std::vector &contained_object_ids, const ObjectID &object_id, std::shared_ptr *data) { - // TODO(edoakes,swang): add contained object IDs to the reference counter. return plasma_store_provider_->Create(metadata, data_size, object_id, data); } @@ -418,6 +441,7 @@ Status CoreWorker::Seal(const ObjectID &object_id, bool pin_object) { RAY_RETURN_NOT_OK(plasma_store_provider_->Seal(object_id)); if (pin_object) { // Tell the raylet to pin the object **after** it is created. + RAY_LOG(DEBUG) << "Pinning created object " << object_id; RAY_CHECK_OK(local_raylet_client_->PinObjectIDs(rpc_address_, {object_id})); } return Status::OK(); @@ -446,7 +470,7 @@ Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_m // the transport type again and return them for the original direct call ids. for (const auto &pair : result_map) { if (pair.second->IsInPlasmaError()) { - RAY_LOG(INFO) << pair.first << " in plasma, doing fetch-and-get"; + RAY_LOG(DEBUG) << pair.first << " in plasma, doing fetch-and-get"; plasma_object_ids.insert(pair.first); } } @@ -892,18 +916,33 @@ Status CoreWorker::AllocateReturnObjects( RAY_CHECK(object_ids.size() == data_sizes.size()); return_objects->resize(object_ids.size(), nullptr); + absl::optional owner_address( + worker_context_.GetCurrentTask()->CallerAddress()); + bool owned_by_us = owner_address->worker_id() == rpc_address_.worker_id(); + if (owned_by_us) { + owner_address.reset(); + } + for (size_t i = 0; i < object_ids.size(); i++) { bool object_already_exists = false; std::shared_ptr data_buffer; if (data_sizes[i] > 0) { + RAY_LOG(DEBUG) << "Creating return object " << object_ids[i]; + // Mark this object as containing other object IDs. The ref counter will + // keep the inner IDs in scope until the outer one is out of scope. + if (!contained_object_ids[i].empty()) { + reference_counter_->WrapObjectIds(object_ids[i], contained_object_ids[i], + owner_address); + } + + // Allocate a buffer for the return object. if (worker_context_.CurrentTaskIsDirectCall() && static_cast(data_sizes[i]) < - RayConfig::instance().max_direct_call_object_size() && - contained_object_ids[i].empty()) { + RayConfig::instance().max_direct_call_object_size()) { data_buffer = std::make_shared(data_sizes[i]); } else { - RAY_RETURN_NOT_OK(Create(metadatas[i], data_sizes[i], contained_object_ids[i], - object_ids[i], &data_buffer)); + RAY_RETURN_NOT_OK( + Create(metadatas[i], data_sizes[i], object_ids[i], &data_buffer)); object_already_exists = !data_buffer; } } @@ -911,7 +950,8 @@ Status CoreWorker::AllocateReturnObjects( // This allows the caller to prevent the core worker from storing an output // (e.g., to support ray.experimental.no_return.NoReturn). if (!object_already_exists && (data_buffer || metadatas[i])) { - return_objects->at(i) = std::make_shared(data_buffer, metadatas[i]); + return_objects->at(i) = + std::make_shared(data_buffer, metadatas[i], contained_object_ids[i]); } } @@ -920,7 +960,9 @@ Status CoreWorker::AllocateReturnObjects( Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, const std::shared_ptr &resource_ids, - std::vector> *return_objects) { + std::vector> *return_objects, + ReferenceCounter::ReferenceTableProto *borrowed_refs) { + RAY_LOG(DEBUG) << "Executing task " << task_spec.TaskId(); task_queue_length_ -= 1; num_executed_tasks_ += 1; @@ -939,7 +981,17 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, std::vector> args; std::vector arg_reference_ids; - RAY_CHECK_OK(BuildArgsForExecutor(task_spec, &args, &arg_reference_ids)); + // This includes all IDs that were passed by reference and any IDs that were + // inlined in the task spec. These references will be pinned during the task + // execution and unpinned once the task completes. We will notify the caller + // about any IDs that we are still borrowing by the time the task completes. + std::vector borrowed_ids; + RAY_CHECK_OK(BuildArgsForExecutor(task_spec, &args, &arg_reference_ids, &borrowed_ids)); + // Pin the borrowed IDs for the duration of the task. + for (const auto &borrowed_id : borrowed_ids) { + RAY_LOG(DEBUG) << "Incrementing ref for borrowed ID " << borrowed_id; + reference_counter_->AddLocalReference(borrowed_id); + } const auto transport_type = worker_context_.CurrentTaskIsDirectCall() ? TaskTransportType::DIRECT @@ -986,6 +1038,21 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, } } + // Get the reference counts for any IDs that we borrowed during this task and + // return them to the caller. This will notify the caller of any IDs that we + // (or a nested task) are still borrowing. It will also any new IDs that were + // contained in a borrowed ID that we (or a nested task) are now borrowing. + reference_counter_->GetAndClearLocalBorrowers(borrowed_ids, borrowed_refs); + // Unpin the borrowed IDs. + std::vector deleted; + for (const auto &borrowed_id : borrowed_ids) { + RAY_LOG(DEBUG) << "Decrementing ref for borrowed ID " << borrowed_id; + reference_counter_->RemoveLocalReference(borrowed_id, &deleted); + } + if (ref_counting_enabled_) { + memory_store_->Delete(deleted); + } + if (task_spec.IsNormalTask() && reference_counter_->NumObjectIDsInScope() != 0) { RAY_LOG(DEBUG) << "There were " << reference_counter_->NumObjectIDsInScope() @@ -1001,12 +1068,14 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, absl::MutexLock lock(&mutex_); current_task_ = TaskSpecification(); } + RAY_LOG(DEBUG) << "Finished executing task " << task_spec.TaskId(); return status; } Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, std::vector> *args, - std::vector *arg_reference_ids) { + std::vector *arg_reference_ids, + std::vector *borrowed_ids) { auto num_args = task.NumArgs(); args->resize(num_args); arg_reference_ids->resize(num_args); @@ -1015,10 +1084,9 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, absl::flat_hash_map by_ref_indices; for (size_t i = 0; i < task.NumArgs(); ++i) { - int count = task.ArgIdCount(i); - if (count > 0) { + if (task.ArgByRef(i)) { // pass by reference. - RAY_CHECK(count == 1); + RAY_CHECK(task.ArgIdCount(i) == 1); // Direct call type objects that weren't inlined have been promoted to plasma. // We need to put an OBJECT_IN_PLASMA error here so the subsequent call to Get() // properly redirects to the plasma store. @@ -1026,9 +1094,15 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, RAY_CHECK_OK(memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), task.ArgId(i, 0))); } - by_ref_ids.insert(task.ArgId(i, 0)); - by_ref_indices.emplace(task.ArgId(i, 0), i); - arg_reference_ids->at(i) = task.ArgId(i, 0); + const auto &arg_id = task.ArgId(i, 0); + by_ref_ids.insert(arg_id); + by_ref_indices.emplace(arg_id, i); + arg_reference_ids->at(i) = arg_id; + // The task borrows all args passed by reference. Because the task does + // not have a reference to the argument ID in the frontend, it is not + // possible for the task to still be borrowing the argument by the time + // it finishes. + borrowed_ids->push_back(arg_id); } else { // pass by value. std::shared_ptr data = nullptr; @@ -1041,8 +1115,16 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, metadata = std::make_shared( const_cast(task.ArgMetadata(i)), task.ArgMetadataSize(i)); } - args->at(i) = std::make_shared(data, metadata, /*copy_data*/ true); + args->at(i) = std::make_shared(data, metadata, task.ArgInlinedIds(i), + /*copy_data*/ true); arg_reference_ids->at(i) = ObjectID::Nil(); + // The task borrows all ObjectIDs that were serialized in the inlined + // arguments. The task will receive references to these IDs, so it is + // possible for the task to continue borrowing these arguments by the + // time it finishes. + for (const auto &inlined_id : task.ArgInlinedIds(i)) { + borrowed_ids->push_back(inlined_id); + } } } @@ -1111,6 +1193,7 @@ void CoreWorker::HandleGetObjectStatus(const rpc::GetObjectStatusRequest &reques rpc::GetObjectStatusReply *reply, rpc::SendReplyCallback send_reply_callback) { ObjectID object_id = ObjectID::FromBinary(request.object_id()); + RAY_LOG(DEBUG) << "Received GetObjectStatus " << object_id; TaskID owner_id = TaskID::FromBinary(request.owner_id()); if (owner_id != GetCallerId()) { RAY_LOG(INFO) << "Handling GetObjectStatus for object produced by previous task " @@ -1167,6 +1250,26 @@ void CoreWorker::HandleWaitForObjectEviction( } } +void CoreWorker::HandleWaitForRefRemoved(const rpc::WaitForRefRemovedRequest &request, + rpc::WaitForRefRemovedReply *reply, + rpc::SendReplyCallback send_reply_callback) { + if (HandleWrongRecipient(WorkerID::FromBinary(request.intended_worker_id()), + send_reply_callback)) { + return; + } + const ObjectID &object_id = ObjectID::FromBinary(request.reference().object_id()); + ObjectID contained_in_id = ObjectID::FromBinary(request.contained_in_id()); + const auto owner_id = TaskID::FromBinary(request.reference().owner_id()); + const auto owner_address = request.reference().owner_address(); + auto ref_removed_callback = + boost::bind(&ReferenceCounter::HandleRefRemoved, reference_counter_, object_id, + reply, send_reply_callback); + // Set a callback to send the reply when the requested object ID's ref count + // goes to 0. + reference_counter_->SetRefRemovedCallback(object_id, contained_in_id, owner_id, + owner_address, ref_removed_callback); +} + void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request, rpc::KillActorReply *reply, rpc::SendReplyCallback send_reply_callback) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 9610a8dc9..de89c40d6 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -157,9 +157,14 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// object. /// /// \param[in] object_id The object ID to deserialize. + /// \param[in] outer_object_id The object ID that contained object_id, if + /// any. This may be nil if the object ID was inlined directly in a task spec + /// or if it was passed out-of-band by the application (deserialized from a + /// byte string). /// \param[out] owner_id The ID of the object's owner. /// \param[out] owner_address The address of the object's owner. void RegisterOwnershipInfoAndResolveFuture(const ObjectID &object_id, + const ObjectID &outer_object_id, const TaskID &owner_id, const rpc::Address &owner_address); @@ -222,12 +227,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// /// \param[in] metadata Metadata of the object to be written. /// \param[in] data_size Size of the object to be written. - /// \param[in] contained_object_ids The IDs serialized in this object. /// \param[in] object_id Object ID specified by the user. /// \param[out] data Buffer for the user to write the object into. /// \return Status. Status Create(const std::shared_ptr &metadata, const size_t data_size, - const std::vector &contained_object_ids, const ObjectID &object_id, std::shared_ptr *data); /// Finalize placing an object into the object store. This should be called after @@ -468,6 +471,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { rpc::WaitForObjectEvictionReply *reply, rpc::SendReplyCallback send_reply_callback) override; + /// Implements gRPC server handler. + void HandleWaitForRefRemoved(const rpc::WaitForRefRemovedRequest &request, + rpc::WaitForRefRemovedReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Implements gRPC server handler. void HandleKillActor(const rpc::KillActorRequest &request, rpc::KillActorReply *reply, rpc::SendReplyCallback send_reply_callback) override; @@ -535,33 +543,47 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Execute a task. /// - /// \param spec[in] Task specification. - /// \param spec[in] Resource IDs of resources assigned to this worker. If nullptr, - /// reuse the previously assigned resources. - /// \param results[out] Result objects that should be returned by value (not via - /// plasma). + /// \param spec[in] task_spec Task specification. + /// \param spec[in] resource_ids Resource IDs of resources assigned to this + /// worker. If nullptr, reuse the previously assigned + /// resources. + /// \param results[out] return_objects Result objects that should be returned + /// by value (not via plasma). + /// \param results[out] borrowed_refs Refs that this task (or a nested task) + /// was or is still borrowing. This includes all + /// objects whose IDs we passed to the task in its + /// arguments and recursively, any object IDs that were + /// contained in those objects. /// \return Status. Status ExecuteTask(const TaskSpecification &task_spec, const std::shared_ptr &resource_ids, - std::vector> *return_objects); + std::vector> *return_objects, + ReferenceCounter::ReferenceTableProto *borrowed_refs); /// Build arguments for task executor. This would loop through all the arguments /// in task spec, and for each of them that's passed by reference (ObjectID), /// fetch its content from store and; for arguments that are passed by value, /// just copy their content. /// - /// \param spec[in] Task specification. - /// \param args[out] Argument data as RayObjects. - /// \param args[out] ObjectIDs corresponding to each by reference argument. The length - /// of this vector will be the same as args, and by value arguments - /// will have ObjectID::Nil(). + /// \param spec[in] task Task specification. + /// \param args[out] args Argument data as RayObjects. + /// \param args[out] arg_reference_ids ObjectIDs corresponding to each by + /// reference argument. The length of this vector will be + /// the same as args, and by value arguments will have + /// ObjectID::Nil(). /// // TODO(edoakes): this is a bit of a hack that's necessary because /// we have separate serialization paths for by-value and by-reference /// arguments in Python. This should ideally be handled better there. + /// \param args[out] borrowed_ids ObjectIDs that we are borrowing from the + /// task caller for the duration of the task execution. This + /// vector will be populated with all argument IDs that were + /// passed by reference and any ObjectIDs that were included + /// in the task spec's inlined arguments. /// \return The arguments for passing to task executor. Status BuildArgsForExecutor(const TaskSpecification &task, std::vector> *args, - std::vector *arg_reference_ids); + std::vector *arg_reference_ids, + std::vector *borrowed_ids); /// Returns whether the message was sent to the wrong worker. The right error reply /// is sent automatically. Messages end up on the wrong worker when a worker dies diff --git a/src/ray/core_worker/future_resolver.cc b/src/ray/core_worker/future_resolver.cc index 8b7a2cec8..0afb1c657 100644 --- a/src/ray/core_worker/future_resolver.cc +++ b/src/ray/core_worker/future_resolver.cc @@ -8,8 +8,8 @@ void FutureResolver::ResolveFutureAsync(const ObjectID &object_id, const TaskID absl::MutexLock lock(&mu_); auto it = owner_clients_.find(owner_id); if (it == owner_clients_.end()) { - auto client = std::shared_ptr( - client_factory_(owner_address.ip_address(), owner_address.port())); + auto client = + std::shared_ptr(client_factory_(owner_address)); it = owner_clients_.emplace(owner_id, std::move(client)).first; } diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index f10dbf95f..d2b39f1dd 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -332,7 +332,9 @@ inline std::shared_ptr JavaNativeRayObjectToNativeRayObject( if (metadata_buffer && metadata_buffer->Size() == 0) { metadata_buffer = nullptr; } - return std::make_shared(data_buffer, metadata_buffer); + // TODO: Support nested IDs for Java. + return std::make_shared(data_buffer, metadata_buffer, + std::vector()); } /// Convert a C++ ray::RayObject to a Java NativeRayObject. diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index 0e81a48ac..0daf0edad 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -1,21 +1,82 @@ #include "ray/core_worker/reference_count.h" +#define PRINT_REF_COUNT(it) \ + RAY_LOG(DEBUG) << "REF " << it->first << " borrowers: " << it->second.borrowers.size() \ + << " local_ref_count: " << it->second.local_ref_count \ + << " submitted_count: " << it->second.submitted_task_ref_count \ + << " contained_in_owned: " << it->second.contained_in_owned.size() \ + << " contained_in_borrowed: " \ + << (it->second.contained_in_borrowed_id.has_value() \ + ? *it->second.contained_in_borrowed_id \ + : ObjectID::Nil()) \ + << " contains: " << it->second.contains.size(); + +namespace {} // namespace + namespace ray { -void ReferenceCounter::AddBorrowedObject(const ObjectID &object_id, - const TaskID &owner_id, - const rpc::Address &owner_address) { - absl::MutexLock lock(&mutex_); - auto it = object_id_refs_.find(object_id); - RAY_CHECK(it != object_id_refs_.end()); +ReferenceCounter::ReferenceTable ReferenceCounter::ReferenceTableFromProto( + const ReferenceTableProto &proto) { + ReferenceTable refs; + for (const auto &ref : proto) { + refs[ray::ObjectID::FromBinary(ref.reference().object_id())] = + Reference::FromProto(ref); + } + return refs; +} - if (!it->second.owner.has_value()) { - it->second.owner = {owner_id, owner_address}; +void ReferenceCounter::ReferenceTableToProto(const ReferenceTable &table, + ReferenceTableProto *proto) { + for (const auto &id_ref : table) { + auto ref = proto->Add(); + id_ref.second.ToProto(ref); + ref->mutable_reference()->set_object_id(id_ref.first.Binary()); } } -void ReferenceCounter::AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id, +bool ReferenceCounter::AddBorrowedObject(const ObjectID &object_id, + const ObjectID &outer_id, const TaskID &owner_id, + const rpc::Address &owner_address) { + absl::MutexLock lock(&mutex_); + return AddBorrowedObjectInternal(object_id, outer_id, owner_id, owner_address); +} + +bool ReferenceCounter::AddBorrowedObjectInternal(const ObjectID &object_id, + const ObjectID &outer_id, + const TaskID &owner_id, + const rpc::Address &owner_address) { + auto it = object_id_refs_.find(object_id); + RAY_CHECK(it != object_id_refs_.end()); + + RAY_LOG(DEBUG) << "Adding borrowed object " << object_id; + // Skip adding this object as a borrower if we already have ownership info. + // If we already have ownership info, then either we are the owner or someone + // else already knows that we are a borrower. + if (it->second.owner.has_value()) { + RAY_LOG(DEBUG) << "Skipping add borrowed object " << object_id; + return false; + } + + it->second.owner = {owner_id, owner_address}; + + if (!outer_id.IsNil()) { + auto outer_it = object_id_refs_.find(outer_id); + if (outer_it != object_id_refs_.end() && !outer_it->second.owned_by_us) { + RAY_LOG(DEBUG) << "Setting borrowed inner ID " << object_id + << " contained_in_borrowed: " << outer_id; + RAY_CHECK(!it->second.contained_in_borrowed_id.has_value()); + it->second.contained_in_borrowed_id = outer_id; + outer_it->second.contains.insert(object_id); + } + } + return true; +} + +void ReferenceCounter::AddOwnedObject(const ObjectID &object_id, + const std::vector &inner_ids, + const TaskID &owner_id, const rpc::Address &owner_address) { + RAY_LOG(DEBUG) << "Adding owned object " << object_id; absl::MutexLock lock(&mutex_); RAY_CHECK(object_id_refs_.count(object_id) == 0) << "Tried to create an owned object that already exists: " << object_id; @@ -23,6 +84,9 @@ void ReferenceCounter::AddOwnedObject(const ObjectID &object_id, const TaskID &o // because this corresponds to a submitted task whose return ObjectID will be created // in the frontend language, incrementing the reference count. object_id_refs_.emplace(object_id, Reference(owner_id, owner_address)); + // Mark that this object ID contains other inner IDs. Then, we will not remove + // the inner objects until the outer object ID goes out of scope. + WrapObjectIdsInternal(object_id, inner_ids, absl::optional()); } void ReferenceCounter::AddLocalReference(const ObjectID &object_id) { @@ -33,6 +97,8 @@ void ReferenceCounter::AddLocalReference(const ObjectID &object_id) { it = object_id_refs_.emplace(object_id, Reference()).first; } it->second.local_ref_count++; + RAY_LOG(DEBUG) << "Add local reference " << object_id; + PRINT_REF_COUNT(it); } void ReferenceCounter::RemoveLocalReference(const ObjectID &object_id, @@ -44,7 +110,16 @@ void ReferenceCounter::RemoveLocalReference(const ObjectID &object_id, << object_id; return; } - if (--it->second.local_ref_count == 0 && it->second.submitted_task_ref_count == 0) { + if (it->second.local_ref_count == 0) { + RAY_LOG(WARNING) + << "Tried to decrease ref count for object ID that has count 0 " << object_id + << ". This should only happen if ray.internal.free was called earlier."; + return; + } + it->second.local_ref_count--; + RAY_LOG(DEBUG) << "Remove local reference " << object_id; + PRINT_REF_COUNT(it); + if (it->second.RefCount() == 0) { DeleteReferenceInternal(it, deleted); } } @@ -63,9 +138,22 @@ void ReferenceCounter::AddSubmittedTaskReferences( } } -void ReferenceCounter::RemoveSubmittedTaskReferences( - const std::vector &object_ids, std::vector *deleted) { +void ReferenceCounter::UpdateSubmittedTaskReferences( + const std::vector &object_ids, const rpc::Address &worker_addr, + const ReferenceTableProto &borrowed_refs, std::vector *deleted) { absl::MutexLock lock(&mutex_); + // Must merge the borrower refs before decrementing any ref counts. This is + // to make sure that for serialized IDs, we increment the borrower count for + // the inner ID before decrementing the submitted_task_ref_count for the + // outer ID. + const auto refs = ReferenceTableFromProto(borrowed_refs); + if (!refs.empty()) { + RAY_CHECK(!WorkerID::FromBinary(worker_addr.worker_id()).IsNil()); + } + for (const ObjectID &object_id : object_ids) { + MergeRemoteBorrowers(object_id, worker_addr, refs); + } + for (const ObjectID &object_id : object_ids) { auto it = object_id_refs_.find(object_id); if (it == object_id_refs_.end()) { @@ -73,7 +161,8 @@ void ReferenceCounter::RemoveSubmittedTaskReferences( << object_id; return; } - if (--it->second.submitted_task_ref_count == 0 && it->second.local_ref_count == 0) { + it->second.submitted_task_ref_count--; + if (it->second.RefCount() == 0) { DeleteReferenceInternal(it, deleted); } } @@ -103,20 +192,83 @@ void ReferenceCounter::DeleteReferences(const std::vector &object_ids) if (it == object_id_refs_.end()) { return; } + it->second.local_ref_count = 0; + it->second.submitted_task_ref_count = 0; + if (distributed_ref_counting_enabled_ && !it->second.CanDelete()) { + RAY_LOG(ERROR) + << "ray.internal.free does not currently work for objects that are still in " + "scope when distributed reference " + "counting is enabled. Try disabling ref counting by passing " + "distributed_ref_counting_enabled: 0 in the ray.init internal config."; + } DeleteReferenceInternal(it, nullptr); } } -void ReferenceCounter::DeleteReferenceInternal( - absl::flat_hash_map::iterator it, - std::vector *deleted) { - if (it->second.on_delete) { - it->second.on_delete(it->first); +void ReferenceCounter::DeleteReferenceInternal(ReferenceTable::iterator it, + std::vector *deleted) { + const ObjectID id = it->first; + RAY_LOG(DEBUG) << "Attempting to delete object " << id; + if (distributed_ref_counting_enabled_ && it->second.RefCount() == 0 && + it->second.on_ref_removed) { + RAY_LOG(DEBUG) << "Calling on_ref_removed for object " << id; + it->second.on_ref_removed(id); + it->second.on_ref_removed = nullptr; } - if (deleted) { - deleted->push_back(it->first); + PRINT_REF_COUNT(it); + + // Whether it is safe to unpin the value. + bool should_delete_value = false; + // Whether it is safe to delete the Reference. + bool should_delete_reference = false; + + // If distributed ref counting is not enabled, then delete the object as soon + // as its local ref count goes to 0. + size_t local_ref_count = + it->second.local_ref_count + it->second.submitted_task_ref_count; + if (!distributed_ref_counting_enabled_ && local_ref_count == 0) { + should_delete_value = true; + } + + if (it->second.CanDelete()) { + // If distributed ref counting is enabled, then delete the object once its + // ref count across all processes is 0. + should_delete_value = true; + should_delete_reference = true; + for (const auto &inner_id : it->second.contains) { + auto inner_it = object_id_refs_.find(inner_id); + if (inner_it != object_id_refs_.end()) { + RAY_LOG(DEBUG) << "Try to delete inner object " << inner_id; + if (it->second.owned_by_us) { + // If this object ID was nested in an owned object, make sure that + // the outer object counted towards the ref count for the inner + // object. + RAY_CHECK(inner_it->second.contained_in_owned.erase(id)); + } else { + // If this object ID was nested in a borrowed object, make sure that + // we have already returned this information through a previous + // GetAndClearLocalBorrowers call. + RAY_CHECK(!inner_it->second.contained_in_borrowed_id.has_value()); + } + DeleteReferenceInternal(inner_it, deleted); + } + } + } + + // Perform the deletion. + if (should_delete_value) { + RAY_LOG(DEBUG) << "Deleting object " << id; + if (it->second.on_delete) { + it->second.on_delete(id); + it->second.on_delete = nullptr; + } + if (deleted) { + deleted->push_back(id); + } + } + if (should_delete_reference) { + object_id_refs_.erase(it); } - object_id_refs_.erase(it); } bool ReferenceCounter::SetDeleteCallback( @@ -126,7 +278,7 @@ bool ReferenceCounter::SetDeleteCallback( if (it == object_id_refs_.end()) { return false; } - RAY_CHECK(!it->second.on_delete); + RAY_CHECK(!it->second.on_delete) << object_id; it->second.on_delete = callback; return true; } @@ -164,4 +316,337 @@ ReferenceCounter::GetAllReferenceCounts() const { return all_ref_counts; } +void ReferenceCounter::GetAndClearLocalBorrowers( + const std::vector &borrowed_ids, + ReferenceCounter::ReferenceTableProto *proto) { + absl::MutexLock lock(&mutex_); + ReferenceTable borrowed_refs; + for (const auto &borrowed_id : borrowed_ids) { + RAY_CHECK(GetAndClearLocalBorrowersInternal(borrowed_id, &borrowed_refs)) + << borrowed_id; + // Decrease the ref count for each of the borrowed IDs. This is because we + // artificially increment each borrowed ID to keep it pinned during task + // execution. However, this should not count towards the final ref count + // returned to the task's caller. + auto it = borrowed_refs.find(borrowed_id); + if (it != borrowed_refs.end()) { + it->second.local_ref_count--; + } + } + ReferenceTableToProto(borrowed_refs, proto); +} + +bool ReferenceCounter::GetAndClearLocalBorrowersInternal(const ObjectID &object_id, + ReferenceTable *borrowed_refs) { + RAY_LOG(DEBUG) << "Pop " << object_id; + auto it = object_id_refs_.find(object_id); + if (it == object_id_refs_.end()) { + return false; + } + + // We only borrow objects that we do not own. This is not an assertion + // because it is possible to receive a reference to an object that we already + // own, e.g., if we execute a task that has an object ID in its arguments + // that we created in an earlier task. + if (it->second.owned_by_us) { + // Return true because we have the ref, but there is no need to return it + // since we own the object. + return true; + } + + borrowed_refs->emplace(object_id, it->second); + // Clear the local list of borrowers that we have accumulated. The receiver + // of the returned borrowed_refs must merge this list into their own list + // until all active borrowers are merged into the owner. + it->second.borrowers.clear(); + + if (it->second.contained_in_borrowed_id.has_value()) { + /// This ID was nested in another ID that we (or a nested task) borrowed. + /// Make sure that we also returned the ID that contained it. + RAY_CHECK(borrowed_refs->count(it->second.contained_in_borrowed_id.value()) > 0); + /// Clear the fact that this ID was nested because we are including it in + /// the returned borrowed_refs. If the nested ID is not being borrowed by + /// us, then it will be deleted recursively when deleting the outer ID. + it->second.contained_in_borrowed_id.reset(); + } + + // Attempt to pop children. + for (const auto &contained_id : it->second.contains) { + GetAndClearLocalBorrowersInternal(contained_id, borrowed_refs); + } + + return true; +} + +void ReferenceCounter::MergeRemoteBorrowers(const ObjectID &object_id, + const rpc::WorkerAddress &worker_addr, + const ReferenceTable &borrowed_refs) { + RAY_LOG(DEBUG) << "Merging ref " << object_id; + auto borrower_it = borrowed_refs.find(object_id); + if (borrower_it == borrowed_refs.end()) { + return; + } + const auto &borrower_ref = borrower_it->second; + RAY_LOG(DEBUG) << "Borrower ref " << object_id << " has " + << borrower_ref.borrowers.size() << " borrowers " + << ", has local: " << borrower_ref.local_ref_count + << " submitted: " << borrower_ref.submitted_task_ref_count + << " contained_in_owned " << borrower_ref.contained_in_owned.size(); + + auto it = object_id_refs_.find(object_id); + if (it == object_id_refs_.end()) { + it = object_id_refs_.emplace(object_id, Reference()).first; + } + if (!it->second.owner.has_value() && + borrower_ref.contained_in_borrowed_id.has_value()) { + // We don't have owner information about this object ID yet and the worker + // received it because it was nested in another ID that the worker was + // borrowing. Copy this information to our local table. + RAY_CHECK(borrower_ref.owner.has_value()); + AddBorrowedObjectInternal(object_id, *borrower_it->second.contained_in_borrowed_id, + borrower_ref.owner->first, borrower_ref.owner->second); + } + std::vector new_borrowers; + + // The worker is still using the reference, so it is still a borrower. + if (borrower_ref.RefCount() > 0) { + auto inserted = it->second.borrowers.insert(worker_addr).second; + // If we are the owner of id, then send WaitForRefRemoved to borrower. + if (inserted) { + RAY_LOG(DEBUG) << "Adding borrower " << worker_addr.ip_address << " to id " + << object_id; + new_borrowers.push_back(worker_addr); + } + } + + // Add any other workers that this worker passed the ID to as new borrowers. + for (const auto &nested_borrower : borrower_ref.borrowers) { + auto inserted = it->second.borrowers.insert(nested_borrower).second; + if (inserted) { + RAY_LOG(DEBUG) << "Adding borrower " << nested_borrower.ip_address << " to id " + << object_id; + new_borrowers.push_back(nested_borrower); + } + } + + // If we own this ID, then wait for all new borrowers to reach a ref count + // of 0 before GCing the object value. + if (it->second.owned_by_us) { + for (const auto &addr : new_borrowers) { + WaitForRefRemoved(it, addr); + } + } + + // Recursively merge any references that were contained in this object, to + // handle any borrowers of nested objects. + for (const auto &inner_id : borrower_ref.contains) { + MergeRemoteBorrowers(inner_id, worker_addr, borrowed_refs); + } +} + +void ReferenceCounter::WaitForRefRemoved(const ReferenceTable::iterator &ref_it, + const rpc::WorkerAddress &addr, + const ObjectID &contained_in_id) { + const ObjectID &object_id = ref_it->first; + rpc::WaitForRefRemovedRequest request; + // Only the owner should send requests to borrowers. + RAY_CHECK(ref_it->second.owned_by_us); + request.mutable_reference()->set_object_id(object_id.Binary()); + request.mutable_reference()->set_owner_id(ref_it->second.owner->first.Binary()); + request.mutable_reference()->mutable_owner_address()->CopyFrom( + ref_it->second.owner->second); + request.set_contained_in_id(contained_in_id.Binary()); + request.set_intended_worker_id(addr.worker_id.Binary()); + + auto it = borrower_cache_.find(addr); + if (it == borrower_cache_.end()) { + RAY_CHECK(client_factory_ != nullptr); + it = borrower_cache_.emplace(addr, client_factory_(addr.ToProto())).first; + RAY_LOG(DEBUG) << "Connected to borrower " << addr.ip_address << ":" << addr.port + << " for object " << object_id; + } + + // Send the borrower a message about this object. The borrower responds once + // it is no longer using the object ID. + RAY_CHECK_OK(it->second->WaitForRefRemoved( + request, [this, object_id, addr](const Status &status, + const rpc::WaitForRefRemovedReply &reply) { + RAY_LOG(DEBUG) << "Received reply from borrower " << addr.ip_address << ":" + << addr.port << " of object " << object_id; + absl::MutexLock lock(&mutex_); + auto it = object_id_refs_.find(object_id); + RAY_CHECK(it != object_id_refs_.end()); + RAY_CHECK(it->second.borrowers.erase(addr)); + + const ReferenceTable new_borrower_refs = + ReferenceTableFromProto(reply.borrowed_refs()); + + MergeRemoteBorrowers(object_id, addr, new_borrower_refs); + DeleteReferenceInternal(it, nullptr); + })); +} + +void ReferenceCounter::WrapObjectIds( + const ObjectID &object_id, const std::vector &inner_ids, + const absl::optional &owner_address) { + absl::MutexLock lock(&mutex_); + WrapObjectIdsInternal(object_id, inner_ids, owner_address); +} + +void ReferenceCounter::WrapObjectIdsInternal( + const ObjectID &object_id, const std::vector &inner_ids, + const absl::optional &owner_address) { + auto it = object_id_refs_.find(object_id); + if (!owner_address.has_value()) { + // `ray.put()` case OR returning an object ID from a task and the task's + // caller executed in the same process as us. + if (it != object_id_refs_.end()) { + RAY_CHECK(it->second.owned_by_us); + // The outer object is still in scope. Mark the inner ones as being + // contained in the outer object ID so we do not GC the inner objects + // until the outer object goes out of scope. + for (const auto &inner_id : inner_ids) { + it->second.contains.insert(inner_id); + auto inner_it = object_id_refs_.find(inner_id); + RAY_CHECK(inner_it != object_id_refs_.end()); + RAY_LOG(DEBUG) << "Setting inner ID " << inner_id + << " contained_in_owned: " << object_id; + inner_it->second.contained_in_owned.insert(object_id); + } + } + } else { + // Returning an object ID from a task, and the task's caller executed in a + // remote process. + for (const auto &inner_id : inner_ids) { + auto inner_it = object_id_refs_.find(inner_id); + RAY_CHECK(inner_it != object_id_refs_.end()); + if (!inner_it->second.owned_by_us) { + RAY_LOG(WARNING) + << "Ref counting currently does not support returning an object ID that was " + "not created by the worker executing the task. The object may be evicted " + "before all references are out of scope."; + // TODO: Do not return. Handle the case where we return a BORROWED id. + continue; + } + // Add the task's caller as a borrower. + auto inserted = inner_it->second.borrowers.insert(*owner_address).second; + if (inserted) { + RAY_LOG(DEBUG) << "Adding borrower " << owner_address->ip_address << " to id " + << object_id << ", borrower owns outer ID " << object_id; + // Wait for it to remove its + // reference. + WaitForRefRemoved(inner_it, *owner_address, object_id); + } + } + } +} + +void ReferenceCounter::HandleRefRemoved(const ObjectID &object_id, + rpc::WaitForRefRemovedReply *reply, + rpc::SendReplyCallback send_reply_callback) { + ReferenceTable borrowed_refs; + RAY_UNUSED(GetAndClearLocalBorrowersInternal(object_id, &borrowed_refs)); + for (const auto &pair : borrowed_refs) { + RAY_LOG(DEBUG) << pair.first << " has " << pair.second.borrowers.size() + << " borrowers"; + } + auto it = object_id_refs_.find(object_id); + if (it != object_id_refs_.end()) { + // We should only have called this callback once our local ref count for + // the object was zero. Also, we should have stripped all distributed ref + // count information and returned it to the owner. Therefore, it should be + // okay to delete the object, if it wasn't already deleted. + RAY_CHECK(it->second.CanDelete()); + } + // Send the owner information about any new borrowers. + ReferenceTableToProto(borrowed_refs, reply->mutable_borrowed_refs()); + + RAY_LOG(DEBUG) << "Replying to WaitForRefRemoved, reply has " + << reply->borrowed_refs().size(); + send_reply_callback(Status::OK(), nullptr, nullptr); +} + +void ReferenceCounter::SetRefRemovedCallback( + const ObjectID &object_id, const ObjectID &contained_in_id, const TaskID &owner_id, + const rpc::Address &owner_address, + const ReferenceCounter::ReferenceRemovedCallback &ref_removed_callback) { + absl::MutexLock lock(&mutex_); + RAY_LOG(DEBUG) << "Received WaitForRefRemoved " << object_id << " contained in " + << contained_in_id; + + auto it = object_id_refs_.find(object_id); + if (it == object_id_refs_.end()) { + it = object_id_refs_.emplace(object_id, Reference()).first; + } + + // If we are borrowing the ID because we own an object that contains it, then + // add the outer object to the inner ID's ref count. We will not respond to + // the owner of the inner ID until the outer object ID goes out of scope. + if (!contained_in_id.IsNil()) { + AddBorrowedObjectInternal(object_id, contained_in_id, owner_id, owner_address); + WrapObjectIdsInternal(contained_in_id, {object_id}, + absl::optional()); + } + + if (it->second.RefCount() == 0) { + // We already stopped borrowing the object ID. Respond to the owner + // immediately. + ref_removed_callback(object_id); + DeleteReferenceInternal(it, nullptr); + } else { + // We are still borrowing the object ID. Respond to the owner once we have + // stopped borrowing it. + if (it->second.on_ref_removed != nullptr) { + // TODO(swang): If the owner of an object dies and and is re-executed, it + // is possible that we will receive a duplicate request to set + // on_ref_removed. If messages are delayed and we overwrite the + // callback here, it's possible we will drop the request that was sent by + // the more recent owner. We should fix this by setting multiple + // callbacks or by versioning the owner requests. + RAY_LOG(WARNING) << "on_ref_removed already set for " << object_id + << ". The owner task must have died and been re-executed."; + } + it->second.on_ref_removed = ref_removed_callback; + } +} + +ReferenceCounter::Reference ReferenceCounter::Reference::FromProto( + const rpc::ObjectReferenceCount &ref_count) { + Reference ref; + ref.owner = {TaskID::FromBinary(ref_count.reference().owner_id()), + ref_count.reference().owner_address()}; + ref.local_ref_count = ref_count.has_local_ref() ? 1 : 0; + + for (const auto &borrower : ref_count.borrowers()) { + ref.borrowers.insert(rpc::WorkerAddress(borrower)); + } + for (const auto &id : ref_count.contains()) { + ref.contains.insert(ObjectID::FromBinary(id)); + } + const auto contained_in_borrowed_id = + ObjectID::FromBinary(ref_count.contained_in_borrowed_id()); + if (!contained_in_borrowed_id.IsNil()) { + ref.contained_in_borrowed_id = contained_in_borrowed_id; + } + return ref; +} + +void ReferenceCounter::Reference::ToProto(rpc::ObjectReferenceCount *ref) const { + if (owner.has_value()) { + ref->mutable_reference()->set_owner_id(owner->first.Binary()); + ref->mutable_reference()->mutable_owner_address()->CopyFrom(owner->second); + } + bool has_local_ref = RefCount() > 0; + ref->set_has_local_ref(has_local_ref); + for (const auto &borrower : borrowers) { + ref->add_borrowers()->CopyFrom(borrower.ToProto()); + } + if (contained_in_borrowed_id.has_value()) { + ref->set_contained_in_borrowed_id(contained_in_borrowed_id->Binary()); + } + for (const auto &contains_id : contains) { + ref->add_contains(contains_id.Binary()); + } +} + } // namespace ray diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index f17199718..919d5692b 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -3,18 +3,30 @@ #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/synchronization/mutex.h" #include "ray/common/id.h" #include "ray/protobuf/common.pb.h" +#include "ray/rpc/grpc_server.h" +#include "ray/rpc/worker/core_worker_client.h" #include "ray/util/logging.h" +#include + namespace ray { /// Class used by the core worker to keep track of ObjectID reference counts for garbage /// collection. This class is thread safe. class ReferenceCounter { public: - ReferenceCounter() {} + using ReferenceTableProto = + ::google::protobuf::RepeatedPtrField; + using ReferenceRemovedCallback = std::function; + + ReferenceCounter(bool distributed_ref_counting_enabled = true, + rpc::ClientFactoryFn client_factory = nullptr) + : distributed_ref_counting_enabled_(distributed_ref_counting_enabled), + client_factory_(client_factory) {} ~ReferenceCounter() {} @@ -39,13 +51,22 @@ class ReferenceCounter { void AddSubmittedTaskReferences(const std::vector &object_ids) LOCKS_EXCLUDED(mutex_); - /// Remove references for the provided object IDs that correspond to them being - /// dependencies to a submitted task. This should be called when inlined - /// dependencies are inlined or when the task finishes for plasma dependencies. + /// Update object references that were given to a submitted task. The task + /// may still be borrowing any object IDs that were contained in its + /// arguments. This should be called when inlined dependencies are inlined or + /// when the task finishes for plasma dependencies. /// /// \param[in] object_ids The object IDs to remove references for. + /// \param[in] worker_addr The address of the worker that executed the task. + /// \param[in] borrowed_refs The references that the worker borrowed during + /// the task. This table includes all task arguments that were passed by + /// reference and any object IDs that were transitively nested in the + /// arguments. Some references in this table may still be borrowed by the + /// worker and/or a task that the worker submitted. /// \param[out] deleted The object IDs whos reference counts reached zero. - void RemoveSubmittedTaskReferences(const std::vector &object_ids, + void UpdateSubmittedTaskReferences(const std::vector &object_ids, + const rpc::Address &worker_addr, + const ReferenceTableProto &borrowed_refs, std::vector *deleted) LOCKS_EXCLUDED(mutex_); @@ -60,20 +81,28 @@ class ReferenceCounter { /// possible to have leftover references after a task has finished. /// /// \param[in] object_id The ID of the object that we own. + /// \param[in] inner_ids ObjectIDs that are contained in the object's value. + /// As long as the object_id is in scope, the inner objects should not be GC'ed. /// \param[in] owner_id The ID of the object's owner. /// \param[in] owner_address The address of the object's owner. /// \param[in] dependencies The objects that the object depends on. - void AddOwnedObject(const ObjectID &object_id, const TaskID &owner_id, + void AddOwnedObject(const ObjectID &object_id, + const std::vector &contained_ids, const TaskID &owner_id, const rpc::Address &owner_address) LOCKS_EXCLUDED(mutex_); /// Add an object that we are borrowing. /// /// \param[in] object_id The ID of the object that we are borrowing. + /// \param[in] outer_id The ID of the object that contained this object ID, + /// if one exists. An outer_id may not exist if object_id was inlined + /// directly in a task spec, or if it was passed in the application + /// out-of-band. /// \param[in] owner_id The ID of the owner of the object. This is either the /// task ID (for non-actors) or the actor ID of the owner. /// \param[in] owner_address The owner's address. - void AddBorrowedObject(const ObjectID &object_id, const TaskID &owner_id, - const rpc::Address &owner_address) LOCKS_EXCLUDED(mutex_); + bool AddBorrowedObject(const ObjectID &object_id, const ObjectID &outer_id, + const TaskID &owner_id, const rpc::Address &owner_address) + LOCKS_EXCLUDED(mutex_); /// Get the owner ID and address of the given object. /// @@ -92,12 +121,41 @@ class ReferenceCounter { const std::function callback) LOCKS_EXCLUDED(mutex_); + /// Set a callback for when we are no longer borrowing this object (when our + /// ref count goes to 0). + /// + /// \param[in] object_id The object ID to set the callback for. + /// \param[in] contained_in_id The object ID that contains object_id, if any. + /// This is used for cases when object_id was returned from a task that we + /// submitted. Then, as long as we have contained_in_id in scope, we are + /// borrowing object_id. + /// \param[in] owner_id The ID of the owner of object_id. This is either the + /// task ID (for non-actors) or the actor ID of the owner. + /// \param[in] owner_address The owner of object_id's address. + /// \param[in] ref_removed_callback The callback to call when we are no + /// longer borrowing the object. + void SetRefRemovedCallback(const ObjectID &object_id, const ObjectID &contained_in_id, + const TaskID &owner_id, const rpc::Address &owner_address, + const ReferenceRemovedCallback &ref_removed_callback) + LOCKS_EXCLUDED(mutex_); + + /// Respond to the object's owner once we are no longer borrowing it. The + /// sender is the owner of the object ID. We will send the reply when our + /// RefCount() for the object ID goes to 0. + /// + /// \param[in] object_id The object that we were borrowing. + /// \param[in] reply A reply sent to the owner when we are no longer + /// borrowing the object ID. This reply also includes any new borrowers and + /// any object IDs that were nested inside the object that we or others are + /// now borrowing. + /// \param[in] send_reply_callback The callback to send the reply. + void HandleRefRemoved(const ObjectID &object_id, rpc::WaitForRefRemovedReply *reply, + rpc::SendReplyCallback send_reply_callback) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + /// Returns the total number of ObjectIDs currently in scope. size_t NumObjectIDsInScope() const LOCKS_EXCLUDED(mutex_); - /// Returns whether this object has an active reference. - bool HasReference(const ObjectID &object_id) const LOCKS_EXCLUDED(mutex_); - /// Returns a set of all ObjectIDs currently in scope (i.e., nonzero reference count). std::unordered_set GetAllInScopeObjectIDs() const LOCKS_EXCLUDED(mutex_); @@ -106,18 +164,79 @@ class ReferenceCounter { std::unordered_map> GetAllReferenceCounts() const LOCKS_EXCLUDED(mutex_); + /// Populate a table with ObjectIDs that we were or are still borrowing. + /// This should be called when a task returns, and the argument should be any + /// IDs that were passed by reference in the task spec or that were + /// serialized in inlined arguments. + /// + /// See GetAndClearLocalBorrowersInternal for the spec of the returned table + /// and how this mutates the local reference count. + /// + /// \param[in] borrowed_ids The object IDs that we or another worker were or + /// are still borrowing. These are the IDs that were given to us via task + /// submission and includes: (1) any IDs that were passed by reference in the + /// task spec, and (2) any IDs that were serialized in the task's inlined + /// arguments. + /// \param[out] proto The protobuf table to populate with the borrowed + /// references. + void GetAndClearLocalBorrowers(const std::vector &borrowed_ids, + ReferenceTableProto *proto) LOCKS_EXCLUDED(mutex_); + + /// Wrap ObjectIDs inside another object ID. + /// + /// \param[in] object_id The object ID whose value we are storing. + /// \param[in] inner_ids The object IDs that we are storing in object_id. + /// \param[in] owner_address The owner address of the outer object_id. If + /// this is not provided, then the outer object ID must be owned by us. the + /// outer object ID is not owned by us, then this is used to contact the + /// outer object's owner, since it is considered a borrower for the inner + /// IDs. + void WrapObjectIds(const ObjectID &object_id, const std::vector &inner_ids, + const absl::optional &owner_address) + LOCKS_EXCLUDED(mutex_); + + /// Whether we have a reference to a particular ObjectID. + /// + /// \param[in] object_id The object ID to check for. + /// \return Whether we have a reference to the object ID. + bool HasReference(const ObjectID &object_id) const LOCKS_EXCLUDED(mutex_); + private: - /// Metadata for an ObjectID reference in the language frontend. struct Reference { /// Constructor for a reference whose origin is unknown. Reference() : owned_by_us(false) {} /// Constructor for a reference that we created. Reference(const TaskID &owner_id, const rpc::Address &owner_address) : owned_by_us(true), owner({owner_id, owner_address}) {} - /// The local ref count for the ObjectID in the language frontend. - size_t local_ref_count = 0; - /// The ref count for submitted tasks that depend on the ObjectID. - size_t submitted_task_ref_count = 0; + + /// Constructor from a protobuf. This is assumed to be a message from + /// another process, so the object defaults to not being owned by us. + static Reference FromProto(const rpc::ObjectReferenceCount &ref_count); + /// Serialize to a protobuf. + void ToProto(rpc::ObjectReferenceCount *ref) const; + + /// The reference count. This number includes: + /// - Python references to the ObjectID. + /// - Pending submitted tasks that depend on the object. + /// - ObjectIDs that we own, that contain this ObjectID, and that are still + /// in scope. + size_t RefCount() const { + return local_ref_count + submitted_task_ref_count + contained_in_owned.size(); + } + + /// Whether we can delete this reference. A reference can NOT be deleted if + /// any of the following are true: + /// - The reference is still being used by this process. + /// - The reference was contained in another ID that we were borrowing, and + /// we haven't told the process that gave us that ID yet. + /// - We gave the reference to at least one other process. + bool CanDelete() const { + bool in_scope = RefCount() > 0; + bool was_contained_in_borrowed_id = contained_in_borrowed_id.has_value(); + bool has_borrowers = borrowers.size() > 0; + return !(in_scope || was_contained_in_borrowed_id || has_borrowers); + } + /// Whether we own the object. If we own the object, then we are /// responsible for tracking the state of the task that creates the object /// (see task_manager.h). @@ -126,22 +245,174 @@ class ReferenceCounter { /// if we do not know the object's owner (because distributed ref counting /// is not yet implemented). absl::optional> owner; - /// Callback that will be called when this ObjectID no longer has references. + + /// The local ref count for the ObjectID in the language frontend. + size_t local_ref_count = 0; + /// The ref count for submitted tasks that depend on the ObjectID. + size_t submitted_task_ref_count = 0; + /// Object IDs that we own and that contain this object ID. + /// ObjectIDs are added to this field when we discover that this object + /// contains other IDs. This can happen in 2 cases: + /// 1. We call ray.put() and store the inner ID(s) in the outer object. + /// 2. A task that we submitted returned an ID(s). + /// ObjectIDs are erased from this field when their Reference is deleted. + absl::flat_hash_set contained_in_owned; + /// An Object ID that we (or one of our children) borrowed that contains + /// this object ID, which is also borrowed. This is used in cases where an + /// ObjectID is nested. We need to notify the owner of the outer ID of any + /// borrowers of this object, so we keep this field around until + /// GetAndClearLocalBorrowersInternal is called on the outer ID. This field + /// is updated in 2 cases: + /// 1. We deserialize an ID that we do not own and that was stored in + /// another object that we do not own. + /// 2. Case (1) occurred for a task that we submitted and we also do not + /// own the inner or outer object. Then, we need to notify our caller + /// that the task we submitted is a borrower for the inner ID. + /// This field is reset to null once GetAndClearLocalBorrowersInternal is + /// called on contained_in_borrowed_id. For each borrower, this field is + /// set at most once during the reference's lifetime. If the object ID is + /// later found to be nested in a second object, we do not need to remember + /// the second ID because we will already have notified the owner of the + /// first outer object about our reference. + absl::optional contained_in_borrowed_id; + /// The object IDs contained in this object. These could be objects that we + /// own or are borrowing. This field is updated in 2 cases: + /// 1. We call ray.put() on this ID and store the contained IDs. + /// 2. We call ray.get() on an ID whose contents we do not know and we + /// discover that it contains these IDs. + absl::flat_hash_set contains; + /// A list of processes that are we gave a reference to that are still + /// borrowing the ID. This field is updated in 2 cases: + /// 1. If we are a borrower of the ID, then we add a process to this list + /// if we passed that process a copy of the ID via task submission and + /// the process is still using the ID by the time it finishes its task. + /// Borrowers are removed from the list when we recursively merge our + /// list into the owner. + /// 2. If we are the owner of the ID, then either the above case, or when + /// we hear from a borrower that it has passed the ID to other + /// borrowers. A borrower is removed from the list when it responds + /// that it is no longer using the reference. + absl::flat_hash_set borrowers; + + /// Callback that will be called when this ObjectID no longer has + /// references. std::function on_delete; + /// Callback that is called when this process is no longer a borrower + /// (RefCount() == 0). + std::function on_ref_removed; }; + using ReferenceTable = absl::flat_hash_map; + + /// Deserialize a ReferenceTable. + static ReferenceTable ReferenceTableFromProto(const ReferenceTableProto &proto); + + /// Serialize a ReferenceTable. + static void ReferenceTableToProto(const ReferenceTable &table, + ReferenceTableProto *proto); + + /// Helper method to wrap an ObjectID(s) inside another object ID. + /// + /// \param[in] object_id The object ID whose value we are storing. + /// \param[in] inner_ids The object IDs that we are storing in object_id. + /// \param[in] owner_address The owner address of the outer object_id. If + /// this is not provided, then the outer object ID must be owned by us. the + /// outer object ID is not owned by us, then this is used to contact the + /// outer object's owner, since it is considered a borrower for the inner + /// IDs. + void WrapObjectIdsInternal(const ObjectID &object_id, + const std::vector &inner_ids, + const absl::optional &owner_address) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + /// Populates the table with the ObjectID that we were or are still + /// borrowing. The table also includes any IDs that we discovered were + /// contained in the ID. For each borrowed ID, we will return: + /// - The borrowed ID's owner's address. + /// - Whether we are still using the ID or not (RefCount() > 0). + /// - Addresses of new borrowers that we passed the ID to. + /// - Whether the borrowed ID was contained in another ID that we borrowed. + /// + /// We will also attempt to clear the information put into the returned table + /// that we no longer need in our local table. Each reference in the local + /// table is modified in the following way: + /// - For each borrowed ID, remove the addresses of any new borrowers. We + /// don't need these anymore because the receiver of the borrowed_refs is + /// either the owner or another borrow who will eventually return the list + /// to the owner. + /// - For each ID that was contained in a borrowed ID, forget that the ID + /// that contained it. We don't need this anymore because we already marked + /// that the borrowed ID contained another ID in the returned + /// borrowed_refs. + bool GetAndClearLocalBorrowersInternal(const ObjectID &object_id, + ReferenceTable *borrowed_refs) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + /// Merge remote borrowers into our local ref count. This will add any + /// workers that are still borrowing the given object ID to the local ref + /// counts, and recursively any workers that are borrowing object IDs that + /// were nested inside. This is the converse of GetAndClearLocalBorrowers. + /// For each borrowed object ID, we will: + /// - Add the worker to our list of borrowers if it is still using the + /// reference. + /// - Add the worker's accumulated borrowers to our list of borrowers. + /// - If the borrowed ID was nested in another borrowed ID, then mark it as + /// such so that we can later merge the inner ID's reference into its + /// owner. + /// - If we are the owner of the ID, then also contact any new borrowers and + /// wait for them to stop using the reference. + void MergeRemoteBorrowers(const ObjectID &object_id, + const rpc::WorkerAddress &worker_addr, + const ReferenceTable &borrowed_refs) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + /// Wait for a borrower to stop using its reference. This should only be + /// called by the owner of the ID. + /// \param[in] reference_it Iterator pointing to the reference that we own. + /// \param[in] addr The address of the borrower. + /// \param[in] contained_in_id Whether the owned ID was contained in another + /// ID. This is used in cases where we return an object ID that we own inside + /// an object that we do not own. Then, we must notify the owner of the outer + /// object that they are borrowing the inner. + void WaitForRefRemoved(const ReferenceTable::iterator &reference_it, + const rpc::WorkerAddress &addr, + const ObjectID &contained_in_id = ObjectID::Nil()) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + /// Helper method to add an object that we are borrowing. This is used when + /// deserializing IDs from a task's arguments, or when deserializing an ID + /// during ray.get(). + bool AddBorrowedObjectInternal(const ObjectID &object_id, const ObjectID &outer_id, + const TaskID &owner_id, + const rpc::Address &owner_address) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + /// Helper method to delete an entry from the reference map and run any necessary /// callbacks. Assumes that the entry is in object_id_refs_ and invalidates the /// iterator. - void DeleteReferenceInternal(absl::flat_hash_map::iterator entry, + void DeleteReferenceInternal(ReferenceTable::iterator entry, std::vector *deleted) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + /// Feature flag for distributed ref counting. If this is false, then we will + /// keep the distributed ref count, but only the local ref count will be used + /// to decide when objects can be evicted. + bool distributed_ref_counting_enabled_; + + /// Factory for producing new core worker clients. + rpc::ClientFactoryFn client_factory_; + + /// Map from worker address to core worker client. The owner of an object + /// uses this client to request a notification from borrowers once the + /// borrower's ref count for the ID goes to 0. + absl::flat_hash_map> + borrower_cache_ GUARDED_BY(mutex_); + /// Protects access to the reference counting state. mutable absl::Mutex mutex_; /// Holds all reference counts and dependency information for tracked ObjectIDs. - absl::flat_hash_map object_id_refs_ GUARDED_BY(mutex_); + ReferenceTable object_id_refs_ GUARDED_BY(mutex_); }; } // namespace ray diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index c86c2a30e..05f2729b9 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -8,6 +8,9 @@ namespace ray { +static const rpc::Address empty_borrower; +static const ReferenceCounter::ReferenceTableProto empty_refs; + class ReferenceCountTest : public ::testing::Test { protected: std::unique_ptr rc; @@ -16,6 +19,133 @@ class ReferenceCountTest : public ::testing::Test { virtual void TearDown() {} }; +class MockWorkerClient : public rpc::CoreWorkerClientInterface { + public: + MockWorkerClient(ReferenceCounter &rc, const std::string &addr) + : rc_(rc), task_id_(TaskID::ForFakeTask()) { + address_.set_ip_address(addr); + address_.set_raylet_id(ClientID::FromRandom().Binary()); + address_.set_worker_id(WorkerID::FromRandom().Binary()); + } + + ray::Status WaitForRefRemoved( + const rpc::WaitForRefRemovedRequest &request, + const rpc::ClientCallback &callback) override { + auto r = num_requests_; + requests_[r] = { + std::make_shared(), + callback, + }; + + auto send_reply_callback = [this, r](Status status, std::function success, + std::function failure) { + requests_[r].second(status, *requests_[r].first); + }; + auto borrower_callback = [=]() { + const ObjectID &object_id = ObjectID::FromBinary(request.reference().object_id()); + ObjectID contained_in_id = ObjectID::FromBinary(request.contained_in_id()); + const auto owner_id = TaskID::FromBinary(request.reference().owner_id()); + const auto owner_address = request.reference().owner_address(); + auto ref_removed_callback = + boost::bind(&ReferenceCounter::HandleRefRemoved, &rc_, _1, + requests_[r].first.get(), send_reply_callback); + rc_.SetRefRemovedCallback(object_id, contained_in_id, owner_id, owner_address, + ref_removed_callback); + }; + borrower_callbacks_[r] = borrower_callback; + + num_requests_++; + return Status::OK(); + } + + bool FlushBorrowerCallbacks() { + if (borrower_callbacks_.empty()) { + return false; + } else { + for (auto &callback : borrower_callbacks_) { + callback.second(); + } + borrower_callbacks_.clear(); + return true; + } + } + + // The below methods mirror a core worker's operations, e.g., `Put` simulates + // a ray.put(). + void Put(const ObjectID &object_id) { + rc_.AddOwnedObject(object_id, {}, task_id_, address_); + rc_.AddLocalReference(object_id); + } + + void PutWrappedId(const ObjectID outer_id, const ObjectID &inner_id) { + rc_.AddOwnedObject(outer_id, {inner_id}, task_id_, address_); + rc_.AddLocalReference(outer_id); + } + + void GetSerializedObjectId(const ObjectID outer_id, const ObjectID &inner_id, + const TaskID &owner_id, const rpc::Address &owner_address) { + rc_.AddLocalReference(inner_id); + rc_.AddBorrowedObject(inner_id, outer_id, owner_id, owner_address); + } + + void ExecuteTaskWithArg(const ObjectID &arg_id, const ObjectID &inner_id, + const TaskID &owner_id, const rpc::Address &owner_address) { + // Add a sentinel reference to keep the argument ID in scope even though + // the frontend won't have a reference. + rc_.AddLocalReference(arg_id); + GetSerializedObjectId(arg_id, inner_id, owner_id, owner_address); + } + + ObjectID SubmitTaskWithArg(const ObjectID &arg_id) { + rc_.AddSubmittedTaskReferences({arg_id}); + ObjectID return_id = ObjectID::FromRandom(); + rc_.AddOwnedObject(return_id, {}, task_id_, address_); + // Add a sentinel reference to keep all nested object IDs in scope. + rc_.AddLocalReference(return_id); + return return_id; + } + + ReferenceCounter::ReferenceTableProto FinishExecutingTask( + const ObjectID &arg_id, const ObjectID &return_id, + const ObjectID *return_wrapped_id = nullptr, + const rpc::WorkerAddress *owner_address = nullptr) { + if (return_wrapped_id) { + rc_.WrapObjectIds(return_id, {*return_wrapped_id}, *owner_address); + } + + ReferenceCounter::ReferenceTableProto refs; + if (!arg_id.IsNil()) { + rc_.GetAndClearLocalBorrowers({arg_id}, &refs); + // Remove the sentinel reference. + rc_.RemoveLocalReference(arg_id, nullptr); + } + return refs; + } + + void HandleSubmittedTaskFinished( + const ObjectID &arg_id, const rpc::Address &borrower_address = empty_borrower, + const ReferenceCounter::ReferenceTableProto &borrower_refs = empty_refs) { + if (!arg_id.IsNil()) { + rc_.UpdateSubmittedTaskReferences({arg_id}, borrower_address, borrower_refs, + nullptr); + } + } + + // Global map from Worker ID -> MockWorkerClient. + // Global map from Object ID -> owner worker ID, list of objects that it depends on, + // worker address that it's scheduled on. Worker map of pending return IDs. + + // The ReferenceCounter at the "client". + ReferenceCounter &rc_; + TaskID task_id_; + rpc::Address address_; + std::unordered_map> borrower_callbacks_; + std::unordered_map, + rpc::ClientCallback>> + requests_; + int num_requests_ = 0; +}; + // Tests basic incrementing/decrementing of direct/submitted task reference counts. An // entry should only be removed once both of its reference counts reach zero. TEST_F(ReferenceCountTest, TestBasic) { @@ -44,13 +174,13 @@ TEST_F(ReferenceCountTest, TestBasic) { rc->AddSubmittedTaskReferences({id1}); rc->AddSubmittedTaskReferences({id1, id2}); ASSERT_EQ(rc->NumObjectIDsInScope(), 2); - rc->RemoveSubmittedTaskReferences({id1}, &out); + rc->UpdateSubmittedTaskReferences({id1}, empty_borrower, empty_refs, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 2); ASSERT_EQ(out.size(), 0); - rc->RemoveSubmittedTaskReferences({id2}, &out); + rc->UpdateSubmittedTaskReferences({id2}, empty_borrower, empty_refs, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 1); ASSERT_EQ(out.size(), 1); - rc->RemoveSubmittedTaskReferences({id1}, &out); + rc->UpdateSubmittedTaskReferences({id1}, empty_borrower, empty_refs, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 0); ASSERT_EQ(out.size(), 2); out.clear(); @@ -63,10 +193,10 @@ TEST_F(ReferenceCountTest, TestBasic) { rc->RemoveLocalReference(id1, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 2); ASSERT_EQ(out.size(), 0); - rc->RemoveSubmittedTaskReferences({id2}, &out); + rc->UpdateSubmittedTaskReferences({id2}, empty_borrower, empty_refs, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 2); ASSERT_EQ(out.size(), 0); - rc->RemoveSubmittedTaskReferences({id1}, &out); + rc->UpdateSubmittedTaskReferences({id1}, empty_borrower, empty_refs, &out); ASSERT_EQ(rc->NumObjectIDsInScope(), 1); ASSERT_EQ(out.size(), 1); rc->RemoveLocalReference(id2, &out); @@ -83,7 +213,7 @@ TEST_F(ReferenceCountTest, TestOwnerAddress) { TaskID task_id = TaskID::ForFakeTask(); rpc::Address address; address.set_ip_address("1234"); - rc->AddOwnedObject(object_id, task_id, address); + rc->AddOwnedObject(object_id, {}, task_id, address); TaskID added_id; rpc::Address added_address; @@ -94,7 +224,7 @@ TEST_F(ReferenceCountTest, TestOwnerAddress) { auto object_id2 = ObjectID::FromRandom(); task_id = TaskID::ForFakeTask(); address.set_ip_address("5678"); - rc->AddOwnedObject(object_id2, task_id, address); + rc->AddOwnedObject(object_id2, {}, task_id, address); ASSERT_TRUE(rc->GetOwner(object_id2, &added_id, &added_address)); ASSERT_EQ(task_id, added_id); ASSERT_EQ(address.ip_address(), added_address.ip_address()); @@ -111,7 +241,7 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { ObjectID id1 = ObjectID::FromRandom().WithDirectTransportType(); ObjectID id2 = ObjectID::FromRandom().WithDirectTransportType(); uint8_t data[] = {1, 2, 3, 4, 5, 6, 7, 8}; - RayObject buffer(std::make_shared(data, sizeof(data)), nullptr); + RayObject buffer(std::make_shared(data, sizeof(data)), nullptr, {}); auto rc = std::shared_ptr(new ReferenceCounter()); CoreWorkerMemoryStore store(nullptr, rc); @@ -132,6 +262,1037 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { ASSERT_EQ(store.Size(), 1); } +// A borrower is given a reference to an object ID, submits a task, waits for +// it to finish, then returns. +// +// @ray.remote +// def borrower(inner_ids): +// inner_id = inner_ids[0] +// ray.get(foo.remote(inner_id)) +// +// inner_id = ray.put(1) +// outer_id = ray.put([inner_id]) +// res = borrower.remote(outer_id) +TEST(DistributedReferenceCountTest, TestNoBorrow) { + ReferenceCounter borrower_rc; + auto borrower = std::make_shared(borrower_rc, "1"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { return borrower; }); + auto owner = std::make_shared(owner_rc, "2"); + + // The owner creates an inner object and wraps it. + auto inner_id = ObjectID::FromRandom(); + auto outer_id = ObjectID::FromRandom(); + owner->Put(inner_id); + owner->PutWrappedId(outer_id, inner_id); + + // The owner submits a task that depends on the outer object. The task will + // be given a reference to inner_id. + owner->SubmitTaskWithArg(outer_id); + // The owner's references go out of scope. + owner_rc.RemoveLocalReference(outer_id, nullptr); + owner_rc.RemoveLocalReference(inner_id, nullptr); + // The owner's ref count > 0 for both objects. + ASSERT_TRUE(owner_rc.HasReference(outer_id)); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // The borrower is given a reference to the inner object. + borrower->ExecuteTaskWithArg(outer_id, inner_id, owner->task_id_, owner->address_); + // The borrower submits a task that depends on the inner object. + borrower->SubmitTaskWithArg(inner_id); + borrower_rc.RemoveLocalReference(inner_id, nullptr); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + + // The borrower waits for the task to finish before returning to the owner. + borrower->HandleSubmittedTaskFinished(inner_id); + auto borrower_refs = borrower->FinishExecutingTask(outer_id, ObjectID::Nil()); + // Check that the borrower's ref count is now 0 for all objects. + ASSERT_FALSE(borrower_rc.HasReference(inner_id)); + ASSERT_FALSE(borrower_rc.HasReference(outer_id)); + + // The owner receives the borrower's reply and merges the borrower's ref + // count into its own. + owner->HandleSubmittedTaskFinished(outer_id, borrower->address_, borrower_refs); + borrower->FlushBorrowerCallbacks(); + // Check that owner's ref count is now 0 for all objects. + ASSERT_FALSE(owner_rc.HasReference(inner_id)); + ASSERT_FALSE(owner_rc.HasReference(outer_id)); +} + +// A borrower is given a reference to an object ID, submits a task, does not +// wait for it to finish. +// +// @ray.remote +// def borrower(inner_ids): +// inner_id = inner_ids[0] +// foo.remote(inner_id) +// +// inner_id = ray.put(1) +// outer_id = ray.put([inner_id]) +// res = borrower.remote(outer_id) +TEST(DistributedReferenceCountTest, TestSimpleBorrower) { + ReferenceCounter borrower_rc; + auto borrower = std::make_shared(borrower_rc, "1"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { return borrower; }); + auto owner = std::make_shared(owner_rc, "2"); + + // The owner creates an inner object and wraps it. + auto inner_id = ObjectID::FromRandom(); + auto outer_id = ObjectID::FromRandom(); + owner->Put(inner_id); + owner->PutWrappedId(outer_id, inner_id); + + // The owner submits a task that depends on the outer object. The task will + // be given a reference to inner_id. + owner->SubmitTaskWithArg(outer_id); + // The owner's references go out of scope. + owner_rc.RemoveLocalReference(outer_id, nullptr); + owner_rc.RemoveLocalReference(inner_id, nullptr); + // The owner's ref count > 0 for both objects. + ASSERT_TRUE(owner_rc.HasReference(outer_id)); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // The borrower is given a reference to the inner object. + borrower->ExecuteTaskWithArg(outer_id, inner_id, owner->task_id_, owner->address_); + // The borrower submits a task that depends on the inner object. + borrower->SubmitTaskWithArg(inner_id); + borrower_rc.RemoveLocalReference(inner_id, nullptr); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + + // The borrower task returns to the owner without waiting for its submitted + // task to finish. + auto borrower_refs = borrower->FinishExecutingTask(outer_id, ObjectID::Nil()); + // ASSERT_FALSE(borrower_rc.HasReference(outer_id)); + // Check that the borrower's ref count for inner_id > 0 because of the + // pending task. + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + + // The owner receives the borrower's reply and merges the borrower's ref + // count into its own. + owner->HandleSubmittedTaskFinished(outer_id, borrower->address_, borrower_refs); + borrower->FlushBorrowerCallbacks(); + // Check that owner now has borrower in inner's borrowers list. + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + // Check that owner's ref count for outer == 0 since the borrower task + // returned and there were no local references to outer_id. + ASSERT_FALSE(owner_rc.HasReference(outer_id)); + + // The task submitted by the borrower returns. Everyone's ref count should go + // to 0. + borrower->HandleSubmittedTaskFinished(inner_id); + ASSERT_FALSE(borrower_rc.HasReference(inner_id)); + ASSERT_FALSE(borrower_rc.HasReference(outer_id)); + ASSERT_FALSE(owner_rc.HasReference(inner_id)); + ASSERT_FALSE(owner_rc.HasReference(outer_id)); +} + +// A borrower is given a reference to an object ID, keeps the reference past +// the task's lifetime, then deletes the reference before it hears from the +// owner. +// +// @ray.remote +// class Borrower: +// def __init__(self, inner_ids): +// self.inner_id = inner_ids[0] +// +// inner_id = ray.put(1) +// outer_id = ray.put([inner_id]) +// res = Borrower.remote(outer_id) +TEST(DistributedReferenceCountTest, TestSimpleBorrowerReferenceRemoved) { + ReferenceCounter borrower_rc; + auto borrower = std::make_shared(borrower_rc, "1"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { return borrower; }); + auto owner = std::make_shared(owner_rc, "2"); + + // The owner creates an inner object and wraps it. + auto inner_id = ObjectID::FromRandom(); + auto outer_id = ObjectID::FromRandom(); + owner->Put(inner_id); + owner->PutWrappedId(outer_id, inner_id); + + // The owner submits a task that depends on the outer object. The task will + // be given a reference to inner_id. + owner->SubmitTaskWithArg(outer_id); + // The owner's references go out of scope. + owner_rc.RemoveLocalReference(outer_id, nullptr); + owner_rc.RemoveLocalReference(inner_id, nullptr); + // The owner's ref count > 0 for both objects. + ASSERT_TRUE(owner_rc.HasReference(outer_id)); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // The borrower is given a reference to the inner object. + borrower->ExecuteTaskWithArg(outer_id, inner_id, owner->task_id_, owner->address_); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + + // The borrower task returns to the owner while still using inner_id. + auto borrower_refs = borrower->FinishExecutingTask(outer_id, ObjectID::Nil()); + ASSERT_FALSE(borrower_rc.HasReference(outer_id)); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + + // The owner receives the borrower's reply and merges the borrower's ref + // count into its own. + owner->HandleSubmittedTaskFinished(outer_id, borrower->address_, borrower_refs); + // Check that owner now has borrower in inner's borrowers list. + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + // Check that owner's ref count for outer == 0 since the borrower task + // returned and there were no local references to outer_id. + ASSERT_FALSE(owner_rc.HasReference(outer_id)); + + // The borrower is no longer using inner_id, but it hasn't received the + // message from the owner yet. + borrower_rc.RemoveLocalReference(inner_id, nullptr); + ASSERT_FALSE(borrower_rc.HasReference(inner_id)); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // The borrower receives the owner's wait message. It should return a reply + // to the owner immediately saying that it is no longer using inner_id. + borrower->FlushBorrowerCallbacks(); + ASSERT_FALSE(borrower_rc.HasReference(inner_id)); + ASSERT_FALSE(owner_rc.HasReference(inner_id)); +} + +// A borrower is given a reference to an object ID, passes the reference to +// another borrower by submitting a task, and does not wait for it to finish. +// +// @ray.remote +// def borrower2(inner_ids): +// pass +// +// @ray.remote +// def borrower(inner_ids): +// borrower2.remote(inner_ids) +// +// inner_id = ray.put(1) +// outer_id = ray.put([inner_id]) +// res = borrower.remote(outer_id) +TEST(DistributedReferenceCountTest, TestBorrowerTree) { + ReferenceCounter borrower_rc1; + auto borrower1 = std::make_shared(borrower_rc1, "1"); + ReferenceCounter borrower_rc2; + auto borrower2 = std::make_shared(borrower_rc2, "2"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { + if (addr.ip_address() == borrower1->address_.ip_address()) { + return borrower1; + } else { + return borrower2; + } + }); + auto owner = std::make_shared(owner_rc, "3"); + + // The owner creates an inner object and wraps it. + auto inner_id = ObjectID::FromRandom(); + auto outer_id = ObjectID::FromRandom(); + owner->Put(inner_id); + owner->PutWrappedId(outer_id, inner_id); + + // The owner submits a task that depends on the outer object. The task will + // be given a reference to inner_id. + owner->SubmitTaskWithArg(outer_id); + // The owner's references go out of scope. + owner_rc.RemoveLocalReference(outer_id, nullptr); + owner_rc.RemoveLocalReference(inner_id, nullptr); + // The owner's ref count > 0 for both objects. + ASSERT_TRUE(owner_rc.HasReference(outer_id)); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // Borrower 1 is given a reference to the inner object. + borrower1->ExecuteTaskWithArg(outer_id, inner_id, owner->task_id_, owner->address_); + // The borrower submits a task that depends on the inner object. + auto outer_id2 = ObjectID::FromRandom(); + borrower1->PutWrappedId(outer_id2, inner_id); + borrower1->SubmitTaskWithArg(outer_id2); + borrower_rc1.RemoveLocalReference(inner_id, nullptr); + borrower_rc1.RemoveLocalReference(outer_id2, nullptr); + ASSERT_TRUE(borrower_rc1.HasReference(inner_id)); + ASSERT_TRUE(borrower_rc1.HasReference(outer_id2)); + + // The borrower task returns to the owner without waiting for its submitted + // task to finish. + auto borrower_refs = borrower1->FinishExecutingTask(outer_id, ObjectID::Nil()); + ASSERT_TRUE(borrower_rc1.HasReference(inner_id)); + ASSERT_TRUE(borrower_rc1.HasReference(outer_id2)); + ASSERT_FALSE(borrower_rc1.HasReference(outer_id)); + + // The owner receives the borrower's reply and merges the borrower's ref + // count into its own. + owner->HandleSubmittedTaskFinished(outer_id, borrower1->address_, borrower_refs); + borrower1->FlushBorrowerCallbacks(); + // Check that owner now has borrower in inner's borrowers list. + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + // Check that owner's ref count for outer == 0 since the borrower task + // returned and there were no local references to outer_id. + ASSERT_FALSE(owner_rc.HasReference(outer_id)); + + // Borrower 2 starts executing. It is given a reference to the inner object + // when it gets outer_id2 as an argument. + borrower2->ExecuteTaskWithArg(outer_id2, inner_id, owner->task_id_, owner->address_); + ASSERT_TRUE(borrower_rc2.HasReference(inner_id)); + // Borrower 2 finishes but it is still using inner_id. + borrower_refs = borrower2->FinishExecutingTask(outer_id2, ObjectID::Nil()); + ASSERT_TRUE(borrower_rc2.HasReference(inner_id)); + ASSERT_FALSE(borrower_rc2.HasReference(outer_id2)); + ASSERT_FALSE(borrower_rc2.HasReference(outer_id)); + + borrower1->HandleSubmittedTaskFinished(outer_id2, borrower2->address_, borrower_refs); + borrower2->FlushBorrowerCallbacks(); + // Borrower 1 no longer has a reference to any objects. + ASSERT_FALSE(borrower_rc1.HasReference(inner_id)); + ASSERT_FALSE(borrower_rc1.HasReference(outer_id2)); + // The owner should now have borrower 2 in its count. + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + borrower_rc2.RemoveLocalReference(inner_id, nullptr); + ASSERT_FALSE(borrower_rc2.HasReference(inner_id)); + ASSERT_FALSE(owner_rc.HasReference(inner_id)); +} + +// A task is given a reference to an object ID, whose value contains another +// object ID. The task gets a reference to the innermost object ID, but deletes +// it by the time the task finishes. +// +// @ray.remote +// def borrower(mid_ids): +// inner_id = ray.get(mid_ids[0]) +// del inner_id +// +// inner_id = ray.put(1) +// mid_id = ray.put([inner_id]) +// outer_id = ray.put([mid_id]) +// res = borrower.remote(outer_id) +TEST(DistributedReferenceCountTest, TestNestedObjectNoBorrow) { + ReferenceCounter borrower_rc; + auto borrower = std::make_shared(borrower_rc, "1"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { return borrower; }); + auto owner = std::make_shared(owner_rc, "2"); + + // The owner creates an inner object and wraps it. + auto inner_id = ObjectID::FromRandom(); + auto mid_id = ObjectID::FromRandom(); + auto outer_id = ObjectID::FromRandom(); + owner->Put(inner_id); + owner->PutWrappedId(mid_id, inner_id); + owner->PutWrappedId(outer_id, mid_id); + + // The owner submits a task that depends on the outer object. The task will + // be given a reference to mid_id. + owner->SubmitTaskWithArg(outer_id); + // The owner's references go out of scope. + owner_rc.RemoveLocalReference(outer_id, nullptr); + owner_rc.RemoveLocalReference(mid_id, nullptr); + owner_rc.RemoveLocalReference(inner_id, nullptr); + // The owner's ref count > 0 for all objects. + ASSERT_TRUE(owner_rc.HasReference(outer_id)); + ASSERT_TRUE(owner_rc.HasReference(mid_id)); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // The borrower is given a reference to the middle object. + borrower->ExecuteTaskWithArg(outer_id, mid_id, owner->task_id_, owner->address_); + ASSERT_TRUE(borrower_rc.HasReference(mid_id)); + ASSERT_FALSE(borrower_rc.HasReference(inner_id)); + + // The borrower unwraps the inner object with ray.get. + borrower->GetSerializedObjectId(mid_id, inner_id, owner->task_id_, owner->address_); + borrower_rc.RemoveLocalReference(mid_id, nullptr); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + // The borrower's reference to inner_id goes out of scope. + borrower_rc.RemoveLocalReference(inner_id, nullptr); + + // The borrower task returns to the owner. + auto borrower_refs = borrower->FinishExecutingTask(outer_id, ObjectID::Nil()); + ASSERT_FALSE(borrower_rc.HasReference(outer_id)); + ASSERT_FALSE(borrower_rc.HasReference(mid_id)); + ASSERT_FALSE(borrower_rc.HasReference(inner_id)); + + // The owner receives the borrower's reply and merges the borrower's ref + // count into its own. + owner->HandleSubmittedTaskFinished(outer_id, borrower->address_, borrower_refs); + // Check that owner now has nothing in scope. + ASSERT_FALSE(owner_rc.HasReference(outer_id)); + ASSERT_FALSE(owner_rc.HasReference(mid_id)); + ASSERT_FALSE(owner_rc.HasReference(inner_id)); +} + +// A task is given a reference to an object ID, whose value contains another +// object ID. The task gets a reference to the innermost object ID, and is +// still borrowing it by the time the task finishes. +// +// @ray.remote +// def borrower(mid_ids): +// inner_id = ray.get(mid_ids[0]) +// foo.remote(inner_id) +// +// inner_id = ray.put(1) +// mid_id = ray.put([inner_id]) +// outer_id = ray.put([mid_id]) +// res = borrower.remote(outer_id) +TEST(DistributedReferenceCountTest, TestNestedObject) { + ReferenceCounter borrower_rc; + auto borrower = std::make_shared(borrower_rc, "1"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { return borrower; }); + auto owner = std::make_shared(owner_rc, "2"); + + // The owner creates an inner object and wraps it. + auto inner_id = ObjectID::FromRandom(); + auto mid_id = ObjectID::FromRandom(); + auto outer_id = ObjectID::FromRandom(); + owner->Put(inner_id); + owner->PutWrappedId(mid_id, inner_id); + owner->PutWrappedId(outer_id, mid_id); + + // The owner submits a task that depends on the outer object. The task will + // be given a reference to mid_id. + owner->SubmitTaskWithArg(outer_id); + // The owner's references go out of scope. + owner_rc.RemoveLocalReference(outer_id, nullptr); + owner_rc.RemoveLocalReference(mid_id, nullptr); + owner_rc.RemoveLocalReference(inner_id, nullptr); + // The owner's ref count > 0 for all objects. + ASSERT_TRUE(owner_rc.HasReference(outer_id)); + ASSERT_TRUE(owner_rc.HasReference(mid_id)); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // The borrower is given a reference to the middle object. + borrower->ExecuteTaskWithArg(outer_id, mid_id, owner->task_id_, owner->address_); + ASSERT_TRUE(borrower_rc.HasReference(mid_id)); + ASSERT_FALSE(borrower_rc.HasReference(inner_id)); + + // The borrower unwraps the inner object with ray.get. + borrower->GetSerializedObjectId(mid_id, inner_id, owner->task_id_, owner->address_); + borrower_rc.RemoveLocalReference(mid_id, nullptr); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + + // The borrower task returns to the owner while still using inner_id. + auto borrower_refs = borrower->FinishExecutingTask(outer_id, ObjectID::Nil()); + ASSERT_FALSE(borrower_rc.HasReference(outer_id)); + ASSERT_FALSE(borrower_rc.HasReference(mid_id)); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + + // The owner receives the borrower's reply and merges the borrower's ref + // count into its own. + owner->HandleSubmittedTaskFinished(outer_id, borrower->address_, borrower_refs); + // Check that owner now has borrower in inner's borrowers list. + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + // Check that owner's ref count for outer and mid are 0 since the borrower + // task returned and there were no local references to outer_id. + ASSERT_FALSE(owner_rc.HasReference(outer_id)); + ASSERT_FALSE(owner_rc.HasReference(mid_id)); + + // The borrower receives the owner's wait message. It should return a reply + // to the owner immediately saying that it is no longer using inner_id. + borrower->FlushBorrowerCallbacks(); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // The borrower is no longer using inner_id, but it hasn't received the + // message from the owner yet. + borrower_rc.RemoveLocalReference(inner_id, nullptr); + ASSERT_FALSE(borrower_rc.HasReference(inner_id)); + ASSERT_FALSE(owner_rc.HasReference(inner_id)); +} + +// A borrower is given a reference to an object ID, whose value contains +// another object ID. The borrower passes the reference again to another +// borrower and waits for it to finish. The nested borrower unwraps the outer +// object and gets a reference to the innermost ID. +// +// @ray.remote +// def borrower2(owner_id2): +// owner_id1 = ray.get(owner_id2[0])[0] +// foo.remote(owner_id1) +// +// @ray.remote +// def borrower1(owner_id2): +// ray.get(borrower2.remote(owner_id2)) +// +// owner_id1 = ray.put(1) +// owner_id2 = ray.put([owner_id1]) +// owner_id3 = ray.put([owner_id2]) +// res = borrower1.remote(owner_id3) +TEST(DistributedReferenceCountTest, TestNestedObjectDifferentOwners) { + ReferenceCounter borrower_rc1; + auto borrower1 = std::make_shared(borrower_rc1, "1"); + ReferenceCounter borrower_rc2; + auto borrower2 = std::make_shared(borrower_rc2, "2"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { + if (addr.ip_address() == borrower1->address_.ip_address()) { + return borrower1; + } else { + return borrower2; + } + }); + auto owner = std::make_shared(owner_rc, "3"); + + // The owner creates an inner object and wraps it. + auto owner_id1 = ObjectID::FromRandom(); + auto owner_id2 = ObjectID::FromRandom(); + auto owner_id3 = ObjectID::FromRandom(); + owner->Put(owner_id1); + owner->PutWrappedId(owner_id2, owner_id1); + owner->PutWrappedId(owner_id3, owner_id2); + + // The owner submits a task that depends on the outer object. The task will + // be given a reference to owner_id2. + owner->SubmitTaskWithArg(owner_id3); + // The owner's references go out of scope. + owner_rc.RemoveLocalReference(owner_id1, nullptr); + owner_rc.RemoveLocalReference(owner_id2, nullptr); + owner_rc.RemoveLocalReference(owner_id3, nullptr); + + // The borrower is given a reference to the middle object. + borrower1->ExecuteTaskWithArg(owner_id3, owner_id2, owner->task_id_, owner->address_); + ASSERT_TRUE(borrower_rc1.HasReference(owner_id2)); + ASSERT_FALSE(borrower_rc1.HasReference(owner_id1)); + + // The borrower wraps the object ID again. + auto borrower_id = ObjectID::FromRandom(); + borrower1->PutWrappedId(borrower_id, owner_id2); + borrower_rc1.RemoveLocalReference(owner_id2, nullptr); + + // Borrower 1 submits a task that depends on the wrapped object. The task + // will be given a reference to owner_id2. + borrower1->SubmitTaskWithArg(borrower_id); + borrower_rc1.RemoveLocalReference(borrower_id, nullptr); + borrower2->ExecuteTaskWithArg(borrower_id, owner_id2, owner->task_id_, owner->address_); + + // The nested task returns while still using owner_id1. + borrower2->GetSerializedObjectId(owner_id2, owner_id1, owner->task_id_, + owner->address_); + borrower_rc2.RemoveLocalReference(owner_id2, nullptr); + auto borrower_refs = borrower2->FinishExecutingTask(borrower_id, ObjectID::Nil()); + ASSERT_TRUE(borrower_rc2.HasReference(owner_id1)); + ASSERT_FALSE(borrower_rc2.HasReference(owner_id2)); + + // Borrower 1 should now know that borrower 2 is borrowing the inner object + // ID. + borrower1->HandleSubmittedTaskFinished(borrower_id, borrower2->address_, borrower_refs); + ASSERT_TRUE(borrower_rc1.HasReference(owner_id1)); + + // Borrower 1 finishes. It should not have any references now because all + // state has been merged into the owner. + borrower_refs = borrower1->FinishExecutingTask(owner_id3, ObjectID::Nil()); + ASSERT_FALSE(borrower_rc1.HasReference(owner_id1)); + ASSERT_FALSE(borrower_rc1.HasReference(owner_id2)); + ASSERT_FALSE(borrower_rc1.HasReference(owner_id3)); + ASSERT_FALSE(borrower_rc1.HasReference(borrower_id)); + + // The owner receives the borrower's reply and merges the borrower's ref + // count into its own. + owner->HandleSubmittedTaskFinished(owner_id3, borrower1->address_, borrower_refs); + // Check that owner now has borrower2 in inner's borrowers list. + ASSERT_TRUE(owner_rc.HasReference(owner_id1)); + ASSERT_FALSE(owner_rc.HasReference(owner_id2)); + ASSERT_FALSE(owner_rc.HasReference(owner_id3)); + + // The borrower receives the owner's wait message. + borrower2->FlushBorrowerCallbacks(); + ASSERT_TRUE(owner_rc.HasReference(owner_id1)); + borrower_rc2.RemoveLocalReference(owner_id1, nullptr); + ASSERT_FALSE(borrower_rc2.HasReference(owner_id1)); + ASSERT_FALSE(owner_rc.HasReference(owner_id1)); +} + +// A borrower is given a reference to an object ID, whose value contains +// another object ID. The borrower passes the reference again to another +// borrower but does not wait for it to finish. The nested borrower unwraps the +// outer object and gets a reference to the innermost ID. +// +// @ray.remote +// def borrower2(owner_id2): +// owner_id1 = ray.get(owner_id2[0])[0] +// foo.remote(owner_id1) +// +// @ray.remote +// def borrower1(owner_id2): +// borrower2.remote(owner_id2) +// +// owner_id1 = ray.put(1) +// owner_id2 = ray.put([owner_id1]) +// owner_id3 = ray.put([owner_id2]) +// res = borrower1.remote(owner_id3) +TEST(DistributedReferenceCountTest, TestNestedObjectDifferentOwners2) { + ReferenceCounter borrower_rc1; + auto borrower1 = std::make_shared(borrower_rc1, "1"); + ReferenceCounter borrower_rc2; + auto borrower2 = std::make_shared(borrower_rc2, "2"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { + if (addr.ip_address() == borrower1->address_.ip_address()) { + return borrower1; + } else { + return borrower2; + } + }); + auto owner = std::make_shared(owner_rc, "3"); + + // The owner creates an inner object and wraps it. + auto owner_id1 = ObjectID::FromRandom(); + auto owner_id2 = ObjectID::FromRandom(); + auto owner_id3 = ObjectID::FromRandom(); + owner->Put(owner_id1); + owner->PutWrappedId(owner_id2, owner_id1); + owner->PutWrappedId(owner_id3, owner_id2); + + // The owner submits a task that depends on the outer object. The task will + // be given a reference to owner_id2. + owner->SubmitTaskWithArg(owner_id3); + // The owner's references go out of scope. + owner_rc.RemoveLocalReference(owner_id1, nullptr); + owner_rc.RemoveLocalReference(owner_id2, nullptr); + owner_rc.RemoveLocalReference(owner_id3, nullptr); + + // The borrower is given a reference to the middle object. + borrower1->ExecuteTaskWithArg(owner_id3, owner_id2, owner->task_id_, owner->address_); + ASSERT_TRUE(borrower_rc1.HasReference(owner_id2)); + ASSERT_FALSE(borrower_rc1.HasReference(owner_id1)); + + // The borrower wraps the object ID again. + auto borrower_id = ObjectID::FromRandom(); + borrower1->PutWrappedId(borrower_id, owner_id2); + borrower_rc1.RemoveLocalReference(owner_id2, nullptr); + + // Borrower 1 submits a task that depends on the wrapped object. The task + // will be given a reference to owner_id2. + borrower1->SubmitTaskWithArg(borrower_id); + borrower2->ExecuteTaskWithArg(borrower_id, owner_id2, owner->task_id_, owner->address_); + + // The nested task returns while still using owner_id1. + borrower2->GetSerializedObjectId(owner_id2, owner_id1, owner->task_id_, + owner->address_); + borrower_rc2.RemoveLocalReference(owner_id2, nullptr); + auto borrower_refs = borrower2->FinishExecutingTask(borrower_id, ObjectID::Nil()); + ASSERT_TRUE(borrower_rc2.HasReference(owner_id1)); + ASSERT_FALSE(borrower_rc2.HasReference(owner_id2)); + + // Borrower 1 should now know that borrower 2 is borrowing the inner object + // ID. + borrower1->HandleSubmittedTaskFinished(borrower_id, borrower2->address_, borrower_refs); + ASSERT_TRUE(borrower_rc1.HasReference(owner_id1)); + ASSERT_TRUE(borrower_rc1.HasReference(owner_id2)); + + // Borrower 1 finishes. It should only have its reference to owner_id2 now. + borrower_refs = borrower1->FinishExecutingTask(owner_id3, ObjectID::Nil()); + ASSERT_TRUE(borrower_rc1.HasReference(owner_id2)); + ASSERT_FALSE(borrower_rc1.HasReference(owner_id3)); + + // The owner receives the borrower's reply and merges the borrower's ref + // count into its own. + owner->HandleSubmittedTaskFinished(owner_id3, borrower1->address_, borrower_refs); + // Check that owner now has borrower2 in inner's borrowers list. + ASSERT_TRUE(owner_rc.HasReference(owner_id1)); + ASSERT_TRUE(owner_rc.HasReference(owner_id2)); + ASSERT_FALSE(owner_rc.HasReference(owner_id3)); + + // The borrower receives the owner's wait message. + borrower2->FlushBorrowerCallbacks(); + ASSERT_TRUE(owner_rc.HasReference(owner_id1)); + borrower_rc2.RemoveLocalReference(owner_id1, nullptr); + ASSERT_FALSE(borrower_rc2.HasReference(owner_id1)); + ASSERT_TRUE(owner_rc.HasReference(owner_id1)); + + // The borrower receives the owner's wait message. + borrower1->FlushBorrowerCallbacks(); + ASSERT_TRUE(owner_rc.HasReference(owner_id2)); + borrower_rc1.RemoveLocalReference(borrower_id, nullptr); + ASSERT_FALSE(borrower_rc1.HasReference(owner_id2)); + ASSERT_FALSE(borrower_rc1.HasReference(owner_id1)); + ASSERT_FALSE(owner_rc.HasReference(owner_id2)); +} + +// A borrower is given a reference to an object ID and passes the reference to +// another task. The nested task executes on the object's owner. +// +// @ray.remote +// def executes_on_owner(inner_ids): +// inner_id = inner_ids[0] +// +// @ray.remote +// def borrower(inner_ids): +// outer_id2 = ray.put(inner_ids) +// executes_on_owner.remote(outer_id2) +// +// inner_id = ray.put(1) +// outer_id = ray.put([inner_id]) +// res = borrower.remote(outer_id) +TEST(DistributedReferenceCountTest, TestBorrowerPingPong) { + ReferenceCounter borrower_rc; + auto borrower = std::make_shared(borrower_rc, "1"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { + RAY_CHECK(addr.ip_address() == borrower->address_.ip_address()); + return borrower; + }); + auto owner = std::make_shared(owner_rc, "2"); + + // The owner creates an inner object and wraps it. + auto inner_id = ObjectID::FromRandom(); + auto outer_id = ObjectID::FromRandom(); + owner->Put(inner_id); + owner->PutWrappedId(outer_id, inner_id); + + // The owner submits a task that depends on the outer object. The task will + // be given a reference to inner_id. + owner->SubmitTaskWithArg(outer_id); + // The owner's references go out of scope. + owner_rc.RemoveLocalReference(outer_id, nullptr); + owner_rc.RemoveLocalReference(inner_id, nullptr); + + // Borrower 1 is given a reference to the inner object. + borrower->ExecuteTaskWithArg(outer_id, inner_id, owner->task_id_, owner->address_); + // The borrower submits a task that depends on the inner object. + auto outer_id2 = ObjectID::FromRandom(); + borrower->PutWrappedId(outer_id2, inner_id); + borrower->SubmitTaskWithArg(outer_id2); + borrower_rc.RemoveLocalReference(inner_id, nullptr); + borrower_rc.RemoveLocalReference(outer_id2, nullptr); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + ASSERT_TRUE(borrower_rc.HasReference(outer_id2)); + + // The borrower task returns to the owner without waiting for its submitted + // task to finish. + auto borrower_refs = borrower->FinishExecutingTask(outer_id, ObjectID::Nil()); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + ASSERT_TRUE(borrower_rc.HasReference(outer_id2)); + ASSERT_FALSE(borrower_rc.HasReference(outer_id)); + + // The owner receives the borrower's reply and merges the borrower's ref + // count into its own. + owner->HandleSubmittedTaskFinished(outer_id, borrower->address_, borrower_refs); + borrower->FlushBorrowerCallbacks(); + // Check that owner now has a borrower for inner. + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + // Check that owner's ref count for outer == 0 since the borrower task + // returned and there were no local references to outer_id. + ASSERT_FALSE(owner_rc.HasReference(outer_id)); + + // Owner starts executing the submitted task. It is given a second reference + // to the inner object when it gets outer_id2 as an argument. + owner->ExecuteTaskWithArg(outer_id2, inner_id, owner->task_id_, owner->address_); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + // Owner finishes but it is still using inner_id. + borrower_refs = owner->FinishExecutingTask(outer_id2, ObjectID::Nil()); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + borrower->HandleSubmittedTaskFinished(outer_id2, owner->address_, borrower_refs); + borrower->FlushBorrowerCallbacks(); + // Borrower no longer has a reference to any objects. + ASSERT_FALSE(borrower_rc.HasReference(inner_id)); + ASSERT_FALSE(borrower_rc.HasReference(outer_id2)); + // The owner should now have borrower 2 in its count. + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + owner_rc.RemoveLocalReference(inner_id, nullptr); + ASSERT_FALSE(owner_rc.HasReference(inner_id)); +} + +// A borrower is given two references to the same object ID. `task` and `Actor` +// execute on the same process. +// +// @ray.remote +// def task(inner_ids): +// foo.remote(inner_ids[0]) +// +// @ray.remote +// class Actor: +// def __init__(self, inner_ids): +// self.inner_id = inner_ids[0] +// +// inner_id = ray.put(1) +// outer_id = ray.put([inner_id]) +// res = task.remote(outer_id) +// Actor.remote(outer_id) +TEST(DistributedReferenceCountTest, TestDuplicateBorrower) { + ReferenceCounter borrower_rc; + auto borrower = std::make_shared(borrower_rc, "1"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { return borrower; }); + auto owner = std::make_shared(owner_rc, "2"); + + // The owner creates an inner object and wraps it. + auto inner_id = ObjectID::FromRandom(); + auto outer_id = ObjectID::FromRandom(); + owner->Put(inner_id); + owner->PutWrappedId(outer_id, inner_id); + + // The owner submits a task that depends on the outer object. The task will + // be given a reference to inner_id. + owner->SubmitTaskWithArg(outer_id); + // The owner's references go out of scope. + owner_rc.RemoveLocalReference(inner_id, nullptr); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // The borrower is given a reference to the inner object. + borrower->ExecuteTaskWithArg(outer_id, inner_id, owner->task_id_, owner->address_); + // The borrower submits a task that depends on the inner object. + borrower->SubmitTaskWithArg(inner_id); + borrower_rc.RemoveLocalReference(inner_id, nullptr); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + + // The borrower task returns to the owner without waiting for its submitted + // task to finish. + auto borrower_refs1 = borrower->FinishExecutingTask(outer_id, ObjectID::Nil()); + // Check that the borrower's ref count for inner_id > 0 because of the + // pending task. + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + + // The borrower is given a 2nd reference to the inner object. + owner->SubmitTaskWithArg(outer_id); + owner_rc.RemoveLocalReference(outer_id, nullptr); + borrower->ExecuteTaskWithArg(outer_id, inner_id, owner->task_id_, owner->address_); + auto borrower_refs2 = borrower->FinishExecutingTask(outer_id, ObjectID::Nil()); + + // The owner receives the borrower's replies and merges the borrower's ref + // count into its own. + owner->HandleSubmittedTaskFinished(outer_id, borrower->address_, borrower_refs1); + owner->HandleSubmittedTaskFinished(outer_id, borrower->address_, borrower_refs2); + borrower->FlushBorrowerCallbacks(); + // Check that owner now has borrower in inner's borrowers list. + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + // Check that owner's ref count for outer == 0 since the borrower task + // returned and there were no local references to outer_id. + ASSERT_FALSE(owner_rc.HasReference(outer_id)); + + // The task submitted by the borrower returns and its second reference goes + // out of scope. Everyone's ref count should go to 0. + borrower->HandleSubmittedTaskFinished(inner_id); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + borrower_rc.RemoveLocalReference(inner_id, nullptr); + ASSERT_FALSE(owner_rc.HasReference(inner_id)); + ASSERT_FALSE(borrower_rc.HasReference(inner_id)); + ASSERT_FALSE(borrower_rc.HasReference(outer_id)); + ASSERT_FALSE(owner_rc.HasReference(outer_id)); +} + +// A borrower is given references to 2 different objects, which each contain a +// reference to an object ID. The borrower unwraps both objects and receives a +// duplicate reference to the inner ID. +TEST(DistributedReferenceCountTest, TestDuplicateNestedObject) { + ReferenceCounter borrower_rc1; + auto borrower1 = std::make_shared(borrower_rc1, "1"); + ReferenceCounter borrower_rc2; + auto borrower2 = std::make_shared(borrower_rc2, "2"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { + if (addr.ip_address() == borrower1->address_.ip_address()) { + return borrower1; + } else { + return borrower2; + } + }); + auto owner = std::make_shared(owner_rc, "3"); + + // The owner creates an inner object and wraps it. + auto owner_id1 = ObjectID::FromRandom(); + auto owner_id2 = ObjectID::FromRandom(); + auto owner_id3 = ObjectID::FromRandom(); + owner->Put(owner_id1); + owner->PutWrappedId(owner_id2, owner_id1); + owner->PutWrappedId(owner_id3, owner_id2); + + owner->SubmitTaskWithArg(owner_id3); + owner->SubmitTaskWithArg(owner_id2); + owner_rc.RemoveLocalReference(owner_id1, nullptr); + owner_rc.RemoveLocalReference(owner_id2, nullptr); + owner_rc.RemoveLocalReference(owner_id3, nullptr); + + borrower2->ExecuteTaskWithArg(owner_id3, owner_id2, owner->task_id_, owner->address_); + borrower2->GetSerializedObjectId(owner_id2, owner_id1, owner->task_id_, + owner->address_); + borrower_rc2.RemoveLocalReference(owner_id2, nullptr); + // The nested task returns while still using owner_id1. + auto borrower_refs = borrower2->FinishExecutingTask(owner_id3, ObjectID::Nil()); + owner->HandleSubmittedTaskFinished(owner_id3, borrower2->address_, borrower_refs); + ASSERT_TRUE(borrower2->FlushBorrowerCallbacks()); + + // The owner submits a task that is given a reference to owner_id1. + borrower1->ExecuteTaskWithArg(owner_id2, owner_id1, owner->task_id_, owner->address_); + // The borrower wraps the object ID again. + auto borrower_id = ObjectID::FromRandom(); + borrower1->PutWrappedId(borrower_id, owner_id1); + borrower_rc1.RemoveLocalReference(owner_id1, nullptr); + // Borrower 1 submits a task that depends on the wrapped object. The task + // will be given a reference to owner_id1. + borrower1->SubmitTaskWithArg(borrower_id); + borrower_rc1.RemoveLocalReference(borrower_id, nullptr); + borrower2->ExecuteTaskWithArg(borrower_id, owner_id1, owner->task_id_, owner->address_); + // The nested task returns while still using owner_id1. + // It should now have 2 local references to owner_id1, one from the owner and + // one from the borrower. + borrower_refs = borrower2->FinishExecutingTask(borrower_id, ObjectID::Nil()); + borrower1->HandleSubmittedTaskFinished(borrower_id, borrower2->address_, borrower_refs); + + // Borrower 1 finishes. It should not have any references now because all + // state has been merged into the owner. + borrower_refs = borrower1->FinishExecutingTask(owner_id2, ObjectID::Nil()); + ASSERT_FALSE(borrower_rc1.HasReference(owner_id1)); + ASSERT_FALSE(borrower_rc1.HasReference(owner_id2)); + ASSERT_FALSE(borrower_rc1.HasReference(owner_id3)); + ASSERT_FALSE(borrower_rc1.HasReference(borrower_id)); + // Borrower 1 should not have merge any refs into the owner because borrower 2's ref was + // already merged into the owner. + owner->HandleSubmittedTaskFinished(owner_id2, borrower1->address_, borrower_refs); + + // The borrower receives the owner's wait message. + borrower2->FlushBorrowerCallbacks(); + ASSERT_TRUE(owner_rc.HasReference(owner_id1)); + borrower_rc2.RemoveLocalReference(owner_id1, nullptr); + ASSERT_TRUE(owner_rc.HasReference(owner_id1)); + borrower_rc2.RemoveLocalReference(owner_id1, nullptr); + ASSERT_FALSE(borrower_rc2.HasReference(owner_id1)); + ASSERT_FALSE(owner_rc.HasReference(owner_id1)); +} + +// We submit a task and immediately delete the reference to the return ID. The +// submitted task returns an object ID. +// +// @ray.remote +// def returns_id(): +// inner_id = ray.put() +// return inner_id +// +// returns_id.remote() +TEST(DistributedReferenceCountTest, TestReturnObjectIdNoBorrow) { + ReferenceCounter caller_rc; + auto caller = std::make_shared(caller_rc, "1"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { + RAY_CHECK(addr.ip_address() == caller->address_.ip_address()); + return caller; + }); + auto owner = std::make_shared(owner_rc, "3"); + + // Caller submits a task. + auto return_id = caller->SubmitTaskWithArg(ObjectID::Nil()); + + // Task returns inner_id as its return value. + auto inner_id = ObjectID::FromRandom(); + owner->Put(inner_id); + rpc::WorkerAddress addr(caller->address_); + auto refs = owner->FinishExecutingTask(ObjectID::Nil(), return_id, &inner_id, &addr); + owner_rc.RemoveLocalReference(inner_id, nullptr); + ASSERT_TRUE(refs.empty()); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // Caller's ref to the task's return ID goes out of scope before it hears + // from the owner of inner_id. + caller->HandleSubmittedTaskFinished(ObjectID::Nil()); + caller_rc.RemoveLocalReference(return_id, nullptr); + ASSERT_FALSE(caller_rc.HasReference(return_id)); + ASSERT_FALSE(caller_rc.HasReference(inner_id)); + + // Caller should respond to the owner's message immediately. + ASSERT_TRUE(caller->FlushBorrowerCallbacks()); + ASSERT_FALSE(owner_rc.HasReference(inner_id)); +} + +// We submit a task and keep the reference to the return ID. The submitted task +// returns an object ID. +// +// @ray.remote +// def returns_id(): +// inner_id = ray.put() +// return inner_id +// +// return_id = returns_id.remote() +TEST(DistributedReferenceCountTest, TestReturnObjectIdBorrow) { + ReferenceCounter caller_rc; + auto caller = std::make_shared(caller_rc, "1"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { + RAY_CHECK(addr.ip_address() == caller->address_.ip_address()); + return caller; + }); + auto owner = std::make_shared(owner_rc, "3"); + + // Caller submits a task. + auto return_id = caller->SubmitTaskWithArg(ObjectID::Nil()); + + // Task returns inner_id as its return value. + auto inner_id = ObjectID::FromRandom(); + owner->Put(inner_id); + rpc::WorkerAddress addr(caller->address_); + auto refs = owner->FinishExecutingTask(ObjectID::Nil(), return_id, &inner_id, &addr); + owner_rc.RemoveLocalReference(inner_id, nullptr); + ASSERT_TRUE(refs.empty()); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // Caller receives the owner's message, but inner_id is still in scope + // because caller has a reference to return_id. + caller->HandleSubmittedTaskFinished(ObjectID::Nil()); + ASSERT_TRUE(caller->FlushBorrowerCallbacks()); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // Caller's reference to return_id goes out of scope. The caller should + // respond to the owner of inner_id so that inner_id can be deleted. + caller_rc.RemoveLocalReference(return_id, nullptr); + ASSERT_FALSE(caller_rc.HasReference(return_id)); + ASSERT_FALSE(caller_rc.HasReference(inner_id)); + ASSERT_FALSE(owner_rc.HasReference(inner_id)); +} + +// We submit a task and submit another task that depends on the return ID. The +// submitted task returns an object ID, which will get borrowed by the second +// task. +// +// @ray.remote +// def returns_id(): +// inner_id = ray.put() +// return inner_id +// +// return_id = returns_id.remote() +// borrow.remote(return_id) +TEST(DistributedReferenceCountTest, TestReturnObjectIdBorrowChain) { + ReferenceCounter caller_rc; + auto caller = std::make_shared(caller_rc, "1"); + ReferenceCounter borrower_rc; + auto borrower = std::make_shared(borrower_rc, "2"); + ReferenceCounter owner_rc(true, [&](const rpc::Address &addr) { + if (addr.ip_address() == caller->address_.ip_address()) { + return caller; + } else { + return borrower; + } + }); + auto owner = std::make_shared(owner_rc, "3"); + + // Caller submits a task. + auto return_id = caller->SubmitTaskWithArg(ObjectID::Nil()); + + // Task returns inner_id as its return value. + auto inner_id = ObjectID::FromRandom(); + owner->Put(inner_id); + rpc::WorkerAddress addr(caller->address_); + auto refs = owner->FinishExecutingTask(ObjectID::Nil(), return_id, &inner_id, &addr); + owner_rc.RemoveLocalReference(inner_id, nullptr); + ASSERT_TRUE(refs.empty()); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // Caller receives the owner's message, but inner_id is still in scope + // because caller has a reference to return_id. + caller->HandleSubmittedTaskFinished(ObjectID::Nil()); + caller->SubmitTaskWithArg(return_id); + caller_rc.RemoveLocalReference(return_id, nullptr); + ASSERT_TRUE(caller->FlushBorrowerCallbacks()); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // Borrower receives a reference to inner_id. It still has a reference when + // the task returns. + borrower->ExecuteTaskWithArg(return_id, inner_id, owner->task_id_, owner->address_); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + auto borrower_refs = borrower->FinishExecutingTask(return_id, return_id); + ASSERT_TRUE(borrower_rc.HasReference(inner_id)); + + // Borrower merges ref count into the caller. + caller->HandleSubmittedTaskFinished(return_id, borrower->address_, borrower_refs); + // The caller should not have a ref count anymore because it was merged into + // the owner. + ASSERT_FALSE(caller_rc.HasReference(return_id)); + ASSERT_FALSE(caller_rc.HasReference(inner_id)); + ASSERT_TRUE(owner_rc.HasReference(inner_id)); + + // The borrower's receives the owner's message and its reference goes out of + // scope. + ASSERT_TRUE(borrower->FlushBorrowerCallbacks()); + borrower_rc.RemoveLocalReference(inner_id, nullptr); + ASSERT_FALSE(borrower_rc.HasReference(return_id)); + ASSERT_FALSE(borrower_rc.HasReference(inner_id)); + ASSERT_FALSE(owner_rc.HasReference(inner_id)); +} + +// TODO: Test returning an Object ID. +// TODO: Test Pop and Merge individually. + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index a6e19e08c..210cda9c5 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -155,8 +155,8 @@ std::shared_ptr CoreWorkerMemoryStore::GetOrPromoteToPlasma( Status CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_id) { RAY_CHECK(object_id.IsDirectCallType()); std::vector)>> async_callbacks; - auto object_entry = - std::make_shared(object.GetData(), object.GetMetadata(), true); + auto object_entry = std::make_shared(object.GetData(), object.GetMetadata(), + object.GetInlinedIds(), true); { absl::MutexLock lock(&mu_); @@ -236,6 +236,7 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, if (iter != objects_.end()) { (*results)[i] = iter->second; if (remove_after_get) { + RAY_LOG(ERROR) << "REMOVE_AFTER_GET"; // Note that we cannot remove the object_id from `objects_` now, // because `object_ids` might have duplicate ids. ids_to_remove.insert(object_id); @@ -353,8 +354,8 @@ Status CoreWorkerMemoryStore::Get( bool *got_exception) { const std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; - RAY_RETURN_NOT_OK( - Get(id_vector, id_vector.size(), timeout_ms, ctx, true, &result_objects)); + RAY_RETURN_NOT_OK(Get(id_vector, id_vector.size(), timeout_ms, ctx, + /*remove_after_get=*/false, &result_objects)); for (size_t i = 0; i < id_vector.size(); i++) { if (result_objects[i] != nullptr) { diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index c812a88e2..2b5624d2b 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -28,7 +28,8 @@ Status CoreWorkerPlasmaStoreProvider::SetClientOptions(std::string name, } Status CoreWorkerPlasmaStoreProvider::Put(const RayObject &object, - const ObjectID &object_id) { + const ObjectID &object_id, + bool *object_exists) { RAY_CHECK(!object.IsInPlasmaError()) << object_id; std::shared_ptr data; RAY_RETURN_NOT_OK(Create(object.GetMetadata(), @@ -41,6 +42,11 @@ Status CoreWorkerPlasmaStoreProvider::Put(const RayObject &object, memcpy(data->Data(), object.GetData()->Data(), object.GetData()->Size()); } RAY_RETURN_NOT_OK(Seal(object_id)); + if (object_exists) { + *object_exists = false; + } + } else if (object_exists) { + *object_exists = true; } return Status::OK(); } @@ -116,7 +122,8 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore( if (plasma_results[i].metadata && plasma_results[i].metadata->size()) { metadata = std::make_shared(plasma_results[i].metadata); } - const auto result_object = std::make_shared(data, metadata); + const auto result_object = + std::make_shared(data, metadata, std::vector()); (*results)[object_id] = result_object; remaining.erase(object_id); if (result_object->IsException()) { diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index 1b545ac7d..992b339bd 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -27,7 +27,15 @@ class CoreWorkerPlasmaStoreProvider { Status SetClientOptions(std::string name, int64_t limit_bytes); - Status Put(const RayObject &object, const ObjectID &object_id); + /// Create and seal an object. + /// + /// \param[in] object The object to create. + /// \param[in] object_id The ID of the object. This can be used as an + /// argument to Get to retrieve the object data. + /// \param[out] object_exists Optional. Returns whether an object with the + /// same ID already exists. If this is true, then the Put does not write any + /// object data. + Status Put(const RayObject &object, const ObjectID &object_id, bool *object_exists); Status Create(const std::shared_ptr &metadata, const size_t data_size, const ObjectID &object_id, std::shared_ptr *data); diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 18a723406..9fbf30f1f 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -24,6 +24,13 @@ void TaskManager::AddPendingTask(const TaskID &caller_id, if (spec.ArgByRef(i)) { for (size_t j = 0; j < spec.ArgIdCount(i); j++) { task_deps.push_back(spec.ArgId(i, j)); + RAY_LOG(DEBUG) << "Adding arg ID " << spec.ArgId(i, j); + } + } else { + const auto &inlined_ids = spec.ArgInlinedIds(i); + for (const auto &inlined_id : inlined_ids) { + task_deps.push_back(inlined_id); + RAY_LOG(DEBUG) << "Adding inlined ID " << inlined_id; } } } @@ -35,8 +42,13 @@ void TaskManager::AddPendingTask(const TaskID &caller_id, num_returns--; } for (size_t i = 0; i < num_returns; i++) { + // We pass an empty vector for inner IDs because we do not know the return + // value of the task yet. If the task returns an ID(s), the worker will + // notify us via the WaitForRefRemoved RPC that we are now a borrower for + // the inner IDs. Note that this RPC can be received *before* the + // PushTaskReply. reference_counter_->AddOwnedObject(spec.ReturnId(i, TaskTransportType::DIRECT), - caller_id, caller_address); + /*inner_ids=*/{}, caller_id, caller_address); } } @@ -59,7 +71,7 @@ bool TaskManager::IsTaskPending(const TaskID &task_id) const { void TaskManager::CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply, - const rpc::Address *actor_addr) { + const rpc::Address &worker_addr) { RAY_LOG(DEBUG) << "Completing task " << task_id; TaskSpecification spec; { @@ -71,7 +83,7 @@ void TaskManager::CompletePendingTask(const TaskID &task_id, pending_tasks_.erase(it); } - RemovePlasmaSubmittedTaskReferences(spec); + RemoveFinishedTaskReferences(spec, worker_addr, reply.borrowed_refs()); for (int i = 0; i < reply.return_objects_size(); i++) { const auto &return_object = reply.return_objects(i); @@ -96,8 +108,10 @@ void TaskManager::CompletePendingTask(const TaskID &task_id, reinterpret_cast(return_object.metadata().data())), return_object.metadata().size()); } - RAY_CHECK_OK( - in_memory_store_->Put(RayObject(data_buffer, metadata_buffer), object_id)); + RAY_CHECK_OK(in_memory_store_->Put( + RayObject(data_buffer, metadata_buffer, + IdVectorFromProtobuf(return_object.inlined_ids())), + object_id)); } } @@ -154,7 +168,10 @@ void TaskManager::PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_ } } } - RemovePlasmaSubmittedTaskReferences(spec); + // The worker failed to execute the task, so it cannot be borrowing any + // objects. + RemoveFinishedTaskReferences(spec, rpc::Address(), + ReferenceCounter::ReferenceTableProto()); MarkPendingTaskFailed(task_id, spec, error_type); } @@ -169,26 +186,38 @@ void TaskManager::ShutdownIfNeeded() { } } -void TaskManager::RemoveSubmittedTaskReferences(const std::vector &object_ids) { +void TaskManager::RemoveSubmittedTaskReferences( + const std::vector &object_ids, const rpc::Address &worker_addr, + const ReferenceCounter::ReferenceTableProto &borrowed_refs) { std::vector deleted; - reference_counter_->RemoveSubmittedTaskReferences(object_ids, &deleted); + reference_counter_->UpdateSubmittedTaskReferences(object_ids, worker_addr, + borrowed_refs, &deleted); in_memory_store_->Delete(deleted); } -void TaskManager::OnTaskDependenciesInlined(const std::vector &object_ids) { - RemoveSubmittedTaskReferences(object_ids); +void TaskManager::OnTaskDependenciesInlined( + const std::vector &inlined_dependency_ids, + const std::vector &contained_ids) { + reference_counter_->AddSubmittedTaskReferences(contained_ids); + RemoveSubmittedTaskReferences(inlined_dependency_ids); } -void TaskManager::RemovePlasmaSubmittedTaskReferences(TaskSpecification &spec) { +void TaskManager::RemoveFinishedTaskReferences( + TaskSpecification &spec, const rpc::Address &borrower_addr, + const ReferenceCounter::ReferenceTableProto &borrowed_refs) { std::vector plasma_dependencies; for (size_t i = 0; i < spec.NumArgs(); i++) { - auto count = spec.ArgIdCount(i); - if (count > 0) { - const auto &id = spec.ArgId(i, 0); - plasma_dependencies.push_back(id); + if (spec.ArgByRef(i)) { + for (size_t j = 0; j < spec.ArgIdCount(i); j++) { + plasma_dependencies.push_back(spec.ArgId(i, j)); + } + } else { + const auto &inlined_ids = spec.ArgInlinedIds(i); + plasma_dependencies.insert(plasma_dependencies.end(), inlined_ids.begin(), + inlined_ids.end()); } } - RemoveSubmittedTaskReferences(plasma_dependencies); + RemoveSubmittedTaskReferences(plasma_dependencies, borrower_addr, borrowed_refs); } void TaskManager::MarkPendingTaskFailed(const TaskID &task_id, diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index c23568e25..52f2c0333 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -16,12 +16,14 @@ namespace ray { class TaskFinisherInterface { public: virtual void CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply, - const rpc::Address *actor_addr) = 0; + const rpc::Address &actor_addr) = 0; virtual void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type, Status *status = nullptr) = 0; - virtual void OnTaskDependenciesInlined(const std::vector &object_ids) = 0; + virtual void OnTaskDependenciesInlined( + const std::vector &inlined_dependency_ids, + const std::vector &contained_ids) = 0; virtual ~TaskFinisherInterface() {} }; @@ -65,10 +67,10 @@ class TaskManager : public TaskFinisherInterface { /// /// \param[in] task_id ID of the pending task. /// \param[in] reply Proto response to a direct actor or task call. - /// \param[in] actor_addr Address of the created actor, or nullptr. + /// \param[in] worker_addr Address of the worker that executed the task. /// \return Void. void CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply, - const rpc::Address *actor_addr) override; + const rpc::Address &worker_addr) override; /// A pending task failed. This will either retry the task or mark the task /// as failed if there are no retries left. @@ -79,7 +81,17 @@ class TaskManager : public TaskFinisherInterface { void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type, Status *status = nullptr) override; - void OnTaskDependenciesInlined(const std::vector &object_id) override; + /// A task's dependencies were inlined in the task spec. This will decrement + /// the ref count for the dependency IDs. If the dependencies contained other + /// ObjectIDs, then the ref count for these object IDs will be incremented. + /// + /// \param[in] inlined_dependency_ids The args that were originally passed by + /// reference into the task, but have now been inlined. + /// \param[in] contained_ids Any ObjectIDs that were newly inlined in the + /// task spec, because a serialized copy of the ID was contained in one of + /// the inlined dependencies. + void OnTaskDependenciesInlined(const std::vector &inlined_dependency_ids, + const std::vector &contained_ids) override; /// Return the spec for a pending task. TaskSpecification GetTaskSpec(const TaskID &task_id) const; @@ -93,13 +105,24 @@ class TaskManager : public TaskFinisherInterface { void MarkPendingTaskFailed(const TaskID &task_id, const TaskSpecification &spec, rpc::ErrorType error_type) LOCKS_EXCLUDED(mu_); - /// Remove submittted task references in the reference counter for the object IDs. - /// If their reference counts reach zero, they are deleted from the in-memory store. - void RemoveSubmittedTaskReferences(const std::vector &object_ids); + /// Remove submitted task references in the reference counter for the object IDs. + /// If the references were borrowed by a worker while executing a task, then + /// merge in the ref counts for any references that the task (or a nested + /// task) is still borrowing. If any reference counts for the borrowed + /// objects reach zero, they are deleted from the in-memory store. + void RemoveSubmittedTaskReferences( + const std::vector &object_ids, + const rpc::Address &worker_addr = rpc::Address(), + const ReferenceCounter::ReferenceTableProto &borrowed_refs = + ReferenceCounter::ReferenceTableProto()); - /// Helper function to call RemoveSubmittedTaskReferences on the plasma dependencies - /// of the given task spec. - void RemovePlasmaSubmittedTaskReferences(TaskSpecification &spec); + /// Helper function to call RemoveSubmittedTaskReferences on the remaining + /// dependencies of the given task spec after the task has finished or + /// failed. The remaining dependencies are plasma objects and any ObjectIDs + /// that were inlined in the task spec. + void RemoveFinishedTaskReferences( + TaskSpecification &spec, const rpc::Address &worker_addr, + const ReferenceCounter::ReferenceTableProto &borrowed_refs); /// Shutdown if all tasks are finished and shutdown is scheduled. void ShutdownIfNeeded() LOCKS_EXCLUDED(mu_); diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 9e38ead3f..11e196584 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -54,7 +54,8 @@ ActorID CreateActorHelper(CoreWorker &worker, RayFunction func(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( "actor creation task", "", "", "")); std::vector args; - args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr))); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(buffer, nullptr, std::vector()))); ActorCreationOptions actor_options{max_reconstructions, is_direct_call, /*max_concurrency*/ 1, resources, resources, {}, @@ -320,11 +321,12 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map &res auto buffer2 = GenerateRandomBuffer(); ObjectID object_id; - RAY_CHECK_OK(driver.Put(RayObject(buffer2, nullptr), {}, &object_id)); + RAY_CHECK_OK(driver.Put(RayObject(buffer2, nullptr, std::vector()), {}, + &object_id)); std::vector args; - args.emplace_back( - TaskArg::PassByValue(std::make_shared(buffer1, nullptr))); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(buffer1, nullptr, std::vector()))); args.emplace_back(TaskArg::PassByReference(object_id)); RayFunction func(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( @@ -369,10 +371,10 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso // Create arguments with PassByRef and PassByValue. std::vector args; - args.emplace_back( - TaskArg::PassByValue(std::make_shared(buffer1, nullptr))); - args.emplace_back( - TaskArg::PassByValue(std::make_shared(buffer2, nullptr))); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(buffer1, nullptr, std::vector()))); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(buffer2, nullptr, std::vector()))); TaskOptions options{1, false, resources}; std::vector return_ids; @@ -410,13 +412,14 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso auto buffer2 = std::make_shared(array2, sizeof(array2)); ObjectID object_id; - RAY_CHECK_OK(driver.Put(RayObject(buffer1, nullptr), {}, &object_id)); + RAY_CHECK_OK( + driver.Put(RayObject(buffer1, nullptr, std::vector()), {}, &object_id)); // Create arguments with PassByRef and PassByValue. std::vector args; args.emplace_back(TaskArg::PassByReference(object_id)); - args.emplace_back( - TaskArg::PassByValue(std::make_shared(buffer2, nullptr))); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(buffer2, nullptr, std::vector()))); TaskOptions options{1, false, resources}; std::vector return_ids; @@ -481,8 +484,8 @@ void CoreWorkerTest::TestActorReconstruction( // Create arguments with PassByValue. std::vector args; - args.emplace_back( - TaskArg::PassByValue(std::make_shared(buffer1, nullptr))); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(buffer1, nullptr, std::vector()))); TaskOptions options{1, false, resources}; std::vector return_ids; @@ -527,8 +530,8 @@ void CoreWorkerTest::TestActorFailure(std::unordered_map &r // Create arguments with PassByRef and PassByValue. std::vector args; - args.emplace_back( - TaskArg::PassByValue(std::make_shared(buffer1, nullptr))); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(buffer1, nullptr, std::vector()))); TaskOptions options{1, false, resources}; std::vector return_ids; @@ -585,7 +588,8 @@ TEST_F(ZeroNodeTest, TestTaskArg) { ASSERT_EQ(by_ref.GetReference(), id); // Test by-value argument. auto buffer = GenerateRandomBuffer(); - TaskArg by_value = TaskArg::PassByValue(std::make_shared(buffer, nullptr)); + TaskArg by_value = TaskArg::PassByValue( + std::make_shared(buffer, nullptr, std::vector())); ASSERT_FALSE(by_value.IsPassedByReference()); auto data = by_value.GetValue().GetData(); ASSERT_TRUE(data != nullptr); @@ -601,7 +605,8 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { RayFunction function(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython("", "", "", "")); std::vector args; - args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr))); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(buffer, nullptr, std::vector()))); std::unordered_map resources; ActorCreationOptions actor_options{0, @@ -678,7 +683,8 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { int64_t array[] = {SHOULD_CHECK_MESSAGE_ORDER, i}; auto buffer = std::make_shared(reinterpret_cast(array), sizeof(array)); - args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr))); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(buffer, nullptr, std::vector()))); TaskOptions options{1, false, resources}; std::vector return_ids; @@ -753,9 +759,11 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { std::vector buffers; buffers.emplace_back(std::make_shared(array1, sizeof(array1)), - std::make_shared(array1, sizeof(array1) / 2)); + std::make_shared(array1, sizeof(array1) / 2), + std::vector()); buffers.emplace_back(std::make_shared(array2, sizeof(array2)), - std::make_shared(array2, sizeof(array2) / 2)); + std::make_shared(array2, sizeof(array2) / 2), + std::vector()); std::vector ids(buffers.size()); for (size_t i = 0; i < ids.size(); i++) { @@ -878,9 +886,11 @@ TEST_F(SingleNodeTest, TestObjectInterface) { std::vector buffers; buffers.emplace_back(std::make_shared(array1, sizeof(array1)), - std::make_shared(array1, sizeof(array1) / 2)); + std::make_shared(array1, sizeof(array1) / 2), + std::vector()); buffers.emplace_back(std::make_shared(array2, sizeof(array2)), - std::make_shared(array2, sizeof(array2) / 2)); + std::make_shared(array2, sizeof(array2) / 2), + std::vector()); std::vector ids(buffers.size()); for (size_t i = 0; i < ids.size(); i++) { @@ -904,8 +914,9 @@ TEST_F(SingleNodeTest, TestObjectInterface) { char error_buffer[error_string.size()]; size_t len = error_string.copy(error_buffer, error_string.size(), 0); buffers_with_exception.emplace_back( - nullptr, std::make_shared( - reinterpret_cast(error_buffer), len)); + nullptr, + std::make_shared(reinterpret_cast(error_buffer), len), + std::vector()); RAY_CHECK_OK( core_worker.Put(buffers_with_exception.back(), {}, ids_with_exception.back())); @@ -958,7 +969,8 @@ TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) { std::vector ids(buffers.size()); for (size_t i = 0; i < ids.size(); i++) { - RAY_CHECK_OK(worker1.Put(RayObject(buffers[i], nullptr), {}, &ids[i])); + RAY_CHECK_OK(worker1.Put(RayObject(buffers[i], nullptr, std::vector()), {}, + &ids[i])); } // Test Get() from remote node. diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index 38bff92bc..0d1b9b38d 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -13,6 +13,8 @@ using ::testing::_; class MockWorkerClient : public rpc::CoreWorkerClientInterface { public: + const rpc::Address &Addr() const override { return addr; } + ray::Status PushActorTask( std::unique_ptr request, const rpc::ClientCallback &callback) override { @@ -32,6 +34,7 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface { return true; } + rpc::Address addr; std::list> callbacks; uint64_t counter = 0; }; @@ -41,11 +44,12 @@ class MockTaskFinisher : public TaskFinisherInterface { MockTaskFinisher() {} MOCK_METHOD3(CompletePendingTask, void(const TaskID &, const rpc::PushTaskReply &, - const rpc::Address *addr)); + const rpc::Address &addr)); MOCK_METHOD3(PendingTaskFailed, void(const TaskID &task_id, rpc::ErrorType error_type, Status *status)); - MOCK_METHOD1(OnTaskDependenciesInlined, void(const std::vector &object_ids)); + MOCK_METHOD2(OnTaskDependenciesInlined, + void(const std::vector &, const std::vector &)); }; TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) { @@ -63,9 +67,8 @@ class DirectActorTransportTest : public ::testing::Test { : worker_client_(std::shared_ptr(new MockWorkerClient())), store_(std::shared_ptr(new CoreWorkerMemoryStore())), task_finisher_(std::make_shared()), - submitter_(address_, - [&](const std::string ip, int port) { return worker_client_; }, store_, - task_finisher_) {} + submitter_(address_, [&](const rpc::Address &addr) { return worker_client_; }, + store_, task_finisher_) {} rpc::Address address_; std::shared_ptr worker_client_; diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 4f189bb0f..082fdfc47 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -46,7 +46,7 @@ class MockTaskFinisher : public TaskFinisherInterface { MockTaskFinisher() {} void CompletePendingTask(const TaskID &, const rpc::PushTaskReply &, - const rpc::Address *actor_addr) override { + const rpc::Address &actor_addr) override { num_tasks_complete++; } @@ -55,13 +55,16 @@ class MockTaskFinisher : public TaskFinisherInterface { num_tasks_failed++; } - void OnTaskDependenciesInlined(const std::vector &object_ids) override { - num_inlined += object_ids.size(); + void OnTaskDependenciesInlined(const std::vector &inlined_dependency_ids, + const std::vector &contained_ids) override { + num_inlined_dependencies += inlined_dependency_ids.size(); + num_contained_ids += contained_ids.size(); } int num_tasks_complete = 0; int num_tasks_failed = 0; - int num_inlined = 0; + int num_inlined_dependencies = 0; + int num_contained_ids = 0; }; class MockRayletClient : public WorkerLeaseInterface { @@ -147,7 +150,7 @@ TEST(LocalDependencyResolverTest, TestNoDependencies) { bool ok = false; resolver.ResolveDependencies(task, [&ok]() { ok = true; }); ASSERT_TRUE(ok); - ASSERT_EQ(task_finisher->num_inlined, 0); + ASSERT_EQ(task_finisher->num_inlined_dependencies, 0); } TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) { @@ -162,7 +165,7 @@ TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) { // We ignore and don't block on plasma dependencies. ASSERT_TRUE(ok); ASSERT_EQ(resolver.NumPendingTasks(), 0); - ASSERT_EQ(task_finisher->num_inlined, 0); + ASSERT_EQ(task_finisher->num_inlined_dependencies, 0); } TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { @@ -173,7 +176,7 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); - auto data = RayObject(nullptr, meta_buffer); + auto data = RayObject(nullptr, meta_buffer, std::vector()); ASSERT_TRUE(store->Put(data, obj1).ok()); TaskSpecification task; task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); @@ -185,7 +188,7 @@ TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { // Checks that the object id is still a direct call id. ASSERT_TRUE(task.ArgId(0, 0).IsDirectCallType()); ASSERT_EQ(resolver.NumPendingTasks(), 0); - ASSERT_EQ(task_finisher->num_inlined, 0); + ASSERT_EQ(task_finisher->num_inlined_dependencies, 0); } TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { @@ -210,7 +213,7 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { ASSERT_NE(task.ArgData(0), nullptr); ASSERT_NE(task.ArgData(1), nullptr); ASSERT_EQ(resolver.NumPendingTasks(), 0); - ASSERT_EQ(task_finisher->num_inlined, 2); + ASSERT_EQ(task_finisher->num_inlined_dependencies, 2); } TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { @@ -237,7 +240,37 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { ASSERT_NE(task.ArgData(0), nullptr); ASSERT_NE(task.ArgData(1), nullptr); ASSERT_EQ(resolver.NumPendingTasks(), 0); - ASSERT_EQ(task_finisher->num_inlined, 2); + ASSERT_EQ(task_finisher->num_inlined_dependencies, 2); + ASSERT_EQ(task_finisher->num_contained_ids, 0); +} + +TEST(LocalDependencyResolverTest, TestInlinedObjectIds) { + auto store = std::make_shared(); + auto task_finisher = std::make_shared(); + LocalDependencyResolver resolver(store, task_finisher); + ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + ObjectID obj3 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + auto data = GenerateRandomObject({obj3}); + TaskSpecification task; + task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); + task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary()); + bool ok = false; + resolver.ResolveDependencies(task, [&ok]() { ok = true; }); + ASSERT_EQ(resolver.NumPendingTasks(), 1); + ASSERT_TRUE(!ok); + ASSERT_TRUE(store->Put(*data, obj1).ok()); + ASSERT_TRUE(store->Put(*data, obj2).ok()); + // Tests that the task proto was rewritten to have inline argument values after + // resolution completes. + ASSERT_TRUE(ok); + ASSERT_FALSE(task.ArgByRef(0)); + ASSERT_FALSE(task.ArgByRef(1)); + ASSERT_NE(task.ArgData(0), nullptr); + ASSERT_NE(task.ArgData(1), nullptr); + ASSERT_EQ(resolver.NumPendingTasks(), 0); + ASSERT_EQ(task_finisher->num_inlined_dependencies, 2); + ASSERT_EQ(task_finisher->num_contained_ids, 2); } TaskSpecification BuildTaskSpec(const std::unordered_map &resources, @@ -255,7 +288,7 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) { auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); - auto factory = [&](const std::string &addr, int port) { return worker_client; }; + auto factory = [&](const rpc::Address &addr) { return worker_client; }; auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); @@ -287,7 +320,7 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) { auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); - auto factory = [&](const std::string &addr, int port) { return worker_client; }; + auto factory = [&](const rpc::Address &addr) { return worker_client; }; auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); @@ -312,7 +345,7 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); - auto factory = [&](const std::string &addr, int port) { return worker_client; }; + auto factory = [&](const rpc::Address &addr) { return worker_client; }; auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); @@ -358,7 +391,7 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); - auto factory = [&](const std::string &addr, int port) { return worker_client; }; + auto factory = [&](const rpc::Address &addr) { return worker_client; }; auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); @@ -407,7 +440,7 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); - auto factory = [&](const std::string &addr, int port) { return worker_client; }; + auto factory = [&](const rpc::Address &addr) { return worker_client; }; auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); @@ -446,7 +479,7 @@ TEST(DirectTaskTransportTest, TestWorkerNotReturnedOnExit) { auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); - auto factory = [&](const std::string &addr, int port) { return worker_client; }; + auto factory = [&](const rpc::Address &addr) { return worker_client; }; auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); @@ -475,7 +508,7 @@ TEST(DirectTaskTransportTest, TestSpillback) { auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); - auto factory = [&](const std::string &addr, int port) { return worker_client; }; + auto factory = [&](const rpc::Address &addr) { return worker_client; }; std::unordered_map> remote_lease_clients; auto lease_client_factory = [&](const std::string &ip, int port) { @@ -525,7 +558,7 @@ TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) { auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); - auto factory = [&](const std::string &addr, int port) { return worker_client; }; + auto factory = [&](const rpc::Address &addr) { return worker_client; }; std::unordered_map> remote_lease_clients; auto lease_client_factory = [&](const std::string &ip, int port) { @@ -584,7 +617,7 @@ void TestSchedulingKey(const std::shared_ptr store, rpc::Address address; auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); - auto factory = [&](const std::string &addr, int port) { return worker_client; }; + auto factory = [&](const rpc::Address &addr) { return worker_client; }; auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); @@ -658,7 +691,7 @@ TEST(DirectTaskTransportTest, TestSchedulingKeys) { std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); - auto plasma_data = RayObject(nullptr, meta_buffer); + auto plasma_data = RayObject(nullptr, meta_buffer, std::vector()); ASSERT_TRUE(store->Put(plasma_data, plasma1).ok()); ASSERT_TRUE(store->Put(plasma_data, plasma2).ok()); @@ -686,7 +719,7 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); - auto factory = [&](const std::string &addr, int port) { return worker_client; }; + auto factory = [&](const rpc::Address &addr) { return worker_client; }; auto task_finisher = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index 059705429..20cef6e26 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -62,7 +62,8 @@ class MockWorker { const_cast(reinterpret_cast(pid_string.data())); auto memory_buffer = std::make_shared(data, pid_string.size(), true); - results->push_back(std::make_shared(memory_buffer, nullptr)); + results->push_back( + std::make_shared(memory_buffer, nullptr, std::vector())); return Status::OK(); } @@ -90,7 +91,8 @@ class MockWorker { // Write the merged content to each of return ids. for (size_t i = 0; i < return_ids.size(); i++) { - results->push_back(std::make_shared(memory_buffer, nullptr)); + results->push_back( + std::make_shared(memory_buffer, nullptr, std::vector())); } return Status::OK(); diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index 1acaa2dd1..16468a4ca 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -65,7 +65,7 @@ TEST_F(TaskManagerTest, TestTaskSuccess) { return_object->set_object_id(return_id.Binary()); auto data = GenerateRandomBuffer(); return_object->set_data(data->Data(), data->Size()); - manager_.CompletePendingTask(spec.TaskId(), reply, nullptr); + manager_.CompletePendingTask(spec.TaskId(), reply, rpc::Address()); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); // Only the return object reference should remain. ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1); diff --git a/src/ray/core_worker/transport/dependency_resolver.cc b/src/ray/core_worker/transport/dependency_resolver.cc index 3f3b6586a..392e0067b 100644 --- a/src/ray/core_worker/transport/dependency_resolver.cc +++ b/src/ray/core_worker/transport/dependency_resolver.cc @@ -18,7 +18,8 @@ struct TaskState { void InlineDependencies( absl::flat_hash_map> dependencies, - TaskSpecification &task, std::vector *inlined) { + TaskSpecification &task, std::vector *inlined_dependency_ids, + std::vector *contained_ids) { auto &msg = task.GetMutableMessage(); size_t found = 0; for (size_t i = 0; i < task.NumArgs(); i++) { @@ -43,7 +44,11 @@ void InlineDependencies( const auto &metadata = it->second->GetMetadata(); mutable_arg->set_metadata(metadata->Data(), metadata->Size()); } - inlined->push_back(id); + for (const auto &inlined_id : it->second->GetInlinedIds()) { + mutable_arg->add_nested_inlined_ids(inlined_id.Binary()); + contained_ids->push_back(inlined_id); + } + inlined_dependency_ids->push_back(id); } found++; } else { @@ -80,27 +85,29 @@ void LocalDependencyResolver::ResolveDependencies(TaskSpecification &task, for (const auto &it : state->local_dependencies) { const ObjectID &obj_id = it.first; - in_memory_store_->GetAsync( - obj_id, [this, state, obj_id, on_complete](std::shared_ptr obj) { - RAY_CHECK(obj != nullptr); - bool complete = false; - std::vector inlined; - { - absl::MutexLock lock(&mu_); - state->local_dependencies[obj_id] = std::move(obj); - if (--state->dependencies_remaining == 0) { - InlineDependencies(state->local_dependencies, state->task, &inlined); - complete = true; - num_pending_ -= 1; - } - } - if (inlined.size() > 0) { - task_finisher_->OnTaskDependenciesInlined(inlined); - } - if (complete) { - on_complete(); - } - }); + in_memory_store_->GetAsync(obj_id, [this, state, obj_id, + on_complete](std::shared_ptr obj) { + RAY_CHECK(obj != nullptr); + bool complete = false; + std::vector inlined_dependency_ids; + std::vector contained_ids; + { + absl::MutexLock lock(&mu_); + state->local_dependencies[obj_id] = std::move(obj); + if (--state->dependencies_remaining == 0) { + InlineDependencies(state->local_dependencies, state->task, + &inlined_dependency_ids, &contained_ids); + complete = true; + num_pending_ -= 1; + } + } + if (inlined_dependency_ids.size() > 0) { + task_finisher_->OnTaskDependenciesInlined(inlined_dependency_ids, contained_ids); + } + if (complete) { + on_complete(); + } + }); } } diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index eee1a7270..c26ad9826 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -80,8 +80,8 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectActor(const ActorID &actor_id, // Create a new connection to the actor. // TODO(edoakes): are these clients cleaned up properly? if (rpc_clients_.count(actor_id) == 0) { - rpc_clients_[actor_id] = std::shared_ptr( - client_factory_(address.ip_address(), address.port())); + rpc_clients_[actor_id] = + std::shared_ptr(client_factory_(address)); } if (pending_requests_.count(actor_id) > 0) { SendPendingTasks(actor_id); @@ -152,13 +152,14 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask( auto it = worker_ids_.find(actor_id); RAY_CHECK(it != worker_ids_.end()) << "Actor worker id not found " << actor_id.Hex(); request->set_intended_worker_id(it->second); + rpc::Address addr(client.Addr()); RAY_CHECK_OK(client.PushActorTask( std::move(request), - [this, task_id](Status status, const rpc::PushTaskReply &reply) { + [this, addr, task_id](Status status, const rpc::PushTaskReply &reply) { if (!status.ok()) { task_finisher_->PendingTaskFailed(task_id, rpc::ErrorType::ACTOR_DIED, &status); } else { - task_finisher_->CompletePendingTask(task_id, reply, nullptr); + task_finisher_->CompletePendingTask(task_id, reply, addr); } })); } @@ -252,7 +253,9 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask( RAY_CHECK(num_returns >= 0); std::vector> return_objects; - auto status = task_handler_(task_spec, resource_ids, &return_objects); + auto status = task_handler_(task_spec, resource_ids, &return_objects, + reply->mutable_borrowed_refs()); + bool objects_valid = return_objects.size() == num_returns; if (objects_valid) { std::vector plasma_return_ids; @@ -276,6 +279,9 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask( return_object->set_metadata(result->GetMetadata()->Data(), result->GetMetadata()->Size()); } + for (const auto &inlined_id : result->GetInlinedIds()) { + return_object->add_inlined_ids(inlined_id.Binary()); + } } } // If we spilled any return objects to plasma, notify the raylet to pin them. diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 99da8e5d2..8b889604a 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -368,7 +368,8 @@ class CoreWorkerDirectTaskReceiver { using TaskHandler = std::function resource_ids, - std::vector> *return_objects)>; + std::vector> *return_objects, + ReferenceCounter::ReferenceTableProto *borrower_refs)>; CoreWorkerDirectTaskReceiver(WorkerContext &worker_context, std::shared_ptr &local_raylet_client, diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index 7fc563cac..ce1c0a5b1 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -29,8 +29,8 @@ void CoreWorkerDirectTaskSubmitter::AddWorkerLeaseClient( const rpc::WorkerAddress &addr, std::shared_ptr lease_client) { auto it = client_cache_.find(addr); if (it == client_cache_.end()) { - client_cache_[addr] = std::shared_ptr( - client_factory_(addr.ip_address, addr.port)); + client_cache_[addr] = + std::shared_ptr(client_factory_(addr.ToProto())); RAY_LOG(INFO) << "Connected to " << addr.ip_address << ":" << addr.port; } int64_t expiration = current_time_ms() + lease_timeout_ms_; @@ -117,10 +117,7 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( // We got a lease for a worker. Add the lease client state and try to // assign work to the worker. RAY_LOG(DEBUG) << "Lease granted " << task_id; - rpc::WorkerAddress addr = { - reply.worker_address().ip_address(), reply.worker_address().port(), - WorkerID::FromBinary(reply.worker_address().worker_id()), - ClientID::FromBinary(reply.worker_address().raylet_id())}; + rpc::WorkerAddress addr(reply.worker_address()); AddWorkerLeaseClient(addr, std::move(lease_client)); auto resources_copy = reply.resource_mapping(); OnWorkerIdle(addr, scheduling_key, @@ -199,8 +196,7 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask( is_actor ? rpc::ErrorType::ACTOR_DIED : rpc::ErrorType::WORKER_DIED, &status); } else { - rpc::Address proto = addr.ToProto(); - task_finisher_->CompletePendingTask(task_id, reply, &proto); + task_finisher_->CompletePendingTask(task_id, reply, addr.ToProto()); } }); if (!status.ok()) { diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc index c9b4b68e6..a7129b7c4 100644 --- a/src/ray/core_worker/transport/raylet_transport.cc +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -48,7 +48,10 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask( } std::vector> results; - auto status = task_handler_(task_spec, resource_ids, &results); + ReferenceCounter::ReferenceTableProto borrower_refs; + // NOTE(swang): Distributed ref counting does not work for the raylet + // transport. + auto status = task_handler_(task_spec, resource_ids, &results, &borrower_refs); if (status.IsSystemExit()) { exit_handler_(status.IsIntentionalSystemExit()); return; diff --git a/src/ray/core_worker/transport/raylet_transport.h b/src/ray/core_worker/transport/raylet_transport.h index 0bcf7a761..7118e72b7 100644 --- a/src/ray/core_worker/transport/raylet_transport.h +++ b/src/ray/core_worker/transport/raylet_transport.h @@ -4,6 +4,7 @@ #include #include "ray/common/ray_object.h" +#include "ray/core_worker/reference_count.h" #include "ray/raylet/raylet_client.h" #include "ray/rpc/worker/core_worker_server.h" @@ -14,7 +15,8 @@ class CoreWorkerRayletTaskReceiver { using TaskHandler = std::function &resource_ids, - std::vector> *return_objects)>; + std::vector> *return_objects, + ReferenceCounter::ReferenceTableProto *borrower_refs)>; CoreWorkerRayletTaskReceiver(const WorkerID &worker_id, std::shared_ptr &raylet_client, diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index c37c805c1..c6d05eb03 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -116,6 +116,8 @@ message TaskArg { bytes data = 2; // Metadata for pass-by-value arguments. bytes metadata = 3; + // ObjectIDs that were nested in the inlined arguments of the data field. + repeated bytes nested_inlined_ids = 4; } // Task spec of an actor creation task. diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 0f8873f86..2698574c0 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -63,6 +63,8 @@ message ReturnObject { bytes data = 3; // Metadata of the object. bytes metadata = 4; + // ObjectIDs that were inlined in the data field. + repeated bytes inlined_ids = 5; } message PushTaskRequest { @@ -93,6 +95,19 @@ message PushTaskReply { repeated ReturnObject return_objects = 1; // Set to true if the worker will be exiting. bool worker_exiting = 2; + // The references that the worker borrowed during the task execution. A + // borrower is a process that is currently using the object ID, in one of 3 ways: + // 1. Has an ObjectID copy in Python. + // 2. Has submitted a task that depends on the object and that is still pending. + // 3. Owns another object that is in scope and whose value contains the + // ObjectID. + // This list includes the reference counts for any IDs that were passed to + // the worker in the task spec as an argument by reference, or an ObjectID + // that was serialized in an inlined argument. It also includes reference + // counts for any IDs that were nested inside these objects that the worker + // may now be borrowing. The reference counts also include any new borrowers + // that the worker created by passing a borrowed ID into a nested task. + repeated ObjectReferenceCount borrowed_refs = 3; } message DirectActorCallArgWaitCompleteRequest { @@ -150,6 +165,52 @@ message GetCoreWorkerStatsReply { CoreWorkerStats core_worker_stats = 1; } +message ObjectReference { + // ObjectID that the worker has a reference to. + bytes object_id = 1; + // The task or actor ID of the object's owner. + bytes owner_id = 2; + // The address of the object's owner. + Address owner_address = 3; +} + +message WaitForRefRemovedRequest { + // The ID of the worker this message is intended for. + bytes intended_worker_id = 1; + // Object whose removal we are waiting for. + ObjectReference reference = 2; + // ObjectID that contains object_id. This is used when an ObjectID is stored + // inside another object ID that we do not own. Then, we must notify the + // outer ID's owner that the ID contains object_id. + bytes contained_in_id = 3; +} + +message ObjectReferenceCount { + // The reference that the worker has or had a reference to. + ObjectReference reference = 1; + // Whether the worker is still using the ObjectID locally. This means that + // it has a copy of the ObjectID in the language frontend, has a pending task + // that depends on the object, and/or owns an ObjectID that is in scope and + // that contains the ObjectID. + bool has_local_ref = 2; + // Any other borrowers that the worker created (by passing the ID on to them). + repeated Address borrowers = 3; + // The borrowed object ID that contained this object, if any. This is used + // for nested object IDs. + bytes contained_in_borrowed_id = 4; + // The object IDs that this object contains, if any. This is used for nested + // object IDs. + repeated bytes contains = 5; +} + +message WaitForRefRemovedReply { + // The reference counts for the object that the worker was borrowing and + // any objects nested inside. The worker should no longer be using the object + // ID by the time it replies, but may have accumulated other borrowers or may + // still be borrowing an object ID that was nested inside. + repeated ObjectReferenceCount borrowed_refs = 1; +} + service CoreWorkerService { // Push a task to a worker from the raylet. rpc AssignTask(AssignTaskRequest) returns (AssignTaskReply); @@ -168,4 +229,6 @@ service CoreWorkerService { rpc KillActor(KillActorRequest) returns (KillActorReply); // Get metrics from core workers. rpc GetCoreWorkerStats(GetCoreWorkerStatsRequest) returns (GetCoreWorkerStatsReply); + // Wait for a borrower to finish using an object. Sent by the object's owner. + rpc WaitForRefRemoved(WaitForRefRemovedRequest) returns (WaitForRefRemovedReply); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index c25f0a250..0870fdc31 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -3074,8 +3074,7 @@ void NodeManager::HandlePinObjectIDs(const rpc::PinObjectIDsRequest &request, auto it = worker_rpc_clients_.find(worker_id); if (it == worker_rpc_clients_.end()) { auto client = std::unique_ptr( - new rpc::CoreWorkerClient(request.owner_address().ip_address(), - request.owner_address().port(), client_call_manager_)); + new rpc::CoreWorkerClient(request.owner_address(), client_call_manager_)); it = worker_rpc_clients_ .emplace(worker_id, std::make_pair, size_t>( @@ -3107,11 +3106,17 @@ void NodeManager::HandlePinObjectIDs(const rpc::PinObjectIDsRequest &request, for (const auto &object_id_binary : request.object_ids()) { ObjectID object_id = ObjectID::FromBinary(object_id_binary); + if (plasma_results[i].data == nullptr) { + RAY_LOG(ERROR) << "Plasma object " << object_id + << " was evicted before the raylet could pin it."; + continue; + } + RAY_LOG(DEBUG) << "Pinning object " << object_id; pinned_objects_.emplace( object_id, std::unique_ptr(new RayObject( std::make_shared(plasma_results[i].data), - std::make_shared(plasma_results[i].metadata)))); + std::make_shared(plasma_results[i].metadata), {}))); i++; // Send a long-running RPC request to the owner for each object. When we get a diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index f3d2622e6..688e7154d 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -24,8 +24,11 @@ Worker::Worker(const WorkerID &worker_id, const Language &language, int port, client_call_manager_(client_call_manager), is_detached_actor_(false) { if (port_ > 0) { + rpc::Address addr; + addr.set_ip_address("127.0.0.1"); + addr.set_port(port_); rpc_client_ = std::unique_ptr( - new rpc::CoreWorkerClient("127.0.0.1", port_, client_call_manager_)); + new rpc::CoreWorkerClient(addr, client_call_manager_)); } } diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index d52b10584..abfc0ef5d 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -40,6 +40,11 @@ class CoreWorkerClientInterface; // TODO(swang): Remove and replace with rpc::Address. class WorkerAddress { public: + WorkerAddress(const rpc::Address &address) + : ip_address(address.ip_address()), + port(address.port()), + worker_id(WorkerID::FromBinary(address.worker_id())), + raylet_id(ClientID::FromBinary(address.raylet_id())) {} template friend H AbslHashValue(H h, const WorkerAddress &w) { return H::combine(std::move(h), w.ip_address, w.port, w.worker_id, w.raylet_id); @@ -69,13 +74,17 @@ class WorkerAddress { const ClientID raylet_id; }; -typedef std::function(const std::string &, - int)> +typedef std::function(const rpc::Address &)> ClientFactoryFn; /// Abstract client interface for testing. class CoreWorkerClientInterface { public: + virtual const rpc::Address &Addr() const { + static const rpc::Address empty_addr_; + return empty_addr_; + } + /// This is called by the Raylet to assign a task to the worker. /// /// \param[in] request The request message. @@ -140,6 +149,12 @@ class CoreWorkerClientInterface { return Status::NotImplemented(""); } + virtual ray::Status WaitForRefRemoved( + const WaitForRefRemovedRequest &request, + const ClientCallback &callback) { + return Status::NotImplemented(""); + } + virtual ~CoreWorkerClientInterface(){}; }; @@ -152,13 +167,15 @@ class CoreWorkerClient : public std::enable_shared_from_this, /// \param[in] address Address of the worker server. /// \param[in] port Port of the worker server. /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. - CoreWorkerClient(const std::string &address, const int port, - ClientCallManager &client_call_manager) - : client_call_manager_(client_call_manager) { - grpc_client_ = std::unique_ptr>( - new GrpcClient(address, port, client_call_manager)); + CoreWorkerClient(const rpc::Address &address, ClientCallManager &client_call_manager) + : addr_(address), client_call_manager_(client_call_manager) { + grpc_client_ = + std::unique_ptr>(new GrpcClient( + addr_.ip_address(), addr_.port(), client_call_manager)); }; + const rpc::Address &Addr() const override { return addr_; } + RPC_CLIENT_METHOD(CoreWorkerService, AssignTask, grpc_client_, override) RPC_CLIENT_METHOD(CoreWorkerService, DirectActorCallArgWaitComplete, grpc_client_, @@ -172,6 +189,8 @@ class CoreWorkerClient : public std::enable_shared_from_this, RPC_CLIENT_METHOD(CoreWorkerService, GetCoreWorkerStats, grpc_client_, override) + RPC_CLIENT_METHOD(CoreWorkerService, WaitForRefRemoved, grpc_client_, override) + ray::Status PushActorTask(std::unique_ptr request, const ClientCallback &callback) override { request->set_sequence_number(request->task_spec().actor_task_spec().actor_counter()); @@ -241,6 +260,9 @@ class CoreWorkerClient : public std::enable_shared_from_this, /// Protects against unsafe concurrent access from the callback thread. std::mutex mutex_; + /// Address of the remote worker. + rpc::Address addr_; + /// The RPC client. std::unique_ptr> grpc_client_; diff --git a/src/ray/rpc/worker/core_worker_server.h b/src/ray/rpc/worker/core_worker_server.h index 10bd3441f..e32660b63 100644 --- a/src/ray/rpc/worker/core_worker_server.h +++ b/src/ray/rpc/worker/core_worker_server.h @@ -20,6 +20,7 @@ namespace rpc { RPC_SERVICE_HANDLER(CoreWorkerService, DirectActorCallArgWaitComplete, 100) \ RPC_SERVICE_HANDLER(CoreWorkerService, GetObjectStatus, 9999) \ RPC_SERVICE_HANDLER(CoreWorkerService, WaitForObjectEviction, 9999) \ + RPC_SERVICE_HANDLER(CoreWorkerService, WaitForRefRemoved, 9999) \ RPC_SERVICE_HANDLER(CoreWorkerService, KillActor, 9999) \ RPC_SERVICE_HANDLER(CoreWorkerService, GetCoreWorkerStats, 100) @@ -29,6 +30,7 @@ namespace rpc { DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(DirectActorCallArgWaitComplete) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetObjectStatus) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForObjectEviction) \ + DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WaitForRefRemoved) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(KillActor) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetCoreWorkerStats) diff --git a/src/ray/util/test_util.h b/src/ray/util/test_util.h index ee3236bb2..23c7bd023 100644 --- a/src/ray/util/test_util.h +++ b/src/ray/util/test_util.h @@ -56,8 +56,10 @@ std::shared_ptr GenerateRandomBuffer() { return std::make_shared(arg1.data(), arg1.size(), true); } -std::shared_ptr GenerateRandomObject() { - return std::shared_ptr(new RayObject(GenerateRandomBuffer(), nullptr)); +std::shared_ptr GenerateRandomObject( + const std::vector &inlined_ids = {}) { + return std::shared_ptr( + new RayObject(GenerateRandomBuffer(), nullptr, inlined_ids)); } /// Path to redis server executable binary. diff --git a/streaming/src/queue/transport.cc b/streaming/src/queue/transport.cc index d3a55f162..d764a5f09 100644 --- a/streaming/src/queue/transport.cc +++ b/streaming/src/queue/transport.cc @@ -22,11 +22,11 @@ void Transport::SendInternal(std::shared_ptr buffer, auto dummy = "__RAY_DUMMY__"; std::shared_ptr dummyBuffer = std::make_shared((uint8_t *)dummy, 13, true); - args.emplace_back(TaskArg::PassByValue( - std::make_shared(std::move(dummyBuffer), meta, true))); + args.emplace_back(TaskArg::PassByValue(std::make_shared( + std::move(dummyBuffer), meta, std::vector(), true))); } - args.emplace_back( - TaskArg::PassByValue(std::make_shared(std::move(buffer), meta, true))); + args.emplace_back(TaskArg::PassByValue(std::make_shared( + std::move(buffer), meta, std::vector(), true))); STREAMING_CHECK(core_worker_ != nullptr); std::vector> results; diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index 1bfce6276..c97860b6d 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -339,8 +339,8 @@ class StreamingWorker { STREAMING_LOG(INFO) << "Test name: " << typed_descriptor->ClassName(); test_suite_->ExecuteTest(typed_descriptor->ClassName()); } else if (func_name == "check_current_test_status") { - results->push_back( - std::make_shared(test_suite_->CheckCurTestStatus(), nullptr)); + results->push_back(std::make_shared(test_suite_->CheckCurTestStatus(), + nullptr, std::vector())); } else if (func_name == "reader_sync_call_func") { if (test_suite_->TestDone()) { STREAMING_LOG(WARNING) << "Test has done!!"; @@ -350,7 +350,8 @@ class StreamingWorker { std::make_shared(args[1]->GetData()->Data(), args[1]->GetData()->Size(), true); auto result_buffer = reader_client_->OnReaderMessageSync(local_buffer); - results->push_back(std::make_shared(result_buffer, nullptr)); + results->push_back( + std::make_shared(result_buffer, nullptr, std::vector())); } else if (func_name == "reader_async_call_func") { if (test_suite_->TestDone()) { STREAMING_LOG(WARNING) << "Test has done!!"; @@ -369,7 +370,8 @@ class StreamingWorker { std::make_shared(args[1]->GetData()->Data(), args[1]->GetData()->Size(), true); auto result_buffer = writer_client_->OnWriterMessageSync(local_buffer); - results->push_back(std::make_shared(result_buffer, nullptr)); + results->push_back( + std::make_shared(result_buffer, nullptr, std::vector())); } else if (func_name == "writer_async_call_func") { if (test_suite_->TestDone()) { STREAMING_LOG(WARNING) << "Test has done!!"; diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h index fe5a8a109..7d62f2c52 100644 --- a/streaming/src/test/queue_tests_base.h +++ b/streaming/src/test/queue_tests_base.h @@ -161,8 +161,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { queue_ids, rescale_queue_ids, suite_name, test_name, param); std::vector args; - args.emplace_back( - TaskArg::PassByValue(std::make_shared(msg.ToBytes(), nullptr, true))); + args.emplace_back(TaskArg::PassByValue(std::make_shared( + msg.ToBytes(), nullptr, std::vector(), true))); std::unordered_map resources; TaskOptions options{0, true, resources}; std::vector return_ids; @@ -176,8 +176,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { uint8_t data[8]; auto buffer = std::make_shared(data, 8, true); std::vector args; - args.emplace_back( - TaskArg::PassByValue(std::make_shared(buffer, nullptr, true))); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(buffer, nullptr, std::vector(), true))); std::unordered_map resources; TaskOptions options{0, true, resources}; std::vector return_ids; @@ -191,8 +191,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { uint8_t data[8]; auto buffer = std::make_shared(data, 8, true); std::vector args; - args.emplace_back( - TaskArg::PassByValue(std::make_shared(buffer, nullptr, true))); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(buffer, nullptr, std::vector(), true))); std::unordered_map resources; TaskOptions options{1, true, resources}; std::vector return_ids; @@ -260,7 +260,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( "actor creation task", "", "", "")}; std::vector args; - args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr))); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(buffer, nullptr, std::vector()))); ActorCreationOptions actor_options{ max_reconstructions, is_direct_call,