[iterators] Add duplicate() call and fix broken test case (#7510)

This commit is contained in:
Eric Liang
2020-03-09 17:18:52 -07:00
committed by GitHub
parent 883ee4912d
commit 90e23a5c43
2 changed files with 60 additions and 9 deletions
+13 -6
View File
@@ -49,13 +49,9 @@ def test_metrics_union(ray_start_regular_shared):
def verify_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["n"] += 1
# Check the unioned iterator gets a new metric context.
assert "foo" not in metrics.counters
assert "bar" not in metrics.counters
# Check parent metrics are accessible.
if metrics.counters["n"] > 2:
assert "foo" in metrics.parent_metrics[0].counters
assert "bar" in metrics.parent_metrics[1].counters
assert "foo" in metrics.counters
assert "bar" in metrics.counters
return x
it1 = it1.gather_async().for_each(foo_metrics)
@@ -116,6 +112,17 @@ def test_combine(ray_start_regular_shared):
assert list(it.gather_sync()) == [0, 0, 1, 1, 2, 2, 3, 3]
def test_duplicate(ray_start_regular_shared):
it = from_range(5, num_shards=1)
it1, it2 = it.gather_sync().duplicate(2)
it1 = it1.batch(2)
it3 = it1.union(it2, deterministic=False)
results = it3.take(20)
assert results == [0, [0, 1], 1, 2, [2, 3], 3, 4, [4]]
def test_chain(ray_start_regular_shared):
it = from_range(4).for_each(lambda x: x * 2).for_each(lambda x: x * 2)
assert repr(
+47 -3
View File
@@ -776,6 +776,46 @@ class LocalIterator(Generic[T]):
if i >= n:
break
def duplicate(self, n) -> List["LocalIterator[T]"]:
"""Copy this iterator `n` times, duplicating the data.
Returns:
List[LocalIterator[T]]: multiple iterators that each have a copy
of the data of this iterator.
"""
if n < 2:
raise ValueError("Number of copies must be >= 2")
queues = []
for _ in range(n):
queues.append(collections.deque())
def fill_next(timeout):
self.timeout = timeout
item = next(self)
for q in queues:
q.append(item)
def make_next(i):
def gen(timeout):
while True:
if len(queues[i]) == 0:
fill_next(timeout)
yield queues[i].popleft()
return gen
iterators = []
for i in range(n):
iterators.append(
LocalIterator(
make_next(i),
self.metrics, [],
name=self.name + ".duplicate[{}]".format(i)))
return iterators
def union(self, *others: "LocalIterator[T]",
deterministic: bool = False) -> "LocalIterator[T]":
"""Return an iterator that is the union of this and the others.
@@ -811,15 +851,19 @@ class LocalIterator(Generic[T]):
for it in list(active):
# Yield items from the iterator until _NextValueNotReady is
# found, then switch to the next iterator.
# To avoid starvation, we yield at most max_yield items per
# iterator before switching.
if deterministic:
max_yield = 1 # Forces round robin.
else:
max_yield = 20
try:
while True:
for _ in range(max_yield):
item = next(it)
if isinstance(item, _NextValueNotReady):
break
else:
yield item
if deterministic:
break
except StopIteration:
active.remove(it)
if not active: