[Placement Group] Capture Child Task Part 1 (#10968)

* In progress.

* In progers.

* Done.

* Addressed code review.

* Increase timeout to make a test less flaky.

* Addressed code review.

* Addressed code review.
This commit is contained in:
SangBin Cho
2020-09-24 09:02:03 -07:00
committed by GitHub
parent 46a560e876
commit 5e6b887f2d
11 changed files with 197 additions and 16 deletions
+5
View File
@@ -162,6 +162,11 @@ remove_placement_group
.. autofunction:: ray.util.placement_group.remove_placement_group
get_current_placement_group
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.util.placement_group.get_current_placement_group
Experimental APIs
-----------------
+5
View File
@@ -791,6 +791,11 @@ cdef class CoreWorker:
return ActorID(
CCoreWorkerProcess.GetCoreWorker().GetActorId().Binary())
def get_placement_group_id(self):
return PlacementGroupID(
CCoreWorkerProcess.GetCoreWorker()
.GetCurrentPlacementGroupId().Binary())
def set_webui_display(self, key, message):
CCoreWorkerProcess.GetCoreWorker().SetWebuiDisplay(key, message)
+1
View File
@@ -122,6 +122,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
CJobID GetCurrentJobId()
CTaskID GetCurrentTaskId()
CClientID GetCurrentNodeId()
CPlacementGroupID GetCurrentPlacementGroupId()
const CActorID &GetActorId()
void SetActorTitle(const c_string &title)
void SetWebuiDisplay(const c_string &key, const c_string &message)
+9
View File
@@ -45,6 +45,15 @@ class RuntimeContext(object):
actor_info = ray.state.actors(self.current_actor_id.hex())
return actor_info and actor_info["NumRestarts"] != 0
@property
def current_placement_group_id(self):
"""Get the current Placement group ID of this worker.
Returns:
The current placement group id of this worker.
"""
return self.worker.placement_group_id
_runtime_context = None
+5 -1
View File
@@ -6,6 +6,7 @@ import time
import ray
from google.protobuf.json_format import MessageToDict
from ray import (
gcs_utils,
services,
@@ -423,7 +424,10 @@ class GlobalState:
placement_group_info.placement_group_id),
"name": placement_group_info.name,
"bundles": {
bundle.bundle_id.bundle_index: bundle.unit_resources
# The value here is needs to be dictionarified
# otherwise, the payload becomes unserializable.
bundle.bundle_id.bundle_index:
MessageToDict(bundle)["unitResources"]
for bundle in placement_group_info.bundles
},
"strategy": get_strategy(placement_group_info.strategy),
+58 -3
View File
@@ -11,7 +11,8 @@ import ray
from ray.test_utils import get_other_nodes, wait_for_condition
import ray.cluster_utils
from ray._raylet import PlacementGroupID
from ray.util.placement_group import PlacementGroup
from ray.util.placement_group import (PlacementGroup,
get_current_placement_group)
def test_placement_group_pack(ray_start_cluster):
@@ -288,7 +289,7 @@ def test_remove_placement_group(ray_start_cluster):
# First try to remove a placement group that doesn't
# exist. This should not do anything.
random_group_id = PlacementGroupID.from_random()
random_placement_group = PlacementGroup(random_group_id, [{"CPU": 1}])
random_placement_group = PlacementGroup(random_group_id)
for _ in range(3):
ray.util.remove_placement_group(random_placement_group)
@@ -591,7 +592,7 @@ def test_placement_group_wait(ray_start_cluster):
assert table["state"] == "CREATED"
pg = ray.get(placement_group.ready())
assert pg.bundles == placement_group.bundles
assert pg.bundle_specs == placement_group.bundle_specs
assert pg.id.binary() == placement_group.id.binary()
@@ -791,5 +792,59 @@ def test_mini_integration(ray_start_cluster):
assert all(ray.get([a.ping.remote() for a in actors]))
def test_capture_child_tasks(ray_start_cluster):
cluster = ray_start_cluster
total_num_actors = 4
for _ in range(2):
cluster.add_node(num_cpus=total_num_actors)
ray.init(address=cluster.address)
pg = ray.util.placement_group(
[{
"CPU": 2
}, {
"CPU": 2
}], strategy="STRICT_PACK")
ray.get(pg.ready(), timeout=5)
# If get_current_placement_group is used when the current worker/driver
# doesn't belong to any of placement group, it should return None.
assert get_current_placement_group() is None
@ray.remote(num_cpus=1)
class NestedActor:
def ready(self):
return True
@ray.remote(num_cpus=1)
class Actor:
def __init__(self):
self.actors = []
def ready(self):
return True
def schedule_nested_actor(self):
actor = NestedActor.options(
placement_group=get_current_placement_group()).remote()
ray.get(actor.ready.remote())
self.actors.append(actor)
a = Actor.options(placement_group=pg).remote()
ray.get(a.ready.remote())
# 1 top level actor + 3 children.
for _ in range(total_num_actors - 1):
ray.get(a.schedule_nested_actor.remote())
# Make sure all the actors are scheduled on the same node.
# (why? The placement group has STRICT_PACK strategy).
node_id_set = set()
for actor_info in ray.actors().values():
node_id = actor_info["Address"]["NodeID"]
node_id_set.add(node_id)
# Since all node id should be identical, set should be equal to 1.
assert len(node_id_set) == 1
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
+75 -12
View File
@@ -1,4 +1,6 @@
from typing import (List, Dict)
import time
from typing import (List, Dict, Optional)
import ray
from ray._raylet import PlacementGroupID, ObjectRef
@@ -9,11 +11,11 @@ class PlacementGroup:
@staticmethod
def empty():
return PlacementGroup(PlacementGroupID.nil(), [])
return PlacementGroup(PlacementGroupID.nil())
def __init__(self, id: PlacementGroupID, bundles: List[Dict[str, float]]):
def __init__(self, id: PlacementGroupID):
self.id = id
self.bundles = bundles
self.bundle_cache = None
def ready(self) -> ObjectRef:
"""Returns an ObjectRef to check ready status.
@@ -29,22 +31,22 @@ class PlacementGroup:
>>> pg = placement_group([{"CPU": 1}])
ray.wait([pg.ready()], timeout=0)
"""
worker = ray.worker.global_worker
worker.check_connected()
self._fill_bundle_cache_if_needed()
@ray.remote(num_cpus=0, max_calls=0)
def bundle_reservation_check(placement_group):
return placement_group
assert len(self.bundles) != 0, (
assert len(self.bundle_cache) != 0, (
"ready() cannot be called on placement group object with a "
f"bundle length == 0, current bundle length: {len(self.bundles)}")
"bundle length == 0, current bundle length: "
f"{len(self.bundle_cache)}")
# Select the first bundle to schedule a dummy task.
# Since the placement group creation will be atomic, it is sufficient
# to schedule a single task.
bundle_index = 0
bundle = self.bundles[bundle_index]
bundle = self.bundle_cache[bundle_index]
resource_name, value = self._get_none_zero_resource(bundle)
num_cpus = 0
@@ -67,11 +69,13 @@ class PlacementGroup:
@property
def bundle_specs(self) -> List[Dict]:
"""List[Dict]: Return bundles belonging to this placement group."""
return self.bundles
self._fill_bundle_cache_if_needed()
return self.bundle_cache
@property
def bundle_count(self):
return len(self.bundles)
self._fill_bundle_cache_if_needed()
return len(self.bundle_cache)
def _get_none_zero_resource(self, bundle: List[Dict]):
for key, value in bundle.items():
@@ -80,6 +84,30 @@ class PlacementGroup:
return key, value
assert False, "This code should be unreachable."
def _fill_bundle_cache_if_needed(self):
if not self.bundle_cache:
# Since creating placement group is async, it is
# possible table is not ready yet. To avoid the
# problem, we should keep trying with timeout.
TIMEOUT_SECOND = 30
WAIT_INTERVAL = 0.05
timeout_cnt = 0
worker = ray.worker.global_worker
worker.check_connected()
while timeout_cnt < int(TIMEOUT_SECOND / WAIT_INTERVAL):
pg_info = ray.state.state.placement_group_table(self.id)
if pg_info:
self.bundle_cache = list(pg_info["bundles"].values())
return
time.sleep(WAIT_INTERVAL)
timeout_cnt += 1
raise RuntimeError(
"Couldn't get the bundle information of placement group id "
f"{self.id} in {TIMEOUT_SECOND} seconds. It is likely "
"because GCS server is too busy.")
def placement_group(bundles: List[Dict[str, float]],
strategy: str = "PACK",
@@ -120,7 +148,7 @@ def placement_group(bundles: List[Dict[str, float]],
placement_group_id = worker.core_worker.create_placement_group(
name, bundles, strategy)
return PlacementGroup(placement_group_id, bundles)
return PlacementGroup(placement_group_id)
def remove_placement_group(placement_group: PlacementGroup):
@@ -149,6 +177,41 @@ def placement_group_table(placement_group: PlacementGroup) -> dict:
return ray.state.state.placement_group_table(placement_group.id)
def get_current_placement_group() -> Optional[PlacementGroup]:
"""Get the current placement group which a task or actor is using.
It returns None if there's no current placement group for the worker.
For example, if you call this method in your driver, it returns None
(because drivers never belong to any placement group).
Examples:
>>> @ray.remote
>>> def f():
>>> # This will return the placement group the task f belongs to.
>>> # It means this pg will be identical to the pg created below.
>>> pg = get_current_placement_group()
>>> pg = placement_group([{"CPU": 2}])
>>> f.options(placement_group=pg).remote()
>>> # New script.
>>> ray.init()
>>> # New script doesn't belong to any placement group,
>>> # so it returns None.
>>> assert get_current_placement_group() is None
Return:
PlacementGroup: Placement group object.
None if the current task or actor wasn't
created with any placement group.
"""
pg_id = ray.runtime_context.get_runtime_context(
).current_placement_group_id
if pg_id.is_nil():
return None
return PlacementGroup(pg_id)
def check_placement_group_index(placement_group: PlacementGroup,
bundle_index: int):
assert placement_group is not None
+4
View File
@@ -155,6 +155,10 @@ class Worker:
def current_task_id(self):
return self.core_worker.get_current_task_id()
@property
def placement_group_id(self):
return self.core_worker.get_placement_group_id()
@property
def current_session_and_job(self):
"""Get the current session index and job id as pair."""
+27
View File
@@ -48,10 +48,19 @@ struct WorkerThreadContext {
void SetCurrentTaskId(const TaskID &task_id) { current_task_id_ = task_id; }
const PlacementGroupID &GetCurrentPlacementGroupId() const {
return current_placement_group_id_;
}
void SetCurrentPlacementGroupId(const PlacementGroupID &placement_group_id) {
current_placement_group_id_ = placement_group_id;
}
void SetCurrentTask(const TaskSpecification &task_spec) {
RAY_CHECK(task_index_ == 0);
RAY_CHECK(put_counter_ == 0);
SetCurrentTaskId(task_spec.TaskId());
SetCurrentPlacementGroupId(task_spec.PlacementGroupId());
current_task_ = std::make_shared<const TaskSpecification>(task_spec);
}
@@ -74,6 +83,13 @@ struct WorkerThreadContext {
/// A running counter for the number of object puts carried out in the current task.
/// Used to calculate the object index for put object ObjectIDs.
int put_counter_;
/// Placement group id that the current task belongs to.
/// NOTE: The top level `WorkerContext` will also have placement_group_id
/// which is set when actors are created. It is because we'd like to keep track
/// thread local placement group id for tasks, and the process placement group id for
/// actors.
PlacementGroupID current_placement_group_id_;
};
thread_local std::unique_ptr<WorkerThreadContext> WorkerContext::thread_context_ =
@@ -85,6 +101,7 @@ WorkerContext::WorkerContext(WorkerType worker_type, const WorkerID &worker_id,
worker_id_(worker_id),
current_job_id_(worker_type_ == WorkerType::DRIVER ? job_id : JobID::Nil()),
current_actor_id_(ActorID::Nil()),
current_actor_placement_group_id_(PlacementGroupID::Nil()),
main_thread_id_(boost::this_thread::get_id()) {
// For worker main thread which initializes the WorkerContext,
// set task_id according to whether current worker is a driver.
@@ -108,6 +125,15 @@ const TaskID &WorkerContext::GetCurrentTaskID() const {
return GetThreadContext().GetCurrentTaskID();
}
const PlacementGroupID &WorkerContext::GetCurrentPlacementGroupId() const {
// If the worker is an actor, we should return the actor's placement group id.
if (current_actor_id_ != ActorID::Nil()) {
return current_actor_placement_group_id_;
} else {
return GetThreadContext().GetCurrentPlacementGroupId();
}
}
void WorkerContext::SetCurrentJobId(const JobID &job_id) { current_job_id_ = job_id; }
void WorkerContext::SetCurrentTaskId(const TaskID &task_id) {
@@ -128,6 +154,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
current_actor_is_direct_call_ = true;
current_actor_max_concurrency_ = task_spec.MaxActorConcurrency();
current_actor_is_asyncio_ = task_spec.IsAsyncioActor();
current_actor_placement_group_id_ = task_spec.PlacementGroupId();
} else if (task_spec.IsActorTask()) {
RAY_CHECK(current_job_id_ == task_spec.JobId());
RAY_CHECK(current_actor_id_ == task_spec.ActorId());
+4
View File
@@ -35,6 +35,8 @@ class WorkerContext {
const TaskID &GetCurrentTaskID() const;
const PlacementGroupID &GetCurrentPlacementGroupId() const;
// TODO(edoakes): remove this once Python core worker uses the task interfaces.
void SetCurrentJobId(const JobID &job_id);
@@ -84,6 +86,8 @@ class WorkerContext {
ActorID current_actor_id_;
int current_actor_max_concurrency_ = 1;
bool current_actor_is_asyncio_ = false;
// The placement group id that the current actor belongs to.
PlacementGroupID current_actor_placement_group_id_;
/// The id of the (main) thread that constructed this worker context.
boost::thread::id main_thread_id_;
+4
View File
@@ -354,6 +354,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
return ClientID::FromBinary(rpc_address_.raylet_id());
}
const PlacementGroupID &GetCurrentPlacementGroupId() const {
return worker_context_.GetCurrentPlacementGroupId();
}
void SetWebuiDisplay(const std::string &key, const std::string &message);
void SetActorTitle(const std::string &title);