From d206fbbc991dff60d152151ef19af5412f6774b1 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 28 Aug 2020 00:57:09 -0700 Subject: [PATCH] [Placement group] Scheduler map refactoring part 1. (#10381) * In Progress * done. * Address code review. --- .../gcs_placement_group_scheduler.cc | 155 +++++++++++++----- .../gcs_placement_group_scheduler.h | 92 +++++++++-- .../gcs_placement_group_scheduler_test.cc | 76 +++++++++ 3 files changed, 265 insertions(+), 58 deletions(-) diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc index 3d764460f..20ed16cac 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc @@ -259,31 +259,22 @@ void GcsPlacementGroupScheduler::ScheduleUnplacedBundles( void GcsPlacementGroupScheduler::DestroyPlacementGroupBundleResourcesIfExists( const PlacementGroupID &placement_group_id) { - auto it = placement_group_to_bundle_locations_.find(placement_group_id); + const auto &maybe_bundle_locations = + bundle_location_index_.GetBundleLocations(placement_group_id); // If bundle location has been already removed, it means bundles // are already destroyed. Do nothing. - if (it == placement_group_to_bundle_locations_.end()) { + if (!maybe_bundle_locations.has_value()) { return; } - std::shared_ptr bundle_locations = it->second; - for (const auto &iter : *bundle_locations) { + const auto &bundle_locations = maybe_bundle_locations.value(); + // Cancel all resource reservation. + for (const auto &iter : *(bundle_locations)) { auto &bundle_spec = iter.second.second; auto &node_id = iter.second.first; CancelResourceReserve(bundle_spec, gcs_node_manager_.GetNode(node_id)); } - placement_group_to_bundle_locations_.erase(it); - - // Remove bundles from node_to_leased_bundles_ because bundles are removed now. - for (const auto &bundle_location : *bundle_locations) { - const auto &bundle_id = bundle_location.first; - const auto &node_id = bundle_location.second.first; - const auto &leased_bundles_it = node_to_leased_bundles_.find(node_id); - // node could've been already dead at this point. - if (leased_bundles_it != node_to_leased_bundles_.end()) { - leased_bundles_it->second.erase(bundle_id); - } - } + bundle_location_index_.Erase(placement_group_id); } void GcsPlacementGroupScheduler::MarkScheduleCancelled( @@ -367,7 +358,7 @@ void GcsPlacementGroupScheduler::OnAllBundleSchedulingRequestReturned( const std::function)> &schedule_success_handler) { const auto &placement_group_id = placement_group->GetPlacementGroupID(); - placement_group_to_bundle_locations_.emplace(placement_group_id, bundle_locations); + bundle_location_index_.AddBundleLocations(placement_group_id, bundle_locations); if (placement_group_leasing_in_progress_.find(placement_group_id) == placement_group_leasing_in_progress_.end() || @@ -393,9 +384,6 @@ void GcsPlacementGroupScheduler::OnAllBundleSchedulingRequestReturned( for (const auto &iter : *bundle_locations) { const auto &location = iter.second; - const auto &bundle_sepc = location.second; - node_to_leased_bundles_[location.first].emplace(bundle_sepc->BundleId(), - bundle_sepc); placement_group->GetMutableBundle(location.second->Index()) ->set_node_id(location.first.Binary()); } @@ -411,24 +399,20 @@ void GcsPlacementGroupScheduler::OnAllBundleSchedulingRequestReturned( std::unique_ptr GcsPlacementGroupScheduler::GetScheduleContext( const PlacementGroupID &placement_group_id) { auto &alive_nodes = gcs_node_manager_.GetAllAliveNodes(); - for (const auto &iter : alive_nodes) { - if (!node_to_leased_bundles_.contains(iter.first)) { - node_to_leased_bundles_.emplace( - iter.first, - absl::flat_hash_map>()); - } - } + bundle_location_index_.AddNodes(alive_nodes); auto node_to_bundles = std::make_shared>(); - for (const auto &iter : node_to_leased_bundles_) { - node_to_bundles->emplace(iter.first, iter.second.size()); + for (const auto &node_it : alive_nodes) { + const auto &node_id = node_it.first; + const auto &bundle_locations_on_node = + bundle_location_index_.GetBundleLocationsOnNode(node_id); + RAY_CHECK(bundle_locations_on_node) + << "Bundle locations haven't been registered for node id " << node_id; + const int bundles_size = bundle_locations_on_node.value()->size(); + node_to_bundles->emplace(node_id, bundles_size); } - std::shared_ptr bundle_locations = nullptr; - auto iter = placement_group_to_bundle_locations_.find(placement_group_id); - if (iter != placement_group_to_bundle_locations_.end()) { - bundle_locations = iter->second; - } + auto &bundle_locations = bundle_location_index_.GetBundleLocations(placement_group_id); return std::unique_ptr(new ScheduleContext( std::move(node_to_bundles), bundle_locations, gcs_node_manager_)); } @@ -436,16 +420,107 @@ std::unique_ptr GcsPlacementGroupScheduler::GetScheduleContext( absl::flat_hash_map> GcsPlacementGroupScheduler::GetBundlesOnNode(const ClientID &node_id) { absl::flat_hash_map> bundles_on_node; - const auto node_iter = node_to_leased_bundles_.find(node_id); - if (node_iter != node_to_leased_bundles_.end()) { - const auto &bundles = node_iter->second; - for (auto &bundle : bundles) { - bundles_on_node[bundle.first.first].push_back(bundle.second->BundleId().second); + const auto &maybe_bundle_locations = + bundle_location_index_.GetBundleLocationsOnNode(node_id); + if (maybe_bundle_locations.has_value()) { + const auto &bundle_locations = maybe_bundle_locations.value(); + for (auto &bundle : *bundle_locations) { + const auto &bundle_placement_group_id = bundle.first.first; + const auto &bundle_index = bundle.first.second; + bundles_on_node[bundle_placement_group_id].push_back(bundle_index); } - node_to_leased_bundles_.erase(node_iter); + bundle_location_index_.Erase(node_id); } return bundles_on_node; } +void BundleLocationIndex::AddBundleLocations( + const PlacementGroupID &placement_group_id, + std::shared_ptr bundle_locations) { + placement_group_to_bundle_locations_.emplace(placement_group_id, bundle_locations); + for (auto iter : *bundle_locations) { + const auto &node_id = iter.second.first; + if (!node_to_leased_bundles_.contains(node_id)) { + node_to_leased_bundles_[node_id] = std::make_shared(); + } + node_to_leased_bundles_[node_id]->emplace(iter.first, iter.second); + } +} + +bool BundleLocationIndex::Erase(const ClientID &node_id) { + const auto leased_bundles_it = node_to_leased_bundles_.find(node_id); + if (leased_bundles_it == node_to_leased_bundles_.end()) { + return false; + } + + const auto &bundle_locations = leased_bundles_it->second; + for (const auto &bundle_location : *bundle_locations) { + // Remove corresponding placement group id. + const auto &bundle_id = bundle_location.first; + const auto &bundle_spec = bundle_location.second.second; + const auto placement_group_id = bundle_spec->PlacementGroupId(); + auto placement_group_it = + placement_group_to_bundle_locations_.find(placement_group_id); + if (placement_group_it != placement_group_to_bundle_locations_.end()) { + auto &pg_bundle_locations = placement_group_it->second; + auto pg_bundle_it = pg_bundle_locations->find(bundle_id); + if (pg_bundle_it != pg_bundle_locations->end()) { + pg_bundle_locations->erase(pg_bundle_it); + } + } + } + node_to_leased_bundles_.erase(leased_bundles_it); + return true; +} + +bool BundleLocationIndex::Erase(const PlacementGroupID &placement_group_id) { + auto it = placement_group_to_bundle_locations_.find(placement_group_id); + if (it == placement_group_to_bundle_locations_.end()) { + return false; + } + + const auto &bundle_locations = it->second; + // Remove bundles from node_to_leased_bundles_ because bundles are removed now. + for (const auto &bundle_location : *bundle_locations) { + const auto &bundle_id = bundle_location.first; + const auto &node_id = bundle_location.second.first; + const auto leased_bundles_it = node_to_leased_bundles_.find(node_id); + // node could've been already dead at this point. + if (leased_bundles_it != node_to_leased_bundles_.end()) { + leased_bundles_it->second->erase(bundle_id); + } + } + placement_group_to_bundle_locations_.erase(it); + + return true; +} + +const absl::optional const> +BundleLocationIndex::GetBundleLocations(const PlacementGroupID &placement_group_id) { + auto it = placement_group_to_bundle_locations_.find(placement_group_id); + if (it == placement_group_to_bundle_locations_.end()) { + return {}; + } + return it->second; +} + +const absl::optional const> +BundleLocationIndex::GetBundleLocationsOnNode(const ClientID &node_id) { + auto it = node_to_leased_bundles_.find(node_id); + if (it == node_to_leased_bundles_.end()) { + return {}; + } + return it->second; +} + +void BundleLocationIndex::AddNodes( + const absl::flat_hash_map> &nodes) { + for (const auto &iter : nodes) { + if (!node_to_leased_bundles_.contains(iter.first)) { + node_to_leased_bundles_[iter.first] = std::make_shared(); + } + } +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h index db54eea73..4a80dbeee 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.h @@ -38,7 +38,7 @@ struct pair_hash { } }; using ScheduleMap = std::unordered_map; -using BundleLocations = std::unordered_map< +using BundleLocations = absl::flat_hash_map< BundleID, std::pair>, pair_hash>; class GcsPlacementGroup; @@ -72,10 +72,11 @@ class GcsPlacementGroupSchedulerInterface { virtual ~GcsPlacementGroupSchedulerInterface() {} }; +/// ScheduleContext provides information that are needed for bundle scheduling decision. class ScheduleContext { public: ScheduleContext(std::shared_ptr> node_to_bundles, - const std::shared_ptr &bundle_locations, + const absl::optional> bundle_locations, const GcsNodeManager &node_manager) : node_to_bundles_(std::move(node_to_bundles)), bundle_locations_(bundle_locations), @@ -84,7 +85,7 @@ class ScheduleContext { // Key is node id, value is the number of bundles on the node. const std::shared_ptr> node_to_bundles_; // The locations of existing bundles for this placement group. - const std::shared_ptr &bundle_locations_; + const absl::optional> bundle_locations_; const GcsNodeManager &node_manager_; }; @@ -130,6 +131,71 @@ class GcsStrictSpreadStrategy : public GcsScheduleStrategy { const std::unique_ptr &context) override; }; +/// A data structure that encapsulates information regarding bundle resource leasing +/// status. +class LeasingContext { + // TODO(sang): Implement in the next PR. +}; + +/// A data structure that helps fast bundle location lookup. +class BundleLocationIndex { + public: + BundleLocationIndex() {} + ~BundleLocationIndex() {} + + /// Add bundle locations to index. + /// + /// \param placement_group_id + /// \param bundle_locations Bundle locations that will be associated with the placement + /// group id. + void AddBundleLocations(const PlacementGroupID &placement_group_id, + std::shared_ptr bundle_locations); + + /// Erase bundle locations associated with a given node id. + /// + /// \param node_id The id of node. + /// \return True if succeed. False otherwise. + bool Erase(const ClientID &node_id); + + /// Erase bundle locations associated with a given placement group id. + /// + /// \param placement_group_id Placement group id + /// \return True if succeed. False otherwise. + bool Erase(const PlacementGroupID &placement_group_id); + + /// Get BundleLocation of placement group id. + /// + /// \param placement_group_id Placement group id of this bundle locations. + /// \return Bundle locations that are associated with a given placement group id. + const absl::optional const> GetBundleLocations( + const PlacementGroupID &placement_group_id); + + /// Get BundleLocation of node id. + /// + /// \param node_id Node id of this bundle locations. + /// \return Bundle locations that are associated with a given node id. + const absl::optional const> GetBundleLocationsOnNode( + const ClientID &node_id); + + /// Update the index to contain new node information. Should be used only when new node + /// is added to the cluster. + /// + /// \param alive_nodes map of alive nodes. + void AddNodes( + const absl::flat_hash_map> &nodes); + + private: + /// Map from node ID to the set of bundles. This is used to lookup bundles at each node + /// when a node is dead. + absl::flat_hash_map> node_to_leased_bundles_; + + /// A map from placement group id to bundle locations. + /// It is used to destroy bundles for the placement group. + /// NOTE: It is a reverse index of `node_to_leased_bundles`. + absl::flat_hash_map> + placement_group_to_bundle_locations_; +}; + /// GcsPlacementGroupScheduler is responsible for scheduling placement_groups registered /// to GcsPlacementGroupManager. This class is not thread-safe. class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { @@ -227,6 +293,9 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { /// Factory for producing new clients to request leases from remote nodes. ReserveResourceClientFactoryFn lease_client_factory_; + /// A vector to store all the schedule strategy. + std::vector> scheduler_strategies_; + /// Map from node ID to the set of bundles for whom we are trying to acquire a lease /// from that node. This is needed so that we can retry lease requests from the node /// until we receive a reply or the node is removed. @@ -234,25 +303,12 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface { absl::flat_hash_map> node_to_bundles_when_leasing_; - /// Map from node ID to the set of bundles. This is needed so that we can reschedule - /// bundles when a node is dead. - absl::flat_hash_map>> - node_to_leased_bundles_; - - /// A vector to store all the schedule strategy. - std::vector> scheduler_strategies_; - /// Set of placement group that have lease requests in flight to nodes. /// It is required to know if placement group has been removed or not. absl::flat_hash_set placement_group_leasing_in_progress_; - /// A map from placement group id to bundle locations. - /// It is used to destroy bundles for the placement group. When we reschedule bundles, - /// we can get the location of other bundles from here. - /// NOTE: It is a reverse index of `node_to_leased_bundles`. - absl::flat_hash_map> - placement_group_to_bundle_locations_; + /// Index to lookup bundle locations of node or placement group. + BundleLocationIndex bundle_location_index_; }; } // namespace gcs diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc index e66d3d3fd..a42c7ea42 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc @@ -549,6 +549,82 @@ TEST_F(GcsPlacementGroupSchedulerTest, TestStrictSpreadStrategyResourceCheck) { WaitPendingDone(success_placement_groups_, 1); } +TEST_F(GcsPlacementGroupSchedulerTest, TestBundleLocationIndex) { + gcs::BundleLocationIndex bundle_location_index; + /// Generate data. + const auto node1 = ClientID::FromRandom(); + const auto node2 = ClientID::FromRandom(); + rpc::CreatePlacementGroupRequest request_pg1 = + Mocker::GenCreatePlacementGroupRequest("pg1"); + const auto pg1_id = PlacementGroupID::FromBinary( + request_pg1.placement_group_spec().placement_group_id()); + const std::shared_ptr bundle_node1_pg1 = + std::make_shared( + BundleSpecification(request_pg1.placement_group_spec().bundles(0))); + const std::shared_ptr bundle_node2_pg1 = + std::make_shared( + BundleSpecification(request_pg1.placement_group_spec().bundles(1))); + std::shared_ptr bundle_locations_pg1 = + std::make_shared(); + (*bundle_locations_pg1) + .emplace(bundle_node1_pg1->BundleId(), std::make_pair(node1, bundle_node1_pg1)); + (*bundle_locations_pg1) + .emplace(bundle_node2_pg1->BundleId(), std::make_pair(node2, bundle_node2_pg1)); + + rpc::CreatePlacementGroupRequest request_pg2 = + Mocker::GenCreatePlacementGroupRequest("pg2"); + const auto pg2_id = PlacementGroupID::FromBinary( + request_pg2.placement_group_spec().placement_group_id()); + const std::shared_ptr bundle_node1_pg2 = + std::make_shared( + BundleSpecification(request_pg2.placement_group_spec().bundles(0))); + const std::shared_ptr bundle_node2_pg2 = + std::make_shared( + BundleSpecification(request_pg2.placement_group_spec().bundles(1))); + std::shared_ptr bundle_locations_pg2 = + std::make_shared(); + (*bundle_locations_pg2)[bundle_node1_pg2->BundleId()] = + std::make_pair(node1, bundle_node1_pg2); + (*bundle_locations_pg2)[bundle_node2_pg2->BundleId()] = + std::make_pair(node2, bundle_node2_pg2); + + // Test Addition. + bundle_location_index.AddBundleLocations(pg1_id, bundle_locations_pg1); + bundle_location_index.AddBundleLocations(pg2_id, bundle_locations_pg2); + + /// Test Get works + auto bundle_locations = bundle_location_index.GetBundleLocations(pg1_id).value(); + ASSERT_TRUE((*bundle_locations).size() == 2); + ASSERT_TRUE((*bundle_locations).contains(bundle_node1_pg1->BundleId())); + ASSERT_TRUE((*bundle_locations).contains(bundle_node2_pg1->BundleId())); + // Make sure pg2 is not in the bundle locations + ASSERT_FALSE((*bundle_locations).contains(bundle_node2_pg2->BundleId())); + + auto bundle_locations2 = bundle_location_index.GetBundleLocations(pg2_id).value(); + ASSERT_TRUE((*bundle_locations2).size() == 2); + ASSERT_TRUE((*bundle_locations2).contains(bundle_node1_pg2->BundleId())); + ASSERT_TRUE((*bundle_locations2).contains(bundle_node2_pg2->BundleId())); + + auto bundle_on_node1 = bundle_location_index.GetBundleLocationsOnNode(node1).value(); + ASSERT_TRUE((*bundle_on_node1).size() == 2); + ASSERT_TRUE((*bundle_on_node1).contains(bundle_node1_pg1->BundleId())); + ASSERT_TRUE((*bundle_on_node1).contains(bundle_node1_pg2->BundleId())); + + auto bundle_on_node2 = bundle_location_index.GetBundleLocationsOnNode(node2).value(); + ASSERT_TRUE((*bundle_on_node2).size() == 2); + ASSERT_TRUE((*bundle_on_node2).contains(bundle_node2_pg1->BundleId())); + ASSERT_TRUE((*bundle_on_node2).contains(bundle_node2_pg2->BundleId())); + + /// Test Erase works + bundle_location_index.Erase(pg1_id); + ASSERT_FALSE(bundle_location_index.GetBundleLocations(pg1_id).has_value()); + ASSERT_TRUE(bundle_location_index.GetBundleLocations(pg2_id).value()->size() == 2); + bundle_location_index.Erase(node1); + ASSERT_FALSE(bundle_location_index.GetBundleLocationsOnNode(node1).has_value()); + ASSERT_TRUE(bundle_location_index.GetBundleLocations(pg2_id).value()->size() == 1); + ASSERT_TRUE(bundle_location_index.GetBundleLocationsOnNode(node2).value()->size() == 1); +} + } // namespace ray int main(int argc, char **argv) {