mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 08:12:53 +08:00
[Placement Group] Support named placement group (#13755)
This commit is contained in:
@@ -32,4 +32,6 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil:
|
||||
c_bool AddWorkerInfo(const c_string &serialized_string)
|
||||
unique_ptr[c_string] GetPlacementGroupInfo(
|
||||
const CPlacementGroupID &placement_group_id)
|
||||
unique_ptr[c_string] GetPlacementGroupByName(
|
||||
const c_string &placement_group_name)
|
||||
c_vector[c_string] GetAllPlacementGroupInfo()
|
||||
|
||||
@@ -147,3 +147,13 @@ cdef class GlobalStateAccessor:
|
||||
if result:
|
||||
return c_string(result.get().data(), result.get().size())
|
||||
return None
|
||||
|
||||
def get_placement_group_by_name(self, placement_group_name):
|
||||
cdef unique_ptr[c_string] result
|
||||
cdef c_string cplacement_group_name = placement_group_name
|
||||
with nogil:
|
||||
result = self.inner.get().GetPlacementGroupByName(
|
||||
cplacement_group_name)
|
||||
if result:
|
||||
return c_string(result.get().data(), result.get().size())
|
||||
return None
|
||||
|
||||
@@ -388,6 +388,20 @@ class GlobalState:
|
||||
|
||||
return dict(result)
|
||||
|
||||
def get_placement_group_by_name(self, placement_group_name):
|
||||
self._check_connected()
|
||||
|
||||
placement_group_info = (
|
||||
self.global_state_accessor.get_placement_group_by_name(
|
||||
placement_group_name))
|
||||
if placement_group_info is None:
|
||||
return None
|
||||
else:
|
||||
placement_group_table_data = \
|
||||
gcs_utils.PlacementGroupTableData.FromString(
|
||||
placement_group_info)
|
||||
return self._gen_placement_group_info(placement_group_table_data)
|
||||
|
||||
def placement_group_table(self, placement_group_id=None):
|
||||
self._check_connected()
|
||||
|
||||
|
||||
@@ -375,6 +375,7 @@ def test_remove_pending_placement_group(ray_start_cluster):
|
||||
# Create a placement group that cannot be scheduled now.
|
||||
placement_group = ray.util.placement_group([{"GPU": 2}, {"CPU": 2}])
|
||||
ray.util.remove_placement_group(placement_group)
|
||||
|
||||
# TODO(sang): Add state check here.
|
||||
@ray.remote(num_cpus=4)
|
||||
def f():
|
||||
@@ -797,10 +798,10 @@ def test_mini_integration(ray_start_cluster):
|
||||
pg_tasks = []
|
||||
# total bundle gpu usage = bundles_per_pg * total_num_pg * per_bundle_gpus
|
||||
# Note this is half of total
|
||||
for _ in range(total_num_pg):
|
||||
for index in range(total_num_pg):
|
||||
pgs.append(
|
||||
ray.util.placement_group(
|
||||
name="name",
|
||||
name=f"name{index}",
|
||||
strategy="PACK",
|
||||
bundles=[{
|
||||
"GPU": per_bundle_gpus
|
||||
@@ -1423,5 +1424,86 @@ ray.shutdown()
|
||||
assert assert_alive_num_actor(4)
|
||||
|
||||
|
||||
def test_named_placement_group(ray_start_cluster):
|
||||
cluster = ray_start_cluster
|
||||
for _ in range(2):
|
||||
cluster.add_node(num_cpus=3)
|
||||
cluster.wait_for_nodes()
|
||||
info = ray.init(address=cluster.address)
|
||||
global_placement_group_name = "named_placement_group"
|
||||
|
||||
# Create a detached placement group with name.
|
||||
driver_code = f"""
|
||||
import ray
|
||||
|
||||
ray.init(address="{info["redis_address"]}")
|
||||
|
||||
pg = ray.util.placement_group(
|
||||
[{{"CPU": 1}} for _ in range(2)],
|
||||
strategy="STRICT_SPREAD",
|
||||
name="{global_placement_group_name}",
|
||||
lifetime="detached")
|
||||
ray.get(pg.ready())
|
||||
|
||||
ray.shutdown()
|
||||
"""
|
||||
|
||||
run_string_as_driver(driver_code)
|
||||
|
||||
# Wait until the driver is reported as dead by GCS.
|
||||
def is_job_done():
|
||||
jobs = ray.jobs()
|
||||
for job in jobs:
|
||||
if "StopTime" in job:
|
||||
return True
|
||||
return False
|
||||
|
||||
wait_for_condition(is_job_done)
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class Actor:
|
||||
def ping(self):
|
||||
return "pong"
|
||||
|
||||
# Get the named placement group and schedule a actor.
|
||||
placement_group = ray.util.get_placement_group(global_placement_group_name)
|
||||
assert placement_group is not None
|
||||
assert placement_group.wait(5)
|
||||
actor = Actor.options(
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=0).remote()
|
||||
|
||||
ray.get(actor.ping.remote())
|
||||
|
||||
# Create another placement group and make sure its creation will failed.
|
||||
same_name_pg = ray.util.placement_group(
|
||||
[{
|
||||
"CPU": 1
|
||||
} for _ in range(2)],
|
||||
strategy="STRICT_SPREAD",
|
||||
name=global_placement_group_name)
|
||||
assert not same_name_pg.wait(10)
|
||||
|
||||
# Remove a named placement group and make sure the second creation
|
||||
# will successful.
|
||||
ray.util.remove_placement_group(placement_group)
|
||||
same_name_pg = ray.util.placement_group(
|
||||
[{
|
||||
"CPU": 1
|
||||
} for _ in range(2)],
|
||||
strategy="STRICT_SPREAD",
|
||||
name=global_placement_group_name)
|
||||
assert same_name_pg.wait(10)
|
||||
|
||||
# Get a named placement group with a name that doesn't exist
|
||||
# and make sure it will raise ValueError correctly.
|
||||
error_count = 0
|
||||
try:
|
||||
ray.util.get_placement_group("inexistent_pg")
|
||||
except ValueError:
|
||||
error_count = error_count + 1
|
||||
assert error_count == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
@@ -4,7 +4,8 @@ from ray.util.check_serialize import inspect_serializability
|
||||
from ray.util.debug import log_once, disable_log_once_globally, \
|
||||
enable_periodic_logging
|
||||
from ray.util.placement_group import (placement_group, placement_group_table,
|
||||
remove_placement_group)
|
||||
remove_placement_group,
|
||||
get_placement_group)
|
||||
from ray.util import rpdb as pdb
|
||||
from ray.util.serialization import register_serializer, deregister_serializer
|
||||
|
||||
@@ -19,6 +20,7 @@ __all__ = [
|
||||
"pdb",
|
||||
"placement_group",
|
||||
"placement_group_table",
|
||||
"get_placement_group",
|
||||
"remove_placement_group",
|
||||
"inspect_serializability",
|
||||
"collective",
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import (List, Dict, Optional, Union)
|
||||
|
||||
import ray
|
||||
from ray._raylet import PlacementGroupID, ObjectRef
|
||||
from ray.utils import hex_to_binary
|
||||
|
||||
bundle_reservation_check = None
|
||||
|
||||
@@ -145,7 +146,7 @@ class PlacementGroup:
|
||||
|
||||
def placement_group(bundles: List[Dict[str, float]],
|
||||
strategy: str = "PACK",
|
||||
name: str = "unnamed_group",
|
||||
name: str = "",
|
||||
lifetime=None) -> PlacementGroup:
|
||||
"""Asynchronously creates a PlacementGroup.
|
||||
|
||||
@@ -211,6 +212,29 @@ def remove_placement_group(placement_group: PlacementGroup):
|
||||
worker.core_worker.remove_placement_group(placement_group.id)
|
||||
|
||||
|
||||
def get_placement_group(placement_group_name: str):
|
||||
"""Get a placement group object with a global name.
|
||||
|
||||
Returns:
|
||||
None if can't find a placement group with the given name.
|
||||
The placement group object otherwise.
|
||||
"""
|
||||
if not placement_group_name:
|
||||
raise ValueError(
|
||||
"Please supply a non-empty value to get_placement_group")
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
placement_group_info = ray.state.state.get_placement_group_by_name(
|
||||
placement_group_name)
|
||||
if placement_group_info is None:
|
||||
raise ValueError(
|
||||
f"Failed to look up actor with name: {placement_group_name}")
|
||||
else:
|
||||
return PlacementGroup(
|
||||
PlacementGroupID(
|
||||
hex_to_binary(placement_group_info["placement_group_id"])))
|
||||
|
||||
|
||||
def placement_group_table(placement_group: PlacementGroup = None) -> list:
|
||||
"""Get the state of the placement group from GCS.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user