mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:23:17 +08:00
[iter] Add .transform() function for arbitrary generator transforms (#8978)
This commit is contained in:
@@ -17,6 +17,24 @@ def test_select_shards(ray_start_regular_shared):
|
||||
assert it2.take(4) == [2, 4]
|
||||
|
||||
|
||||
def test_transform(ray_start_regular_shared):
|
||||
def f(it):
|
||||
for item in it:
|
||||
yield item * 2
|
||||
|
||||
def g(it):
|
||||
for item in it:
|
||||
if item >= 2:
|
||||
yield item
|
||||
|
||||
it = from_range(4).transform(f)
|
||||
assert repr(it) == "ParallelIterator[from_range[4, shards=2].transform()]"
|
||||
assert list(it.gather_sync()) == [0, 4, 2, 6]
|
||||
|
||||
it = from_range(4)
|
||||
assert list(it.gather_sync().transform(g)) == [2, 3]
|
||||
|
||||
|
||||
def test_metrics(ray_start_regular_shared):
|
||||
it = from_items([1, 2, 3, 4], num_shards=1)
|
||||
it2 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
|
||||
+43
-3
@@ -193,10 +193,37 @@ class ParallelIterator(Generic[T]):
|
||||
name=self.name + name,
|
||||
parent_iterators=self.parent_iterators)
|
||||
|
||||
def transform(self, fn: Callable[[Iterable[T]], Iterable[U]]
|
||||
) -> "ParallelIterator[U]":
|
||||
"""Remotely transform the iterator.
|
||||
|
||||
This is advanced version of for_each that allows you to apply arbitrary
|
||||
generator transformations over the iterator. Prefer to use .for_each()
|
||||
when possible for simplicity.
|
||||
|
||||
Args:
|
||||
fn (func): function to use to transform the iterator. The function
|
||||
should pass through instances of _NextValueNotReady that appear
|
||||
in its input iterator. Note that this function is only called
|
||||
**once** over the input iterator.
|
||||
|
||||
Returns:
|
||||
ParallelIterator[U]: a parallel iterator.
|
||||
|
||||
Examples:
|
||||
>>> def f(it):
|
||||
... for x in it:
|
||||
... if x % 2 == 0:
|
||||
... yield x
|
||||
>>> from_range(10, 1).transform(f).gather_sync().take(5)
|
||||
... [0, 2, 4, 6, 8]
|
||||
"""
|
||||
return self._with_transform(lambda local_it: local_it.transform(fn),
|
||||
".transform()")
|
||||
|
||||
def for_each(self, fn: Callable[[T], U], max_concurrency=1,
|
||||
resources=None) -> "ParallelIterator[U]":
|
||||
"""Remotely apply fn to each item in this iterator, at most `max_concurrency`
|
||||
at a time per shard.
|
||||
"""Remotely apply fn to each item in this iterator.
|
||||
|
||||
If `max_concurrency` == 1 then `fn` will be executed serially by each
|
||||
shards
|
||||
@@ -224,7 +251,6 @@ class ParallelIterator(Generic[T]):
|
||||
ParallelIterator[U]: a parallel iterator whose elements have `fn`
|
||||
applied.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> next(from_range(4).for_each(
|
||||
lambda x: x * 2,
|
||||
@@ -737,6 +763,20 @@ class LocalIterator(Generic[T]):
|
||||
def __repr__(self):
|
||||
return "LocalIterator[{}]".format(self.name)
|
||||
|
||||
def transform(self, fn: Callable[[Iterable[T]], Iterable[U]]
|
||||
) -> "LocalIterator[U]":
|
||||
|
||||
# TODO(ekl) can we automatically handle NextValueNotReady here?
|
||||
def apply_transform(it):
|
||||
for item in fn(it):
|
||||
yield item
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.shared_metrics,
|
||||
self.local_transforms + [apply_transform],
|
||||
name=self.name + ".transform()")
|
||||
|
||||
def for_each(self, fn: Callable[[T], U], max_concurrency=1,
|
||||
resources=None) -> "LocalIterator[U]":
|
||||
if max_concurrency == 1:
|
||||
|
||||
Reference in New Issue
Block a user