From b7b655c85112abd20a8c92dac4b6062681ffbd19 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 27 Nov 2019 22:46:15 -0800 Subject: [PATCH] Also use NotifyDirectCallTaskBlock/Unblocked for plasma store accesses (#6249) * wip * fix it * lint * wip * fix * unblock * flaky * use fetch only flag * Revert "use fetch only flag" This reverts commit 56e938a0ee2024f5c99c9ab2d55fd35558fb15e1. * restore error resolution * use worker task id * proto comments * fix if --- python/ray/_raylet.pyx | 2 +- python/ray/includes/libraylet.pxd | 4 +- python/ray/tests/BUILD | 1 + python/ray/tune/BUILD | 1 + src/ray/core_worker/context.cc | 4 ++ src/ray/core_worker/context.h | 4 ++ src/ray/core_worker/core_worker.cc | 19 +++--- .../memory_store/memory_store.cc | 4 +- .../store_provider/plasma_store_provider.cc | 62 ++++++++++++----- .../store_provider/plasma_store_provider.h | 9 ++- src/ray/raylet/format/node_manager.fbs | 11 ++- src/ray/raylet/node_manager.cc | 68 ++++++++++++------- src/ray/raylet/node_manager.h | 24 ++++--- src/ray/raylet/raylet_client.cc | 10 +-- src/ray/raylet/raylet_client.h | 7 +- 15 files changed, 148 insertions(+), 82 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 382e190b7..399ae824d 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -397,7 +397,7 @@ cdef class RayletClient: TaskID current_task_id=TaskID.nil()): cdef c_vector[CObjectID] fetch_ids = ObjectIDsToVector(object_ids) check_status(self.client.FetchOrReconstruct( - fetch_ids, fetch_only, current_task_id.native())) + fetch_ids, fetch_only, True, current_task_id.native())) def push_error(self, JobID job_id, error_type, error_message, double timestamp): diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index 93534eb3a..36643ce28 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -50,11 +50,13 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: CRayStatus SubmitTask(const CTaskSpec &task_spec) CRayStatus FetchOrReconstruct(c_vector[CObjectID] &object_ids, c_bool fetch_only, + c_bool is_direct_call_task, const CTaskID ¤t_task_id) CRayStatus NotifyUnblocked(const CTaskID ¤t_task_id) CRayStatus Wait(const c_vector[CObjectID] &object_ids, int num_returns, int64_t timeout_milliseconds, - c_bool wait_local, const CTaskID ¤t_task_id, + c_bool wait_local, c_bool is_direct_call_task, + const CTaskID ¤t_task_id, WaitResultPair *result) CRayStatus PushError(const CJobID &job_id, const c_string &type, const c_string &error_message, double timestamp) diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 6a567b62c..bb99cdbbb 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -263,6 +263,7 @@ py_test( size = "small", srcs = ["test_queue.py"], deps = ["//:ray_lib"], + flaky = 1, ) py_test( diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index d28f0e4ea..09fe9f12b 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -55,6 +55,7 @@ py_test( size = "small", srcs = ["tests/test_experiment.py"], deps = [":tune_lib"], + flaky = 1, ) py_test( diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 37293bddc..325656feb 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -124,6 +124,10 @@ bool WorkerContext::CurrentThreadIsMain() const { return boost::this_thread::get_id() == main_thread_id_; } +bool WorkerContext::ShouldReleaseResourcesOnBlockingCalls() const { + return !CurrentActorIsDirectCall() && CurrentThreadIsMain(); +} + bool WorkerContext::CurrentActorIsDirectCall() const { return current_actor_is_direct_call_; } diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index b9cb059b8..4b776ce49 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -39,6 +39,10 @@ class WorkerContext { /// Returns whether the current thread is the main worker thread. bool CurrentThreadIsMain() const; + /// Returns whether we should Block/Unblock through the raylet on Get/Wait. + /// This only applies to direct task calls. + bool ShouldReleaseResourcesOnBlockingCalls() const; + /// Returns whether we are in a direct call actor. bool CurrentActorIsDirectCall() const; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 80ac7e608..ee6c4a850 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -364,9 +364,8 @@ Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_m bool got_exception = false; absl::flat_hash_map> result_map; auto start_time = current_time_ms(); - RAY_RETURN_NOT_OK(plasma_store_provider_->Get(plasma_object_ids, timeout_ms, - worker_context_.GetCurrentTaskID(), - &result_map, &got_exception)); + RAY_RETURN_NOT_OK(plasma_store_provider_->Get( + plasma_object_ids, timeout_ms, worker_context_, &result_map, &got_exception)); if (!got_exception) { int64_t local_timeout_ms = timeout_ms; @@ -398,8 +397,8 @@ Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_m } RAY_LOG(DEBUG) << "Plasma GET timeout " << local_timeout_ms; RAY_RETURN_NOT_OK(plasma_store_provider_->Get(promoted_plasma_ids, local_timeout_ms, - worker_context_.GetCurrentTaskID(), - &result_map, &got_exception)); + worker_context_, &result_map, + &got_exception)); for (const auto &id : promoted_plasma_ids) { auto it = result_map.find(id); if (it == result_map.end()) { @@ -489,9 +488,8 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, // where we might use up the entire timeout on trying to get objects from one store // provider before even trying another (which might have all of the objects available). if (plasma_object_ids.size() > 0) { - RAY_RETURN_NOT_OK( - plasma_store_provider_->Wait(plasma_object_ids, num_objects, /*timeout_ms=*/0, - worker_context_.GetCurrentTaskID(), &ready)); + RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( + plasma_object_ids, num_objects, /*timeout_ms=*/0, worker_context_, &ready)); } RAY_CHECK(static_cast(ready.size()) <= num_objects); if (static_cast(ready.size()) < num_objects && memory_object_ids.size() > 0) { @@ -510,9 +508,8 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, int64_t start_time = current_time_ms(); if (plasma_object_ids.size() > 0) { - RAY_RETURN_NOT_OK( - plasma_store_provider_->Wait(plasma_object_ids, num_objects, timeout_ms, - worker_context_.GetCurrentTaskID(), &ready)); + RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( + plasma_object_ids, num_objects, timeout_ms, worker_context_, &ready)); } RAY_CHECK(static_cast(ready.size()) <= num_objects); if (timeout_ms > 0) { diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index ab6dfda9e..dcb92064a 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -267,10 +267,8 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, } // Only send block/unblock IPCs for non-actor tasks on the main thread. - // TODO(ekl) support non-lifetime resources for direct actor calls. bool should_notify_raylet = - (raylet_client_ != nullptr && !ctx.CurrentActorIsDirectCall() && - ctx.CurrentThreadIsMain()); + (raylet_client_ != nullptr && ctx.ShouldReleaseResourcesOnBlockingCalls()); // Wait for remaining objects (or timeout). if (should_notify_raylet) { diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index c2986ec67..ef3a1dcf1 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -84,10 +84,11 @@ Status CoreWorkerPlasmaStoreProvider::Seal(const ObjectID &object_id) { Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore( absl::flat_hash_set &remaining, const std::vector &batch_ids, - int64_t timeout_ms, bool fetch_only, const TaskID &task_id, + int64_t timeout_ms, bool fetch_only, bool in_direct_call, const TaskID &task_id, absl::flat_hash_map> *results, bool *got_exception) { - RAY_RETURN_NOT_OK(raylet_client_->FetchOrReconstruct(batch_ids, fetch_only, task_id)); + RAY_RETURN_NOT_OK(raylet_client_->FetchOrReconstruct( + batch_ids, fetch_only, /*mark_worker_blocked*/ !in_direct_call, task_id)); std::vector plasma_batch_ids; plasma_batch_ids.reserve(batch_ids.size()); @@ -127,9 +128,22 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore( return Status::OK(); } +Status UnblockIfNeeded(const std::shared_ptr &client, + const WorkerContext &ctx) { + if (ctx.CurrentTaskIsDirectCall()) { + if (ctx.ShouldReleaseResourcesOnBlockingCalls()) { + return client->NotifyDirectCallTaskUnblocked(); + } else { + return Status::OK(); // We don't need to release resources. + } + } else { + return client->NotifyUnblocked(ctx.GetCurrentTaskID()); + } +} + Status CoreWorkerPlasmaStoreProvider::Get( const absl::flat_hash_set &object_ids, int64_t timeout_ms, - const TaskID &task_id, + const WorkerContext &ctx, absl::flat_hash_map> *results, bool *got_exception) { int64_t batch_size = RayConfig::instance().worker_fetch_request_size(); @@ -144,9 +158,10 @@ Status CoreWorkerPlasmaStoreProvider::Get( for (int64_t i = start; i < batch_size && i < total_size; i++) { batch_ids.push_back(id_vector[start + i]); } - RAY_RETURN_NOT_OK(FetchAndGetFromPlasmaStore(remaining, batch_ids, /*timeout_ms=*/0, - /*fetch_only=*/true, task_id, results, - got_exception)); + RAY_RETURN_NOT_OK( + FetchAndGetFromPlasmaStore(remaining, batch_ids, /*timeout_ms=*/0, + /*fetch_only=*/true, ctx.CurrentTaskIsDirectCall(), + ctx.GetCurrentTaskID(), results, got_exception)); } // If all objects were fetched already, return. @@ -179,12 +194,14 @@ Status CoreWorkerPlasmaStoreProvider::Get( } size_t previous_size = remaining.size(); - // TODO: For direct calls, use NotifyDirectCallTaskBlocked/Unblocked calls - // for missing objects instead of going through the normal fetch-and-get - // codepath. - RAY_RETURN_NOT_OK(FetchAndGetFromPlasmaStore(remaining, batch_ids, batch_timeout, - /*fetch_only=*/false, task_id, results, - got_exception)); + // This is a separate IPC from the FetchAndGet in direct call mode. + if (ctx.CurrentTaskIsDirectCall() && ctx.ShouldReleaseResourcesOnBlockingCalls()) { + RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked()); + } + RAY_RETURN_NOT_OK( + FetchAndGetFromPlasmaStore(remaining, batch_ids, batch_timeout, + /*fetch_only=*/false, ctx.CurrentTaskIsDirectCall(), + ctx.GetCurrentTaskID(), results, got_exception)); should_break = timed_out || *got_exception; if ((previous_size - remaining.size()) < batch_ids.size()) { @@ -195,20 +212,20 @@ Status CoreWorkerPlasmaStoreProvider::Get( Status status = check_signals_(); if (!status.ok()) { // TODO(edoakes): in this case which status should we return? - RAY_RETURN_NOT_OK(raylet_client_->NotifyUnblocked(task_id)); + RAY_RETURN_NOT_OK(UnblockIfNeeded(raylet_client_, ctx)); return status; } } } if (!remaining.empty() && timed_out) { - RAY_RETURN_NOT_OK(raylet_client_->NotifyUnblocked(task_id)); + RAY_RETURN_NOT_OK(UnblockIfNeeded(raylet_client_, ctx)); return Status::TimedOut("Get timed out: some object(s) not ready."); } // Notify unblocked because we blocked when calling FetchOrReconstruct with // fetch_only=false. - return raylet_client_->NotifyUnblocked(task_id); + return UnblockIfNeeded(raylet_client_, ctx); } Status CoreWorkerPlasmaStoreProvider::Contains(const ObjectID &object_id, @@ -220,7 +237,7 @@ Status CoreWorkerPlasmaStoreProvider::Contains(const ObjectID &object_id, Status CoreWorkerPlasmaStoreProvider::Wait( const absl::flat_hash_set &object_ids, int num_objects, int64_t timeout_ms, - const TaskID &task_id, absl::flat_hash_set *ready) { + const WorkerContext &ctx, absl::flat_hash_set *ready) { std::vector id_vector(object_ids.begin(), object_ids.end()); bool should_break = false; @@ -234,8 +251,14 @@ Status CoreWorkerPlasmaStoreProvider::Wait( should_break = remaining_timeout <= 0; } - RAY_RETURN_NOT_OK(raylet_client_->Wait(id_vector, num_objects, call_timeout, false, - task_id, &result_pair)); + // This is a separate IPC from the Wait in direct call mode. + if (ctx.CurrentTaskIsDirectCall() && ctx.ShouldReleaseResourcesOnBlockingCalls()) { + RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked()); + } + RAY_RETURN_NOT_OK( + raylet_client_->Wait(id_vector, num_objects, call_timeout, false, + /*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(), + ctx.GetCurrentTaskID(), &result_pair)); if (result_pair.first.size() >= static_cast(num_objects)) { should_break = true; @@ -247,6 +270,9 @@ Status CoreWorkerPlasmaStoreProvider::Wait( RAY_RETURN_NOT_OK(check_signals_()); } } + if (ctx.CurrentTaskIsDirectCall() && ctx.ShouldReleaseResourcesOnBlockingCalls()) { + RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskUnblocked()); + } return Status::OK(); } diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index c908da8cc..c6cac4212 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -8,6 +8,7 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/core_worker/common.h" +#include "ray/core_worker/context.h" #include "ray/raylet/raylet_client.h" namespace ray { @@ -34,14 +35,14 @@ class CoreWorkerPlasmaStoreProvider { Status Seal(const ObjectID &object_id); Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, - const TaskID &task_id, + const WorkerContext &ctx, absl::flat_hash_map> *results, bool *got_exception); Status Contains(const ObjectID &object_id, bool *has_object); Status Wait(const absl::flat_hash_set &object_ids, int num_objects, - int64_t timeout_ms, const TaskID &task_id, + int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_set *ready); Status Delete(const absl::flat_hash_set &object_ids, bool local_only, @@ -59,6 +60,7 @@ class CoreWorkerPlasmaStoreProvider { /// \param[in] timeout_ms Timeout in milliseconds. /// \param[in] fetch_only Whether the raylet should only fetch or also attempt to /// reconstruct objects. + /// \param[in] in_direct_call_task Whether the current task is direct call. /// \param[in] task_id The current TaskID. /// \param[out] results Map of objects to write results into. This method will only /// add to this map, not clear or remove from it, so the caller can pass in a non-empty @@ -68,7 +70,8 @@ class CoreWorkerPlasmaStoreProvider { /// \return Status. Status FetchAndGetFromPlasmaStore( absl::flat_hash_set &remaining, const std::vector &batch_ids, - int64_t timeout_ms, bool fetch_only, const TaskID &task_id, + int64_t timeout_ms, bool fetch_only, bool in_direct_call_task, + const TaskID &task_id, absl::flat_hash_map> *results, bool *got_exception); diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index cccbba91a..16959b730 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -164,7 +164,10 @@ table FetchOrReconstruct { object_ids: [string]; // Do we only want to fetch the objects or also reconstruct them? fetch_only: bool; - // The current task ID. If fetch_only is false, then this task is blocked. + // False for direct call tasks. Blocking for those tasks is handled via the + // NotifyDirectCallTaskBlocked/Unblocked IPCs. + mark_worker_blocked: bool; + // The current task ID. task_id: string; } @@ -188,8 +191,10 @@ table WaitRequest { timeout: long; // Whether to wait until objects appear locally. wait_local: bool; - // The current task ID. If there are less than num_ready_objects local, then - // this task is blocked. + // False for direct call tasks. Blocking for those tasks is handled via the + // NotifyDirectCallTaskBlocked/Unblocked IPCs. + mark_worker_blocked: bool; + // The current task ID. task_id: string; } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index f833924c4..0e39e35b9 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -922,7 +922,8 @@ void NodeManager::ProcessClientMessage( } break; case protocol::MessageType::NotifyUnblocked: { auto message = flatbuffers::GetRoot(message_data); - HandleTaskUnblocked(client, from_flatbuf(*message->task_id())); + AsyncResolveObjectsFinish(client, from_flatbuf(*message->task_id()), + /*was_blocked*/ true); } break; case protocol::MessageType::WaitRequest: { ProcessWaitRequestMessage(client, message_data); @@ -1113,10 +1114,10 @@ void NodeManager::ProcessDisconnectClientMessage( } else { // Clean up any open ray.get calls that the worker made. while (!worker->GetBlockedTaskIds().empty()) { - // NOTE(swang): HandleTaskUnblocked will modify the worker, so it is + // NOTE(swang): AsyncResolveObjectsFinish will modify the worker, so it is // not safe to pass in the iterator directly. const TaskID task_id = *worker->GetBlockedTaskIds().begin(); - HandleTaskUnblocked(client, task_id); + AsyncResolveObjectsFinish(client, task_id, true); } // Clean up any open ray.wait calls that the worker made. task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); @@ -1235,7 +1236,8 @@ void NodeManager::ProcessFetchOrReconstructMessage( if (!required_object_ids.empty()) { const TaskID task_id = from_flatbuf(*message->task_id()); - HandleTaskBlocked(client, required_object_ids, task_id, /*ray_get=*/true); + AsyncResolveObjects(client, required_object_ids, task_id, /*ray_get=*/true, + /*mark_worker_blocked*/ message->mark_worker_blocked()); } } @@ -1259,15 +1261,17 @@ void NodeManager::ProcessWaitRequestMessage( } const TaskID ¤t_task_id = from_flatbuf(*message->task_id()); - bool client_blocked = !required_object_ids.empty(); - if (client_blocked) { - HandleTaskBlocked(client, required_object_ids, current_task_id, /*ray_get=*/false); + bool resolve_objects = !required_object_ids.empty(); + bool was_blocked = message->mark_worker_blocked(); + if (resolve_objects) { + AsyncResolveObjects(client, required_object_ids, current_task_id, /*ray_get=*/false, + /*mark_worker_blocked*/ was_blocked); } ray::Status status = object_manager_.Wait( object_ids, wait_ms, num_required_objects, wait_local, - [this, client_blocked, client, current_task_id](std::vector found, - std::vector remaining) { + [this, resolve_objects, was_blocked, client, current_task_id]( + std::vector found, std::vector remaining) { // Write the data. flatbuffers::FlatBufferBuilder fbb; flatbuffers::Offset wait_reply = protocol::CreateWaitReply( @@ -1279,8 +1283,8 @@ void NodeManager::ProcessWaitRequestMessage( fbb.GetSize(), fbb.GetBufferPointer()); if (status.ok()) { // The client is unblocked now because the wait call has returned. - if (client_blocked) { - HandleTaskUnblocked(client, current_task_id); + if (resolve_objects) { + AsyncResolveObjectsFinish(client, current_task_id, was_blocked); } } else { // We failed to write to the client, so disconnect the client. @@ -1908,18 +1912,21 @@ void NodeManager::HandleDirectCallTaskUnblocked(const std::shared_ptr &w .ToString(); } worker->MarkUnblocked(); + task_dependency_manager_.UnsubscribeGetDependencies(task_id); } -void NodeManager::HandleTaskBlocked(const std::shared_ptr &client, - const std::vector &required_object_ids, - const TaskID ¤t_task_id, bool ray_get) { +void NodeManager::AsyncResolveObjects( + const std::shared_ptr &client, + const std::vector &required_object_ids, const TaskID ¤t_task_id, + bool ray_get, bool mark_worker_blocked) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (worker) { // The client is a worker. If the worker is not already blocked and the // blocked task matches the one assigned to the worker, then mark the // worker as blocked. This temporarily releases any resources that the // worker holds while it is blocked. - if (!worker->IsBlocked() && current_task_id == worker->GetAssignedTaskId()) { + if (mark_worker_blocked && !worker->IsBlocked() && + current_task_id == worker->GetAssignedTaskId()) { Task task; RAY_CHECK(local_queues_.RemoveTask(current_task_id, &task)); local_queues_.QueueTasks({task}, TaskState::RUNNING); @@ -1942,25 +1949,31 @@ void NodeManager::HandleTaskBlocked(const std::shared_ptr RAY_CHECK(worker); // Mark the task as blocked. - worker->AddBlockedTaskId(current_task_id); - if (local_queues_.GetBlockedTaskIds().count(current_task_id) == 0) { - local_queues_.AddBlockedTaskId(current_task_id); + if (mark_worker_blocked) { + worker->AddBlockedTaskId(current_task_id); + if (local_queues_.GetBlockedTaskIds().count(current_task_id) == 0) { + local_queues_.AddBlockedTaskId(current_task_id); + } } // Subscribe to the objects required by the task. These objects will be // fetched and/or reconstructed as necessary, until the objects become local // or are unsubscribed. if (ray_get) { - task_dependency_manager_.SubscribeGetDependencies(current_task_id, - required_object_ids); + // TODO(ekl) using the assigned task id is a hack to handle unsubscription for + // HandleDirectCallUnblocked. + task_dependency_manager_.SubscribeGetDependencies( + mark_worker_blocked ? current_task_id : worker->GetAssignedTaskId(), + required_object_ids); } else { task_dependency_manager_.SubscribeWaitDependencies(worker->WorkerId(), required_object_ids); } } -void NodeManager::HandleTaskUnblocked( - const std::shared_ptr &client, const TaskID ¤t_task_id) { +void NodeManager::AsyncResolveObjectsFinish( + const std::shared_ptr &client, const TaskID ¤t_task_id, + bool was_blocked) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); // TODO(swang): Because the object dependencies are tracked in the task @@ -1972,8 +1985,8 @@ void NodeManager::HandleTaskUnblocked( // worker as unblocked. This returns the temporarily released resources to // the worker. Workers that have been marked dead have already been cleaned // up. - if (worker->IsBlocked() && current_task_id == worker->GetAssignedTaskId() && - !worker->IsDead()) { + if (was_blocked && worker->IsBlocked() && + current_task_id == worker->GetAssignedTaskId() && !worker->IsDead()) { // (See design_docs/task_states.rst for the state transition diagram.) Task task; RAY_CHECK(local_queues_.RemoveTask(current_task_id, &task)); @@ -2017,8 +2030,10 @@ void NodeManager::HandleTaskUnblocked( task_dependency_manager_.UnsubscribeGetDependencies(current_task_id); // Mark the task as unblocked. RAY_CHECK(worker); - worker->RemoveBlockedTaskId(current_task_id); - local_queues_.RemoveBlockedTaskId(current_task_id); + if (was_blocked) { + worker->RemoveBlockedTaskId(current_task_id); + local_queues_.RemoveBlockedTaskId(current_task_id); + } } void NodeManager::EnqueuePlaceableTask(const Task &task) { @@ -2475,6 +2490,7 @@ void NodeManager::HandleObjectLocal(const ObjectID &object_id) { // First filter out the tasks that should not be moved to READY. local_queues_.FilterState(ready_task_id_set, TaskState::BLOCKED); + local_queues_.FilterState(ready_task_id_set, TaskState::RUNNING); local_queues_.FilterState(ready_task_id_set, TaskState::DRIVER); local_queues_.FilterState(ready_task_id_set, TaskState::WAITING_FOR_ACTOR_CREATION); diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index ad4416c1f..d57169dc3 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -303,7 +303,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { void DispatchTasks( const std::unordered_map> &tasks_by_class); - /// Handle a task that is blocked. This could be a task assigned to a worker, + /// Handle blocking gets of objects. This could be a task assigned to a worker, /// an out-of-band task (e.g., a thread created by the application), or a /// driver task. This can be triggered when a client starts a get call or a /// wait call. @@ -311,24 +311,28 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param client The client that is executing the blocked task. /// \param required_object_ids The IDs that the client is blocked waiting for. /// \param current_task_id The task that is blocked. - /// \param ray_get Whether the task is blocked in a `ray.get` call, as - /// opposed to a `ray.wait` call. + /// \param ray_get Whether the task is blocked in a `ray.get` call. + /// \param mark_worker_blocked Whether to mark the worker as blocked. This + /// should be False for direct calls. /// \return Void. - void HandleTaskBlocked(const std::shared_ptr &client, - const std::vector &required_object_ids, - const TaskID ¤t_task_id, bool ray_get); + void AsyncResolveObjects(const std::shared_ptr &client, + const std::vector &required_object_ids, + const TaskID ¤t_task_id, bool ray_get, + bool mark_worker_blocked); - /// Handle a task that is unblocked. This could be a task assigned to a + /// Handle end of a blocking object get. This could be a task assigned to a /// worker, an out-of-band task (e.g., a thread created by the application), /// or a driver task. This can be triggered when a client finishes a get call /// or a wait call. The given task must be blocked, via a previous call to - /// HandleTaskBlocked. + /// AsyncResolveObjects. /// /// \param client The client that is executing the unblocked task. /// \param current_task_id The task that is unblocked. + /// \param worker_was_blocked Whether we previously marked the worker as + /// blocked in AsyncResolveObjects(). /// \return Void. - void HandleTaskUnblocked(const std::shared_ptr &client, - const TaskID ¤t_task_id); + void AsyncResolveObjectsFinish(const std::shared_ptr &client, + const TaskID ¤t_task_id, bool was_blocked); /// Handle a direct call task that is blocked. Note that this callback may /// arrive after the worker lease has been returned to the node manager. diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 1c12d2ad8..72fc4e6bb 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -249,12 +249,13 @@ ray::Status RayletClient::TaskDone() { } ray::Status RayletClient::FetchOrReconstruct(const std::vector &object_ids, - bool fetch_only, + bool fetch_only, bool mark_worker_blocked, const TaskID ¤t_task_id) { flatbuffers::FlatBufferBuilder fbb; auto object_ids_message = to_flatbuf(fbb, object_ids); auto message = ray::protocol::CreateFetchOrReconstruct( - fbb, object_ids_message, fetch_only, to_flatbuf(fbb, current_task_id)); + fbb, object_ids_message, fetch_only, mark_worker_blocked, + to_flatbuf(fbb, current_task_id)); fbb.Finish(message); auto status = conn_->WriteMessage(MessageType::FetchOrReconstruct, &fbb); return status; @@ -284,12 +285,13 @@ ray::Status RayletClient::NotifyDirectCallTaskUnblocked() { ray::Status RayletClient::Wait(const std::vector &object_ids, int num_returns, int64_t timeout_milliseconds, bool wait_local, - const TaskID ¤t_task_id, WaitResultPair *result) { + bool mark_worker_blocked, const TaskID ¤t_task_id, + WaitResultPair *result) { // Write request. flatbuffers::FlatBufferBuilder fbb; auto message = ray::protocol::CreateWaitRequest( fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, wait_local, - to_flatbuf(fbb, current_task_id)); + mark_worker_blocked, to_flatbuf(fbb, current_task_id)); fbb.Finish(message); std::unique_ptr reply; auto status = conn_->AtomicRequestReply(MessageType::WaitRequest, diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index caa85b13c..4671bf24a 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -123,10 +123,11 @@ class RayletClient : public WorkerLeaseInterface { /// /// \param object_ids The IDs of the objects to reconstruct. /// \param fetch_only Only fetch objects, do not reconstruct them. + /// \param mark_worker_blocked Set to false if current task is a direct call task. /// \param current_task_id The task that needs the objects. /// \return int 0 means correct, other numbers mean error. ray::Status FetchOrReconstruct(const std::vector &object_ids, bool fetch_only, - const TaskID ¤t_task_id); + bool mark_worker_blocked, const TaskID ¤t_task_id); /// Notify the raylet that this client (worker) is no longer blocked. /// @@ -153,13 +154,15 @@ class RayletClient : public WorkerLeaseInterface { /// \param num_returns The number of objects to wait for. /// \param timeout_milliseconds Duration, in milliseconds, to wait before returning. /// \param wait_local Whether to wait for objects to appear on this node. + /// \param mark_worker_blocked Set to false if current task is a direct call task. /// \param current_task_id The task that called wait. /// \param result A pair with the first element containing the object ids that were /// found, and the second element the objects that were not found. /// \return ray::Status. ray::Status Wait(const std::vector &object_ids, int num_returns, int64_t timeout_milliseconds, bool wait_local, - const TaskID ¤t_task_id, WaitResultPair *result); + bool mark_worker_blocked, const TaskID ¤t_task_id, + WaitResultPair *result); /// Wait for the given objects, asynchronously. The core worker is notified when /// the wait completes.