diff --git a/doc/source/index.rst b/doc/source/index.rst index e19e8194a..8f0779108 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -283,6 +283,7 @@ Getting Involved signals.rst async_api.rst serve.rst + iter.rst multiprocessing.rst .. toctree:: diff --git a/doc/source/iter.rst b/doc/source/iter.rst new file mode 100644 index 000000000..ecff24480 --- /dev/null +++ b/doc/source/iter.rst @@ -0,0 +1,165 @@ +Parallel Iterator API (Experimental) +==================================== + +``ray.experimental.iter`` provides a parallel iterator API for simple data ingest +and processing. It can be thought of as syntactic sugar around Ray actors and ``ray.wait`` loops. + +Parallel iterators are lazy and can operate over infinite sequences of items. Iterator +transformations are only executed when the user calls ``next()`` to fetch the next output +item from the iterator. + +Concepts +-------- + +**Parallel Iterators**: You can create a ``ParallelIterator`` object from an existing +set of items, range of numbers, set of iterators, or set of worker actors. Ray will +create a worker actor that produces the data for each shard of the iterator: + +.. code-block:: python + + # Create an iterator with 2 worker actors over the list [1, 2, 3, 4]. + >>> it = ray.experimental.iter.from_items([1, 2, 3, 4], num_shards=2) + ParallelIterator[from_items[int, 4, shards=2]] + + # Create an iterator with 32 worker actors over range(1000000). + >>> it = ray.experimental.iter.from_range(1000000, num_shards=32) + ParallelIterator[from_range[1000000, shards=32]] + + # Create an iterator over two range(10) generators. + >>> it = ray.experimental.iter.from_iterators([range(10), range(10)]) + ParallelIterator[from_iterators[shards=2]] + + # Create an iterator from existing worker actors. These actors must + # implement the ParallelIteratorWorker interface. + >>> it = ray.experimental.iter.from_actors([a1, a2, a3, a4]) + ParallelIterator[from_actors[shards=4]] + +Simple transformations can be chained on the iterator, such as mapping, +filtering, and batching. These will be executed in parallel on the workers: + +.. code-block:: python + + # Apply a transformation to each element of the iterator. + >>> it = it.for_each(lambda x: x ** 2) + ParallelIterator[...].for_each() + + # Batch together items into a lists of 32 elements. + >>> it = it.batch(32) + ParallelIterator[...].for_each().batch(32) + + # Filter out items with odd values. + >>> it = it.filter(lambda x: x % 2 == 0) + ParallelIterator[...].for_each().batch(32).filter() + +**Local Iterators**: To read elements from a parallel iterator, it has to be converted +to a ``LocalIterator`` by calling ``gather_sync()`` or ``gather_async()``. These +correspond to ``ray.get`` and ``ray.wait`` loops over the actors respectively: + +.. code-block:: python + + # Gather items synchronously (deterministic round robin across shards): + >>> it = ray.experimental.iter.from_range(1000000, 1) + >>> it = it.gather_sync() + LocalIterator[ParallelIterator[from_range[1000000, shards=1]].gather_sync()] + + # Local iterators can be used as any other Python iterator. + >>> it.take(5) + [0, 1, 2, 3, 4] + + # They also support chaining of transformations. Unlike transformations + # applied on a ParallelIterator, they will be executed in the current process. + >>> it.filter(lambda x: x % 2 == 0).take(5) + [0, 2, 4, 6, 8] + + # Async gather can be used for better performance, but it is non-deterministic. + >>> it = ray.experimental.iter.from_range(1000, 4).gather_async() + >>> it.take(5) + [0, 250, 500, 750, 1] + +**Passing iterators to remote functions**: Both ``ParallelIterator`` and ``LocalIterator`` +are serializable. They can be passed to any Ray remote function. However, note that +each shard should only be read by one process at a time: + +.. code-block:: python + + # Get local iterators representing the shards of this ParallelIterator: + >>> it = ray.experimental.iter.from_range(10000, 3) + >>> [s0, s1, s2] = it.shards() + [LocalIterator[from_range[10000, shards=3].shard[0]], + LocalIterator[from_range[10000, shards=3].shard[1]], + LocalIterator[from_range[10000, shards=3].shard[2]]] + + # Iterator shards can be passed to remote functions. + >>> @ray.remote + ... def do_sum(it): + ... return sum(it) + ... + >>> ray.get([do_sum.remote(s) for s in it.shards()]) + [5552778, 16661667, 27780555] + +Example: Streaming word frequency count +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Parallel iterators can be used for simple data processing use cases such as +streaming grep: + +.. code-block:: python + + import ray + import glob + import gzip + import numpy as np + + ray.init() + + file_list = glob.glob("/var/log/syslog*.gz") + it = ( + ray.experimental.iter.from_items(file_list, num_shards=4) + .for_each(lambda f: gzip.open(f).readlines()) + .flatten() + .for_each(lambda line: line.decode("utf-8")) + .for_each(lambda line: 1 if "cron" in line else 0) + .batch(1024) + .for_each(np.mean) + ) + + # Show the probability of a log line containing "cron", with a + # sliding window of 1024 lines. + for freq in it.gather_async(): + print(freq) + +Example: Passing iterator shards to remote functions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Both parallel iterators and local iterators are fully serializable, so once +created you can pass them to Ray tasks and actors. This can be useful for +distributed training: + +.. code-block:: python + + import ray + import numpy as np + + ray.init() + + @ray.remote + def train(data_shard): + for batch in data_shard: + print("train on", batch) # perform model update with batch + + it = ( + ray.experimental.iter.from_range(1000000, num_shards=4, repeat=True) + .batch(1024) + .for_each(np.array) + ) + + work = [train.remote(shard) for shard in it.shards()] + ray.get(work) + +API Reference +------------- + +.. automodule:: ray.experimental.iter + :members: + :show-inheritance: + :special-members: diff --git a/python/ray/experimental/__init__.py b/python/ray/experimental/__init__.py index 676962707..f31b58249 100644 --- a/python/ray/experimental/__init__.py +++ b/python/ray/experimental/__init__.py @@ -8,6 +8,7 @@ from .named_actors import get_actor, register_actor from .api import get, wait from .actor_pool import ActorPool from .dynamic_resources import set_resource +from . import iter def TensorFlowVariables(*args, **kwargs): @@ -17,7 +18,15 @@ def TensorFlowVariables(*args, **kwargs): __all__ = [ - "TensorFlowVariables", "get_actor", "register_actor", "get", "wait", - "set_flushing_policy", "GcsFlushPolicy", "SimpleGcsFlushPolicy", - "set_resource", "ActorPool" + "TensorFlowVariables", + "get_actor", + "register_actor", + "get", + "wait", + "set_flushing_policy", + "GcsFlushPolicy", + "SimpleGcsFlushPolicy", + "set_resource", + "ActorPool", + "iter", ] diff --git a/python/ray/experimental/iter.py b/python/ray/experimental/iter.py new file mode 100644 index 000000000..d4aaf64ff --- /dev/null +++ b/python/ray/experimental/iter.py @@ -0,0 +1,668 @@ +from typing import TypeVar, Generic, Iterable, List, Callable, Any + +import ray + +# The type of an iterator element. +T = TypeVar("T") +U = TypeVar("U") + + +def from_items(items: List[T], num_shards: int = 2, + repeat: bool = False) -> "ParallelIterator[T]": + """Create a parallel iterator from an existing set of objects. + + The objects will be divided round-robin among the number of shards. + + Arguments: + items (list): The list of items to iterate over. + num_shards (int): The number of worker actors to create. + repeat (bool): Whether to cycle over the items forever. + """ + shards = [[] for _ in range(num_shards)] + for i, item in enumerate(items): + shards[i % num_shards].append(item) + name = "from_items[{}, {}, shards={}{}]".format( + items and type(items[0]).__name__ or "None", len(items), num_shards, + ", repeat=True" if repeat else "") + return from_iterators(shards, repeat=repeat, name=name) + + +def from_range(n: int, num_shards: int = 2, + repeat: bool = False) -> "ParallelIterator[int]": + """Create a parallel iterator over the range 0..n. + + The range will be partitioned sequentially among the number of shards. + + Arguments: + n (int): The max end of the range of numbers. + num_shards (int): The number of worker actors to create. + repeat (bool): Whether to cycle over the range forever. + """ + generators = [] + shard_size = n // num_shards + for i in range(num_shards): + start = i * shard_size + if i == num_shards - 1: + end = n + else: + end = (i + 1) * shard_size + generators.append(range(start, end)) + name = "from_range[{}, shards={}{}]".format( + n, num_shards, ", repeat=True" if repeat else "") + return from_iterators(generators, repeat=repeat, name=name) + + +def from_iterators(generators: List[Iterable[T]], + repeat: bool = False, + name=None) -> "ParallelIterator[T]": + """Create a parallel iterator from a set of iterators. + + An actor will be created for each iterator. + + Examples: + >>> # Create using a list of generators. + >>> from_iterators([range(100), range(100)]) + + >>> # Equivalent to the above. + >>> from_iterators([lambda: range(100), lambda: range(100)]) + + Arguments: + generators (list): A list of Python generator objects or lambda + functions that produced a generator when called. We allow lambda + functions since the generator itself might not be serializable, + but a lambda that returns it can be. + repeat (bool): Whether to cycle over the iterators forever. + name (str): Optional name to give the iterator. + """ + worker_cls = ray.remote(ParallelIteratorWorker) + actors = [worker_cls.remote(g, repeat) for g in generators] + if not name: + name = "from_iterators[shards={}{}]".format( + len(generators), ", repeat=True" if repeat else "") + return from_actors(actors, name=name) + + +def from_actors(actors: List["ray.actor.ActorHandle"], + name=None) -> "ParallelIterator[T]": + """Create a parallel iterator from an existing set of actors. + + Each actor must subclass the ParallelIteratorWorker interface. + + Arguments: + actors (list): List of actors that each implement + ParallelIteratorWorker. + name (str): Optional name to give the iterator. + """ + if not name: + name = "from_actors[shards={}]".format(len(actors)) + return ParallelIterator([_ActorSet(actors, [])], name) + + +class ParallelIterator(Generic[T]): + """A parallel iterator over a set of remote actors. + + This can be used to iterate over a fixed set of task results + (like an actor pool), or a stream of data (e.g., a fixed range of numbers, + an infinite stream of RLlib rollout results). + + This class is **serializable** and can be passed to other remote + tasks and actors. However, each shard should be read from at most one + process at a time. + + Examples: + >>> # Applying a function over items in parallel. + >>> it = ray.experimental.iter.from_items([1, 2, 3], num_shards=2) + ... <__main__.ParallelIterator object> + >>> it = it.for_each(lambda x: x * 2).gather_sync() + ... <__main__.LocalIterator object> + >>> print(list(it)) + ... [2, 4, 6] + + >>> # Creating from generators. + >>> it = ray.experimental.iter.from_iterators([range(3), range(3)]) + ... <__main__.ParallelIterator object> + >>> print(list(it.gather_sync())) + ... [0, 0, 1, 1, 2, 2] + + >>> # Accessing the individual shards of an iterator. + >>> it = ray.experimental.iter.from_range(10, num_shards=2) + ... <__main__.ParallelIterator object> + >>> it0 = it.get_shard(0) + ... <__main__.LocalIterator object> + >>> print(list(it0)) + ... [0, 1, 2, 3, 4] + >>> it1 = it.get_shard(1) + ... <__main__.LocalIterator object> + >>> print(list(it1)) + ... [5, 6, 7, 8, 9] + + >>> # Gathering results from actors synchronously in parallel. + >>> it = ray.experimental.iter.from_actors(workers) + ... <__main__.ParallelIterator object> + >>> it = it.batch_across_shards() + ... <__main__.LocalIterator object> + >>> print(next(it)) + ... [worker_1_result_1, worker_2_result_1] + >>> print(next(it)) + ... [worker_1_result_2, worker_2_result_2] + """ + + def __init__(self, actor_sets: List["_ActorSet"], name: str): + # We track multiple sets of actors to support parallel .union(). + self.actor_sets = actor_sets + self.name = name + + def __iter__(self): + raise TypeError( + "You must use it.gather_sync() or it.gather_async() to " + "iterate over the results of a ParallelIterator.") + + def __str__(self): + return repr(self) + + def __repr__(self): + return "ParallelIterator[{}]".format(self.name) + + def for_each(self, fn: Callable[[T], U]) -> "ParallelIterator[U]": + """Remotely apply fn to each item in this iterator. + + Arguments: + fn (func): function to apply to each item. + + Examples: + >>> next(from_range(4).for_each(lambda x: x * 2).gather_sync()) + ... [0, 2, 4, 8] + """ + return ParallelIterator( + [ + a.with_transform(lambda local_it: local_it.for_each(fn)) + for a in self.actor_sets + ], + name=self.name + ".for_each()") + + def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]": + """Remotely filter items from this iterator. + + Arguments: + fn (func): returns False for items to drop from the iterator. + + Examples: + >>> it = from_items([0, 1, 2]).filter(lambda x: x > 0) + >>> next(it.gather_sync()) + ... [1, 2] + """ + return ParallelIterator( + [ + a.with_transform(lambda local_it: local_it.filter(fn)) + for a in self.actor_sets + ], + name=self.name + ".filter()") + + def batch(self, n: int) -> "ParallelIterator[List[T]]": + """Remotely batch together items in this iterator. + + Arguments: + n (int): Number of items to batch together. + + Examples: + >>> next(from_range(10, 1).batch(4).gather_sync()) + ... [0, 1, 2, 3] + """ + return ParallelIterator( + [ + a.with_transform(lambda local_it: local_it.batch(n)) + for a in self.actor_sets + ], + name=self.name + ".batch({})".format(n)) + + def flatten(self) -> "ParallelIterator[T[0]]": + """Flatten batches of items into individual items. + + Examples: + >>> next(from_range(10, 1).batch(4).flatten()) + ... 0 + """ + return ParallelIterator( + [ + a.with_transform(lambda local_it: local_it.flatten()) + for a in self.actor_sets + ], + name=self.name + ".flatten()") + + def combine(self, fn: Callable[[T], List[U]]) -> "ParallelIterator[U]": + """Transform and then combine items horizontally. + + This is the equivalent of for_each(fn).flatten() (flat map). + """ + it = self.for_each(fn).flatten() + it.name = self.name + ".combine()" + return it + + def gather_sync(self) -> "LocalIterator[T]": + """Returns a local iterable for synchronous iteration. + + New items will be fetched from the shards on-demand as the iterator + is stepped through. + + This is the equivalent of batch_across_shards().flatten(). + + Examples: + >>> it = from_range(100, 1).gather_sync() + >>> next(it) + ... 0 + >>> next(it) + ... 1 + >>> next(it) + ... 2 + """ + it = self.batch_across_shards().flatten() + it.name = "{}.gather_sync()".format(self) + return it + + def batch_across_shards(self) -> "LocalIterator[List[T]]": + """Iterate over the results of multiple shards in parallel. + + Examples: + >>> it = from_iterators([range(3), range(3)]) + >>> next(it.batch_across_shards()) + ... [0, 0] + """ + + def base_iterator(timeout=None): + active = [] + for actor_set in self.actor_sets: + actor_set.init_actors() + active.extend(actor_set.actors) + futures = [a.par_iter_next.remote() for a in active] + while active: + try: + yield ray.get(futures, timeout=timeout) + futures = [a.par_iter_next.remote() for a in active] + # Always yield after each round of gets with timeout. + if timeout is not None: + yield _NextValueNotReady() + except TimeoutError: + yield _NextValueNotReady() + except StopIteration: + # Find and remove the actor that produced StopIteration. + results = [] + for a, f in zip(list(active), futures): + try: + results.append(ray.get(f)) + except StopIteration: + active.remove(a) + if results: + yield results + futures = [a.par_iter_next.remote() for a in active] + + name = "{}.batch_across_shards()".format(self) + return LocalIterator(base_iterator, name=name) + + def gather_async(self) -> "LocalIterator[T]": + """Returns a local iterable for asynchronous iteration. + + New items will be fetched from the shards asynchronously as soon as + the previous one is computed. Items arrive in non-deterministic order. + + Examples: + >>> it = from_range(100, 1).gather_async() + >>> next(it) + ... 3 + >>> next(it) + ... 0 + >>> next(it) + ... 1 + """ + + def base_iterator(timeout=None): + all_actors = [] + for actor_set in self.actor_sets: + actor_set.init_actors() + all_actors.extend(actor_set.actors) + futures = {} + for a in all_actors: + futures[a.par_iter_next.remote()] = a + while futures: + pending = list(futures) + if timeout is None: + # First try to do a batch wait for efficiency. + ready, _ = ray.wait( + pending, num_returns=len(pending), timeout=0) + # Fall back to a blocking wait. + if not ready: + ready, _ = ray.wait(pending, num_returns=1) + else: + ready, _ = ray.wait( + pending, num_returns=len(pending), timeout=timeout) + for obj_id in ready: + actor = futures.pop(obj_id) + try: + yield ray.get(obj_id) + futures[actor.par_iter_next.remote()] = actor + except StopIteration: + pass + # Always yield after each round of wait with timeout. + if timeout is not None: + yield _NextValueNotReady() + + name = "{}.gather_async()".format(self) + return LocalIterator(base_iterator, name=name) + + def take(self, n: int) -> List[T]: + """Return up to the first n items from this iterator.""" + return self.gather_sync().take(n) + + def show(self, n: int = 20): + """Print up to the first n items from this iterator.""" + return self.gather_sync().show(n) + + def union(self, other: "ParallelIterator[T]") -> "ParallelIterator[T]": + """Return an iterator that is the union of this and the other.""" + if not isinstance(other, ParallelIterator): + raise ValueError( + "other must be of type ParallelIterator, got {}".format( + type(other))) + actor_sets = [] + actor_sets.extend(self.actor_sets) + actor_sets.extend(other.actor_sets) + return ParallelIterator(actor_sets, "ParallelUnion[{}, {}]".format( + self, other)) + + def num_shards(self) -> int: + """Return the number of worker actors backing this iterator.""" + return sum(len(a.actors) for a in self.actor_sets) + + def shards(self) -> List["LocalIterator[T]"]: + """Return the list of all shards.""" + return [self.get_shard(i) for i in range(self.num_shards())] + + def get_shard(self, shard_index: int) -> "LocalIterator[T]": + """Return a local iterator for the given shard. + + The iterator is guaranteed to be serializable and can be passed to + remote tasks or actors. + """ + a, t = None, None + i = shard_index + for actor_set in self.actor_sets: + if i < len(actor_set.actors): + a = actor_set.actors[i] + t = actor_set.transforms + break + else: + i -= len(actor_set.actors) + if a is None: + raise ValueError("Shard index out of range", shard_index, + self.num_shards()) + + def base_iterator(timeout=None): + ray.get(a.par_iter_init.remote(t)) + while True: + try: + yield ray.get(a.par_iter_next.remote(), timeout=timeout) + # Always yield after each round of gets with timeout. + if timeout is not None: + yield _NextValueNotReady() + except TimeoutError: + yield _NextValueNotReady() + except StopIteration: + break + + name = self.name + ".shard[{}]".format(shard_index) + return LocalIterator(base_iterator, name=name) + + +class LocalIterator(Generic[T]): + """An iterator over a single shard of data. + + It implements similar transformations as ParallelIterator[T], but the + transforms will be applied locally and not remotely in parallel. + + This class is **serializable** and can be passed to other remote + tasks and actors. However, it should be read from at most one process at + a time.""" + + def __init__(self, + base_iterator: Callable[[], Iterable[T]], + local_transforms: List[Callable[[Iterable], Any]] = None, + timeout: int = None, + name=None): + """Create a local iterator (this is an internal function). + + Arguments: + base_iterator (func): A function that produces the base iterator. + This is a function so that we can ensure LocalIterator is + serializable. + local_transforms (list): A list of transformation functions to be + applied on top of the base iterator. When iteration begins, we + create the base iterator and apply these functions. This lazy + creation ensures LocalIterator is serializable until you start + iterating over it. + timeout (int): Optional timeout in seconds for this iterator, after + which _NextValueNotReady will be returned. This avoids + blocking. + name (str): Optional name for this iterator. + """ + self.base_iterator = base_iterator + self.built_iterator = None + self.local_transforms = local_transforms or [] + self.timeout = timeout + self.name = name or "unknown" + + def _build_once(self): + if self.built_iterator is None: + it = iter(self.base_iterator(self.timeout)) + for fn in self.local_transforms: + it = fn(it) + self.built_iterator = it + + def __iter__(self): + self._build_once() + return self.built_iterator + + def __next__(self): + self._build_once() + return next(self.built_iterator) + + def __str__(self): + return repr(self) + + def __repr__(self): + return "LocalIterator[{}]".format(self.name) + + def for_each(self, fn: Callable[[T], U]) -> "LocalIterator[U]": + def apply_foreach(it): + for item in it: + if isinstance(item, _NextValueNotReady): + yield item + else: + yield fn(item) + + return LocalIterator( + self.base_iterator, + self.local_transforms + [apply_foreach], + name=self.name + ".for_each()") + + def filter(self, fn: Callable[[T], bool]) -> "LocalIterator[T]": + def apply_filter(it): + for item in it: + if isinstance(item, _NextValueNotReady) or fn(item): + yield item + + return LocalIterator( + self.base_iterator, + self.local_transforms + [apply_filter], + name=self.name + ".filter()") + + def batch(self, n: int) -> "LocalIterator[List[T]]": + def apply_batch(it): + batch = [] + for item in it: + if isinstance(item, _NextValueNotReady): + yield item + else: + batch.append(item) + if len(batch) >= n: + yield batch + batch = [] + if batch: + yield batch + + return LocalIterator( + self.base_iterator, + self.local_transforms + [apply_batch], + name=self.name + ".batch({})".format(n)) + + def flatten(self) -> "LocalIterator[T[0]]": + def apply_flatten(it): + for item in it: + if isinstance(item, _NextValueNotReady): + yield item + else: + for subitem in item: + yield subitem + + return LocalIterator( + self.base_iterator, + self.local_transforms + [apply_flatten], + name=self.name + ".flatten()") + + def combine(self, fn: Callable[[T], List[U]]) -> "LocalIterator[U]": + it = self.for_each(fn).flatten() + it.name = self.name + ".combine()" + return it + + def take(self, n: int) -> List[T]: + """Return up to the first n items from this iterator.""" + out = [] + for item in self: + out.append(item) + if len(out) >= n: + break + return out + + def show(self, n: int = 20): + """Print up to the first n items from this iterator.""" + i = 0 + for item in self: + print(item) + i += 1 + if i >= n: + break + + def union(self, other: "LocalIterator[T]") -> "LocalIterator[T]": + """Return an iterator that is the union of this and the other. + + There are no ordering guarantees between the two iterators. We make a + best-effort attempt to return items from both as they become ready, + preventing starvation of any particular iterator. + """ + + if not isinstance(other, LocalIterator): + raise ValueError( + "other must be of type LocalIterator, got {}".format( + type(other))) + + it1 = LocalIterator( + self.base_iterator, self.local_transforms, timeout=0) + it2 = LocalIterator( + other.base_iterator, other.local_transforms, timeout=0) + active = [it1, it2] + + def build_union(timeout=None): + while True: + for it in list(active): + # Yield items from the iterator until _NextValueNotReady is + # found, then switch to the next iterator. + try: + while True: + item = next(it) + if isinstance(item, _NextValueNotReady): + break + else: + yield item + except StopIteration: + active.remove(it) + if not active: + break + + return LocalIterator( + build_union, [], name="LocalUnion[{}, {}]".format(self, other)) + + +class ParallelIteratorWorker(object): + """Worker actor for a ParallelIterator. + + Actors that are passed to iter.from_actors() must subclass this interface. + """ + + def __init__(self, item_generator: Any, repeat: bool): + """Create an iterator worker. + + Subclasses must call this init function. + + Arguments: + item_generator (obj): A Python generator objects or lambda function + that produces a generator when called. We allow lambda + functions since the generator itself might not be serializable, + but a lambda that returns it can be. + repeat (bool): Whether to loop over the iterator forever. + """ + + def make_iterator(): + if callable(item_generator): + return item_generator() + else: + return item_generator + + if repeat: + + def cycle(): + while True: + it = make_iterator() + for item in it: + yield item + + self.item_generator = cycle() + else: + self.item_generator = make_iterator() + + self.transforms = [] + self.local_it = None + + def par_iter_init(self, transforms): + """Implements ParallelIterator worker init.""" + it = LocalIterator(lambda timeout: self.item_generator) + for fn in transforms: + it = fn(it) + assert it is not None, fn + self.local_it = iter(it) + + def par_iter_next(self): + """Implements ParallelIterator worker item fetch.""" + assert self.local_it is not None, "must call par_iter_init()" + return next(self.local_it) + + +class _NextValueNotReady(Exception): + """Indicates that a local iterator has no value currently available. + + This is used internally to implement the union() of multiple blocking + local generators.""" + pass + + +class _ActorSet(object): + """Helper class that represents a set of actors and transforms.""" + + def __init__( + self, actors: List["ray.actor.ActorHandle"], + transforms: List[Callable[["LocalIterator"], "LocalIterator"]]): + self.actors = actors + self.transforms = transforms + + def init_actors(self): + ray.get([a.par_iter_init.remote(self.transforms) for a in self.actors]) + + def with_transform(self, fn): + return _ActorSet(self.actors, self.transforms + [fn]) diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 439d63ad0..c38ed1138 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -14,6 +14,14 @@ py_test( deps = ["//:ray_lib"], ) +py_test( + name = "test_iter", + size = "small", + srcs = ["test_iter.py"], + tags = ["exclusive"], + deps = ["//:ray_lib"], +) + py_test( name = "test_actor_resources", size = "medium", diff --git a/python/ray/tests/test_iter.py b/python/ray/tests/test_iter.py new file mode 100644 index 000000000..9b206fc8e --- /dev/null +++ b/python/ray/tests/test_iter.py @@ -0,0 +1,216 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import ray +from ray.experimental.iter import from_items, from_iterators, from_range, \ + from_actors, ParallelIteratorWorker + + +def test_from_items(ray_start_regular_shared): + it = from_items([1, 2, 3, 4]) + assert repr(it) == "ParallelIterator[from_items[int, 4, shards=2]]" + assert list(it.gather_sync()) == [1, 2, 3, 4] + assert next(it.gather_sync()) == 1 + + +def test_from_items_repeat(ray_start_regular_shared): + it = from_items([1, 2, 3, 4], repeat=True) + assert repr( + it) == "ParallelIterator[from_items[int, 4, shards=2, repeat=True]]" + assert it.take(8) == [1, 2, 3, 4, 1, 2, 3, 4] + + +def test_from_iterators(ray_start_regular_shared): + it = from_iterators([range(2), range(2)]) + assert repr(it) == "ParallelIterator[from_iterators[shards=2]]" + assert list(it.gather_sync()) == [0, 0, 1, 1] + + +def test_from_range(ray_start_regular_shared): + it = from_range(4) + assert repr(it) == "ParallelIterator[from_range[4, shards=2]]" + assert list(it.gather_sync()) == [0, 2, 1, 3] + + +def test_from_actors(ray_start_regular_shared): + @ray.remote + class CustomWorker(ParallelIteratorWorker): + def __init__(self, data): + ParallelIteratorWorker.__init__(self, data, False) + + a = CustomWorker.remote([1, 2]) + b = CustomWorker.remote([3, 4]) + it = from_actors([a, b]) + assert repr(it) == "ParallelIterator[from_actors[shards=2]]" + assert list(it.gather_sync()) == [1, 3, 2, 4] + + +def test_for_each(ray_start_regular_shared): + it = from_range(4).for_each(lambda x: x * 2) + assert repr(it) == "ParallelIterator[from_range[4, shards=2].for_each()]" + assert list(it.gather_sync()) == [0, 4, 2, 6] + + +def test_combine(ray_start_regular_shared): + it = from_range(4, 1).combine(lambda x: [x, x]) + assert repr(it) == "ParallelIterator[from_range[4, shards=1].combine()]" + assert list(it.gather_sync()) == [0, 0, 1, 1, 2, 2, 3, 3] + + +def test_chain(ray_start_regular_shared): + it = from_range(4).for_each(lambda x: x * 2).for_each(lambda x: x * 2) + assert repr( + it + ) == "ParallelIterator[from_range[4, shards=2].for_each().for_each()]" + assert list(it.gather_sync()) == [0, 8, 4, 12] + + +def test_filter(ray_start_regular_shared): + it = from_range(4).filter(lambda x: x < 3) + assert repr(it) == "ParallelIterator[from_range[4, shards=2].filter()]" + assert list(it.gather_sync()) == [0, 2, 1] + + +def test_batch(ray_start_regular_shared): + it = from_range(4, 1).batch(2) + assert repr(it) == "ParallelIterator[from_range[4, shards=1].batch(2)]" + assert list(it.gather_sync()) == [[0, 1], [2, 3]] + + +def test_flatten(ray_start_regular_shared): + it = from_items([[1, 2], [3, 4]], 1).flatten() + assert repr( + it) == "ParallelIterator[from_items[list, 2, shards=1].flatten()]" + assert list(it.gather_sync()) == [1, 2, 3, 4] + + +def test_gather_sync(ray_start_regular_shared): + it = from_range(4) + it = it.gather_sync() + assert ( + repr(it) == "LocalIterator[ParallelIterator[from_range[4, shards=2]]" + ".gather_sync()]") + assert sorted(it) == [0, 1, 2, 3] + + +def test_gather_async(ray_start_regular_shared): + it = from_range(4) + it = it.gather_async() + assert ( + repr(it) == "LocalIterator[ParallelIterator[from_range[4, shards=2]]" + ".gather_async()]") + assert sorted(it) == [0, 1, 2, 3] + + +def test_batch_across_shards(ray_start_regular_shared): + it = from_iterators([[0, 1], [2, 3]]) + it = it.batch_across_shards() + assert ( + repr(it) == "LocalIterator[ParallelIterator[from_iterators[shards=2]]" + ".batch_across_shards()]") + assert sorted(it) == [[0, 2], [1, 3]] + + +def test_remote(ray_start_regular_shared): + it = from_iterators([[0, 1], [3, 4], [5, 6, 7]]) + assert it.num_shards() == 3 + + @ray.remote + def get_shard(it, i): + return list(it.get_shard(i)) + + assert ray.get(get_shard.remote(it, 0)) == [0, 1] + assert ray.get(get_shard.remote(it, 1)) == [3, 4] + assert ray.get(get_shard.remote(it, 2)) == [5, 6, 7] + + @ray.remote + def check_remote(it): + assert ray.get(get_shard.remote(it, 0)) == [0, 1] + assert ray.get(get_shard.remote(it, 1)) == [3, 4] + assert ray.get(get_shard.remote(it, 2)) == [5, 6, 7] + + ray.get(check_remote.remote(it)) + + +def test_union(ray_start_regular_shared): + it1 = from_items(["a", "b", "c"], 1) + it2 = from_items(["x", "y", "z"], 1) + it = it1.union(it2) + assert (repr(it) == "ParallelIterator[ParallelUnion[ParallelIterator[" + "from_items[str, 3, shards=1]], ParallelIterator[" + "from_items[str, 3, shards=1]]]]") + assert list(it.gather_sync()) == ["a", "x", "b", "y", "c", "z"] + + +def test_union_local(ray_start_regular_shared): + it1 = from_items(["a", "b", "c"], 1).gather_async() + it2 = from_range(5, 2).for_each(str).gather_async() + it = it1.union(it2) + assert sorted(it) == ["0", "1", "2", "3", "4", "a", "b", "c"] + + +def test_union_async(ray_start_regular_shared): + def gen_fast(): + for i in range(10): + time.sleep(0.05) + print("PRODUCE FAST", i) + yield i + + def gen_slow(): + for i in range(10): + time.sleep(0.3) + print("PRODUCE SLOW", i) + yield i + + it1 = from_iterators([gen_fast]).for_each(lambda x: ("fast", x)) + it2 = from_iterators([gen_slow]).for_each(lambda x: ("slow", x)) + it = it1.union(it2) + results = list(it.gather_async()) + assert all(x[0] == "slow" for x in results[-3:]), results + + +def test_union_local_async(ray_start_regular_shared): + def gen_fast(): + for i in range(10): + time.sleep(0.05) + print("PRODUCE FAST", i) + yield i + + def gen_slow(): + for i in range(10): + time.sleep(0.3) + print("PRODUCE SLOW", i) + yield i + + it1 = from_iterators([gen_fast]).for_each(lambda x: ("fast", x)) + it2 = from_iterators([gen_slow]).for_each(lambda x: ("slow", x)) + it = it1.gather_async().union(it2.gather_async()) + assert (repr(it) == "LocalIterator[LocalUnion[LocalIterator[" + "ParallelIterator[from_iterators[shards=1].for_each()]" + ".gather_async()], LocalIterator[ParallelIterator[" + "from_iterators[shards=1].for_each()].gather_async()]]]") + results = list(it) + assert all(x[0] == "slow" for x in results[-3:]), results + + +def test_serialization(ray_start_regular_shared): + it = (from_items([1, 2, 3, 4]).gather_sync().for_each(lambda x: x) + .filter(lambda x: True).batch(2).flatten()) + assert (repr(it) == "LocalIterator[ParallelIterator[" + "from_items[int, 4, shards=2]].gather_sync()." + "for_each().filter().batch(2).flatten()]") + + @ray.remote + def get(it): + return list(it) + + assert ray.get(get.remote(it)) == [1, 2, 3, 4] + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__]))