[Parallel Iterator] Foreach concur (#8140)

This commit is contained in:
Alex Wu
2020-05-06 08:00:01 -07:00
committed by GitHub
parent ec9357b486
commit 04813c2ef5
3 changed files with 137 additions and 19 deletions
+15
View File
@@ -223,6 +223,21 @@ class SignalActor:
await self.ready_event.wait()
@ray.remote(num_cpus=0)
class Semaphore:
def __init__(self, value=1):
self._sema = asyncio.Semaphore(value=value)
async def acquire(self):
self._sema.acquire()
async def release(self):
self._sema.release()
async def locked(self):
return self._sema.locked()
@ray.remote
def _put(obj):
return obj
+39
View File
@@ -6,6 +6,7 @@ import pytest
import ray
from ray.util.iter import from_items, from_iterators, from_range, \
from_actors, ParallelIteratorWorker, LocalIterator
from ray.test_utils import Semaphore
def test_metrics(ray_start_regular_shared):
@@ -158,6 +159,44 @@ def test_for_each(ray_start_regular_shared):
assert list(it.gather_sync()) == [0, 4, 2, 6]
def test_for_each_concur(ray_start_regular_shared):
main_wait = Semaphore.remote(value=0)
test_wait = Semaphore.remote(value=0)
def task(x):
i, main_wait, test_wait = x
ray.get(main_wait.release.remote())
ray.get(test_wait.acquire.remote())
return i + 10
@ray.remote(num_cpus=0.1)
def to_list(it):
return list(it)
it = from_items(
[(i, main_wait, test_wait) for i in range(8)], num_shards=2)
it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.1})
for i in range(4):
ray.get(main_wait.acquire.remote())
# There should be exactly 4 tasks executing at this point.
assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism"
# When we finish one task, exactly one more should start.
ray.get(test_wait.release.remote())
ray.get(main_wait.acquire.remote())
assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism"
# Finish everything and make sure the output matches a regular iterator.
for i in range(3):
ray.get(test_wait.release.remote())
assert repr(
it) == "ParallelIterator[from_items[tuple, 8, shards=2].for_each()]"
assert ray.get(to_list.remote(it.gather_sync())) == list(range(10, 18))
def test_combine(ray_start_regular_shared):
it = from_range(4, 1).combine(lambda x: [x, x])
assert repr(it) == "ParallelIterator[from_range[4, shards=1].combine()]"
+83 -19
View File
@@ -185,18 +185,50 @@ class ParallelIterator(Generic[T]):
name=self.name + name,
parent_iterators=self.parent_iterators)
def for_each(self, fn: Callable[[T], U]) -> "ParallelIterator[U]":
"""Remotely apply fn to each item in this iterator.
def for_each(self, fn: Callable[[T], U], max_concurrency=1,
resources=None) -> "ParallelIterator[U]":
"""Remotely apply fn to each item in this iterator, at most `max_concurrency`
at a time.
If `max_concurrency` == 1 then `fn` will be executed serially by the
shards
`max_concurrency` should be used to achieve a high degree of
parallelism without the overhead of increasing the number of shards
(which are actor based). This provides the semantic guarantee that
`fn(x_i)` will _begin_ executing before `fn(x_{i+1})` (but not
necessarily finish first)
A performance note: When executing concurrently, this function
maintains its own internal buffer. If `async_queue_depth` is `n` and
max_concur is `k` then the total number of buffered objects could be up
to `n + k - 1`
Args:
fn (func): function to apply to each item.
max_concurrency (int): max number of concurrent calls to fn per
shard. If 0, then apply all operations concurrently.
resources (dict): resources that the function requires to execute.
This has the same default as `ray.remote` and is only used
when `max_concurrency > 1`.
Returns:
ParallelIterator[U]: a parallel iterator whose elements have `fn`
applied.
Examples:
>>> next(from_range(4).for_each(lambda x: x * 2).gather_sync())
>>> next(from_range(4).for_each(
lambda x: x * 2,
max_concur=2,
resources={"num_cpus": 0.1}).gather_sync()
)
... [0, 2, 4, 8]
"""
return self._with_transform(lambda local_it: local_it.for_each(fn),
".for_each()")
return self._with_transform(
lambda local_it: local_it.for_each(fn, max_concurrency, resources),
".for_each()")
def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]":
"""Remotely filter items from this iterator.
@@ -639,20 +671,52 @@ class LocalIterator(Generic[T]):
def __repr__(self):
return "LocalIterator[{}]".format(self.name)
def for_each(self, fn: Callable[[T], U]) -> "LocalIterator[U]":
def apply_foreach(it):
for item in it:
if isinstance(item, _NextValueNotReady):
yield item
else:
# Keep retrying the function until it returns a valid
# value. This allows for non-blocking functions.
while True:
with self._metrics_context():
result = fn(item)
yield result
if not isinstance(result, _NextValueNotReady):
break
def for_each(self, fn: Callable[[T], U], max_concurrency=1,
resources=None) -> "LocalIterator[U]":
if max_concurrency == 1:
def apply_foreach(it):
for item in it:
if isinstance(item, _NextValueNotReady):
yield item
else:
# Keep retrying the function until it returns a valid
# value. This allows for non-blocking functions.
while True:
with self._metrics_context():
result = fn(item)
yield result
if not isinstance(result, _NextValueNotReady):
break
else:
if resources is None:
resources = {}
def apply_foreach(it):
cur = []
remote = ray.remote(fn).options(**resources)
remote_fn = remote.remote
for item in it:
if isinstance(item, _NextValueNotReady):
yield item
else:
finished, remaining = ray.wait(cur, timeout=0)
if max_concurrency and len(
remaining) >= max_concurrency:
ray.wait(cur, num_returns=(len(finished) + 1))
cur.append(remote_fn(item))
while len(cur) > 0:
to_yield = cur[0]
finished, remaining = ray.wait(
[to_yield], timeout=0)
if finished:
cur.pop(0)
yield ray.get(to_yield)
else:
break
yield from ray.get(cur)
if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME):
unwrapped = apply_foreach