mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 16:00:00 +08:00
[rllib] Port Ape-X to distributed execution API (#7497)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
import collections
|
||||
from collections import Counter
|
||||
import pytest
|
||||
|
||||
@@ -32,6 +33,16 @@ def test_metrics(ray_start_regular_shared):
|
||||
LocalIterator.get_metrics()
|
||||
|
||||
|
||||
def test_zip_with_source_actor(ray_start_regular_shared):
|
||||
it = from_items([1, 2, 3, 4], num_shards=2)
|
||||
counts = collections.defaultdict(int)
|
||||
for actor, value in it.gather_async().zip_with_source_actor():
|
||||
counts[actor] += 1
|
||||
assert len(counts) == 2
|
||||
for a, count in counts.items():
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_metrics_union(ray_start_regular_shared):
|
||||
it1 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
it2 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
@@ -49,7 +60,8 @@ def test_metrics_union(ray_start_regular_shared):
|
||||
def verify_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["n"] += 1
|
||||
if metrics.counters["n"] > 2:
|
||||
# Check the metrics context is shared.
|
||||
if metrics.counters["n"] >= 2:
|
||||
assert "foo" in metrics.counters
|
||||
assert "bar" in metrics.counters
|
||||
return x
|
||||
@@ -238,6 +250,12 @@ def test_gather_async(ray_start_regular_shared):
|
||||
assert sorted(it) == [0, 1, 2, 3]
|
||||
|
||||
|
||||
def test_gather_async_queue(ray_start_regular_shared):
|
||||
it = from_range(100)
|
||||
it = it.gather_async(async_queue_depth=4)
|
||||
assert sorted(it) == list(range(100))
|
||||
|
||||
|
||||
def test_batch_across_shards(ray_start_regular_shared):
|
||||
it = from_iterators([[0, 1], [2, 3]])
|
||||
it = it.batch_across_shards()
|
||||
|
||||
+32
-7
@@ -414,12 +414,17 @@ class ParallelIterator(Generic[T]):
|
||||
name = "{}.batch_across_shards()".format(self)
|
||||
return LocalIterator(base_iterator, MetricsContext(), name=name)
|
||||
|
||||
def gather_async(self) -> "LocalIterator[T]":
|
||||
def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for asynchronous iteration.
|
||||
|
||||
New items will be fetched from the shards asynchronously as soon as
|
||||
the previous one is computed. Items arrive in non-deterministic order.
|
||||
|
||||
Arguments:
|
||||
async_queue_depth (int): The max number of async requests in flight
|
||||
per actor. Increasing this improves the amount of pipeline
|
||||
parallelism in the iterator.
|
||||
|
||||
Examples:
|
||||
>>> it = from_range(100, 1).gather_async()
|
||||
>>> next(it)
|
||||
@@ -430,16 +435,19 @@ class ParallelIterator(Generic[T]):
|
||||
... 1
|
||||
"""
|
||||
|
||||
metrics = MetricsContext()
|
||||
if async_queue_depth < 1:
|
||||
raise ValueError("queue depth must be positive")
|
||||
|
||||
def base_iterator(timeout=None):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
all_actors = []
|
||||
for actor_set in self.actor_sets:
|
||||
actor_set.init_actors()
|
||||
all_actors.extend(actor_set.actors)
|
||||
futures = {}
|
||||
for a in all_actors:
|
||||
futures[a.par_iter_next.remote()] = a
|
||||
for _ in range(async_queue_depth):
|
||||
for a in all_actors:
|
||||
futures[a.par_iter_next.remote()] = a
|
||||
while futures:
|
||||
pending = list(futures)
|
||||
if timeout is None:
|
||||
@@ -455,7 +463,7 @@ class ParallelIterator(Generic[T]):
|
||||
for obj_id in ready:
|
||||
actor = futures.pop(obj_id)
|
||||
try:
|
||||
metrics.cur_actor = actor
|
||||
metrics.current_actor = actor
|
||||
yield ray.get(obj_id)
|
||||
futures[actor.par_iter_next.remote()] = actor
|
||||
except StopIteration:
|
||||
@@ -465,7 +473,7 @@ class ParallelIterator(Generic[T]):
|
||||
yield _NextValueNotReady()
|
||||
|
||||
name = "{}.gather_async()".format(self)
|
||||
return LocalIterator(base_iterator, metrics, name=name)
|
||||
return LocalIterator(base_iterator, MetricsContext(), name=name)
|
||||
|
||||
def take(self, n: int) -> List[T]:
|
||||
"""Return up to the first n items from this iterator."""
|
||||
@@ -638,7 +646,13 @@ class LocalIterator(Generic[T]):
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
yield item
|
||||
else:
|
||||
yield fn(item)
|
||||
# Keep retrying the function until it returns a valid
|
||||
# value. This allows for non-blocking functions.
|
||||
while True:
|
||||
result = fn(item)
|
||||
yield result
|
||||
if not isinstance(result, _NextValueNotReady):
|
||||
break
|
||||
|
||||
if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME):
|
||||
unwrapped = apply_foreach
|
||||
@@ -758,6 +772,17 @@ class LocalIterator(Generic[T]):
|
||||
it.name = self.name + ".combine()"
|
||||
return it
|
||||
|
||||
def zip_with_source_actor(self):
|
||||
def zip_with_source(item):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
if metrics.current_actor is None:
|
||||
raise ValueError("Could not identify source actor of item")
|
||||
return metrics.current_actor, item
|
||||
|
||||
it = self.for_each(zip_with_source)
|
||||
it.name = self.name + ".zip_with_source_actor()"
|
||||
return it
|
||||
|
||||
def take(self, n: int) -> List[T]:
|
||||
"""Return up to the first n items from this iterator."""
|
||||
out = []
|
||||
|
||||
Reference in New Issue
Block a user