[Parallel Iterators] Allow for operator chaining after repartition (#7268)

* bug fix repartition

* change add_transform from private to inner

* formatting

* addressing comments

* formatting
This commit is contained in:
Eric Liang
2020-03-04 14:42:52 -08:00
committed by GitHub
parent c7f0b303f3
commit 476b5c6196
3 changed files with 48 additions and 45 deletions
+11 -4
View File
@@ -161,15 +161,16 @@ def test_local_shuffle(ray_start_regular_shared):
def test_repartition_less(ray_start_regular_shared):
it = from_range(9, num_shards=3)
it1 = it.repartition(2)
# chaining operations after a repartition should work
it1 = it.repartition(2).for_each(lambda x: 2 * x)
assert repr(it1) == ("ParallelIterator[from_range[9, " +
"shards=3].repartition[num_partitions=2]]")
"shards=3].repartition[num_partitions=2].for_each()]")
assert it1.num_shards() == 2
shard_0_set = set(it1.get_shard(0))
shard_1_set = set(it1.get_shard(1))
assert shard_0_set == {0, 2, 3, 5, 6, 8}
assert shard_1_set == {1, 4, 7}
assert shard_0_set == {0, 4, 6, 10, 12, 16}
assert shard_1_set == {2, 8, 14}
def test_repartition_more(ray_start_regular_shared):
@@ -187,11 +188,17 @@ def test_repartition_consistent(ray_start_regular_shared):
# repartition should be deterministic
it1 = from_range(9, num_shards=1).repartition(2)
it2 = from_range(9, num_shards=1).repartition(2)
# union should work after repartition
it3 = it1.union(it2)
assert it1.num_shards() == 2
assert it2.num_shards() == 2
assert set(it1.get_shard(0)) == set(it2.get_shard(0))
assert set(it1.get_shard(1)) == set(it2.get_shard(1))
assert it3.num_shards() == 4
assert set(it3.gather_async()) == set(it1.gather_async()) | set(
it2.gather_async())
def test_batch(ray_start_regular_shared):
it = from_range(4, 1).batch(2)
+36 -39
View File
@@ -53,7 +53,11 @@ def from_range(n: int, num_shards: int = 2,
generators.append(range(start, end))
name = "from_range[{}, shards={}{}]".format(
n, num_shards, ", repeat=True" if repeat else "")
return from_iterators(generators, repeat=repeat, name=name)
return from_iterators(
generators,
repeat=repeat,
name=name,
)
def from_iterators(generators: List[Iterable[T]],
@@ -99,7 +103,7 @@ def from_actors(actors: List["ray.actor.ActorHandle"],
"""
if not name:
name = "from_actors[shards={}]".format(len(actors))
return ParallelIterator([_ActorSet(actors, [])], name)
return ParallelIterator([_ActorSet(actors, [])], name, parent_iterators=[])
class ParallelIterator(Generic[T]):
@@ -151,13 +155,17 @@ class ParallelIterator(Generic[T]):
... [worker_1_result_2, worker_2_result_2]
"""
def __init__(self, actor_sets: List["_ActorSet"], name: str):
def __init__(self, actor_sets: List["_ActorSet"], name: str,
parent_iterators: List["ParallelIterator[Any]"]):
"""Create a parallel iterator (this is an internal function)."""
# We track multiple sets of actors to support parallel .union().
self.actor_sets = actor_sets
self.name = name
# keep explicit reference to parent iterator for repartition
self.parent_iterators = parent_iterators
def __iter__(self):
raise TypeError(
"You must use it.gather_sync() or it.gather_async() to "
@@ -169,6 +177,13 @@ class ParallelIterator(Generic[T]):
def __repr__(self):
return "ParallelIterator[{}]".format(self.name)
def _with_transform(self, local_it_fn, name):
"""Helper function to create new Parallel Iterator"""
return ParallelIterator(
[a.with_transform(local_it_fn) for a in self.actor_sets],
name=self.name + name,
parent_iterators=self.parent_iterators)
def for_each(self, fn: Callable[[T], U]) -> "ParallelIterator[U]":
"""Remotely apply fn to each item in this iterator.
@@ -179,12 +194,8 @@ class ParallelIterator(Generic[T]):
>>> next(from_range(4).for_each(lambda x: x * 2).gather_sync())
... [0, 2, 4, 8]
"""
return ParallelIterator(
[
a.with_transform(lambda local_it: local_it.for_each(fn))
for a in self.actor_sets
],
name=self.name + ".for_each()")
return self._with_transform(lambda local_it: local_it.for_each(fn),
".for_each()")
def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]":
"""Remotely filter items from this iterator.
@@ -197,12 +208,8 @@ class ParallelIterator(Generic[T]):
>>> next(it.gather_sync())
... [1, 2]
"""
return ParallelIterator(
[
a.with_transform(lambda local_it: local_it.filter(fn))
for a in self.actor_sets
],
name=self.name + ".filter()")
return self._with_transform(lambda local_it: local_it.filter(fn),
".filter()")
def batch(self, n: int) -> "ParallelIterator[List[T]]":
"""Remotely batch together items in this iterator.
@@ -214,12 +221,8 @@ class ParallelIterator(Generic[T]):
>>> next(from_range(10, 1).batch(4).gather_sync())
... [0, 1, 2, 3]
"""
return ParallelIterator(
[
a.with_transform(lambda local_it: local_it.batch(n))
for a in self.actor_sets
],
name=self.name + ".batch({})".format(n))
return self._with_transform(lambda local_it: local_it.batch(n),
".batch({})".format(n))
def flatten(self) -> "ParallelIterator[T[0]]":
"""Flatten batches of items into individual items.
@@ -228,12 +231,8 @@ class ParallelIterator(Generic[T]):
>>> next(from_range(10, 1).batch(4).flatten())
... 0
"""
return ParallelIterator(
[
a.with_transform(lambda local_it: local_it.flatten())
for a in self.actor_sets
],
name=self.name + ".flatten()")
return self._with_transform(lambda local_it: local_it.flatten(),
".flatten()")
def combine(self, fn: Callable[[T], List[U]]) -> "ParallelIterator[U]":
"""Transform and then combine items horizontally.
@@ -273,13 +272,8 @@ class ParallelIterator(Generic[T]):
>>> next(it)
1
"""
return ParallelIterator(
[
a.with_transform(
lambda localit: localit.shuffle(shuffle_buffer_size, seed))
for a in self.actor_sets
],
name=self.name +
return self._with_transform(
lambda local_it: local_it.shuffle(shuffle_buffer_size, seed),
".local_shuffle(shuffle_buffer_size={}, seed={})".format(
shuffle_buffer_size,
str(seed) if seed is not None else "None"))
@@ -356,10 +350,9 @@ class ParallelIterator(Generic[T]):
generators = [make_gen_i(s) for s in range(num_partitions)]
worker_cls = ray.remote(ParallelIteratorWorker)
actors = [worker_cls.remote(g, repeat=False) for g in generators]
x = ParallelIterator([_ActorSet(actors, [])], name)
# need explicit reference to self so actors in this instance do not die
x.parent_iterator = self
return x
return ParallelIterator(
[_ActorSet(actors, [])], name, parent_iterators=[self])
def gather_sync(self) -> "LocalIterator[T]":
"""Returns a local iterable for synchronous iteration.
@@ -491,8 +484,12 @@ class ParallelIterator(Generic[T]):
actor_sets = []
actor_sets.extend(self.actor_sets)
actor_sets.extend(other.actor_sets)
return ParallelIterator(actor_sets, "ParallelUnion[{}, {}]".format(
self, other))
# if one of these iterators is a result of a repartition, we need to
# keep an explicit reference to its parent iterator
return ParallelIterator(
actor_sets,
"ParallelUnion[{}, {}]".format(self, other),
parent_iterators=self.parent_iterators + other.parent_iterators)
def num_shards(self) -> int:
"""Return the number of worker actors backing this iterator."""