[Parallel Iterator] Local Shuffle (#6921)

* adding local shuffle and corresponding tests

* fix quotes

* addressing comments and adding seed argument

* formatting

* fix formatting issues

* change test size from small to medium

* addressing comments
This commit is contained in:
Amog Kamsetty
2020-01-30 12:27:38 -08:00
committed by GitHub
parent 136ada5fb9
commit c8bf0715a6
3 changed files with 121 additions and 10 deletions
+90 -9
View File
@@ -1,4 +1,5 @@
from typing import TypeVar, Generic, Iterable, List, Callable, Any
import random
import ray
@@ -13,7 +14,7 @@ def from_items(items: List[T], num_shards: int = 2,
The objects will be divided round-robin among the number of shards.
Arguments:
Args:
items (list): The list of items to iterate over.
num_shards (int): The number of worker actors to create.
repeat (bool): Whether to cycle over the items forever.
@@ -33,7 +34,7 @@ def from_range(n: int, num_shards: int = 2,
The range will be partitioned sequentially among the number of shards.
Arguments:
Args:
n (int): The max end of the range of numbers.
num_shards (int): The number of worker actors to create.
repeat (bool): Whether to cycle over the range forever.
@@ -66,7 +67,7 @@ def from_iterators(generators: List[Iterable[T]],
>>> # Equivalent to the above.
>>> from_iterators([lambda: range(100), lambda: range(100)])
Arguments:
Args:
generators (list): A list of Python generator objects or lambda
functions that produced a generator when called. We allow lambda
functions since the generator itself might not be serializable,
@@ -88,7 +89,7 @@ def from_actors(actors: List["ray.actor.ActorHandle"],
Each actor must subclass the ParallelIteratorWorker interface.
Arguments:
Args:
actors (list): List of actors that each implement
ParallelIteratorWorker.
name (str): Optional name to give the iterator.
@@ -166,7 +167,7 @@ class ParallelIterator(Generic[T]):
def for_each(self, fn: Callable[[T], U]) -> "ParallelIterator[U]":
"""Remotely apply fn to each item in this iterator.
Arguments:
Args:
fn (func): function to apply to each item.
Examples:
@@ -183,7 +184,7 @@ class ParallelIterator(Generic[T]):
def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]":
"""Remotely filter items from this iterator.
Arguments:
Args:
fn (func): returns False for items to drop from the iterator.
Examples:
@@ -201,7 +202,7 @@ class ParallelIterator(Generic[T]):
def batch(self, n: int) -> "ParallelIterator[List[T]]":
"""Remotely batch together items in this iterator.
Arguments:
Args:
n (int): Number of items to batch together.
Examples:
@@ -238,6 +239,46 @@ class ParallelIterator(Generic[T]):
it.name = self.name + ".combine()"
return it
def local_shuffle(self, shuffle_buffer_size: int,
seed: int = None) -> "ParallelIterator[T]":
"""Remotely shuffle items of each shard independently
Args:
shuffle_buffer_size (int): The algorithm fills a buffer with
shuffle_buffer_size elements and randomly samples elements from
this buffer, replacing the selected elements with new elements.
For perfect shuffling, this argument should be greater than or
equal to the largest iterator size.
seed (int): Seed to use for
randomness. Default value is None.
Returns:
Returns a ParallelIterator with a local shuffle applied on the
base iterator
Examples:
>>> it = from_range(10, 1).local_shuffle(shuffle_buffer_size=2)
>>> it = it.gather_sync()
>>> next(it)
0
>>> next(it)
2
>>> next(it)
3
>>> next(it)
1
"""
return ParallelIterator(
[
a.with_transform(
lambda localit: localit.shuffle(shuffle_buffer_size, seed))
for a in self.actor_sets
],
name=self.name +
".local_shuffle(shuffle_buffer_size={}, seed={})".format(
shuffle_buffer_size,
str(seed) if seed is not None else "None"))
def gather_sync(self) -> "LocalIterator[T]":
"""Returns a local iterable for synchronous iteration.
@@ -429,7 +470,7 @@ class LocalIterator(Generic[T]):
name=None):
"""Create a local iterator (this is an internal function).
Arguments:
Args:
base_iterator (func): A function that produces the base iterator.
This is a function so that we can ensure LocalIterator is
serializable.
@@ -527,6 +568,46 @@ class LocalIterator(Generic[T]):
self.local_transforms + [apply_flatten],
name=self.name + ".flatten()")
def shuffle(self, shuffle_buffer_size: int,
seed: int = None) -> "LocalIterator[T]":
"""Shuffle items of this iterator
Args:
shuffle_buffer_size (int): The algorithm fills a buffer with
shuffle_buffer_size elements and randomly samples elements from
this buffer, replacing the selected elements with new elements.
For perfect shuffling, this argument should be greater than or
equal to the largest iterator size.
seed (int): Seed to use for
randomness. Default value is None.
Returns:
A new LocalIterator with shuffling applied
"""
shuffle_random = random.Random(seed)
def apply_shuffle(it):
buffer = []
for item in it:
if isinstance(item, _NextValueNotReady):
yield item
else:
buffer.append(item)
if len(buffer) >= shuffle_buffer_size:
yield buffer.pop(
shuffle_random.randint(0,
len(buffer) - 1))
while len(buffer) > 0:
yield buffer.pop(shuffle_random.randint(0, len(buffer) - 1))
return LocalIterator(
self.base_iterator,
self.local_transforms + [apply_shuffle],
name=self.name +
".shuffle(shuffle_buffer_size={}, seed={})".format(
shuffle_buffer_size,
str(seed) if seed is not None else "None"))
def combine(self, fn: Callable[[T], List[U]]) -> "LocalIterator[U]":
it = self.for_each(fn).flatten()
it.name = self.name + ".combine()"
@@ -601,7 +682,7 @@ class ParallelIteratorWorker(object):
Subclasses must call this init function.
Arguments:
Args:
item_generator (obj): A Python generator objects or lambda function
that produces a generator when called. We allow lambda
functions since the generator itself might not be serializable,
+1 -1
View File
@@ -16,7 +16,7 @@ py_test(
py_test(
name = "test_iter",
size = "small",
size = "medium",
srcs = ["test_iter.py"],
tags = ["exclusive"],
deps = ["//:ray_lib"],
+30
View File
@@ -1,4 +1,5 @@
import time
from collections import Counter
import ray
from ray.experimental.iter import from_items, from_iterators, from_range, \
@@ -70,6 +71,35 @@ def test_filter(ray_start_regular_shared):
assert list(it.gather_sync()) == [0, 2, 1]
def test_local_shuffle(ray_start_regular_shared):
# confirm that no data disappears, and they all stay within the same shard
it = from_range(8, num_shards=2).local_shuffle(shuffle_buffer_size=2)
assert repr(it) == ("ParallelIterator[from_range[8, shards=2]" +
".local_shuffle(shuffle_buffer_size=2, seed=None)]")
shard_0 = it.get_shard(0)
shard_1 = it.get_shard(1)
assert set(shard_0) == {0, 1, 2, 3}
assert set(shard_1) == {4, 5, 6, 7}
# check that shuffling results in different orders
it1 = from_range(100, num_shards=10).local_shuffle(shuffle_buffer_size=5)
it2 = from_range(100, num_shards=10).local_shuffle(shuffle_buffer_size=5)
assert list(it1.gather_sync()) != list(it2.gather_sync())
# buffer size of 1 should not result in any shuffling
it3 = from_range(10, num_shards=1).local_shuffle(shuffle_buffer_size=1)
assert list(it3.gather_sync()) == list(range(10))
# statistical test
it4 = from_items(
[0, 1] * 10000, num_shards=1).local_shuffle(shuffle_buffer_size=100)
result = "".join(it4.gather_sync().for_each(str))
freq_counter = Counter(zip(result[:-1], result[1:]))
assert len(freq_counter) == 4
for key, value in freq_counter.items():
assert value / len(freq_counter) > 0.2
def test_batch(ray_start_regular_shared):
it = from_range(4, 1).batch(2)
assert repr(it) == "ParallelIterator[from_range[4, shards=1].batch(2)]"