[Placement group] Scheduler map refactoring part 1. (#10381)

* In Progress

* done.

* Address code review.
This commit is contained in:
SangBin Cho
2020-08-28 00:57:09 -07:00
committed by GitHub
parent 7b29eb7949
commit d206fbbc99
3 changed files with 265 additions and 58 deletions
@@ -259,31 +259,22 @@ void GcsPlacementGroupScheduler::ScheduleUnplacedBundles(
void GcsPlacementGroupScheduler::DestroyPlacementGroupBundleResourcesIfExists(
const PlacementGroupID &placement_group_id) {
auto it = placement_group_to_bundle_locations_.find(placement_group_id);
const auto &maybe_bundle_locations =
bundle_location_index_.GetBundleLocations(placement_group_id);
// If bundle location has been already removed, it means bundles
// are already destroyed. Do nothing.
if (it == placement_group_to_bundle_locations_.end()) {
if (!maybe_bundle_locations.has_value()) {
return;
}
std::shared_ptr<BundleLocations> bundle_locations = it->second;
for (const auto &iter : *bundle_locations) {
const auto &bundle_locations = maybe_bundle_locations.value();
// Cancel all resource reservation.
for (const auto &iter : *(bundle_locations)) {
auto &bundle_spec = iter.second.second;
auto &node_id = iter.second.first;
CancelResourceReserve(bundle_spec, gcs_node_manager_.GetNode(node_id));
}
placement_group_to_bundle_locations_.erase(it);
// Remove bundles from node_to_leased_bundles_ because bundles are removed now.
for (const auto &bundle_location : *bundle_locations) {
const auto &bundle_id = bundle_location.first;
const auto &node_id = bundle_location.second.first;
const auto &leased_bundles_it = node_to_leased_bundles_.find(node_id);
// node could've been already dead at this point.
if (leased_bundles_it != node_to_leased_bundles_.end()) {
leased_bundles_it->second.erase(bundle_id);
}
}
bundle_location_index_.Erase(placement_group_id);
}
void GcsPlacementGroupScheduler::MarkScheduleCancelled(
@@ -367,7 +358,7 @@ void GcsPlacementGroupScheduler::OnAllBundleSchedulingRequestReturned(
const std::function<void(std::shared_ptr<GcsPlacementGroup>)>
&schedule_success_handler) {
const auto &placement_group_id = placement_group->GetPlacementGroupID();
placement_group_to_bundle_locations_.emplace(placement_group_id, bundle_locations);
bundle_location_index_.AddBundleLocations(placement_group_id, bundle_locations);
if (placement_group_leasing_in_progress_.find(placement_group_id) ==
placement_group_leasing_in_progress_.end() ||
@@ -393,9 +384,6 @@ void GcsPlacementGroupScheduler::OnAllBundleSchedulingRequestReturned(
for (const auto &iter : *bundle_locations) {
const auto &location = iter.second;
const auto &bundle_sepc = location.second;
node_to_leased_bundles_[location.first].emplace(bundle_sepc->BundleId(),
bundle_sepc);
placement_group->GetMutableBundle(location.second->Index())
->set_node_id(location.first.Binary());
}
@@ -411,24 +399,20 @@ void GcsPlacementGroupScheduler::OnAllBundleSchedulingRequestReturned(
std::unique_ptr<ScheduleContext> GcsPlacementGroupScheduler::GetScheduleContext(
const PlacementGroupID &placement_group_id) {
auto &alive_nodes = gcs_node_manager_.GetAllAliveNodes();
for (const auto &iter : alive_nodes) {
if (!node_to_leased_bundles_.contains(iter.first)) {
node_to_leased_bundles_.emplace(
iter.first,
absl::flat_hash_map<BundleID, std::shared_ptr<BundleSpecification>>());
}
}
bundle_location_index_.AddNodes(alive_nodes);
auto node_to_bundles = std::make_shared<absl::flat_hash_map<ClientID, int64_t>>();
for (const auto &iter : node_to_leased_bundles_) {
node_to_bundles->emplace(iter.first, iter.second.size());
for (const auto &node_it : alive_nodes) {
const auto &node_id = node_it.first;
const auto &bundle_locations_on_node =
bundle_location_index_.GetBundleLocationsOnNode(node_id);
RAY_CHECK(bundle_locations_on_node)
<< "Bundle locations haven't been registered for node id " << node_id;
const int bundles_size = bundle_locations_on_node.value()->size();
node_to_bundles->emplace(node_id, bundles_size);
}
std::shared_ptr<BundleLocations> bundle_locations = nullptr;
auto iter = placement_group_to_bundle_locations_.find(placement_group_id);
if (iter != placement_group_to_bundle_locations_.end()) {
bundle_locations = iter->second;
}
auto &bundle_locations = bundle_location_index_.GetBundleLocations(placement_group_id);
return std::unique_ptr<ScheduleContext>(new ScheduleContext(
std::move(node_to_bundles), bundle_locations, gcs_node_manager_));
}
@@ -436,16 +420,107 @@ std::unique_ptr<ScheduleContext> GcsPlacementGroupScheduler::GetScheduleContext(
absl::flat_hash_map<PlacementGroupID, std::vector<int64_t>>
GcsPlacementGroupScheduler::GetBundlesOnNode(const ClientID &node_id) {
absl::flat_hash_map<PlacementGroupID, std::vector<int64_t>> bundles_on_node;
const auto node_iter = node_to_leased_bundles_.find(node_id);
if (node_iter != node_to_leased_bundles_.end()) {
const auto &bundles = node_iter->second;
for (auto &bundle : bundles) {
bundles_on_node[bundle.first.first].push_back(bundle.second->BundleId().second);
const auto &maybe_bundle_locations =
bundle_location_index_.GetBundleLocationsOnNode(node_id);
if (maybe_bundle_locations.has_value()) {
const auto &bundle_locations = maybe_bundle_locations.value();
for (auto &bundle : *bundle_locations) {
const auto &bundle_placement_group_id = bundle.first.first;
const auto &bundle_index = bundle.first.second;
bundles_on_node[bundle_placement_group_id].push_back(bundle_index);
}
node_to_leased_bundles_.erase(node_iter);
bundle_location_index_.Erase(node_id);
}
return bundles_on_node;
}
void BundleLocationIndex::AddBundleLocations(
const PlacementGroupID &placement_group_id,
std::shared_ptr<BundleLocations> bundle_locations) {
placement_group_to_bundle_locations_.emplace(placement_group_id, bundle_locations);
for (auto iter : *bundle_locations) {
const auto &node_id = iter.second.first;
if (!node_to_leased_bundles_.contains(node_id)) {
node_to_leased_bundles_[node_id] = std::make_shared<BundleLocations>();
}
node_to_leased_bundles_[node_id]->emplace(iter.first, iter.second);
}
}
bool BundleLocationIndex::Erase(const ClientID &node_id) {
const auto leased_bundles_it = node_to_leased_bundles_.find(node_id);
if (leased_bundles_it == node_to_leased_bundles_.end()) {
return false;
}
const auto &bundle_locations = leased_bundles_it->second;
for (const auto &bundle_location : *bundle_locations) {
// Remove corresponding placement group id.
const auto &bundle_id = bundle_location.first;
const auto &bundle_spec = bundle_location.second.second;
const auto placement_group_id = bundle_spec->PlacementGroupId();
auto placement_group_it =
placement_group_to_bundle_locations_.find(placement_group_id);
if (placement_group_it != placement_group_to_bundle_locations_.end()) {
auto &pg_bundle_locations = placement_group_it->second;
auto pg_bundle_it = pg_bundle_locations->find(bundle_id);
if (pg_bundle_it != pg_bundle_locations->end()) {
pg_bundle_locations->erase(pg_bundle_it);
}
}
}
node_to_leased_bundles_.erase(leased_bundles_it);
return true;
}
bool BundleLocationIndex::Erase(const PlacementGroupID &placement_group_id) {
auto it = placement_group_to_bundle_locations_.find(placement_group_id);
if (it == placement_group_to_bundle_locations_.end()) {
return false;
}
const auto &bundle_locations = it->second;
// Remove bundles from node_to_leased_bundles_ because bundles are removed now.
for (const auto &bundle_location : *bundle_locations) {
const auto &bundle_id = bundle_location.first;
const auto &node_id = bundle_location.second.first;
const auto leased_bundles_it = node_to_leased_bundles_.find(node_id);
// node could've been already dead at this point.
if (leased_bundles_it != node_to_leased_bundles_.end()) {
leased_bundles_it->second->erase(bundle_id);
}
}
placement_group_to_bundle_locations_.erase(it);
return true;
}
const absl::optional<std::shared_ptr<BundleLocations> const>
BundleLocationIndex::GetBundleLocations(const PlacementGroupID &placement_group_id) {
auto it = placement_group_to_bundle_locations_.find(placement_group_id);
if (it == placement_group_to_bundle_locations_.end()) {
return {};
}
return it->second;
}
const absl::optional<std::shared_ptr<BundleLocations> const>
BundleLocationIndex::GetBundleLocationsOnNode(const ClientID &node_id) {
auto it = node_to_leased_bundles_.find(node_id);
if (it == node_to_leased_bundles_.end()) {
return {};
}
return it->second;
}
void BundleLocationIndex::AddNodes(
const absl::flat_hash_map<ClientID, std::shared_ptr<rpc::GcsNodeInfo>> &nodes) {
for (const auto &iter : nodes) {
if (!node_to_leased_bundles_.contains(iter.first)) {
node_to_leased_bundles_[iter.first] = std::make_shared<BundleLocations>();
}
}
}
} // namespace gcs
} // namespace ray
@@ -38,7 +38,7 @@ struct pair_hash {
}
};
using ScheduleMap = std::unordered_map<BundleID, ClientID, pair_hash>;
using BundleLocations = std::unordered_map<
using BundleLocations = absl::flat_hash_map<
BundleID, std::pair<ClientID, std::shared_ptr<BundleSpecification>>, pair_hash>;
class GcsPlacementGroup;
@@ -72,10 +72,11 @@ class GcsPlacementGroupSchedulerInterface {
virtual ~GcsPlacementGroupSchedulerInterface() {}
};
/// ScheduleContext provides information that are needed for bundle scheduling decision.
class ScheduleContext {
public:
ScheduleContext(std::shared_ptr<absl::flat_hash_map<ClientID, int64_t>> node_to_bundles,
const std::shared_ptr<BundleLocations> &bundle_locations,
const absl::optional<std::shared_ptr<BundleLocations>> bundle_locations,
const GcsNodeManager &node_manager)
: node_to_bundles_(std::move(node_to_bundles)),
bundle_locations_(bundle_locations),
@@ -84,7 +85,7 @@ class ScheduleContext {
// Key is node id, value is the number of bundles on the node.
const std::shared_ptr<absl::flat_hash_map<ClientID, int64_t>> node_to_bundles_;
// The locations of existing bundles for this placement group.
const std::shared_ptr<BundleLocations> &bundle_locations_;
const absl::optional<std::shared_ptr<BundleLocations>> bundle_locations_;
const GcsNodeManager &node_manager_;
};
@@ -130,6 +131,71 @@ class GcsStrictSpreadStrategy : public GcsScheduleStrategy {
const std::unique_ptr<ScheduleContext> &context) override;
};
/// A data structure that encapsulates information regarding bundle resource leasing
/// status.
class LeasingContext {
// TODO(sang): Implement in the next PR.
};
/// A data structure that helps fast bundle location lookup.
class BundleLocationIndex {
public:
BundleLocationIndex() {}
~BundleLocationIndex() {}
/// Add bundle locations to index.
///
/// \param placement_group_id
/// \param bundle_locations Bundle locations that will be associated with the placement
/// group id.
void AddBundleLocations(const PlacementGroupID &placement_group_id,
std::shared_ptr<BundleLocations> bundle_locations);
/// Erase bundle locations associated with a given node id.
///
/// \param node_id The id of node.
/// \return True if succeed. False otherwise.
bool Erase(const ClientID &node_id);
/// Erase bundle locations associated with a given placement group id.
///
/// \param placement_group_id Placement group id
/// \return True if succeed. False otherwise.
bool Erase(const PlacementGroupID &placement_group_id);
/// Get BundleLocation of placement group id.
///
/// \param placement_group_id Placement group id of this bundle locations.
/// \return Bundle locations that are associated with a given placement group id.
const absl::optional<std::shared_ptr<BundleLocations> const> GetBundleLocations(
const PlacementGroupID &placement_group_id);
/// Get BundleLocation of node id.
///
/// \param node_id Node id of this bundle locations.
/// \return Bundle locations that are associated with a given node id.
const absl::optional<std::shared_ptr<BundleLocations> const> GetBundleLocationsOnNode(
const ClientID &node_id);
/// Update the index to contain new node information. Should be used only when new node
/// is added to the cluster.
///
/// \param alive_nodes map of alive nodes.
void AddNodes(
const absl::flat_hash_map<ClientID, std::shared_ptr<rpc::GcsNodeInfo>> &nodes);
private:
/// Map from node ID to the set of bundles. This is used to lookup bundles at each node
/// when a node is dead.
absl::flat_hash_map<ClientID, std::shared_ptr<BundleLocations>> node_to_leased_bundles_;
/// A map from placement group id to bundle locations.
/// It is used to destroy bundles for the placement group.
/// NOTE: It is a reverse index of `node_to_leased_bundles`.
absl::flat_hash_map<PlacementGroupID, std::shared_ptr<BundleLocations>>
placement_group_to_bundle_locations_;
};
/// GcsPlacementGroupScheduler is responsible for scheduling placement_groups registered
/// to GcsPlacementGroupManager. This class is not thread-safe.
class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface {
@@ -227,6 +293,9 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface {
/// Factory for producing new clients to request leases from remote nodes.
ReserveResourceClientFactoryFn lease_client_factory_;
/// A vector to store all the schedule strategy.
std::vector<std::shared_ptr<GcsScheduleStrategy>> scheduler_strategies_;
/// Map from node ID to the set of bundles for whom we are trying to acquire a lease
/// from that node. This is needed so that we can retry lease requests from the node
/// until we receive a reply or the node is removed.
@@ -234,25 +303,12 @@ class GcsPlacementGroupScheduler : public GcsPlacementGroupSchedulerInterface {
absl::flat_hash_map<ClientID, absl::flat_hash_set<BundleID>>
node_to_bundles_when_leasing_;
/// Map from node ID to the set of bundles. This is needed so that we can reschedule
/// bundles when a node is dead.
absl::flat_hash_map<ClientID,
absl::flat_hash_map<BundleID, std::shared_ptr<BundleSpecification>>>
node_to_leased_bundles_;
/// A vector to store all the schedule strategy.
std::vector<std::shared_ptr<GcsScheduleStrategy>> scheduler_strategies_;
/// Set of placement group that have lease requests in flight to nodes.
/// It is required to know if placement group has been removed or not.
absl::flat_hash_set<PlacementGroupID> placement_group_leasing_in_progress_;
/// A map from placement group id to bundle locations.
/// It is used to destroy bundles for the placement group. When we reschedule bundles,
/// we can get the location of other bundles from here.
/// NOTE: It is a reverse index of `node_to_leased_bundles`.
absl::flat_hash_map<PlacementGroupID, std::shared_ptr<BundleLocations>>
placement_group_to_bundle_locations_;
/// Index to lookup bundle locations of node or placement group.
BundleLocationIndex bundle_location_index_;
};
} // namespace gcs
@@ -549,6 +549,82 @@ TEST_F(GcsPlacementGroupSchedulerTest, TestStrictSpreadStrategyResourceCheck) {
WaitPendingDone(success_placement_groups_, 1);
}
TEST_F(GcsPlacementGroupSchedulerTest, TestBundleLocationIndex) {
gcs::BundleLocationIndex bundle_location_index;
/// Generate data.
const auto node1 = ClientID::FromRandom();
const auto node2 = ClientID::FromRandom();
rpc::CreatePlacementGroupRequest request_pg1 =
Mocker::GenCreatePlacementGroupRequest("pg1");
const auto pg1_id = PlacementGroupID::FromBinary(
request_pg1.placement_group_spec().placement_group_id());
const std::shared_ptr<BundleSpecification> bundle_node1_pg1 =
std::make_shared<BundleSpecification>(
BundleSpecification(request_pg1.placement_group_spec().bundles(0)));
const std::shared_ptr<BundleSpecification> bundle_node2_pg1 =
std::make_shared<BundleSpecification>(
BundleSpecification(request_pg1.placement_group_spec().bundles(1)));
std::shared_ptr<gcs::BundleLocations> bundle_locations_pg1 =
std::make_shared<gcs::BundleLocations>();
(*bundle_locations_pg1)
.emplace(bundle_node1_pg1->BundleId(), std::make_pair(node1, bundle_node1_pg1));
(*bundle_locations_pg1)
.emplace(bundle_node2_pg1->BundleId(), std::make_pair(node2, bundle_node2_pg1));
rpc::CreatePlacementGroupRequest request_pg2 =
Mocker::GenCreatePlacementGroupRequest("pg2");
const auto pg2_id = PlacementGroupID::FromBinary(
request_pg2.placement_group_spec().placement_group_id());
const std::shared_ptr<BundleSpecification> bundle_node1_pg2 =
std::make_shared<BundleSpecification>(
BundleSpecification(request_pg2.placement_group_spec().bundles(0)));
const std::shared_ptr<BundleSpecification> bundle_node2_pg2 =
std::make_shared<BundleSpecification>(
BundleSpecification(request_pg2.placement_group_spec().bundles(1)));
std::shared_ptr<gcs::BundleLocations> bundle_locations_pg2 =
std::make_shared<gcs::BundleLocations>();
(*bundle_locations_pg2)[bundle_node1_pg2->BundleId()] =
std::make_pair(node1, bundle_node1_pg2);
(*bundle_locations_pg2)[bundle_node2_pg2->BundleId()] =
std::make_pair(node2, bundle_node2_pg2);
// Test Addition.
bundle_location_index.AddBundleLocations(pg1_id, bundle_locations_pg1);
bundle_location_index.AddBundleLocations(pg2_id, bundle_locations_pg2);
/// Test Get works
auto bundle_locations = bundle_location_index.GetBundleLocations(pg1_id).value();
ASSERT_TRUE((*bundle_locations).size() == 2);
ASSERT_TRUE((*bundle_locations).contains(bundle_node1_pg1->BundleId()));
ASSERT_TRUE((*bundle_locations).contains(bundle_node2_pg1->BundleId()));
// Make sure pg2 is not in the bundle locations
ASSERT_FALSE((*bundle_locations).contains(bundle_node2_pg2->BundleId()));
auto bundle_locations2 = bundle_location_index.GetBundleLocations(pg2_id).value();
ASSERT_TRUE((*bundle_locations2).size() == 2);
ASSERT_TRUE((*bundle_locations2).contains(bundle_node1_pg2->BundleId()));
ASSERT_TRUE((*bundle_locations2).contains(bundle_node2_pg2->BundleId()));
auto bundle_on_node1 = bundle_location_index.GetBundleLocationsOnNode(node1).value();
ASSERT_TRUE((*bundle_on_node1).size() == 2);
ASSERT_TRUE((*bundle_on_node1).contains(bundle_node1_pg1->BundleId()));
ASSERT_TRUE((*bundle_on_node1).contains(bundle_node1_pg2->BundleId()));
auto bundle_on_node2 = bundle_location_index.GetBundleLocationsOnNode(node2).value();
ASSERT_TRUE((*bundle_on_node2).size() == 2);
ASSERT_TRUE((*bundle_on_node2).contains(bundle_node2_pg1->BundleId()));
ASSERT_TRUE((*bundle_on_node2).contains(bundle_node2_pg2->BundleId()));
/// Test Erase works
bundle_location_index.Erase(pg1_id);
ASSERT_FALSE(bundle_location_index.GetBundleLocations(pg1_id).has_value());
ASSERT_TRUE(bundle_location_index.GetBundleLocations(pg2_id).value()->size() == 2);
bundle_location_index.Erase(node1);
ASSERT_FALSE(bundle_location_index.GetBundleLocationsOnNode(node1).has_value());
ASSERT_TRUE(bundle_location_index.GetBundleLocations(pg2_id).value()->size() == 1);
ASSERT_TRUE(bundle_location_index.GetBundleLocationsOnNode(node2).value()->size() == 1);
}
} // namespace ray
int main(int argc, char **argv) {