mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 00:35:15 +08:00
[Parallel Iterators] Allow for operator chaining after repartition (#7268)
* bug fix repartition * change add_transform from private to inner * formatting * addressing comments * formatting
This commit is contained in:
@@ -161,15 +161,16 @@ def test_local_shuffle(ray_start_regular_shared):
|
||||
|
||||
def test_repartition_less(ray_start_regular_shared):
|
||||
it = from_range(9, num_shards=3)
|
||||
it1 = it.repartition(2)
|
||||
# chaining operations after a repartition should work
|
||||
it1 = it.repartition(2).for_each(lambda x: 2 * x)
|
||||
assert repr(it1) == ("ParallelIterator[from_range[9, " +
|
||||
"shards=3].repartition[num_partitions=2]]")
|
||||
"shards=3].repartition[num_partitions=2].for_each()]")
|
||||
|
||||
assert it1.num_shards() == 2
|
||||
shard_0_set = set(it1.get_shard(0))
|
||||
shard_1_set = set(it1.get_shard(1))
|
||||
assert shard_0_set == {0, 2, 3, 5, 6, 8}
|
||||
assert shard_1_set == {1, 4, 7}
|
||||
assert shard_0_set == {0, 4, 6, 10, 12, 16}
|
||||
assert shard_1_set == {2, 8, 14}
|
||||
|
||||
|
||||
def test_repartition_more(ray_start_regular_shared):
|
||||
@@ -187,11 +188,17 @@ def test_repartition_consistent(ray_start_regular_shared):
|
||||
# repartition should be deterministic
|
||||
it1 = from_range(9, num_shards=1).repartition(2)
|
||||
it2 = from_range(9, num_shards=1).repartition(2)
|
||||
# union should work after repartition
|
||||
it3 = it1.union(it2)
|
||||
assert it1.num_shards() == 2
|
||||
assert it2.num_shards() == 2
|
||||
assert set(it1.get_shard(0)) == set(it2.get_shard(0))
|
||||
assert set(it1.get_shard(1)) == set(it2.get_shard(1))
|
||||
|
||||
assert it3.num_shards() == 4
|
||||
assert set(it3.gather_async()) == set(it1.gather_async()) | set(
|
||||
it2.gather_async())
|
||||
|
||||
|
||||
def test_batch(ray_start_regular_shared):
|
||||
it = from_range(4, 1).batch(2)
|
||||
|
||||
+36
-39
@@ -53,7 +53,11 @@ def from_range(n: int, num_shards: int = 2,
|
||||
generators.append(range(start, end))
|
||||
name = "from_range[{}, shards={}{}]".format(
|
||||
n, num_shards, ", repeat=True" if repeat else "")
|
||||
return from_iterators(generators, repeat=repeat, name=name)
|
||||
return from_iterators(
|
||||
generators,
|
||||
repeat=repeat,
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
def from_iterators(generators: List[Iterable[T]],
|
||||
@@ -99,7 +103,7 @@ def from_actors(actors: List["ray.actor.ActorHandle"],
|
||||
"""
|
||||
if not name:
|
||||
name = "from_actors[shards={}]".format(len(actors))
|
||||
return ParallelIterator([_ActorSet(actors, [])], name)
|
||||
return ParallelIterator([_ActorSet(actors, [])], name, parent_iterators=[])
|
||||
|
||||
|
||||
class ParallelIterator(Generic[T]):
|
||||
@@ -151,13 +155,17 @@ class ParallelIterator(Generic[T]):
|
||||
... [worker_1_result_2, worker_2_result_2]
|
||||
"""
|
||||
|
||||
def __init__(self, actor_sets: List["_ActorSet"], name: str):
|
||||
def __init__(self, actor_sets: List["_ActorSet"], name: str,
|
||||
parent_iterators: List["ParallelIterator[Any]"]):
|
||||
"""Create a parallel iterator (this is an internal function)."""
|
||||
|
||||
# We track multiple sets of actors to support parallel .union().
|
||||
self.actor_sets = actor_sets
|
||||
self.name = name
|
||||
|
||||
# keep explicit reference to parent iterator for repartition
|
||||
self.parent_iterators = parent_iterators
|
||||
|
||||
def __iter__(self):
|
||||
raise TypeError(
|
||||
"You must use it.gather_sync() or it.gather_async() to "
|
||||
@@ -169,6 +177,13 @@ class ParallelIterator(Generic[T]):
|
||||
def __repr__(self):
|
||||
return "ParallelIterator[{}]".format(self.name)
|
||||
|
||||
def _with_transform(self, local_it_fn, name):
|
||||
"""Helper function to create new Parallel Iterator"""
|
||||
return ParallelIterator(
|
||||
[a.with_transform(local_it_fn) for a in self.actor_sets],
|
||||
name=self.name + name,
|
||||
parent_iterators=self.parent_iterators)
|
||||
|
||||
def for_each(self, fn: Callable[[T], U]) -> "ParallelIterator[U]":
|
||||
"""Remotely apply fn to each item in this iterator.
|
||||
|
||||
@@ -179,12 +194,8 @@ class ParallelIterator(Generic[T]):
|
||||
>>> next(from_range(4).for_each(lambda x: x * 2).gather_sync())
|
||||
... [0, 2, 4, 8]
|
||||
"""
|
||||
return ParallelIterator(
|
||||
[
|
||||
a.with_transform(lambda local_it: local_it.for_each(fn))
|
||||
for a in self.actor_sets
|
||||
],
|
||||
name=self.name + ".for_each()")
|
||||
return self._with_transform(lambda local_it: local_it.for_each(fn),
|
||||
".for_each()")
|
||||
|
||||
def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]":
|
||||
"""Remotely filter items from this iterator.
|
||||
@@ -197,12 +208,8 @@ class ParallelIterator(Generic[T]):
|
||||
>>> next(it.gather_sync())
|
||||
... [1, 2]
|
||||
"""
|
||||
return ParallelIterator(
|
||||
[
|
||||
a.with_transform(lambda local_it: local_it.filter(fn))
|
||||
for a in self.actor_sets
|
||||
],
|
||||
name=self.name + ".filter()")
|
||||
return self._with_transform(lambda local_it: local_it.filter(fn),
|
||||
".filter()")
|
||||
|
||||
def batch(self, n: int) -> "ParallelIterator[List[T]]":
|
||||
"""Remotely batch together items in this iterator.
|
||||
@@ -214,12 +221,8 @@ class ParallelIterator(Generic[T]):
|
||||
>>> next(from_range(10, 1).batch(4).gather_sync())
|
||||
... [0, 1, 2, 3]
|
||||
"""
|
||||
return ParallelIterator(
|
||||
[
|
||||
a.with_transform(lambda local_it: local_it.batch(n))
|
||||
for a in self.actor_sets
|
||||
],
|
||||
name=self.name + ".batch({})".format(n))
|
||||
return self._with_transform(lambda local_it: local_it.batch(n),
|
||||
".batch({})".format(n))
|
||||
|
||||
def flatten(self) -> "ParallelIterator[T[0]]":
|
||||
"""Flatten batches of items into individual items.
|
||||
@@ -228,12 +231,8 @@ class ParallelIterator(Generic[T]):
|
||||
>>> next(from_range(10, 1).batch(4).flatten())
|
||||
... 0
|
||||
"""
|
||||
return ParallelIterator(
|
||||
[
|
||||
a.with_transform(lambda local_it: local_it.flatten())
|
||||
for a in self.actor_sets
|
||||
],
|
||||
name=self.name + ".flatten()")
|
||||
return self._with_transform(lambda local_it: local_it.flatten(),
|
||||
".flatten()")
|
||||
|
||||
def combine(self, fn: Callable[[T], List[U]]) -> "ParallelIterator[U]":
|
||||
"""Transform and then combine items horizontally.
|
||||
@@ -273,13 +272,8 @@ class ParallelIterator(Generic[T]):
|
||||
>>> next(it)
|
||||
1
|
||||
"""
|
||||
return ParallelIterator(
|
||||
[
|
||||
a.with_transform(
|
||||
lambda localit: localit.shuffle(shuffle_buffer_size, seed))
|
||||
for a in self.actor_sets
|
||||
],
|
||||
name=self.name +
|
||||
return self._with_transform(
|
||||
lambda local_it: local_it.shuffle(shuffle_buffer_size, seed),
|
||||
".local_shuffle(shuffle_buffer_size={}, seed={})".format(
|
||||
shuffle_buffer_size,
|
||||
str(seed) if seed is not None else "None"))
|
||||
@@ -356,10 +350,9 @@ class ParallelIterator(Generic[T]):
|
||||
generators = [make_gen_i(s) for s in range(num_partitions)]
|
||||
worker_cls = ray.remote(ParallelIteratorWorker)
|
||||
actors = [worker_cls.remote(g, repeat=False) for g in generators]
|
||||
x = ParallelIterator([_ActorSet(actors, [])], name)
|
||||
# need explicit reference to self so actors in this instance do not die
|
||||
x.parent_iterator = self
|
||||
return x
|
||||
return ParallelIterator(
|
||||
[_ActorSet(actors, [])], name, parent_iterators=[self])
|
||||
|
||||
def gather_sync(self) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for synchronous iteration.
|
||||
@@ -491,8 +484,12 @@ class ParallelIterator(Generic[T]):
|
||||
actor_sets = []
|
||||
actor_sets.extend(self.actor_sets)
|
||||
actor_sets.extend(other.actor_sets)
|
||||
return ParallelIterator(actor_sets, "ParallelUnion[{}, {}]".format(
|
||||
self, other))
|
||||
# if one of these iterators is a result of a repartition, we need to
|
||||
# keep an explicit reference to its parent iterator
|
||||
return ParallelIterator(
|
||||
actor_sets,
|
||||
"ParallelUnion[{}, {}]".format(self, other),
|
||||
parent_iterators=self.parent_iterators + other.parent_iterators)
|
||||
|
||||
def num_shards(self) -> int:
|
||||
"""Return the number of worker actors backing this iterator."""
|
||||
|
||||
Reference in New Issue
Block a user