[Placement Group] Support named placement group (#13755)

This commit is contained in:
DK.Pino
2021-02-05 11:04:51 +08:00
committed by GitHub
parent 40bad86c7a
commit fb89f9c2c8
18 changed files with 346 additions and 17 deletions
@@ -32,4 +32,6 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil:
c_bool AddWorkerInfo(const c_string &serialized_string)
unique_ptr[c_string] GetPlacementGroupInfo(
const CPlacementGroupID &placement_group_id)
unique_ptr[c_string] GetPlacementGroupByName(
const c_string &placement_group_name)
c_vector[c_string] GetAllPlacementGroupInfo()
@@ -147,3 +147,13 @@ cdef class GlobalStateAccessor:
if result:
return c_string(result.get().data(), result.get().size())
return None
def get_placement_group_by_name(self, placement_group_name):
cdef unique_ptr[c_string] result
cdef c_string cplacement_group_name = placement_group_name
with nogil:
result = self.inner.get().GetPlacementGroupByName(
cplacement_group_name)
if result:
return c_string(result.get().data(), result.get().size())
return None
+14
View File
@@ -388,6 +388,20 @@ class GlobalState:
return dict(result)
def get_placement_group_by_name(self, placement_group_name):
self._check_connected()
placement_group_info = (
self.global_state_accessor.get_placement_group_by_name(
placement_group_name))
if placement_group_info is None:
return None
else:
placement_group_table_data = \
gcs_utils.PlacementGroupTableData.FromString(
placement_group_info)
return self._gen_placement_group_info(placement_group_table_data)
def placement_group_table(self, placement_group_id=None):
self._check_connected()
+84 -2
View File
@@ -375,6 +375,7 @@ def test_remove_pending_placement_group(ray_start_cluster):
# Create a placement group that cannot be scheduled now.
placement_group = ray.util.placement_group([{"GPU": 2}, {"CPU": 2}])
ray.util.remove_placement_group(placement_group)
# TODO(sang): Add state check here.
@ray.remote(num_cpus=4)
def f():
@@ -797,10 +798,10 @@ def test_mini_integration(ray_start_cluster):
pg_tasks = []
# total bundle gpu usage = bundles_per_pg * total_num_pg * per_bundle_gpus
# Note this is half of total
for _ in range(total_num_pg):
for index in range(total_num_pg):
pgs.append(
ray.util.placement_group(
name="name",
name=f"name{index}",
strategy="PACK",
bundles=[{
"GPU": per_bundle_gpus
@@ -1423,5 +1424,86 @@ ray.shutdown()
assert assert_alive_num_actor(4)
def test_named_placement_group(ray_start_cluster):
cluster = ray_start_cluster
for _ in range(2):
cluster.add_node(num_cpus=3)
cluster.wait_for_nodes()
info = ray.init(address=cluster.address)
global_placement_group_name = "named_placement_group"
# Create a detached placement group with name.
driver_code = f"""
import ray
ray.init(address="{info["redis_address"]}")
pg = ray.util.placement_group(
[{{"CPU": 1}} for _ in range(2)],
strategy="STRICT_SPREAD",
name="{global_placement_group_name}",
lifetime="detached")
ray.get(pg.ready())
ray.shutdown()
"""
run_string_as_driver(driver_code)
# Wait until the driver is reported as dead by GCS.
def is_job_done():
jobs = ray.jobs()
for job in jobs:
if "StopTime" in job:
return True
return False
wait_for_condition(is_job_done)
@ray.remote(num_cpus=1)
class Actor:
def ping(self):
return "pong"
# Get the named placement group and schedule a actor.
placement_group = ray.util.get_placement_group(global_placement_group_name)
assert placement_group is not None
assert placement_group.wait(5)
actor = Actor.options(
placement_group=placement_group,
placement_group_bundle_index=0).remote()
ray.get(actor.ping.remote())
# Create another placement group and make sure its creation will failed.
same_name_pg = ray.util.placement_group(
[{
"CPU": 1
} for _ in range(2)],
strategy="STRICT_SPREAD",
name=global_placement_group_name)
assert not same_name_pg.wait(10)
# Remove a named placement group and make sure the second creation
# will successful.
ray.util.remove_placement_group(placement_group)
same_name_pg = ray.util.placement_group(
[{
"CPU": 1
} for _ in range(2)],
strategy="STRICT_SPREAD",
name=global_placement_group_name)
assert same_name_pg.wait(10)
# Get a named placement group with a name that doesn't exist
# and make sure it will raise ValueError correctly.
error_count = 0
try:
ray.util.get_placement_group("inexistent_pg")
except ValueError:
error_count = error_count + 1
assert error_count == 1
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
+3 -1
View File
@@ -4,7 +4,8 @@ from ray.util.check_serialize import inspect_serializability
from ray.util.debug import log_once, disable_log_once_globally, \
enable_periodic_logging
from ray.util.placement_group import (placement_group, placement_group_table,
remove_placement_group)
remove_placement_group,
get_placement_group)
from ray.util import rpdb as pdb
from ray.util.serialization import register_serializer, deregister_serializer
@@ -19,6 +20,7 @@ __all__ = [
"pdb",
"placement_group",
"placement_group_table",
"get_placement_group",
"remove_placement_group",
"inspect_serializability",
"collective",
+25 -1
View File
@@ -4,6 +4,7 @@ from typing import (List, Dict, Optional, Union)
import ray
from ray._raylet import PlacementGroupID, ObjectRef
from ray.utils import hex_to_binary
bundle_reservation_check = None
@@ -145,7 +146,7 @@ class PlacementGroup:
def placement_group(bundles: List[Dict[str, float]],
strategy: str = "PACK",
name: str = "unnamed_group",
name: str = "",
lifetime=None) -> PlacementGroup:
"""Asynchronously creates a PlacementGroup.
@@ -211,6 +212,29 @@ def remove_placement_group(placement_group: PlacementGroup):
worker.core_worker.remove_placement_group(placement_group.id)
def get_placement_group(placement_group_name: str):
"""Get a placement group object with a global name.
Returns:
None if can't find a placement group with the given name.
The placement group object otherwise.
"""
if not placement_group_name:
raise ValueError(
"Please supply a non-empty value to get_placement_group")
worker = ray.worker.global_worker
worker.check_connected()
placement_group_info = ray.state.state.get_placement_group_by_name(
placement_group_name)
if placement_group_info is None:
raise ValueError(
f"Failed to look up actor with name: {placement_group_name}")
else:
return PlacementGroup(
PlacementGroupID(
hex_to_binary(placement_group_info["placement_group_id"])))
def placement_group_table(placement_group: PlacementGroup = None) -> list:
"""Get the state of the placement group from GCS.