[rllib] Port Ape-X to distributed execution API (#7497)

This commit is contained in:
Eric Liang
2020-03-12 00:54:08 -07:00
committed by GitHub
parent 4c834b9d68
commit f5d12a958b
17 changed files with 401 additions and 116 deletions
+19 -1
View File
@@ -1,4 +1,5 @@
import time
import collections
from collections import Counter
import pytest
@@ -32,6 +33,16 @@ def test_metrics(ray_start_regular_shared):
LocalIterator.get_metrics()
def test_zip_with_source_actor(ray_start_regular_shared):
it = from_items([1, 2, 3, 4], num_shards=2)
counts = collections.defaultdict(int)
for actor, value in it.gather_async().zip_with_source_actor():
counts[actor] += 1
assert len(counts) == 2
for a, count in counts.items():
assert count == 2
def test_metrics_union(ray_start_regular_shared):
it1 = from_items([1, 2, 3, 4], num_shards=1)
it2 = from_items([1, 2, 3, 4], num_shards=1)
@@ -49,7 +60,8 @@ def test_metrics_union(ray_start_regular_shared):
def verify_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["n"] += 1
if metrics.counters["n"] > 2:
# Check the metrics context is shared.
if metrics.counters["n"] >= 2:
assert "foo" in metrics.counters
assert "bar" in metrics.counters
return x
@@ -238,6 +250,12 @@ def test_gather_async(ray_start_regular_shared):
assert sorted(it) == [0, 1, 2, 3]
def test_gather_async_queue(ray_start_regular_shared):
it = from_range(100)
it = it.gather_async(async_queue_depth=4)
assert sorted(it) == list(range(100))
def test_batch_across_shards(ray_start_regular_shared):
it = from_iterators([[0, 1], [2, 3]])
it = it.batch_across_shards()
+32 -7
View File
@@ -414,12 +414,17 @@ class ParallelIterator(Generic[T]):
name = "{}.batch_across_shards()".format(self)
return LocalIterator(base_iterator, MetricsContext(), name=name)
def gather_async(self) -> "LocalIterator[T]":
def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]":
"""Returns a local iterable for asynchronous iteration.
New items will be fetched from the shards asynchronously as soon as
the previous one is computed. Items arrive in non-deterministic order.
Arguments:
async_queue_depth (int): The max number of async requests in flight
per actor. Increasing this improves the amount of pipeline
parallelism in the iterator.
Examples:
>>> it = from_range(100, 1).gather_async()
>>> next(it)
@@ -430,16 +435,19 @@ class ParallelIterator(Generic[T]):
... 1
"""
metrics = MetricsContext()
if async_queue_depth < 1:
raise ValueError("queue depth must be positive")
def base_iterator(timeout=None):
metrics = LocalIterator.get_metrics()
all_actors = []
for actor_set in self.actor_sets:
actor_set.init_actors()
all_actors.extend(actor_set.actors)
futures = {}
for a in all_actors:
futures[a.par_iter_next.remote()] = a
for _ in range(async_queue_depth):
for a in all_actors:
futures[a.par_iter_next.remote()] = a
while futures:
pending = list(futures)
if timeout is None:
@@ -455,7 +463,7 @@ class ParallelIterator(Generic[T]):
for obj_id in ready:
actor = futures.pop(obj_id)
try:
metrics.cur_actor = actor
metrics.current_actor = actor
yield ray.get(obj_id)
futures[actor.par_iter_next.remote()] = actor
except StopIteration:
@@ -465,7 +473,7 @@ class ParallelIterator(Generic[T]):
yield _NextValueNotReady()
name = "{}.gather_async()".format(self)
return LocalIterator(base_iterator, metrics, name=name)
return LocalIterator(base_iterator, MetricsContext(), name=name)
def take(self, n: int) -> List[T]:
"""Return up to the first n items from this iterator."""
@@ -638,7 +646,13 @@ class LocalIterator(Generic[T]):
if isinstance(item, _NextValueNotReady):
yield item
else:
yield fn(item)
# Keep retrying the function until it returns a valid
# value. This allows for non-blocking functions.
while True:
result = fn(item)
yield result
if not isinstance(result, _NextValueNotReady):
break
if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME):
unwrapped = apply_foreach
@@ -758,6 +772,17 @@ class LocalIterator(Generic[T]):
it.name = self.name + ".combine()"
return it
def zip_with_source_actor(self):
def zip_with_source(item):
metrics = LocalIterator.get_metrics()
if metrics.current_actor is None:
raise ValueError("Could not identify source actor of item")
return metrics.current_actor, item
it = self.for_each(zip_with_source)
it.name = self.name + ".zip_with_source_actor()"
return it
def take(self, n: int) -> List[T]:
"""Return up to the first n items from this iterator."""
out = []