mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 19:17:01 +08:00
[Parallel Iterator] Foreach concur (#8140)
This commit is contained in:
@@ -223,6 +223,21 @@ class SignalActor:
|
||||
await self.ready_event.wait()
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class Semaphore:
|
||||
def __init__(self, value=1):
|
||||
self._sema = asyncio.Semaphore(value=value)
|
||||
|
||||
async def acquire(self):
|
||||
self._sema.acquire()
|
||||
|
||||
async def release(self):
|
||||
self._sema.release()
|
||||
|
||||
async def locked(self):
|
||||
return self._sema.locked()
|
||||
|
||||
|
||||
@ray.remote
|
||||
def _put(obj):
|
||||
return obj
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
import ray
|
||||
from ray.util.iter import from_items, from_iterators, from_range, \
|
||||
from_actors, ParallelIteratorWorker, LocalIterator
|
||||
from ray.test_utils import Semaphore
|
||||
|
||||
|
||||
def test_metrics(ray_start_regular_shared):
|
||||
@@ -158,6 +159,44 @@ def test_for_each(ray_start_regular_shared):
|
||||
assert list(it.gather_sync()) == [0, 4, 2, 6]
|
||||
|
||||
|
||||
def test_for_each_concur(ray_start_regular_shared):
|
||||
main_wait = Semaphore.remote(value=0)
|
||||
test_wait = Semaphore.remote(value=0)
|
||||
|
||||
def task(x):
|
||||
i, main_wait, test_wait = x
|
||||
ray.get(main_wait.release.remote())
|
||||
ray.get(test_wait.acquire.remote())
|
||||
return i + 10
|
||||
|
||||
@ray.remote(num_cpus=0.1)
|
||||
def to_list(it):
|
||||
return list(it)
|
||||
|
||||
it = from_items(
|
||||
[(i, main_wait, test_wait) for i in range(8)], num_shards=2)
|
||||
it = it.for_each(task, max_concurrency=2, resources={"num_cpus": 0.1})
|
||||
|
||||
for i in range(4):
|
||||
ray.get(main_wait.acquire.remote())
|
||||
|
||||
# There should be exactly 4 tasks executing at this point.
|
||||
assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism"
|
||||
|
||||
# When we finish one task, exactly one more should start.
|
||||
ray.get(test_wait.release.remote())
|
||||
ray.get(main_wait.acquire.remote())
|
||||
assert ray.get(main_wait.locked.remote()) is True, "Too much parallelism"
|
||||
|
||||
# Finish everything and make sure the output matches a regular iterator.
|
||||
for i in range(3):
|
||||
ray.get(test_wait.release.remote())
|
||||
|
||||
assert repr(
|
||||
it) == "ParallelIterator[from_items[tuple, 8, shards=2].for_each()]"
|
||||
assert ray.get(to_list.remote(it.gather_sync())) == list(range(10, 18))
|
||||
|
||||
|
||||
def test_combine(ray_start_regular_shared):
|
||||
it = from_range(4, 1).combine(lambda x: [x, x])
|
||||
assert repr(it) == "ParallelIterator[from_range[4, shards=1].combine()]"
|
||||
|
||||
+83
-19
@@ -185,18 +185,50 @@ class ParallelIterator(Generic[T]):
|
||||
name=self.name + name,
|
||||
parent_iterators=self.parent_iterators)
|
||||
|
||||
def for_each(self, fn: Callable[[T], U]) -> "ParallelIterator[U]":
|
||||
"""Remotely apply fn to each item in this iterator.
|
||||
def for_each(self, fn: Callable[[T], U], max_concurrency=1,
|
||||
resources=None) -> "ParallelIterator[U]":
|
||||
"""Remotely apply fn to each item in this iterator, at most `max_concurrency`
|
||||
at a time.
|
||||
|
||||
If `max_concurrency` == 1 then `fn` will be executed serially by the
|
||||
shards
|
||||
|
||||
`max_concurrency` should be used to achieve a high degree of
|
||||
parallelism without the overhead of increasing the number of shards
|
||||
(which are actor based). This provides the semantic guarantee that
|
||||
`fn(x_i)` will _begin_ executing before `fn(x_{i+1})` (but not
|
||||
necessarily finish first)
|
||||
|
||||
A performance note: When executing concurrently, this function
|
||||
maintains its own internal buffer. If `async_queue_depth` is `n` and
|
||||
max_concur is `k` then the total number of buffered objects could be up
|
||||
to `n + k - 1`
|
||||
|
||||
Args:
|
||||
fn (func): function to apply to each item.
|
||||
max_concurrency (int): max number of concurrent calls to fn per
|
||||
shard. If 0, then apply all operations concurrently.
|
||||
resources (dict): resources that the function requires to execute.
|
||||
This has the same default as `ray.remote` and is only used
|
||||
when `max_concurrency > 1`.
|
||||
|
||||
Returns:
|
||||
ParallelIterator[U]: a parallel iterator whose elements have `fn`
|
||||
applied.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> next(from_range(4).for_each(lambda x: x * 2).gather_sync())
|
||||
>>> next(from_range(4).for_each(
|
||||
lambda x: x * 2,
|
||||
max_concur=2,
|
||||
resources={"num_cpus": 0.1}).gather_sync()
|
||||
)
|
||||
... [0, 2, 4, 8]
|
||||
|
||||
"""
|
||||
return self._with_transform(lambda local_it: local_it.for_each(fn),
|
||||
".for_each()")
|
||||
return self._with_transform(
|
||||
lambda local_it: local_it.for_each(fn, max_concurrency, resources),
|
||||
".for_each()")
|
||||
|
||||
def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]":
|
||||
"""Remotely filter items from this iterator.
|
||||
@@ -639,20 +671,52 @@ class LocalIterator(Generic[T]):
|
||||
def __repr__(self):
|
||||
return "LocalIterator[{}]".format(self.name)
|
||||
|
||||
def for_each(self, fn: Callable[[T], U]) -> "LocalIterator[U]":
|
||||
def apply_foreach(it):
|
||||
for item in it:
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
yield item
|
||||
else:
|
||||
# Keep retrying the function until it returns a valid
|
||||
# value. This allows for non-blocking functions.
|
||||
while True:
|
||||
with self._metrics_context():
|
||||
result = fn(item)
|
||||
yield result
|
||||
if not isinstance(result, _NextValueNotReady):
|
||||
break
|
||||
def for_each(self, fn: Callable[[T], U], max_concurrency=1,
|
||||
resources=None) -> "LocalIterator[U]":
|
||||
if max_concurrency == 1:
|
||||
|
||||
def apply_foreach(it):
|
||||
for item in it:
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
yield item
|
||||
else:
|
||||
# Keep retrying the function until it returns a valid
|
||||
# value. This allows for non-blocking functions.
|
||||
while True:
|
||||
with self._metrics_context():
|
||||
result = fn(item)
|
||||
yield result
|
||||
if not isinstance(result, _NextValueNotReady):
|
||||
break
|
||||
else:
|
||||
if resources is None:
|
||||
resources = {}
|
||||
|
||||
def apply_foreach(it):
|
||||
cur = []
|
||||
remote = ray.remote(fn).options(**resources)
|
||||
remote_fn = remote.remote
|
||||
for item in it:
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
yield item
|
||||
else:
|
||||
finished, remaining = ray.wait(cur, timeout=0)
|
||||
if max_concurrency and len(
|
||||
remaining) >= max_concurrency:
|
||||
ray.wait(cur, num_returns=(len(finished) + 1))
|
||||
cur.append(remote_fn(item))
|
||||
|
||||
while len(cur) > 0:
|
||||
to_yield = cur[0]
|
||||
finished, remaining = ray.wait(
|
||||
[to_yield], timeout=0)
|
||||
if finished:
|
||||
cur.pop(0)
|
||||
yield ray.get(to_yield)
|
||||
else:
|
||||
break
|
||||
|
||||
yield from ray.get(cur)
|
||||
|
||||
if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME):
|
||||
unwrapped = apply_foreach
|
||||
|
||||
Reference in New Issue
Block a user