mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 15:00:17 +08:00
[iterators] Add duplicate() call and fix broken test case (#7510)
This commit is contained in:
@@ -49,13 +49,9 @@ def test_metrics_union(ray_start_regular_shared):
|
||||
def verify_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["n"] += 1
|
||||
# Check the unioned iterator gets a new metric context.
|
||||
assert "foo" not in metrics.counters
|
||||
assert "bar" not in metrics.counters
|
||||
# Check parent metrics are accessible.
|
||||
if metrics.counters["n"] > 2:
|
||||
assert "foo" in metrics.parent_metrics[0].counters
|
||||
assert "bar" in metrics.parent_metrics[1].counters
|
||||
assert "foo" in metrics.counters
|
||||
assert "bar" in metrics.counters
|
||||
return x
|
||||
|
||||
it1 = it1.gather_async().for_each(foo_metrics)
|
||||
@@ -116,6 +112,17 @@ def test_combine(ray_start_regular_shared):
|
||||
assert list(it.gather_sync()) == [0, 0, 1, 1, 2, 2, 3, 3]
|
||||
|
||||
|
||||
def test_duplicate(ray_start_regular_shared):
|
||||
it = from_range(5, num_shards=1)
|
||||
|
||||
it1, it2 = it.gather_sync().duplicate(2)
|
||||
it1 = it1.batch(2)
|
||||
|
||||
it3 = it1.union(it2, deterministic=False)
|
||||
results = it3.take(20)
|
||||
assert results == [0, [0, 1], 1, 2, [2, 3], 3, 4, [4]]
|
||||
|
||||
|
||||
def test_chain(ray_start_regular_shared):
|
||||
it = from_range(4).for_each(lambda x: x * 2).for_each(lambda x: x * 2)
|
||||
assert repr(
|
||||
|
||||
+47
-3
@@ -776,6 +776,46 @@ class LocalIterator(Generic[T]):
|
||||
if i >= n:
|
||||
break
|
||||
|
||||
def duplicate(self, n) -> List["LocalIterator[T]"]:
|
||||
"""Copy this iterator `n` times, duplicating the data.
|
||||
|
||||
Returns:
|
||||
List[LocalIterator[T]]: multiple iterators that each have a copy
|
||||
of the data of this iterator.
|
||||
"""
|
||||
|
||||
if n < 2:
|
||||
raise ValueError("Number of copies must be >= 2")
|
||||
|
||||
queues = []
|
||||
for _ in range(n):
|
||||
queues.append(collections.deque())
|
||||
|
||||
def fill_next(timeout):
|
||||
self.timeout = timeout
|
||||
item = next(self)
|
||||
for q in queues:
|
||||
q.append(item)
|
||||
|
||||
def make_next(i):
|
||||
def gen(timeout):
|
||||
while True:
|
||||
if len(queues[i]) == 0:
|
||||
fill_next(timeout)
|
||||
yield queues[i].popleft()
|
||||
|
||||
return gen
|
||||
|
||||
iterators = []
|
||||
for i in range(n):
|
||||
iterators.append(
|
||||
LocalIterator(
|
||||
make_next(i),
|
||||
self.metrics, [],
|
||||
name=self.name + ".duplicate[{}]".format(i)))
|
||||
|
||||
return iterators
|
||||
|
||||
def union(self, *others: "LocalIterator[T]",
|
||||
deterministic: bool = False) -> "LocalIterator[T]":
|
||||
"""Return an iterator that is the union of this and the others.
|
||||
@@ -811,15 +851,19 @@ class LocalIterator(Generic[T]):
|
||||
for it in list(active):
|
||||
# Yield items from the iterator until _NextValueNotReady is
|
||||
# found, then switch to the next iterator.
|
||||
# To avoid starvation, we yield at most max_yield items per
|
||||
# iterator before switching.
|
||||
if deterministic:
|
||||
max_yield = 1 # Forces round robin.
|
||||
else:
|
||||
max_yield = 20
|
||||
try:
|
||||
while True:
|
||||
for _ in range(max_yield):
|
||||
item = next(it)
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
break
|
||||
else:
|
||||
yield item
|
||||
if deterministic:
|
||||
break
|
||||
except StopIteration:
|
||||
active.remove(it)
|
||||
if not active:
|
||||
|
||||
Reference in New Issue
Block a user