From 239196fffca6d587c33ee955d9e1f499c85cb89a Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Thu, 23 Jul 2020 21:13:29 -0700 Subject: [PATCH] [Core] WorkerInterface refactor (#9655) * . * . * refactor WorkerInterface * . * Basic unit test structure complete? * . * . * . * . * Fixed tests * Fixed tests * . --- src/ray/raylet/node_manager.cc | 61 +++++++------ src/ray/raylet/node_manager.h | 22 ++--- .../raylet/scheduling/cluster_task_manager.cc | 10 +-- .../raylet/scheduling/cluster_task_manager.h | 15 ++-- src/ray/raylet/worker.h | 87 +++++++++++++++++-- src/ray/raylet/worker_pool.cc | 45 +++++----- src/ray/raylet/worker_pool.h | 57 ++++++++---- src/ray/raylet/worker_pool_test.cc | 14 +-- 8 files changed, 213 insertions(+), 98 deletions(-) diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 06eaf5ad3..061529bb0 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -105,7 +105,8 @@ namespace raylet { // A helper function to print the leased workers. std::string LeasedWorkersSring( - const std::unordered_map> &leased_workers) { + const std::unordered_map> + &leased_workers) { std::stringstream buffer; buffer << " @leased_workers: ("; for (const auto &pair : leased_workers) { @@ -117,7 +118,8 @@ std::string LeasedWorkersSring( } // A helper function to print the workers in worker_pool_. -std::string WorkerPoolString(const std::vector> &worker_pool) { +std::string WorkerPoolString( + const std::vector> &worker_pool) { std::stringstream buffer; buffer << " @worker_pool: ("; for (const auto &worker : worker_pool) { @@ -128,7 +130,7 @@ std::string WorkerPoolString(const std::vector> &worker_ } // Helper function to print the worker's owner worker and and node owner. -std::string WorkerOwnerString(std::shared_ptr &worker) { +std::string WorkerOwnerString(std::shared_ptr &worker) { std::stringstream buffer; const auto owner_worker_id = WorkerID::FromBinary(worker->GetOwnerAddress().worker_id()); @@ -320,7 +322,7 @@ ray::Status NodeManager::RegisterGcs() { return ray::Status::OK(); } -void NodeManager::KillWorker(std::shared_ptr worker) { +void NodeManager::KillWorker(std::shared_ptr worker) { #ifdef _WIN32 // TODO(mehrdadn): implement graceful process termination mechanism #else @@ -1072,7 +1074,7 @@ void NodeManager::DispatchTasks( // Try to get an idle worker to execute this task. If nullptr, there // aren't any available workers so we can't assign the task. - std::shared_ptr worker = + std::shared_ptr worker = worker_pool_.PopWorker(task.GetTaskSpecification()); if (worker != nullptr) { AssignTask(worker, task, &post_assign_callbacks); @@ -1145,11 +1147,11 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr & ProcessFetchOrReconstructMessage(client, message_data); } break; case protocol::MessageType::NotifyDirectCallTaskBlocked: { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); HandleDirectCallTaskBlocked(worker); } break; case protocol::MessageType::NotifyDirectCallTaskUnblocked: { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); HandleDirectCallTaskUnblocked(worker); } break; case protocol::MessageType::NotifyUnblocked: { @@ -1214,8 +1216,8 @@ void NodeManager::ProcessRegisterClientRequestMessage( WorkerID worker_id = from_flatbuf(*message->worker_id()); pid_t pid = message->worker_pid(); std::string worker_ip_address = string_from_flatbuf(*message->ip_address()); - auto worker = std::make_shared(worker_id, language, worker_ip_address, client, - client_call_manager_); + auto worker = std::dynamic_pointer_cast(std::make_shared( + worker_id, language, worker_ip_address, client, client_call_manager_)); int assigned_port; if (message->is_worker()) { @@ -1269,7 +1271,7 @@ void NodeManager::ProcessRegisterClientRequestMessage( void NodeManager::ProcessAnnounceWorkerPortMessage( const std::shared_ptr &client, const uint8_t *message_data) { bool is_worker = true; - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (worker == nullptr) { is_worker = false; worker = worker_pool_.GetRegisteredDriver(client); @@ -1345,11 +1347,11 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca } void NodeManager::HandleWorkerAvailable(const std::shared_ptr &client) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); HandleWorkerAvailable(worker); } -void NodeManager::HandleWorkerAvailable(const std::shared_ptr &worker) { +void NodeManager::HandleWorkerAvailable(const std::shared_ptr &worker) { RAY_CHECK(worker); bool worker_idle = true; @@ -1376,7 +1378,7 @@ void NodeManager::HandleWorkerAvailable(const std::shared_ptr &worker) { void NodeManager::ProcessDisconnectClientMessage( const std::shared_ptr &client, bool intentional_disconnect) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); bool is_worker = false, is_driver = false; if (worker) { // The client is a worker. @@ -1617,7 +1619,8 @@ void NodeManager::ProcessWaitForDirectActorCallArgsRequestMessage( object_ids, -1, object_ids.size(), false, [this, client, tag](std::vector found, std::vector remaining) { RAY_CHECK(remaining.empty()); - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + std::shared_ptr worker = + worker_pool_.GetRegisteredWorker(client); if (!worker) { RAY_LOG(ERROR) << "Lost worker for wait request " << client; } else { @@ -1647,7 +1650,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( const auto &actor_entry = actor_registry_.find(actor_id); RAY_CHECK(actor_entry != actor_registry_.end()); - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); RAY_CHECK(worker && worker->GetActorId() == actor_id); std::shared_ptr checkpoint_data = @@ -1822,7 +1825,7 @@ void NodeManager::HandleReturnWorker(const rpc::ReturnWorkerRequest &request, rpc::SendReplyCallback send_reply_callback) { // Read the resource spec submitted by the client. auto worker_id = WorkerID::FromBinary(request.worker_id()); - std::shared_ptr worker = leased_workers_[worker_id]; + std::shared_ptr worker = leased_workers_[worker_id]; Status status; leased_workers_.erase(worker_id); @@ -2320,7 +2323,8 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag } } -void NodeManager::HandleDirectCallTaskBlocked(const std::shared_ptr &worker) { +void NodeManager::HandleDirectCallTaskBlocked( + const std::shared_ptr &worker) { if (new_scheduler_enabled_) { if (!worker) { return; @@ -2349,7 +2353,8 @@ void NodeManager::HandleDirectCallTaskBlocked(const std::shared_ptr &wor DispatchTasks(local_queues_.GetReadyTasksByClass()); } -void NodeManager::HandleDirectCallTaskUnblocked(const std::shared_ptr &worker) { +void NodeManager::HandleDirectCallTaskUnblocked( + const std::shared_ptr &worker) { if (new_scheduler_enabled_) { if (!worker) { return; @@ -2406,7 +2411,7 @@ void NodeManager::AsyncResolveObjects( const std::shared_ptr &client, const std::vector &required_object_refs, const TaskID ¤t_task_id, bool ray_get, bool mark_worker_blocked) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (worker) { // The client is a worker. If the worker is not already blocked and the // blocked task matches the one assigned to the worker, then mark the @@ -2460,7 +2465,7 @@ void NodeManager::AsyncResolveObjects( void NodeManager::AsyncResolveObjectsFinish( const std::shared_ptr &client, const TaskID ¤t_task_id, bool was_blocked) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); // TODO(swang): Because the object dependencies are tracked in the task // dependency manager, we could actually remove this message entirely and @@ -2540,7 +2545,8 @@ void NodeManager::EnqueuePlaceableTask(const Task &task) { task_dependency_manager_.TaskPending(task); } -void NodeManager::AssignTask(const std::shared_ptr &worker, const Task &task, +void NodeManager::AssignTask(const std::shared_ptr &worker, + const Task &task, std::vector> *post_assign_callbacks) { const TaskSpecification &spec = task.GetTaskSpecification(); RAY_CHECK(post_assign_callbacks); @@ -2626,7 +2632,7 @@ void NodeManager::AssignTask(const std::shared_ptr &worker, const Task & } } -bool NodeManager::FinishAssignedTask(Worker &worker) { +bool NodeManager::FinishAssignedTask(WorkerInterface &worker) { TaskID task_id = worker.GetAssignedTaskId(); RAY_LOG(DEBUG) << "Finished task " << task_id; @@ -2735,7 +2741,7 @@ std::shared_ptr NodeManager::CreateActorTableDataFromCreationTas return actor_info_ptr; } -void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { +void NodeManager::FinishAssignedActorTask(WorkerInterface &worker, const Task &task) { RAY_LOG(DEBUG) << "Finishing assigned actor task"; ActorID actor_id; TaskID caller_id; @@ -3303,7 +3309,7 @@ void NodeManager::ForwardTask( }); } -void NodeManager::FinishAssignTask(const std::shared_ptr &worker, +void NodeManager::FinishAssignTask(const std::shared_ptr &worker, const TaskID &task_id, bool success) { RAY_LOG(DEBUG) << "FinishAssignTask: " << task_id; // Remove the ASSIGNED task from the READY queue. @@ -3348,7 +3354,8 @@ void NodeManager::FinishAssignTask(const std::shared_ptr &worker, void NodeManager::ProcessSubscribePlasmaReady( const std::shared_ptr &client, const uint8_t *message_data) { - std::shared_ptr associated_worker = worker_pool_.GetRegisteredWorker(client); + std::shared_ptr associated_worker = + worker_pool_.GetRegisteredWorker(client); if (associated_worker == nullptr) { associated_worker = worker_pool_.GetRegisteredDriver(client); } @@ -3361,7 +3368,7 @@ void NodeManager::ProcessSubscribePlasmaReady( absl::MutexLock guard(&plasma_object_notification_lock_); if (!async_plasma_objects_notification_.contains(id)) { async_plasma_objects_notification_.emplace( - id, absl::flat_hash_set>()); + id, absl::flat_hash_set>()); } // Only insert a worker once @@ -3375,7 +3382,7 @@ ray::Status NodeManager::SetupPlasmaSubscription() { return object_manager_.SubscribeObjAdded( [this](const object_manager::protocol::ObjectInfoT &object_info) { ObjectID object_id = ObjectID::FromBinary(object_info.object_id); - auto waiting_workers = absl::flat_hash_set>(); + auto waiting_workers = absl::flat_hash_set>(); { absl::MutexLock guard(&plasma_object_notification_lock_); auto waiting = this->async_plasma_objects_notification_.extract(object_id); diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 643723b00..9fa038505 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -256,7 +256,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param[in] task The task in question. /// \param[out] post_assign_callbacks Vector of callbacks that will be appended /// to with any logic that should run after the DispatchTasks loop runs. - void AssignTask(const std::shared_ptr &worker, const Task &task, + void AssignTask(const std::shared_ptr &worker, const Task &task, std::vector> *post_assign_callbacks); /// Handle a worker finishing its assigned task. /// @@ -264,7 +264,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \return Whether the worker should be returned to the idle pool. This is /// only false for direct actor creation calls, which should never be /// returned to idle. - bool FinishAssignedTask(Worker &worker); + bool FinishAssignedTask(WorkerInterface &worker); /// Helper function to produce actor table data for a newly created actor. /// /// \param task_spec Task specification of the actor creation task that created the @@ -276,7 +276,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param worker The worker that finished the task. /// \param task The actor task or actor creation task. /// \return Void. - void FinishAssignedActorTask(Worker &worker, const Task &task); + void FinishAssignedActorTask(WorkerInterface &worker, const Task &task); /// Helper function for handling worker to finish its assigned actor task /// or actor creation task. Gets invoked when tasks's parent actor is known. /// @@ -395,20 +395,20 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// arrive after the worker lease has been returned to the node manager. /// /// \param worker Shared ptr to the worker, or nullptr if lost. - void HandleDirectCallTaskBlocked(const std::shared_ptr &worker); + void HandleDirectCallTaskBlocked(const std::shared_ptr &worker); /// Handle a direct call task that is unblocked. Note that this callback may /// arrive after the worker lease has been returned to the node manager. /// However, it is guaranteed to arrive after DirectCallTaskBlocked. /// /// \param worker Shared ptr to the worker, or nullptr if lost. - void HandleDirectCallTaskUnblocked(const std::shared_ptr &worker); + void HandleDirectCallTaskUnblocked(const std::shared_ptr &worker); /// Kill a worker. /// /// \param worker The worker to kill. /// \return Void. - void KillWorker(std::shared_ptr worker); + void KillWorker(std::shared_ptr worker); /// The callback for handling an actor state transition (e.g., from ALIVE to /// DEAD), whether as a notification from the actor table or as a handler for @@ -495,7 +495,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// /// \param worker The pointer to the worker /// \return Void. - void HandleWorkerAvailable(const std::shared_ptr &worker); + void HandleWorkerAvailable(const std::shared_ptr &worker); /// Handle a client that has disconnected. This can be called multiple times /// on the same client because this is triggered both when a client @@ -582,8 +582,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param task_id Id of the task. /// \param success Whether or not assigning the task was successful. /// \return void. - void FinishAssignTask(const std::shared_ptr &worker, const TaskID &task_id, - bool success); + void FinishAssignTask(const std::shared_ptr &worker, + const TaskID &task_id, bool success); /// Process worker subscribing to plasma. /// @@ -762,7 +762,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { remote_node_manager_clients_; /// Map of workers leased out to direct call clients. - std::unordered_map> leased_workers_; + std::unordered_map> leased_workers_; /// Map from owner worker ID to a list of worker IDs that the owner has a /// lease on. @@ -805,7 +805,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { mutable absl::Mutex plasma_object_notification_lock_; /// Keeps track of workers waiting for objects - absl::flat_hash_map>> + absl::flat_hash_map>> async_plasma_objects_notification_ GUARDED_BY(plasma_object_notification_lock_); /// Objects that are out of scope in the application and that should be freed diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index 1ce8bb689..d1e85ea7b 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -81,8 +81,8 @@ bool ClusterTaskManager::WaitForTaskArgsRequests(Work work) { } void ClusterTaskManager::DispatchScheduledTasksToWorkers( - WorkerPool &worker_pool, - std::unordered_map> &leased_workers) { + WorkerPoolInterface &worker_pool, + std::unordered_map> &leased_workers) { // Check every task in task_to_dispatch queue to see // whether it can be dispatched and ran. This avoids head-of-line // blocking where a task which cannot be dispatched because @@ -94,7 +94,7 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers( auto spec = task.GetTaskSpecification(); tasks_to_dispatch_.pop_front(); - std::shared_ptr worker = worker_pool.PopWorker(spec); + std::shared_ptr worker = worker_pool.PopWorker(spec); if (!worker) { // No worker available to schedule this task. // Put the task back in the dispatch queue. @@ -148,8 +148,8 @@ void ClusterTaskManager::TasksUnblocked(const std::vector ready_ids) { } void ClusterTaskManager::Dispatch( - std::shared_ptr worker, - std::unordered_map> &leased_workers_, + std::shared_ptr worker, + std::unordered_map> &leased_workers_, const TaskSpecification &task_spec, rpc::RequestWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback) { reply->mutable_worker_address()->set_ip_address(worker->IpAddress()); diff --git a/src/ray/raylet/scheduling/cluster_task_manager.h b/src/ray/raylet/scheduling/cluster_task_manager.h index 4fe289d33..85ccb0bc2 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/ray/raylet/scheduling/cluster_task_manager.h @@ -61,14 +61,14 @@ class ClusterTaskManager { /// `worker_pool` state will be modified (idle workers will be popped) during /// dispatching. void DispatchScheduledTasksToWorkers( - WorkerPool &worker_pool, - std::unordered_map> &leased_workers); + WorkerPoolInterface &worker_pool, + std::unordered_map> &leased_workers); /// (Step 1) Queue tasks for scheduling. /// \param fn: The function used during dispatching. /// \param task: The incoming task to schedule. void QueueTask(const Task &task, rpc::RequestWorkerLeaseReply *reply, - rpc::SendReplyCallback send_reply_callback); + rpc::SendReplyCallback); /// Move tasks from waiting to ready for dispatch. Called when a task's /// dependencies are resolved. @@ -96,10 +96,11 @@ class ClusterTaskManager { /// \return True if the work can be immediately dispatched. bool WaitForTaskArgsRequests(Work work); - void Dispatch(std::shared_ptr worker, - std::unordered_map> &leased_workers_, - const TaskSpecification &task_spec, rpc::RequestWorkerLeaseReply *reply, - rpc::SendReplyCallback send_reply_callback); + void Dispatch( + std::shared_ptr worker, + std::unordered_map> &leased_workers_, + const TaskSpecification &task_spec, rpc::RequestWorkerLeaseReply *reply, + rpc::SendReplyCallback send_reply_callback); void Spillback(ClientID spillback_to, std::string address, int port, rpc::RequestWorkerLeaseReply *reply, diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index e03c6abc1..c89ece745 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -30,10 +30,91 @@ namespace ray { namespace raylet { +/// \class WorkerPoolInterface +/// +/// Used for new scheduler unit tests. +class WorkerInterface { + public: + /// A destructor responsible for freeing all worker state. + virtual ~WorkerInterface() {} + virtual void MarkDead() = 0; + virtual bool IsDead() const = 0; + virtual void MarkBlocked() = 0; + virtual void MarkUnblocked() = 0; + virtual bool IsBlocked() const = 0; + /// Return the worker's ID. + virtual WorkerID WorkerId() const = 0; + /// Return the worker process. + virtual Process GetProcess() const = 0; + virtual void SetProcess(Process proc) = 0; + virtual Language GetLanguage() const = 0; + virtual const std::string IpAddress() const = 0; + /// Connect this worker's gRPC client. + virtual void Connect(int port) = 0; + virtual int Port() const = 0; + virtual int AssignedPort() const = 0; + virtual void SetAssignedPort(int port) = 0; + virtual void AssignTaskId(const TaskID &task_id) = 0; + virtual const TaskID &GetAssignedTaskId() const = 0; + virtual bool AddBlockedTaskId(const TaskID &task_id) = 0; + virtual bool RemoveBlockedTaskId(const TaskID &task_id) = 0; + virtual const std::unordered_set &GetBlockedTaskIds() const = 0; + virtual void AssignJobId(const JobID &job_id) = 0; + virtual const JobID &GetAssignedJobId() const = 0; + virtual void AssignActorId(const ActorID &actor_id) = 0; + virtual const ActorID &GetActorId() const = 0; + virtual void MarkDetachedActor() = 0; + virtual bool IsDetachedActor() const = 0; + virtual const std::shared_ptr Connection() const = 0; + virtual void SetOwnerAddress(const rpc::Address &address) = 0; + virtual const rpc::Address &GetOwnerAddress() const = 0; + + virtual const ResourceIdSet &GetLifetimeResourceIds() const = 0; + virtual void SetLifetimeResourceIds(ResourceIdSet &resource_ids) = 0; + virtual void ResetLifetimeResourceIds() = 0; + + virtual const ResourceIdSet &GetTaskResourceIds() const = 0; + virtual void SetTaskResourceIds(ResourceIdSet &resource_ids) = 0; + virtual void ResetTaskResourceIds() = 0; + virtual ResourceIdSet ReleaseTaskCpuResources() = 0; + virtual void AcquireTaskCpuResources(const ResourceIdSet &cpu_resources) = 0; + + virtual Status AssignTask(const Task &task, const ResourceIdSet &resource_id_set) = 0; + virtual void DirectActorCallArgWaitComplete(int64_t tag) = 0; + + // Setter, geter, and clear methods for allocated_instances_. + virtual void SetAllocatedInstances( + std::shared_ptr &allocated_instances) = 0; + + virtual std::shared_ptr GetAllocatedInstances() = 0; + + virtual void ClearAllocatedInstances() = 0; + + virtual void SetLifetimeAllocatedInstances( + std::shared_ptr &allocated_instances) = 0; + virtual std::shared_ptr GetLifetimeAllocatedInstances() = 0; + + virtual void ClearLifetimeAllocatedInstances() = 0; + + virtual void SetBorrowedCPUInstances(std::vector &cpu_instances) = 0; + + virtual std::vector &GetBorrowedCPUInstances() = 0; + + virtual void ClearBorrowedCPUInstances() = 0; + + virtual Task &GetAssignedTask() = 0; + + virtual void SetAssignedTask(Task &assigned_task) = 0; + + virtual bool IsRegistered() = 0; + + virtual rpc::CoreWorkerClient *rpc_client() = 0; +}; + /// Worker class encapsulates the implementation details of a worker. A worker /// is the execution container around a unit of Ray work, such as a task or an /// actor. Ray units of work execute in the context of a Worker. -class Worker { +class Worker : public WorkerInterface { public: /// A constructor that initializes a worker object. /// NOTE: You MUST manually set the worker process. @@ -84,12 +165,8 @@ class Worker { ResourceIdSet ReleaseTaskCpuResources(); void AcquireTaskCpuResources(const ResourceIdSet &cpu_resources); - const std::unordered_set &GetActiveObjectIds() const; - void SetActiveObjectIds(const std::unordered_set &&object_ids); - Status AssignTask(const Task &task, const ResourceIdSet &resource_id_set); void DirectActorCallArgWaitComplete(int64_t tag); - void WorkerLeaseGranted(const std::string &address, int port); // Setter, geter, and clear methods for allocated_instances_. void SetAllocatedInstances( diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 19693114d..ed08bc336 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -29,8 +29,8 @@ namespace { // A helper function to get a worker from a list. -std::shared_ptr GetWorker( - const std::unordered_set> &worker_pool, +std::shared_ptr GetWorker( + const std::unordered_set> &worker_pool, const std::shared_ptr &connection) { for (auto it = worker_pool.begin(); it != worker_pool.end(); it++) { if ((*it)->Connection() == connection) { @@ -42,8 +42,9 @@ std::shared_ptr GetWorker( // A helper function to remove a worker from a list. Returns true if the worker // was found and removed. -bool RemoveWorker(std::unordered_set> &worker_pool, - const std::shared_ptr &worker) { +bool RemoveWorker( + std::unordered_set> &worker_pool, + const std::shared_ptr &worker) { return worker_pool.erase(worker) > 0; } @@ -326,8 +327,8 @@ void WorkerPool::MarkPortAsFree(int port) { } } -Status WorkerPool::RegisterWorker(const std::shared_ptr &worker, pid_t pid, - int *port) { +Status WorkerPool::RegisterWorker(const std::shared_ptr &worker, + pid_t pid, int *port) { auto &state = GetStateForLanguage(worker->GetLanguage()); auto it = state.starting_worker_processes.find(Process::FromPid(pid)); if (it == state.starting_worker_processes.end()) { @@ -347,7 +348,8 @@ Status WorkerPool::RegisterWorker(const std::shared_ptr &worker, pid_t p return Status::OK(); } -Status WorkerPool::RegisterDriver(const std::shared_ptr &driver, int *port) { +Status WorkerPool::RegisterDriver(const std::shared_ptr &driver, + int *port) { RAY_CHECK(!driver->GetAssignedTaskId().IsNil()); RAY_RETURN_NOT_OK(GetNextFreePort(port)); driver->SetAssignedPort(*port); @@ -356,7 +358,7 @@ Status WorkerPool::RegisterDriver(const std::shared_ptr &driver, int *po return Status::OK(); } -std::shared_ptr WorkerPool::GetRegisteredWorker( +std::shared_ptr WorkerPool::GetRegisteredWorker( const std::shared_ptr &connection) const { for (const auto &entry : states_by_lang_) { auto worker = GetWorker(entry.second.registered_workers, connection); @@ -367,7 +369,7 @@ std::shared_ptr WorkerPool::GetRegisteredWorker( return nullptr; } -std::shared_ptr WorkerPool::GetRegisteredDriver( +std::shared_ptr WorkerPool::GetRegisteredDriver( const std::shared_ptr &connection) const { for (const auto &entry : states_by_lang_) { auto driver = GetWorker(entry.second.registered_drivers, connection); @@ -378,7 +380,7 @@ std::shared_ptr WorkerPool::GetRegisteredDriver( return nullptr; } -void WorkerPool::PushWorker(const std::shared_ptr &worker) { +void WorkerPool::PushWorker(const std::shared_ptr &worker) { // Since the worker is now idle, unset its assigned task ID. RAY_CHECK(worker->GetAssignedTaskId().IsNil()) << "Idle workers cannot have an assigned task ID"; @@ -401,10 +403,11 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { } } -std::shared_ptr WorkerPool::PopWorker(const TaskSpecification &task_spec) { +std::shared_ptr WorkerPool::PopWorker( + const TaskSpecification &task_spec) { auto &state = GetStateForLanguage(task_spec.GetLanguage()); - std::shared_ptr worker = nullptr; + std::shared_ptr worker = nullptr; Process proc; if (task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) { // Code path of actor creation task with dynamic worker options. @@ -455,7 +458,7 @@ std::shared_ptr WorkerPool::PopWorker(const TaskSpecification &task_spec return worker; } -bool WorkerPool::DisconnectWorker(const std::shared_ptr &worker) { +bool WorkerPool::DisconnectWorker(const std::shared_ptr &worker) { auto &state = GetStateForLanguage(worker->GetLanguage()); RAY_CHECK(RemoveWorker(state.registered_workers, worker)); @@ -467,7 +470,7 @@ bool WorkerPool::DisconnectWorker(const std::shared_ptr &worker) { return RemoveWorker(state.idle, worker); } -void WorkerPool::DisconnectDriver(const std::shared_ptr &driver) { +void WorkerPool::DisconnectDriver(const std::shared_ptr &driver) { auto &state = GetStateForLanguage(driver->GetLanguage()); RAY_CHECK(RemoveWorker(state.registered_drivers, driver)); stats::CurrentDriver().Record( @@ -482,9 +485,9 @@ inline WorkerPool::State &WorkerPool::GetStateForLanguage(const Language &langua return state->second; } -std::vector> WorkerPool::GetWorkersRunningTasksForJob( +std::vector> WorkerPool::GetWorkersRunningTasksForJob( const JobID &job_id) const { - std::vector> workers; + std::vector> workers; for (const auto &entry : states_by_lang_) { for (const auto &worker : entry.second.registered_workers) { @@ -497,8 +500,9 @@ std::vector> WorkerPool::GetWorkersRunningTasksForJob( return workers; } -const std::vector> WorkerPool::GetAllRegisteredWorkers() const { - std::vector> workers; +const std::vector> WorkerPool::GetAllRegisteredWorkers() + const { + std::vector> workers; for (const auto &entry : states_by_lang_) { for (const auto &worker : entry.second.registered_workers) { @@ -511,8 +515,9 @@ const std::vector> WorkerPool::GetAllRegisteredWorkers() return workers; } -const std::vector> WorkerPool::GetAllRegisteredDrivers() const { - std::vector> drivers; +const std::vector> WorkerPool::GetAllRegisteredDrivers() + const { + std::vector> drivers; for (const auto &entry : states_by_lang_) { for (const auto &driver : entry.second.registered_drivers) { diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 322d8b230..ee25afe70 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -36,13 +36,35 @@ namespace raylet { using WorkerCommandMap = std::unordered_map, std::hash>; +/// \class WorkerPoolInterface +/// +/// Used for new scheduler unit tests. +class WorkerPoolInterface { + public: + /// Pop an idle worker from the pool. The caller is responsible for pushing + /// the worker back onto the pool once the worker has completed its work. + /// + /// \param task_spec The returned worker must be able to execute this task. + /// \return An idle worker with the requested task spec. Returns nullptr if no + /// such worker exists. + virtual std::shared_ptr PopWorker( + const TaskSpecification &task_spec) = 0; + /// Add an idle worker to the pool. + /// + /// \param The idle worker to add. + virtual void PushWorker(const std::shared_ptr &worker) = 0; + + virtual ~WorkerPoolInterface(){}; +}; + +class WorkerInterface; class Worker; /// \class WorkerPool /// /// The WorkerPool is responsible for managing a pool of Workers. Each Worker /// is a container for a unit of work. -class WorkerPool { +class WorkerPool : public WorkerPoolInterface { public: /// Create a pool and asynchronously start at least the specified number of workers per /// language. @@ -81,7 +103,8 @@ class WorkerPool { /// \param[out] port The port that this worker's gRPC server should listen on. /// Returns 0 if the worker should bind on a random port. /// \return If the registration is successful. - Status RegisterWorker(const std::shared_ptr &worker, pid_t pid, int *port); + Status RegisterWorker(const std::shared_ptr &worker, pid_t pid, + int *port); /// Register a new driver. /// @@ -89,14 +112,14 @@ class WorkerPool { /// \param[out] port The port that this driver's gRPC server should listen on. /// Returns 0 if the driver should bind on a random port. /// \return If the registration is successful. - Status RegisterDriver(const std::shared_ptr &worker, int *port); + Status RegisterDriver(const std::shared_ptr &worker, int *port); /// Get the client connection's registered worker. /// /// \param The client connection owned by a registered worker. /// \return The Worker that owns the given client connection. Returns nullptr /// if the client has not registered a worker yet. - std::shared_ptr GetRegisteredWorker( + std::shared_ptr GetRegisteredWorker( const std::shared_ptr &connection) const; /// Get the client connection's registered driver. @@ -104,24 +127,24 @@ class WorkerPool { /// \param The client connection owned by a registered driver. /// \return The Worker that owns the given client connection. Returns nullptr /// if the client has not registered a driver. - std::shared_ptr GetRegisteredDriver( + std::shared_ptr GetRegisteredDriver( const std::shared_ptr &connection) const; /// Disconnect a registered worker. /// /// \param The worker to disconnect. The worker must be registered. /// \return Whether the given worker was in the pool of idle workers. - bool DisconnectWorker(const std::shared_ptr &worker); + bool DisconnectWorker(const std::shared_ptr &worker); /// Disconnect a registered driver. /// /// \param The driver to disconnect. The driver must be registered. - void DisconnectDriver(const std::shared_ptr &driver); + void DisconnectDriver(const std::shared_ptr &driver); /// Add an idle worker to the pool. /// /// \param The idle worker to add. - void PushWorker(const std::shared_ptr &worker); + void PushWorker(const std::shared_ptr &worker); /// Pop an idle worker from the pool. The caller is responsible for pushing /// the worker back onto the pool once the worker has completed its work. @@ -129,7 +152,7 @@ class WorkerPool { /// \param task_spec The returned worker must be able to execute this task. /// \return An idle worker with the requested task spec. Returns nullptr if no /// such worker exists. - std::shared_ptr PopWorker(const TaskSpecification &task_spec); + std::shared_ptr PopWorker(const TaskSpecification &task_spec); /// Return the current size of the worker pool for the requested language. Counts only /// idle workers. @@ -142,18 +165,18 @@ class WorkerPool { /// /// \param job_id The job ID. /// \return A list containing all the workers which are running tasks for the job. - std::vector> GetWorkersRunningTasksForJob( + std::vector> GetWorkersRunningTasksForJob( const JobID &job_id) const; /// Get all the registered workers. /// /// \return A list containing all the workers. - const std::vector> GetAllRegisteredWorkers() const; + const std::vector> GetAllRegisteredWorkers() const; /// Get all the registered drivers. /// /// \return A list containing all the drivers. - const std::vector> GetAllRegisteredDrivers() const; + const std::vector> GetAllRegisteredDrivers() const; /// Whether there is a pending worker for the given task. /// Note that, this is only used for actor creation task with dynamic options. @@ -210,16 +233,16 @@ class WorkerPool { int num_workers_per_process; /// The pool of dedicated workers for actor creation tasks /// with prefix or suffix worker command. - std::unordered_map> idle_dedicated_workers; + std::unordered_map> idle_dedicated_workers; /// The pool of idle non-actor workers. - std::unordered_set> idle; + std::unordered_set> idle; /// The pool of idle actor workers. - std::unordered_map> idle_actor; + std::unordered_map> idle_actor; /// All workers that have registered and are still connected, including both /// idle and executing. - std::unordered_set> registered_workers; + std::unordered_set> registered_workers; /// All drivers that have registered and are still connected. - std::unordered_set> registered_drivers; + std::unordered_set> registered_drivers; /// A map from the pids of starting worker processes /// to the number of their unregistered workers. std::unordered_map starting_worker_processes; diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index ad90ef80b..9027c6dcb 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -100,8 +100,8 @@ class WorkerPoolTest : public ::testing::Test { worker_pool_ = std::unique_ptr(new WorkerPoolMock(io_service_)); } - std::shared_ptr CreateWorker(Process proc, - const Language &language = Language::PYTHON) { + std::shared_ptr CreateWorker( + Process proc, const Language &language = Language::PYTHON) { std::function client_handler = [this](ClientConnection &client) { HandleNewClient(client); }; std::function, int64_t, @@ -115,8 +115,10 @@ class WorkerPoolTest : public ::testing::Test { auto client = ClientConnection::Create(client_handler, message_handler, std::move(socket), "worker", {}, error_message_type_); - std::shared_ptr worker = std::make_shared( + std::shared_ptr worker_ = std::make_shared( WorkerID::FromRandom(), language, "127.0.0.1", client, client_call_manager_); + std::shared_ptr worker = + std::dynamic_pointer_cast(worker_); if (!proc.IsNull()) { worker->SetProcess(proc); } @@ -205,7 +207,7 @@ TEST_F(WorkerPoolTest, CompareWorkerProcessObjects) { TEST_F(WorkerPoolTest, HandleWorkerRegistration) { Process proc = worker_pool_->StartWorkerProcess(Language::JAVA); - std::vector> workers; + std::vector> workers; for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) { workers.push_back(CreateWorker(Process(), Language::JAVA)); } @@ -254,13 +256,13 @@ TEST_F(WorkerPoolTest, InitialWorkerProcessCount) { TEST_F(WorkerPoolTest, HandleWorkerPushPop) { // Try to pop a worker from the empty pool and make sure we don't get one. - std::shared_ptr popped_worker; + std::shared_ptr popped_worker; const auto task_spec = ExampleTaskSpec(); popped_worker = worker_pool_->PopWorker(task_spec); ASSERT_EQ(popped_worker, nullptr); // Create some workers. - std::unordered_set> workers; + std::unordered_set> workers; workers.insert(CreateWorker(Process::CreateNewDummy())); workers.insert(CreateWorker(Process::CreateNewDummy())); // Add the workers to the pool.