diff --git a/python/ray/tests/test_iter.py b/python/ray/tests/test_iter.py index 441ed6eb1..f1e140bc7 100644 --- a/python/ray/tests/test_iter.py +++ b/python/ray/tests/test_iter.py @@ -329,12 +329,58 @@ def test_gather_async(ray_start_regular_shared): assert sorted(it) == [0, 1, 2, 3] -def test_gather_async_queue(ray_start_regular_shared): +def test_gather_async_optimized(ray_start_regular_shared): it = from_range(100) - it = it.gather_async(num_async=4) + it = it.gather_async(batch_ms=100, num_async=4) assert sorted(it) == list(range(100)) +def test_get_shard_optimized(ray_start_regular_shared): + it = from_range(6, num_shards=3) + shard1 = it.get_shard(shard_index=0, batch_ms=25, num_async=2) + shard2 = it.get_shard(shard_index=1, batch_ms=15, num_async=3) + shard3 = it.get_shard(shard_index=2, batch_ms=5, num_async=4) + assert list(shard1) == [0, 1] + assert list(shard2) == [2, 3] + assert list(shard3) == [4, 5] + + +# Tested on 5/13/20 +# Run on 2019 Macbook Pro with 8 cores, 16 threads +# 14.52 sec +# 14.64 sec +# 0.935 sec +# 0.515 sec +""" +def test_gather_async_optimized_benchmark(ray_start_regular_shared): + import numpy as np + import tensorflow as tf + train, _ = tf.keras.datasets.fashion_mnist.load_data() + images, labels = train + num_bytes = images.nbytes / 1e6 + items = list(images) + it = from_items(items, num_shards=4) + it = it.for_each(lambda img: img/255) + #local_it = it.gather_async(batch_ms=0, num_async=1) + #local_it = it.gather_async(batch_ms=0, num_async=3) + #local_it = it.gather_async(batch_ms=10, num_async=1) + #local_it = it.gather_async(batch_ms=10, num_async=3) + + # dummy iterations + for i in range(20): + record = next(local_it) + + start_time = time.time() + #print(start_time) + count = 0 + for record in local_it: + count += 1 + assert count == len(items) - 20 + end_time = time.time() - start_time + print(end_time) +""" + + def test_batch_across_shards(ray_start_regular_shared): it = from_iterators([[0, 1], [2, 3]]) it = it.batch_across_shards() diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index 07eedaf02..542a53aec 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -2,6 +2,7 @@ from contextlib import contextmanager import collections import random import threading +import time from typing import TypeVar, Generic, Iterable, List, Callable, Any import ray @@ -64,21 +65,28 @@ def from_range(n: int, num_shards: int = 2, def from_iterators(generators: List[Iterable[T]], repeat: bool = False, name=None) -> "ParallelIterator[T]": - """Create a parallel iterator from a set of iterators. + """Create a parallel iterator from a list of iterables. + An iterable can be a conatiner (list, str, tuple, set, etc.), + a generator, or a custom class that implements __iter__ or __getitem__. - An actor will be created for each iterator. + An actor will be created for each iterable. Examples: >>> # Create using a list of generators. >>> from_iterators([range(100), range(100)]) - >>> # Equivalent to the above. - >>> from_iterators([lambda: range(100), lambda: range(100)]) + >>> # Certain generators are not serializable. + >>> from_iterators([(x for x in range(100))]) + ... TypeError: can't pickle generator objects + + >>> # So use lambda functions instead. + >>> # Lambda functions are serializable. + >>> from_iterators([lambda: (x for x in range(100))]) Args: - generators (list): A list of Python generator objects or lambda - functions that produced a generator when called. We allow lambda - functions since the generator itself might not be serializable, + generators (list): A list of Python iterables or lambda + functions that produce an iterable when called. We allow lambda + functions since certain generators might not be serializable, but a lambda that returns it can be. repeat (bool): Whether to cycle over the iterators forever. name (str): Optional name to give the iterator. @@ -188,9 +196,9 @@ class ParallelIterator(Generic[T]): def for_each(self, fn: Callable[[T], U], max_concurrency=1, resources=None) -> "ParallelIterator[U]": """Remotely apply fn to each item in this iterator, at most `max_concurrency` - at a time. + at a time per shard. - If `max_concurrency` == 1 then `fn` will be executed serially by the + If `max_concurrency` == 1 then `fn` will be executed serially by each shards `max_concurrency` should be used to achieve a high degree of @@ -200,7 +208,7 @@ class ParallelIterator(Generic[T]): necessarily finish first) A performance note: When executing concurrently, this function - maintains its own internal buffer. If `async_queue_depth` is `n` and + maintains its own internal buffer. If `num_async` is `n` and max_concur is `k` then the total number of buffered objects could be up to `n + k - 1` @@ -311,7 +319,8 @@ class ParallelIterator(Generic[T]): shuffle_buffer_size, str(seed) if seed is not None else "None")) - def repartition(self, num_partitions: int) -> "ParallelIterator[T]": + def repartition(self, num_partitions: int, + batch_ms: int = 0) -> "ParallelIterator[T]": """Returns a new ParallelIterator instance with num_partitions shards. The new iterator contains the same data in this instance except with @@ -321,6 +330,9 @@ class ParallelIterator(Generic[T]): Args: num_partitions (int): The number of shards to use for the new ParallelIterator + batch_ms (int): Batches items for batch_ms milliseconds + on each shard before retrieving it. + Increasing batch_ms increases latency but improves throughput. Returns: A ParallelIterator with num_partitions number of shards and the @@ -347,8 +359,10 @@ class ParallelIterator(Generic[T]): def base_iterator(num_partitions, partition_index, timeout=None): futures = {} for a in all_actors: - futures[a.par_iter_slice.remote( - step=num_partitions, start=partition_index)] = a + futures[a.par_iter_slice_batch.remote( + step=num_partitions, + start=partition_index, + batch_ms=batch_ms)] = a while futures: pending = list(futures) if timeout is None: @@ -364,10 +378,13 @@ class ParallelIterator(Generic[T]): for obj_id in ready: actor = futures.pop(obj_id) try: - yield ray.get(obj_id) - futures[actor.par_iter_slice.remote( + batch = ray.get(obj_id) + futures[actor.par_iter_slice_batch.remote( step=num_partitions, - start=partition_index)] = actor + start=partition_index, + batch_ms=batch_ms)] = actor + for item in batch: + yield item except StopIteration: pass # Always yield after each round of wait with timeout. @@ -447,13 +464,17 @@ class ParallelIterator(Generic[T]): name = "{}.batch_across_shards()".format(self) return LocalIterator(base_iterator, SharedMetrics(), name=name) - def gather_async(self, num_async=1) -> "LocalIterator[T]": + def gather_async(self, batch_ms=0, num_async=1) -> "LocalIterator[T]": """Returns a local iterable for asynchronous iteration. New items will be fetched from the shards asynchronously as soon as the previous one is computed. Items arrive in non-deterministic order. Arguments: + batch_ms (int): Batches items for batch_ms milliseconds + on each shard before retrieving it. + Increasing batch_ms increases latency but improves throughput. + If this value is 0, then items are returned immediately. num_async (int): The max number of async requests in flight per actor. Increasing this improves the amount of pipeline parallelism in the iterator. @@ -470,6 +491,8 @@ class ParallelIterator(Generic[T]): if num_async < 1: raise ValueError("queue depth must be positive") + if batch_ms < 0: + raise ValueError("batch time must be positive") # Forward reference to the returned iterator. local_iter = None @@ -482,7 +505,7 @@ class ParallelIterator(Generic[T]): futures = {} for _ in range(num_async): for a in all_actors: - futures[a.par_iter_next.remote()] = a + futures[a.par_iter_next_batch.remote(batch_ms)] = a while futures: pending = list(futures) if timeout is None: @@ -499,8 +522,11 @@ class ParallelIterator(Generic[T]): actor = futures.pop(obj_id) try: local_iter.shared_metrics.get().current_actor = actor - yield ray.get(obj_id) - futures[actor.par_iter_next.remote()] = actor + batch = ray.get(obj_id) + futures[actor.par_iter_next_batch.remote( + batch_ms)] = actor + for item in batch: + yield item except StopIteration: pass # Always yield after each round of wait with timeout. @@ -522,7 +548,7 @@ class ParallelIterator(Generic[T]): def union(self, other: "ParallelIterator[T]") -> "ParallelIterator[T]": """Return an iterator that is the union of this and the other.""" if not isinstance(other, ParallelIterator): - raise ValueError( + raise TypeError( "other must be of type ParallelIterator, got {}".format( type(other))) actor_sets = [] @@ -566,12 +592,29 @@ class ParallelIterator(Generic[T]): """Return the list of all shards.""" return [self.get_shard(i) for i in range(self.num_shards())] - def get_shard(self, shard_index: int) -> "LocalIterator[T]": + def get_shard(self, + shard_index: int, + batch_ms: int = 0, + num_async: int = 1) -> "LocalIterator[T]": """Return a local iterator for the given shard. The iterator is guaranteed to be serializable and can be passed to remote tasks or actors. + + Arguments: + shard_index (int): Index of the shard to gather. + batch_ms (int): Batches items for batch_ms milliseconds + before retrieving it. + Increasing batch_ms increases latency but improves throughput. + If this value is 0, then items are returned immediately. + num_async (int): The max number of requests in flight. + Increasing this improves the amount of pipeline + parallelism in the iterator. """ + if num_async < 1: + raise ValueError("num async must be positive") + if batch_ms < 0: + raise ValueError("batch time must be positive") a, t = None, None i = shard_index for actor_set in self.actor_sets: @@ -586,10 +629,16 @@ class ParallelIterator(Generic[T]): self.num_shards()) def base_iterator(timeout=None): + queue = collections.deque() ray.get(a.par_iter_init.remote(t)) + for _ in range(num_async): + queue.append(a.par_iter_next_batch.remote(batch_ms)) while True: try: - yield ray.get(a.par_iter_next.remote(), timeout=timeout) + batch = ray.get(queue.popleft(), timeout=timeout) + queue.append(a.par_iter_next_batch.remote(batch_ms)) + for item in batch: + yield item # Always yield after each round of gets with timeout. if timeout is not None: yield _NextValueNotReady() @@ -919,7 +968,10 @@ class LocalIterator(Generic[T]): yield _NextValueNotReady() else: if len(queues[i]) == 0: - fill_next(timeout) + try: + fill_next(timeout) + except StopIteration: + return yield queues[i].popleft() return gen @@ -1023,7 +1075,7 @@ class ParallelIteratorWorker(object): Subclasses must call this init function. Args: - item_generator (obj): A Python generator objects or lambda function + item_generator (obj): A Python iterable or lambda function that produces a generator when called. We allow lambda functions since the generator itself might not be serializable, but a lambda that returns it can be. @@ -1040,7 +1092,13 @@ class ParallelIteratorWorker(object): def cycle(): while True: - it = make_iterator() + it = iter(make_iterator()) + if it is item_generator: + raise ValueError( + "Cannot iterate over {} multiple times." + + "Please pass in the base iterable or" + + "lambda: {} instead.".format( + item_generator, item_generator)) for item in it: yield item @@ -1066,6 +1124,23 @@ class ParallelIteratorWorker(object): assert self.local_it is not None, "must call par_iter_init()" return next(self.local_it) + def par_iter_next_batch(self, batch_ms: int): + """Batches par_iter_next.""" + batch = [] + if batch_ms == 0: + batch.append(self.par_iter_next()) + return batch + t_end = time.time() + (0.001 * batch_ms) + while time.time() < t_end: + try: + batch.append(self.par_iter_next()) + except StopIteration: + if len(batch) == 0: + raise StopIteration + else: + pass + return batch + def par_iter_slice(self, step: int, start: int): """Iterates in increments of step starting from start.""" assert self.local_it is not None, "must call par_iter_init()" @@ -1089,6 +1164,23 @@ class ParallelIteratorWorker(object): return self.next_ith_buffer[start].pop(0) + def par_iter_slice_batch(self, step: int, start: int, batch_ms: int): + """Batches par_iter_slice.""" + batch = [] + if batch_ms == 0: + batch.append(self.par_iter_slice(step, start)) + return batch + t_end = time.time() + (0.001 * batch_ms) + while time.time() < t_end: + try: + batch.append(self.par_iter_slice(step, start)) + except StopIteration: + if len(batch) == 0: + raise StopIteration + else: + pass + return batch + def _randomized_int_cast(float_value): base = int(float_value)