diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index c8dcae7f9..7f207c4fb 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -205,16 +205,6 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeID &self std::shared_ptr(new ClusterResourceScheduler( self_node_id_.Binary(), local_resources.GetTotalResources().GetResourceMap())); - std::function fulfills_dependencies_func = - [this](const Task &task) { - bool args_ready = task_dependency_manager_.SubscribeGetDependencies( - task.GetTaskSpecification().TaskId(), task.GetDependencies()); - if (args_ready) { - task_dependency_manager_.UnsubscribeGetDependencies( - task.GetTaskSpecification().TaskId()); - } - return args_ready; - }; auto get_node_info_func = [this](const NodeID &node_id) { return gcs_client_->Nodes().Get(node_id); @@ -228,8 +218,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeID &self PublishInfeasibleTaskError(task); }; cluster_task_manager_ = std::shared_ptr(new ClusterTaskManager( - self_node_id_, new_resource_scheduler_, fulfills_dependencies_func, - is_owner_alive, get_node_info_func, announce_infeasible_task)); + self_node_id_, new_resource_scheduler_, task_dependency_manager_, is_owner_alive, + get_node_info_func, announce_infeasible_task)); placement_group_resource_manager_ = std::make_shared(new_resource_scheduler_); } else { @@ -2644,44 +2634,49 @@ void NodeManager::HandleObjectMissing(const ObjectID &object_id) { } RAY_LOG(DEBUG) << result.str(); - // Transition any tasks that were in the runnable state and are dependent on - // this object to the waiting state. - if (!waiting_task_ids.empty()) { - std::unordered_set waiting_task_id_set(waiting_task_ids.begin(), - waiting_task_ids.end()); + // We don't need to do anything if the new scheduler is enabled because tasks + // will get moved back to waiting once they reach the front of the dispatch + // queue. + if (!new_scheduler_enabled_) { + // Transition any tasks that were in the runnable state and are dependent on + // this object to the waiting state. + if (!waiting_task_ids.empty()) { + std::unordered_set waiting_task_id_set(waiting_task_ids.begin(), + waiting_task_ids.end()); - // NOTE(zhijunfu): For direct actors, the worker is initially assigned actor - // creation task ID, which will not be reset after the task finishes. And later tasks - // of this actor will reuse this task ID to require objects from plasma with - // FetchOrReconstruct, since direct actor task IDs are not known to raylet. - // To support actor reconstruction for direct actor, raylet marks actor creation task - // as completed and removes it from `local_queues_` when it receives `TaskDone` - // message from worker. This is necessary because the actor creation task will be - // re-submitted during reconstruction, if the task is not removed previously, the new - // submitted task will be marked as duplicate and thus ignored. - // So here we check for direct actor creation task explicitly to allow this case. - auto iter = waiting_task_id_set.begin(); - while (iter != waiting_task_id_set.end()) { - if (IsActorCreationTask(*iter)) { - RAY_LOG(DEBUG) << "Ignoring direct actor creation task " << *iter - << " when handling object missing for " << object_id; - iter = waiting_task_id_set.erase(iter); - } else { - ++iter; + // NOTE(zhijunfu): For direct actors, the worker is initially assigned actor + // creation task ID, which will not be reset after the task finishes. And later + // tasks of this actor will reuse this task ID to require objects from plasma with + // FetchOrReconstruct, since direct actor task IDs are not known to raylet. + // To support actor reconstruction for direct actor, raylet marks actor creation + // task as completed and removes it from `local_queues_` when it receives `TaskDone` + // message from worker. This is necessary because the actor creation task will be + // re-submitted during reconstruction, if the task is not removed previously, the + // new submitted task will be marked as duplicate and thus ignored. So here we check + // for direct actor creation task explicitly to allow this case. + auto iter = waiting_task_id_set.begin(); + while (iter != waiting_task_id_set.end()) { + if (IsActorCreationTask(*iter)) { + RAY_LOG(DEBUG) << "Ignoring direct actor creation task " << *iter + << " when handling object missing for " << object_id; + iter = waiting_task_id_set.erase(iter); + } else { + ++iter; + } } - } - // First filter out any tasks that can't be transitioned to READY. These - // are running workers or drivers, now blocked in a get. - local_queues_.FilterState(waiting_task_id_set, TaskState::RUNNING); - local_queues_.FilterState(waiting_task_id_set, TaskState::DRIVER); - // Transition the tasks back to the waiting state. They will be made - // runnable once the deleted object becomes available again. - local_queues_.MoveTasks(waiting_task_id_set, TaskState::READY, TaskState::WAITING); - RAY_CHECK(waiting_task_id_set.empty()); - // Moving ready tasks to waiting may have changed the load, making space for placing - // new tasks locally. - ScheduleTasks(cluster_resource_map_); + // First filter out any tasks that can't be transitioned to READY. These + // are running workers or drivers, now blocked in a get. + local_queues_.FilterState(waiting_task_id_set, TaskState::RUNNING); + local_queues_.FilterState(waiting_task_id_set, TaskState::DRIVER); + // Transition the tasks back to the waiting state. They will be made + // runnable once the deleted object becomes available again. + local_queues_.MoveTasks(waiting_task_id_set, TaskState::READY, TaskState::WAITING); + RAY_CHECK(waiting_task_id_set.empty()); + // Moving ready tasks to waiting may have changed the load, making space for placing + // new tasks locally. + ScheduleTasks(cluster_resource_map_); + } } } diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index 12715430e..ab3b46ea7 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -14,13 +14,13 @@ const int kMaxPendingActorsToReport = 20; ClusterTaskManager::ClusterTaskManager( const NodeID &self_node_id, std::shared_ptr cluster_resource_scheduler, - std::function fulfills_dependencies_func, + TaskDependencyManagerInterface &task_dependency_manager, std::function is_owner_alive, NodeInfoGetter get_node_info, std::function announce_infeasible_task) : self_node_id_(self_node_id), cluster_resource_scheduler_(cluster_resource_scheduler), - fulfills_dependencies_func_(fulfills_dependencies_func), + task_dependency_manager_(task_dependency_manager), is_owner_alive_(is_owner_alive), get_node_info_(get_node_info), announce_infeasible_task_(announce_infeasible_task), @@ -102,7 +102,8 @@ bool ClusterTaskManager::WaitForTaskArgsRequests(Work work) { auto object_ids = task.GetTaskSpecification().GetDependencies(); bool can_dispatch = true; if (object_ids.size() > 0) { - bool args_ready = fulfills_dependencies_func_(task); + bool args_ready = task_dependency_manager_.SubscribeGetDependencies( + task.GetTaskSpecification().TaskId(), task.GetDependencies()); if (args_ready) { RAY_LOG(DEBUG) << "Args already ready, task can be dispatched " << task.GetTaskSpecification().TaskId(); @@ -138,6 +139,16 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers( auto &task = std::get<0>(work); auto &spec = task.GetTaskSpecification(); + // An argument was evicted since this task was added to the dispatch + // queue. Move it back to the waiting queue. The caller is responsible + // for notifying us when the task is unblocked again. + if (!spec.GetDependencies().empty() && + !task_dependency_manager_.IsTaskReady(spec.TaskId())) { + waiting_tasks_[spec.TaskId()] = std::move(*work_it); + work_it = dispatch_queue.erase(work_it); + continue; + } + std::shared_ptr worker = worker_pool.PopWorker(spec); if (!worker) { // No worker available, we won't be able to schedule any kind of task. @@ -152,6 +163,9 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers( RAY_LOG(WARNING) << "Task: " << task.GetTaskSpecification().TaskId() << "'s caller is no longer running. Cancelling task."; worker_pool.PushWorker(worker); + if (!spec.GetDependencies().empty()) { + RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId())); + } work_it = dispatch_queue.erase(work_it); } else { bool worker_leased; @@ -164,6 +178,9 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers( worker_pool.PushWorker(worker); } if (remove) { + if (!spec.GetDependencies().empty()) { + RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId())); + } work_it = dispatch_queue.erase(work_it); } else { break; @@ -295,6 +312,9 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id) { if (task.GetTaskSpecification().TaskId() == task_id) { RemoveFromBacklogTracker(task); ReplyCancelled(*work_it); + if (!task.GetTaskSpecification().GetDependencies().empty()) { + RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(task_id)); + } work_queue.erase(work_it); if (work_queue.empty()) { tasks_to_dispatch_.erase(shapes_it); @@ -326,6 +346,9 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id) { const auto &task = std::get<0>(iter->second); RemoveFromBacklogTracker(task); ReplyCancelled(iter->second); + if (!task.GetTaskSpecification().GetDependencies().empty()) { + task_dependency_manager_.UnsubscribeGetDependencies(task_id); + } waiting_tasks_.erase(iter); return true; } diff --git a/src/ray/raylet/scheduling/cluster_task_manager.h b/src/ray/raylet/scheduling/cluster_task_manager.h index 3e3ff2e44..61cfce031 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/ray/raylet/scheduling/cluster_task_manager.h @@ -5,6 +5,7 @@ #include "ray/common/task/task.h" #include "ray/common/task/task_common.h" #include "ray/raylet/scheduling/cluster_resource_scheduler.h" +#include "ray/raylet/task_dependency_manager.h" #include "ray/raylet/worker.h" #include "ray/raylet/worker_pool.h" #include "ray/rpc/grpc_client.h" @@ -46,14 +47,13 @@ class ClusterTaskManager { /// \param self_node_id: ID of local node. /// \param cluster_resource_scheduler: The resource scheduler which contains /// the state of the cluster. - /// \param fulfills_dependencies_func: Returns true if all of a task's - /// dependencies are fulfilled. + /// \param task_dependency_manager_ Used to fetch task's dependencies. /// \param is_owner_alive: A callback which returns if the owner process is alive /// (according to our ownership model). /// \param gcs_client: A gcs client. ClusterTaskManager(const NodeID &self_node_id, std::shared_ptr cluster_resource_scheduler, - std::function fulfills_dependencies_func, + TaskDependencyManagerInterface &task_dependency_manager_, std::function is_owner_alive, NodeInfoGetter get_node_info, std::function announce_infeasible_task); @@ -145,8 +145,8 @@ class ClusterTaskManager { const NodeID &self_node_id_; std::shared_ptr cluster_resource_scheduler_; - /// Function to make task dependencies to be local. - std::function fulfills_dependencies_func_; + /// Class to make task dependencies to be local. + TaskDependencyManagerInterface &task_dependency_manager_; /// Function to check if the owner is alive on a given node. std::function is_owner_alive_; /// Function to get the node information of a given node id. @@ -163,10 +163,20 @@ class ClusterTaskManager { /// Queue of lease requests that should be scheduled onto workers. /// Tasks move from scheduled | waiting -> dispatch. + /// Tasks can also move from dispatch -> waiting if one of their arguments is + /// evicted. + /// All tasks in this map that have dependencies should be registered with + /// the dependency manager, in case a dependency gets evicted while the task + /// is still queued. std::unordered_map> tasks_to_dispatch_; /// Tasks waiting for arguments to be transferred locally. /// Tasks move from waiting -> dispatch. + /// Tasks can also move from dispatch -> waiting if one of their arguments is + /// evicted. + /// All tasks in this map that have dependencies should be registered with + /// the dependency manager, so that they can be moved to dispatch once their + /// dependencies are local. absl::flat_hash_map waiting_tasks_; /// Queue of lease requests that are infeasible. @@ -192,6 +202,8 @@ class ClusterTaskManager { void AddToBacklogTracker(const Task &task); void RemoveFromBacklogTracker(const Task &task); + + friend class ClusterTaskManagerTest; }; } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index 24018dbc8..3f33cbf06 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -39,6 +39,8 @@ namespace ray { namespace raylet { +using ::testing::_; + class MockWorkerPool : public WorkerPoolInterface { public: std::shared_ptr PopWorker(const TaskSpecification &task_spec) { @@ -92,21 +94,34 @@ Task CreateTask(const std::unordered_map &required_resource return Task(spec_builder.Build(), TaskExecutionSpecification(execution_spec_message)); } +class MockTaskDependencyManager : public TaskDependencyManagerInterface { + public: + bool SubscribeGetDependencies( + const TaskID &task_id, const std::vector &required_objects) { + RAY_CHECK(subscribed_tasks.insert(task_id).second); + return task_ready_; + } + + bool UnsubscribeGetDependencies(const TaskID &task_id) { + return subscribed_tasks.erase(task_id); + } + + bool IsTaskReady(const TaskID &task_id) const { return task_ready_; } + + bool task_ready_ = true; + + std::unordered_set subscribed_tasks; +}; + class ClusterTaskManagerTest : public ::testing::Test { public: ClusterTaskManagerTest() : id_(NodeID::FromRandom()), scheduler_(CreateSingleNodeScheduler(id_.Binary())), - fulfills_dependencies_calls_(0), - dependencies_fulfilled_(true), is_owner_alive_(true), node_info_calls_(0), announce_infeasible_task_calls_(0), - task_manager_(id_, scheduler_, - [this](const Task &_task) { - fulfills_dependencies_calls_++; - return dependencies_fulfilled_; - }, + task_manager_(id_, scheduler_, dependency_manager_, [this](const WorkerID &worker_id, const NodeID &node_id) { return is_owner_alive_; }, @@ -132,20 +147,26 @@ class ClusterTaskManagerTest : public ::testing::Test { node_info_[id] = info; } + void AssertNoLeaks() { + ASSERT_TRUE(task_manager_.tasks_to_schedule_.empty()); + ASSERT_TRUE(task_manager_.tasks_to_dispatch_.empty()); + ASSERT_TRUE(task_manager_.waiting_tasks_.empty()); + ASSERT_TRUE(task_manager_.infeasible_tasks_.empty()); + ASSERT_TRUE(dependency_manager_.subscribed_tasks.empty()); + } + NodeID id_; std::shared_ptr scheduler_; MockWorkerPool pool_; std::unordered_map> leased_workers_; - int fulfills_dependencies_calls_; - bool dependencies_fulfilled_; - bool is_owner_alive_; int node_info_calls_; int announce_infeasible_task_calls_; std::unordered_map> node_info_; + MockTaskDependencyManager dependency_manager_; ClusterTaskManager task_manager_; }; @@ -178,8 +199,9 @@ TEST_F(ClusterTaskManagerTest, BasicTest) { ASSERT_TRUE(callback_occurred); ASSERT_EQ(leased_workers_.size(), 1); ASSERT_EQ(pool_.workers.size(), 0); - ASSERT_EQ(fulfills_dependencies_calls_, 0); ASSERT_EQ(node_info_calls_, 0); + + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, NoFeasibleNodeTest) { @@ -202,7 +224,6 @@ TEST_F(ClusterTaskManagerTest, NoFeasibleNodeTest) { ASSERT_EQ(leased_workers_.size(), 0); // Worker is unused. ASSERT_EQ(pool_.workers.size(), 1); - ASSERT_EQ(fulfills_dependencies_calls_, 0); ASSERT_EQ(node_info_calls_, 0); } @@ -227,11 +248,14 @@ TEST_F(ClusterTaskManagerTest, ResourceTakenWhileResolving) { }; /* Blocked on dependencies */ + dependency_manager_.task_ready_ = false; auto task = CreateTask({{ray::kCPU_ResourceLabel, 5}}, 1); - dependencies_fulfilled_ = false; + std::unordered_set expected_subscribed_tasks = { + task.GetTaskSpecification().TaskId()}; task_manager_.QueueTask(task, &reply, callback); task_manager_.SchedulePendingTasks(); task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(dependency_manager_.subscribed_tasks, expected_subscribed_tasks); ASSERT_EQ(num_callbacks, 0); ASSERT_EQ(leased_workers_.size(), 0); @@ -242,18 +266,20 @@ TEST_F(ClusterTaskManagerTest, ResourceTakenWhileResolving) { task_manager_.QueueTask(task2, &reply, callback); task_manager_.SchedulePendingTasks(); task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(dependency_manager_.subscribed_tasks, expected_subscribed_tasks); ASSERT_EQ(num_callbacks, 1); ASSERT_EQ(leased_workers_.size(), 1); ASSERT_EQ(pool_.workers.size(), 1); /* First task is unblocked now, but resources are no longer available */ + dependency_manager_.task_ready_ = true; auto id = task.GetTaskSpecification().TaskId(); std::vector unblocked = {id}; - dependencies_fulfilled_ = true; task_manager_.TasksUnblocked(unblocked); task_manager_.SchedulePendingTasks(); task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(dependency_manager_.subscribed_tasks, expected_subscribed_tasks); ASSERT_EQ(num_callbacks, 1); ASSERT_EQ(leased_workers_.size(), 1); @@ -265,11 +291,13 @@ TEST_F(ClusterTaskManagerTest, ResourceTakenWhileResolving) { task_manager_.SchedulePendingTasks(); task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_TRUE(dependency_manager_.subscribed_tasks.empty()); // Task2 is now done so task can run. ASSERT_EQ(num_callbacks, 2); ASSERT_EQ(leased_workers_.size(), 1); ASSERT_EQ(pool_.workers.size(), 0); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, TestSpillAfterAssigned) { @@ -319,6 +347,7 @@ TEST_F(ClusterTaskManagerTest, TestSpillAfterAssigned) { // The second task was spilled. ASSERT_EQ(spillback_reply.retry_at_raylet_address().raylet_id(), remote_node_id.Binary()); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, TaskCancellationTest) { @@ -375,6 +404,7 @@ TEST_F(ClusterTaskManagerTest, TaskCancellationTest) { ASSERT_FALSE(callback_called); ASSERT_EQ(pool_.workers.size(), 0); ASSERT_EQ(leased_workers_.size(), 1); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, TaskCancelInfeasibleTask) { @@ -412,6 +442,7 @@ TEST_F(ClusterTaskManagerTest, TaskCancelInfeasibleTask) { ASSERT_TRUE(reply.canceled()); ASSERT_EQ(leased_workers_.size(), 0); ASSERT_EQ(pool_.workers.size(), 1); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, HeartbeatTest) { @@ -552,7 +583,6 @@ TEST_F(ClusterTaskManagerTest, BacklogReportTest) { ASSERT_FALSE(callback_occurred); ASSERT_EQ(leased_workers_.size(), 0); ASSERT_EQ(pool_.workers.size(), 1); - ASSERT_EQ(fulfills_dependencies_calls_, 0); ASSERT_EQ(node_info_calls_, 0); auto data = std::make_shared(); @@ -578,6 +608,7 @@ TEST_F(ClusterTaskManagerTest, BacklogReportTest) { ASSERT_EQ(shape1.backlog_size(), 0); ASSERT_EQ(shape1.num_infeasible_requests_queued(), 0); ASSERT_EQ(shape1.num_ready_requests_queued(), 0); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, OwnerDeadTest) { @@ -611,6 +642,7 @@ TEST_F(ClusterTaskManagerTest, OwnerDeadTest) { ASSERT_FALSE(callback_occurred); ASSERT_EQ(leased_workers_.size(), 0); ASSERT_EQ(pool_.workers.size(), 1); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, TestInfeasibleTaskWarning) { @@ -653,6 +685,7 @@ TEST_F(ClusterTaskManagerTest, TestInfeasibleTaskWarning) { ASSERT_EQ(pool_.workers.size(), 1); // Make sure the spillback callback is called. ASSERT_EQ(reply.retry_at_raylet_address().raylet_id(), remote_node_id.Binary()); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, TestMultipleInfeasibleTasksWarnOnce) { @@ -719,6 +752,64 @@ TEST_F(ClusterTaskManagerTest, TestAnyPendingTasks) { &pending_actor_creations, &pending_tasks)); } +TEST_F(ClusterTaskManagerTest, ArgumentEvicted) { + /* + Test the task's dependencies becoming local, then one of the arguments is + evicted. The task should go from waiting -> dispatch -> waiting. + */ + std::shared_ptr worker = + std::make_shared(WorkerID::FromRandom(), 1234); + pool_.PushWorker(std::dynamic_pointer_cast(worker)); + + rpc::RequestWorkerLeaseReply reply; + int num_callbacks = 0; + int *num_callbacks_ptr = &num_callbacks; + auto callback = [num_callbacks_ptr]() { + (*num_callbacks_ptr) = *num_callbacks_ptr + 1; + }; + + /* Blocked on dependencies */ + dependency_manager_.task_ready_ = false; + auto task = CreateTask({{ray::kCPU_ResourceLabel, 5}}, 2); + std::unordered_set expected_subscribed_tasks = { + task.GetTaskSpecification().TaskId()}; + task_manager_.QueueTask(task, &reply, callback); + task_manager_.SchedulePendingTasks(); + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(dependency_manager_.subscribed_tasks, expected_subscribed_tasks); + ASSERT_EQ(num_callbacks, 0); + ASSERT_EQ(leased_workers_.size(), 0); + + /* Task is unblocked now */ + dependency_manager_.task_ready_ = true; + pool_.workers.clear(); + auto id = task.GetTaskSpecification().TaskId(); + task_manager_.TasksUnblocked({id}); + task_manager_.SchedulePendingTasks(); + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(dependency_manager_.subscribed_tasks, expected_subscribed_tasks); + ASSERT_EQ(num_callbacks, 0); + ASSERT_EQ(leased_workers_.size(), 0); + + /* Task argument gets evicted */ + dependency_manager_.task_ready_ = false; + pool_.PushWorker(std::dynamic_pointer_cast(worker)); + task_manager_.SchedulePendingTasks(); + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(dependency_manager_.subscribed_tasks, expected_subscribed_tasks); + ASSERT_EQ(num_callbacks, 0); + ASSERT_EQ(leased_workers_.size(), 0); + + /* Worker available and arguments available */ + task_manager_.TasksUnblocked({id}); + dependency_manager_.task_ready_ = true; + task_manager_.SchedulePendingTasks(); + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(num_callbacks, 1); + ASSERT_EQ(leased_workers_.size(), 1); + AssertNoLeaks(); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index f2b0ab959..74c3d8c7a 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -211,6 +211,12 @@ bool TaskDependencyManager::SubscribeGetDependencies( return (task_entry.num_missing_get_dependencies == 0); } +bool TaskDependencyManager::IsTaskReady(const TaskID &task_id) const { + auto task_entry = task_dependencies_.find(task_id); + RAY_CHECK(task_entry != task_dependencies_.end()); + return task_entry->second.num_missing_get_dependencies == 0; +} + void TaskDependencyManager::SubscribeWaitDependencies( const WorkerID &worker_id, const std::vector &required_objects) { diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index 75654698f..eb2c53ee9 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -29,6 +29,18 @@ using rpc::TaskLeaseData; class ReconstructionPolicy; +/// Used for unit-testing the ClusterTaskManager, which calls these methods for +/// locally queued tasks that have dependencies. +class TaskDependencyManagerInterface { + public: + virtual bool SubscribeGetDependencies( + const TaskID &task_id, + const std::vector &required_objects) = 0; + virtual bool IsTaskReady(const TaskID &task_id) const = 0; + virtual bool UnsubscribeGetDependencies(const TaskID &task_id) = 0; + virtual ~TaskDependencyManagerInterface() {} +}; + /// \class TaskDependencyManager /// /// Responsible for managing object dependencies for tasks. The caller can @@ -39,7 +51,7 @@ class ReconstructionPolicy; /// made available locally, either by object transfer from a remote node or /// reconstruction. The task manager will also cancel these objects if they are /// no longer needed by any task. -class TaskDependencyManager { +class TaskDependencyManager : public TaskDependencyManagerInterface { public: /// Create a task dependency manager. TaskDependencyManager(ObjectManagerInterface &object_manager, @@ -70,6 +82,14 @@ class TaskDependencyManager { bool SubscribeGetDependencies( const TaskID &task_id, const std::vector &required_objects); + /// Check whether a task is ready to run. The task ID must + /// have been previously subscribed by the caller. + /// + /// \param task_id The ID of the task to check. + /// \return Whether all of the dependencies for the task are + /// local. + bool IsTaskReady(const TaskID &task_id) const; + /// Subscribe to object depedencies required by the worker. This should be called for /// ray.wait calls during task execution. ///