[iter] Add .transform() function for arbitrary generator transforms (#8978)

This commit is contained in:
Eric Liang
2020-06-25 11:04:14 -07:00
committed by GitHub
parent 0f1d99befc
commit 4522038259
3 changed files with 62 additions and 4 deletions
+18
View File
@@ -17,6 +17,24 @@ def test_select_shards(ray_start_regular_shared):
assert it2.take(4) == [2, 4]
def test_transform(ray_start_regular_shared):
def f(it):
for item in it:
yield item * 2
def g(it):
for item in it:
if item >= 2:
yield item
it = from_range(4).transform(f)
assert repr(it) == "ParallelIterator[from_range[4, shards=2].transform()]"
assert list(it.gather_sync()) == [0, 4, 2, 6]
it = from_range(4)
assert list(it.gather_sync().transform(g)) == [2, 3]
def test_metrics(ray_start_regular_shared):
it = from_items([1, 2, 3, 4], num_shards=1)
it2 = from_items([1, 2, 3, 4], num_shards=1)
+43 -3
View File
@@ -193,10 +193,37 @@ class ParallelIterator(Generic[T]):
name=self.name + name,
parent_iterators=self.parent_iterators)
def transform(self, fn: Callable[[Iterable[T]], Iterable[U]]
) -> "ParallelIterator[U]":
"""Remotely transform the iterator.
This is advanced version of for_each that allows you to apply arbitrary
generator transformations over the iterator. Prefer to use .for_each()
when possible for simplicity.
Args:
fn (func): function to use to transform the iterator. The function
should pass through instances of _NextValueNotReady that appear
in its input iterator. Note that this function is only called
**once** over the input iterator.
Returns:
ParallelIterator[U]: a parallel iterator.
Examples:
>>> def f(it):
... for x in it:
... if x % 2 == 0:
... yield x
>>> from_range(10, 1).transform(f).gather_sync().take(5)
... [0, 2, 4, 6, 8]
"""
return self._with_transform(lambda local_it: local_it.transform(fn),
".transform()")
def for_each(self, fn: Callable[[T], U], max_concurrency=1,
resources=None) -> "ParallelIterator[U]":
"""Remotely apply fn to each item in this iterator, at most `max_concurrency`
at a time per shard.
"""Remotely apply fn to each item in this iterator.
If `max_concurrency` == 1 then `fn` will be executed serially by each
shards
@@ -224,7 +251,6 @@ class ParallelIterator(Generic[T]):
ParallelIterator[U]: a parallel iterator whose elements have `fn`
applied.
Examples:
>>> next(from_range(4).for_each(
lambda x: x * 2,
@@ -737,6 +763,20 @@ class LocalIterator(Generic[T]):
def __repr__(self):
return "LocalIterator[{}]".format(self.name)
def transform(self, fn: Callable[[Iterable[T]], Iterable[U]]
) -> "LocalIterator[U]":
# TODO(ekl) can we automatically handle NextValueNotReady here?
def apply_transform(it):
for item in fn(it):
yield item
return LocalIterator(
self.base_iterator,
self.shared_metrics,
self.local_transforms + [apply_transform],
name=self.name + ".transform()")
def for_each(self, fn: Callable[[T], U], max_concurrency=1,
resources=None) -> "LocalIterator[U]":
if max_concurrency == 1: