diff --git a/python/ray/experimental/__init__.py b/python/ray/experimental/__init__.py index cd0878256..3647380b3 100644 --- a/python/ray/experimental/__init__.py +++ b/python/ray/experimental/__init__.py @@ -1,13 +1,8 @@ from .api import get, wait from .dynamic_resources import set_resource from .object_spilling import force_spill_objects, force_restore_spilled_objects -from .placement_group import ( - placement_group, ) +from .placement_group import (placement_group, placement_group_table) __all__ = [ - "get", - "wait", - "set_resource", - "force_spill_objects", - "force_restore_spilled_objects", - "placement_group", + "get", "wait", "set_resource", "force_spill_objects", + "force_restore_spilled_objects", "placement_group", "placement_group_table" ] diff --git a/python/ray/experimental/placement_group.py b/python/ray/experimental/placement_group.py index f1fec5eff..c4cecb458 100644 --- a/python/ray/experimental/placement_group.py +++ b/python/ray/experimental/placement_group.py @@ -30,3 +30,10 @@ def placement_group(bundles: List[Dict[str, float]], name, bundles, strategy) return placement_group_id + + +def placement_group_table(placement_group_id): + assert placement_group_id is not None + worker = ray.worker.global_worker + worker.check_connected() + return ray.state.state.placement_group_table(placement_group_id) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 7147ca97e..1662489c0 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -1,48 +1,19 @@ from ray.core.generated.gcs_pb2 import ( - ActorCheckpointIdData, - ActorTableData, - GcsNodeInfo, - JobTableData, - JobConfig, - ErrorTableData, - ErrorType, - GcsEntry, - HeartbeatBatchTableData, - HeartbeatTableData, - ObjectTableData, - ProfileTableData, - TablePrefix, - TablePubsub, - TaskTableData, - ResourceMap, - ResourceTableData, - ObjectLocationInfo, - PubSubMessage, - WorkerTableData, -) + ActorCheckpointIdData, ActorTableData, GcsNodeInfo, JobTableData, + JobConfig, ErrorTableData, ErrorType, GcsEntry, HeartbeatBatchTableData, + HeartbeatTableData, ObjectTableData, ProfileTableData, TablePrefix, + TablePubsub, TaskTableData, ResourceMap, ResourceTableData, + ObjectLocationInfo, PubSubMessage, WorkerTableData, + PlacementGroupTableData) __all__ = [ - "ActorCheckpointIdData", - "ActorTableData", - "GcsNodeInfo", - "JobTableData", - "JobConfig", - "ErrorTableData", - "ErrorType", - "GcsEntry", - "HeartbeatBatchTableData", - "HeartbeatTableData", - "ObjectTableData", - "ProfileTableData", - "TablePrefix", - "TablePubsub", - "TaskTableData", - "ResourceMap", - "ResourceTableData", - "construct_error_message", - "ObjectLocationInfo", - "PubSubMessage", - "WorkerTableData", + "ActorCheckpointIdData", "ActorTableData", "GcsNodeInfo", "JobTableData", + "JobConfig", "ErrorTableData", "ErrorType", "GcsEntry", + "HeartbeatBatchTableData", "HeartbeatTableData", "ObjectTableData", + "ProfileTableData", "TablePrefix", "TablePubsub", "TaskTableData", + "ResourceMap", "ResourceTableData", "construct_error_message", + "ObjectLocationInfo", "PubSubMessage", "WorkerTableData", + "PlacementGroupTableData" ] FUNCTION_PREFIX = "RemoteFunction:" diff --git a/python/ray/includes/global_state_accessor.pxd b/python/ray/includes/global_state_accessor.pxd index bd6eeeb0f..dce948389 100644 --- a/python/ray/includes/global_state_accessor.pxd +++ b/python/ray/includes/global_state_accessor.pxd @@ -7,6 +7,7 @@ from ray.includes.unique_ids cimport ( CClientID, CObjectID, CWorkerID, + CPlacementGroupID, ) cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: @@ -27,3 +28,5 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: unique_ptr[c_string] GetWorkerInfo(const CWorkerID &worker_id) c_vector[c_string] GetAllWorkerInfo() c_bool AddWorkerInfo(const c_string &serialized_string) + unique_ptr[c_string] GetPlacementGroupInfo( + const CPlacementGroupID &placement_group_id) diff --git a/python/ray/includes/global_state_accessor.pxi b/python/ray/includes/global_state_accessor.pxi index ba954ffd3..bda49ac20 100644 --- a/python/ray/includes/global_state_accessor.pxi +++ b/python/ray/includes/global_state_accessor.pxi @@ -3,6 +3,7 @@ from ray.includes.unique_ids cimport ( CClientID, CObjectID, CWorkerID, + CPlacementGroupID ) from ray.includes.global_state_accessor cimport ( @@ -114,3 +115,14 @@ cdef class GlobalStateAccessor: with nogil: result = self.inner.get().AddWorkerInfo(cserialized_string) return result + + def get_placement_group_info(self, placement_group_id): + cdef unique_ptr[c_string] result + cdef CPlacementGroupID cplacement_group_id = ( + CPlacementGroupID.FromBinary(placement_group_id.binary())) + with nogil: + result = self.inner.get().GetPlacementGroupInfo( + cplacement_group_id) + if result: + return c_string(result.get().data(), result.get().size()) + return None diff --git a/python/ray/state.py b/python/ray/state.py index 5ddf8af81..8c1e7b821 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -377,6 +377,57 @@ class GlobalState: return dict(result) + # SANG-TODO Add functions. + def placement_group_table(self, placement_group_id=None): + self._check_connected() + + if placement_group_id is not None: + placement_group_id = ray.PlacementGroupID( + hex_to_binary(placement_group_id.hex())) + placement_group_info = ( + self.global_state_accessor.get_placement_group_info( + placement_group_id)) + if placement_group_info is None: + return {} + else: + placement_group_info = (gcs_utils.PlacementGroupTableData. + FromString(placement_group_info)) + return self._gen_placement_group_info(placement_group_info) + else: + raise NotImplementedError( + "Get all placement group is not implemented yet.") + + def _gen_placement_group_info(self, placement_group_info): + # This should be imported here, otherwise, it will error doc build. + from ray.core.generated.common_pb2 import PlacementStrategy + + def get_state(state): + if state == ray.gcs_utils.PlacementGroupTableData.PENDING: + return "PENDING" + elif state == ray.gcs_utils.PlacementGroupTableData.ALIVE: + return "ALIVE" + else: + return "DEAD" + + def get_strategy(strategy): + if strategy == PlacementStrategy.PACK: + return "PACK" + else: + return "SPREAD" + + assert placement_group_info is not None + return { + "placement_group_id": binary_to_hex( + placement_group_info.placement_group_id), + "name": placement_group_info.name, + "bundles": { + bundle.bundle_id.bundle_index: bundle.unit_resources + for bundle in placement_group_info.bundles + }, + "strategy": get_strategy(placement_group_info.strategy), + "state": get_state(placement_group_info.state), + } + def _seconds_to_microseconds(self, time_in_seconds): """A helper function for converting seconds to microseconds.""" time_in_microseconds = 10**6 * time_in_seconds diff --git a/python/ray/tests/test_placement_group.py b/python/ray/tests/test_placement_group.py index 656523c51..50234da94 100644 --- a/python/ray/tests/test_placement_group.py +++ b/python/ray/tests/test_placement_group.py @@ -219,6 +219,47 @@ def test_placement_group_hang(ray_start_cluster): assert "CPU_group_" in list(resources.keys())[0], resources +def test_placement_group_table(ray_start_cluster): + @ray.remote(num_cpus=2) + class Actor(object): + def __init__(self): + self.n = 0 + + def value(self): + return self.n + + cluster = ray_start_cluster + num_nodes = 2 + for _ in range(num_nodes): + cluster.add_node(num_cpus=4) + ray.init(address=cluster.address) + + # Originally placement group creation should be pending because + # there are no resources. + name = "name" + strategy = "PACK" + bundles = [{"CPU": 2, "GPU": 1}, {"CPU": 2}] + placement_group_id = ray.experimental.placement_group( + name=name, strategy=strategy, bundles=bundles) + result = ray.experimental.placement_group_table(placement_group_id) + assert result["name"] == name + assert result["strategy"] == strategy + for i in range(len(bundles)): + assert bundles[i] == result["bundles"][i] + assert result["state"] == "PENDING" + + # Now the placement group should be scheduled. + cluster.add_node(num_cpus=5, num_gpus=1) + cluster.wait_for_nodes() + actor_1 = Actor.options( + placement_group_id=placement_group_id, + placement_group_bundle_index=0).remote() + ray.get(actor_1.value.remote()) + + result = ray.experimental.placement_group_table(placement_group_id) + assert result["state"] == "ALIVE" + + def test_cuda_visible_devices(ray_start_cluster): @ray.remote(num_gpus=1) def f(): diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 6300cbda9..60dfbf94e 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1396,7 +1396,6 @@ Status CoreWorker::KillActor(const ActorID &actor_id, bool force_kill, bool no_r return Status::OK(); } -// SANG-TODO Status CoreWorker::KillActorLocalMode(const ActorID &actor_id) { // KillActor doesn't do anything in local mode. We only remove named actor entry if // exists. diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index daa8a8fa0..b401b1b7f 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -737,10 +737,19 @@ class PlacementGroupInfoAccessor { /// /// \param placement_group_spec The specification for the placement group creation task. /// \param callback Callback that will be called after the placement group info is - /// written to GCS. \return Status + /// written to GCS. + /// \return Status. virtual Status AsyncCreatePlacementGroup( const PlacementGroupSpecification &placement_group_spec) = 0; + /// Get a placement group data from GCS asynchronously. + /// + /// \param placement_group_id The id of a placement group to obtain from GCS. + /// \return Status. + virtual Status AsyncGet( + const PlacementGroupID &placement_group_id, + const OptionalItemCallback &callback) = 0; + protected: PlacementGroupInfoAccessor() = default; }; diff --git a/src/ray/gcs/gcs_client/global_state_accessor.cc b/src/ray/gcs/gcs_client/global_state_accessor.cc index dc4549cb3..260140a3e 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.cc +++ b/src/ray/gcs/gcs_client/global_state_accessor.cc @@ -232,5 +232,16 @@ bool GlobalStateAccessor::AddWorkerInfo(const std::string &serialized_string) { return true; } +std::unique_ptr GlobalStateAccessor::GetPlacementGroupInfo( + const PlacementGroupID &placement_group_id) { + std::unique_ptr placement_group_table_data; + std::promise promise; + RAY_CHECK_OK(gcs_client_->PlacementGroups().AsyncGet( + placement_group_id, TransformForOptionalItemCallback( + placement_group_table_data, promise))); + promise.get_future().get(); + return placement_group_table_data; +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_client/global_state_accessor.h b/src/ray/gcs/gcs_client/global_state_accessor.h index f3d0597f5..002628ccd 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.h +++ b/src/ray/gcs/gcs_client/global_state_accessor.h @@ -137,6 +137,15 @@ class GlobalStateAccessor { /// \return Is operation success. bool AddWorkerInfo(const std::string &serialized_string); + /// Get information of a placement group from GCS Service. + /// + /// \param placement_group The ID of placement group to look up in the GCS Service. + /// \return Placement group info. To support multi-language, we serialize each + /// PlacementGroupTableData and return the serialized string. Where used, it needs to be + /// deserialized with protobuf function. + std::unique_ptr GetPlacementGroupInfo( + const PlacementGroupID &placement_group_id); + private: /// MultiItem transformation helper in template style. /// diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index 273a0446d..914a6724f 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -1462,5 +1462,26 @@ Status ServiceBasedPlacementGroupInfoAccessor::AsyncCreatePlacementGroup( return Status::OK(); } +Status ServiceBasedPlacementGroupInfoAccessor::AsyncGet( + const PlacementGroupID &placement_group_id, + const OptionalItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting placement group info, placement group id = " + << placement_group_id; + rpc::GetPlacementGroupRequest request; + request.set_placement_group_id(placement_group_id.Binary()); + client_impl_->GetGcsRpcClient().GetPlacementGroup( + request, [placement_group_id, callback](const Status &status, + const rpc::GetPlacementGroupReply &reply) { + if (reply.has_placement_group_table_data()) { + callback(status, reply.placement_group_table_data()); + } else { + callback(status, boost::none); + } + RAY_LOG(DEBUG) << "Finished getting placement group info, placement group id = " + << placement_group_id; + }); + return Status::OK(); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index 18b85a1b6..4b99de3e7 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -446,6 +446,10 @@ class ServiceBasedPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor Status AsyncCreatePlacementGroup( const PlacementGroupSpecification &placement_group_spec) override; + Status AsyncGet( + const PlacementGroupID &placement_group_id, + const OptionalItemCallback &callback) override; + private: ServiceBasedGcsClient *client_impl_; }; diff --git a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc index a25043811..0b9d832f0 100644 --- a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc +++ b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc @@ -269,6 +269,8 @@ TEST_F(GlobalStateAccessorTest, TestWorkerTable) { ASSERT_EQ(global_state_->GetAllWorkerInfo().size(), 2); } +// TODO(sang): Add tests after adding asyncAdd + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc index c49044686..f49d98377 100644 --- a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc @@ -1173,6 +1173,8 @@ TEST_F(ServiceBasedGcsClientTest, TestMultiThreadSubAndUnsub) { } } +// TODO(sang): Add tests after adding asyncAdd + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc index 8a139ecb5..e56e66e42 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc @@ -174,6 +174,32 @@ void GcsPlacementGroupManager::HandleCreatePlacementGroup( })); } +void GcsPlacementGroupManager::HandleGetPlacementGroup( + const rpc::GetPlacementGroupRequest &request, rpc::GetPlacementGroupReply *reply, + rpc::SendReplyCallback send_reply_callback) { + PlacementGroupID placement_group_id = + PlacementGroupID::FromBinary(request.placement_group_id()); + RAY_LOG(DEBUG) << "Getting placement group info, placement group id = " + << placement_group_id; + + auto on_done = [placement_group_id, reply, send_reply_callback]( + const Status &status, + const boost::optional &result) { + if (result) { + reply->mutable_placement_group_table_data()->CopyFrom(*result); + } + RAY_LOG(DEBUG) << "Finished getting placement group info, placement group id = " + << placement_group_id; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + }; + + Status status = + gcs_table_storage_->PlacementGroupTable().Get(placement_group_id, on_done); + if (!status.ok()) { + on_done(status, boost::none); + } +} + void GcsPlacementGroupManager::RetryCreatingPlacementGroup() { execute_after(io_context_, [this] { SchedulePendingPlacementGroups(); }, RayConfig::instance().gcs_create_placement_group_retry_interval_ms()); diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h index f279134ca..93e03f775 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.h +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.h @@ -107,6 +107,10 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler { rpc::CreatePlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) override; + void HandleGetPlacementGroup(const rpc::GetPlacementGroupRequest &request, + rpc::GetPlacementGroupReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Register placement_group asynchronously. /// /// \param request Contains the meta info to create the placement_group. diff --git a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc index 25a46fea8..78ab650f0 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc @@ -771,6 +771,8 @@ TEST_F(GcsServerTest, TestWorkerInfo) { worker_data->worker_address().worker_id()); } +// TODO(sang): Add tests after adding asyncAdd + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/gcs/redis_accessor.cc b/src/ray/gcs/redis_accessor.cc index 342d8e0bf..08ce4c3a5 100644 --- a/src/ray/gcs/redis_accessor.cc +++ b/src/ray/gcs/redis_accessor.cc @@ -833,6 +833,12 @@ Status RedisPlacementGroupInfoAccessor::AsyncCreatePlacementGroup( return Status::Invalid("Not implemented"); } +Status RedisPlacementGroupInfoAccessor::AsyncGet( + const PlacementGroupID &placement_group_id, + const OptionalItemCallback &callback) { + return Status::Invalid("Not implemented"); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index 0dc787884..da2229dfb 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -478,6 +478,10 @@ class RedisPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor { Status AsyncCreatePlacementGroup( const PlacementGroupSpecification &placement_group_spec) override; + + Status AsyncGet( + const PlacementGroupID &placement_group_id, + const OptionalItemCallback &callback) override; }; } // namespace gcs diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 1797126ac..24796b495 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -166,9 +166,11 @@ service ActorInfoGcsService { // Service for placement group info access. service PlacementGroupInfoGcsService { - // Create placement group via gcs service + // Create placement group via gcs service. rpc CreatePlacementGroup(CreatePlacementGroupRequest) returns (CreatePlacementGroupReply); + // Get placement group information via gcs service. + rpc GetPlacementGroup(GetPlacementGroupRequest) returns (GetPlacementGroupReply); } message RegisterNodeRequest { @@ -497,6 +499,15 @@ message CreatePlacementGroupReply { GcsStatus status = 1; } +message GetPlacementGroupRequest { + bytes placement_group_id = 1; +} + +message GetPlacementGroupReply { + GcsStatus status = 1; + PlacementGroupTableData placement_group_table_data = 2; +} + enum GcsServiceFailureType { RPC_DISCONNECT = 0; GCS_SERVER_RESTART = 1; diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 35bbecb73..6933101cc 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -251,6 +251,10 @@ class GcsRpcClient { VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, CreatePlacementGroup, placement_group_info_grpc_client_, ) + /// Get placement group via GCS Service. + VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetPlacementGroup, + placement_group_info_grpc_client_, ) + private: std::function gcs_service_failure_detected_; diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index c1981b436..628e74b39 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -456,6 +456,10 @@ class PlacementGroupInfoGcsServiceHandler { virtual void HandleCreatePlacementGroup(const CreatePlacementGroupRequest &request, CreatePlacementGroupReply *reply, SendReplyCallback send_reply_callback) = 0; + + virtual void HandleGetPlacementGroup(const GetPlacementGroupRequest &request, + GetPlacementGroupReply *reply, + SendReplyCallback send_reply_callback) = 0; }; /// The `GrpcService` for `PlacementGroupInfoGcsService`. @@ -475,6 +479,7 @@ class PlacementGroupInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories) override { PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(CreatePlacementGroup); + PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetPlacementGroup); } private: