mirror of
https://github.com/wassname/ray.git
synced 2026-07-06 00:31:32 +08:00
[rllib] Distributed exec workflow for impala (#8321)
This commit is contained in:
@@ -535,6 +535,29 @@ class ParallelIterator(Generic[T]):
|
||||
"ParallelUnion[{}, {}]".format(self, other),
|
||||
parent_iterators=self.parent_iterators + other.parent_iterators)
|
||||
|
||||
def select_shards(self,
|
||||
shards_to_keep: List[int]) -> "ParallelIterator[T]":
|
||||
"""Return a child iterator that only iterates over given shards.
|
||||
|
||||
It is the user's responsibility to ensure child iterators are operating
|
||||
over disjoint sub-sets of this iterator's shards.
|
||||
"""
|
||||
if len(self.actor_sets) > 1:
|
||||
raise ValueError("select_shards() is not allowed after union()")
|
||||
if len(shards_to_keep) == 0:
|
||||
raise ValueError("at least one shard must be selected")
|
||||
old_actor_set = self.actor_sets[0]
|
||||
new_actors = [
|
||||
a for (i, a) in enumerate(old_actor_set.actors)
|
||||
if i in shards_to_keep
|
||||
]
|
||||
assert len(new_actors) == len(shards_to_keep), "Invalid actor index"
|
||||
new_actor_set = _ActorSet(new_actors, old_actor_set.transforms)
|
||||
return ParallelIterator(
|
||||
[new_actor_set],
|
||||
"{}.select_shards({} total)".format(self, len(shards_to_keep)),
|
||||
parent_iterators=self.parent_iterators)
|
||||
|
||||
def num_shards(self) -> int:
|
||||
"""Return the number of worker actors backing this iterator."""
|
||||
return sum(len(a.actors) for a in self.actor_sets)
|
||||
|
||||
Reference in New Issue
Block a user