mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 18:06:25 +08:00
[Placement Group] Support named placement group (#13755)
This commit is contained in:
@@ -252,6 +252,41 @@ Note that you can anytime remove the placement group to clean up resources.
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
Named Placement Groups
|
||||
----------------------
|
||||
|
||||
A placement group can be given a globally unique name.
|
||||
This allows you to retrieve the placement group from any job in the Ray cluster.
|
||||
This can be useful if you cannot directly pass the placement group handle to
|
||||
the actor or task that needs it, or if you are trying to
|
||||
access a placement group launched by another driver.
|
||||
Note that the placement group will still be destroyed if it's lifetime isn't `detached`.
|
||||
See :ref:`placement-group-lifetimes` for more details.
|
||||
|
||||
.. tabs::
|
||||
.. group-tab:: Python
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# first_driver.py
|
||||
# Create a placement group with a global name.
|
||||
pg = placement_group([{"CPU": 2}, {"CPU": 2}], strategy="STRICT_SPREAD", lifetime="detached", name="global_name")
|
||||
ray.get(pg.ready())
|
||||
|
||||
Then, we can retrieve the actor later somewhere.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# second_driver.py
|
||||
# Retrieve a placement group with a global name.
|
||||
pg = ray.util.get_placement_group("global_name")
|
||||
|
||||
.. group-tab:: Java
|
||||
|
||||
The named placement group is not implemented for Java APIs yet.
|
||||
|
||||
.. _placement-group-lifetimes:
|
||||
|
||||
Placement Group Lifetimes
|
||||
-------------------------
|
||||
|
||||
|
||||
@@ -32,4 +32,6 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil:
|
||||
c_bool AddWorkerInfo(const c_string &serialized_string)
|
||||
unique_ptr[c_string] GetPlacementGroupInfo(
|
||||
const CPlacementGroupID &placement_group_id)
|
||||
unique_ptr[c_string] GetPlacementGroupByName(
|
||||
const c_string &placement_group_name)
|
||||
c_vector[c_string] GetAllPlacementGroupInfo()
|
||||
|
||||
@@ -147,3 +147,13 @@ cdef class GlobalStateAccessor:
|
||||
if result:
|
||||
return c_string(result.get().data(), result.get().size())
|
||||
return None
|
||||
|
||||
def get_placement_group_by_name(self, placement_group_name):
|
||||
cdef unique_ptr[c_string] result
|
||||
cdef c_string cplacement_group_name = placement_group_name
|
||||
with nogil:
|
||||
result = self.inner.get().GetPlacementGroupByName(
|
||||
cplacement_group_name)
|
||||
if result:
|
||||
return c_string(result.get().data(), result.get().size())
|
||||
return None
|
||||
|
||||
@@ -388,6 +388,20 @@ class GlobalState:
|
||||
|
||||
return dict(result)
|
||||
|
||||
def get_placement_group_by_name(self, placement_group_name):
|
||||
self._check_connected()
|
||||
|
||||
placement_group_info = (
|
||||
self.global_state_accessor.get_placement_group_by_name(
|
||||
placement_group_name))
|
||||
if placement_group_info is None:
|
||||
return None
|
||||
else:
|
||||
placement_group_table_data = \
|
||||
gcs_utils.PlacementGroupTableData.FromString(
|
||||
placement_group_info)
|
||||
return self._gen_placement_group_info(placement_group_table_data)
|
||||
|
||||
def placement_group_table(self, placement_group_id=None):
|
||||
self._check_connected()
|
||||
|
||||
|
||||
@@ -375,6 +375,7 @@ def test_remove_pending_placement_group(ray_start_cluster):
|
||||
# Create a placement group that cannot be scheduled now.
|
||||
placement_group = ray.util.placement_group([{"GPU": 2}, {"CPU": 2}])
|
||||
ray.util.remove_placement_group(placement_group)
|
||||
|
||||
# TODO(sang): Add state check here.
|
||||
@ray.remote(num_cpus=4)
|
||||
def f():
|
||||
@@ -797,10 +798,10 @@ def test_mini_integration(ray_start_cluster):
|
||||
pg_tasks = []
|
||||
# total bundle gpu usage = bundles_per_pg * total_num_pg * per_bundle_gpus
|
||||
# Note this is half of total
|
||||
for _ in range(total_num_pg):
|
||||
for index in range(total_num_pg):
|
||||
pgs.append(
|
||||
ray.util.placement_group(
|
||||
name="name",
|
||||
name=f"name{index}",
|
||||
strategy="PACK",
|
||||
bundles=[{
|
||||
"GPU": per_bundle_gpus
|
||||
@@ -1423,5 +1424,86 @@ ray.shutdown()
|
||||
assert assert_alive_num_actor(4)
|
||||
|
||||
|
||||
def test_named_placement_group(ray_start_cluster):
|
||||
cluster = ray_start_cluster
|
||||
for _ in range(2):
|
||||
cluster.add_node(num_cpus=3)
|
||||
cluster.wait_for_nodes()
|
||||
info = ray.init(address=cluster.address)
|
||||
global_placement_group_name = "named_placement_group"
|
||||
|
||||
# Create a detached placement group with name.
|
||||
driver_code = f"""
|
||||
import ray
|
||||
|
||||
ray.init(address="{info["redis_address"]}")
|
||||
|
||||
pg = ray.util.placement_group(
|
||||
[{{"CPU": 1}} for _ in range(2)],
|
||||
strategy="STRICT_SPREAD",
|
||||
name="{global_placement_group_name}",
|
||||
lifetime="detached")
|
||||
ray.get(pg.ready())
|
||||
|
||||
ray.shutdown()
|
||||
"""
|
||||
|
||||
run_string_as_driver(driver_code)
|
||||
|
||||
# Wait until the driver is reported as dead by GCS.
|
||||
def is_job_done():
|
||||
jobs = ray.jobs()
|
||||
for job in jobs:
|
||||
if "StopTime" in job:
|
||||
return True
|
||||
return False
|
||||
|
||||
wait_for_condition(is_job_done)
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class Actor:
|
||||
def ping(self):
|
||||
return "pong"
|
||||
|
||||
# Get the named placement group and schedule a actor.
|
||||
placement_group = ray.util.get_placement_group(global_placement_group_name)
|
||||
assert placement_group is not None
|
||||
assert placement_group.wait(5)
|
||||
actor = Actor.options(
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=0).remote()
|
||||
|
||||
ray.get(actor.ping.remote())
|
||||
|
||||
# Create another placement group and make sure its creation will failed.
|
||||
same_name_pg = ray.util.placement_group(
|
||||
[{
|
||||
"CPU": 1
|
||||
} for _ in range(2)],
|
||||
strategy="STRICT_SPREAD",
|
||||
name=global_placement_group_name)
|
||||
assert not same_name_pg.wait(10)
|
||||
|
||||
# Remove a named placement group and make sure the second creation
|
||||
# will successful.
|
||||
ray.util.remove_placement_group(placement_group)
|
||||
same_name_pg = ray.util.placement_group(
|
||||
[{
|
||||
"CPU": 1
|
||||
} for _ in range(2)],
|
||||
strategy="STRICT_SPREAD",
|
||||
name=global_placement_group_name)
|
||||
assert same_name_pg.wait(10)
|
||||
|
||||
# Get a named placement group with a name that doesn't exist
|
||||
# and make sure it will raise ValueError correctly.
|
||||
error_count = 0
|
||||
try:
|
||||
ray.util.get_placement_group("inexistent_pg")
|
||||
except ValueError:
|
||||
error_count = error_count + 1
|
||||
assert error_count == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
@@ -4,7 +4,8 @@ from ray.util.check_serialize import inspect_serializability
|
||||
from ray.util.debug import log_once, disable_log_once_globally, \
|
||||
enable_periodic_logging
|
||||
from ray.util.placement_group import (placement_group, placement_group_table,
|
||||
remove_placement_group)
|
||||
remove_placement_group,
|
||||
get_placement_group)
|
||||
from ray.util import rpdb as pdb
|
||||
from ray.util.serialization import register_serializer, deregister_serializer
|
||||
|
||||
@@ -19,6 +20,7 @@ __all__ = [
|
||||
"pdb",
|
||||
"placement_group",
|
||||
"placement_group_table",
|
||||
"get_placement_group",
|
||||
"remove_placement_group",
|
||||
"inspect_serializability",
|
||||
"collective",
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import (List, Dict, Optional, Union)
|
||||
|
||||
import ray
|
||||
from ray._raylet import PlacementGroupID, ObjectRef
|
||||
from ray.utils import hex_to_binary
|
||||
|
||||
bundle_reservation_check = None
|
||||
|
||||
@@ -145,7 +146,7 @@ class PlacementGroup:
|
||||
|
||||
def placement_group(bundles: List[Dict[str, float]],
|
||||
strategy: str = "PACK",
|
||||
name: str = "unnamed_group",
|
||||
name: str = "",
|
||||
lifetime=None) -> PlacementGroup:
|
||||
"""Asynchronously creates a PlacementGroup.
|
||||
|
||||
@@ -211,6 +212,29 @@ def remove_placement_group(placement_group: PlacementGroup):
|
||||
worker.core_worker.remove_placement_group(placement_group.id)
|
||||
|
||||
|
||||
def get_placement_group(placement_group_name: str):
|
||||
"""Get a placement group object with a global name.
|
||||
|
||||
Returns:
|
||||
None if can't find a placement group with the given name.
|
||||
The placement group object otherwise.
|
||||
"""
|
||||
if not placement_group_name:
|
||||
raise ValueError(
|
||||
"Please supply a non-empty value to get_placement_group")
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
placement_group_info = ray.state.state.get_placement_group_by_name(
|
||||
placement_group_name)
|
||||
if placement_group_info is None:
|
||||
raise ValueError(
|
||||
f"Failed to look up actor with name: {placement_group_name}")
|
||||
else:
|
||||
return PlacementGroup(
|
||||
PlacementGroupID(
|
||||
hex_to_binary(placement_group_info["placement_group_id"])))
|
||||
|
||||
|
||||
def placement_group_table(placement_group: PlacementGroup = None) -> list:
|
||||
"""Get the state of the placement group from GCS.
|
||||
|
||||
|
||||
@@ -727,7 +727,7 @@ class PlacementGroupInfoAccessor {
|
||||
virtual Status AsyncCreatePlacementGroup(
|
||||
const PlacementGroupSpecification &placement_group_spec) = 0;
|
||||
|
||||
/// Get a placement group data from GCS asynchronously.
|
||||
/// Get a placement group data from GCS asynchronously by id.
|
||||
///
|
||||
/// \param placement_group_id The id of a placement group to obtain from GCS.
|
||||
/// \return Status.
|
||||
@@ -735,6 +735,14 @@ class PlacementGroupInfoAccessor {
|
||||
const PlacementGroupID &placement_group_id,
|
||||
const OptionalItemCallback<rpc::PlacementGroupTableData> &callback) = 0;
|
||||
|
||||
/// Get a placement group data from GCS asynchronously by name.
|
||||
///
|
||||
/// \param placement_group_name The name of a placement group to obtain from GCS.
|
||||
/// \return Status.
|
||||
virtual Status AsyncGetByName(
|
||||
const std::string &placement_group_name,
|
||||
const OptionalItemCallback<rpc::PlacementGroupTableData> &callback) = 0;
|
||||
|
||||
/// Get all placement group info from GCS asynchronously.
|
||||
///
|
||||
/// \param callback Callback that will be called after lookup finished.
|
||||
|
||||
@@ -259,5 +259,17 @@ std::unique_ptr<std::string> GlobalStateAccessor::GetPlacementGroupInfo(
|
||||
return placement_group_table_data;
|
||||
}
|
||||
|
||||
std::unique_ptr<std::string> GlobalStateAccessor::GetPlacementGroupByName(
|
||||
const std::string &placement_group_name) {
|
||||
std::unique_ptr<std::string> placement_group_table_data;
|
||||
std::promise<bool> promise;
|
||||
RAY_CHECK_OK(gcs_client_->PlacementGroups().AsyncGetByName(
|
||||
placement_group_name,
|
||||
TransformForOptionalItemCallback<rpc::PlacementGroupTableData>(
|
||||
placement_group_table_data, promise)));
|
||||
promise.get_future().get();
|
||||
return placement_group_table_data;
|
||||
}
|
||||
|
||||
} // namespace gcs
|
||||
} // namespace ray
|
||||
|
||||
@@ -151,15 +151,24 @@ class GlobalStateAccessor {
|
||||
/// deserialized with protobuf function.
|
||||
std::vector<std::string> GetAllPlacementGroupInfo();
|
||||
|
||||
/// Get information of a placement group from GCS Service.
|
||||
/// Get information of a placement group from GCS Service by ID.
|
||||
///
|
||||
/// \param placement_group The ID of placement group to look up in the GCS Service.
|
||||
/// \param placement_group_id The ID of placement group to look up in the GCS Service.
|
||||
/// \return Placement group info. To support multi-language, we serialize each
|
||||
/// PlacementGroupTableData and return the serialized string. Where used, it needs to be
|
||||
/// deserialized with protobuf function.
|
||||
std::unique_ptr<std::string> GetPlacementGroupInfo(
|
||||
const PlacementGroupID &placement_group_id);
|
||||
|
||||
/// Get information of a placement group from GCS Service by name.
|
||||
///
|
||||
/// \param placement_group_name The name of placement group to look up in the GCS
|
||||
/// Service. \return Placement group info. To support multi-language, we serialize each
|
||||
/// PlacementGroupTableData and return the serialized string. Where used, it needs to be
|
||||
/// deserialized with protobuf function.
|
||||
std::unique_ptr<std::string> GetPlacementGroupByName(
|
||||
const std::string &placement_group_name);
|
||||
|
||||
private:
|
||||
/// MultiItem transformation helper in template style.
|
||||
///
|
||||
|
||||
@@ -1466,6 +1466,26 @@ Status ServiceBasedPlacementGroupInfoAccessor::AsyncGet(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ServiceBasedPlacementGroupInfoAccessor::AsyncGetByName(
|
||||
const std::string &name,
|
||||
const OptionalItemCallback<rpc::PlacementGroupTableData> &callback) {
|
||||
RAY_LOG(DEBUG) << "Getting named placement group info, name = " << name;
|
||||
rpc::GetNamedPlacementGroupRequest request;
|
||||
request.set_name(name);
|
||||
client_impl_->GetGcsRpcClient().GetNamedPlacementGroup(
|
||||
request, [name, callback](const Status &status,
|
||||
const rpc::GetNamedPlacementGroupReply &reply) {
|
||||
if (reply.has_placement_group_table_data()) {
|
||||
callback(status, reply.placement_group_table_data());
|
||||
} else {
|
||||
callback(status, boost::none);
|
||||
}
|
||||
RAY_LOG(DEBUG) << "Finished getting named placement group info, status = "
|
||||
<< status << ", name = " << name;
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ServiceBasedPlacementGroupInfoAccessor::AsyncGetAll(
|
||||
const MultiItemCallback<rpc::PlacementGroupTableData> &callback) {
|
||||
RAY_LOG(DEBUG) << "Getting all placement group info.";
|
||||
|
||||
@@ -453,6 +453,10 @@ class ServiceBasedPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor
|
||||
const PlacementGroupID &placement_group_id,
|
||||
const OptionalItemCallback<rpc::PlacementGroupTableData> &callback) override;
|
||||
|
||||
Status AsyncGetByName(
|
||||
const std::string &name,
|
||||
const OptionalItemCallback<rpc::PlacementGroupTableData> &callback) override;
|
||||
|
||||
Status AsyncGetAll(
|
||||
const MultiItemCallback<rpc::PlacementGroupTableData> &callback) override;
|
||||
|
||||
|
||||
@@ -65,7 +65,8 @@ rpc::PlacementStrategy GcsPlacementGroup::GetStrategy() const {
|
||||
return placement_group_table_data_.strategy();
|
||||
}
|
||||
|
||||
const rpc::PlacementGroupTableData &GcsPlacementGroup::GetPlacementGroupTableData() {
|
||||
const rpc::PlacementGroupTableData &GcsPlacementGroup::GetPlacementGroupTableData()
|
||||
const {
|
||||
return placement_group_table_data_;
|
||||
}
|
||||
|
||||
@@ -147,6 +148,21 @@ void GcsPlacementGroupManager::RegisterPlacementGroup(
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (!placement_group->GetName().empty()) {
|
||||
auto it = named_placement_groups_.find(placement_group->GetName());
|
||||
if (it == named_placement_groups_.end()) {
|
||||
named_placement_groups_.emplace(placement_group->GetName(),
|
||||
placement_group->GetPlacementGroupID());
|
||||
} else {
|
||||
std::stringstream stream;
|
||||
stream << "Failed to create placement group '"
|
||||
<< placement_group->GetPlacementGroupID() << "' because name '"
|
||||
<< placement_group->GetName() << "' already exists.";
|
||||
RAY_LOG(WARNING) << stream.str();
|
||||
callback(Status::Invalid(stream.str()));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Mark the callback as pending and invoke it after the placement_group has been
|
||||
// successfully created.
|
||||
@@ -178,11 +194,9 @@ void GcsPlacementGroupManager::RegisterPlacementGroup(
|
||||
PlacementGroupID GcsPlacementGroupManager::GetPlacementGroupIDByName(
|
||||
const std::string &name) {
|
||||
PlacementGroupID placement_group_id = PlacementGroupID::Nil();
|
||||
for (const auto &iter : registered_placement_groups_) {
|
||||
if (iter.second->GetName() == name) {
|
||||
placement_group_id = iter.first;
|
||||
break;
|
||||
}
|
||||
auto it = named_placement_groups_.find(name);
|
||||
if (it != named_placement_groups_.end()) {
|
||||
placement_group_id = it->second;
|
||||
}
|
||||
return placement_group_id;
|
||||
}
|
||||
@@ -315,10 +329,19 @@ void GcsPlacementGroupManager::RemovePlacementGroup(
|
||||
on_placement_group_removed(Status::OK());
|
||||
return;
|
||||
}
|
||||
auto placement_group = placement_group_it->second;
|
||||
auto placement_group = std::move(placement_group_it->second);
|
||||
registered_placement_groups_.erase(placement_group_it);
|
||||
placement_group_to_create_callbacks_.erase(placement_group_id);
|
||||
|
||||
// Remove placement group from `named_placement_groups_` if its name is not empty.
|
||||
if (!placement_group->GetName().empty()) {
|
||||
auto it = named_placement_groups_.find(placement_group->GetName());
|
||||
if (it != named_placement_groups_.end() &&
|
||||
it->second == placement_group->GetPlacementGroupID()) {
|
||||
named_placement_groups_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy all bundles.
|
||||
gcs_placement_group_scheduler_->DestroyPlacementGroupBundleResourcesIfExists(
|
||||
placement_group_id);
|
||||
@@ -385,6 +408,30 @@ void GcsPlacementGroupManager::HandleGetPlacementGroup(
|
||||
++counts_[CountType::GET_PLACEMENT_GROUP_REQUEST];
|
||||
}
|
||||
|
||||
void GcsPlacementGroupManager::HandleGetNamedPlacementGroup(
|
||||
const rpc::GetNamedPlacementGroupRequest &request,
|
||||
rpc::GetNamedPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) {
|
||||
const std::string &name = request.name();
|
||||
RAY_LOG(DEBUG) << "Getting named placement group info, name = " << name;
|
||||
|
||||
// Try to look up the placement Group ID for the named placement group.
|
||||
auto placement_group_id = GetPlacementGroupIDByName(name);
|
||||
|
||||
if (placement_group_id.IsNil()) {
|
||||
// The placement group was not found.
|
||||
RAY_LOG(DEBUG) << "Placement Group with name '" << name << "' was not found";
|
||||
} else {
|
||||
const auto &iter = registered_placement_groups_.find(placement_group_id);
|
||||
RAY_CHECK(iter != registered_placement_groups_.end());
|
||||
reply->mutable_placement_group_table_data()->CopyFrom(
|
||||
iter->second->GetPlacementGroupTableData());
|
||||
RAY_LOG(DEBUG) << "Finished get named placement group info, placement group id = "
|
||||
<< placement_group_id;
|
||||
}
|
||||
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
|
||||
++counts_[CountType::GET_NAMED_PLACEMENT_GROUP_REQUEST];
|
||||
}
|
||||
|
||||
void GcsPlacementGroupManager::HandleGetAllPlacementGroup(
|
||||
const rpc::GetAllPlacementGroupRequest &request,
|
||||
rpc::GetAllPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) {
|
||||
@@ -550,6 +597,10 @@ void GcsPlacementGroupManager::Initialize(const GcsInitData &gcs_init_data) {
|
||||
auto placement_group = std::make_shared<GcsPlacementGroup>(item.second);
|
||||
if (item.second.state() != rpc::PlacementGroupTableData::REMOVED) {
|
||||
registered_placement_groups_.emplace(item.first, placement_group);
|
||||
if (!placement_group->GetName().empty()) {
|
||||
named_placement_groups_.emplace(placement_group->GetName(),
|
||||
placement_group->GetPlacementGroupID());
|
||||
}
|
||||
|
||||
if (item.second.state() == rpc::PlacementGroupTableData::PENDING ||
|
||||
item.second.state() == rpc::PlacementGroupTableData::RESCHEDULING) {
|
||||
@@ -587,6 +638,7 @@ std::string GcsPlacementGroupManager::DebugString() const {
|
||||
<< ", WaitPlacementGroupUntilReady request count: "
|
||||
<< counts_[CountType::WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST]
|
||||
<< ", Registered placement groups count: " << registered_placement_groups_.size()
|
||||
<< ", Named placement group count: " << named_placement_groups_.size()
|
||||
<< ", Pending placement groups count: " << pending_placement_groups_.size()
|
||||
<< "}";
|
||||
return stream.str();
|
||||
|
||||
@@ -65,7 +65,7 @@ class GcsPlacementGroup {
|
||||
}
|
||||
|
||||
/// Get the immutable PlacementGroupTableData of this placement group.
|
||||
const rpc::PlacementGroupTableData &GetPlacementGroupTableData();
|
||||
const rpc::PlacementGroupTableData &GetPlacementGroupTableData() const;
|
||||
|
||||
/// Get the mutable bundle of this placement group.
|
||||
rpc::Bundle *GetMutableBundle(int bundle_index);
|
||||
@@ -155,10 +155,13 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler {
|
||||
rpc::GetPlacementGroupReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleGetNamedPlacementGroup(const rpc::GetNamedPlacementGroupRequest &request,
|
||||
rpc::GetNamedPlacementGroupReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleGetAllPlacementGroup(const rpc::GetAllPlacementGroupRequest &request,
|
||||
rpc::GetAllPlacementGroupReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleWaitPlacementGroupUntilReady(
|
||||
const rpc::WaitPlacementGroupUntilReadyRequest &request,
|
||||
rpc::WaitPlacementGroupUntilReadyReply *reply,
|
||||
@@ -315,6 +318,9 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler {
|
||||
/// Reference of GcsResourceManager.
|
||||
GcsResourceManager &gcs_resource_manager_;
|
||||
|
||||
/// Maps placement group names to their placement group ID for lookups by name.
|
||||
absl::flat_hash_map<std::string, PlacementGroupID> named_placement_groups_;
|
||||
|
||||
// Debug info.
|
||||
enum CountType {
|
||||
CREATE_PLACEMENT_GROUP_REQUEST = 0,
|
||||
@@ -322,7 +328,8 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler {
|
||||
GET_PLACEMENT_GROUP_REQUEST = 2,
|
||||
GET_ALL_PLACEMENT_GROUP_REQUEST = 3,
|
||||
WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST = 4,
|
||||
CountType_MAX = 5,
|
||||
GET_NAMED_PLACEMENT_GROUP_REQUEST = 5,
|
||||
CountType_MAX = 6,
|
||||
};
|
||||
uint64_t counts_[CountType::CountType_MAX] = {0};
|
||||
};
|
||||
|
||||
@@ -174,6 +174,31 @@ TEST_F(GcsPlacementGroupManagerTest, TestGetPlacementGroupIDByName) {
|
||||
PlacementGroupID::FromBinary(request.placement_group_spec().placement_group_id()));
|
||||
}
|
||||
|
||||
TEST_F(GcsPlacementGroupManagerTest, TestRemoveNamedPlacementGroup) {
|
||||
auto request = Mocker::GenCreatePlacementGroupRequest("test_name");
|
||||
std::atomic<int> finished_placement_group_count(0);
|
||||
gcs_placement_group_manager_->RegisterPlacementGroup(
|
||||
std::make_shared<gcs::GcsPlacementGroup>(request),
|
||||
[&finished_placement_group_count](const Status &status) {
|
||||
++finished_placement_group_count;
|
||||
});
|
||||
|
||||
ASSERT_EQ(finished_placement_group_count, 0);
|
||||
WaitForExpectedPgCount(1);
|
||||
auto placement_group = mock_placement_group_scheduler_->placement_groups_.back();
|
||||
mock_placement_group_scheduler_->placement_groups_.pop_back();
|
||||
|
||||
gcs_placement_group_manager_->OnPlacementGroupCreationSuccess(placement_group);
|
||||
WaitForExpectedCount(finished_placement_group_count, 1);
|
||||
ASSERT_EQ(placement_group->GetState(), rpc::PlacementGroupTableData::CREATED);
|
||||
// Remove the named placement group.
|
||||
gcs_placement_group_manager_->RemovePlacementGroup(
|
||||
placement_group->GetPlacementGroupID(),
|
||||
[](const Status &status) { ASSERT_TRUE(status.ok()); });
|
||||
ASSERT_EQ(gcs_placement_group_manager_->GetPlacementGroupIDByName("test_name"),
|
||||
PlacementGroupID::Nil());
|
||||
}
|
||||
|
||||
TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeAdd) {
|
||||
auto request = Mocker::GenCreatePlacementGroupRequest();
|
||||
std::atomic<int> finished_placement_group_count(0);
|
||||
|
||||
@@ -504,6 +504,17 @@ message WaitPlacementGroupUntilReadyReply {
|
||||
GcsStatus status = 1;
|
||||
}
|
||||
|
||||
message GetNamedPlacementGroupRequest {
|
||||
// Name of the placement group.
|
||||
string name = 1;
|
||||
}
|
||||
|
||||
message GetNamedPlacementGroupReply {
|
||||
GcsStatus status = 1;
|
||||
// Data of placement group.
|
||||
PlacementGroupTableData placement_group_table_data = 2;
|
||||
}
|
||||
|
||||
// Service for placement group info access.
|
||||
service PlacementGroupInfoGcsService {
|
||||
// Create placement group via gcs service.
|
||||
@@ -514,6 +525,9 @@ service PlacementGroupInfoGcsService {
|
||||
returns (RemovePlacementGroupReply);
|
||||
// Get placement group information via gcs service.
|
||||
rpc GetPlacementGroup(GetPlacementGroupRequest) returns (GetPlacementGroupReply);
|
||||
// Get named placement group information via gcs service.
|
||||
rpc GetNamedPlacementGroup(GetNamedPlacementGroupRequest)
|
||||
returns (GetNamedPlacementGroupReply);
|
||||
// Get information of all placement group from GCS Service.
|
||||
rpc GetAllPlacementGroup(GetAllPlacementGroupRequest)
|
||||
returns (GetAllPlacementGroupReply);
|
||||
|
||||
@@ -254,6 +254,10 @@ class GcsRpcClient {
|
||||
VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetPlacementGroup,
|
||||
placement_group_info_grpc_client_, )
|
||||
|
||||
/// Get placement group data from GCS Service by name.
|
||||
VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetNamedPlacementGroup,
|
||||
placement_group_info_grpc_client_, )
|
||||
|
||||
/// Get information of all placement group from GCS Service.
|
||||
VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetAllPlacementGroup,
|
||||
placement_group_info_grpc_client_, )
|
||||
|
||||
@@ -522,6 +522,10 @@ class PlacementGroupInfoGcsServiceHandler {
|
||||
const WaitPlacementGroupUntilReadyRequest &request,
|
||||
WaitPlacementGroupUntilReadyReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
|
||||
virtual void HandleGetNamedPlacementGroup(const GetNamedPlacementGroupRequest &request,
|
||||
GetNamedPlacementGroupReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
};
|
||||
|
||||
/// The `GrpcService` for `PlacementGroupInfoGcsService`.
|
||||
@@ -543,6 +547,7 @@ class PlacementGroupInfoGrpcService : public GrpcService {
|
||||
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(CreatePlacementGroup);
|
||||
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(RemovePlacementGroup);
|
||||
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetPlacementGroup);
|
||||
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetNamedPlacementGroup);
|
||||
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetAllPlacementGroup);
|
||||
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(WaitPlacementGroupUntilReady);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user