From fb89f9c2c856ab143661f4116e645b47587d6db4 Mon Sep 17 00:00:00 2001 From: "DK.Pino" Date: Fri, 5 Feb 2021 11:04:51 +0800 Subject: [PATCH] [Placement Group] Support named placement group (#13755) --- doc/source/placement-group.rst | 35 ++++++++ python/ray/includes/global_state_accessor.pxd | 2 + python/ray/includes/global_state_accessor.pxi | 10 +++ python/ray/state.py | 14 +++ python/ray/tests/test_placement_group.py | 86 ++++++++++++++++++- python/ray/util/__init__.py | 4 +- python/ray/util/placement_group.py | 26 +++++- src/ray/gcs/accessor.h | 10 ++- .../gcs/gcs_client/global_state_accessor.cc | 12 +++ .../gcs/gcs_client/global_state_accessor.h | 13 ++- .../gcs/gcs_client/service_based_accessor.cc | 20 +++++ .../gcs/gcs_client/service_based_accessor.h | 4 + .../gcs_server/gcs_placement_group_manager.cc | 66 ++++++++++++-- .../gcs_server/gcs_placement_group_manager.h | 13 ++- .../test/gcs_placement_group_manager_test.cc | 25 ++++++ src/ray/protobuf/gcs_service.proto | 14 +++ src/ray/rpc/gcs_server/gcs_rpc_client.h | 4 + src/ray/rpc/gcs_server/gcs_rpc_server.h | 5 ++ 18 files changed, 346 insertions(+), 17 deletions(-) diff --git a/doc/source/placement-group.rst b/doc/source/placement-group.rst index 1424b850c..7db38fd84 100644 --- a/doc/source/placement-group.rst +++ b/doc/source/placement-group.rst @@ -252,6 +252,41 @@ Note that you can anytime remove the placement group to clean up resources. ray.shutdown() +Named Placement Groups +---------------------- + +A placement group can be given a globally unique name. +This allows you to retrieve the placement group from any job in the Ray cluster. +This can be useful if you cannot directly pass the placement group handle to +the actor or task that needs it, or if you are trying to +access a placement group launched by another driver. +Note that the placement group will still be destroyed if it's lifetime isn't `detached`. +See :ref:`placement-group-lifetimes` for more details. + +.. tabs:: + .. group-tab:: Python + + .. code-block:: python + + # first_driver.py + # Create a placement group with a global name. + pg = placement_group([{"CPU": 2}, {"CPU": 2}], strategy="STRICT_SPREAD", lifetime="detached", name="global_name") + ray.get(pg.ready()) + + Then, we can retrieve the actor later somewhere. + + .. code-block:: python + + # second_driver.py + # Retrieve a placement group with a global name. + pg = ray.util.get_placement_group("global_name") + + .. group-tab:: Java + + The named placement group is not implemented for Java APIs yet. + +.. _placement-group-lifetimes: + Placement Group Lifetimes ------------------------- diff --git a/python/ray/includes/global_state_accessor.pxd b/python/ray/includes/global_state_accessor.pxd index 31418f10c..e27aa0547 100644 --- a/python/ray/includes/global_state_accessor.pxd +++ b/python/ray/includes/global_state_accessor.pxd @@ -32,4 +32,6 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: c_bool AddWorkerInfo(const c_string &serialized_string) unique_ptr[c_string] GetPlacementGroupInfo( const CPlacementGroupID &placement_group_id) + unique_ptr[c_string] GetPlacementGroupByName( + const c_string &placement_group_name) c_vector[c_string] GetAllPlacementGroupInfo() diff --git a/python/ray/includes/global_state_accessor.pxi b/python/ray/includes/global_state_accessor.pxi index cbb1bac0a..5690d3bab 100644 --- a/python/ray/includes/global_state_accessor.pxi +++ b/python/ray/includes/global_state_accessor.pxi @@ -147,3 +147,13 @@ cdef class GlobalStateAccessor: if result: return c_string(result.get().data(), result.get().size()) return None + + def get_placement_group_by_name(self, placement_group_name): + cdef unique_ptr[c_string] result + cdef c_string cplacement_group_name = placement_group_name + with nogil: + result = self.inner.get().GetPlacementGroupByName( + cplacement_group_name) + if result: + return c_string(result.get().data(), result.get().size()) + return None diff --git a/python/ray/state.py b/python/ray/state.py index aa3488e20..7524ea124 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -388,6 +388,20 @@ class GlobalState: return dict(result) + def get_placement_group_by_name(self, placement_group_name): + self._check_connected() + + placement_group_info = ( + self.global_state_accessor.get_placement_group_by_name( + placement_group_name)) + if placement_group_info is None: + return None + else: + placement_group_table_data = \ + gcs_utils.PlacementGroupTableData.FromString( + placement_group_info) + return self._gen_placement_group_info(placement_group_table_data) + def placement_group_table(self, placement_group_id=None): self._check_connected() diff --git a/python/ray/tests/test_placement_group.py b/python/ray/tests/test_placement_group.py index 87273a499..024ff6c55 100644 --- a/python/ray/tests/test_placement_group.py +++ b/python/ray/tests/test_placement_group.py @@ -375,6 +375,7 @@ def test_remove_pending_placement_group(ray_start_cluster): # Create a placement group that cannot be scheduled now. placement_group = ray.util.placement_group([{"GPU": 2}, {"CPU": 2}]) ray.util.remove_placement_group(placement_group) + # TODO(sang): Add state check here. @ray.remote(num_cpus=4) def f(): @@ -797,10 +798,10 @@ def test_mini_integration(ray_start_cluster): pg_tasks = [] # total bundle gpu usage = bundles_per_pg * total_num_pg * per_bundle_gpus # Note this is half of total - for _ in range(total_num_pg): + for index in range(total_num_pg): pgs.append( ray.util.placement_group( - name="name", + name=f"name{index}", strategy="PACK", bundles=[{ "GPU": per_bundle_gpus @@ -1423,5 +1424,86 @@ ray.shutdown() assert assert_alive_num_actor(4) +def test_named_placement_group(ray_start_cluster): + cluster = ray_start_cluster + for _ in range(2): + cluster.add_node(num_cpus=3) + cluster.wait_for_nodes() + info = ray.init(address=cluster.address) + global_placement_group_name = "named_placement_group" + + # Create a detached placement group with name. + driver_code = f""" +import ray + +ray.init(address="{info["redis_address"]}") + +pg = ray.util.placement_group( + [{{"CPU": 1}} for _ in range(2)], + strategy="STRICT_SPREAD", + name="{global_placement_group_name}", + lifetime="detached") +ray.get(pg.ready()) + +ray.shutdown() + """ + + run_string_as_driver(driver_code) + + # Wait until the driver is reported as dead by GCS. + def is_job_done(): + jobs = ray.jobs() + for job in jobs: + if "StopTime" in job: + return True + return False + + wait_for_condition(is_job_done) + + @ray.remote(num_cpus=1) + class Actor: + def ping(self): + return "pong" + + # Get the named placement group and schedule a actor. + placement_group = ray.util.get_placement_group(global_placement_group_name) + assert placement_group is not None + assert placement_group.wait(5) + actor = Actor.options( + placement_group=placement_group, + placement_group_bundle_index=0).remote() + + ray.get(actor.ping.remote()) + + # Create another placement group and make sure its creation will failed. + same_name_pg = ray.util.placement_group( + [{ + "CPU": 1 + } for _ in range(2)], + strategy="STRICT_SPREAD", + name=global_placement_group_name) + assert not same_name_pg.wait(10) + + # Remove a named placement group and make sure the second creation + # will successful. + ray.util.remove_placement_group(placement_group) + same_name_pg = ray.util.placement_group( + [{ + "CPU": 1 + } for _ in range(2)], + strategy="STRICT_SPREAD", + name=global_placement_group_name) + assert same_name_pg.wait(10) + + # Get a named placement group with a name that doesn't exist + # and make sure it will raise ValueError correctly. + error_count = 0 + try: + ray.util.get_placement_group("inexistent_pg") + except ValueError: + error_count = error_count + 1 + assert error_count == 1 + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/__init__.py b/python/ray/util/__init__.py index b682f15dc..d20bac2a3 100644 --- a/python/ray/util/__init__.py +++ b/python/ray/util/__init__.py @@ -4,7 +4,8 @@ from ray.util.check_serialize import inspect_serializability from ray.util.debug import log_once, disable_log_once_globally, \ enable_periodic_logging from ray.util.placement_group import (placement_group, placement_group_table, - remove_placement_group) + remove_placement_group, + get_placement_group) from ray.util import rpdb as pdb from ray.util.serialization import register_serializer, deregister_serializer @@ -19,6 +20,7 @@ __all__ = [ "pdb", "placement_group", "placement_group_table", + "get_placement_group", "remove_placement_group", "inspect_serializability", "collective", diff --git a/python/ray/util/placement_group.py b/python/ray/util/placement_group.py index 6d15f607f..c723f77d3 100644 --- a/python/ray/util/placement_group.py +++ b/python/ray/util/placement_group.py @@ -4,6 +4,7 @@ from typing import (List, Dict, Optional, Union) import ray from ray._raylet import PlacementGroupID, ObjectRef +from ray.utils import hex_to_binary bundle_reservation_check = None @@ -145,7 +146,7 @@ class PlacementGroup: def placement_group(bundles: List[Dict[str, float]], strategy: str = "PACK", - name: str = "unnamed_group", + name: str = "", lifetime=None) -> PlacementGroup: """Asynchronously creates a PlacementGroup. @@ -211,6 +212,29 @@ def remove_placement_group(placement_group: PlacementGroup): worker.core_worker.remove_placement_group(placement_group.id) +def get_placement_group(placement_group_name: str): + """Get a placement group object with a global name. + + Returns: + None if can't find a placement group with the given name. + The placement group object otherwise. + """ + if not placement_group_name: + raise ValueError( + "Please supply a non-empty value to get_placement_group") + worker = ray.worker.global_worker + worker.check_connected() + placement_group_info = ray.state.state.get_placement_group_by_name( + placement_group_name) + if placement_group_info is None: + raise ValueError( + f"Failed to look up actor with name: {placement_group_name}") + else: + return PlacementGroup( + PlacementGroupID( + hex_to_binary(placement_group_info["placement_group_id"]))) + + def placement_group_table(placement_group: PlacementGroup = None) -> list: """Get the state of the placement group from GCS. diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index e7ddb765b..034e91082 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -727,7 +727,7 @@ class PlacementGroupInfoAccessor { virtual Status AsyncCreatePlacementGroup( const PlacementGroupSpecification &placement_group_spec) = 0; - /// Get a placement group data from GCS asynchronously. + /// Get a placement group data from GCS asynchronously by id. /// /// \param placement_group_id The id of a placement group to obtain from GCS. /// \return Status. @@ -735,6 +735,14 @@ class PlacementGroupInfoAccessor { const PlacementGroupID &placement_group_id, const OptionalItemCallback &callback) = 0; + /// Get a placement group data from GCS asynchronously by name. + /// + /// \param placement_group_name The name of a placement group to obtain from GCS. + /// \return Status. + virtual Status AsyncGetByName( + const std::string &placement_group_name, + const OptionalItemCallback &callback) = 0; + /// Get all placement group info from GCS asynchronously. /// /// \param callback Callback that will be called after lookup finished. diff --git a/src/ray/gcs/gcs_client/global_state_accessor.cc b/src/ray/gcs/gcs_client/global_state_accessor.cc index 4e9a6fa18..669b16e2b 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.cc +++ b/src/ray/gcs/gcs_client/global_state_accessor.cc @@ -259,5 +259,17 @@ std::unique_ptr GlobalStateAccessor::GetPlacementGroupInfo( return placement_group_table_data; } +std::unique_ptr GlobalStateAccessor::GetPlacementGroupByName( + const std::string &placement_group_name) { + std::unique_ptr placement_group_table_data; + std::promise promise; + RAY_CHECK_OK(gcs_client_->PlacementGroups().AsyncGetByName( + placement_group_name, + TransformForOptionalItemCallback( + placement_group_table_data, promise))); + promise.get_future().get(); + return placement_group_table_data; +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_client/global_state_accessor.h b/src/ray/gcs/gcs_client/global_state_accessor.h index 0c5695780..c15963587 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.h +++ b/src/ray/gcs/gcs_client/global_state_accessor.h @@ -151,15 +151,24 @@ class GlobalStateAccessor { /// deserialized with protobuf function. std::vector GetAllPlacementGroupInfo(); - /// Get information of a placement group from GCS Service. + /// Get information of a placement group from GCS Service by ID. /// - /// \param placement_group The ID of placement group to look up in the GCS Service. + /// \param placement_group_id The ID of placement group to look up in the GCS Service. /// \return Placement group info. To support multi-language, we serialize each /// PlacementGroupTableData and return the serialized string. Where used, it needs to be /// deserialized with protobuf function. std::unique_ptr GetPlacementGroupInfo( const PlacementGroupID &placement_group_id); + /// Get information of a placement group from GCS Service by name. + /// + /// \param placement_group_name The name of placement group to look up in the GCS + /// Service. \return Placement group info. To support multi-language, we serialize each + /// PlacementGroupTableData and return the serialized string. Where used, it needs to be + /// deserialized with protobuf function. + std::unique_ptr GetPlacementGroupByName( + const std::string &placement_group_name); + private: /// MultiItem transformation helper in template style. /// diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index c4f550e50..015da29f3 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -1466,6 +1466,26 @@ Status ServiceBasedPlacementGroupInfoAccessor::AsyncGet( return Status::OK(); } +Status ServiceBasedPlacementGroupInfoAccessor::AsyncGetByName( + const std::string &name, + const OptionalItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting named placement group info, name = " << name; + rpc::GetNamedPlacementGroupRequest request; + request.set_name(name); + client_impl_->GetGcsRpcClient().GetNamedPlacementGroup( + request, [name, callback](const Status &status, + const rpc::GetNamedPlacementGroupReply &reply) { + if (reply.has_placement_group_table_data()) { + callback(status, reply.placement_group_table_data()); + } else { + callback(status, boost::none); + } + RAY_LOG(DEBUG) << "Finished getting named placement group info, status = " + << status << ", name = " << name; + }); + return Status::OK(); +} + Status ServiceBasedPlacementGroupInfoAccessor::AsyncGetAll( const MultiItemCallback &callback) { RAY_LOG(DEBUG) << "Getting all placement group info."; diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index 79deb2a6c..c883e7b62 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -453,6 +453,10 @@ class ServiceBasedPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor const PlacementGroupID &placement_group_id, const OptionalItemCallback &callback) override; + Status AsyncGetByName( + const std::string &name, + const OptionalItemCallback &callback) override; + Status AsyncGetAll( const MultiItemCallback &callback) override; diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc index a856002b6..12260d867 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc @@ -65,7 +65,8 @@ rpc::PlacementStrategy GcsPlacementGroup::GetStrategy() const { return placement_group_table_data_.strategy(); } -const rpc::PlacementGroupTableData &GcsPlacementGroup::GetPlacementGroupTableData() { +const rpc::PlacementGroupTableData &GcsPlacementGroup::GetPlacementGroupTableData() + const { return placement_group_table_data_; } @@ -147,6 +148,21 @@ void GcsPlacementGroupManager::RegisterPlacementGroup( } return; } + if (!placement_group->GetName().empty()) { + auto it = named_placement_groups_.find(placement_group->GetName()); + if (it == named_placement_groups_.end()) { + named_placement_groups_.emplace(placement_group->GetName(), + placement_group->GetPlacementGroupID()); + } else { + std::stringstream stream; + stream << "Failed to create placement group '" + << placement_group->GetPlacementGroupID() << "' because name '" + << placement_group->GetName() << "' already exists."; + RAY_LOG(WARNING) << stream.str(); + callback(Status::Invalid(stream.str())); + return; + } + } // Mark the callback as pending and invoke it after the placement_group has been // successfully created. @@ -178,11 +194,9 @@ void GcsPlacementGroupManager::RegisterPlacementGroup( PlacementGroupID GcsPlacementGroupManager::GetPlacementGroupIDByName( const std::string &name) { PlacementGroupID placement_group_id = PlacementGroupID::Nil(); - for (const auto &iter : registered_placement_groups_) { - if (iter.second->GetName() == name) { - placement_group_id = iter.first; - break; - } + auto it = named_placement_groups_.find(name); + if (it != named_placement_groups_.end()) { + placement_group_id = it->second; } return placement_group_id; } @@ -315,10 +329,19 @@ void GcsPlacementGroupManager::RemovePlacementGroup( on_placement_group_removed(Status::OK()); return; } - auto placement_group = placement_group_it->second; + auto placement_group = std::move(placement_group_it->second); registered_placement_groups_.erase(placement_group_it); placement_group_to_create_callbacks_.erase(placement_group_id); + // Remove placement group from `named_placement_groups_` if its name is not empty. + if (!placement_group->GetName().empty()) { + auto it = named_placement_groups_.find(placement_group->GetName()); + if (it != named_placement_groups_.end() && + it->second == placement_group->GetPlacementGroupID()) { + named_placement_groups_.erase(it); + } + } + // Destroy all bundles. gcs_placement_group_scheduler_->DestroyPlacementGroupBundleResourcesIfExists( placement_group_id); @@ -385,6 +408,30 @@ void GcsPlacementGroupManager::HandleGetPlacementGroup( ++counts_[CountType::GET_PLACEMENT_GROUP_REQUEST]; } +void GcsPlacementGroupManager::HandleGetNamedPlacementGroup( + const rpc::GetNamedPlacementGroupRequest &request, + rpc::GetNamedPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) { + const std::string &name = request.name(); + RAY_LOG(DEBUG) << "Getting named placement group info, name = " << name; + + // Try to look up the placement Group ID for the named placement group. + auto placement_group_id = GetPlacementGroupIDByName(name); + + if (placement_group_id.IsNil()) { + // The placement group was not found. + RAY_LOG(DEBUG) << "Placement Group with name '" << name << "' was not found"; + } else { + const auto &iter = registered_placement_groups_.find(placement_group_id); + RAY_CHECK(iter != registered_placement_groups_.end()); + reply->mutable_placement_group_table_data()->CopyFrom( + iter->second->GetPlacementGroupTableData()); + RAY_LOG(DEBUG) << "Finished get named placement group info, placement group id = " + << placement_group_id; + } + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + ++counts_[CountType::GET_NAMED_PLACEMENT_GROUP_REQUEST]; +} + void GcsPlacementGroupManager::HandleGetAllPlacementGroup( const rpc::GetAllPlacementGroupRequest &request, rpc::GetAllPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) { @@ -550,6 +597,10 @@ void GcsPlacementGroupManager::Initialize(const GcsInitData &gcs_init_data) { auto placement_group = std::make_shared(item.second); if (item.second.state() != rpc::PlacementGroupTableData::REMOVED) { registered_placement_groups_.emplace(item.first, placement_group); + if (!placement_group->GetName().empty()) { + named_placement_groups_.emplace(placement_group->GetName(), + placement_group->GetPlacementGroupID()); + } if (item.second.state() == rpc::PlacementGroupTableData::PENDING || item.second.state() == rpc::PlacementGroupTableData::RESCHEDULING) { @@ -587,6 +638,7 @@ std::string GcsPlacementGroupManager::DebugString() const { << ", WaitPlacementGroupUntilReady request count: " << counts_[CountType::WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST] << ", Registered placement groups count: " << registered_placement_groups_.size() + << ", Named placement group count: " << named_placement_groups_.size() << ", Pending placement groups count: " << pending_placement_groups_.size() << "}"; return stream.str(); diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h index 28ce82090..49a7634df 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h @@ -65,7 +65,7 @@ class GcsPlacementGroup { } /// Get the immutable PlacementGroupTableData of this placement group. - const rpc::PlacementGroupTableData &GetPlacementGroupTableData(); + const rpc::PlacementGroupTableData &GetPlacementGroupTableData() const; /// Get the mutable bundle of this placement group. rpc::Bundle *GetMutableBundle(int bundle_index); @@ -155,10 +155,13 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { rpc::GetPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) override; + void HandleGetNamedPlacementGroup(const rpc::GetNamedPlacementGroupRequest &request, + rpc::GetNamedPlacementGroupReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + void HandleGetAllPlacementGroup(const rpc::GetAllPlacementGroupRequest &request, rpc::GetAllPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) override; - void HandleWaitPlacementGroupUntilReady( const rpc::WaitPlacementGroupUntilReadyRequest &request, rpc::WaitPlacementGroupUntilReadyReply *reply, @@ -315,6 +318,9 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { /// Reference of GcsResourceManager. GcsResourceManager &gcs_resource_manager_; + /// Maps placement group names to their placement group ID for lookups by name. + absl::flat_hash_map named_placement_groups_; + // Debug info. enum CountType { CREATE_PLACEMENT_GROUP_REQUEST = 0, @@ -322,7 +328,8 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { GET_PLACEMENT_GROUP_REQUEST = 2, GET_ALL_PLACEMENT_GROUP_REQUEST = 3, WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST = 4, - CountType_MAX = 5, + GET_NAMED_PLACEMENT_GROUP_REQUEST = 5, + CountType_MAX = 6, }; uint64_t counts_[CountType::CountType_MAX] = {0}; }; diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc index fec3f2540..77784e44b 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc @@ -174,6 +174,31 @@ TEST_F(GcsPlacementGroupManagerTest, TestGetPlacementGroupIDByName) { PlacementGroupID::FromBinary(request.placement_group_spec().placement_group_id())); } +TEST_F(GcsPlacementGroupManagerTest, TestRemoveNamedPlacementGroup) { + auto request = Mocker::GenCreatePlacementGroupRequest("test_name"); + std::atomic finished_placement_group_count(0); + gcs_placement_group_manager_->RegisterPlacementGroup( + std::make_shared(request), + [&finished_placement_group_count](const Status &status) { + ++finished_placement_group_count; + }); + + ASSERT_EQ(finished_placement_group_count, 0); + WaitForExpectedPgCount(1); + auto placement_group = mock_placement_group_scheduler_->placement_groups_.back(); + mock_placement_group_scheduler_->placement_groups_.pop_back(); + + gcs_placement_group_manager_->OnPlacementGroupCreationSuccess(placement_group); + WaitForExpectedCount(finished_placement_group_count, 1); + ASSERT_EQ(placement_group->GetState(), rpc::PlacementGroupTableData::CREATED); + // Remove the named placement group. + gcs_placement_group_manager_->RemovePlacementGroup( + placement_group->GetPlacementGroupID(), + [](const Status &status) { ASSERT_TRUE(status.ok()); }); + ASSERT_EQ(gcs_placement_group_manager_->GetPlacementGroupIDByName("test_name"), + PlacementGroupID::Nil()); +} + TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeAdd) { auto request = Mocker::GenCreatePlacementGroupRequest(); std::atomic finished_placement_group_count(0); diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 8922ce6f4..ed5ca92e2 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -504,6 +504,17 @@ message WaitPlacementGroupUntilReadyReply { GcsStatus status = 1; } +message GetNamedPlacementGroupRequest { + // Name of the placement group. + string name = 1; +} + +message GetNamedPlacementGroupReply { + GcsStatus status = 1; + // Data of placement group. + PlacementGroupTableData placement_group_table_data = 2; +} + // Service for placement group info access. service PlacementGroupInfoGcsService { // Create placement group via gcs service. @@ -514,6 +525,9 @@ service PlacementGroupInfoGcsService { returns (RemovePlacementGroupReply); // Get placement group information via gcs service. rpc GetPlacementGroup(GetPlacementGroupRequest) returns (GetPlacementGroupReply); + // Get named placement group information via gcs service. + rpc GetNamedPlacementGroup(GetNamedPlacementGroupRequest) + returns (GetNamedPlacementGroupReply); // Get information of all placement group from GCS Service. rpc GetAllPlacementGroup(GetAllPlacementGroupRequest) returns (GetAllPlacementGroupReply); diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index fa77fddd2..bf9a72bed 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -254,6 +254,10 @@ class GcsRpcClient { VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetPlacementGroup, placement_group_info_grpc_client_, ) + /// Get placement group data from GCS Service by name. + VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetNamedPlacementGroup, + placement_group_info_grpc_client_, ) + /// Get information of all placement group from GCS Service. VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetAllPlacementGroup, placement_group_info_grpc_client_, ) diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index 0add85c0e..328aa5f73 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -522,6 +522,10 @@ class PlacementGroupInfoGcsServiceHandler { const WaitPlacementGroupUntilReadyRequest &request, WaitPlacementGroupUntilReadyReply *reply, SendReplyCallback send_reply_callback) = 0; + + virtual void HandleGetNamedPlacementGroup(const GetNamedPlacementGroupRequest &request, + GetNamedPlacementGroupReply *reply, + SendReplyCallback send_reply_callback) = 0; }; /// The `GrpcService` for `PlacementGroupInfoGcsService`. @@ -543,6 +547,7 @@ class PlacementGroupInfoGrpcService : public GrpcService { PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(CreatePlacementGroup); PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(RemovePlacementGroup); PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetPlacementGroup); + PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetNamedPlacementGroup); PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetAllPlacementGroup); PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(WaitPlacementGroupUntilReady); }