mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 10:18:53 +08:00
[Parallel Iterator] Local Shuffle (#6921)
* adding local shuffle and corresponding tests * fix quotes * addressing comments and adding seed argument * formatting * fix formatting issues * change test size from small to medium * addressing comments
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from typing import TypeVar, Generic, Iterable, List, Callable, Any
|
||||
import random
|
||||
|
||||
import ray
|
||||
|
||||
@@ -13,7 +14,7 @@ def from_items(items: List[T], num_shards: int = 2,
|
||||
|
||||
The objects will be divided round-robin among the number of shards.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
items (list): The list of items to iterate over.
|
||||
num_shards (int): The number of worker actors to create.
|
||||
repeat (bool): Whether to cycle over the items forever.
|
||||
@@ -33,7 +34,7 @@ def from_range(n: int, num_shards: int = 2,
|
||||
|
||||
The range will be partitioned sequentially among the number of shards.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
n (int): The max end of the range of numbers.
|
||||
num_shards (int): The number of worker actors to create.
|
||||
repeat (bool): Whether to cycle over the range forever.
|
||||
@@ -66,7 +67,7 @@ def from_iterators(generators: List[Iterable[T]],
|
||||
>>> # Equivalent to the above.
|
||||
>>> from_iterators([lambda: range(100), lambda: range(100)])
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
generators (list): A list of Python generator objects or lambda
|
||||
functions that produced a generator when called. We allow lambda
|
||||
functions since the generator itself might not be serializable,
|
||||
@@ -88,7 +89,7 @@ def from_actors(actors: List["ray.actor.ActorHandle"],
|
||||
|
||||
Each actor must subclass the ParallelIteratorWorker interface.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
actors (list): List of actors that each implement
|
||||
ParallelIteratorWorker.
|
||||
name (str): Optional name to give the iterator.
|
||||
@@ -166,7 +167,7 @@ class ParallelIterator(Generic[T]):
|
||||
def for_each(self, fn: Callable[[T], U]) -> "ParallelIterator[U]":
|
||||
"""Remotely apply fn to each item in this iterator.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
fn (func): function to apply to each item.
|
||||
|
||||
Examples:
|
||||
@@ -183,7 +184,7 @@ class ParallelIterator(Generic[T]):
|
||||
def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]":
|
||||
"""Remotely filter items from this iterator.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
fn (func): returns False for items to drop from the iterator.
|
||||
|
||||
Examples:
|
||||
@@ -201,7 +202,7 @@ class ParallelIterator(Generic[T]):
|
||||
def batch(self, n: int) -> "ParallelIterator[List[T]]":
|
||||
"""Remotely batch together items in this iterator.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
n (int): Number of items to batch together.
|
||||
|
||||
Examples:
|
||||
@@ -238,6 +239,46 @@ class ParallelIterator(Generic[T]):
|
||||
it.name = self.name + ".combine()"
|
||||
return it
|
||||
|
||||
def local_shuffle(self, shuffle_buffer_size: int,
|
||||
seed: int = None) -> "ParallelIterator[T]":
|
||||
"""Remotely shuffle items of each shard independently
|
||||
|
||||
Args:
|
||||
shuffle_buffer_size (int): The algorithm fills a buffer with
|
||||
shuffle_buffer_size elements and randomly samples elements from
|
||||
this buffer, replacing the selected elements with new elements.
|
||||
For perfect shuffling, this argument should be greater than or
|
||||
equal to the largest iterator size.
|
||||
seed (int): Seed to use for
|
||||
randomness. Default value is None.
|
||||
|
||||
Returns:
|
||||
Returns a ParallelIterator with a local shuffle applied on the
|
||||
base iterator
|
||||
|
||||
Examples:
|
||||
>>> it = from_range(10, 1).local_shuffle(shuffle_buffer_size=2)
|
||||
>>> it = it.gather_sync()
|
||||
>>> next(it)
|
||||
0
|
||||
>>> next(it)
|
||||
2
|
||||
>>> next(it)
|
||||
3
|
||||
>>> next(it)
|
||||
1
|
||||
"""
|
||||
return ParallelIterator(
|
||||
[
|
||||
a.with_transform(
|
||||
lambda localit: localit.shuffle(shuffle_buffer_size, seed))
|
||||
for a in self.actor_sets
|
||||
],
|
||||
name=self.name +
|
||||
".local_shuffle(shuffle_buffer_size={}, seed={})".format(
|
||||
shuffle_buffer_size,
|
||||
str(seed) if seed is not None else "None"))
|
||||
|
||||
def gather_sync(self) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for synchronous iteration.
|
||||
|
||||
@@ -429,7 +470,7 @@ class LocalIterator(Generic[T]):
|
||||
name=None):
|
||||
"""Create a local iterator (this is an internal function).
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
base_iterator (func): A function that produces the base iterator.
|
||||
This is a function so that we can ensure LocalIterator is
|
||||
serializable.
|
||||
@@ -527,6 +568,46 @@ class LocalIterator(Generic[T]):
|
||||
self.local_transforms + [apply_flatten],
|
||||
name=self.name + ".flatten()")
|
||||
|
||||
def shuffle(self, shuffle_buffer_size: int,
|
||||
seed: int = None) -> "LocalIterator[T]":
|
||||
"""Shuffle items of this iterator
|
||||
|
||||
Args:
|
||||
shuffle_buffer_size (int): The algorithm fills a buffer with
|
||||
shuffle_buffer_size elements and randomly samples elements from
|
||||
this buffer, replacing the selected elements with new elements.
|
||||
For perfect shuffling, this argument should be greater than or
|
||||
equal to the largest iterator size.
|
||||
seed (int): Seed to use for
|
||||
randomness. Default value is None.
|
||||
|
||||
Returns:
|
||||
A new LocalIterator with shuffling applied
|
||||
"""
|
||||
shuffle_random = random.Random(seed)
|
||||
|
||||
def apply_shuffle(it):
|
||||
buffer = []
|
||||
for item in it:
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
yield item
|
||||
else:
|
||||
buffer.append(item)
|
||||
if len(buffer) >= shuffle_buffer_size:
|
||||
yield buffer.pop(
|
||||
shuffle_random.randint(0,
|
||||
len(buffer) - 1))
|
||||
while len(buffer) > 0:
|
||||
yield buffer.pop(shuffle_random.randint(0, len(buffer) - 1))
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.local_transforms + [apply_shuffle],
|
||||
name=self.name +
|
||||
".shuffle(shuffle_buffer_size={}, seed={})".format(
|
||||
shuffle_buffer_size,
|
||||
str(seed) if seed is not None else "None"))
|
||||
|
||||
def combine(self, fn: Callable[[T], List[U]]) -> "LocalIterator[U]":
|
||||
it = self.for_each(fn).flatten()
|
||||
it.name = self.name + ".combine()"
|
||||
@@ -601,7 +682,7 @@ class ParallelIteratorWorker(object):
|
||||
|
||||
Subclasses must call this init function.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
item_generator (obj): A Python generator objects or lambda function
|
||||
that produces a generator when called. We allow lambda
|
||||
functions since the generator itself might not be serializable,
|
||||
|
||||
Reference in New Issue
Block a user