From 9f804ade5fb477cdea7a5c9a7d48f05547ebbcd1 Mon Sep 17 00:00:00 2001 From: "DK.Pino" Date: Sat, 24 Oct 2020 02:46:48 +0800 Subject: [PATCH] [Placement Group]Add get all placement group api (#11460) * add get all interface for placement group * add get all interface for placement group * make it work * fix lint * fix lint * fix comment * add cpp test * fix python lint --- python/ray/includes/global_state_accessor.pxd | 1 + python/ray/includes/global_state_accessor.pxi | 6 +++ python/ray/state.py | 14 ++++++- python/ray/tests/test_placement_group.py | 23 +++++++++++ python/ray/util/placement_group.py | 7 ++-- src/ray/gcs/accessor.h | 7 ++++ .../gcs/gcs_client/global_state_accessor.cc | 10 +++++ .../gcs/gcs_client/global_state_accessor.h | 7 ++++ .../gcs/gcs_client/service_based_accessor.cc | 14 +++++++ .../gcs/gcs_client/service_based_accessor.h | 3 ++ .../test/global_state_accessor_test.cc | 3 ++ .../gcs_server/gcs_placement_group_manager.cc | 19 +++++++++ .../gcs_server/gcs_placement_group_manager.h | 4 ++ src/ray/gcs/redis_accessor.cc | 5 +++ src/ray/gcs/redis_accessor.h | 3 ++ src/ray/protobuf/gcs_service.proto | 41 ++++++++++++------- src/ray/rpc/gcs_server/gcs_rpc_client.h | 4 ++ src/ray/rpc/gcs_server/gcs_rpc_server.h | 5 +++ 18 files changed, 157 insertions(+), 19 deletions(-) diff --git a/python/ray/includes/global_state_accessor.pxd b/python/ray/includes/global_state_accessor.pxd index 058aa8514..e574554c7 100644 --- a/python/ray/includes/global_state_accessor.pxd +++ b/python/ray/includes/global_state_accessor.pxd @@ -31,3 +31,4 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: c_bool AddWorkerInfo(const c_string &serialized_string) unique_ptr[c_string] GetPlacementGroupInfo( const CPlacementGroupID &placement_group_id) + c_vector[c_string] GetAllPlacementGroupInfo() diff --git a/python/ray/includes/global_state_accessor.pxi b/python/ray/includes/global_state_accessor.pxi index 216d3dd82..0d213f01f 100644 --- a/python/ray/includes/global_state_accessor.pxi +++ b/python/ray/includes/global_state_accessor.pxi @@ -122,6 +122,12 @@ cdef class GlobalStateAccessor: result = self.inner.get().AddWorkerInfo(cserialized_string) return result + def get_placement_group_table(self): + cdef c_vector[c_string] result + with nogil: + result = self.inner.get().GetAllPlacementGroupInfo() + return result + def get_placement_group_info(self, placement_group_id): cdef unique_ptr[c_string] result cdef CPlacementGroupID cplacement_group_id = ( diff --git a/python/ray/state.py b/python/ray/state.py index 64e749870..086d7567b 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -388,8 +388,18 @@ class GlobalState: FromString(placement_group_info)) return self._gen_placement_group_info(placement_group_info) else: - raise NotImplementedError( - "Get all placement group is not implemented yet.") + placement_group_table = self.global_state_accessor.\ + get_placement_group_table() + results = {} + for placement_group_info in placement_group_table: + placement_group_table_data = gcs_utils.\ + PlacementGroupTableData.FromString(placement_group_info) + placement_group_id = binary_to_hex( + placement_group_table_data.placement_group_id) + results[placement_group_id] = \ + self._gen_placement_group_info(placement_group_table_data) + + return results def _gen_placement_group_info(self, placement_group_info): # This should be imported here, otherwise, it will error doc build. diff --git a/python/ray/tests/test_placement_group.py b/python/ray/tests/test_placement_group.py index a5b75636a..e7779a43b 100644 --- a/python/ray/tests/test_placement_group.py +++ b/python/ray/tests/test_placement_group.py @@ -406,6 +406,7 @@ def test_placement_group_table(ray_start_cluster): # Now the placement group should be scheduled. cluster.add_node(num_cpus=5, num_gpus=1) + cluster.wait_for_nodes() actor_1 = Actor.options( placement_group=placement_group, @@ -415,6 +416,28 @@ def test_placement_group_table(ray_start_cluster): result = ray.util.placement_group_table(placement_group) assert result["state"] == "CREATED" + # Add tow more placement group for placement group table test. + second_strategy = "SPREAD" + ray.util.placement_group( + name="second_placement_group", + strategy=second_strategy, + bundles=bundles) + ray.util.placement_group( + name="third_placement_group", + strategy=second_strategy, + bundles=bundles) + + placement_group_table = ray.util.placement_group_table() + assert len(placement_group_table) == 3 + + true_name_set = {"name", "second_placement_group", "third_placement_group"} + get_name_set = set() + + for _, placement_group_data in placement_group_table.items(): + get_name_set.add(placement_group_data["name"]) + + assert true_name_set == get_name_set + def test_cuda_visible_devices(ray_start_cluster): @ray.remote(num_gpus=1) diff --git a/python/ray/util/placement_group.py b/python/ray/util/placement_group.py index 6e3d49189..8a6b6c3d1 100644 --- a/python/ray/util/placement_group.py +++ b/python/ray/util/placement_group.py @@ -185,17 +185,18 @@ def remove_placement_group(placement_group: PlacementGroup): worker.core_worker.remove_placement_group(placement_group.id) -def placement_group_table(placement_group: PlacementGroup) -> dict: +def placement_group_table(placement_group: PlacementGroup = None) -> list: """Get the state of the placement group from GCS. Args: placement_group (PlacementGroup): placement group to see states. """ - assert placement_group is not None worker = ray.worker.global_worker worker.check_connected() - return ray.state.state.placement_group_table(placement_group.id) + placement_group_id = placement_group.id if (placement_group is + not None) else None + return ray.state.state.placement_group_table(placement_group_id) def get_current_placement_group() -> Optional[PlacementGroup]: diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index a360ff078..72453c092 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -769,6 +769,13 @@ class PlacementGroupInfoAccessor { const PlacementGroupID &placement_group_id, const OptionalItemCallback &callback) = 0; + /// Get all placement group info from GCS asynchronously. + /// + /// \param callback Callback that will be called after lookup finished. + /// \return Status + virtual Status AsyncGetAll( + const MultiItemCallback &callback) = 0; + /// Remove a placement group to GCS synchronously. /// /// \param placement_group_id The id for the placement group to remove. diff --git a/src/ray/gcs/gcs_client/global_state_accessor.cc b/src/ray/gcs/gcs_client/global_state_accessor.cc index b556f7479..96a408251 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.cc +++ b/src/ray/gcs/gcs_client/global_state_accessor.cc @@ -237,6 +237,16 @@ bool GlobalStateAccessor::AddWorkerInfo(const std::string &serialized_string) { return true; } +std::vector GlobalStateAccessor::GetAllPlacementGroupInfo() { + std::vector placement_group_table_data; + std::promise promise; + RAY_CHECK_OK(gcs_client_->PlacementGroups().AsyncGetAll( + TransformForMultiItemCallback( + placement_group_table_data, promise))); + promise.get_future().get(); + return placement_group_table_data; +} + std::unique_ptr GlobalStateAccessor::GetPlacementGroupInfo( const PlacementGroupID &placement_group_id) { std::unique_ptr placement_group_table_data; diff --git a/src/ray/gcs/gcs_client/global_state_accessor.h b/src/ray/gcs/gcs_client/global_state_accessor.h index 8ce72b156..f75b39e16 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.h +++ b/src/ray/gcs/gcs_client/global_state_accessor.h @@ -144,6 +144,13 @@ class GlobalStateAccessor { /// \return Is operation success. bool AddWorkerInfo(const std::string &serialized_string); + /// Get information of all placement group from GCS Service. + /// + /// \return All placement group info. To support multi-language, we serialize each + /// PlacementGroupTableData and return the serialized string. Where used, it needs to be + /// deserialized with protobuf function. + std::vector GetAllPlacementGroupInfo(); + /// Get information of a placement group from GCS Service. /// /// \param placement_group The ID of placement group to look up in the GCS Service. diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index a931acfa3..35cce0bcd 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -1545,5 +1545,19 @@ Status ServiceBasedPlacementGroupInfoAccessor::AsyncGet( return Status::OK(); } +Status ServiceBasedPlacementGroupInfoAccessor::AsyncGetAll( + const MultiItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting all placement group info."; + rpc::GetAllPlacementGroupRequest request; + client_impl_->GetGcsRpcClient().GetAllPlacementGroup( + request, + [callback](const Status &status, const rpc::GetAllPlacementGroupReply &reply) { + callback(status, VectorFromProtobuf(reply.placement_group_table_data())); + RAY_LOG(DEBUG) << "Finished getting all placement group info, status = " + << status; + }); + return Status::OK(); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index 0d53da1ae..56e740479 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -460,6 +460,9 @@ class ServiceBasedPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor const PlacementGroupID &placement_group_id, const OptionalItemCallback &callback) override; + Status AsyncGetAll( + const MultiItemCallback &callback) override; + private: ServiceBasedGcsClient *client_impl_; }; diff --git a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc index 2f7dc171a..1f310c881 100644 --- a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc +++ b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc @@ -250,6 +250,9 @@ TEST_F(GlobalStateAccessorTest, TestWorkerTable) { } // TODO(sang): Add tests after adding asyncAdd +TEST_F(GlobalStateAccessorTest, TestPlacementGroupTable) { + ASSERT_EQ(global_state_->GetAllPlacementGroupInfo().size(), 0); +} } // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc index 997d0e50d..a593bd3e5 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc @@ -328,6 +328,25 @@ void GcsPlacementGroupManager::HandleGetPlacementGroup( } } +void GcsPlacementGroupManager::HandleGetAllPlacementGroup( + const rpc::GetAllPlacementGroupRequest &request, + rpc::GetAllPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(DEBUG) << "Getting all placement group info."; + auto on_done = + [reply, send_reply_callback]( + const std::unordered_map &result) { + for (auto &data : result) { + reply->add_placement_group_table_data()->CopyFrom(data.second); + } + RAY_LOG(DEBUG) << "Finished getting all placement group info."; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + }; + Status status = gcs_table_storage_->PlacementGroupTable().GetAll(on_done); + if (!status.ok()) { + on_done(std::unordered_map()); + } +} + void GcsPlacementGroupManager::RetryCreatingPlacementGroup() { execute_after(io_context_, [this] { SchedulePendingPlacementGroups(); }, RayConfig::instance().gcs_create_placement_group_retry_interval_ms()); diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h index a6fb807ce..a1bbf2d36 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h @@ -129,6 +129,10 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { rpc::GetPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) override; + void HandleGetAllPlacementGroup(const rpc::GetAllPlacementGroupRequest &request, + rpc::GetAllPlacementGroupReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Register placement_group asynchronously. /// /// \param placement_group The placement group to be created. diff --git a/src/ray/gcs/redis_accessor.cc b/src/ray/gcs/redis_accessor.cc index 1df21cae8..f37be4d6a 100644 --- a/src/ray/gcs/redis_accessor.cc +++ b/src/ray/gcs/redis_accessor.cc @@ -781,6 +781,11 @@ Status RedisPlacementGroupInfoAccessor::AsyncGet( return Status::Invalid("Not implemented"); } +Status RedisPlacementGroupInfoAccessor::AsyncGetAll( + const MultiItemCallback &callback) { + return Status::Invalid("Not implemented"); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index 491cf89ce..e1e0da9ba 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -490,6 +490,9 @@ class RedisPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor { Status AsyncGet( const PlacementGroupID &placement_group_id, const OptionalItemCallback &callback) override; + + Status AsyncGetAll( + const MultiItemCallback &callback) override; }; } // namespace gcs diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index e43d90440..d5a40c2db 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -164,20 +164,6 @@ service ActorInfoGcsService { returns (GetActorCheckpointIDReply); } -// Service for placement group info access. -service PlacementGroupInfoGcsService { - // Create placement group via gcs service. - rpc CreatePlacementGroup(CreatePlacementGroupRequest) - returns (CreatePlacementGroupReply); - - // Remove placement group via gcs service. - rpc RemovePlacementGroup(RemovePlacementGroupRequest) - returns (RemovePlacementGroupReply); - - // Get placement group information via gcs service. - rpc GetPlacementGroup(GetPlacementGroupRequest) returns (GetPlacementGroupReply); -} - message RegisterNodeRequest { // Info of node. GcsNodeInfo node_info = 1; @@ -543,3 +529,30 @@ enum GcsServiceFailureType { RPC_DISCONNECT = 0; GCS_SERVER_RESTART = 1; } + +message GetAllPlacementGroupRequest { +} + +message GetAllPlacementGroupReply { + GcsStatus status = 1; + // Data of placement group + repeated PlacementGroupTableData placement_group_table_data = 2; +} + +// Service for placement group info access. +service PlacementGroupInfoGcsService { + // Create placement group via gcs service. + rpc CreatePlacementGroup(CreatePlacementGroupRequest) + returns (CreatePlacementGroupReply); + + // Remove placement group via gcs service. + rpc RemovePlacementGroup(RemovePlacementGroupRequest) + returns (RemovePlacementGroupReply); + + // Get placement group information via gcs service. + rpc GetPlacementGroup(GetPlacementGroupRequest) returns (GetPlacementGroupReply); + + // Get information of all placement group from GCS Service. + rpc GetAllPlacementGroup(GetAllPlacementGroupRequest) + returns (GetAllPlacementGroupReply); +} \ No newline at end of file diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 79e964b89..f592c2799 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -262,6 +262,10 @@ class GcsRpcClient { VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetPlacementGroup, placement_group_info_grpc_client_, ) + /// Get information of all placement group from GCS Service. + VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetAllPlacementGroup, + placement_group_info_grpc_client_, ) + private: std::function gcs_service_failure_detected_; diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index 3a601a44a..ba9225a58 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -470,6 +470,10 @@ class PlacementGroupInfoGcsServiceHandler { virtual void HandleGetPlacementGroup(const GetPlacementGroupRequest &request, GetPlacementGroupReply *reply, SendReplyCallback send_reply_callback) = 0; + + virtual void HandleGetAllPlacementGroup(const GetAllPlacementGroupRequest &request, + GetAllPlacementGroupReply *reply, + SendReplyCallback send_reply_callback) = 0; }; /// The `GrpcService` for `PlacementGroupInfoGcsService`. @@ -491,6 +495,7 @@ class PlacementGroupInfoGrpcService : public GrpcService { PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(CreatePlacementGroup); PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(RemovePlacementGroup); PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetPlacementGroup); + PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetAllPlacementGroup); } private: