diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py index 12995b174..3e53bec7d 100644 --- a/python/ray/test_utils.py +++ b/python/ray/test_utils.py @@ -223,6 +223,21 @@ class SignalActor: await self.ready_event.wait() +@ray.remote(num_cpus=0) +class Semaphore: + def __init__(self, value=1): + self._sema = asyncio.Semaphore(value=value) + + async def acquire(self): + self._sema.acquire() + + async def release(self): + self._sema.release() + + async def locked(self): + return self._sema.locked() + + @ray.remote def _put(obj): return obj diff --git a/python/ray/tests/test_iter.py b/python/ray/tests/test_iter.py index 0242eb320..931ec2283 100644 --- a/python/ray/tests/test_iter.py +++ b/python/ray/tests/test_iter.py @@ -6,6 +6,7 @@ import pytest import ray from ray.util.iter import from_items, from_iterators, from_range, \ from_actors, ParallelIteratorWorker, LocalIterator +from ray.test_utils import Semaphore def test_metrics(ray_start_regular_shared): @@ -158,6 +159,44 @@ def test_for_each(ray_start_regular_shared): assert list(it.gather_sync()) == [0, 4, 2, 6] +def test_for_each_concur(ray_start_regular_shared): + main_wait = Semaphore.remote(value=0) + test_wait = Semaphore.remote(value=0) + + def task(x): + i, main_wait, test_wait = x + ray.get(main_wait.release.remote()) + ray.get(test_wait.acquire.remote()) + return i + 10 + + @ray.remote(num_cpus=0.1) + def to_list(it): + return list(it) + + it = from_items( + [(i, main_wait, test_wait) for i in range(8)], num_shards=2) + it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.1}) + + for i in range(4): + ray.get(main_wait.acquire.remote()) + + # There should be exactly 4 tasks executing at this point. + assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism" + + # When we finish one task, exactly one more should start. + ray.get(test_wait.release.remote()) + ray.get(main_wait.acquire.remote()) + assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism" + + # Finish everything and make sure the output matches a regular iterator. + for i in range(3): + ray.get(test_wait.release.remote()) + + assert repr( + it) == "ParallelIterator[from_items[tuple, 8, shards=2].for_each()]" + assert ray.get(to_list.remote(it.gather_sync())) == list(range(10, 18)) + + def test_combine(ray_start_regular_shared): it = from_range(4, 1).combine(lambda x: [x, x]) assert repr(it) == "ParallelIterator[from_range[4, shards=1].combine()]" diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index bcd19736b..8d97ff31b 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -185,18 +185,50 @@ class ParallelIterator(Generic[T]): name=self.name + name, parent_iterators=self.parent_iterators) - def for_each(self, fn: Callable[[T], U]) -> "ParallelIterator[U]": - """Remotely apply fn to each item in this iterator. + def for_each(self, fn: Callable[[T], U], max_concurrency=1, + resources=None) -> "ParallelIterator[U]": + """Remotely apply fn to each item in this iterator, at most `max_concurrency` + at a time. + + If `max_concurrency` == 1 then `fn` will be executed serially by the + shards + + `max_concurrency` should be used to achieve a high degree of + parallelism without the overhead of increasing the number of shards + (which are actor based). This provides the semantic guarantee that + `fn(x_i)` will _begin_ executing before `fn(x_{i+1})` (but not + necessarily finish first) + + A performance note: When executing concurrently, this function + maintains its own internal buffer. If `async_queue_depth` is `n` and + max_concur is `k` then the total number of buffered objects could be up + to `n + k - 1` Args: fn (func): function to apply to each item. + max_concurrency (int): max number of concurrent calls to fn per + shard. If 0, then apply all operations concurrently. + resources (dict): resources that the function requires to execute. + This has the same default as `ray.remote` and is only used + when `max_concurrency > 1`. + + Returns: + ParallelIterator[U]: a parallel iterator whose elements have `fn` + applied. + Examples: - >>> next(from_range(4).for_each(lambda x: x * 2).gather_sync()) + >>> next(from_range(4).for_each( + lambda x: x * 2, + max_concur=2, + resources={"num_cpus": 0.1}).gather_sync() + ) ... [0, 2, 4, 8] + """ - return self._with_transform(lambda local_it: local_it.for_each(fn), - ".for_each()") + return self._with_transform( + lambda local_it: local_it.for_each(fn, max_concurrency, resources), + ".for_each()") def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]": """Remotely filter items from this iterator. @@ -639,20 +671,52 @@ class LocalIterator(Generic[T]): def __repr__(self): return "LocalIterator[{}]".format(self.name) - def for_each(self, fn: Callable[[T], U]) -> "LocalIterator[U]": - def apply_foreach(it): - for item in it: - if isinstance(item, _NextValueNotReady): - yield item - else: - # Keep retrying the function until it returns a valid - # value. This allows for non-blocking functions. - while True: - with self._metrics_context(): - result = fn(item) - yield result - if not isinstance(result, _NextValueNotReady): - break + def for_each(self, fn: Callable[[T], U], max_concurrency=1, + resources=None) -> "LocalIterator[U]": + if max_concurrency == 1: + + def apply_foreach(it): + for item in it: + if isinstance(item, _NextValueNotReady): + yield item + else: + # Keep retrying the function until it returns a valid + # value. This allows for non-blocking functions. + while True: + with self._metrics_context(): + result = fn(item) + yield result + if not isinstance(result, _NextValueNotReady): + break + else: + if resources is None: + resources = {} + + def apply_foreach(it): + cur = [] + remote = ray.remote(fn).options(**resources) + remote_fn = remote.remote + for item in it: + if isinstance(item, _NextValueNotReady): + yield item + else: + finished, remaining = ray.wait(cur, timeout=0) + if max_concurrency and len( + remaining) >= max_concurrency: + ray.wait(cur, num_returns=(len(finished) + 1)) + cur.append(remote_fn(item)) + + while len(cur) > 0: + to_yield = cur[0] + finished, remaining = ray.wait( + [to_yield], timeout=0) + if finished: + cur.pop(0) + yield ray.get(to_yield) + else: + break + + yield from ray.get(cur) if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME): unwrapped = apply_foreach