mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 11:27:06 +08:00
Support NotifyBlocked/UnBlocked for direct call tasks (#6177)
This commit is contained in:
@@ -55,7 +55,8 @@ WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id)
|
||||
worker_id_(worker_type_ == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id)
|
||||
: WorkerID::FromRandom()),
|
||||
current_job_id_(worker_type_ == WorkerType::DRIVER ? job_id : JobID::Nil()),
|
||||
current_actor_id_(ActorID::Nil()) {
|
||||
current_actor_id_(ActorID::Nil()),
|
||||
main_thread_id_(boost::this_thread::get_id()) {
|
||||
// For worker main thread which initializes the WorkerContext,
|
||||
// set task_id according to whether current worker is a driver.
|
||||
// (For other threads it's set to random ID via GetThreadContext).
|
||||
@@ -118,6 +119,10 @@ std::shared_ptr<const TaskSpecification> WorkerContext::GetCurrentTask() const {
|
||||
|
||||
const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; }
|
||||
|
||||
bool WorkerContext::CurrentThreadIsMain() const {
|
||||
return boost::this_thread::get_id() == main_thread_id_;
|
||||
}
|
||||
|
||||
bool WorkerContext::CurrentActorIsDirectCall() const {
|
||||
return current_actor_is_direct_call_;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#ifndef RAY_CORE_WORKER_CONTEXT_H
|
||||
#define RAY_CORE_WORKER_CONTEXT_H
|
||||
|
||||
#include <boost/thread.hpp>
|
||||
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/core_worker/common.h"
|
||||
|
||||
@@ -34,6 +36,9 @@ class WorkerContext {
|
||||
|
||||
const ActorID &GetCurrentActorID() const;
|
||||
|
||||
/// Returns whether the current thread is the main worker thread.
|
||||
bool CurrentThreadIsMain() const;
|
||||
|
||||
/// Returns whether we are in a direct call actor.
|
||||
bool CurrentActorIsDirectCall() const;
|
||||
|
||||
@@ -56,6 +61,9 @@ class WorkerContext {
|
||||
bool current_task_is_direct_call_ = false;
|
||||
int current_actor_max_concurrency_ = 1;
|
||||
|
||||
/// The id of the (main) thread that constructed this worker context.
|
||||
boost::thread::id main_thread_id_;
|
||||
|
||||
private:
|
||||
static WorkerThreadContext &GetThreadContext(bool for_main_thread = false);
|
||||
|
||||
|
||||
@@ -165,7 +165,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||
[this](const RayObject &obj, const ObjectID &obj_id) {
|
||||
RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id));
|
||||
},
|
||||
ref_counting_enabled ? reference_counter_ : nullptr));
|
||||
ref_counting_enabled ? reference_counter_ : nullptr, raylet_client_));
|
||||
memory_store_provider_.reset(new CoreWorkerMemoryStoreProvider(memory_store_));
|
||||
|
||||
// Create an entry for the driver task in the task table. This task is
|
||||
@@ -263,8 +263,7 @@ void CoreWorker::ReportActiveObjectIDs() {
|
||||
reference_counter_->GetAllInScopeObjectIDs();
|
||||
RAY_LOG(DEBUG) << "Sending " << active_object_ids.size() << " object IDs to raylet.";
|
||||
if (active_object_ids.size() > RayConfig::instance().raylet_max_active_object_ids()) {
|
||||
RAY_LOG(WARNING) << active_object_ids.size() << "object IDs are currently in scope. "
|
||||
<< "This may lead to required objects being garbage collected.";
|
||||
RAY_LOG(WARNING) << active_object_ids.size() << " object IDs are currently in scope.";
|
||||
}
|
||||
|
||||
if (!raylet_client_->ReportActiveObjectIDs(active_object_ids).ok()) {
|
||||
@@ -347,8 +346,8 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
|
||||
timeout_ms - (current_time_ms() - start_time));
|
||||
}
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_->Get(memory_object_ids, local_timeout_ms,
|
||||
worker_context_.GetCurrentTaskID(),
|
||||
&result_map, &got_exception));
|
||||
worker_context_, &result_map,
|
||||
&got_exception));
|
||||
}
|
||||
|
||||
// If any of the objects have been promoted to plasma, then we retry their
|
||||
@@ -454,7 +453,7 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
|
||||
// consider waiting on them in plasma as well to ensure they are local.
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_->Wait(
|
||||
memory_object_ids, num_objects - static_cast<int>(ready.size()),
|
||||
/*timeout_ms=*/0, worker_context_.GetCurrentTaskID(), &ready));
|
||||
/*timeout_ms=*/0, worker_context_, &ready));
|
||||
}
|
||||
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
|
||||
|
||||
@@ -477,7 +476,7 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
|
||||
if (static_cast<int>(ready.size()) < num_objects && memory_object_ids.size() > 0) {
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_->Wait(
|
||||
memory_object_ids, num_objects - static_cast<int>(ready.size()), timeout_ms,
|
||||
worker_context_.GetCurrentTaskID(), &ready));
|
||||
worker_context_, &ready));
|
||||
}
|
||||
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
|
||||
}
|
||||
|
||||
@@ -161,7 +161,8 @@ TEST(MemoryStoreIntegrationTest, TestSimple) {
|
||||
RAY_CHECK_OK(store.Put(id1, buffer));
|
||||
ASSERT_EQ(store.Size(), 1);
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
RAY_CHECK_OK(store.Get({id1}, /*num_objects*/ 1, /*timeout_ms*/ -1,
|
||||
WorkerContext ctx(WorkerType::WORKER, JobID::Nil());
|
||||
RAY_CHECK_OK(store.Get({id1}, /*num_objects*/ 1, /*timeout_ms*/ -1, ctx,
|
||||
/*remove_after_get*/ true, &results));
|
||||
ASSERT_EQ(results.size(), 1);
|
||||
ASSERT_EQ(store.Size(), 1);
|
||||
|
||||
@@ -109,8 +109,11 @@ std::shared_ptr<RayObject> GetRequest::Get(const ObjectID &object_id) const {
|
||||
|
||||
CoreWorkerMemoryStore::CoreWorkerMemoryStore(
|
||||
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma,
|
||||
std::shared_ptr<ReferenceCounter> counter)
|
||||
: store_in_plasma_(store_in_plasma), ref_counter_(counter) {}
|
||||
std::shared_ptr<ReferenceCounter> counter,
|
||||
std::shared_ptr<RayletClient> raylet_client)
|
||||
: store_in_plasma_(store_in_plasma),
|
||||
ref_counter_(counter),
|
||||
raylet_client_(raylet_client) {}
|
||||
|
||||
void CoreWorkerMemoryStore::GetAsync(
|
||||
const ObjectID &object_id, std::function<void(std::shared_ptr<RayObject>)> callback) {
|
||||
@@ -208,7 +211,7 @@ Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &ob
|
||||
|
||||
Status CoreWorkerMemoryStore::Get(const std::vector<ObjectID> &object_ids,
|
||||
int num_objects, int64_t timeout_ms,
|
||||
bool remove_after_get,
|
||||
const WorkerContext &ctx, bool remove_after_get,
|
||||
std::vector<std::shared_ptr<RayObject>> *results) {
|
||||
(*results).resize(object_ids.size(), nullptr);
|
||||
|
||||
@@ -260,8 +263,20 @@ Status CoreWorkerMemoryStore::Get(const std::vector<ObjectID> &object_ids,
|
||||
}
|
||||
}
|
||||
|
||||
// Only send block/unblock IPCs for non-actor tasks on the main thread.
|
||||
// TODO(ekl) support non-lifetime resources for direct actor calls.
|
||||
bool should_notify_raylet =
|
||||
(raylet_client_ != nullptr && !ctx.CurrentActorIsDirectCall() &&
|
||||
ctx.CurrentThreadIsMain());
|
||||
|
||||
// Wait for remaining objects (or timeout).
|
||||
if (should_notify_raylet) {
|
||||
RAY_CHECK_OK(raylet_client_->NotifyDirectCallTaskBlocked());
|
||||
}
|
||||
bool done = get_request->Wait(timeout_ms);
|
||||
if (should_notify_raylet) {
|
||||
RAY_CHECK_OK(raylet_client_->NotifyDirectCallTaskUnblocked());
|
||||
}
|
||||
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/core_worker/common.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/reference_count.h"
|
||||
|
||||
namespace ray {
|
||||
@@ -24,9 +25,11 @@ class CoreWorkerMemoryStore {
|
||||
/// \param[in] store_in_plasma If not null, this is used to spill to plasma.
|
||||
/// \param[in] counter If not null, this enables ref counting for local objects,
|
||||
/// and the `remove_after_get` flag for Get() will be ignored.
|
||||
/// \param[in] raylet_client If not null, used to notify tasks blocked / unblocked.
|
||||
CoreWorkerMemoryStore(
|
||||
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma = nullptr,
|
||||
std::shared_ptr<ReferenceCounter> counter = nullptr);
|
||||
std::shared_ptr<ReferenceCounter> counter = nullptr,
|
||||
std::shared_ptr<RayletClient> raylet_client = nullptr);
|
||||
~CoreWorkerMemoryStore(){};
|
||||
|
||||
/// Put an object with specified ID into object store.
|
||||
@@ -41,12 +44,14 @@ class CoreWorkerMemoryStore {
|
||||
/// \param[in] object_ids IDs of the objects to get. Duplicates are not allowed.
|
||||
/// \param[in] num_objects Number of objects that should appear.
|
||||
/// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative.
|
||||
/// \param[in] ctx The current worker context.
|
||||
/// \param[in] remove_after_get When to remove the objects from store after `Get`
|
||||
/// finishes. This has no effect if ref counting is enabled.
|
||||
/// \param[out] results Result list of objects data.
|
||||
/// \return Status.
|
||||
Status Get(const std::vector<ObjectID> &object_ids, int num_objects, int64_t timeout_ms,
|
||||
bool remove_after_get, std::vector<std::shared_ptr<RayObject>> *results);
|
||||
const WorkerContext &ctx, bool remove_after_get,
|
||||
std::vector<std::shared_ptr<RayObject>> *results);
|
||||
|
||||
/// Asynchronously get an object from the object store. The object will not be removed
|
||||
/// from storage after GetAsync (TODO(ekl): integrate this with object GC).
|
||||
@@ -93,6 +98,9 @@ class CoreWorkerMemoryStore {
|
||||
/// mandatory once Java is supported.
|
||||
std::shared_ptr<ReferenceCounter> ref_counter_ = nullptr;
|
||||
|
||||
// If set, this will be used to notify worker blocked / unblocked on get calls.
|
||||
std::shared_ptr<RayletClient> raylet_client_ = nullptr;
|
||||
|
||||
/// Protects the data structures below.
|
||||
absl::Mutex mu_;
|
||||
|
||||
|
||||
@@ -25,13 +25,13 @@ Status CoreWorkerMemoryStoreProvider::Put(const RayObject &object,
|
||||
|
||||
Status CoreWorkerMemoryStoreProvider::Get(
|
||||
const absl::flat_hash_set<ObjectID> &object_ids, int64_t timeout_ms,
|
||||
const TaskID &task_id,
|
||||
const WorkerContext &ctx,
|
||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
|
||||
bool *got_exception) {
|
||||
const std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end());
|
||||
std::vector<std::shared_ptr<RayObject>> result_objects;
|
||||
RAY_RETURN_NOT_OK(
|
||||
store_->Get(id_vector, id_vector.size(), timeout_ms, true, &result_objects));
|
||||
store_->Get(id_vector, id_vector.size(), timeout_ms, ctx, true, &result_objects));
|
||||
|
||||
for (size_t i = 0; i < id_vector.size(); i++) {
|
||||
if (result_objects[i] != nullptr) {
|
||||
@@ -52,11 +52,12 @@ Status CoreWorkerMemoryStoreProvider::Contains(const ObjectID &object_id,
|
||||
|
||||
Status CoreWorkerMemoryStoreProvider::Wait(
|
||||
const absl::flat_hash_set<ObjectID> &object_ids, int num_objects, int64_t timeout_ms,
|
||||
const TaskID &task_id, absl::flat_hash_set<ObjectID> *ready) {
|
||||
const WorkerContext &ctx, absl::flat_hash_set<ObjectID> *ready) {
|
||||
std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end());
|
||||
std::vector<std::shared_ptr<RayObject>> result_objects;
|
||||
RAY_CHECK(object_ids.size() == id_vector.size());
|
||||
auto status = store_->Get(id_vector, num_objects, timeout_ms, false, &result_objects);
|
||||
auto status =
|
||||
store_->Get(id_vector, num_objects, timeout_ms, ctx, false, &result_objects);
|
||||
// Ignore TimedOut statuses since we return ready objects explicitly.
|
||||
if (!status.IsTimedOut()) {
|
||||
RAY_RETURN_NOT_OK(status);
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/core_worker/common.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
|
||||
namespace ray {
|
||||
@@ -27,7 +28,7 @@ class CoreWorkerMemoryStoreProvider {
|
||||
Status Put(const RayObject &object, const ObjectID &object_id);
|
||||
|
||||
Status Get(const absl::flat_hash_set<ObjectID> &object_ids, int64_t timeout_ms,
|
||||
const TaskID &task_id,
|
||||
const WorkerContext &ctx,
|
||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
|
||||
bool *got_exception);
|
||||
|
||||
@@ -35,7 +36,7 @@ class CoreWorkerMemoryStoreProvider {
|
||||
|
||||
/// Note that `num_objects` must equal to number of items in `object_ids`.
|
||||
Status Wait(const absl::flat_hash_set<ObjectID> &object_ids, int num_objects,
|
||||
int64_t timeout_ms, const TaskID &task_id,
|
||||
int64_t timeout_ms, const WorkerContext &ctx,
|
||||
absl::flat_hash_set<ObjectID> *ready);
|
||||
|
||||
/// Note that `local_only` must be true, and `delete_creating_tasks` must be false here.
|
||||
|
||||
@@ -646,15 +646,15 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
|
||||
absl::flat_hash_set<ObjectID> wait_results;
|
||||
|
||||
ObjectID nonexistent_id = ObjectID::FromRandom().WithDirectTransportType();
|
||||
WorkerContext ctx(WorkerType::WORKER, JobID::Nil());
|
||||
wait_ids.insert(nonexistent_id);
|
||||
RAY_CHECK_OK(
|
||||
provider.Wait(wait_ids, ids.size() + 1, 100, RandomTaskId(), &wait_results));
|
||||
RAY_CHECK_OK(provider.Wait(wait_ids, ids.size() + 1, 100, ctx, &wait_results));
|
||||
ASSERT_EQ(wait_results.size(), ids.size());
|
||||
ASSERT_TRUE(wait_results.count(nonexistent_id) == 0);
|
||||
|
||||
// Test Wait() where the required `num_objects` is less than size of `wait_ids`.
|
||||
wait_results.clear();
|
||||
RAY_CHECK_OK(provider.Wait(wait_ids, ids.size(), -1, RandomTaskId(), &wait_results));
|
||||
RAY_CHECK_OK(provider.Wait(wait_ids, ids.size(), -1, ctx, &wait_results));
|
||||
ASSERT_EQ(wait_results.size(), ids.size());
|
||||
ASSERT_TRUE(wait_results.count(nonexistent_id) == 0);
|
||||
|
||||
@@ -662,7 +662,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
|
||||
bool got_exception = false;
|
||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> results;
|
||||
absl::flat_hash_set<ObjectID> ids_set(ids.begin(), ids.end());
|
||||
RAY_CHECK_OK(provider.Get(ids_set, -1, RandomTaskId(), &results, &got_exception));
|
||||
RAY_CHECK_OK(provider.Get(ids_set, -1, ctx, &results, &got_exception));
|
||||
|
||||
ASSERT_TRUE(!got_exception);
|
||||
ASSERT_EQ(results.size(), ids.size());
|
||||
@@ -685,8 +685,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
|
||||
RAY_CHECK_OK(provider.Delete(ids_set));
|
||||
|
||||
usleep(200 * 1000);
|
||||
ASSERT_TRUE(
|
||||
provider.Get(ids_set, 0, RandomTaskId(), &results, &got_exception).IsTimedOut());
|
||||
ASSERT_TRUE(provider.Get(ids_set, 0, ctx, &results, &got_exception).IsTimedOut());
|
||||
ASSERT_TRUE(!got_exception);
|
||||
ASSERT_EQ(results.size(), 0);
|
||||
|
||||
@@ -715,8 +714,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
|
||||
wait_results.clear();
|
||||
|
||||
// Check that only the ready ids are returned when timeout ends before thread runs.
|
||||
RAY_CHECK_OK(
|
||||
provider.Wait(wait_ids, ready_ids.size() + 1, 100, RandomTaskId(), &wait_results));
|
||||
RAY_CHECK_OK(provider.Wait(wait_ids, ready_ids.size() + 1, 100, ctx, &wait_results));
|
||||
ASSERT_EQ(ready_ids.size(), wait_results.size());
|
||||
for (const auto &ready_id : ready_ids) {
|
||||
ASSERT_TRUE(wait_results.find(ready_id) != wait_results.end());
|
||||
@@ -727,8 +725,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
|
||||
|
||||
wait_results.clear();
|
||||
// Check that enough objects are returned after the thread inserts at least one object.
|
||||
RAY_CHECK_OK(
|
||||
provider.Wait(wait_ids, ready_ids.size() + 1, 5000, RandomTaskId(), &wait_results));
|
||||
RAY_CHECK_OK(provider.Wait(wait_ids, ready_ids.size() + 1, 5000, ctx, &wait_results));
|
||||
ASSERT_TRUE(wait_results.size() >= ready_ids.size() + 1);
|
||||
for (const auto &ready_id : ready_ids) {
|
||||
ASSERT_TRUE(wait_results.find(ready_id) != wait_results.end());
|
||||
@@ -737,8 +734,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
|
||||
wait_results.clear();
|
||||
// Check that all objects are returned after the thread completes.
|
||||
async_thread.join();
|
||||
RAY_CHECK_OK(
|
||||
provider.Wait(wait_ids, wait_ids.size(), -1, RandomTaskId(), &wait_results));
|
||||
RAY_CHECK_OK(provider.Wait(wait_ids, wait_ids.size(), -1, ctx, &wait_results));
|
||||
ASSERT_EQ(wait_results.size(), ready_ids.size() + unready_ids.size());
|
||||
for (const auto &ready_id : ready_ids) {
|
||||
ASSERT_TRUE(wait_results.find(ready_id) != wait_results.end());
|
||||
|
||||
@@ -31,7 +31,8 @@ class CoreWorkerRayletTaskReceiver {
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
private:
|
||||
/// Raylet client.
|
||||
/// Reference to the core worker's raylet client. This is a pointer ref so that it
|
||||
/// can be initialized by core worker after this class is constructed.
|
||||
std::shared_ptr<RayletClient> &raylet_client_;
|
||||
/// The callback function to process a task.
|
||||
TaskHandler task_handler_;
|
||||
|
||||
@@ -31,6 +31,11 @@ enum MessageType:int {
|
||||
// For a worker that was blocked on some object(s), tell the raylet
|
||||
// that the worker is now unblocked. This is sent from a worker to a raylet.
|
||||
NotifyUnblocked,
|
||||
// Notify the current worker is blocked. This is only used by direct task calls;
|
||||
// otherwise the block command is piggybacked on other messages.
|
||||
NotifyDirectCallTaskBlocked,
|
||||
// Notify the current worker is unblocked. This is only used by direct task calls.
|
||||
NotifyDirectCallTaskUnblocked,
|
||||
// A request to get the task frontier for an actor, called by the actor when
|
||||
// saving a checkpoint.
|
||||
GetActorFrontierRequest,
|
||||
@@ -161,6 +166,12 @@ table NotifyUnblocked {
|
||||
task_id: string;
|
||||
}
|
||||
|
||||
table NotifyDirectCallTaskBlocked {
|
||||
}
|
||||
|
||||
table NotifyDirectCallTaskUnblocked {
|
||||
}
|
||||
|
||||
table WaitRequest {
|
||||
// List of object ids we'll be waiting on.
|
||||
object_ids: [string];
|
||||
|
||||
@@ -833,6 +833,7 @@ void NodeManager::DispatchTasks(
|
||||
return local_queues_.NumRunning(a->first) < local_queues_.NumRunning(b->first);
|
||||
});
|
||||
}
|
||||
std::vector<std::function<void()>> post_assign_callbacks;
|
||||
// Approximate fair round robin between classes.
|
||||
for (const auto &it : fair_order) {
|
||||
const auto &task_resources =
|
||||
@@ -845,7 +846,7 @@ void NodeManager::DispatchTasks(
|
||||
// once the first task is not feasible, we can break out of this loop
|
||||
break;
|
||||
}
|
||||
if (AssignTask(task)) {
|
||||
if (AssignTask(task, &post_assign_callbacks)) {
|
||||
removed_task_ids.insert(task_id);
|
||||
}
|
||||
}
|
||||
@@ -854,6 +855,9 @@ void NodeManager::DispatchTasks(
|
||||
// it queued locally. Once the GetTaskReply has been sent, the task will get
|
||||
// re-queued, depending on whether the message succeeded or not.
|
||||
local_queues_.MoveTasks(removed_task_ids, TaskState::READY, TaskState::SWAP);
|
||||
for (auto func : post_assign_callbacks) {
|
||||
func();
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::ProcessClientMessage(
|
||||
@@ -902,6 +906,14 @@ void NodeManager::ProcessClientMessage(
|
||||
case protocol::MessageType::FetchOrReconstruct: {
|
||||
ProcessFetchOrReconstructMessage(client, message_data);
|
||||
} break;
|
||||
case protocol::MessageType::NotifyDirectCallTaskBlocked: {
|
||||
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
|
||||
HandleDirectCallTaskBlocked(worker);
|
||||
} break;
|
||||
case protocol::MessageType::NotifyDirectCallTaskUnblocked: {
|
||||
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
|
||||
HandleDirectCallTaskUnblocked(worker);
|
||||
} break;
|
||||
case protocol::MessageType::NotifyUnblocked: {
|
||||
auto message = flatbuffers::GetRoot<protocol::NotifyUnblocked>(message_data);
|
||||
HandleTaskUnblocked(client, from_flatbuf<TaskID>(*message->task_id()));
|
||||
@@ -1103,6 +1115,8 @@ void NodeManager::ProcessDisconnectClientMessage(
|
||||
// Clean up any open ray.wait calls that the worker made.
|
||||
task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId());
|
||||
}
|
||||
// Erase any lease metadata.
|
||||
leased_workers_.erase(worker->Port());
|
||||
}
|
||||
|
||||
if (is_worker) {
|
||||
@@ -1435,6 +1449,7 @@ void NodeManager::HandleWorkerLeaseRequest(const rpc::WorkerLeaseRequest &reques
|
||||
// TODO(swang): Kill worker if other end hangs up.
|
||||
// TODO(swang): Implement a lease term by which the owner needs to return the
|
||||
// worker.
|
||||
RAY_CHECK(leased_workers_.find(port) == leased_workers_.end());
|
||||
leased_workers_[port] = std::static_pointer_cast<Worker>(granted);
|
||||
});
|
||||
task.OnSpillbackInstead(
|
||||
@@ -1454,14 +1469,18 @@ void NodeManager::HandleReturnWorker(const rpc::ReturnWorkerRequest &request,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
// Read the resource spec submitted by the client.
|
||||
auto worker_port = request.worker_port();
|
||||
RAY_LOG(DEBUG) << "Return worker " << worker_port;
|
||||
std::shared_ptr<Worker> worker = std::move(leased_workers_[worker_port]);
|
||||
leased_workers_.erase(worker_port);
|
||||
Status status;
|
||||
if (worker) {
|
||||
// Handle the edge case where the worker was returned before we got the
|
||||
// unblock RPC by unblocking it immediately (unblock is idempotent).
|
||||
if (worker->IsBlocked()) {
|
||||
HandleDirectCallTaskUnblocked(worker);
|
||||
}
|
||||
HandleWorkerAvailable(worker);
|
||||
} else {
|
||||
status = Status::Invalid("Returned worker does not exist");
|
||||
status = Status::Invalid("Returned worker does not exist any more");
|
||||
}
|
||||
send_reply_callback(status, nullptr, nullptr);
|
||||
}
|
||||
@@ -1844,6 +1863,48 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::HandleDirectCallTaskBlocked(const std::shared_ptr<Worker> &worker) {
|
||||
if (!worker || worker->GetAssignedTaskId().IsNil() || worker->IsBlocked()) {
|
||||
return; // The worker may have died or is no longer processing the task.
|
||||
}
|
||||
auto const cpu_resource_ids = worker->ReleaseTaskCpuResources();
|
||||
local_available_resources_.Release(cpu_resource_ids);
|
||||
cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release(
|
||||
cpu_resource_ids.ToResourceSet());
|
||||
worker->MarkBlocked();
|
||||
DispatchTasks(local_queues_.GetReadyTasksByClass());
|
||||
}
|
||||
|
||||
void NodeManager::HandleDirectCallTaskUnblocked(const std::shared_ptr<Worker> &worker) {
|
||||
if (!worker || worker->GetAssignedTaskId().IsNil() || !worker->IsBlocked()) {
|
||||
return; // The worker may have died or is no longer processing the task.
|
||||
}
|
||||
TaskID task_id = worker->GetAssignedTaskId();
|
||||
Task task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING);
|
||||
const auto required_resources = task.GetTaskSpecification().GetRequiredResources();
|
||||
const ResourceSet cpu_resources = required_resources.GetNumCpus();
|
||||
bool oversubscribed = !local_available_resources_.Contains(cpu_resources);
|
||||
if (!oversubscribed) {
|
||||
// Reacquire the CPU resources for the worker. Note that care needs to be
|
||||
// taken if the user is using the specific CPU IDs since the IDs that we
|
||||
// reacquire here may be different from the ones that the task started with.
|
||||
auto const resource_ids = local_available_resources_.Acquire(cpu_resources);
|
||||
worker->AcquireTaskCpuResources(resource_ids);
|
||||
cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Acquire(
|
||||
cpu_resources);
|
||||
} else {
|
||||
// In this case, we simply don't reacquire the CPU resources for the worker.
|
||||
// The worker can keep running and when the task finishes, it will simply
|
||||
// not have any CPU resources to release.
|
||||
RAY_LOG(WARNING)
|
||||
<< "Resources oversubscribed: "
|
||||
<< cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()]
|
||||
.GetAvailableResources()
|
||||
.ToString();
|
||||
}
|
||||
worker->MarkUnblocked();
|
||||
}
|
||||
|
||||
void NodeManager::HandleTaskBlocked(const std::shared_ptr<LocalClientConnection> &client,
|
||||
const std::vector<ObjectID> &required_object_ids,
|
||||
const TaskID ¤t_task_id, bool ray_get) {
|
||||
@@ -1884,12 +1945,14 @@ void NodeManager::HandleTaskBlocked(const std::shared_ptr<LocalClientConnection>
|
||||
// Subscribe to the objects required by the task. These objects will be
|
||||
// fetched and/or reconstructed as necessary, until the objects become local
|
||||
// or are unsubscribed.
|
||||
if (ray_get) {
|
||||
task_dependency_manager_.SubscribeGetDependencies(current_task_id,
|
||||
required_object_ids);
|
||||
} else {
|
||||
task_dependency_manager_.SubscribeWaitDependencies(worker->WorkerId(),
|
||||
required_object_ids);
|
||||
if (!required_object_ids.empty()) {
|
||||
if (ray_get) {
|
||||
task_dependency_manager_.SubscribeGetDependencies(current_task_id,
|
||||
required_object_ids);
|
||||
} else {
|
||||
task_dependency_manager_.SubscribeWaitDependencies(worker->WorkerId(),
|
||||
required_object_ids);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1974,7 +2037,8 @@ void NodeManager::EnqueuePlaceableTask(const Task &task) {
|
||||
task_dependency_manager_.TaskPending(task);
|
||||
}
|
||||
|
||||
bool NodeManager::AssignTask(const Task &task) {
|
||||
bool NodeManager::AssignTask(const Task &task,
|
||||
std::vector<std::function<void()>> *post_assign_callbacks) {
|
||||
const TaskSpecification &spec = task.GetTaskSpecification();
|
||||
|
||||
// If this is an actor task, check that the new task has the correct counter.
|
||||
@@ -2036,7 +2100,12 @@ bool NodeManager::AssignTask(const Task &task) {
|
||||
|
||||
if (task.OnDispatch() != nullptr) {
|
||||
task.OnDispatch()(worker, initial_config_.node_manager_address, worker->Port());
|
||||
finish_assign_task_callback(Status::OK());
|
||||
if (post_assign_callbacks != nullptr) {
|
||||
// Moves the tasks from SWAP to RUNNING state atomically. This avoids race
|
||||
// conditions with ReturnLease requests.
|
||||
post_assign_callbacks->push_back(
|
||||
[this, worker, task_id]() { FinishAssignTask(task_id, *worker, true); });
|
||||
}
|
||||
} else {
|
||||
worker->AssignTask(task, resource_id_set, finish_assign_task_callback);
|
||||
}
|
||||
|
||||
@@ -211,8 +211,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
|
||||
/// Assign a task. The task is assumed to not be queued in local_queues_.
|
||||
///
|
||||
/// \param task The task in question.
|
||||
/// \param post_assign_callbacks Set of functions to run after assignments finish.
|
||||
/// \return true, if tasks was assigned to a worker, false otherwise.
|
||||
bool AssignTask(const Task &task);
|
||||
bool AssignTask(const Task &task,
|
||||
std::vector<std::function<void()>> *post_assign_callbacks);
|
||||
/// Handle a worker finishing its assigned task.
|
||||
///
|
||||
/// \param worker The worker that finished the task.
|
||||
@@ -328,6 +330,19 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
|
||||
void HandleTaskUnblocked(const std::shared_ptr<LocalClientConnection> &client,
|
||||
const TaskID ¤t_task_id);
|
||||
|
||||
/// Handle a direct call task that is blocked. Note that this callback may
|
||||
/// arrive after the worker lease has been returned to the node manager.
|
||||
///
|
||||
/// \param worker Shared ptr to the worker, or nullptr if lost.
|
||||
void HandleDirectCallTaskBlocked(const std::shared_ptr<Worker> &worker);
|
||||
|
||||
/// Handle a direct call task that is unblocked. Note that this callback may
|
||||
/// arrive after the worker lease has been returned to the node manager.
|
||||
/// However, it is guaranteed to arrive after DirectCallTaskBlocked.
|
||||
///
|
||||
/// \param worker Shared ptr to the worker, or nullptr if lost.
|
||||
void HandleDirectCallTaskUnblocked(const std::shared_ptr<Worker> &worker);
|
||||
|
||||
/// Kill a worker.
|
||||
///
|
||||
/// \param worker The worker to kill.
|
||||
|
||||
@@ -258,6 +258,20 @@ ray::Status RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) {
|
||||
return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::NotifyDirectCallTaskBlocked() {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateNotifyDirectCallTaskBlocked(fbb);
|
||||
fbb.Finish(message);
|
||||
return conn_->WriteMessage(MessageType::NotifyDirectCallTaskBlocked, &fbb);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::NotifyDirectCallTaskUnblocked() {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateNotifyDirectCallTaskUnblocked(fbb);
|
||||
fbb.Finish(message);
|
||||
return conn_->WriteMessage(MessageType::NotifyDirectCallTaskUnblocked, &fbb);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::Wait(const std::vector<ObjectID> &object_ids, int num_returns,
|
||||
int64_t timeout_milliseconds, bool wait_local,
|
||||
const TaskID ¤t_task_id, WaitResultPair *result) {
|
||||
@@ -392,6 +406,8 @@ ray::Status RayletClient::ReturnWorker(int worker_port) {
|
||||
request.set_worker_port(worker_port);
|
||||
return grpc_client_->ReturnWorker(
|
||||
request, [](const ray::Status &status, const ray::rpc::ReturnWorkerReply &reply) {
|
||||
RAY_CHECK_OK(status);
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(ERROR) << "Error returning worker: " << status;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -126,12 +126,25 @@ class RayletClient : public WorkerLeaseInterface {
|
||||
/// \return int 0 means correct, other numbers mean error.
|
||||
ray::Status FetchOrReconstruct(const std::vector<ObjectID> &object_ids, bool fetch_only,
|
||||
const TaskID ¤t_task_id);
|
||||
|
||||
/// Notify the raylet that this client (worker) is no longer blocked.
|
||||
///
|
||||
/// \param current_task_id The task that is no longer blocked.
|
||||
/// \return ray::Status.
|
||||
ray::Status NotifyUnblocked(const TaskID ¤t_task_id);
|
||||
|
||||
/// Notify the raylet that this client is blocked. This is only used for direct task
|
||||
/// calls. Note that ordering of this with respect to Unblock calls is important.
|
||||
///
|
||||
/// \return ray::Status.
|
||||
ray::Status NotifyDirectCallTaskBlocked();
|
||||
|
||||
/// Notify the raylet that this client is unblocked. This is only used for direct task
|
||||
/// calls. Note that ordering of this with respect to Block calls is important.
|
||||
///
|
||||
/// \return ray::Status.
|
||||
ray::Status NotifyDirectCallTaskUnblocked();
|
||||
|
||||
/// Wait for the given objects until timeout expires or num_return objects are
|
||||
/// found.
|
||||
///
|
||||
|
||||
Reference in New Issue
Block a user