From 476b5c6196fa734794e395a53d2506e7c8485d12 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 4 Mar 2020 14:42:52 -0800 Subject: [PATCH] [Parallel Iterators] Allow for operator chaining after repartition (#7268) * bug fix repartition * change add_transform from private to inner * formatting * addressing comments * formatting --- doc/source/iter.rst | 3 +- python/ray/tests/test_iter.py | 15 +++++-- python/ray/util/iter.py | 75 +++++++++++++++++------------------ 3 files changed, 48 insertions(+), 45 deletions(-) diff --git a/doc/source/iter.rst b/doc/source/iter.rst index 1ea42fbd5..d46d1a096 100644 --- a/doc/source/iter.rst +++ b/doc/source/iter.rst @@ -1,4 +1,4 @@ -Distributed Iterators +Parallel Iterators ===================== .. _`issue on GitHub`: https://github.com/ray-project/ray/issues @@ -204,4 +204,3 @@ API Reference .. automodule:: ray.util.iter :members: :show-inheritance: - :special-members: diff --git a/python/ray/tests/test_iter.py b/python/ray/tests/test_iter.py index 276b26a8c..fdac5167b 100644 --- a/python/ray/tests/test_iter.py +++ b/python/ray/tests/test_iter.py @@ -161,15 +161,16 @@ def test_local_shuffle(ray_start_regular_shared): def test_repartition_less(ray_start_regular_shared): it = from_range(9, num_shards=3) - it1 = it.repartition(2) + # chaining operations after a repartition should work + it1 = it.repartition(2).for_each(lambda x: 2 * x) assert repr(it1) == ("ParallelIterator[from_range[9, " + - "shards=3].repartition[num_partitions=2]]") + "shards=3].repartition[num_partitions=2].for_each()]") assert it1.num_shards() == 2 shard_0_set = set(it1.get_shard(0)) shard_1_set = set(it1.get_shard(1)) - assert shard_0_set == {0, 2, 3, 5, 6, 8} - assert shard_1_set == {1, 4, 7} + assert shard_0_set == {0, 4, 6, 10, 12, 16} + assert shard_1_set == {2, 8, 14} def test_repartition_more(ray_start_regular_shared): @@ -187,11 +188,17 @@ def test_repartition_consistent(ray_start_regular_shared): # repartition should be deterministic it1 = from_range(9, num_shards=1).repartition(2) it2 = from_range(9, num_shards=1).repartition(2) + # union should work after repartition + it3 = it1.union(it2) assert it1.num_shards() == 2 assert it2.num_shards() == 2 assert set(it1.get_shard(0)) == set(it2.get_shard(0)) assert set(it1.get_shard(1)) == set(it2.get_shard(1)) + assert it3.num_shards() == 4 + assert set(it3.gather_async()) == set(it1.gather_async()) | set( + it2.gather_async()) + def test_batch(ray_start_regular_shared): it = from_range(4, 1).batch(2) diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index c2455f93f..6c9583c88 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -53,7 +53,11 @@ def from_range(n: int, num_shards: int = 2, generators.append(range(start, end)) name = "from_range[{}, shards={}{}]".format( n, num_shards, ", repeat=True" if repeat else "") - return from_iterators(generators, repeat=repeat, name=name) + return from_iterators( + generators, + repeat=repeat, + name=name, + ) def from_iterators(generators: List[Iterable[T]], @@ -99,7 +103,7 @@ def from_actors(actors: List["ray.actor.ActorHandle"], """ if not name: name = "from_actors[shards={}]".format(len(actors)) - return ParallelIterator([_ActorSet(actors, [])], name) + return ParallelIterator([_ActorSet(actors, [])], name, parent_iterators=[]) class ParallelIterator(Generic[T]): @@ -151,13 +155,17 @@ class ParallelIterator(Generic[T]): ... [worker_1_result_2, worker_2_result_2] """ - def __init__(self, actor_sets: List["_ActorSet"], name: str): + def __init__(self, actor_sets: List["_ActorSet"], name: str, + parent_iterators: List["ParallelIterator[Any]"]): """Create a parallel iterator (this is an internal function).""" # We track multiple sets of actors to support parallel .union(). self.actor_sets = actor_sets self.name = name + # keep explicit reference to parent iterator for repartition + self.parent_iterators = parent_iterators + def __iter__(self): raise TypeError( "You must use it.gather_sync() or it.gather_async() to " @@ -169,6 +177,13 @@ class ParallelIterator(Generic[T]): def __repr__(self): return "ParallelIterator[{}]".format(self.name) + def _with_transform(self, local_it_fn, name): + """Helper function to create new Parallel Iterator""" + return ParallelIterator( + [a.with_transform(local_it_fn) for a in self.actor_sets], + name=self.name + name, + parent_iterators=self.parent_iterators) + def for_each(self, fn: Callable[[T], U]) -> "ParallelIterator[U]": """Remotely apply fn to each item in this iterator. @@ -179,12 +194,8 @@ class ParallelIterator(Generic[T]): >>> next(from_range(4).for_each(lambda x: x * 2).gather_sync()) ... [0, 2, 4, 8] """ - return ParallelIterator( - [ - a.with_transform(lambda local_it: local_it.for_each(fn)) - for a in self.actor_sets - ], - name=self.name + ".for_each()") + return self._with_transform(lambda local_it: local_it.for_each(fn), + ".for_each()") def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]": """Remotely filter items from this iterator. @@ -197,12 +208,8 @@ class ParallelIterator(Generic[T]): >>> next(it.gather_sync()) ... [1, 2] """ - return ParallelIterator( - [ - a.with_transform(lambda local_it: local_it.filter(fn)) - for a in self.actor_sets - ], - name=self.name + ".filter()") + return self._with_transform(lambda local_it: local_it.filter(fn), + ".filter()") def batch(self, n: int) -> "ParallelIterator[List[T]]": """Remotely batch together items in this iterator. @@ -214,12 +221,8 @@ class ParallelIterator(Generic[T]): >>> next(from_range(10, 1).batch(4).gather_sync()) ... [0, 1, 2, 3] """ - return ParallelIterator( - [ - a.with_transform(lambda local_it: local_it.batch(n)) - for a in self.actor_sets - ], - name=self.name + ".batch({})".format(n)) + return self._with_transform(lambda local_it: local_it.batch(n), + ".batch({})".format(n)) def flatten(self) -> "ParallelIterator[T[0]]": """Flatten batches of items into individual items. @@ -228,12 +231,8 @@ class ParallelIterator(Generic[T]): >>> next(from_range(10, 1).batch(4).flatten()) ... 0 """ - return ParallelIterator( - [ - a.with_transform(lambda local_it: local_it.flatten()) - for a in self.actor_sets - ], - name=self.name + ".flatten()") + return self._with_transform(lambda local_it: local_it.flatten(), + ".flatten()") def combine(self, fn: Callable[[T], List[U]]) -> "ParallelIterator[U]": """Transform and then combine items horizontally. @@ -273,13 +272,8 @@ class ParallelIterator(Generic[T]): >>> next(it) 1 """ - return ParallelIterator( - [ - a.with_transform( - lambda localit: localit.shuffle(shuffle_buffer_size, seed)) - for a in self.actor_sets - ], - name=self.name + + return self._with_transform( + lambda local_it: local_it.shuffle(shuffle_buffer_size, seed), ".local_shuffle(shuffle_buffer_size={}, seed={})".format( shuffle_buffer_size, str(seed) if seed is not None else "None")) @@ -356,10 +350,9 @@ class ParallelIterator(Generic[T]): generators = [make_gen_i(s) for s in range(num_partitions)] worker_cls = ray.remote(ParallelIteratorWorker) actors = [worker_cls.remote(g, repeat=False) for g in generators] - x = ParallelIterator([_ActorSet(actors, [])], name) # need explicit reference to self so actors in this instance do not die - x.parent_iterator = self - return x + return ParallelIterator( + [_ActorSet(actors, [])], name, parent_iterators=[self]) def gather_sync(self) -> "LocalIterator[T]": """Returns a local iterable for synchronous iteration. @@ -491,8 +484,12 @@ class ParallelIterator(Generic[T]): actor_sets = [] actor_sets.extend(self.actor_sets) actor_sets.extend(other.actor_sets) - return ParallelIterator(actor_sets, "ParallelUnion[{}, {}]".format( - self, other)) + # if one of these iterators is a result of a repartition, we need to + # keep an explicit reference to its parent iterator + return ParallelIterator( + actor_sets, + "ParallelUnion[{}, {}]".format(self, other), + parent_iterators=self.parent_iterators + other.parent_iterators) def num_shards(self) -> int: """Return the number of worker actors backing this iterator."""