mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:01:24 +08:00
[Parallel Iterator] Local Shuffle (#6921)
* adding local shuffle and corresponding tests * fix quotes * addressing comments and adding seed argument * formatting * fix formatting issues * change test size from small to medium * addressing comments
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from typing import TypeVar, Generic, Iterable, List, Callable, Any
|
||||
import random
|
||||
|
||||
import ray
|
||||
|
||||
@@ -13,7 +14,7 @@ def from_items(items: List[T], num_shards: int = 2,
|
||||
|
||||
The objects will be divided round-robin among the number of shards.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
items (list): The list of items to iterate over.
|
||||
num_shards (int): The number of worker actors to create.
|
||||
repeat (bool): Whether to cycle over the items forever.
|
||||
@@ -33,7 +34,7 @@ def from_range(n: int, num_shards: int = 2,
|
||||
|
||||
The range will be partitioned sequentially among the number of shards.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
n (int): The max end of the range of numbers.
|
||||
num_shards (int): The number of worker actors to create.
|
||||
repeat (bool): Whether to cycle over the range forever.
|
||||
@@ -66,7 +67,7 @@ def from_iterators(generators: List[Iterable[T]],
|
||||
>>> # Equivalent to the above.
|
||||
>>> from_iterators([lambda: range(100), lambda: range(100)])
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
generators (list): A list of Python generator objects or lambda
|
||||
functions that produced a generator when called. We allow lambda
|
||||
functions since the generator itself might not be serializable,
|
||||
@@ -88,7 +89,7 @@ def from_actors(actors: List["ray.actor.ActorHandle"],
|
||||
|
||||
Each actor must subclass the ParallelIteratorWorker interface.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
actors (list): List of actors that each implement
|
||||
ParallelIteratorWorker.
|
||||
name (str): Optional name to give the iterator.
|
||||
@@ -166,7 +167,7 @@ class ParallelIterator(Generic[T]):
|
||||
def for_each(self, fn: Callable[[T], U]) -> "ParallelIterator[U]":
|
||||
"""Remotely apply fn to each item in this iterator.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
fn (func): function to apply to each item.
|
||||
|
||||
Examples:
|
||||
@@ -183,7 +184,7 @@ class ParallelIterator(Generic[T]):
|
||||
def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]":
|
||||
"""Remotely filter items from this iterator.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
fn (func): returns False for items to drop from the iterator.
|
||||
|
||||
Examples:
|
||||
@@ -201,7 +202,7 @@ class ParallelIterator(Generic[T]):
|
||||
def batch(self, n: int) -> "ParallelIterator[List[T]]":
|
||||
"""Remotely batch together items in this iterator.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
n (int): Number of items to batch together.
|
||||
|
||||
Examples:
|
||||
@@ -238,6 +239,46 @@ class ParallelIterator(Generic[T]):
|
||||
it.name = self.name + ".combine()"
|
||||
return it
|
||||
|
||||
def local_shuffle(self, shuffle_buffer_size: int,
|
||||
seed: int = None) -> "ParallelIterator[T]":
|
||||
"""Remotely shuffle items of each shard independently
|
||||
|
||||
Args:
|
||||
shuffle_buffer_size (int): The algorithm fills a buffer with
|
||||
shuffle_buffer_size elements and randomly samples elements from
|
||||
this buffer, replacing the selected elements with new elements.
|
||||
For perfect shuffling, this argument should be greater than or
|
||||
equal to the largest iterator size.
|
||||
seed (int): Seed to use for
|
||||
randomness. Default value is None.
|
||||
|
||||
Returns:
|
||||
Returns a ParallelIterator with a local shuffle applied on the
|
||||
base iterator
|
||||
|
||||
Examples:
|
||||
>>> it = from_range(10, 1).local_shuffle(shuffle_buffer_size=2)
|
||||
>>> it = it.gather_sync()
|
||||
>>> next(it)
|
||||
0
|
||||
>>> next(it)
|
||||
2
|
||||
>>> next(it)
|
||||
3
|
||||
>>> next(it)
|
||||
1
|
||||
"""
|
||||
return ParallelIterator(
|
||||
[
|
||||
a.with_transform(
|
||||
lambda localit: localit.shuffle(shuffle_buffer_size, seed))
|
||||
for a in self.actor_sets
|
||||
],
|
||||
name=self.name +
|
||||
".local_shuffle(shuffle_buffer_size={}, seed={})".format(
|
||||
shuffle_buffer_size,
|
||||
str(seed) if seed is not None else "None"))
|
||||
|
||||
def gather_sync(self) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for synchronous iteration.
|
||||
|
||||
@@ -429,7 +470,7 @@ class LocalIterator(Generic[T]):
|
||||
name=None):
|
||||
"""Create a local iterator (this is an internal function).
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
base_iterator (func): A function that produces the base iterator.
|
||||
This is a function so that we can ensure LocalIterator is
|
||||
serializable.
|
||||
@@ -527,6 +568,46 @@ class LocalIterator(Generic[T]):
|
||||
self.local_transforms + [apply_flatten],
|
||||
name=self.name + ".flatten()")
|
||||
|
||||
def shuffle(self, shuffle_buffer_size: int,
|
||||
seed: int = None) -> "LocalIterator[T]":
|
||||
"""Shuffle items of this iterator
|
||||
|
||||
Args:
|
||||
shuffle_buffer_size (int): The algorithm fills a buffer with
|
||||
shuffle_buffer_size elements and randomly samples elements from
|
||||
this buffer, replacing the selected elements with new elements.
|
||||
For perfect shuffling, this argument should be greater than or
|
||||
equal to the largest iterator size.
|
||||
seed (int): Seed to use for
|
||||
randomness. Default value is None.
|
||||
|
||||
Returns:
|
||||
A new LocalIterator with shuffling applied
|
||||
"""
|
||||
shuffle_random = random.Random(seed)
|
||||
|
||||
def apply_shuffle(it):
|
||||
buffer = []
|
||||
for item in it:
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
yield item
|
||||
else:
|
||||
buffer.append(item)
|
||||
if len(buffer) >= shuffle_buffer_size:
|
||||
yield buffer.pop(
|
||||
shuffle_random.randint(0,
|
||||
len(buffer) - 1))
|
||||
while len(buffer) > 0:
|
||||
yield buffer.pop(shuffle_random.randint(0, len(buffer) - 1))
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.local_transforms + [apply_shuffle],
|
||||
name=self.name +
|
||||
".shuffle(shuffle_buffer_size={}, seed={})".format(
|
||||
shuffle_buffer_size,
|
||||
str(seed) if seed is not None else "None"))
|
||||
|
||||
def combine(self, fn: Callable[[T], List[U]]) -> "LocalIterator[U]":
|
||||
it = self.for_each(fn).flatten()
|
||||
it.name = self.name + ".combine()"
|
||||
@@ -601,7 +682,7 @@ class ParallelIteratorWorker(object):
|
||||
|
||||
Subclasses must call this init function.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
item_generator (obj): A Python generator objects or lambda function
|
||||
that produces a generator when called. We allow lambda
|
||||
functions since the generator itself might not be serializable,
|
||||
|
||||
@@ -16,7 +16,7 @@ py_test(
|
||||
|
||||
py_test(
|
||||
name = "test_iter",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["test_iter.py"],
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
from collections import Counter
|
||||
|
||||
import ray
|
||||
from ray.experimental.iter import from_items, from_iterators, from_range, \
|
||||
@@ -70,6 +71,35 @@ def test_filter(ray_start_regular_shared):
|
||||
assert list(it.gather_sync()) == [0, 2, 1]
|
||||
|
||||
|
||||
def test_local_shuffle(ray_start_regular_shared):
|
||||
# confirm that no data disappears, and they all stay within the same shard
|
||||
it = from_range(8, num_shards=2).local_shuffle(shuffle_buffer_size=2)
|
||||
assert repr(it) == ("ParallelIterator[from_range[8, shards=2]" +
|
||||
".local_shuffle(shuffle_buffer_size=2, seed=None)]")
|
||||
shard_0 = it.get_shard(0)
|
||||
shard_1 = it.get_shard(1)
|
||||
assert set(shard_0) == {0, 1, 2, 3}
|
||||
assert set(shard_1) == {4, 5, 6, 7}
|
||||
|
||||
# check that shuffling results in different orders
|
||||
it1 = from_range(100, num_shards=10).local_shuffle(shuffle_buffer_size=5)
|
||||
it2 = from_range(100, num_shards=10).local_shuffle(shuffle_buffer_size=5)
|
||||
assert list(it1.gather_sync()) != list(it2.gather_sync())
|
||||
|
||||
# buffer size of 1 should not result in any shuffling
|
||||
it3 = from_range(10, num_shards=1).local_shuffle(shuffle_buffer_size=1)
|
||||
assert list(it3.gather_sync()) == list(range(10))
|
||||
|
||||
# statistical test
|
||||
it4 = from_items(
|
||||
[0, 1] * 10000, num_shards=1).local_shuffle(shuffle_buffer_size=100)
|
||||
result = "".join(it4.gather_sync().for_each(str))
|
||||
freq_counter = Counter(zip(result[:-1], result[1:]))
|
||||
assert len(freq_counter) == 4
|
||||
for key, value in freq_counter.items():
|
||||
assert value / len(freq_counter) > 0.2
|
||||
|
||||
|
||||
def test_batch(ray_start_regular_shared):
|
||||
it = from_range(4, 1).batch(2)
|
||||
assert repr(it) == "ParallelIterator[from_range[4, shards=1].batch(2)]"
|
||||
|
||||
Reference in New Issue
Block a user