mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 22:53:49 +08:00
[rllib] Distributed exec workflow for impala (#8321)
This commit is contained in:
@@ -9,6 +9,14 @@ from ray.util.iter import from_items, from_iterators, from_range, \
|
||||
from ray.test_utils import Semaphore
|
||||
|
||||
|
||||
def test_select_shards(ray_start_regular_shared):
|
||||
it = from_items([1, 2, 3, 4], num_shards=4)
|
||||
it1 = it.select_shards([0, 2])
|
||||
it2 = it.select_shards([1, 3])
|
||||
assert it1.take(4) == [1, 3]
|
||||
assert it2.take(4) == [2, 4]
|
||||
|
||||
|
||||
def test_metrics(ray_start_regular_shared):
|
||||
it = from_items([1, 2, 3, 4], num_shards=1)
|
||||
it2 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
|
||||
@@ -535,6 +535,29 @@ class ParallelIterator(Generic[T]):
|
||||
"ParallelUnion[{}, {}]".format(self, other),
|
||||
parent_iterators=self.parent_iterators + other.parent_iterators)
|
||||
|
||||
def select_shards(self,
|
||||
shards_to_keep: List[int]) -> "ParallelIterator[T]":
|
||||
"""Return a child iterator that only iterates over given shards.
|
||||
|
||||
It is the user's responsibility to ensure child iterators are operating
|
||||
over disjoint sub-sets of this iterator's shards.
|
||||
"""
|
||||
if len(self.actor_sets) > 1:
|
||||
raise ValueError("select_shards() is not allowed after union()")
|
||||
if len(shards_to_keep) == 0:
|
||||
raise ValueError("at least one shard must be selected")
|
||||
old_actor_set = self.actor_sets[0]
|
||||
new_actors = [
|
||||
a for (i, a) in enumerate(old_actor_set.actors)
|
||||
if i in shards_to_keep
|
||||
]
|
||||
assert len(new_actors) == len(shards_to_keep), "Invalid actor index"
|
||||
new_actor_set = _ActorSet(new_actors, old_actor_set.transforms)
|
||||
return ParallelIterator(
|
||||
[new_actor_set],
|
||||
"{}.select_shards({} total)".format(self, len(shards_to_keep)),
|
||||
parent_iterators=self.parent_iterators)
|
||||
|
||||
def num_shards(self) -> int:
|
||||
"""Return the number of worker actors backing this iterator."""
|
||||
return sum(len(a.actors) for a in self.actor_sets)
|
||||
|
||||
Reference in New Issue
Block a user