From c8bf0715a662bd7d5e4486225da3f2e19dc792ad Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Thu, 30 Jan 2020 12:27:38 -0800 Subject: [PATCH] [Parallel Iterator] Local Shuffle (#6921) * adding local shuffle and corresponding tests * fix quotes * addressing comments and adding seed argument * formatting * fix formatting issues * change test size from small to medium * addressing comments --- python/ray/experimental/iter.py | 99 ++++++++++++++++++++++++++++++--- python/ray/tests/BUILD | 2 +- python/ray/tests/test_iter.py | 30 ++++++++++ 3 files changed, 121 insertions(+), 10 deletions(-) diff --git a/python/ray/experimental/iter.py b/python/ray/experimental/iter.py index d4aaf64ff..89c6e6262 100644 --- a/python/ray/experimental/iter.py +++ b/python/ray/experimental/iter.py @@ -1,4 +1,5 @@ from typing import TypeVar, Generic, Iterable, List, Callable, Any +import random import ray @@ -13,7 +14,7 @@ def from_items(items: List[T], num_shards: int = 2, The objects will be divided round-robin among the number of shards. - Arguments: + Args: items (list): The list of items to iterate over. num_shards (int): The number of worker actors to create. repeat (bool): Whether to cycle over the items forever. @@ -33,7 +34,7 @@ def from_range(n: int, num_shards: int = 2, The range will be partitioned sequentially among the number of shards. - Arguments: + Args: n (int): The max end of the range of numbers. num_shards (int): The number of worker actors to create. repeat (bool): Whether to cycle over the range forever. @@ -66,7 +67,7 @@ def from_iterators(generators: List[Iterable[T]], >>> # Equivalent to the above. >>> from_iterators([lambda: range(100), lambda: range(100)]) - Arguments: + Args: generators (list): A list of Python generator objects or lambda functions that produced a generator when called. We allow lambda functions since the generator itself might not be serializable, @@ -88,7 +89,7 @@ def from_actors(actors: List["ray.actor.ActorHandle"], Each actor must subclass the ParallelIteratorWorker interface. - Arguments: + Args: actors (list): List of actors that each implement ParallelIteratorWorker. name (str): Optional name to give the iterator. @@ -166,7 +167,7 @@ class ParallelIterator(Generic[T]): def for_each(self, fn: Callable[[T], U]) -> "ParallelIterator[U]": """Remotely apply fn to each item in this iterator. - Arguments: + Args: fn (func): function to apply to each item. Examples: @@ -183,7 +184,7 @@ class ParallelIterator(Generic[T]): def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]": """Remotely filter items from this iterator. - Arguments: + Args: fn (func): returns False for items to drop from the iterator. Examples: @@ -201,7 +202,7 @@ class ParallelIterator(Generic[T]): def batch(self, n: int) -> "ParallelIterator[List[T]]": """Remotely batch together items in this iterator. - Arguments: + Args: n (int): Number of items to batch together. Examples: @@ -238,6 +239,46 @@ class ParallelIterator(Generic[T]): it.name = self.name + ".combine()" return it + def local_shuffle(self, shuffle_buffer_size: int, + seed: int = None) -> "ParallelIterator[T]": + """Remotely shuffle items of each shard independently + + Args: + shuffle_buffer_size (int): The algorithm fills a buffer with + shuffle_buffer_size elements and randomly samples elements from + this buffer, replacing the selected elements with new elements. + For perfect shuffling, this argument should be greater than or + equal to the largest iterator size. + seed (int): Seed to use for + randomness. Default value is None. + + Returns: + Returns a ParallelIterator with a local shuffle applied on the + base iterator + + Examples: + >>> it = from_range(10, 1).local_shuffle(shuffle_buffer_size=2) + >>> it = it.gather_sync() + >>> next(it) + 0 + >>> next(it) + 2 + >>> next(it) + 3 + >>> next(it) + 1 + """ + return ParallelIterator( + [ + a.with_transform( + lambda localit: localit.shuffle(shuffle_buffer_size, seed)) + for a in self.actor_sets + ], + name=self.name + + ".local_shuffle(shuffle_buffer_size={}, seed={})".format( + shuffle_buffer_size, + str(seed) if seed is not None else "None")) + def gather_sync(self) -> "LocalIterator[T]": """Returns a local iterable for synchronous iteration. @@ -429,7 +470,7 @@ class LocalIterator(Generic[T]): name=None): """Create a local iterator (this is an internal function). - Arguments: + Args: base_iterator (func): A function that produces the base iterator. This is a function so that we can ensure LocalIterator is serializable. @@ -527,6 +568,46 @@ class LocalIterator(Generic[T]): self.local_transforms + [apply_flatten], name=self.name + ".flatten()") + def shuffle(self, shuffle_buffer_size: int, + seed: int = None) -> "LocalIterator[T]": + """Shuffle items of this iterator + + Args: + shuffle_buffer_size (int): The algorithm fills a buffer with + shuffle_buffer_size elements and randomly samples elements from + this buffer, replacing the selected elements with new elements. + For perfect shuffling, this argument should be greater than or + equal to the largest iterator size. + seed (int): Seed to use for + randomness. Default value is None. + + Returns: + A new LocalIterator with shuffling applied + """ + shuffle_random = random.Random(seed) + + def apply_shuffle(it): + buffer = [] + for item in it: + if isinstance(item, _NextValueNotReady): + yield item + else: + buffer.append(item) + if len(buffer) >= shuffle_buffer_size: + yield buffer.pop( + shuffle_random.randint(0, + len(buffer) - 1)) + while len(buffer) > 0: + yield buffer.pop(shuffle_random.randint(0, len(buffer) - 1)) + + return LocalIterator( + self.base_iterator, + self.local_transforms + [apply_shuffle], + name=self.name + + ".shuffle(shuffle_buffer_size={}, seed={})".format( + shuffle_buffer_size, + str(seed) if seed is not None else "None")) + def combine(self, fn: Callable[[T], List[U]]) -> "LocalIterator[U]": it = self.for_each(fn).flatten() it.name = self.name + ".combine()" @@ -601,7 +682,7 @@ class ParallelIteratorWorker(object): Subclasses must call this init function. - Arguments: + Args: item_generator (obj): A Python generator objects or lambda function that produces a generator when called. We allow lambda functions since the generator itself might not be serializable, diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 9a1bbde6d..27d346c87 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -16,7 +16,7 @@ py_test( py_test( name = "test_iter", - size = "small", + size = "medium", srcs = ["test_iter.py"], tags = ["exclusive"], deps = ["//:ray_lib"], diff --git a/python/ray/tests/test_iter.py b/python/ray/tests/test_iter.py index 2e38011c1..a811aaa12 100644 --- a/python/ray/tests/test_iter.py +++ b/python/ray/tests/test_iter.py @@ -1,4 +1,5 @@ import time +from collections import Counter import ray from ray.experimental.iter import from_items, from_iterators, from_range, \ @@ -70,6 +71,35 @@ def test_filter(ray_start_regular_shared): assert list(it.gather_sync()) == [0, 2, 1] +def test_local_shuffle(ray_start_regular_shared): + # confirm that no data disappears, and they all stay within the same shard + it = from_range(8, num_shards=2).local_shuffle(shuffle_buffer_size=2) + assert repr(it) == ("ParallelIterator[from_range[8, shards=2]" + + ".local_shuffle(shuffle_buffer_size=2, seed=None)]") + shard_0 = it.get_shard(0) + shard_1 = it.get_shard(1) + assert set(shard_0) == {0, 1, 2, 3} + assert set(shard_1) == {4, 5, 6, 7} + + # check that shuffling results in different orders + it1 = from_range(100, num_shards=10).local_shuffle(shuffle_buffer_size=5) + it2 = from_range(100, num_shards=10).local_shuffle(shuffle_buffer_size=5) + assert list(it1.gather_sync()) != list(it2.gather_sync()) + + # buffer size of 1 should not result in any shuffling + it3 = from_range(10, num_shards=1).local_shuffle(shuffle_buffer_size=1) + assert list(it3.gather_sync()) == list(range(10)) + + # statistical test + it4 = from_items( + [0, 1] * 10000, num_shards=1).local_shuffle(shuffle_buffer_size=100) + result = "".join(it4.gather_sync().for_each(str)) + freq_counter = Counter(zip(result[:-1], result[1:])) + assert len(freq_counter) == 4 + for key, value in freq_counter.items(): + assert value / len(freq_counter) > 0.2 + + def test_batch(ray_start_regular_shared): it = from_range(4, 1).batch(2) assert repr(it) == "ParallelIterator[from_range[4, shards=1].batch(2)]"