mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 15:40:09 +08:00
@@ -292,7 +292,7 @@ def test_gather_async(ray_start_regular_shared):
|
||||
|
||||
def test_gather_async_queue(ray_start_regular_shared):
|
||||
it = from_range(100)
|
||||
it = it.gather_async(async_queue_depth=4)
|
||||
it = it.gather_async(num_async=4)
|
||||
assert sorted(it) == list(range(100))
|
||||
|
||||
|
||||
|
||||
@@ -415,14 +415,14 @@ class ParallelIterator(Generic[T]):
|
||||
name = "{}.batch_across_shards()".format(self)
|
||||
return LocalIterator(base_iterator, SharedMetrics(), name=name)
|
||||
|
||||
def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]":
|
||||
def gather_async(self, num_async=1) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for asynchronous iteration.
|
||||
|
||||
New items will be fetched from the shards asynchronously as soon as
|
||||
the previous one is computed. Items arrive in non-deterministic order.
|
||||
|
||||
Arguments:
|
||||
async_queue_depth (int): The max number of async requests in flight
|
||||
num_async (int): The max number of async requests in flight
|
||||
per actor. Increasing this improves the amount of pipeline
|
||||
parallelism in the iterator.
|
||||
|
||||
@@ -436,7 +436,7 @@ class ParallelIterator(Generic[T]):
|
||||
... 1
|
||||
"""
|
||||
|
||||
if async_queue_depth < 1:
|
||||
if num_async < 1:
|
||||
raise ValueError("queue depth must be positive")
|
||||
|
||||
# Forward reference to the returned iterator.
|
||||
@@ -448,7 +448,7 @@ class ParallelIterator(Generic[T]):
|
||||
actor_set.init_actors()
|
||||
all_actors.extend(actor_set.actors)
|
||||
futures = {}
|
||||
for _ in range(async_queue_depth):
|
||||
for _ in range(num_async):
|
||||
for a in all_actors:
|
||||
futures[a.par_iter_next.remote()] = a
|
||||
while futures:
|
||||
|
||||
Reference in New Issue
Block a user