From 90e23a5c43724b5d8a1ce0fc075f3939643c4ec4 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 9 Mar 2020 17:18:52 -0700 Subject: [PATCH] [iterators] Add duplicate() call and fix broken test case (#7510) --- python/ray/tests/test_iter.py | 19 ++++++++----- python/ray/util/iter.py | 50 ++++++++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/python/ray/tests/test_iter.py b/python/ray/tests/test_iter.py index fdac5167b..86f17d7e9 100644 --- a/python/ray/tests/test_iter.py +++ b/python/ray/tests/test_iter.py @@ -49,13 +49,9 @@ def test_metrics_union(ray_start_regular_shared): def verify_metrics(x): metrics = LocalIterator.get_metrics() metrics.counters["n"] += 1 - # Check the unioned iterator gets a new metric context. - assert "foo" not in metrics.counters - assert "bar" not in metrics.counters - # Check parent metrics are accessible. if metrics.counters["n"] > 2: - assert "foo" in metrics.parent_metrics[0].counters - assert "bar" in metrics.parent_metrics[1].counters + assert "foo" in metrics.counters + assert "bar" in metrics.counters return x it1 = it1.gather_async().for_each(foo_metrics) @@ -116,6 +112,17 @@ def test_combine(ray_start_regular_shared): assert list(it.gather_sync()) == [0, 0, 1, 1, 2, 2, 3, 3] +def test_duplicate(ray_start_regular_shared): + it = from_range(5, num_shards=1) + + it1, it2 = it.gather_sync().duplicate(2) + it1 = it1.batch(2) + + it3 = it1.union(it2, deterministic=False) + results = it3.take(20) + assert results == [0, [0, 1], 1, 2, [2, 3], 3, 4, [4]] + + def test_chain(ray_start_regular_shared): it = from_range(4).for_each(lambda x: x * 2).for_each(lambda x: x * 2) assert repr( diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index 6595c4326..514dfbc04 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -776,6 +776,46 @@ class LocalIterator(Generic[T]): if i >= n: break + def duplicate(self, n) -> List["LocalIterator[T]"]: + """Copy this iterator `n` times, duplicating the data. + + Returns: + List[LocalIterator[T]]: multiple iterators that each have a copy + of the data of this iterator. + """ + + if n < 2: + raise ValueError("Number of copies must be >= 2") + + queues = [] + for _ in range(n): + queues.append(collections.deque()) + + def fill_next(timeout): + self.timeout = timeout + item = next(self) + for q in queues: + q.append(item) + + def make_next(i): + def gen(timeout): + while True: + if len(queues[i]) == 0: + fill_next(timeout) + yield queues[i].popleft() + + return gen + + iterators = [] + for i in range(n): + iterators.append( + LocalIterator( + make_next(i), + self.metrics, [], + name=self.name + ".duplicate[{}]".format(i))) + + return iterators + def union(self, *others: "LocalIterator[T]", deterministic: bool = False) -> "LocalIterator[T]": """Return an iterator that is the union of this and the others. @@ -811,15 +851,19 @@ class LocalIterator(Generic[T]): for it in list(active): # Yield items from the iterator until _NextValueNotReady is # found, then switch to the next iterator. + # To avoid starvation, we yield at most max_yield items per + # iterator before switching. + if deterministic: + max_yield = 1 # Forces round robin. + else: + max_yield = 20 try: - while True: + for _ in range(max_yield): item = next(it) if isinstance(item, _NextValueNotReady): break else: yield item - if deterministic: - break except StopIteration: active.remove(it) if not active: