diff --git a/ci/travis/format.sh b/ci/travis/format.sh index 0f7ce3a6a..6cddf4d55 100755 --- a/ci/travis/format.sh +++ b/ci/travis/format.sh @@ -144,8 +144,7 @@ fi # Ensure import ordering # Make sure that for every import psutil; import setpproctitle # There's a import ray above it. - -python ci/travis/check_import_order.py . -s ci -s python/ray/pyarrow_files -s python/ray/thirdparty_files -s python/build +python ci/travis/check_import_order.py . -s ci -s python/ray/pyarrow_files -s python/ray/thirdparty_files -s python/build -s lib if ! git diff --quiet &>/dev/null; then echo 'Reformatted changed files. Please review and stage the changes.' diff --git a/python/ray/tests/test_iter.py b/python/ray/tests/test_iter.py index ebef30cc9..4aa097250 100644 --- a/python/ray/tests/test_iter.py +++ b/python/ray/tests/test_iter.py @@ -100,6 +100,40 @@ def test_local_shuffle(ray_start_regular_shared): assert value / len(freq_counter) > 0.2 +def test_repartition_less(ray_start_regular_shared): + it = from_range(9, num_shards=3) + it1 = it.repartition(2) + assert repr(it1) == ("ParallelIterator[from_range[9, " + + "shards=3].repartition[num_partitions=2]]") + + assert it1.num_shards() == 2 + shard_0_set = set(it1.get_shard(0)) + shard_1_set = set(it1.get_shard(1)) + assert shard_0_set == {0, 2, 3, 5, 6, 8} + assert shard_1_set == {1, 4, 7} + + +def test_repartition_more(ray_start_regular_shared): + it = from_range(100, 2).repartition(3) + assert it.num_shards() == 3 + assert set(it.get_shard(0)) == set(range(0, 50, 3)) | set( + (range(50, 100, 3))) + assert set( + it.get_shard(1)) == set(range(1, 50, 3)) | set(range(51, 100, 3)) + assert set( + it.get_shard(2)) == set(range(2, 50, 3)) | set(range(52, 100, 3)) + + +def test_repartition_consistent(ray_start_regular_shared): + # repartition should be deterministic + it1 = from_range(9, num_shards=1).repartition(2) + it2 = from_range(9, num_shards=1).repartition(2) + assert it1.num_shards() == 2 + assert it2.num_shards() == 2 + assert set(it1.get_shard(0)) == set(it2.get_shard(0)) + assert set(it1.get_shard(1)) == set(it2.get_shard(1)) + + def test_batch(ray_start_regular_shared): it = from_range(4, 1).batch(2) assert repr(it) == "ParallelIterator[from_range[4, shards=1].batch(2)]" diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index 2d5ba9a06..532b00d54 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -1,5 +1,6 @@ -from typing import TypeVar, Generic, Iterable, List, Callable, Any +import collections import random +from typing import TypeVar, Generic, Iterable, List, Callable, Any import ray @@ -253,8 +254,8 @@ class ParallelIterator(Generic[T]): randomness. Default value is None. Returns: - Returns a ParallelIterator with a local shuffle applied on the - base iterator + A ParallelIterator with a local shuffle applied on the base + iterator Examples: >>> it = from_range(10, 1).local_shuffle(shuffle_buffer_size=2) @@ -279,6 +280,83 @@ class ParallelIterator(Generic[T]): shuffle_buffer_size, str(seed) if seed is not None else "None")) + def repartition(self, num_partitions: int) -> "ParallelIterator[T]": + """Returns a new ParallelIterator instance with num_partitions shards. + + The new iterator contains the same data in this instance except with + num_partitions shards. The data is split in round-robin fashion for + the new ParallelIterator. + + Args: + num_partitions (int): The number of shards to use for the new + ParallelIterator + + Returns: + A ParallelIterator with num_partitions number of shards and the + data of this ParallelIterator split round-robin among the new + number of shards. + + Examples: + >>> it = from_range(8, 2) + >>> it = it.repartition(3) + >>> list(it.get_shard(0)) + [0, 4, 3, 7] + >>> list(it.get_shard(1)) + [1, 5] + >>> list(it.get_shard(2)) + [2, 6] + """ + + # initialize the local iterators for all the actors + all_actors = [] + for actor_set in self.actor_sets: + actor_set.init_actors() + all_actors.extend(actor_set.actors) + + def base_iterator(num_partitions, partition_index, timeout=None): + futures = {} + for a in all_actors: + futures[a.par_iter_slice.remote( + step=num_partitions, start=partition_index)] = a + while futures: + pending = list(futures) + if timeout is None: + # First try to do a batch wait for efficiency. + ready, _ = ray.wait( + pending, num_returns=len(pending), timeout=0) + # Fall back to a blocking wait. + if not ready: + ready, _ = ray.wait(pending, num_returns=1) + else: + ready, _ = ray.wait( + pending, num_returns=len(pending), timeout=timeout) + for obj_id in ready: + actor = futures.pop(obj_id) + try: + yield ray.get(obj_id) + futures[actor.par_iter_slice.remote( + step=num_partitions, + start=partition_index)] = actor + except StopIteration: + pass + # Always yield after each round of wait with timeout. + if timeout is not None: + yield _NextValueNotReady() + + def make_gen_i(i): + return lambda: base_iterator(num_partitions, i) + + name = self.name + ".repartition[num_partitions={}]".format( + num_partitions) + + generators = [make_gen_i(s) for s in range(num_partitions)] + worker_cls = ray.remote(ParallelIteratorWorker) + actors = [worker_cls.remote(g, repeat=False) for g in generators] + x = ParallelIterator([_ActorSet(actors, [])], name) + # need explicit reference to self so actors in this instance do not die + x.parent_iterator = self + return x + def gather_sync(self) -> "LocalIterator[T]": """Returns a local iterable for synchronous iteration. @@ -710,6 +788,7 @@ class ParallelIteratorWorker(object): self.transforms = [] self.local_it = None + self.next_ith_buffer = None def par_iter_init(self, transforms): """Implements ParallelIterator worker init.""" @@ -724,6 +803,29 @@ class ParallelIteratorWorker(object): assert self.local_it is not None, "must call par_iter_init()" return next(self.local_it) + def par_iter_slice(self, step: int, start: int): + """Iterates in increments of step starting from start.""" + assert self.local_it is not None, "must call par_iter_init()" + + if self.next_ith_buffer is None: + self.next_ith_buffer = collections.defaultdict(list) + + index_buffer = self.next_ith_buffer[start] + if len(index_buffer) > 0: + return index_buffer.pop(0) + else: + for j in range(step): + try: + val = next(self.local_it) + self.next_ith_buffer[j].append(val) + except StopIteration: + pass + + if not self.next_ith_buffer[start]: + raise StopIteration + + return self.next_ith_buffer[start].pop(0) + class _NextValueNotReady(Exception): """Indicates that a local iterator has no value currently available.