[Placement Group] Support named placement group (#13755)

This commit is contained in:
DK.Pino
2021-02-05 11:04:51 +08:00
committed by GitHub
parent 40bad86c7a
commit fb89f9c2c8
18 changed files with 346 additions and 17 deletions
+35
View File
@@ -252,6 +252,41 @@ Note that you can anytime remove the placement group to clean up resources.
ray.shutdown()
Named Placement Groups
----------------------
A placement group can be given a globally unique name.
This allows you to retrieve the placement group from any job in the Ray cluster.
This can be useful if you cannot directly pass the placement group handle to
the actor or task that needs it, or if you are trying to
access a placement group launched by another driver.
Note that the placement group will still be destroyed if it's lifetime isn't `detached`.
See :ref:`placement-group-lifetimes` for more details.
.. tabs::
.. group-tab:: Python
.. code-block:: python
# first_driver.py
# Create a placement group with a global name.
pg = placement_group([{"CPU": 2}, {"CPU": 2}], strategy="STRICT_SPREAD", lifetime="detached", name="global_name")
ray.get(pg.ready())
Then, we can retrieve the actor later somewhere.
.. code-block:: python
# second_driver.py
# Retrieve a placement group with a global name.
pg = ray.util.get_placement_group("global_name")
.. group-tab:: Java
The named placement group is not implemented for Java APIs yet.
.. _placement-group-lifetimes:
Placement Group Lifetimes
-------------------------
@@ -32,4 +32,6 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil:
c_bool AddWorkerInfo(const c_string &serialized_string)
unique_ptr[c_string] GetPlacementGroupInfo(
const CPlacementGroupID &placement_group_id)
unique_ptr[c_string] GetPlacementGroupByName(
const c_string &placement_group_name)
c_vector[c_string] GetAllPlacementGroupInfo()
@@ -147,3 +147,13 @@ cdef class GlobalStateAccessor:
if result:
return c_string(result.get().data(), result.get().size())
return None
def get_placement_group_by_name(self, placement_group_name):
cdef unique_ptr[c_string] result
cdef c_string cplacement_group_name = placement_group_name
with nogil:
result = self.inner.get().GetPlacementGroupByName(
cplacement_group_name)
if result:
return c_string(result.get().data(), result.get().size())
return None
+14
View File
@@ -388,6 +388,20 @@ class GlobalState:
return dict(result)
def get_placement_group_by_name(self, placement_group_name):
self._check_connected()
placement_group_info = (
self.global_state_accessor.get_placement_group_by_name(
placement_group_name))
if placement_group_info is None:
return None
else:
placement_group_table_data = \
gcs_utils.PlacementGroupTableData.FromString(
placement_group_info)
return self._gen_placement_group_info(placement_group_table_data)
def placement_group_table(self, placement_group_id=None):
self._check_connected()
+84 -2
View File
@@ -375,6 +375,7 @@ def test_remove_pending_placement_group(ray_start_cluster):
# Create a placement group that cannot be scheduled now.
placement_group = ray.util.placement_group([{"GPU": 2}, {"CPU": 2}])
ray.util.remove_placement_group(placement_group)
# TODO(sang): Add state check here.
@ray.remote(num_cpus=4)
def f():
@@ -797,10 +798,10 @@ def test_mini_integration(ray_start_cluster):
pg_tasks = []
# total bundle gpu usage = bundles_per_pg * total_num_pg * per_bundle_gpus
# Note this is half of total
for _ in range(total_num_pg):
for index in range(total_num_pg):
pgs.append(
ray.util.placement_group(
name="name",
name=f"name{index}",
strategy="PACK",
bundles=[{
"GPU": per_bundle_gpus
@@ -1423,5 +1424,86 @@ ray.shutdown()
assert assert_alive_num_actor(4)
def test_named_placement_group(ray_start_cluster):
cluster = ray_start_cluster
for _ in range(2):
cluster.add_node(num_cpus=3)
cluster.wait_for_nodes()
info = ray.init(address=cluster.address)
global_placement_group_name = "named_placement_group"
# Create a detached placement group with name.
driver_code = f"""
import ray
ray.init(address="{info["redis_address"]}")
pg = ray.util.placement_group(
[{{"CPU": 1}} for _ in range(2)],
strategy="STRICT_SPREAD",
name="{global_placement_group_name}",
lifetime="detached")
ray.get(pg.ready())
ray.shutdown()
"""
run_string_as_driver(driver_code)
# Wait until the driver is reported as dead by GCS.
def is_job_done():
jobs = ray.jobs()
for job in jobs:
if "StopTime" in job:
return True
return False
wait_for_condition(is_job_done)
@ray.remote(num_cpus=1)
class Actor:
def ping(self):
return "pong"
# Get the named placement group and schedule a actor.
placement_group = ray.util.get_placement_group(global_placement_group_name)
assert placement_group is not None
assert placement_group.wait(5)
actor = Actor.options(
placement_group=placement_group,
placement_group_bundle_index=0).remote()
ray.get(actor.ping.remote())
# Create another placement group and make sure its creation will failed.
same_name_pg = ray.util.placement_group(
[{
"CPU": 1
} for _ in range(2)],
strategy="STRICT_SPREAD",
name=global_placement_group_name)
assert not same_name_pg.wait(10)
# Remove a named placement group and make sure the second creation
# will successful.
ray.util.remove_placement_group(placement_group)
same_name_pg = ray.util.placement_group(
[{
"CPU": 1
} for _ in range(2)],
strategy="STRICT_SPREAD",
name=global_placement_group_name)
assert same_name_pg.wait(10)
# Get a named placement group with a name that doesn't exist
# and make sure it will raise ValueError correctly.
error_count = 0
try:
ray.util.get_placement_group("inexistent_pg")
except ValueError:
error_count = error_count + 1
assert error_count == 1
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
+3 -1
View File
@@ -4,7 +4,8 @@ from ray.util.check_serialize import inspect_serializability
from ray.util.debug import log_once, disable_log_once_globally, \
enable_periodic_logging
from ray.util.placement_group import (placement_group, placement_group_table,
remove_placement_group)
remove_placement_group,
get_placement_group)
from ray.util import rpdb as pdb
from ray.util.serialization import register_serializer, deregister_serializer
@@ -19,6 +20,7 @@ __all__ = [
"pdb",
"placement_group",
"placement_group_table",
"get_placement_group",
"remove_placement_group",
"inspect_serializability",
"collective",
+25 -1
View File
@@ -4,6 +4,7 @@ from typing import (List, Dict, Optional, Union)
import ray
from ray._raylet import PlacementGroupID, ObjectRef
from ray.utils import hex_to_binary
bundle_reservation_check = None
@@ -145,7 +146,7 @@ class PlacementGroup:
def placement_group(bundles: List[Dict[str, float]],
strategy: str = "PACK",
name: str = "unnamed_group",
name: str = "",
lifetime=None) -> PlacementGroup:
"""Asynchronously creates a PlacementGroup.
@@ -211,6 +212,29 @@ def remove_placement_group(placement_group: PlacementGroup):
worker.core_worker.remove_placement_group(placement_group.id)
def get_placement_group(placement_group_name: str):
"""Get a placement group object with a global name.
Returns:
None if can't find a placement group with the given name.
The placement group object otherwise.
"""
if not placement_group_name:
raise ValueError(
"Please supply a non-empty value to get_placement_group")
worker = ray.worker.global_worker
worker.check_connected()
placement_group_info = ray.state.state.get_placement_group_by_name(
placement_group_name)
if placement_group_info is None:
raise ValueError(
f"Failed to look up actor with name: {placement_group_name}")
else:
return PlacementGroup(
PlacementGroupID(
hex_to_binary(placement_group_info["placement_group_id"])))
def placement_group_table(placement_group: PlacementGroup = None) -> list:
"""Get the state of the placement group from GCS.
+9 -1
View File
@@ -727,7 +727,7 @@ class PlacementGroupInfoAccessor {
virtual Status AsyncCreatePlacementGroup(
const PlacementGroupSpecification &placement_group_spec) = 0;
/// Get a placement group data from GCS asynchronously.
/// Get a placement group data from GCS asynchronously by id.
///
/// \param placement_group_id The id of a placement group to obtain from GCS.
/// \return Status.
@@ -735,6 +735,14 @@ class PlacementGroupInfoAccessor {
const PlacementGroupID &placement_group_id,
const OptionalItemCallback<rpc::PlacementGroupTableData> &callback) = 0;
/// Get a placement group data from GCS asynchronously by name.
///
/// \param placement_group_name The name of a placement group to obtain from GCS.
/// \return Status.
virtual Status AsyncGetByName(
const std::string &placement_group_name,
const OptionalItemCallback<rpc::PlacementGroupTableData> &callback) = 0;
/// Get all placement group info from GCS asynchronously.
///
/// \param callback Callback that will be called after lookup finished.
@@ -259,5 +259,17 @@ std::unique_ptr<std::string> GlobalStateAccessor::GetPlacementGroupInfo(
return placement_group_table_data;
}
std::unique_ptr<std::string> GlobalStateAccessor::GetPlacementGroupByName(
const std::string &placement_group_name) {
std::unique_ptr<std::string> placement_group_table_data;
std::promise<bool> promise;
RAY_CHECK_OK(gcs_client_->PlacementGroups().AsyncGetByName(
placement_group_name,
TransformForOptionalItemCallback<rpc::PlacementGroupTableData>(
placement_group_table_data, promise)));
promise.get_future().get();
return placement_group_table_data;
}
} // namespace gcs
} // namespace ray
+11 -2
View File
@@ -151,15 +151,24 @@ class GlobalStateAccessor {
/// deserialized with protobuf function.
std::vector<std::string> GetAllPlacementGroupInfo();
/// Get information of a placement group from GCS Service.
/// Get information of a placement group from GCS Service by ID.
///
/// \param placement_group The ID of placement group to look up in the GCS Service.
/// \param placement_group_id The ID of placement group to look up in the GCS Service.
/// \return Placement group info. To support multi-language, we serialize each
/// PlacementGroupTableData and return the serialized string. Where used, it needs to be
/// deserialized with protobuf function.
std::unique_ptr<std::string> GetPlacementGroupInfo(
const PlacementGroupID &placement_group_id);
/// Get information of a placement group from GCS Service by name.
///
/// \param placement_group_name The name of placement group to look up in the GCS
/// Service. \return Placement group info. To support multi-language, we serialize each
/// PlacementGroupTableData and return the serialized string. Where used, it needs to be
/// deserialized with protobuf function.
std::unique_ptr<std::string> GetPlacementGroupByName(
const std::string &placement_group_name);
private:
/// MultiItem transformation helper in template style.
///
@@ -1466,6 +1466,26 @@ Status ServiceBasedPlacementGroupInfoAccessor::AsyncGet(
return Status::OK();
}
Status ServiceBasedPlacementGroupInfoAccessor::AsyncGetByName(
const std::string &name,
const OptionalItemCallback<rpc::PlacementGroupTableData> &callback) {
RAY_LOG(DEBUG) << "Getting named placement group info, name = " << name;
rpc::GetNamedPlacementGroupRequest request;
request.set_name(name);
client_impl_->GetGcsRpcClient().GetNamedPlacementGroup(
request, [name, callback](const Status &status,
const rpc::GetNamedPlacementGroupReply &reply) {
if (reply.has_placement_group_table_data()) {
callback(status, reply.placement_group_table_data());
} else {
callback(status, boost::none);
}
RAY_LOG(DEBUG) << "Finished getting named placement group info, status = "
<< status << ", name = " << name;
});
return Status::OK();
}
Status ServiceBasedPlacementGroupInfoAccessor::AsyncGetAll(
const MultiItemCallback<rpc::PlacementGroupTableData> &callback) {
RAY_LOG(DEBUG) << "Getting all placement group info.";
@@ -453,6 +453,10 @@ class ServiceBasedPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor
const PlacementGroupID &placement_group_id,
const OptionalItemCallback<rpc::PlacementGroupTableData> &callback) override;
Status AsyncGetByName(
const std::string &name,
const OptionalItemCallback<rpc::PlacementGroupTableData> &callback) override;
Status AsyncGetAll(
const MultiItemCallback<rpc::PlacementGroupTableData> &callback) override;
@@ -65,7 +65,8 @@ rpc::PlacementStrategy GcsPlacementGroup::GetStrategy() const {
return placement_group_table_data_.strategy();
}
const rpc::PlacementGroupTableData &GcsPlacementGroup::GetPlacementGroupTableData() {
const rpc::PlacementGroupTableData &GcsPlacementGroup::GetPlacementGroupTableData()
const {
return placement_group_table_data_;
}
@@ -147,6 +148,21 @@ void GcsPlacementGroupManager::RegisterPlacementGroup(
}
return;
}
if (!placement_group->GetName().empty()) {
auto it = named_placement_groups_.find(placement_group->GetName());
if (it == named_placement_groups_.end()) {
named_placement_groups_.emplace(placement_group->GetName(),
placement_group->GetPlacementGroupID());
} else {
std::stringstream stream;
stream << "Failed to create placement group '"
<< placement_group->GetPlacementGroupID() << "' because name '"
<< placement_group->GetName() << "' already exists.";
RAY_LOG(WARNING) << stream.str();
callback(Status::Invalid(stream.str()));
return;
}
}
// Mark the callback as pending and invoke it after the placement_group has been
// successfully created.
@@ -178,11 +194,9 @@ void GcsPlacementGroupManager::RegisterPlacementGroup(
PlacementGroupID GcsPlacementGroupManager::GetPlacementGroupIDByName(
const std::string &name) {
PlacementGroupID placement_group_id = PlacementGroupID::Nil();
for (const auto &iter : registered_placement_groups_) {
if (iter.second->GetName() == name) {
placement_group_id = iter.first;
break;
}
auto it = named_placement_groups_.find(name);
if (it != named_placement_groups_.end()) {
placement_group_id = it->second;
}
return placement_group_id;
}
@@ -315,10 +329,19 @@ void GcsPlacementGroupManager::RemovePlacementGroup(
on_placement_group_removed(Status::OK());
return;
}
auto placement_group = placement_group_it->second;
auto placement_group = std::move(placement_group_it->second);
registered_placement_groups_.erase(placement_group_it);
placement_group_to_create_callbacks_.erase(placement_group_id);
// Remove placement group from `named_placement_groups_` if its name is not empty.
if (!placement_group->GetName().empty()) {
auto it = named_placement_groups_.find(placement_group->GetName());
if (it != named_placement_groups_.end() &&
it->second == placement_group->GetPlacementGroupID()) {
named_placement_groups_.erase(it);
}
}
// Destroy all bundles.
gcs_placement_group_scheduler_->DestroyPlacementGroupBundleResourcesIfExists(
placement_group_id);
@@ -385,6 +408,30 @@ void GcsPlacementGroupManager::HandleGetPlacementGroup(
++counts_[CountType::GET_PLACEMENT_GROUP_REQUEST];
}
void GcsPlacementGroupManager::HandleGetNamedPlacementGroup(
const rpc::GetNamedPlacementGroupRequest &request,
rpc::GetNamedPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) {
const std::string &name = request.name();
RAY_LOG(DEBUG) << "Getting named placement group info, name = " << name;
// Try to look up the placement Group ID for the named placement group.
auto placement_group_id = GetPlacementGroupIDByName(name);
if (placement_group_id.IsNil()) {
// The placement group was not found.
RAY_LOG(DEBUG) << "Placement Group with name '" << name << "' was not found";
} else {
const auto &iter = registered_placement_groups_.find(placement_group_id);
RAY_CHECK(iter != registered_placement_groups_.end());
reply->mutable_placement_group_table_data()->CopyFrom(
iter->second->GetPlacementGroupTableData());
RAY_LOG(DEBUG) << "Finished get named placement group info, placement group id = "
<< placement_group_id;
}
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
++counts_[CountType::GET_NAMED_PLACEMENT_GROUP_REQUEST];
}
void GcsPlacementGroupManager::HandleGetAllPlacementGroup(
const rpc::GetAllPlacementGroupRequest &request,
rpc::GetAllPlacementGroupReply *reply, rpc::SendReplyCallback send_reply_callback) {
@@ -550,6 +597,10 @@ void GcsPlacementGroupManager::Initialize(const GcsInitData &gcs_init_data) {
auto placement_group = std::make_shared<GcsPlacementGroup>(item.second);
if (item.second.state() != rpc::PlacementGroupTableData::REMOVED) {
registered_placement_groups_.emplace(item.first, placement_group);
if (!placement_group->GetName().empty()) {
named_placement_groups_.emplace(placement_group->GetName(),
placement_group->GetPlacementGroupID());
}
if (item.second.state() == rpc::PlacementGroupTableData::PENDING ||
item.second.state() == rpc::PlacementGroupTableData::RESCHEDULING) {
@@ -587,6 +638,7 @@ std::string GcsPlacementGroupManager::DebugString() const {
<< ", WaitPlacementGroupUntilReady request count: "
<< counts_[CountType::WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST]
<< ", Registered placement groups count: " << registered_placement_groups_.size()
<< ", Named placement group count: " << named_placement_groups_.size()
<< ", Pending placement groups count: " << pending_placement_groups_.size()
<< "}";
return stream.str();
@@ -65,7 +65,7 @@ class GcsPlacementGroup {
}
/// Get the immutable PlacementGroupTableData of this placement group.
const rpc::PlacementGroupTableData &GetPlacementGroupTableData();
const rpc::PlacementGroupTableData &GetPlacementGroupTableData() const;
/// Get the mutable bundle of this placement group.
rpc::Bundle *GetMutableBundle(int bundle_index);
@@ -155,10 +155,13 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler {
rpc::GetPlacementGroupReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleGetNamedPlacementGroup(const rpc::GetNamedPlacementGroupRequest &request,
rpc::GetNamedPlacementGroupReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleGetAllPlacementGroup(const rpc::GetAllPlacementGroupRequest &request,
rpc::GetAllPlacementGroupReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleWaitPlacementGroupUntilReady(
const rpc::WaitPlacementGroupUntilReadyRequest &request,
rpc::WaitPlacementGroupUntilReadyReply *reply,
@@ -315,6 +318,9 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler {
/// Reference of GcsResourceManager.
GcsResourceManager &gcs_resource_manager_;
/// Maps placement group names to their placement group ID for lookups by name.
absl::flat_hash_map<std::string, PlacementGroupID> named_placement_groups_;
// Debug info.
enum CountType {
CREATE_PLACEMENT_GROUP_REQUEST = 0,
@@ -322,7 +328,8 @@ class GcsPlacementGroupManager : public rpc::PlacementGroupInfoHandler {
GET_PLACEMENT_GROUP_REQUEST = 2,
GET_ALL_PLACEMENT_GROUP_REQUEST = 3,
WAIT_PLACEMENT_GROUP_UNTIL_READY_REQUEST = 4,
CountType_MAX = 5,
GET_NAMED_PLACEMENT_GROUP_REQUEST = 5,
CountType_MAX = 6,
};
uint64_t counts_[CountType::CountType_MAX] = {0};
};
@@ -174,6 +174,31 @@ TEST_F(GcsPlacementGroupManagerTest, TestGetPlacementGroupIDByName) {
PlacementGroupID::FromBinary(request.placement_group_spec().placement_group_id()));
}
TEST_F(GcsPlacementGroupManagerTest, TestRemoveNamedPlacementGroup) {
auto request = Mocker::GenCreatePlacementGroupRequest("test_name");
std::atomic<int> finished_placement_group_count(0);
gcs_placement_group_manager_->RegisterPlacementGroup(
std::make_shared<gcs::GcsPlacementGroup>(request),
[&finished_placement_group_count](const Status &status) {
++finished_placement_group_count;
});
ASSERT_EQ(finished_placement_group_count, 0);
WaitForExpectedPgCount(1);
auto placement_group = mock_placement_group_scheduler_->placement_groups_.back();
mock_placement_group_scheduler_->placement_groups_.pop_back();
gcs_placement_group_manager_->OnPlacementGroupCreationSuccess(placement_group);
WaitForExpectedCount(finished_placement_group_count, 1);
ASSERT_EQ(placement_group->GetState(), rpc::PlacementGroupTableData::CREATED);
// Remove the named placement group.
gcs_placement_group_manager_->RemovePlacementGroup(
placement_group->GetPlacementGroupID(),
[](const Status &status) { ASSERT_TRUE(status.ok()); });
ASSERT_EQ(gcs_placement_group_manager_->GetPlacementGroupIDByName("test_name"),
PlacementGroupID::Nil());
}
TEST_F(GcsPlacementGroupManagerTest, TestRescheduleWhenNodeAdd) {
auto request = Mocker::GenCreatePlacementGroupRequest();
std::atomic<int> finished_placement_group_count(0);
+14
View File
@@ -504,6 +504,17 @@ message WaitPlacementGroupUntilReadyReply {
GcsStatus status = 1;
}
message GetNamedPlacementGroupRequest {
// Name of the placement group.
string name = 1;
}
message GetNamedPlacementGroupReply {
GcsStatus status = 1;
// Data of placement group.
PlacementGroupTableData placement_group_table_data = 2;
}
// Service for placement group info access.
service PlacementGroupInfoGcsService {
// Create placement group via gcs service.
@@ -514,6 +525,9 @@ service PlacementGroupInfoGcsService {
returns (RemovePlacementGroupReply);
// Get placement group information via gcs service.
rpc GetPlacementGroup(GetPlacementGroupRequest) returns (GetPlacementGroupReply);
// Get named placement group information via gcs service.
rpc GetNamedPlacementGroup(GetNamedPlacementGroupRequest)
returns (GetNamedPlacementGroupReply);
// Get information of all placement group from GCS Service.
rpc GetAllPlacementGroup(GetAllPlacementGroupRequest)
returns (GetAllPlacementGroupReply);
+4
View File
@@ -254,6 +254,10 @@ class GcsRpcClient {
VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetPlacementGroup,
placement_group_info_grpc_client_, )
/// Get placement group data from GCS Service by name.
VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetNamedPlacementGroup,
placement_group_info_grpc_client_, )
/// Get information of all placement group from GCS Service.
VOID_GCS_RPC_CLIENT_METHOD(PlacementGroupInfoGcsService, GetAllPlacementGroup,
placement_group_info_grpc_client_, )
+5
View File
@@ -522,6 +522,10 @@ class PlacementGroupInfoGcsServiceHandler {
const WaitPlacementGroupUntilReadyRequest &request,
WaitPlacementGroupUntilReadyReply *reply,
SendReplyCallback send_reply_callback) = 0;
virtual void HandleGetNamedPlacementGroup(const GetNamedPlacementGroupRequest &request,
GetNamedPlacementGroupReply *reply,
SendReplyCallback send_reply_callback) = 0;
};
/// The `GrpcService` for `PlacementGroupInfoGcsService`.
@@ -543,6 +547,7 @@ class PlacementGroupInfoGrpcService : public GrpcService {
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(CreatePlacementGroup);
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(RemovePlacementGroup);
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetPlacementGroup);
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetNamedPlacementGroup);
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(GetAllPlacementGroup);
PLACEMENT_GROUP_INFO_SERVICE_RPC_HANDLER(WaitPlacementGroupUntilReady);
}