Support NotifyBlocked/UnBlocked for direct call tasks (#6177)

This commit is contained in:
Eric Liang
2019-11-20 22:07:12 -08:00
committed by GitHub
parent db77595298
commit 425edb5cd9
15 changed files with 205 additions and 46 deletions
+6 -1
View File
@@ -55,7 +55,8 @@ WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id)
worker_id_(worker_type_ == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id)
: WorkerID::FromRandom()),
current_job_id_(worker_type_ == WorkerType::DRIVER ? job_id : JobID::Nil()),
current_actor_id_(ActorID::Nil()) {
current_actor_id_(ActorID::Nil()),
main_thread_id_(boost::this_thread::get_id()) {
// For worker main thread which initializes the WorkerContext,
// set task_id according to whether current worker is a driver.
// (For other threads it's set to random ID via GetThreadContext).
@@ -118,6 +119,10 @@ std::shared_ptr<const TaskSpecification> WorkerContext::GetCurrentTask() const {
const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; }
bool WorkerContext::CurrentThreadIsMain() const {
return boost::this_thread::get_id() == main_thread_id_;
}
bool WorkerContext::CurrentActorIsDirectCall() const {
return current_actor_is_direct_call_;
}
+8
View File
@@ -1,6 +1,8 @@
#ifndef RAY_CORE_WORKER_CONTEXT_H
#define RAY_CORE_WORKER_CONTEXT_H
#include <boost/thread.hpp>
#include "ray/common/task/task_spec.h"
#include "ray/core_worker/common.h"
@@ -34,6 +36,9 @@ class WorkerContext {
const ActorID &GetCurrentActorID() const;
/// Returns whether the current thread is the main worker thread.
bool CurrentThreadIsMain() const;
/// Returns whether we are in a direct call actor.
bool CurrentActorIsDirectCall() const;
@@ -56,6 +61,9 @@ class WorkerContext {
bool current_task_is_direct_call_ = false;
int current_actor_max_concurrency_ = 1;
/// The id of the (main) thread that constructed this worker context.
boost::thread::id main_thread_id_;
private:
static WorkerThreadContext &GetThreadContext(bool for_main_thread = false);
+6 -7
View File
@@ -165,7 +165,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
[this](const RayObject &obj, const ObjectID &obj_id) {
RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id));
},
ref_counting_enabled ? reference_counter_ : nullptr));
ref_counting_enabled ? reference_counter_ : nullptr, raylet_client_));
memory_store_provider_.reset(new CoreWorkerMemoryStoreProvider(memory_store_));
// Create an entry for the driver task in the task table. This task is
@@ -263,8 +263,7 @@ void CoreWorker::ReportActiveObjectIDs() {
reference_counter_->GetAllInScopeObjectIDs();
RAY_LOG(DEBUG) << "Sending " << active_object_ids.size() << " object IDs to raylet.";
if (active_object_ids.size() > RayConfig::instance().raylet_max_active_object_ids()) {
RAY_LOG(WARNING) << active_object_ids.size() << "object IDs are currently in scope. "
<< "This may lead to required objects being garbage collected.";
RAY_LOG(WARNING) << active_object_ids.size() << " object IDs are currently in scope.";
}
if (!raylet_client_->ReportActiveObjectIDs(active_object_ids).ok()) {
@@ -347,8 +346,8 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
timeout_ms - (current_time_ms() - start_time));
}
RAY_RETURN_NOT_OK(memory_store_provider_->Get(memory_object_ids, local_timeout_ms,
worker_context_.GetCurrentTaskID(),
&result_map, &got_exception));
worker_context_, &result_map,
&got_exception));
}
// If any of the objects have been promoted to plasma, then we retry their
@@ -454,7 +453,7 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
// consider waiting on them in plasma as well to ensure they are local.
RAY_RETURN_NOT_OK(memory_store_provider_->Wait(
memory_object_ids, num_objects - static_cast<int>(ready.size()),
/*timeout_ms=*/0, worker_context_.GetCurrentTaskID(), &ready));
/*timeout_ms=*/0, worker_context_, &ready));
}
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
@@ -477,7 +476,7 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
if (static_cast<int>(ready.size()) < num_objects && memory_object_ids.size() > 0) {
RAY_RETURN_NOT_OK(memory_store_provider_->Wait(
memory_object_ids, num_objects - static_cast<int>(ready.size()), timeout_ms,
worker_context_.GetCurrentTaskID(), &ready));
worker_context_, &ready));
}
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
}
+2 -1
View File
@@ -161,7 +161,8 @@ TEST(MemoryStoreIntegrationTest, TestSimple) {
RAY_CHECK_OK(store.Put(id1, buffer));
ASSERT_EQ(store.Size(), 1);
std::vector<std::shared_ptr<RayObject>> results;
RAY_CHECK_OK(store.Get({id1}, /*num_objects*/ 1, /*timeout_ms*/ -1,
WorkerContext ctx(WorkerType::WORKER, JobID::Nil());
RAY_CHECK_OK(store.Get({id1}, /*num_objects*/ 1, /*timeout_ms*/ -1, ctx,
/*remove_after_get*/ true, &results));
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(store.Size(), 1);
@@ -109,8 +109,11 @@ std::shared_ptr<RayObject> GetRequest::Get(const ObjectID &object_id) const {
CoreWorkerMemoryStore::CoreWorkerMemoryStore(
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma,
std::shared_ptr<ReferenceCounter> counter)
: store_in_plasma_(store_in_plasma), ref_counter_(counter) {}
std::shared_ptr<ReferenceCounter> counter,
std::shared_ptr<RayletClient> raylet_client)
: store_in_plasma_(store_in_plasma),
ref_counter_(counter),
raylet_client_(raylet_client) {}
void CoreWorkerMemoryStore::GetAsync(
const ObjectID &object_id, std::function<void(std::shared_ptr<RayObject>)> callback) {
@@ -208,7 +211,7 @@ Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &ob
Status CoreWorkerMemoryStore::Get(const std::vector<ObjectID> &object_ids,
int num_objects, int64_t timeout_ms,
bool remove_after_get,
const WorkerContext &ctx, bool remove_after_get,
std::vector<std::shared_ptr<RayObject>> *results) {
(*results).resize(object_ids.size(), nullptr);
@@ -260,8 +263,20 @@ Status CoreWorkerMemoryStore::Get(const std::vector<ObjectID> &object_ids,
}
}
// Only send block/unblock IPCs for non-actor tasks on the main thread.
// TODO(ekl) support non-lifetime resources for direct actor calls.
bool should_notify_raylet =
(raylet_client_ != nullptr && !ctx.CurrentActorIsDirectCall() &&
ctx.CurrentThreadIsMain());
// Wait for remaining objects (or timeout).
if (should_notify_raylet) {
RAY_CHECK_OK(raylet_client_->NotifyDirectCallTaskBlocked());
}
bool done = get_request->Wait(timeout_ms);
if (should_notify_raylet) {
RAY_CHECK_OK(raylet_client_->NotifyDirectCallTaskUnblocked());
}
{
absl::MutexLock lock(&mu_);
@@ -7,6 +7,7 @@
#include "ray/common/id.h"
#include "ray/common/status.h"
#include "ray/core_worker/common.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/reference_count.h"
namespace ray {
@@ -24,9 +25,11 @@ class CoreWorkerMemoryStore {
/// \param[in] store_in_plasma If not null, this is used to spill to plasma.
/// \param[in] counter If not null, this enables ref counting for local objects,
/// and the `remove_after_get` flag for Get() will be ignored.
/// \param[in] raylet_client If not null, used to notify tasks blocked / unblocked.
CoreWorkerMemoryStore(
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma = nullptr,
std::shared_ptr<ReferenceCounter> counter = nullptr);
std::shared_ptr<ReferenceCounter> counter = nullptr,
std::shared_ptr<RayletClient> raylet_client = nullptr);
~CoreWorkerMemoryStore(){};
/// Put an object with specified ID into object store.
@@ -41,12 +44,14 @@ class CoreWorkerMemoryStore {
/// \param[in] object_ids IDs of the objects to get. Duplicates are not allowed.
/// \param[in] num_objects Number of objects that should appear.
/// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative.
/// \param[in] ctx The current worker context.
/// \param[in] remove_after_get When to remove the objects from store after `Get`
/// finishes. This has no effect if ref counting is enabled.
/// \param[out] results Result list of objects data.
/// \return Status.
Status Get(const std::vector<ObjectID> &object_ids, int num_objects, int64_t timeout_ms,
bool remove_after_get, std::vector<std::shared_ptr<RayObject>> *results);
const WorkerContext &ctx, bool remove_after_get,
std::vector<std::shared_ptr<RayObject>> *results);
/// Asynchronously get an object from the object store. The object will not be removed
/// from storage after GetAsync (TODO(ekl): integrate this with object GC).
@@ -93,6 +98,9 @@ class CoreWorkerMemoryStore {
/// mandatory once Java is supported.
std::shared_ptr<ReferenceCounter> ref_counter_ = nullptr;
// If set, this will be used to notify worker blocked / unblocked on get calls.
std::shared_ptr<RayletClient> raylet_client_ = nullptr;
/// Protects the data structures below.
absl::Mutex mu_;
@@ -25,13 +25,13 @@ Status CoreWorkerMemoryStoreProvider::Put(const RayObject &object,
Status CoreWorkerMemoryStoreProvider::Get(
const absl::flat_hash_set<ObjectID> &object_ids, int64_t timeout_ms,
const TaskID &task_id,
const WorkerContext &ctx,
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
bool *got_exception) {
const std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end());
std::vector<std::shared_ptr<RayObject>> result_objects;
RAY_RETURN_NOT_OK(
store_->Get(id_vector, id_vector.size(), timeout_ms, true, &result_objects));
store_->Get(id_vector, id_vector.size(), timeout_ms, ctx, true, &result_objects));
for (size_t i = 0; i < id_vector.size(); i++) {
if (result_objects[i] != nullptr) {
@@ -52,11 +52,12 @@ Status CoreWorkerMemoryStoreProvider::Contains(const ObjectID &object_id,
Status CoreWorkerMemoryStoreProvider::Wait(
const absl::flat_hash_set<ObjectID> &object_ids, int num_objects, int64_t timeout_ms,
const TaskID &task_id, absl::flat_hash_set<ObjectID> *ready) {
const WorkerContext &ctx, absl::flat_hash_set<ObjectID> *ready) {
std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end());
std::vector<std::shared_ptr<RayObject>> result_objects;
RAY_CHECK(object_ids.size() == id_vector.size());
auto status = store_->Get(id_vector, num_objects, timeout_ms, false, &result_objects);
auto status =
store_->Get(id_vector, num_objects, timeout_ms, ctx, false, &result_objects);
// Ignore TimedOut statuses since we return ready objects explicitly.
if (!status.IsTimedOut()) {
RAY_RETURN_NOT_OK(status);
@@ -7,6 +7,7 @@
#include "ray/common/id.h"
#include "ray/common/status.h"
#include "ray/core_worker/common.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
namespace ray {
@@ -27,7 +28,7 @@ class CoreWorkerMemoryStoreProvider {
Status Put(const RayObject &object, const ObjectID &object_id);
Status Get(const absl::flat_hash_set<ObjectID> &object_ids, int64_t timeout_ms,
const TaskID &task_id,
const WorkerContext &ctx,
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
bool *got_exception);
@@ -35,7 +36,7 @@ class CoreWorkerMemoryStoreProvider {
/// Note that `num_objects` must equal to number of items in `object_ids`.
Status Wait(const absl::flat_hash_set<ObjectID> &object_ids, int num_objects,
int64_t timeout_ms, const TaskID &task_id,
int64_t timeout_ms, const WorkerContext &ctx,
absl::flat_hash_set<ObjectID> *ready);
/// Note that `local_only` must be true, and `delete_creating_tasks` must be false here.
+8 -12
View File
@@ -646,15 +646,15 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
absl::flat_hash_set<ObjectID> wait_results;
ObjectID nonexistent_id = ObjectID::FromRandom().WithDirectTransportType();
WorkerContext ctx(WorkerType::WORKER, JobID::Nil());
wait_ids.insert(nonexistent_id);
RAY_CHECK_OK(
provider.Wait(wait_ids, ids.size() + 1, 100, RandomTaskId(), &wait_results));
RAY_CHECK_OK(provider.Wait(wait_ids, ids.size() + 1, 100, ctx, &wait_results));
ASSERT_EQ(wait_results.size(), ids.size());
ASSERT_TRUE(wait_results.count(nonexistent_id) == 0);
// Test Wait() where the required `num_objects` is less than size of `wait_ids`.
wait_results.clear();
RAY_CHECK_OK(provider.Wait(wait_ids, ids.size(), -1, RandomTaskId(), &wait_results));
RAY_CHECK_OK(provider.Wait(wait_ids, ids.size(), -1, ctx, &wait_results));
ASSERT_EQ(wait_results.size(), ids.size());
ASSERT_TRUE(wait_results.count(nonexistent_id) == 0);
@@ -662,7 +662,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
bool got_exception = false;
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> results;
absl::flat_hash_set<ObjectID> ids_set(ids.begin(), ids.end());
RAY_CHECK_OK(provider.Get(ids_set, -1, RandomTaskId(), &results, &got_exception));
RAY_CHECK_OK(provider.Get(ids_set, -1, ctx, &results, &got_exception));
ASSERT_TRUE(!got_exception);
ASSERT_EQ(results.size(), ids.size());
@@ -685,8 +685,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
RAY_CHECK_OK(provider.Delete(ids_set));
usleep(200 * 1000);
ASSERT_TRUE(
provider.Get(ids_set, 0, RandomTaskId(), &results, &got_exception).IsTimedOut());
ASSERT_TRUE(provider.Get(ids_set, 0, ctx, &results, &got_exception).IsTimedOut());
ASSERT_TRUE(!got_exception);
ASSERT_EQ(results.size(), 0);
@@ -715,8 +714,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
wait_results.clear();
// Check that only the ready ids are returned when timeout ends before thread runs.
RAY_CHECK_OK(
provider.Wait(wait_ids, ready_ids.size() + 1, 100, RandomTaskId(), &wait_results));
RAY_CHECK_OK(provider.Wait(wait_ids, ready_ids.size() + 1, 100, ctx, &wait_results));
ASSERT_EQ(ready_ids.size(), wait_results.size());
for (const auto &ready_id : ready_ids) {
ASSERT_TRUE(wait_results.find(ready_id) != wait_results.end());
@@ -727,8 +725,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
wait_results.clear();
// Check that enough objects are returned after the thread inserts at least one object.
RAY_CHECK_OK(
provider.Wait(wait_ids, ready_ids.size() + 1, 5000, RandomTaskId(), &wait_results));
RAY_CHECK_OK(provider.Wait(wait_ids, ready_ids.size() + 1, 5000, ctx, &wait_results));
ASSERT_TRUE(wait_results.size() >= ready_ids.size() + 1);
for (const auto &ready_id : ready_ids) {
ASSERT_TRUE(wait_results.find(ready_id) != wait_results.end());
@@ -737,8 +734,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
wait_results.clear();
// Check that all objects are returned after the thread completes.
async_thread.join();
RAY_CHECK_OK(
provider.Wait(wait_ids, wait_ids.size(), -1, RandomTaskId(), &wait_results));
RAY_CHECK_OK(provider.Wait(wait_ids, wait_ids.size(), -1, ctx, &wait_results));
ASSERT_EQ(wait_results.size(), ready_ids.size() + unready_ids.size());
for (const auto &ready_id : ready_ids) {
ASSERT_TRUE(wait_results.find(ready_id) != wait_results.end());
@@ -31,7 +31,8 @@ class CoreWorkerRayletTaskReceiver {
rpc::SendReplyCallback send_reply_callback);
private:
/// Raylet client.
/// Reference to the core worker's raylet client. This is a pointer ref so that it
/// can be initialized by core worker after this class is constructed.
std::shared_ptr<RayletClient> &raylet_client_;
/// The callback function to process a task.
TaskHandler task_handler_;
+11
View File
@@ -31,6 +31,11 @@ enum MessageType:int {
// For a worker that was blocked on some object(s), tell the raylet
// that the worker is now unblocked. This is sent from a worker to a raylet.
NotifyUnblocked,
// Notify the current worker is blocked. This is only used by direct task calls;
// otherwise the block command is piggybacked on other messages.
NotifyDirectCallTaskBlocked,
// Notify the current worker is unblocked. This is only used by direct task calls.
NotifyDirectCallTaskUnblocked,
// A request to get the task frontier for an actor, called by the actor when
// saving a checkpoint.
GetActorFrontierRequest,
@@ -161,6 +166,12 @@ table NotifyUnblocked {
task_id: string;
}
table NotifyDirectCallTaskBlocked {
}
table NotifyDirectCallTaskUnblocked {
}
table WaitRequest {
// List of object ids we'll be waiting on.
object_ids: [string];
+80 -11
View File
@@ -833,6 +833,7 @@ void NodeManager::DispatchTasks(
return local_queues_.NumRunning(a->first) < local_queues_.NumRunning(b->first);
});
}
std::vector<std::function<void()>> post_assign_callbacks;
// Approximate fair round robin between classes.
for (const auto &it : fair_order) {
const auto &task_resources =
@@ -845,7 +846,7 @@ void NodeManager::DispatchTasks(
// once the first task is not feasible, we can break out of this loop
break;
}
if (AssignTask(task)) {
if (AssignTask(task, &post_assign_callbacks)) {
removed_task_ids.insert(task_id);
}
}
@@ -854,6 +855,9 @@ void NodeManager::DispatchTasks(
// it queued locally. Once the GetTaskReply has been sent, the task will get
// re-queued, depending on whether the message succeeded or not.
local_queues_.MoveTasks(removed_task_ids, TaskState::READY, TaskState::SWAP);
for (auto func : post_assign_callbacks) {
func();
}
}
void NodeManager::ProcessClientMessage(
@@ -902,6 +906,14 @@ void NodeManager::ProcessClientMessage(
case protocol::MessageType::FetchOrReconstruct: {
ProcessFetchOrReconstructMessage(client, message_data);
} break;
case protocol::MessageType::NotifyDirectCallTaskBlocked: {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
HandleDirectCallTaskBlocked(worker);
} break;
case protocol::MessageType::NotifyDirectCallTaskUnblocked: {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
HandleDirectCallTaskUnblocked(worker);
} break;
case protocol::MessageType::NotifyUnblocked: {
auto message = flatbuffers::GetRoot<protocol::NotifyUnblocked>(message_data);
HandleTaskUnblocked(client, from_flatbuf<TaskID>(*message->task_id()));
@@ -1103,6 +1115,8 @@ void NodeManager::ProcessDisconnectClientMessage(
// Clean up any open ray.wait calls that the worker made.
task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId());
}
// Erase any lease metadata.
leased_workers_.erase(worker->Port());
}
if (is_worker) {
@@ -1435,6 +1449,7 @@ void NodeManager::HandleWorkerLeaseRequest(const rpc::WorkerLeaseRequest &reques
// TODO(swang): Kill worker if other end hangs up.
// TODO(swang): Implement a lease term by which the owner needs to return the
// worker.
RAY_CHECK(leased_workers_.find(port) == leased_workers_.end());
leased_workers_[port] = std::static_pointer_cast<Worker>(granted);
});
task.OnSpillbackInstead(
@@ -1454,14 +1469,18 @@ void NodeManager::HandleReturnWorker(const rpc::ReturnWorkerRequest &request,
rpc::SendReplyCallback send_reply_callback) {
// Read the resource spec submitted by the client.
auto worker_port = request.worker_port();
RAY_LOG(DEBUG) << "Return worker " << worker_port;
std::shared_ptr<Worker> worker = std::move(leased_workers_[worker_port]);
leased_workers_.erase(worker_port);
Status status;
if (worker) {
// Handle the edge case where the worker was returned before we got the
// unblock RPC by unblocking it immediately (unblock is idempotent).
if (worker->IsBlocked()) {
HandleDirectCallTaskUnblocked(worker);
}
HandleWorkerAvailable(worker);
} else {
status = Status::Invalid("Returned worker does not exist");
status = Status::Invalid("Returned worker does not exist any more");
}
send_reply_callback(status, nullptr, nullptr);
}
@@ -1844,6 +1863,48 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag
}
}
void NodeManager::HandleDirectCallTaskBlocked(const std::shared_ptr<Worker> &worker) {
if (!worker || worker->GetAssignedTaskId().IsNil() || worker->IsBlocked()) {
return; // The worker may have died or is no longer processing the task.
}
auto const cpu_resource_ids = worker->ReleaseTaskCpuResources();
local_available_resources_.Release(cpu_resource_ids);
cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release(
cpu_resource_ids.ToResourceSet());
worker->MarkBlocked();
DispatchTasks(local_queues_.GetReadyTasksByClass());
}
void NodeManager::HandleDirectCallTaskUnblocked(const std::shared_ptr<Worker> &worker) {
if (!worker || worker->GetAssignedTaskId().IsNil() || !worker->IsBlocked()) {
return; // The worker may have died or is no longer processing the task.
}
TaskID task_id = worker->GetAssignedTaskId();
Task task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING);
const auto required_resources = task.GetTaskSpecification().GetRequiredResources();
const ResourceSet cpu_resources = required_resources.GetNumCpus();
bool oversubscribed = !local_available_resources_.Contains(cpu_resources);
if (!oversubscribed) {
// Reacquire the CPU resources for the worker. Note that care needs to be
// taken if the user is using the specific CPU IDs since the IDs that we
// reacquire here may be different from the ones that the task started with.
auto const resource_ids = local_available_resources_.Acquire(cpu_resources);
worker->AcquireTaskCpuResources(resource_ids);
cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Acquire(
cpu_resources);
} else {
// In this case, we simply don't reacquire the CPU resources for the worker.
// The worker can keep running and when the task finishes, it will simply
// not have any CPU resources to release.
RAY_LOG(WARNING)
<< "Resources oversubscribed: "
<< cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()]
.GetAvailableResources()
.ToString();
}
worker->MarkUnblocked();
}
void NodeManager::HandleTaskBlocked(const std::shared_ptr<LocalClientConnection> &client,
const std::vector<ObjectID> &required_object_ids,
const TaskID &current_task_id, bool ray_get) {
@@ -1884,12 +1945,14 @@ void NodeManager::HandleTaskBlocked(const std::shared_ptr<LocalClientConnection>
// Subscribe to the objects required by the task. These objects will be
// fetched and/or reconstructed as necessary, until the objects become local
// or are unsubscribed.
if (ray_get) {
task_dependency_manager_.SubscribeGetDependencies(current_task_id,
required_object_ids);
} else {
task_dependency_manager_.SubscribeWaitDependencies(worker->WorkerId(),
required_object_ids);
if (!required_object_ids.empty()) {
if (ray_get) {
task_dependency_manager_.SubscribeGetDependencies(current_task_id,
required_object_ids);
} else {
task_dependency_manager_.SubscribeWaitDependencies(worker->WorkerId(),
required_object_ids);
}
}
}
@@ -1974,7 +2037,8 @@ void NodeManager::EnqueuePlaceableTask(const Task &task) {
task_dependency_manager_.TaskPending(task);
}
bool NodeManager::AssignTask(const Task &task) {
bool NodeManager::AssignTask(const Task &task,
std::vector<std::function<void()>> *post_assign_callbacks) {
const TaskSpecification &spec = task.GetTaskSpecification();
// If this is an actor task, check that the new task has the correct counter.
@@ -2036,7 +2100,12 @@ bool NodeManager::AssignTask(const Task &task) {
if (task.OnDispatch() != nullptr) {
task.OnDispatch()(worker, initial_config_.node_manager_address, worker->Port());
finish_assign_task_callback(Status::OK());
if (post_assign_callbacks != nullptr) {
// Moves the tasks from SWAP to RUNNING state atomically. This avoids race
// conditions with ReturnLease requests.
post_assign_callbacks->push_back(
[this, worker, task_id]() { FinishAssignTask(task_id, *worker, true); });
}
} else {
worker->AssignTask(task, resource_id_set, finish_assign_task_callback);
}
+16 -1
View File
@@ -211,8 +211,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// Assign a task. The task is assumed to not be queued in local_queues_.
///
/// \param task The task in question.
/// \param post_assign_callbacks Set of functions to run after assignments finish.
/// \return true, if tasks was assigned to a worker, false otherwise.
bool AssignTask(const Task &task);
bool AssignTask(const Task &task,
std::vector<std::function<void()>> *post_assign_callbacks);
/// Handle a worker finishing its assigned task.
///
/// \param worker The worker that finished the task.
@@ -328,6 +330,19 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
void HandleTaskUnblocked(const std::shared_ptr<LocalClientConnection> &client,
const TaskID &current_task_id);
/// Handle a direct call task that is blocked. Note that this callback may
/// arrive after the worker lease has been returned to the node manager.
///
/// \param worker Shared ptr to the worker, or nullptr if lost.
void HandleDirectCallTaskBlocked(const std::shared_ptr<Worker> &worker);
/// Handle a direct call task that is unblocked. Note that this callback may
/// arrive after the worker lease has been returned to the node manager.
/// However, it is guaranteed to arrive after DirectCallTaskBlocked.
///
/// \param worker Shared ptr to the worker, or nullptr if lost.
void HandleDirectCallTaskUnblocked(const std::shared_ptr<Worker> &worker);
/// Kill a worker.
///
/// \param worker The worker to kill.
+17 -1
View File
@@ -258,6 +258,20 @@ ray::Status RayletClient::NotifyUnblocked(const TaskID &current_task_id) {
return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb);
}
ray::Status RayletClient::NotifyDirectCallTaskBlocked() {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateNotifyDirectCallTaskBlocked(fbb);
fbb.Finish(message);
return conn_->WriteMessage(MessageType::NotifyDirectCallTaskBlocked, &fbb);
}
ray::Status RayletClient::NotifyDirectCallTaskUnblocked() {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateNotifyDirectCallTaskUnblocked(fbb);
fbb.Finish(message);
return conn_->WriteMessage(MessageType::NotifyDirectCallTaskUnblocked, &fbb);
}
ray::Status RayletClient::Wait(const std::vector<ObjectID> &object_ids, int num_returns,
int64_t timeout_milliseconds, bool wait_local,
const TaskID &current_task_id, WaitResultPair *result) {
@@ -392,6 +406,8 @@ ray::Status RayletClient::ReturnWorker(int worker_port) {
request.set_worker_port(worker_port);
return grpc_client_->ReturnWorker(
request, [](const ray::Status &status, const ray::rpc::ReturnWorkerReply &reply) {
RAY_CHECK_OK(status);
if (!status.ok()) {
RAY_LOG(ERROR) << "Error returning worker: " << status;
}
});
}
+13
View File
@@ -126,12 +126,25 @@ class RayletClient : public WorkerLeaseInterface {
/// \return int 0 means correct, other numbers mean error.
ray::Status FetchOrReconstruct(const std::vector<ObjectID> &object_ids, bool fetch_only,
const TaskID &current_task_id);
/// Notify the raylet that this client (worker) is no longer blocked.
///
/// \param current_task_id The task that is no longer blocked.
/// \return ray::Status.
ray::Status NotifyUnblocked(const TaskID &current_task_id);
/// Notify the raylet that this client is blocked. This is only used for direct task
/// calls. Note that ordering of this with respect to Unblock calls is important.
///
/// \return ray::Status.
ray::Status NotifyDirectCallTaskBlocked();
/// Notify the raylet that this client is unblocked. This is only used for direct task
/// calls. Note that ordering of this with respect to Block calls is important.
///
/// \return ray::Status.
ray::Status NotifyDirectCallTaskUnblocked();
/// Wait for the given objects until timeout expires or num_return objects are
/// found.
///