diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 59dda5c11..0eba2ebd8 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -55,7 +55,8 @@ WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id) worker_id_(worker_type_ == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id) : WorkerID::FromRandom()), current_job_id_(worker_type_ == WorkerType::DRIVER ? job_id : JobID::Nil()), - current_actor_id_(ActorID::Nil()) { + current_actor_id_(ActorID::Nil()), + main_thread_id_(boost::this_thread::get_id()) { // For worker main thread which initializes the WorkerContext, // set task_id according to whether current worker is a driver. // (For other threads it's set to random ID via GetThreadContext). @@ -118,6 +119,10 @@ std::shared_ptr WorkerContext::GetCurrentTask() const { const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; } +bool WorkerContext::CurrentThreadIsMain() const { + return boost::this_thread::get_id() == main_thread_id_; +} + bool WorkerContext::CurrentActorIsDirectCall() const { return current_actor_is_direct_call_; } diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 3ced2ced1..08f79a9fd 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -1,6 +1,8 @@ #ifndef RAY_CORE_WORKER_CONTEXT_H #define RAY_CORE_WORKER_CONTEXT_H +#include + #include "ray/common/task/task_spec.h" #include "ray/core_worker/common.h" @@ -34,6 +36,9 @@ class WorkerContext { const ActorID &GetCurrentActorID() const; + /// Returns whether the current thread is the main worker thread. + bool CurrentThreadIsMain() const; + /// Returns whether we are in a direct call actor. bool CurrentActorIsDirectCall() const; @@ -56,6 +61,9 @@ class WorkerContext { bool current_task_is_direct_call_ = false; int current_actor_max_concurrency_ = 1; + /// The id of the (main) thread that constructed this worker context. + boost::thread::id main_thread_id_; + private: static WorkerThreadContext &GetThreadContext(bool for_main_thread = false); diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 36f58150e..842926078 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -165,7 +165,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, [this](const RayObject &obj, const ObjectID &obj_id) { RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id)); }, - ref_counting_enabled ? reference_counter_ : nullptr)); + ref_counting_enabled ? reference_counter_ : nullptr, raylet_client_)); memory_store_provider_.reset(new CoreWorkerMemoryStoreProvider(memory_store_)); // Create an entry for the driver task in the task table. This task is @@ -263,8 +263,7 @@ void CoreWorker::ReportActiveObjectIDs() { reference_counter_->GetAllInScopeObjectIDs(); RAY_LOG(DEBUG) << "Sending " << active_object_ids.size() << " object IDs to raylet."; if (active_object_ids.size() > RayConfig::instance().raylet_max_active_object_ids()) { - RAY_LOG(WARNING) << active_object_ids.size() << "object IDs are currently in scope. " - << "This may lead to required objects being garbage collected."; + RAY_LOG(WARNING) << active_object_ids.size() << " object IDs are currently in scope."; } if (!raylet_client_->ReportActiveObjectIDs(active_object_ids).ok()) { @@ -347,8 +346,8 @@ Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_m timeout_ms - (current_time_ms() - start_time)); } RAY_RETURN_NOT_OK(memory_store_provider_->Get(memory_object_ids, local_timeout_ms, - worker_context_.GetCurrentTaskID(), - &result_map, &got_exception)); + worker_context_, &result_map, + &got_exception)); } // If any of the objects have been promoted to plasma, then we retry their @@ -454,7 +453,7 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, // consider waiting on them in plasma as well to ensure they are local. RAY_RETURN_NOT_OK(memory_store_provider_->Wait( memory_object_ids, num_objects - static_cast(ready.size()), - /*timeout_ms=*/0, worker_context_.GetCurrentTaskID(), &ready)); + /*timeout_ms=*/0, worker_context_, &ready)); } RAY_CHECK(static_cast(ready.size()) <= num_objects); @@ -477,7 +476,7 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, if (static_cast(ready.size()) < num_objects && memory_object_ids.size() > 0) { RAY_RETURN_NOT_OK(memory_store_provider_->Wait( memory_object_ids, num_objects - static_cast(ready.size()), timeout_ms, - worker_context_.GetCurrentTaskID(), &ready)); + worker_context_, &ready)); } RAY_CHECK(static_cast(ready.size()) <= num_objects); } diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index 60d6c6191..c904f22db 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -161,7 +161,8 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { RAY_CHECK_OK(store.Put(id1, buffer)); ASSERT_EQ(store.Size(), 1); std::vector> results; - RAY_CHECK_OK(store.Get({id1}, /*num_objects*/ 1, /*timeout_ms*/ -1, + WorkerContext ctx(WorkerType::WORKER, JobID::Nil()); + RAY_CHECK_OK(store.Get({id1}, /*num_objects*/ 1, /*timeout_ms*/ -1, ctx, /*remove_after_get*/ true, &results)); ASSERT_EQ(results.size(), 1); ASSERT_EQ(store.Size(), 1); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 1fa82e495..1154645bc 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -109,8 +109,11 @@ std::shared_ptr GetRequest::Get(const ObjectID &object_id) const { CoreWorkerMemoryStore::CoreWorkerMemoryStore( std::function store_in_plasma, - std::shared_ptr counter) - : store_in_plasma_(store_in_plasma), ref_counter_(counter) {} + std::shared_ptr counter, + std::shared_ptr raylet_client) + : store_in_plasma_(store_in_plasma), + ref_counter_(counter), + raylet_client_(raylet_client) {} void CoreWorkerMemoryStore::GetAsync( const ObjectID &object_id, std::function)> callback) { @@ -208,7 +211,7 @@ Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &ob Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, - bool remove_after_get, + const WorkerContext &ctx, bool remove_after_get, std::vector> *results) { (*results).resize(object_ids.size(), nullptr); @@ -260,8 +263,20 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, } } + // Only send block/unblock IPCs for non-actor tasks on the main thread. + // TODO(ekl) support non-lifetime resources for direct actor calls. + bool should_notify_raylet = + (raylet_client_ != nullptr && !ctx.CurrentActorIsDirectCall() && + ctx.CurrentThreadIsMain()); + // Wait for remaining objects (or timeout). + if (should_notify_raylet) { + RAY_CHECK_OK(raylet_client_->NotifyDirectCallTaskBlocked()); + } bool done = get_request->Wait(timeout_ms); + if (should_notify_raylet) { + RAY_CHECK_OK(raylet_client_->NotifyDirectCallTaskUnblocked()); + } { absl::MutexLock lock(&mu_); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index 504404a35..ef94e3373 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -7,6 +7,7 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/core_worker/common.h" +#include "ray/core_worker/context.h" #include "ray/core_worker/reference_count.h" namespace ray { @@ -24,9 +25,11 @@ class CoreWorkerMemoryStore { /// \param[in] store_in_plasma If not null, this is used to spill to plasma. /// \param[in] counter If not null, this enables ref counting for local objects, /// and the `remove_after_get` flag for Get() will be ignored. + /// \param[in] raylet_client If not null, used to notify tasks blocked / unblocked. CoreWorkerMemoryStore( std::function store_in_plasma = nullptr, - std::shared_ptr counter = nullptr); + std::shared_ptr counter = nullptr, + std::shared_ptr raylet_client = nullptr); ~CoreWorkerMemoryStore(){}; /// Put an object with specified ID into object store. @@ -41,12 +44,14 @@ class CoreWorkerMemoryStore { /// \param[in] object_ids IDs of the objects to get. Duplicates are not allowed. /// \param[in] num_objects Number of objects that should appear. /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. + /// \param[in] ctx The current worker context. /// \param[in] remove_after_get When to remove the objects from store after `Get` /// finishes. This has no effect if ref counting is enabled. /// \param[out] results Result list of objects data. /// \return Status. Status Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, - bool remove_after_get, std::vector> *results); + const WorkerContext &ctx, bool remove_after_get, + std::vector> *results); /// Asynchronously get an object from the object store. The object will not be removed /// from storage after GetAsync (TODO(ekl): integrate this with object GC). @@ -93,6 +98,9 @@ class CoreWorkerMemoryStore { /// mandatory once Java is supported. std::shared_ptr ref_counter_ = nullptr; + // If set, this will be used to notify worker blocked / unblocked on get calls. + std::shared_ptr raylet_client_ = nullptr; + /// Protects the data structures below. absl::Mutex mu_; diff --git a/src/ray/core_worker/store_provider/memory_store_provider.cc b/src/ray/core_worker/store_provider/memory_store_provider.cc index 3568fa923..773c8c2dd 100644 --- a/src/ray/core_worker/store_provider/memory_store_provider.cc +++ b/src/ray/core_worker/store_provider/memory_store_provider.cc @@ -25,13 +25,13 @@ Status CoreWorkerMemoryStoreProvider::Put(const RayObject &object, Status CoreWorkerMemoryStoreProvider::Get( const absl::flat_hash_set &object_ids, int64_t timeout_ms, - const TaskID &task_id, + const WorkerContext &ctx, absl::flat_hash_map> *results, bool *got_exception) { const std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; RAY_RETURN_NOT_OK( - store_->Get(id_vector, id_vector.size(), timeout_ms, true, &result_objects)); + store_->Get(id_vector, id_vector.size(), timeout_ms, ctx, true, &result_objects)); for (size_t i = 0; i < id_vector.size(); i++) { if (result_objects[i] != nullptr) { @@ -52,11 +52,12 @@ Status CoreWorkerMemoryStoreProvider::Contains(const ObjectID &object_id, Status CoreWorkerMemoryStoreProvider::Wait( const absl::flat_hash_set &object_ids, int num_objects, int64_t timeout_ms, - const TaskID &task_id, absl::flat_hash_set *ready) { + const WorkerContext &ctx, absl::flat_hash_set *ready) { std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; RAY_CHECK(object_ids.size() == id_vector.size()); - auto status = store_->Get(id_vector, num_objects, timeout_ms, false, &result_objects); + auto status = + store_->Get(id_vector, num_objects, timeout_ms, ctx, false, &result_objects); // Ignore TimedOut statuses since we return ready objects explicitly. if (!status.IsTimedOut()) { RAY_RETURN_NOT_OK(status); diff --git a/src/ray/core_worker/store_provider/memory_store_provider.h b/src/ray/core_worker/store_provider/memory_store_provider.h index 76ac204dd..ce07f6633 100644 --- a/src/ray/core_worker/store_provider/memory_store_provider.h +++ b/src/ray/core_worker/store_provider/memory_store_provider.h @@ -7,6 +7,7 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/core_worker/common.h" +#include "ray/core_worker/context.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" namespace ray { @@ -27,7 +28,7 @@ class CoreWorkerMemoryStoreProvider { Status Put(const RayObject &object, const ObjectID &object_id); Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, - const TaskID &task_id, + const WorkerContext &ctx, absl::flat_hash_map> *results, bool *got_exception); @@ -35,7 +36,7 @@ class CoreWorkerMemoryStoreProvider { /// Note that `num_objects` must equal to number of items in `object_ids`. Status Wait(const absl::flat_hash_set &object_ids, int num_objects, - int64_t timeout_ms, const TaskID &task_id, + int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_set *ready); /// Note that `local_only` must be true, and `delete_creating_tasks` must be false here. diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index e3a8a7af6..fe16a624f 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -646,15 +646,15 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { absl::flat_hash_set wait_results; ObjectID nonexistent_id = ObjectID::FromRandom().WithDirectTransportType(); + WorkerContext ctx(WorkerType::WORKER, JobID::Nil()); wait_ids.insert(nonexistent_id); - RAY_CHECK_OK( - provider.Wait(wait_ids, ids.size() + 1, 100, RandomTaskId(), &wait_results)); + RAY_CHECK_OK(provider.Wait(wait_ids, ids.size() + 1, 100, ctx, &wait_results)); ASSERT_EQ(wait_results.size(), ids.size()); ASSERT_TRUE(wait_results.count(nonexistent_id) == 0); // Test Wait() where the required `num_objects` is less than size of `wait_ids`. wait_results.clear(); - RAY_CHECK_OK(provider.Wait(wait_ids, ids.size(), -1, RandomTaskId(), &wait_results)); + RAY_CHECK_OK(provider.Wait(wait_ids, ids.size(), -1, ctx, &wait_results)); ASSERT_EQ(wait_results.size(), ids.size()); ASSERT_TRUE(wait_results.count(nonexistent_id) == 0); @@ -662,7 +662,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { bool got_exception = false; absl::flat_hash_map> results; absl::flat_hash_set ids_set(ids.begin(), ids.end()); - RAY_CHECK_OK(provider.Get(ids_set, -1, RandomTaskId(), &results, &got_exception)); + RAY_CHECK_OK(provider.Get(ids_set, -1, ctx, &results, &got_exception)); ASSERT_TRUE(!got_exception); ASSERT_EQ(results.size(), ids.size()); @@ -685,8 +685,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { RAY_CHECK_OK(provider.Delete(ids_set)); usleep(200 * 1000); - ASSERT_TRUE( - provider.Get(ids_set, 0, RandomTaskId(), &results, &got_exception).IsTimedOut()); + ASSERT_TRUE(provider.Get(ids_set, 0, ctx, &results, &got_exception).IsTimedOut()); ASSERT_TRUE(!got_exception); ASSERT_EQ(results.size(), 0); @@ -715,8 +714,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { wait_results.clear(); // Check that only the ready ids are returned when timeout ends before thread runs. - RAY_CHECK_OK( - provider.Wait(wait_ids, ready_ids.size() + 1, 100, RandomTaskId(), &wait_results)); + RAY_CHECK_OK(provider.Wait(wait_ids, ready_ids.size() + 1, 100, ctx, &wait_results)); ASSERT_EQ(ready_ids.size(), wait_results.size()); for (const auto &ready_id : ready_ids) { ASSERT_TRUE(wait_results.find(ready_id) != wait_results.end()); @@ -727,8 +725,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { wait_results.clear(); // Check that enough objects are returned after the thread inserts at least one object. - RAY_CHECK_OK( - provider.Wait(wait_ids, ready_ids.size() + 1, 5000, RandomTaskId(), &wait_results)); + RAY_CHECK_OK(provider.Wait(wait_ids, ready_ids.size() + 1, 5000, ctx, &wait_results)); ASSERT_TRUE(wait_results.size() >= ready_ids.size() + 1); for (const auto &ready_id : ready_ids) { ASSERT_TRUE(wait_results.find(ready_id) != wait_results.end()); @@ -737,8 +734,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { wait_results.clear(); // Check that all objects are returned after the thread completes. async_thread.join(); - RAY_CHECK_OK( - provider.Wait(wait_ids, wait_ids.size(), -1, RandomTaskId(), &wait_results)); + RAY_CHECK_OK(provider.Wait(wait_ids, wait_ids.size(), -1, ctx, &wait_results)); ASSERT_EQ(wait_results.size(), ready_ids.size() + unready_ids.size()); for (const auto &ready_id : ready_ids) { ASSERT_TRUE(wait_results.find(ready_id) != wait_results.end()); diff --git a/src/ray/core_worker/transport/raylet_transport.h b/src/ray/core_worker/transport/raylet_transport.h index b4de32b24..c2137f754 100644 --- a/src/ray/core_worker/transport/raylet_transport.h +++ b/src/ray/core_worker/transport/raylet_transport.h @@ -31,7 +31,8 @@ class CoreWorkerRayletTaskReceiver { rpc::SendReplyCallback send_reply_callback); private: - /// Raylet client. + /// Reference to the core worker's raylet client. This is a pointer ref so that it + /// can be initialized by core worker after this class is constructed. std::shared_ptr &raylet_client_; /// The callback function to process a task. TaskHandler task_handler_; diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 91694b7e6..23b7406e5 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -31,6 +31,11 @@ enum MessageType:int { // For a worker that was blocked on some object(s), tell the raylet // that the worker is now unblocked. This is sent from a worker to a raylet. NotifyUnblocked, + // Notify the current worker is blocked. This is only used by direct task calls; + // otherwise the block command is piggybacked on other messages. + NotifyDirectCallTaskBlocked, + // Notify the current worker is unblocked. This is only used by direct task calls. + NotifyDirectCallTaskUnblocked, // A request to get the task frontier for an actor, called by the actor when // saving a checkpoint. GetActorFrontierRequest, @@ -161,6 +166,12 @@ table NotifyUnblocked { task_id: string; } +table NotifyDirectCallTaskBlocked { +} + +table NotifyDirectCallTaskUnblocked { +} + table WaitRequest { // List of object ids we'll be waiting on. object_ids: [string]; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 6c73e1bd8..6bedea3ea 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -833,6 +833,7 @@ void NodeManager::DispatchTasks( return local_queues_.NumRunning(a->first) < local_queues_.NumRunning(b->first); }); } + std::vector> post_assign_callbacks; // Approximate fair round robin between classes. for (const auto &it : fair_order) { const auto &task_resources = @@ -845,7 +846,7 @@ void NodeManager::DispatchTasks( // once the first task is not feasible, we can break out of this loop break; } - if (AssignTask(task)) { + if (AssignTask(task, &post_assign_callbacks)) { removed_task_ids.insert(task_id); } } @@ -854,6 +855,9 @@ void NodeManager::DispatchTasks( // it queued locally. Once the GetTaskReply has been sent, the task will get // re-queued, depending on whether the message succeeded or not. local_queues_.MoveTasks(removed_task_ids, TaskState::READY, TaskState::SWAP); + for (auto func : post_assign_callbacks) { + func(); + } } void NodeManager::ProcessClientMessage( @@ -902,6 +906,14 @@ void NodeManager::ProcessClientMessage( case protocol::MessageType::FetchOrReconstruct: { ProcessFetchOrReconstructMessage(client, message_data); } break; + case protocol::MessageType::NotifyDirectCallTaskBlocked: { + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + HandleDirectCallTaskBlocked(worker); + } break; + case protocol::MessageType::NotifyDirectCallTaskUnblocked: { + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + HandleDirectCallTaskUnblocked(worker); + } break; case protocol::MessageType::NotifyUnblocked: { auto message = flatbuffers::GetRoot(message_data); HandleTaskUnblocked(client, from_flatbuf(*message->task_id())); @@ -1103,6 +1115,8 @@ void NodeManager::ProcessDisconnectClientMessage( // Clean up any open ray.wait calls that the worker made. task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); } + // Erase any lease metadata. + leased_workers_.erase(worker->Port()); } if (is_worker) { @@ -1435,6 +1449,7 @@ void NodeManager::HandleWorkerLeaseRequest(const rpc::WorkerLeaseRequest &reques // TODO(swang): Kill worker if other end hangs up. // TODO(swang): Implement a lease term by which the owner needs to return the // worker. + RAY_CHECK(leased_workers_.find(port) == leased_workers_.end()); leased_workers_[port] = std::static_pointer_cast(granted); }); task.OnSpillbackInstead( @@ -1454,14 +1469,18 @@ void NodeManager::HandleReturnWorker(const rpc::ReturnWorkerRequest &request, rpc::SendReplyCallback send_reply_callback) { // Read the resource spec submitted by the client. auto worker_port = request.worker_port(); - RAY_LOG(DEBUG) << "Return worker " << worker_port; std::shared_ptr worker = std::move(leased_workers_[worker_port]); leased_workers_.erase(worker_port); Status status; if (worker) { + // Handle the edge case where the worker was returned before we got the + // unblock RPC by unblocking it immediately (unblock is idempotent). + if (worker->IsBlocked()) { + HandleDirectCallTaskUnblocked(worker); + } HandleWorkerAvailable(worker); } else { - status = Status::Invalid("Returned worker does not exist"); + status = Status::Invalid("Returned worker does not exist any more"); } send_reply_callback(status, nullptr, nullptr); } @@ -1844,6 +1863,48 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag } } +void NodeManager::HandleDirectCallTaskBlocked(const std::shared_ptr &worker) { + if (!worker || worker->GetAssignedTaskId().IsNil() || worker->IsBlocked()) { + return; // The worker may have died or is no longer processing the task. + } + auto const cpu_resource_ids = worker->ReleaseTaskCpuResources(); + local_available_resources_.Release(cpu_resource_ids); + cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release( + cpu_resource_ids.ToResourceSet()); + worker->MarkBlocked(); + DispatchTasks(local_queues_.GetReadyTasksByClass()); +} + +void NodeManager::HandleDirectCallTaskUnblocked(const std::shared_ptr &worker) { + if (!worker || worker->GetAssignedTaskId().IsNil() || !worker->IsBlocked()) { + return; // The worker may have died or is no longer processing the task. + } + TaskID task_id = worker->GetAssignedTaskId(); + Task task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING); + const auto required_resources = task.GetTaskSpecification().GetRequiredResources(); + const ResourceSet cpu_resources = required_resources.GetNumCpus(); + bool oversubscribed = !local_available_resources_.Contains(cpu_resources); + if (!oversubscribed) { + // Reacquire the CPU resources for the worker. Note that care needs to be + // taken if the user is using the specific CPU IDs since the IDs that we + // reacquire here may be different from the ones that the task started with. + auto const resource_ids = local_available_resources_.Acquire(cpu_resources); + worker->AcquireTaskCpuResources(resource_ids); + cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Acquire( + cpu_resources); + } else { + // In this case, we simply don't reacquire the CPU resources for the worker. + // The worker can keep running and when the task finishes, it will simply + // not have any CPU resources to release. + RAY_LOG(WARNING) + << "Resources oversubscribed: " + << cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()] + .GetAvailableResources() + .ToString(); + } + worker->MarkUnblocked(); +} + void NodeManager::HandleTaskBlocked(const std::shared_ptr &client, const std::vector &required_object_ids, const TaskID ¤t_task_id, bool ray_get) { @@ -1884,12 +1945,14 @@ void NodeManager::HandleTaskBlocked(const std::shared_ptr // Subscribe to the objects required by the task. These objects will be // fetched and/or reconstructed as necessary, until the objects become local // or are unsubscribed. - if (ray_get) { - task_dependency_manager_.SubscribeGetDependencies(current_task_id, - required_object_ids); - } else { - task_dependency_manager_.SubscribeWaitDependencies(worker->WorkerId(), - required_object_ids); + if (!required_object_ids.empty()) { + if (ray_get) { + task_dependency_manager_.SubscribeGetDependencies(current_task_id, + required_object_ids); + } else { + task_dependency_manager_.SubscribeWaitDependencies(worker->WorkerId(), + required_object_ids); + } } } @@ -1974,7 +2037,8 @@ void NodeManager::EnqueuePlaceableTask(const Task &task) { task_dependency_manager_.TaskPending(task); } -bool NodeManager::AssignTask(const Task &task) { +bool NodeManager::AssignTask(const Task &task, + std::vector> *post_assign_callbacks) { const TaskSpecification &spec = task.GetTaskSpecification(); // If this is an actor task, check that the new task has the correct counter. @@ -2036,7 +2100,12 @@ bool NodeManager::AssignTask(const Task &task) { if (task.OnDispatch() != nullptr) { task.OnDispatch()(worker, initial_config_.node_manager_address, worker->Port()); - finish_assign_task_callback(Status::OK()); + if (post_assign_callbacks != nullptr) { + // Moves the tasks from SWAP to RUNNING state atomically. This avoids race + // conditions with ReturnLease requests. + post_assign_callbacks->push_back( + [this, worker, task_id]() { FinishAssignTask(task_id, *worker, true); }); + } } else { worker->AssignTask(task, resource_id_set, finish_assign_task_callback); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index efea4246d..f49ea2f20 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -211,8 +211,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// Assign a task. The task is assumed to not be queued in local_queues_. /// /// \param task The task in question. + /// \param post_assign_callbacks Set of functions to run after assignments finish. /// \return true, if tasks was assigned to a worker, false otherwise. - bool AssignTask(const Task &task); + bool AssignTask(const Task &task, + std::vector> *post_assign_callbacks); /// Handle a worker finishing its assigned task. /// /// \param worker The worker that finished the task. @@ -328,6 +330,19 @@ class NodeManager : public rpc::NodeManagerServiceHandler { void HandleTaskUnblocked(const std::shared_ptr &client, const TaskID ¤t_task_id); + /// Handle a direct call task that is blocked. Note that this callback may + /// arrive after the worker lease has been returned to the node manager. + /// + /// \param worker Shared ptr to the worker, or nullptr if lost. + void HandleDirectCallTaskBlocked(const std::shared_ptr &worker); + + /// Handle a direct call task that is unblocked. Note that this callback may + /// arrive after the worker lease has been returned to the node manager. + /// However, it is guaranteed to arrive after DirectCallTaskBlocked. + /// + /// \param worker Shared ptr to the worker, or nullptr if lost. + void HandleDirectCallTaskUnblocked(const std::shared_ptr &worker); + /// Kill a worker. /// /// \param worker The worker to kill. diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 9803a44ae..f9e839bf4 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -258,6 +258,20 @@ ray::Status RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) { return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb); } +ray::Status RayletClient::NotifyDirectCallTaskBlocked() { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateNotifyDirectCallTaskBlocked(fbb); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::NotifyDirectCallTaskBlocked, &fbb); +} + +ray::Status RayletClient::NotifyDirectCallTaskUnblocked() { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateNotifyDirectCallTaskUnblocked(fbb); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::NotifyDirectCallTaskUnblocked, &fbb); +} + ray::Status RayletClient::Wait(const std::vector &object_ids, int num_returns, int64_t timeout_milliseconds, bool wait_local, const TaskID ¤t_task_id, WaitResultPair *result) { @@ -392,6 +406,8 @@ ray::Status RayletClient::ReturnWorker(int worker_port) { request.set_worker_port(worker_port); return grpc_client_->ReturnWorker( request, [](const ray::Status &status, const ray::rpc::ReturnWorkerReply &reply) { - RAY_CHECK_OK(status); + if (!status.ok()) { + RAY_LOG(ERROR) << "Error returning worker: " << status; + } }); } diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index f796307ce..a28555fc0 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -126,12 +126,25 @@ class RayletClient : public WorkerLeaseInterface { /// \return int 0 means correct, other numbers mean error. ray::Status FetchOrReconstruct(const std::vector &object_ids, bool fetch_only, const TaskID ¤t_task_id); + /// Notify the raylet that this client (worker) is no longer blocked. /// /// \param current_task_id The task that is no longer blocked. /// \return ray::Status. ray::Status NotifyUnblocked(const TaskID ¤t_task_id); + /// Notify the raylet that this client is blocked. This is only used for direct task + /// calls. Note that ordering of this with respect to Unblock calls is important. + /// + /// \return ray::Status. + ray::Status NotifyDirectCallTaskBlocked(); + + /// Notify the raylet that this client is unblocked. This is only used for direct task + /// calls. Note that ordering of this with respect to Block calls is important. + /// + /// \return ray::Status. + ray::Status NotifyDirectCallTaskUnblocked(); + /// Wait for the given objects until timeout expires or num_return objects are /// found. ///