mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 02:42:52 +08:00
Add experimental parallel iterators API (#6644)
This commit is contained in:
@@ -8,6 +8,7 @@ from .named_actors import get_actor, register_actor
|
||||
from .api import get, wait
|
||||
from .actor_pool import ActorPool
|
||||
from .dynamic_resources import set_resource
|
||||
from . import iter
|
||||
|
||||
|
||||
def TensorFlowVariables(*args, **kwargs):
|
||||
@@ -17,7 +18,15 @@ def TensorFlowVariables(*args, **kwargs):
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TensorFlowVariables", "get_actor", "register_actor", "get", "wait",
|
||||
"set_flushing_policy", "GcsFlushPolicy", "SimpleGcsFlushPolicy",
|
||||
"set_resource", "ActorPool"
|
||||
"TensorFlowVariables",
|
||||
"get_actor",
|
||||
"register_actor",
|
||||
"get",
|
||||
"wait",
|
||||
"set_flushing_policy",
|
||||
"GcsFlushPolicy",
|
||||
"SimpleGcsFlushPolicy",
|
||||
"set_resource",
|
||||
"ActorPool",
|
||||
"iter",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,668 @@
|
||||
from typing import TypeVar, Generic, Iterable, List, Callable, Any
|
||||
|
||||
import ray
|
||||
|
||||
# The type of an iterator element.
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
def from_items(items: List[T], num_shards: int = 2,
|
||||
repeat: bool = False) -> "ParallelIterator[T]":
|
||||
"""Create a parallel iterator from an existing set of objects.
|
||||
|
||||
The objects will be divided round-robin among the number of shards.
|
||||
|
||||
Arguments:
|
||||
items (list): The list of items to iterate over.
|
||||
num_shards (int): The number of worker actors to create.
|
||||
repeat (bool): Whether to cycle over the items forever.
|
||||
"""
|
||||
shards = [[] for _ in range(num_shards)]
|
||||
for i, item in enumerate(items):
|
||||
shards[i % num_shards].append(item)
|
||||
name = "from_items[{}, {}, shards={}{}]".format(
|
||||
items and type(items[0]).__name__ or "None", len(items), num_shards,
|
||||
", repeat=True" if repeat else "")
|
||||
return from_iterators(shards, repeat=repeat, name=name)
|
||||
|
||||
|
||||
def from_range(n: int, num_shards: int = 2,
|
||||
repeat: bool = False) -> "ParallelIterator[int]":
|
||||
"""Create a parallel iterator over the range 0..n.
|
||||
|
||||
The range will be partitioned sequentially among the number of shards.
|
||||
|
||||
Arguments:
|
||||
n (int): The max end of the range of numbers.
|
||||
num_shards (int): The number of worker actors to create.
|
||||
repeat (bool): Whether to cycle over the range forever.
|
||||
"""
|
||||
generators = []
|
||||
shard_size = n // num_shards
|
||||
for i in range(num_shards):
|
||||
start = i * shard_size
|
||||
if i == num_shards - 1:
|
||||
end = n
|
||||
else:
|
||||
end = (i + 1) * shard_size
|
||||
generators.append(range(start, end))
|
||||
name = "from_range[{}, shards={}{}]".format(
|
||||
n, num_shards, ", repeat=True" if repeat else "")
|
||||
return from_iterators(generators, repeat=repeat, name=name)
|
||||
|
||||
|
||||
def from_iterators(generators: List[Iterable[T]],
|
||||
repeat: bool = False,
|
||||
name=None) -> "ParallelIterator[T]":
|
||||
"""Create a parallel iterator from a set of iterators.
|
||||
|
||||
An actor will be created for each iterator.
|
||||
|
||||
Examples:
|
||||
>>> # Create using a list of generators.
|
||||
>>> from_iterators([range(100), range(100)])
|
||||
|
||||
>>> # Equivalent to the above.
|
||||
>>> from_iterators([lambda: range(100), lambda: range(100)])
|
||||
|
||||
Arguments:
|
||||
generators (list): A list of Python generator objects or lambda
|
||||
functions that produced a generator when called. We allow lambda
|
||||
functions since the generator itself might not be serializable,
|
||||
but a lambda that returns it can be.
|
||||
repeat (bool): Whether to cycle over the iterators forever.
|
||||
name (str): Optional name to give the iterator.
|
||||
"""
|
||||
worker_cls = ray.remote(ParallelIteratorWorker)
|
||||
actors = [worker_cls.remote(g, repeat) for g in generators]
|
||||
if not name:
|
||||
name = "from_iterators[shards={}{}]".format(
|
||||
len(generators), ", repeat=True" if repeat else "")
|
||||
return from_actors(actors, name=name)
|
||||
|
||||
|
||||
def from_actors(actors: List["ray.actor.ActorHandle"],
|
||||
name=None) -> "ParallelIterator[T]":
|
||||
"""Create a parallel iterator from an existing set of actors.
|
||||
|
||||
Each actor must subclass the ParallelIteratorWorker interface.
|
||||
|
||||
Arguments:
|
||||
actors (list): List of actors that each implement
|
||||
ParallelIteratorWorker.
|
||||
name (str): Optional name to give the iterator.
|
||||
"""
|
||||
if not name:
|
||||
name = "from_actors[shards={}]".format(len(actors))
|
||||
return ParallelIterator([_ActorSet(actors, [])], name)
|
||||
|
||||
|
||||
class ParallelIterator(Generic[T]):
|
||||
"""A parallel iterator over a set of remote actors.
|
||||
|
||||
This can be used to iterate over a fixed set of task results
|
||||
(like an actor pool), or a stream of data (e.g., a fixed range of numbers,
|
||||
an infinite stream of RLlib rollout results).
|
||||
|
||||
This class is **serializable** and can be passed to other remote
|
||||
tasks and actors. However, each shard should be read from at most one
|
||||
process at a time.
|
||||
|
||||
Examples:
|
||||
>>> # Applying a function over items in parallel.
|
||||
>>> it = ray.experimental.iter.from_items([1, 2, 3], num_shards=2)
|
||||
... <__main__.ParallelIterator object>
|
||||
>>> it = it.for_each(lambda x: x * 2).gather_sync()
|
||||
... <__main__.LocalIterator object>
|
||||
>>> print(list(it))
|
||||
... [2, 4, 6]
|
||||
|
||||
>>> # Creating from generators.
|
||||
>>> it = ray.experimental.iter.from_iterators([range(3), range(3)])
|
||||
... <__main__.ParallelIterator object>
|
||||
>>> print(list(it.gather_sync()))
|
||||
... [0, 0, 1, 1, 2, 2]
|
||||
|
||||
>>> # Accessing the individual shards of an iterator.
|
||||
>>> it = ray.experimental.iter.from_range(10, num_shards=2)
|
||||
... <__main__.ParallelIterator object>
|
||||
>>> it0 = it.get_shard(0)
|
||||
... <__main__.LocalIterator object>
|
||||
>>> print(list(it0))
|
||||
... [0, 1, 2, 3, 4]
|
||||
>>> it1 = it.get_shard(1)
|
||||
... <__main__.LocalIterator object>
|
||||
>>> print(list(it1))
|
||||
... [5, 6, 7, 8, 9]
|
||||
|
||||
>>> # Gathering results from actors synchronously in parallel.
|
||||
>>> it = ray.experimental.iter.from_actors(workers)
|
||||
... <__main__.ParallelIterator object>
|
||||
>>> it = it.batch_across_shards()
|
||||
... <__main__.LocalIterator object>
|
||||
>>> print(next(it))
|
||||
... [worker_1_result_1, worker_2_result_1]
|
||||
>>> print(next(it))
|
||||
... [worker_1_result_2, worker_2_result_2]
|
||||
"""
|
||||
|
||||
def __init__(self, actor_sets: List["_ActorSet"], name: str):
|
||||
# We track multiple sets of actors to support parallel .union().
|
||||
self.actor_sets = actor_sets
|
||||
self.name = name
|
||||
|
||||
def __iter__(self):
|
||||
raise TypeError(
|
||||
"You must use it.gather_sync() or it.gather_async() to "
|
||||
"iterate over the results of a ParallelIterator.")
|
||||
|
||||
def __str__(self):
|
||||
return repr(self)
|
||||
|
||||
def __repr__(self):
|
||||
return "ParallelIterator[{}]".format(self.name)
|
||||
|
||||
def for_each(self, fn: Callable[[T], U]) -> "ParallelIterator[U]":
|
||||
"""Remotely apply fn to each item in this iterator.
|
||||
|
||||
Arguments:
|
||||
fn (func): function to apply to each item.
|
||||
|
||||
Examples:
|
||||
>>> next(from_range(4).for_each(lambda x: x * 2).gather_sync())
|
||||
... [0, 2, 4, 8]
|
||||
"""
|
||||
return ParallelIterator(
|
||||
[
|
||||
a.with_transform(lambda local_it: local_it.for_each(fn))
|
||||
for a in self.actor_sets
|
||||
],
|
||||
name=self.name + ".for_each()")
|
||||
|
||||
def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]":
|
||||
"""Remotely filter items from this iterator.
|
||||
|
||||
Arguments:
|
||||
fn (func): returns False for items to drop from the iterator.
|
||||
|
||||
Examples:
|
||||
>>> it = from_items([0, 1, 2]).filter(lambda x: x > 0)
|
||||
>>> next(it.gather_sync())
|
||||
... [1, 2]
|
||||
"""
|
||||
return ParallelIterator(
|
||||
[
|
||||
a.with_transform(lambda local_it: local_it.filter(fn))
|
||||
for a in self.actor_sets
|
||||
],
|
||||
name=self.name + ".filter()")
|
||||
|
||||
def batch(self, n: int) -> "ParallelIterator[List[T]]":
|
||||
"""Remotely batch together items in this iterator.
|
||||
|
||||
Arguments:
|
||||
n (int): Number of items to batch together.
|
||||
|
||||
Examples:
|
||||
>>> next(from_range(10, 1).batch(4).gather_sync())
|
||||
... [0, 1, 2, 3]
|
||||
"""
|
||||
return ParallelIterator(
|
||||
[
|
||||
a.with_transform(lambda local_it: local_it.batch(n))
|
||||
for a in self.actor_sets
|
||||
],
|
||||
name=self.name + ".batch({})".format(n))
|
||||
|
||||
def flatten(self) -> "ParallelIterator[T[0]]":
|
||||
"""Flatten batches of items into individual items.
|
||||
|
||||
Examples:
|
||||
>>> next(from_range(10, 1).batch(4).flatten())
|
||||
... 0
|
||||
"""
|
||||
return ParallelIterator(
|
||||
[
|
||||
a.with_transform(lambda local_it: local_it.flatten())
|
||||
for a in self.actor_sets
|
||||
],
|
||||
name=self.name + ".flatten()")
|
||||
|
||||
def combine(self, fn: Callable[[T], List[U]]) -> "ParallelIterator[U]":
|
||||
"""Transform and then combine items horizontally.
|
||||
|
||||
This is the equivalent of for_each(fn).flatten() (flat map).
|
||||
"""
|
||||
it = self.for_each(fn).flatten()
|
||||
it.name = self.name + ".combine()"
|
||||
return it
|
||||
|
||||
def gather_sync(self) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for synchronous iteration.
|
||||
|
||||
New items will be fetched from the shards on-demand as the iterator
|
||||
is stepped through.
|
||||
|
||||
This is the equivalent of batch_across_shards().flatten().
|
||||
|
||||
Examples:
|
||||
>>> it = from_range(100, 1).gather_sync()
|
||||
>>> next(it)
|
||||
... 0
|
||||
>>> next(it)
|
||||
... 1
|
||||
>>> next(it)
|
||||
... 2
|
||||
"""
|
||||
it = self.batch_across_shards().flatten()
|
||||
it.name = "{}.gather_sync()".format(self)
|
||||
return it
|
||||
|
||||
def batch_across_shards(self) -> "LocalIterator[List[T]]":
|
||||
"""Iterate over the results of multiple shards in parallel.
|
||||
|
||||
Examples:
|
||||
>>> it = from_iterators([range(3), range(3)])
|
||||
>>> next(it.batch_across_shards())
|
||||
... [0, 0]
|
||||
"""
|
||||
|
||||
def base_iterator(timeout=None):
|
||||
active = []
|
||||
for actor_set in self.actor_sets:
|
||||
actor_set.init_actors()
|
||||
active.extend(actor_set.actors)
|
||||
futures = [a.par_iter_next.remote() for a in active]
|
||||
while active:
|
||||
try:
|
||||
yield ray.get(futures, timeout=timeout)
|
||||
futures = [a.par_iter_next.remote() for a in active]
|
||||
# Always yield after each round of gets with timeout.
|
||||
if timeout is not None:
|
||||
yield _NextValueNotReady()
|
||||
except TimeoutError:
|
||||
yield _NextValueNotReady()
|
||||
except StopIteration:
|
||||
# Find and remove the actor that produced StopIteration.
|
||||
results = []
|
||||
for a, f in zip(list(active), futures):
|
||||
try:
|
||||
results.append(ray.get(f))
|
||||
except StopIteration:
|
||||
active.remove(a)
|
||||
if results:
|
||||
yield results
|
||||
futures = [a.par_iter_next.remote() for a in active]
|
||||
|
||||
name = "{}.batch_across_shards()".format(self)
|
||||
return LocalIterator(base_iterator, name=name)
|
||||
|
||||
def gather_async(self) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for asynchronous iteration.
|
||||
|
||||
New items will be fetched from the shards asynchronously as soon as
|
||||
the previous one is computed. Items arrive in non-deterministic order.
|
||||
|
||||
Examples:
|
||||
>>> it = from_range(100, 1).gather_async()
|
||||
>>> next(it)
|
||||
... 3
|
||||
>>> next(it)
|
||||
... 0
|
||||
>>> next(it)
|
||||
... 1
|
||||
"""
|
||||
|
||||
def base_iterator(timeout=None):
|
||||
all_actors = []
|
||||
for actor_set in self.actor_sets:
|
||||
actor_set.init_actors()
|
||||
all_actors.extend(actor_set.actors)
|
||||
futures = {}
|
||||
for a in all_actors:
|
||||
futures[a.par_iter_next.remote()] = a
|
||||
while futures:
|
||||
pending = list(futures)
|
||||
if timeout is None:
|
||||
# First try to do a batch wait for efficiency.
|
||||
ready, _ = ray.wait(
|
||||
pending, num_returns=len(pending), timeout=0)
|
||||
# Fall back to a blocking wait.
|
||||
if not ready:
|
||||
ready, _ = ray.wait(pending, num_returns=1)
|
||||
else:
|
||||
ready, _ = ray.wait(
|
||||
pending, num_returns=len(pending), timeout=timeout)
|
||||
for obj_id in ready:
|
||||
actor = futures.pop(obj_id)
|
||||
try:
|
||||
yield ray.get(obj_id)
|
||||
futures[actor.par_iter_next.remote()] = actor
|
||||
except StopIteration:
|
||||
pass
|
||||
# Always yield after each round of wait with timeout.
|
||||
if timeout is not None:
|
||||
yield _NextValueNotReady()
|
||||
|
||||
name = "{}.gather_async()".format(self)
|
||||
return LocalIterator(base_iterator, name=name)
|
||||
|
||||
def take(self, n: int) -> List[T]:
|
||||
"""Return up to the first n items from this iterator."""
|
||||
return self.gather_sync().take(n)
|
||||
|
||||
def show(self, n: int = 20):
|
||||
"""Print up to the first n items from this iterator."""
|
||||
return self.gather_sync().show(n)
|
||||
|
||||
def union(self, other: "ParallelIterator[T]") -> "ParallelIterator[T]":
|
||||
"""Return an iterator that is the union of this and the other."""
|
||||
if not isinstance(other, ParallelIterator):
|
||||
raise ValueError(
|
||||
"other must be of type ParallelIterator, got {}".format(
|
||||
type(other)))
|
||||
actor_sets = []
|
||||
actor_sets.extend(self.actor_sets)
|
||||
actor_sets.extend(other.actor_sets)
|
||||
return ParallelIterator(actor_sets, "ParallelUnion[{}, {}]".format(
|
||||
self, other))
|
||||
|
||||
def num_shards(self) -> int:
|
||||
"""Return the number of worker actors backing this iterator."""
|
||||
return sum(len(a.actors) for a in self.actor_sets)
|
||||
|
||||
def shards(self) -> List["LocalIterator[T]"]:
|
||||
"""Return the list of all shards."""
|
||||
return [self.get_shard(i) for i in range(self.num_shards())]
|
||||
|
||||
def get_shard(self, shard_index: int) -> "LocalIterator[T]":
|
||||
"""Return a local iterator for the given shard.
|
||||
|
||||
The iterator is guaranteed to be serializable and can be passed to
|
||||
remote tasks or actors.
|
||||
"""
|
||||
a, t = None, None
|
||||
i = shard_index
|
||||
for actor_set in self.actor_sets:
|
||||
if i < len(actor_set.actors):
|
||||
a = actor_set.actors[i]
|
||||
t = actor_set.transforms
|
||||
break
|
||||
else:
|
||||
i -= len(actor_set.actors)
|
||||
if a is None:
|
||||
raise ValueError("Shard index out of range", shard_index,
|
||||
self.num_shards())
|
||||
|
||||
def base_iterator(timeout=None):
|
||||
ray.get(a.par_iter_init.remote(t))
|
||||
while True:
|
||||
try:
|
||||
yield ray.get(a.par_iter_next.remote(), timeout=timeout)
|
||||
# Always yield after each round of gets with timeout.
|
||||
if timeout is not None:
|
||||
yield _NextValueNotReady()
|
||||
except TimeoutError:
|
||||
yield _NextValueNotReady()
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
name = self.name + ".shard[{}]".format(shard_index)
|
||||
return LocalIterator(base_iterator, name=name)
|
||||
|
||||
|
||||
class LocalIterator(Generic[T]):
|
||||
"""An iterator over a single shard of data.
|
||||
|
||||
It implements similar transformations as ParallelIterator[T], but the
|
||||
transforms will be applied locally and not remotely in parallel.
|
||||
|
||||
This class is **serializable** and can be passed to other remote
|
||||
tasks and actors. However, it should be read from at most one process at
|
||||
a time."""
|
||||
|
||||
def __init__(self,
|
||||
base_iterator: Callable[[], Iterable[T]],
|
||||
local_transforms: List[Callable[[Iterable], Any]] = None,
|
||||
timeout: int = None,
|
||||
name=None):
|
||||
"""Create a local iterator (this is an internal function).
|
||||
|
||||
Arguments:
|
||||
base_iterator (func): A function that produces the base iterator.
|
||||
This is a function so that we can ensure LocalIterator is
|
||||
serializable.
|
||||
local_transforms (list): A list of transformation functions to be
|
||||
applied on top of the base iterator. When iteration begins, we
|
||||
create the base iterator and apply these functions. This lazy
|
||||
creation ensures LocalIterator is serializable until you start
|
||||
iterating over it.
|
||||
timeout (int): Optional timeout in seconds for this iterator, after
|
||||
which _NextValueNotReady will be returned. This avoids
|
||||
blocking.
|
||||
name (str): Optional name for this iterator.
|
||||
"""
|
||||
self.base_iterator = base_iterator
|
||||
self.built_iterator = None
|
||||
self.local_transforms = local_transforms or []
|
||||
self.timeout = timeout
|
||||
self.name = name or "unknown"
|
||||
|
||||
def _build_once(self):
|
||||
if self.built_iterator is None:
|
||||
it = iter(self.base_iterator(self.timeout))
|
||||
for fn in self.local_transforms:
|
||||
it = fn(it)
|
||||
self.built_iterator = it
|
||||
|
||||
def __iter__(self):
|
||||
self._build_once()
|
||||
return self.built_iterator
|
||||
|
||||
def __next__(self):
|
||||
self._build_once()
|
||||
return next(self.built_iterator)
|
||||
|
||||
def __str__(self):
|
||||
return repr(self)
|
||||
|
||||
def __repr__(self):
|
||||
return "LocalIterator[{}]".format(self.name)
|
||||
|
||||
def for_each(self, fn: Callable[[T], U]) -> "LocalIterator[U]":
|
||||
def apply_foreach(it):
|
||||
for item in it:
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
yield item
|
||||
else:
|
||||
yield fn(item)
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.local_transforms + [apply_foreach],
|
||||
name=self.name + ".for_each()")
|
||||
|
||||
def filter(self, fn: Callable[[T], bool]) -> "LocalIterator[T]":
|
||||
def apply_filter(it):
|
||||
for item in it:
|
||||
if isinstance(item, _NextValueNotReady) or fn(item):
|
||||
yield item
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.local_transforms + [apply_filter],
|
||||
name=self.name + ".filter()")
|
||||
|
||||
def batch(self, n: int) -> "LocalIterator[List[T]]":
|
||||
def apply_batch(it):
|
||||
batch = []
|
||||
for item in it:
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
yield item
|
||||
else:
|
||||
batch.append(item)
|
||||
if len(batch) >= n:
|
||||
yield batch
|
||||
batch = []
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.local_transforms + [apply_batch],
|
||||
name=self.name + ".batch({})".format(n))
|
||||
|
||||
def flatten(self) -> "LocalIterator[T[0]]":
|
||||
def apply_flatten(it):
|
||||
for item in it:
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
yield item
|
||||
else:
|
||||
for subitem in item:
|
||||
yield subitem
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.local_transforms + [apply_flatten],
|
||||
name=self.name + ".flatten()")
|
||||
|
||||
def combine(self, fn: Callable[[T], List[U]]) -> "LocalIterator[U]":
|
||||
it = self.for_each(fn).flatten()
|
||||
it.name = self.name + ".combine()"
|
||||
return it
|
||||
|
||||
def take(self, n: int) -> List[T]:
|
||||
"""Return up to the first n items from this iterator."""
|
||||
out = []
|
||||
for item in self:
|
||||
out.append(item)
|
||||
if len(out) >= n:
|
||||
break
|
||||
return out
|
||||
|
||||
def show(self, n: int = 20):
|
||||
"""Print up to the first n items from this iterator."""
|
||||
i = 0
|
||||
for item in self:
|
||||
print(item)
|
||||
i += 1
|
||||
if i >= n:
|
||||
break
|
||||
|
||||
def union(self, other: "LocalIterator[T]") -> "LocalIterator[T]":
|
||||
"""Return an iterator that is the union of this and the other.
|
||||
|
||||
There are no ordering guarantees between the two iterators. We make a
|
||||
best-effort attempt to return items from both as they become ready,
|
||||
preventing starvation of any particular iterator.
|
||||
"""
|
||||
|
||||
if not isinstance(other, LocalIterator):
|
||||
raise ValueError(
|
||||
"other must be of type LocalIterator, got {}".format(
|
||||
type(other)))
|
||||
|
||||
it1 = LocalIterator(
|
||||
self.base_iterator, self.local_transforms, timeout=0)
|
||||
it2 = LocalIterator(
|
||||
other.base_iterator, other.local_transforms, timeout=0)
|
||||
active = [it1, it2]
|
||||
|
||||
def build_union(timeout=None):
|
||||
while True:
|
||||
for it in list(active):
|
||||
# Yield items from the iterator until _NextValueNotReady is
|
||||
# found, then switch to the next iterator.
|
||||
try:
|
||||
while True:
|
||||
item = next(it)
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
break
|
||||
else:
|
||||
yield item
|
||||
except StopIteration:
|
||||
active.remove(it)
|
||||
if not active:
|
||||
break
|
||||
|
||||
return LocalIterator(
|
||||
build_union, [], name="LocalUnion[{}, {}]".format(self, other))
|
||||
|
||||
|
||||
class ParallelIteratorWorker(object):
|
||||
"""Worker actor for a ParallelIterator.
|
||||
|
||||
Actors that are passed to iter.from_actors() must subclass this interface.
|
||||
"""
|
||||
|
||||
def __init__(self, item_generator: Any, repeat: bool):
|
||||
"""Create an iterator worker.
|
||||
|
||||
Subclasses must call this init function.
|
||||
|
||||
Arguments:
|
||||
item_generator (obj): A Python generator objects or lambda function
|
||||
that produces a generator when called. We allow lambda
|
||||
functions since the generator itself might not be serializable,
|
||||
but a lambda that returns it can be.
|
||||
repeat (bool): Whether to loop over the iterator forever.
|
||||
"""
|
||||
|
||||
def make_iterator():
|
||||
if callable(item_generator):
|
||||
return item_generator()
|
||||
else:
|
||||
return item_generator
|
||||
|
||||
if repeat:
|
||||
|
||||
def cycle():
|
||||
while True:
|
||||
it = make_iterator()
|
||||
for item in it:
|
||||
yield item
|
||||
|
||||
self.item_generator = cycle()
|
||||
else:
|
||||
self.item_generator = make_iterator()
|
||||
|
||||
self.transforms = []
|
||||
self.local_it = None
|
||||
|
||||
def par_iter_init(self, transforms):
|
||||
"""Implements ParallelIterator worker init."""
|
||||
it = LocalIterator(lambda timeout: self.item_generator)
|
||||
for fn in transforms:
|
||||
it = fn(it)
|
||||
assert it is not None, fn
|
||||
self.local_it = iter(it)
|
||||
|
||||
def par_iter_next(self):
|
||||
"""Implements ParallelIterator worker item fetch."""
|
||||
assert self.local_it is not None, "must call par_iter_init()"
|
||||
return next(self.local_it)
|
||||
|
||||
|
||||
class _NextValueNotReady(Exception):
|
||||
"""Indicates that a local iterator has no value currently available.
|
||||
|
||||
This is used internally to implement the union() of multiple blocking
|
||||
local generators."""
|
||||
pass
|
||||
|
||||
|
||||
class _ActorSet(object):
|
||||
"""Helper class that represents a set of actors and transforms."""
|
||||
|
||||
def __init__(
|
||||
self, actors: List["ray.actor.ActorHandle"],
|
||||
transforms: List[Callable[["LocalIterator"], "LocalIterator"]]):
|
||||
self.actors = actors
|
||||
self.transforms = transforms
|
||||
|
||||
def init_actors(self):
|
||||
ray.get([a.par_iter_init.remote(self.transforms) for a in self.actors])
|
||||
|
||||
def with_transform(self, fn):
|
||||
return _ActorSet(self.actors, self.transforms + [fn])
|
||||
@@ -14,6 +14,14 @@ py_test(
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_iter",
|
||||
size = "small",
|
||||
srcs = ["test_iter.py"],
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_actor_resources",
|
||||
size = "medium",
|
||||
|
||||
@@ -0,0 +1,216 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.experimental.iter import from_items, from_iterators, from_range, \
|
||||
from_actors, ParallelIteratorWorker
|
||||
|
||||
|
||||
def test_from_items(ray_start_regular_shared):
|
||||
it = from_items([1, 2, 3, 4])
|
||||
assert repr(it) == "ParallelIterator[from_items[int, 4, shards=2]]"
|
||||
assert list(it.gather_sync()) == [1, 2, 3, 4]
|
||||
assert next(it.gather_sync()) == 1
|
||||
|
||||
|
||||
def test_from_items_repeat(ray_start_regular_shared):
|
||||
it = from_items([1, 2, 3, 4], repeat=True)
|
||||
assert repr(
|
||||
it) == "ParallelIterator[from_items[int, 4, shards=2, repeat=True]]"
|
||||
assert it.take(8) == [1, 2, 3, 4, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_from_iterators(ray_start_regular_shared):
|
||||
it = from_iterators([range(2), range(2)])
|
||||
assert repr(it) == "ParallelIterator[from_iterators[shards=2]]"
|
||||
assert list(it.gather_sync()) == [0, 0, 1, 1]
|
||||
|
||||
|
||||
def test_from_range(ray_start_regular_shared):
|
||||
it = from_range(4)
|
||||
assert repr(it) == "ParallelIterator[from_range[4, shards=2]]"
|
||||
assert list(it.gather_sync()) == [0, 2, 1, 3]
|
||||
|
||||
|
||||
def test_from_actors(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class CustomWorker(ParallelIteratorWorker):
|
||||
def __init__(self, data):
|
||||
ParallelIteratorWorker.__init__(self, data, False)
|
||||
|
||||
a = CustomWorker.remote([1, 2])
|
||||
b = CustomWorker.remote([3, 4])
|
||||
it = from_actors([a, b])
|
||||
assert repr(it) == "ParallelIterator[from_actors[shards=2]]"
|
||||
assert list(it.gather_sync()) == [1, 3, 2, 4]
|
||||
|
||||
|
||||
def test_for_each(ray_start_regular_shared):
|
||||
it = from_range(4).for_each(lambda x: x * 2)
|
||||
assert repr(it) == "ParallelIterator[from_range[4, shards=2].for_each()]"
|
||||
assert list(it.gather_sync()) == [0, 4, 2, 6]
|
||||
|
||||
|
||||
def test_combine(ray_start_regular_shared):
|
||||
it = from_range(4, 1).combine(lambda x: [x, x])
|
||||
assert repr(it) == "ParallelIterator[from_range[4, shards=1].combine()]"
|
||||
assert list(it.gather_sync()) == [0, 0, 1, 1, 2, 2, 3, 3]
|
||||
|
||||
|
||||
def test_chain(ray_start_regular_shared):
|
||||
it = from_range(4).for_each(lambda x: x * 2).for_each(lambda x: x * 2)
|
||||
assert repr(
|
||||
it
|
||||
) == "ParallelIterator[from_range[4, shards=2].for_each().for_each()]"
|
||||
assert list(it.gather_sync()) == [0, 8, 4, 12]
|
||||
|
||||
|
||||
def test_filter(ray_start_regular_shared):
|
||||
it = from_range(4).filter(lambda x: x < 3)
|
||||
assert repr(it) == "ParallelIterator[from_range[4, shards=2].filter()]"
|
||||
assert list(it.gather_sync()) == [0, 2, 1]
|
||||
|
||||
|
||||
def test_batch(ray_start_regular_shared):
|
||||
it = from_range(4, 1).batch(2)
|
||||
assert repr(it) == "ParallelIterator[from_range[4, shards=1].batch(2)]"
|
||||
assert list(it.gather_sync()) == [[0, 1], [2, 3]]
|
||||
|
||||
|
||||
def test_flatten(ray_start_regular_shared):
|
||||
it = from_items([[1, 2], [3, 4]], 1).flatten()
|
||||
assert repr(
|
||||
it) == "ParallelIterator[from_items[list, 2, shards=1].flatten()]"
|
||||
assert list(it.gather_sync()) == [1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_gather_sync(ray_start_regular_shared):
|
||||
it = from_range(4)
|
||||
it = it.gather_sync()
|
||||
assert (
|
||||
repr(it) == "LocalIterator[ParallelIterator[from_range[4, shards=2]]"
|
||||
".gather_sync()]")
|
||||
assert sorted(it) == [0, 1, 2, 3]
|
||||
|
||||
|
||||
def test_gather_async(ray_start_regular_shared):
|
||||
it = from_range(4)
|
||||
it = it.gather_async()
|
||||
assert (
|
||||
repr(it) == "LocalIterator[ParallelIterator[from_range[4, shards=2]]"
|
||||
".gather_async()]")
|
||||
assert sorted(it) == [0, 1, 2, 3]
|
||||
|
||||
|
||||
def test_batch_across_shards(ray_start_regular_shared):
|
||||
it = from_iterators([[0, 1], [2, 3]])
|
||||
it = it.batch_across_shards()
|
||||
assert (
|
||||
repr(it) == "LocalIterator[ParallelIterator[from_iterators[shards=2]]"
|
||||
".batch_across_shards()]")
|
||||
assert sorted(it) == [[0, 2], [1, 3]]
|
||||
|
||||
|
||||
def test_remote(ray_start_regular_shared):
|
||||
it = from_iterators([[0, 1], [3, 4], [5, 6, 7]])
|
||||
assert it.num_shards() == 3
|
||||
|
||||
@ray.remote
|
||||
def get_shard(it, i):
|
||||
return list(it.get_shard(i))
|
||||
|
||||
assert ray.get(get_shard.remote(it, 0)) == [0, 1]
|
||||
assert ray.get(get_shard.remote(it, 1)) == [3, 4]
|
||||
assert ray.get(get_shard.remote(it, 2)) == [5, 6, 7]
|
||||
|
||||
@ray.remote
|
||||
def check_remote(it):
|
||||
assert ray.get(get_shard.remote(it, 0)) == [0, 1]
|
||||
assert ray.get(get_shard.remote(it, 1)) == [3, 4]
|
||||
assert ray.get(get_shard.remote(it, 2)) == [5, 6, 7]
|
||||
|
||||
ray.get(check_remote.remote(it))
|
||||
|
||||
|
||||
def test_union(ray_start_regular_shared):
|
||||
it1 = from_items(["a", "b", "c"], 1)
|
||||
it2 = from_items(["x", "y", "z"], 1)
|
||||
it = it1.union(it2)
|
||||
assert (repr(it) == "ParallelIterator[ParallelUnion[ParallelIterator["
|
||||
"from_items[str, 3, shards=1]], ParallelIterator["
|
||||
"from_items[str, 3, shards=1]]]]")
|
||||
assert list(it.gather_sync()) == ["a", "x", "b", "y", "c", "z"]
|
||||
|
||||
|
||||
def test_union_local(ray_start_regular_shared):
|
||||
it1 = from_items(["a", "b", "c"], 1).gather_async()
|
||||
it2 = from_range(5, 2).for_each(str).gather_async()
|
||||
it = it1.union(it2)
|
||||
assert sorted(it) == ["0", "1", "2", "3", "4", "a", "b", "c"]
|
||||
|
||||
|
||||
def test_union_async(ray_start_regular_shared):
|
||||
def gen_fast():
|
||||
for i in range(10):
|
||||
time.sleep(0.05)
|
||||
print("PRODUCE FAST", i)
|
||||
yield i
|
||||
|
||||
def gen_slow():
|
||||
for i in range(10):
|
||||
time.sleep(0.3)
|
||||
print("PRODUCE SLOW", i)
|
||||
yield i
|
||||
|
||||
it1 = from_iterators([gen_fast]).for_each(lambda x: ("fast", x))
|
||||
it2 = from_iterators([gen_slow]).for_each(lambda x: ("slow", x))
|
||||
it = it1.union(it2)
|
||||
results = list(it.gather_async())
|
||||
assert all(x[0] == "slow" for x in results[-3:]), results
|
||||
|
||||
|
||||
def test_union_local_async(ray_start_regular_shared):
|
||||
def gen_fast():
|
||||
for i in range(10):
|
||||
time.sleep(0.05)
|
||||
print("PRODUCE FAST", i)
|
||||
yield i
|
||||
|
||||
def gen_slow():
|
||||
for i in range(10):
|
||||
time.sleep(0.3)
|
||||
print("PRODUCE SLOW", i)
|
||||
yield i
|
||||
|
||||
it1 = from_iterators([gen_fast]).for_each(lambda x: ("fast", x))
|
||||
it2 = from_iterators([gen_slow]).for_each(lambda x: ("slow", x))
|
||||
it = it1.gather_async().union(it2.gather_async())
|
||||
assert (repr(it) == "LocalIterator[LocalUnion[LocalIterator["
|
||||
"ParallelIterator[from_iterators[shards=1].for_each()]"
|
||||
".gather_async()], LocalIterator[ParallelIterator["
|
||||
"from_iterators[shards=1].for_each()].gather_async()]]]")
|
||||
results = list(it)
|
||||
assert all(x[0] == "slow" for x in results[-3:]), results
|
||||
|
||||
|
||||
def test_serialization(ray_start_regular_shared):
|
||||
it = (from_items([1, 2, 3, 4]).gather_sync().for_each(lambda x: x)
|
||||
.filter(lambda x: True).batch(2).flatten())
|
||||
assert (repr(it) == "LocalIterator[ParallelIterator["
|
||||
"from_items[int, 4, shards=2]].gather_sync()."
|
||||
"for_each().filter().batch(2).flatten()]")
|
||||
|
||||
@ray.remote
|
||||
def get(it):
|
||||
return list(it)
|
||||
|
||||
assert ray.get(get.remote(it)) == [1, 2, 3, 4]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
Reference in New Issue
Block a user