diff --git a/python/ray/tests/test_reference_counting.py b/python/ray/tests/test_reference_counting.py index c995f6aea..4155e68e5 100644 --- a/python/ray/tests/test_reference_counting.py +++ b/python/ray/tests/test_reference_counting.py @@ -38,10 +38,12 @@ def _fill_object_store_and_get(oid, succeed=True, object_MiB=40, oid = ray.ObjectID(oid) if succeed: - ray.get(oid) + wait_for_condition( + lambda: ray.worker.global_worker.core_worker.object_exists(oid)) else: - with pytest.raises(ray.exceptions.RayTimeoutError): - ray.get(oid, timeout=0.1) + wait_for_condition( + lambda: not ray.worker.global_worker.core_worker.object_exists(oid) + ) def _check_refcounts(expected): diff --git a/python/ray/tests/test_reference_counting_2.py b/python/ray/tests/test_reference_counting_2.py index 87a1db5ed..6f4fcde9b 100644 --- a/python/ray/tests/test_reference_counting_2.py +++ b/python/ray/tests/test_reference_counting_2.py @@ -41,10 +41,12 @@ def _fill_object_store_and_get(oid, succeed=True, object_MiB=40, oid = ray.ObjectID(oid) if succeed: - ray.get(oid) + wait_for_condition( + lambda: ray.worker.global_worker.core_worker.object_exists(oid)) else: - with pytest.raises(ray.exceptions.RayTimeoutError): - ray.get(oid, timeout=0.1) + wait_for_condition( + lambda: not ray.worker.global_worker.core_worker.object_exists(oid) + ) # Test that an object containing object IDs within it pins the inner IDs diff --git a/src/ray/common/task/task.cc b/src/ray/common/task/task.cc index 827408cf4..280c8e1d9 100644 --- a/src/ray/common/task/task.cc +++ b/src/ray/common/task/task.cc @@ -12,7 +12,9 @@ const TaskSpecification &Task::GetTaskSpecification() const { return task_spec_; void Task::IncrementNumForwards() { task_execution_spec_.IncrementNumForwards(); } -const std::vector &Task::GetDependencies() const { return dependencies_; } +const std::vector &Task::GetDependencies() const { + return dependencies_; +} void Task::ComputeDependencies() { dependencies_ = task_spec_.GetDependencies(); } diff --git a/src/ray/common/task/task.h b/src/ray/common/task/task.h index 57f1a2aa2..282a641f1 100644 --- a/src/ray/common/task/task.h +++ b/src/ray/common/task/task.h @@ -80,7 +80,7 @@ class Task { /// arguments and the mutable execution dependencies. /// /// \return The object dependencies. - const std::vector &GetDependencies() const; + const std::vector &GetDependencies() const; /// Update the dynamic/mutable information for this task. /// \param task Task structure with updated dynamic information. @@ -110,7 +110,7 @@ class Task { /// A cached copy of the task's object dependencies, including arguments from /// the TaskSpecification and execution dependencies from the /// TaskExecutionSpecification. - std::vector dependencies_; + std::vector dependencies_; /// For direct task calls, overrides the dispatch behaviour to send an RPC /// back to the submitting worker. diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index f96c75858..f77851bca 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -117,6 +117,11 @@ ObjectID TaskSpecification::ArgId(size_t arg_index) const { return ObjectID::FromBinary(message_->args(arg_index).object_ref().object_id()); } +rpc::ObjectReference TaskSpecification::ArgRef(size_t arg_index) const { + RAY_CHECK(ArgByRef(arg_index)); + return message_->args(arg_index).object_ref(); +} + const uint8_t *TaskSpecification::ArgData(size_t arg_index) const { return reinterpret_cast(message_->args(arg_index).data().data()); } @@ -141,7 +146,7 @@ const ResourceSet &TaskSpecification::GetRequiredResources() const { return *required_resources_; } -std::vector TaskSpecification::GetDependencies() const { +std::vector TaskSpecification::GetDependencyIds() const { std::vector dependencies; for (size_t i = 0; i < NumArgs(); ++i) { if (ArgByRef(i)) { @@ -154,6 +159,21 @@ std::vector TaskSpecification::GetDependencies() const { return dependencies; } +std::vector TaskSpecification::GetDependencies() const { + std::vector dependencies; + for (size_t i = 0; i < NumArgs(); ++i) { + if (ArgByRef(i)) { + dependencies.push_back(message_->args(i).object_ref()); + } + } + if (IsActorTask()) { + const auto &dummy_ref = + GetReferenceForActorDummyObject(PreviousActorTaskDummyObjectId()); + dependencies.push_back(dummy_ref); + } + return dependencies; +} + const ResourceSet &TaskSpecification::GetRequiredPlacementResources() const { return *required_placement_resources_; } diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index cb58b44e1..ccd7db733 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -20,6 +20,13 @@ namespace ray { typedef ResourceSet SchedulingClassDescriptor; typedef int SchedulingClass; +static inline rpc::ObjectReference GetReferenceForActorDummyObject( + const ObjectID &object_id) { + rpc::ObjectReference ref; + ref.set_object_id(object_id.Binary()); + return ref; +}; + /// Wrapper class of protobuf `TaskSpec`, see `common.proto` for details. /// TODO(ekl) we should consider passing around std::unique_ptrs /// instead `const TaskSpecification`, since this class is actually mutable. @@ -71,6 +78,8 @@ class TaskSpecification : public MessageWrapper { ObjectID ArgId(size_t arg_index) const; + rpc::ObjectReference ArgRef(size_t arg_index) const; + ObjectID ReturnId(size_t return_index) const; const uint8_t *ArgData(size_t arg_index) const; @@ -109,11 +118,18 @@ class TaskSpecification : public MessageWrapper { /// \return The resources that are required to place a task on a node. const ResourceSet &GetRequiredPlacementResources() const; + /// Return the ObjectIDs of any dependencies passed by reference to this + /// task. This is recomputed each time, so it can be used if the task spec is + /// mutated. + /// + /// \return The recomputed IDs of the dependencies for the task. + std::vector GetDependencyIds() const; + /// Return the dependencies of this task. This is recomputed each time, so it can /// be used if the task spec is mutated. /// /// \return The recomputed dependencies for the task. - std::vector GetDependencies() const; + std::vector GetDependencies() const; bool IsDriverTask() const; diff --git a/src/ray/common/test_util.h b/src/ray/common/test_util.h index 3516bb337..529e94a4a 100644 --- a/src/ray/common/test_util.h +++ b/src/ray/common/test_util.h @@ -23,10 +23,22 @@ #include "gtest/gtest.h" #include "ray/common/id.h" +#include "ray/protobuf/common.pb.h" #include "ray/util/util.h" namespace ray { +static inline std::vector ObjectIdsToRefs( + std::vector object_ids) { + std::vector refs; + for (const auto &object_id : object_ids) { + rpc::ObjectReference ref; + ref.set_object_id(object_id.Binary()); + refs.push_back(ref); + } + return refs; +} + class Buffer; class RayObject; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index d640b228d..a9f6e3023 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -354,7 +354,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ actor_reporter_ = std::unique_ptr(new ActorReporter(gcs_client_)); plasma_store_provider_.reset(new CoreWorkerPlasmaStoreProvider( - options_.store_socket, local_raylet_client_, options_.check_signals, + options_.store_socket, local_raylet_client_, reference_counter_, + options_.check_signals, /*evict_if_full=*/RayConfig::instance().object_pinning_enabled(), boost::bind(&CoreWorker::TriggerGlobalGC, this), boost::bind(&CoreWorker::CurrentCallSite, this))); @@ -368,7 +369,9 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ auto check_node_alive_fn = [this](const ClientID &node_id) { auto node = gcs_client_->Nodes().Get(node_id); - RAY_CHECK(node.has_value()); + if (!node) { + return false; + } return node->state() == rpc::GcsNodeInfo::ALIVE; }; auto reconstruct_object_callback = [this](const ObjectID &object_id) { @@ -447,7 +450,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ future_resolver_.reset(new FutureResolver(memory_store_, client_factory, rpc_address_)); // Unfortunately the raylet client has to be constructed after the receivers. if (direct_task_receiver_ != nullptr) { - direct_task_receiver_->Init(client_factory, rpc_address_, local_raylet_client_); + task_argument_waiter_.reset(new DependencyWaiterImpl(*local_raylet_client_)); + direct_task_receiver_->Init(client_factory, rpc_address_, task_argument_waiter_); } actor_manager_ = std::unique_ptr( @@ -1415,12 +1419,8 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, // execution and unpinned once the task completes. We will notify the caller // about any IDs that we are still borrowing by the time the task completes. std::vector borrowed_ids; - RAY_CHECK_OK(BuildArgsForExecutor(task_spec, &args, &arg_reference_ids, &borrowed_ids)); - // Pin the borrowed IDs for the duration of the task. - for (const auto &borrowed_id : borrowed_ids) { - RAY_LOG(DEBUG) << "Incrementing ref for borrowed ID " << borrowed_id; - reference_counter_->AddLocalReference(borrowed_id, task_spec.CallSiteString()); - } + RAY_CHECK_OK( + GetAndPinArgsForExecutor(task_spec, &args, &arg_reference_ids, &borrowed_ids)); std::vector return_ids; for (size_t i = 0; i < task_spec.NumReturns(); i++) { @@ -1468,8 +1468,9 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, // Get the reference counts for any IDs that we borrowed during this task and // return them to the caller. This will notify the caller of any IDs that we - // (or a nested task) are still borrowing. It will also any new IDs that were - // contained in a borrowed ID that we (or a nested task) are now borrowing. + // (or a nested task) are still borrowing. It will also notify the caller of + // any new IDs that were contained in a borrowed ID that we (or a nested + // task) are now borrowing. if (!borrowed_ids.empty()) { reference_counter_->GetAndClearLocalBorrowers(borrowed_ids, borrowed_refs); } @@ -1532,10 +1533,10 @@ void CoreWorker::ExecuteTaskLocalMode(const TaskSpecification &task_spec, SetActorId(old_id); } -Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, - std::vector> *args, - std::vector *arg_reference_ids, - std::vector *borrowed_ids) { +Status CoreWorker::GetAndPinArgsForExecutor(const TaskSpecification &task, + std::vector> *args, + std::vector *arg_reference_ids, + std::vector *borrowed_ids) { auto num_args = task.NumArgs(); args->resize(num_args); arg_reference_ids->resize(num_args); @@ -1560,10 +1561,15 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, it->second.push_back(i); } arg_reference_ids->at(i) = arg_id; - // The task borrows all args passed by reference. Because the task does - // not have a reference to the argument ID in the frontend, it is not - // possible for the task to still be borrowing the argument by the time - // it finishes. + // Pin all args passed by reference for the duration of the task. This + // ensures that when the task completes, we can retrieve metadata about + // any borrowed ObjectIDs that were serialized in the argument's value. + RAY_LOG(DEBUG) << "Incrementing ref for argument ID " << arg_id; + reference_counter_->AddLocalReference(arg_id, task.CallSiteString()); + // Attach the argument's owner's address. This is needed to retrieve the + // value from plasma. + reference_counter_->AddBorrowedObject(arg_id, ObjectID::Nil(), + task.ArgRef(i).owner_address()); borrowed_ids->push_back(arg_id); } else { // A pass-by-value argument. @@ -1585,6 +1591,11 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, // possible for the task to continue borrowing these arguments by the // time it finishes. for (const auto &inlined_id : task.ArgInlinedIds(i)) { + RAY_LOG(DEBUG) << "Incrementing ref for borrowed ID " << inlined_id; + // We do not need to add the ownership information here because it will + // get added once the language frontend deserializes the value, before + // the ObjectID can be used. + reference_counter_->AddLocalReference(inlined_id, task.CallSiteString()); borrowed_ids->push_back(inlined_id); } } @@ -1649,10 +1660,14 @@ void CoreWorker::HandleDirectActorCallArgWaitComplete( return; } + // Post on the task execution event loop since this may trigger the + // execution of a task that is now ready to run. task_execution_service_.post([=] { - direct_task_receiver_->HandleDirectActorCallArgWaitComplete(request, reply, - send_reply_callback); + RAY_LOG(DEBUG) << "Arg wait complete for tag " << request.tag(); + task_argument_waiter_->OnWaitComplete(request.tag()); }); + + send_reply_callback(Status::OK(), nullptr, nullptr); } void CoreWorker::HandleGetObjectStatus(const rpc::GetObjectStatusRequest &request, diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 5fa20af72..c5d9f51f9 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -849,10 +849,17 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { void ExecuteTaskLocalMode(const TaskSpecification &task_spec, const ActorID &actor_id = ActorID::Nil()); - /// Build arguments for task executor. This would loop through all the arguments - /// in task spec, and for each of them that's passed by reference (ObjectID), - /// fetch its content from store and; for arguments that are passed by value, - /// just copy their content. + /// Get the values of the task arguments for the executor. Values are + /// retrieved from the local plasma store or, if the value is inlined, from + /// the task spec. + /// + /// This also pins all plasma arguments and ObjectIDs that were contained in + /// an inlined argument by adding a local reference in the reference counter. + /// This is to ensure that we have the address of the object's owner, which + /// is needed to retrieve the value. It also ensures that when the task + /// completes, we can retrieve any metadata about objects that are still + /// being borrowed by this process. The IDs should be unpinned once the task + /// completes. /// /// \param spec[in] task Task specification. /// \param args[out] args Argument data as RayObjects. @@ -863,16 +870,16 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// // TODO(edoakes): this is a bit of a hack that's necessary because /// we have separate serialization paths for by-value and by-reference /// arguments in Python. This should ideally be handled better there. - /// \param args[out] borrowed_ids ObjectIDs that we are borrowing from the - /// task caller for the duration of the task execution. This - /// vector will be populated with all argument IDs that were - /// passed by reference and any ObjectIDs that were included - /// in the task spec's inlined arguments. - /// \return The arguments for passing to task executor. - Status BuildArgsForExecutor(const TaskSpecification &task, - std::vector> *args, - std::vector *arg_reference_ids, - std::vector *borrowed_ids); + /// \param args[out] pinned_ids ObjectIDs that should be unpinned once the + /// task completes execution. This vector will be populated + /// with all argument IDs that were passed by reference and + /// any ObjectIDs that were included in the task spec's + /// inlined arguments. + /// \return Error if the values could not be retrieved. + Status GetAndPinArgsForExecutor(const TaskSpecification &task, + std::vector> *args, + std::vector *arg_reference_ids, + std::vector *pinned_ids); /// Returns whether the message was sent to the wrong worker. The right error reply /// is sent automatically. Messages end up on the wrong worker when a worker dies @@ -1048,6 +1055,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Common rpc service for all worker modules. rpc::CoreWorkerGrpcService grpc_service_; + /// Used to notify the task receiver when the arguments of a queued + /// actor task are ready. + std::shared_ptr task_argument_waiter_; + // Interface that receives tasks from direct actor calls. std::unique_ptr direct_task_receiver_; diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index 1b2f4951b..5113f47cd 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -325,6 +325,11 @@ void ReferenceCounter::RemoveSubmittedTaskReferences( bool ReferenceCounter::GetOwner(const ObjectID &object_id, rpc::Address *owner_address) const { absl::MutexLock lock(&mutex_); + return GetOwnerInternal(object_id, owner_address); +} + +bool ReferenceCounter::GetOwnerInternal(const ObjectID &object_id, + rpc::Address *owner_address) const { auto it = object_id_refs_.find(object_id); if (it == object_id_refs_.end()) { return false; @@ -338,6 +343,31 @@ bool ReferenceCounter::GetOwner(const ObjectID &object_id, } } +std::vector ReferenceCounter::GetOwnerAddresses( + const std::vector object_ids) const { + absl::MutexLock lock(&mutex_); + std::vector owner_addresses; + for (const auto &object_id : object_ids) { + rpc::Address owner_addr; + bool has_owner = GetOwnerInternal(object_id, &owner_addr); + if (!has_owner) { + RAY_LOG(WARNING) + << " Object IDs generated randomly (ObjectID.from_random()) or out-of-band " + "(ObjectID.from_binary(...)) cannot be passed to ray.get(), ray.wait(), or " + "as " + "a task argument because Ray does not know which task will create them. " + "If this was not how your object ID was generated, please file an issue " + "at https://github.com/ray-project/ray/issues/"; + // TODO(swang): Java does not seem to keep the ref count properly, so the + // entry may get deleted. + owner_addresses.push_back(rpc::Address()); + } else { + owner_addresses.push_back(owner_addr); + } + } + return owner_addresses; +} + void ReferenceCounter::FreePlasmaObjects(const std::vector &object_ids) { absl::MutexLock lock(&mutex_); for (const ObjectID &object_id : object_ids) { diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index 915572472..155090eb9 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -190,13 +190,24 @@ class ReferenceCounter : public ReferenceCounterInterface { bool AddBorrowedObject(const ObjectID &object_id, const ObjectID &outer_id, const rpc::Address &owner_address) LOCKS_EXCLUDED(mutex_); - /// Get the owner ID and address of the given object. + /// Get the owner address of the given object. /// /// \param[in] object_id The ID of the object to look up. /// \param[out] owner_address The address of the object owner. + /// \return false if the object is out of scope or we do not yet have + /// ownership information. The latter can happen when object IDs are pasesd + /// out of band. bool GetOwner(const ObjectID &object_id, rpc::Address *owner_address = nullptr) const LOCKS_EXCLUDED(mutex_); + /// Get the owner addresses of the given objects. The owner address + /// must be registered for these objects. + /// + /// \param[in] object_ids The IDs of the object to look up. + /// \return The addresses of the objects' owners. + std::vector GetOwnerAddresses( + const std::vector object_ids) const; + /// Release the underlying value from plasma (if any) for these objects. /// /// \param[in] object_ids The IDs whose values to free. @@ -498,6 +509,10 @@ class ReferenceCounter : public ReferenceCounterInterface { using ReferenceTable = absl::flat_hash_map; + bool GetOwnerInternal(const ObjectID &object_id, + rpc::Address *owner_address = nullptr) const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + /// Release the pinned plasma object, if any. Also unsets the raylet address /// that the object was pinned at, if the address was set. void ReleasePlasmaObject(ReferenceTable::iterator it); diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index fb7e4d967..616f8ed54 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -24,10 +24,12 @@ namespace ray { CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider( const std::string &store_socket, const std::shared_ptr raylet_client, + const std::shared_ptr reference_counter, std::function check_signals, bool evict_if_full, std::function on_store_full, std::function get_current_call_site) : raylet_client_(raylet_client), + reference_counter_(reference_counter), check_signals_(check_signals), evict_if_full_(evict_if_full), on_store_full_(on_store_full) { @@ -156,8 +158,10 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore( int64_t timeout_ms, bool fetch_only, bool in_direct_call, const TaskID &task_id, absl::flat_hash_map> *results, bool *got_exception) { + const auto owner_addresses = reference_counter_->GetOwnerAddresses(batch_ids); RAY_RETURN_NOT_OK(raylet_client_->FetchOrReconstruct( - batch_ids, fetch_only, /*mark_worker_blocked*/ !in_direct_call, task_id)); + batch_ids, owner_addresses, fetch_only, /*mark_worker_blocked*/ !in_direct_call, + task_id)); std::vector plasma_results; { @@ -335,10 +339,11 @@ Status CoreWorkerPlasmaStoreProvider::Wait( if (ctx.CurrentTaskIsDirectCall() && ctx.ShouldReleaseResourcesOnBlockingCalls()) { RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked()); } - RAY_RETURN_NOT_OK( - raylet_client_->Wait(id_vector, num_objects, call_timeout, /*wait_local*/ true, - /*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(), - ctx.GetCurrentTaskID(), &result_pair)); + const auto owner_addresses = reference_counter_->GetOwnerAddresses(id_vector); + RAY_RETURN_NOT_OK(raylet_client_->Wait( + id_vector, owner_addresses, num_objects, call_timeout, /*wait_local*/ true, + /*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(), ctx.GetCurrentTaskID(), + &result_pair)); if (result_pair.first.size() >= static_cast(num_objects)) { should_break = true; diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index 34626e058..1ebc1f5b3 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -21,6 +21,7 @@ #include "ray/common/status.h" #include "ray/core_worker/common.h" #include "ray/core_worker/context.h" +#include "ray/core_worker/reference_count.h" #include "ray/object_manager/plasma/client.h" #include "ray/raylet/raylet_client.h" @@ -35,6 +36,7 @@ class CoreWorkerPlasmaStoreProvider { CoreWorkerPlasmaStoreProvider( const std::string &store_socket, const std::shared_ptr raylet_client, + const std::shared_ptr reference_counter, std::function check_signals, bool evict_if_full, std::function on_store_full = nullptr, std::function get_current_call_site = nullptr); @@ -138,6 +140,8 @@ class CoreWorkerPlasmaStoreProvider { const std::shared_ptr raylet_client_; plasma::PlasmaClient store_client_; + /// Used to look up a plasma object's owner. + const std::shared_ptr reference_counter_; std::mutex store_client_mutex_; std::function check_signals_; const bool evict_if_full_; diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index 34266d07e..655b94e1b 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -332,12 +332,10 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartRetry) { ASSERT_THAT(worker_client_->received_seq_nos, ElementsAre(0, 1, 2, 2, 0, 1)); } -class MockDependencyWaiterInterface : public DependencyWaiterInterface { +class MockDependencyWaiter : public DependencyWaiter { public: - virtual Status WaitForDirectActorCallArgs(const std::vector &object_ids, - int64_t tag) override { - return Status::OK(); - } + MOCK_METHOD2(Wait, void(const std::vector &dependencies, + std::function on_dependencies_available)); }; class MockWorkerContext : public WorkerContext { @@ -353,7 +351,7 @@ class DirectActorReceiverTest : public ::testing::Test { DirectActorReceiverTest() : worker_context_(WorkerType::WORKER, JobID::FromInt(0)), worker_client_(std::shared_ptr(new MockWorkerClient())), - dependency_client_(std::make_shared()) { + dependency_waiter_(std::make_shared()) { auto execute_task = std::bind(&DirectActorReceiverTest::MockExecuteTask, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); @@ -361,7 +359,7 @@ class DirectActorReceiverTest : public ::testing::Test { new CoreWorkerDirectTaskReceiver(worker_context_, main_io_service_, execute_task, [] { return Status::OK(); })); receiver_->Init([&](const rpc::Address &addr) { return worker_client_; }, - rpc_address_, dependency_client_); + rpc_address_, dependency_waiter_); } Status MockExecuteTask(const TaskSpecification &task_spec, @@ -387,7 +385,7 @@ class DirectActorReceiverTest : public ::testing::Test { MockWorkerContext worker_context_; boost::asio::io_service main_io_service_; std::shared_ptr worker_client_; - std::shared_ptr dependency_client_; + std::shared_ptr dependency_waiter_; }; TEST_F(DirectActorReceiverTest, TestNewTaskFromDifferentWorker) { diff --git a/src/ray/core_worker/test/scheduling_queue_test.cc b/src/ray/core_worker/test/scheduling_queue_test.cc index 322bc79d2..98d38d3af 100644 --- a/src/ray/core_worker/test/scheduling_queue_test.cc +++ b/src/ray/core_worker/test/scheduling_queue_test.cc @@ -15,6 +15,7 @@ #include #include "gtest/gtest.h" +#include "ray/common/test_util.h" #include "ray/core_worker/transport/direct_actor_transport.h" namespace ray { @@ -23,7 +24,7 @@ class MockWaiter : public DependencyWaiter { public: MockWaiter() {} - void Wait(const std::vector &dependencies, + void Wait(const std::vector &dependencies, std::function on_dependencies_available) override { callbacks_.push_back([on_dependencies_available]() { on_dependencies_available(); }); } @@ -65,9 +66,9 @@ TEST(SchedulingQueueTest, TestWaitForObjects) { auto fn_ok = [&n_ok]() { n_ok++; }; auto fn_rej = [&n_rej]() { n_rej++; }; queue.Add(0, -1, fn_ok, fn_rej); - queue.Add(1, -1, fn_ok, fn_rej, {obj1}); - queue.Add(2, -1, fn_ok, fn_rej, {obj2}); - queue.Add(3, -1, fn_ok, fn_rej, {obj3}); + queue.Add(1, -1, fn_ok, fn_rej, ObjectIdsToRefs({obj1})); + queue.Add(2, -1, fn_ok, fn_rej, ObjectIdsToRefs({obj2})); + queue.Add(3, -1, fn_ok, fn_rej, ObjectIdsToRefs({obj3})); ASSERT_EQ(n_ok, 1); waiter.Complete(0); @@ -91,7 +92,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) { auto fn_ok = [&n_ok]() { n_ok++; }; auto fn_rej = [&n_rej]() { n_rej++; }; queue.Add(0, -1, fn_ok, fn_rej); - queue.Add(1, -1, fn_ok, fn_rej, {obj1}); + queue.Add(1, -1, fn_ok, fn_rej, ObjectIdsToRefs({obj1})); ASSERT_EQ(n_ok, 1); io_service.run(); ASSERT_EQ(n_rej, 0); diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index 227ad5a37..219c59fa6 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -505,7 +505,7 @@ TEST_F(TaskManagerLineageTest, TestResubmitTask) { // The task finished, its return ID is still in scope, and the return object // was stored in plasma. It is okay to resubmit it now. ASSERT_TRUE(manager_.ResubmitTask(spec.TaskId(), &resubmitted_task_deps).ok()); - ASSERT_EQ(resubmitted_task_deps, spec.GetDependencies()); + ASSERT_EQ(resubmitted_task_deps, spec.GetDependencyIds()); ASSERT_EQ(num_retries_, 1); resubmitted_task_deps.clear(); diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 10716939a..726de288c 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -269,8 +269,8 @@ bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) c void CoreWorkerDirectTaskReceiver::Init( rpc::ClientFactoryFn client_factory, rpc::Address rpc_address, - std::shared_ptr dependency_client) { - waiter_.reset(new DependencyWaiterImpl(*dependency_client)); + std::shared_ptr dependency_waiter) { + waiter_ = std::move(dependency_waiter); rpc_address_ = rpc_address; client_factory_ = client_factory; } @@ -293,13 +293,6 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask( return; } - std::vector dependencies; - for (size_t i = 0; i < task_spec.NumArgs(); ++i) { - if (task_spec.ArgByRef(i)) { - dependencies.push_back(task_spec.ArgId(i)); - } - } - // Only assign resources for non-actor tasks. Actor tasks inherit the resources // assigned at initial actor creation time. std::shared_ptr resource_ids; @@ -394,17 +387,14 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask( SchedulingQueue(task_main_io_service_, *waiter_, worker_context_)); it = result.first; } + auto dependencies = task_spec.GetDependencies(); + // Pop the dummy actor dependency. + if (task_spec.IsActorTask()) { + // TODO(swang): Remove this with legacy raylet code. + dependencies.pop_back(); + } it->second.Add(request.sequence_number(), request.client_processed_up_to(), accept_callback, reject_callback, dependencies); } -void CoreWorkerDirectTaskReceiver::HandleDirectActorCallArgWaitComplete( - const rpc::DirectActorCallArgWaitCompleteRequest &request, - rpc::DirectActorCallArgWaitCompleteReply *reply, - rpc::SendReplyCallback send_reply_callback) { - RAY_LOG(DEBUG) << "Arg wait complete for tag " << request.tag(); - waiter_->OnWaitComplete(request.tag()); - send_reply_callback(Status::OK(), nullptr, nullptr); -} - } // namespace ray diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index f4ff3d34e..9494d1144 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -241,7 +241,7 @@ class InboundRequest { class DependencyWaiter { public: /// Calls `callback` once the specified objects become available. - virtual void Wait(const std::vector &dependencies, + virtual void Wait(const std::vector &dependencies, std::function on_dependencies_available) = 0; }; @@ -250,7 +250,7 @@ class DependencyWaiterImpl : public DependencyWaiter { DependencyWaiterImpl(DependencyWaiterInterface &dependency_client) : dependency_client_(dependency_client) {} - void Wait(const std::vector &dependencies, + void Wait(const std::vector &dependencies, std::function on_dependencies_available) override { auto tag = next_request_id_++; requests_[tag] = on_dependencies_available; @@ -320,7 +320,7 @@ class SchedulingQueue { void Add(int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, std::function reject_request, - const std::vector &dependencies = {}) { + const std::vector &dependencies = {}) { if (seq_no == -1) { accept_request(); // A seq_no of -1 means no ordering constraint. return; @@ -474,7 +474,7 @@ class CoreWorkerDirectTaskReceiver { /// Initialize this receiver. This must be called prior to use. void Init(rpc::ClientFactoryFn client_factory, rpc::Address rpc_address, - std::shared_ptr dependency_client); + std::shared_ptr dependency_waiter); /// Handle a `PushTask` request. /// @@ -484,16 +484,6 @@ class CoreWorkerDirectTaskReceiver { void HandlePushTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, rpc::SendReplyCallback send_reply_callback); - /// Handle a `DirectActorCallArgWaitComplete` request. - /// - /// \param[in] request The request message. - /// \param[out] reply The reply message. - /// \param[in] send_reply_callback The callback to be called when the request is done. - void HandleDirectActorCallArgWaitComplete( - const rpc::DirectActorCallArgWaitCompleteRequest &request, - rpc::DirectActorCallArgWaitCompleteReply *reply, - rpc::SendReplyCallback send_reply_callback); - private: // Worker context. WorkerContext &worker_context_; @@ -508,7 +498,7 @@ class CoreWorkerDirectTaskReceiver { /// Address of our RPC server. rpc::Address rpc_address_; /// Shared waiter for dependencies required by incoming tasks. - std::unique_ptr waiter_; + std::shared_ptr waiter_; /// Queue of pending requests per actor handle. /// TODO(ekl) GC these queues once the handle is no longer active. std::unordered_map scheduling_queue_; diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index 133639f62..14a3f34c5 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -57,7 +57,7 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) { // Note that the dependencies in the task spec are mutated to only contain // plasma dependencies after ResolveDependencies finishes. const SchedulingKey scheduling_key( - task_spec.GetSchedulingClass(), task_spec.GetDependencies(), + task_spec.GetSchedulingClass(), task_spec.GetDependencyIds(), task_spec.IsActorCreationTask() ? task_spec.ActorCreationId() : ActorID::Nil()); auto it = task_queues_.find(scheduling_key); @@ -300,7 +300,7 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec, bool force_kill) { RAY_LOG(INFO) << "Killing task: " << task_spec.TaskId(); const SchedulingKey scheduling_key( - task_spec.GetSchedulingClass(), task_spec.GetDependencies(), + task_spec.GetSchedulingClass(), task_spec.GetDependencyIds(), task_spec.IsActorCreationTask() ? task_spec.ActorCreationId() : ActorID::Nil()); std::shared_ptr client = nullptr; { diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 266d9d141..7c7df6d56 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -188,9 +188,20 @@ table ForwardTaskRequest { uncommitted_tasks: [Task]; } +// Mimics the Address protobuf. +table Address { + raylet_id: string; + ip_address: string; + port: int; + // Optional unique id for the worker. + worker_id: string; +} + table FetchOrReconstruct { // List of object IDs of the objects that we want to reconstruct or fetch. object_ids: [string]; + // The RPC addresses of the workers that own the objects in object_ids. + owner_addresses: [Address]; // Do we only want to fetch the objects or also reconstruct them? fetch_only: bool; // False for direct call tasks. Blocking for those tasks is handled via the @@ -214,6 +225,8 @@ table NotifyDirectCallTaskUnblocked { table WaitRequest { // List of object ids we'll be waiting on. object_ids: [string]; + // The RPC addresses of the workers that own the objects in object_ids. + owner_addresses: [Address]; // Number of objects expected to be returned, if available. num_ready_objects: int; // timeout @@ -237,6 +250,8 @@ table WaitReply { table WaitForDirectActorCallArgsRequest { // List of object ids we'll be waiting on. object_ids: [string]; + // The RPC addresses of the workers that own the objects in object_ids. + owner_addresses: [Address]; // Id used to uniquely identify this request. This is sent back to the core // worker to notify the wait has completed. tag: int; diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 44c585ede..6730d9884 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -62,7 +62,7 @@ void LineageEntry::ComputeParentTaskIds() { parent_task_ids_.clear(); // A task's parents are the tasks that created its arguments. for (const auto &dependency : task_.GetDependencies()) { - parent_task_ids_.insert(dependency.TaskId()); + parent_task_ids_.insert(ObjectID::FromBinary(dependency.object_id()).TaskId()); } } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e9e8e5bd5..1e13f6a4c 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -78,6 +78,25 @@ ActorStats GetActorStatisticalData( return item; } +std::vector FlatbufferToObjectReference( + const flatbuffers::Vector> &object_ids, + const flatbuffers::Vector> + &owner_addresses) { + RAY_CHECK(object_ids.size() == owner_addresses.size()); + std::vector refs; + for (int64_t i = 0; i < object_ids.size(); i++) { + ray::rpc::ObjectReference ref; + ref.set_object_id(object_ids.Get(i)->str()); + const auto &addr = owner_addresses.Get(i); + ref.mutable_owner_address()->set_raylet_id(addr->raylet_id()->str()); + ref.mutable_owner_address()->set_ip_address(addr->ip_address()->str()); + ref.mutable_owner_address()->set_port(addr->port()); + ref.mutable_owner_address()->set_worker_id(addr->worker_id()->str()); + refs.emplace_back(std::move(ref)); + } + return refs; +} + } // namespace namespace ray { @@ -1461,28 +1480,25 @@ void NodeManager::ProcessDisconnectClientMessage( void NodeManager::ProcessFetchOrReconstructMessage( const std::shared_ptr &client, const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); - std::vector required_object_ids; - for (int64_t i = 0; i < message->object_ids()->size(); ++i) { - ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i)); - if (message->fetch_only()) { + const auto refs = + FlatbufferToObjectReference(*message->object_ids(), *message->owner_addresses()); + if (message->fetch_only()) { + for (const auto &ref : refs) { + ObjectID object_id = ObjectID::FromBinary(ref.object_id()); // If only a fetch is required, then do not subscribe to the // dependencies to the task dependency manager. if (!task_dependency_manager_.CheckObjectLocal(object_id)) { // Fetch the object if it's not already local. RAY_CHECK_OK(object_manager_.Pull(object_id)); } - } else { - // If reconstruction is also required, then add any requested objects to - // the list to subscribe to in the task dependency manager. These objects - // will be pulled from remote node managers and restarted if - // necessary. - required_object_ids.push_back(object_id); } - } - - if (!required_object_ids.empty()) { + } else { + // The values are needed. Add all requested objects to the list to + // subscribe to in the task dependency manager. These objects will be + // pulled from remote node managers. If an object's owner dies, an error + // will be stored as the object's value. const TaskID task_id = from_flatbuf(*message->task_id()); - AsyncResolveObjects(client, required_object_ids, task_id, /*ray_get=*/true, + AsyncResolveObjects(client, refs, task_id, /*ray_get=*/true, /*mark_worker_blocked*/ message->mark_worker_blocked()); } } @@ -1496,21 +1512,24 @@ void NodeManager::ProcessWaitRequestMessage( uint64_t num_required_objects = static_cast(message->num_ready_objects()); bool wait_local = message->wait_local(); - std::vector required_object_ids; + bool resolve_objects = false; for (auto const &object_id : object_ids) { if (!task_dependency_manager_.CheckObjectLocal(object_id)) { - // Add any missing objects to the list to subscribe to in the task - // dependency manager. These objects will be pulled from remote node - // managers and restarted if necessary. - required_object_ids.push_back(object_id); + // At least one object requires resolution. + resolve_objects = true; } } const TaskID ¤t_task_id = from_flatbuf(*message->task_id()); - bool resolve_objects = !required_object_ids.empty(); bool was_blocked = message->mark_worker_blocked(); if (resolve_objects) { - AsyncResolveObjects(client, required_object_ids, current_task_id, /*ray_get=*/false, + // Resolve any missing objects. This is a no-op for any objects that are + // already local. Missing objects will be pulled from remote node managers. + // If an object's owner dies, an error will be stored as the object's + // value. + const auto refs = + FlatbufferToObjectReference(*message->object_ids(), *message->owner_addresses()); + AsyncResolveObjects(client, refs, current_task_id, /*ray_get=*/false, /*mark_worker_blocked*/ was_blocked); } @@ -1545,18 +1564,18 @@ void NodeManager::ProcessWaitForDirectActorCallArgsRequestMessage( // Read the data. auto message = flatbuffers::GetRoot(message_data); - int64_t tag = message->tag(); std::vector object_ids = from_flatbuf(*message->object_ids()); - std::vector required_object_ids; - for (auto const &object_id : object_ids) { - if (!task_dependency_manager_.CheckObjectLocal(object_id)) { - // Add any missing objects to the list to subscribe to in the task - // dependency manager. These objects will be pulled from remote node - // managers and restarted if necessary. - required_object_ids.push_back(object_id); - } - } - + int64_t tag = message->tag(); + // Resolve any missing objects. This will pull the objects from remote node + // managers or store an error if the objects have failed. + const auto refs = + FlatbufferToObjectReference(*message->object_ids(), *message->owner_addresses()); + AsyncResolveObjects(client, refs, TaskID::Nil(), /*ray_get=*/false, + /*mark_worker_blocked*/ false); + // Reply to the client once a location has been found for all arguments. + // NOTE(swang): ObjectManager::Wait currently returns as soon as any location + // has been found, so the object may still be on a remote node when the + // client receives the reply. ray::Status status = object_manager_.Wait( object_ids, -1, object_ids.size(), false, [this, client, tag](std::vector found, std::vector remaining) { @@ -1746,11 +1765,11 @@ void NodeManager::NewSchedulerSchedulePendingTasks() { void NodeManager::WaitForTaskArgsRequests(std::pair &work) { RAY_CHECK(new_scheduler_enabled_); const Task &task = work.second; - std::vector object_ids = task.GetTaskSpecification().GetDependencies(); + const auto &object_refs = task.GetDependencies(); - if (object_ids.size() > 0) { + if (object_refs.size() > 0) { bool args_ready = task_dependency_manager_.SubscribeGetDependencies( - task.GetTaskSpecification().TaskId(), task.GetDependencies()); + task.GetTaskSpecification().TaskId(), object_refs); if (args_ready) { task_dependency_manager_.UnsubscribeGetDependencies( task.GetTaskSpecification().TaskId()); @@ -2321,8 +2340,8 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // Once the actor has been created and this method removed from the // waiting queue, the caller must make the corresponding call to // UnsubscribeGetDependencies. - task_dependency_manager_.SubscribeGetDependencies(spec.TaskId(), - {actor_creation_dummy_object}); + task_dependency_manager_.SubscribeGetDependencies( + spec.TaskId(), {GetReferenceForActorDummyObject(actor_creation_dummy_object)}); // Mark the task as pending. It will be canceled once we discover the // actor's location and either execute the task ourselves or forward it // to another node. @@ -2426,10 +2445,10 @@ void NodeManager::HandleDirectCallTaskUnblocked(const std::shared_ptr &w worker->MarkUnblocked(); } -void NodeManager::AsyncResolveObjects(const std::shared_ptr &client, - const std::vector &required_object_ids, - const TaskID ¤t_task_id, bool ray_get, - bool mark_worker_blocked) { +void NodeManager::AsyncResolveObjects( + const std::shared_ptr &client, + const std::vector &required_object_refs, + const TaskID ¤t_task_id, bool ray_get, bool mark_worker_blocked) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (worker) { // The client is a worker. If the worker is not already blocked and the @@ -2473,11 +2492,11 @@ void NodeManager::AsyncResolveObjects(const std::shared_ptr &c // HandleDirectCallUnblocked. auto &task_id = mark_worker_blocked ? current_task_id : worker->GetAssignedTaskId(); if (!task_id.IsNil()) { - task_dependency_manager_.SubscribeGetDependencies(task_id, required_object_ids); + task_dependency_manager_.SubscribeGetDependencies(task_id, required_object_refs); } } else { task_dependency_manager_.SubscribeWaitDependencies(worker->WorkerId(), - required_object_ids); + required_object_refs); } } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index ffd4253ec..2b3d3f73b 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -353,14 +353,14 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// wait call. /// /// \param client The client that is executing the blocked task. - /// \param required_object_ids The IDs that the client is blocked waiting for. + /// \param required_object_refs The objects that the client is blocked waiting for. /// \param current_task_id The task that is blocked. /// \param ray_get Whether the task is blocked in a `ray.get` call. /// \param mark_worker_blocked Whether to mark the worker as blocked. This /// should be False for direct calls. /// \return Void. void AsyncResolveObjects(const std::shared_ptr &client, - const std::vector &required_object_ids, + const std::vector &required_object_refs, const TaskID ¤t_task_id, bool ray_get, bool mark_worker_blocked); diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 054e12dfe..fbde68211 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -34,6 +34,24 @@ using MessageType = ray::protocol::MessageType; +namespace { + +flatbuffers::Offset>> +AddressesToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb, + const std::vector &addresses) { + std::vector> address_vec; + address_vec.reserve(addresses.size()); + for (const auto &addr : addresses) { + auto fbb_addr = ray::protocol::CreateAddress( + fbb, fbb.CreateString(addr.raylet_id()), fbb.CreateString(addr.ip_address()), + addr.port(), fbb.CreateString(addr.worker_id())); + address_vec.push_back(fbb_addr); + } + return fbb.CreateVector(address_vec); +} + +} // namespace + namespace ray { static int read_bytes(local_stream_socket &conn, void *cursor, size_t length) { @@ -216,14 +234,16 @@ Status raylet::RayletClient::TaskDone() { return conn_->WriteMessage(MessageType::TaskDone); } -Status raylet::RayletClient::FetchOrReconstruct(const std::vector &object_ids, - bool fetch_only, bool mark_worker_blocked, - const TaskID ¤t_task_id) { +Status raylet::RayletClient::FetchOrReconstruct( + const std::vector &object_ids, + const std::vector &owner_addresses, bool fetch_only, + bool mark_worker_blocked, const TaskID ¤t_task_id) { + RAY_CHECK(object_ids.size() == owner_addresses.size()); flatbuffers::FlatBufferBuilder fbb; auto object_ids_message = to_flatbuf(fbb, object_ids); - auto message = protocol::CreateFetchOrReconstruct(fbb, object_ids_message, fetch_only, - mark_worker_blocked, - to_flatbuf(fbb, current_task_id)); + auto message = protocol::CreateFetchOrReconstruct( + fbb, object_ids_message, AddressesToFlatbuffer(fbb, owner_addresses), fetch_only, + mark_worker_blocked, to_flatbuf(fbb, current_task_id)); fbb.Finish(message); auto status = conn_->WriteMessage(MessageType::FetchOrReconstruct, &fbb); return status; @@ -251,14 +271,16 @@ Status raylet::RayletClient::NotifyDirectCallTaskUnblocked() { } Status raylet::RayletClient::Wait(const std::vector &object_ids, + const std::vector &owner_addresses, int num_returns, int64_t timeout_milliseconds, bool wait_local, bool mark_worker_blocked, const TaskID ¤t_task_id, WaitResultPair *result) { // Write request. flatbuffers::FlatBufferBuilder fbb; auto message = protocol::CreateWaitRequest( - fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, wait_local, - mark_worker_blocked, to_flatbuf(fbb, current_task_id)); + fbb, to_flatbuf(fbb, object_ids), AddressesToFlatbuffer(fbb, owner_addresses), + num_returns, timeout_milliseconds, wait_local, mark_worker_blocked, + to_flatbuf(fbb, current_task_id)); fbb.Finish(message); std::unique_ptr reply; auto status = conn_->AtomicRequestReply(MessageType::WaitRequest, @@ -280,10 +302,16 @@ Status raylet::RayletClient::Wait(const std::vector &object_ids, } Status raylet::RayletClient::WaitForDirectActorCallArgs( - const std::vector &object_ids, int64_t tag) { + const std::vector &references, int64_t tag) { flatbuffers::FlatBufferBuilder fbb; + std::vector object_ids; + std::vector owner_addresses; + for (const auto &ref : references) { + object_ids.push_back(ObjectID::FromBinary(ref.object_id())); + owner_addresses.push_back(ref.owner_address()); + } auto message = protocol::CreateWaitForDirectActorCallArgsRequest( - fbb, to_flatbuf(fbb, object_ids), tag); + fbb, to_flatbuf(fbb, object_ids), AddressesToFlatbuffer(fbb, owner_addresses), tag); fbb.Finish(message); return conn_->WriteMessage(MessageType::WaitForDirectActorCallArgsRequest, &fbb); } diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 8678582e8..37409bf5d 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -89,11 +89,11 @@ class DependencyWaiterInterface { /// Wait for the given objects, asynchronously. The core worker is notified when /// the wait completes. /// - /// \param object_ids The objects to wait for. + /// \param references The objects to wait for. /// \param tag Value that will be sent to the core worker via gRPC on completion. /// \return ray::Status. - virtual ray::Status WaitForDirectActorCallArgs(const std::vector &object_ids, - int64_t tag) = 0; + virtual ray::Status WaitForDirectActorCallArgs( + const std::vector &references, int64_t tag) = 0; virtual ~DependencyWaiterInterface(){}; }; @@ -191,13 +191,16 @@ class RayletClient : public PinObjectsInterface, /// Tell the raylet to reconstruct or fetch objects. /// - /// \param object_ids The IDs of the objects to reconstruct. + /// \param object_ids The IDs of the objects to fetch. + /// \param owner_addresses The addresses of the workers that own the objects. /// \param fetch_only Only fetch objects, do not reconstruct them. /// \param mark_worker_blocked Set to false if current task is a direct call task. /// \param current_task_id The task that needs the objects. /// \return int 0 means correct, other numbers mean error. - ray::Status FetchOrReconstruct(const std::vector &object_ids, bool fetch_only, - bool mark_worker_blocked, const TaskID ¤t_task_id); + ray::Status FetchOrReconstruct(const std::vector &object_ids, + const std::vector &owner_addresses, + bool fetch_only, bool mark_worker_blocked, + const TaskID ¤t_task_id); /// Notify the raylet that this client (worker) is no longer blocked. /// @@ -221,6 +224,7 @@ class RayletClient : public PinObjectsInterface, /// found. /// /// \param object_ids The objects to wait for. + /// \param owner_addresses The addresses of the workers that own the objects. /// \param num_returns The number of objects to wait for. /// \param timeout_milliseconds Duration, in milliseconds, to wait before returning. /// \param wait_local Whether to wait for objects to appear on this node. @@ -229,7 +233,8 @@ class RayletClient : public PinObjectsInterface, /// \param result A pair with the first element containing the object ids that were /// found, and the second element the objects that were not found. /// \return ray::Status. - ray::Status Wait(const std::vector &object_ids, int num_returns, + ray::Status Wait(const std::vector &object_ids, + const std::vector &owner_addresses, int num_returns, int64_t timeout_milliseconds, bool wait_local, bool mark_worker_blocked, const TaskID ¤t_task_id, WaitResultPair *result); @@ -237,11 +242,11 @@ class RayletClient : public PinObjectsInterface, /// Wait for the given objects, asynchronously. The core worker is notified when /// the wait completes. /// - /// \param object_ids The objects to wait for. + /// \param references The objects to wait for. /// \param tag Value that will be sent to the core worker via gRPC on completion. /// \return ray::Status. - ray::Status WaitForDirectActorCallArgs(const std::vector &object_ids, - int64_t tag) override; + ray::Status WaitForDirectActorCallArgs( + const std::vector &references, int64_t tag) override; /// Push an error to the relevant driver. /// diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index 60a84f5fb..ab6c1f26e 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -173,11 +173,12 @@ std::vector TaskDependencyManager::HandleObjectMissing( } bool TaskDependencyManager::SubscribeGetDependencies( - const TaskID &task_id, const std::vector &required_objects) { + const TaskID &task_id, const std::vector &required_objects) { auto &task_entry = task_dependencies_[task_id]; // Record the task's dependencies. - for (const auto &object_id : required_objects) { + for (const auto &object : required_objects) { + const auto &object_id = ObjectID::FromBinary(object.object_id()); auto inserted = task_entry.get_dependencies.insert(object_id); if (inserted.second) { RAY_LOG(DEBUG) << "Task " << task_id << " blocked on object " << object_id; @@ -188,15 +189,23 @@ bool TaskDependencyManager::SubscribeGetDependencies( // The object is not local. task_entry.num_missing_get_dependencies++; } + + auto it = required_tasks_[creating_task_id].find(object_id); + if (it == required_tasks_[creating_task_id].end()) { + it = required_tasks_[creating_task_id] + .emplace(object_id, ObjectDependencies(object)) + .first; + } // Add the subscribed task to the mapping from object ID to list of // dependent tasks. - required_tasks_[creating_task_id][object_id].dependent_tasks.insert(task_id); + it->second.dependent_tasks.insert(task_id); } } // These dependencies are required by the given task. Try to make them local // if necessary. - for (const auto &object_id : required_objects) { + for (const auto &object : required_objects) { + const auto &object_id = ObjectID::FromBinary(object.object_id()); HandleRemoteDependencyRequired(object_id); } @@ -205,11 +214,13 @@ bool TaskDependencyManager::SubscribeGetDependencies( } void TaskDependencyManager::SubscribeWaitDependencies( - const WorkerID &worker_id, const std::vector &required_objects) { + const WorkerID &worker_id, + const std::vector &required_objects) { auto &worker_entry = worker_dependencies_[worker_id]; // Record the worker's dependencies. - for (const auto &object_id : required_objects) { + for (const auto &object : required_objects) { + const auto &object_id = ObjectID::FromBinary(object.object_id()); if (local_objects_.count(object_id) == 0) { RAY_LOG(DEBUG) << "Worker " << worker_id << " called ray.wait on remote object " << object_id; @@ -218,19 +229,24 @@ void TaskDependencyManager::SubscribeWaitDependencies( auto inserted = worker_entry.insert(object_id); if (inserted.second) { // Get the ID of the task that creates the dependency. - // TODO(qwang): Refine here to: - // if (object_id.CreatedByTask()) {// ...} TaskID creating_task_id = object_id.TaskId(); + auto it = required_tasks_[creating_task_id].find(object_id); + if (it == required_tasks_[creating_task_id].end()) { + it = required_tasks_[creating_task_id] + .emplace(object_id, ObjectDependencies(object)) + .first; + } // Add the subscribed worker to the mapping from object ID to list of // dependent workers. - required_tasks_[creating_task_id][object_id].dependent_workers.insert(worker_id); + it->second.dependent_workers.insert(worker_id); } } } // These dependencies are required by the given worker. Try to make them // local if necessary. - for (const auto &object_id : required_objects) { + for (const auto &object : required_objects) { + const auto &object_id = ObjectID::FromBinary(object.object_id()); HandleRemoteDependencyRequired(object_id); } } @@ -252,11 +268,12 @@ bool TaskDependencyManager::UnsubscribeGetDependencies(const TaskID &task_id) { auto creating_task_entry = required_tasks_.find(creating_task_id); // Remove the task from the list of tasks that are dependent on this // object. - auto &dependent_tasks = creating_task_entry->second[object_id].dependent_tasks; - RAY_CHECK(dependent_tasks.erase(task_id) > 0); + auto it = creating_task_entry->second.find(object_id); + RAY_CHECK(it != creating_task_entry->second.end()); + RAY_CHECK(it->second.dependent_tasks.erase(task_id) > 0); // If nothing else depends on the object, then erase the object entry. - if (creating_task_entry->second[object_id].Empty()) { - creating_task_entry->second.erase(object_id); + if (it->second.Empty()) { + creating_task_entry->second.erase(it); // Remove the task that creates this object if there are no more object // dependencies created by the task. if (creating_task_entry->second.empty()) { @@ -291,11 +308,12 @@ void TaskDependencyManager::UnsubscribeWaitDependencies(const WorkerID &worker_i auto creating_task_entry = required_tasks_.find(creating_task_id); // Remove the worker from the list of workers that are dependent on this // object. - auto &dependent_workers = creating_task_entry->second[object_id].dependent_workers; - RAY_CHECK(dependent_workers.erase(worker_id) > 0); + auto it = creating_task_entry->second.find(object_id); + RAY_CHECK(it != creating_task_entry->second.end()); + RAY_CHECK(it->second.dependent_workers.erase(worker_id) > 0); // If nothing else depends on the object, then erase the object entry. - if (creating_task_entry->second[object_id].Empty()) { - creating_task_entry->second.erase(object_id); + if (it->second.Empty()) { + creating_task_entry->second.erase(it); // Remove the task that creates this object if there are no more object // dependencies created by the task. if (creating_task_entry->second.empty()) { diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index dbddf4c8b..85cc7e56b 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -71,8 +71,8 @@ class TaskDependencyManager { /// \param required_objects The objects required by the task. /// \return Whether all of the given dependencies for the given task are /// local. - bool SubscribeGetDependencies(const TaskID &task_id, - const std::vector &required_objects); + bool SubscribeGetDependencies( + const TaskID &task_id, const std::vector &required_objects); /// Subscribe to object depedencies required by the worker. This should be called for /// ray.wait calls during task execution. @@ -86,8 +86,9 @@ class TaskDependencyManager { /// \param worker_id The ID of the worker that called `ray.wait`. /// \param required_objects The objects required by the worker. /// \return Void. - void SubscribeWaitDependencies(const WorkerID &worker_id, - const std::vector &required_objects); + void SubscribeWaitDependencies( + const WorkerID &worker_id, + const std::vector &required_objects); /// Unsubscribe from the object dependencies required by this task through the task /// arguments or `ray.get`. If the objects were remote and are no longer required by any @@ -165,12 +166,16 @@ class TaskDependencyManager { private: struct ObjectDependencies { + ObjectDependencies(const rpc::ObjectReference &ref) + : owner_address(ref.owner_address()) {} /// The tasks that depend on this object, either because the object is a task argument /// or because the task called `ray.get` on the object. std::unordered_set dependent_tasks; /// The workers that depend on this object because they called `ray.wait` on the /// object. std::unordered_set dependent_workers; + /// The address of the worker that owns this object. + rpc::Address owner_address; bool Empty() const { return dependent_tasks.empty() && dependent_workers.empty(); } }; diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index b9c447f89..7220108f1 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -151,7 +151,8 @@ TEST_F(TaskDependencyManagerTest, TestSimpleTask) { EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id)); } // Subscribe to the task's dependencies. - bool ready = task_dependency_manager_.SubscribeGetDependencies(task_id, arguments); + bool ready = task_dependency_manager_.SubscribeGetDependencies( + task_id, ObjectIdsToRefs(arguments)); ASSERT_FALSE(ready); // All arguments should be canceled as they become available locally. @@ -187,7 +188,8 @@ TEST_F(TaskDependencyManagerTest, TestDuplicateSubscribeGetDependencies) { // requested from the node manager once. EXPECT_CALL(object_manager_mock_, Pull(argument_id)); EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id)); - bool ready = task_dependency_manager_.SubscribeGetDependencies(task_id, arguments); + bool ready = task_dependency_manager_.SubscribeGetDependencies( + task_id, ObjectIdsToRefs(arguments)); ASSERT_FALSE(ready); } @@ -223,8 +225,8 @@ TEST_F(TaskDependencyManagerTest, TestMultipleTasks) { TaskID task_id = RandomTaskId(); dependent_tasks.push_back(task_id); // Subscribe to each of the task's dependencies. - bool ready = - task_dependency_manager_.SubscribeGetDependencies(task_id, {argument_id}); + bool ready = task_dependency_manager_.SubscribeGetDependencies( + task_id, ObjectIdsToRefs({argument_id})); ASSERT_FALSE(ready); } @@ -312,7 +314,7 @@ TEST_F(TaskDependencyManagerTest, TestDependentPut) { EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(put_id)); // Subscribe to the task's dependencies. bool ready = task_dependency_manager_.SubscribeGetDependencies( - task2.GetTaskSpecification().TaskId(), {put_id}); + task2.GetTaskSpecification().TaskId(), ObjectIdsToRefs({put_id})); ASSERT_FALSE(ready); // The put object should be considered local as soon as the task that creates @@ -373,7 +375,8 @@ TEST_F(TaskDependencyManagerTest, TestEviction) { EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id)); } // Subscribe to the task's dependencies. - bool ready = task_dependency_manager_.SubscribeGetDependencies(task_id, arguments); + bool ready = task_dependency_manager_.SubscribeGetDependencies( + task_id, ObjectIdsToRefs(arguments)); ASSERT_FALSE(ready); // Tell the task dependency manager that each of the arguments is now @@ -528,10 +531,12 @@ TEST_F(TaskDependencyManagerTest, TestWaitDependencies) { EXPECT_CALL(object_manager_mock_, Pull(_)).Times(num_objects); EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(_)) .Times(num_objects); - task_dependency_manager_.SubscribeWaitDependencies(worker_id, wait_object_ids); + task_dependency_manager_.SubscribeWaitDependencies(worker_id, + ObjectIdsToRefs(wait_object_ids)); // Check that it's okay to call `ray.wait` on the same objects again. No new // calls should be made to try and make the objects local. - task_dependency_manager_.SubscribeWaitDependencies(worker_id, wait_object_ids); + task_dependency_manager_.SubscribeWaitDependencies(worker_id, + ObjectIdsToRefs(wait_object_ids)); // Cancel the worker's `ray.wait`. calls. EXPECT_CALL(object_manager_mock_, CancelPull(_)).Times(num_objects); EXPECT_CALL(reconstruction_policy_mock_, Cancel(_)).Times(num_objects); @@ -563,7 +568,8 @@ TEST_F(TaskDependencyManagerTest, TestWaitDependenciesObjectLocal) { EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(object_id)); } } - task_dependency_manager_.SubscribeWaitDependencies(worker_id, wait_object_ids); + task_dependency_manager_.SubscribeWaitDependencies(worker_id, + ObjectIdsToRefs(wait_object_ids)); // Simulate the local object getting evicted. The `ray.wait` call should not // be reactivated. auto waiting_task_ids = task_dependency_manager_.HandleObjectMissing(local_object_id); @@ -593,7 +599,8 @@ TEST_F(TaskDependencyManagerTest, TestWaitDependenciesHandleObjectLocal) { EXPECT_CALL(object_manager_mock_, Pull(_)).Times(num_objects); EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(_)) .Times(num_objects); - task_dependency_manager_.SubscribeWaitDependencies(worker_id, wait_object_ids); + task_dependency_manager_.SubscribeWaitDependencies(worker_id, + ObjectIdsToRefs(wait_object_ids)); // Simulate one of the objects becoming local while the `ray.wait` calls is // active. The `ray.wait` call should be canceled. const ObjectID local_object_id = std::move(wait_object_ids.back());