[Parallel Iterators] Batching + Pipelining optimizations (#7931)

* batching + get_shard pipelining

* duplicate fix

* formatting

* adding performance benchmark

* minor changes

* turn batching off by default
This commit is contained in:
Amog Kamsetty
2020-05-26 00:37:57 -07:00
committed by GitHub
parent 26cffb9c7c
commit ae2e1f0883
2 changed files with 166 additions and 28 deletions
+48 -2
View File
@@ -329,12 +329,58 @@ def test_gather_async(ray_start_regular_shared):
assert sorted(it) == [0, 1, 2, 3]
def test_gather_async_queue(ray_start_regular_shared):
def test_gather_async_optimized(ray_start_regular_shared):
it = from_range(100)
it = it.gather_async(num_async=4)
it = it.gather_async(batch_ms=100, num_async=4)
assert sorted(it) == list(range(100))
def test_get_shard_optimized(ray_start_regular_shared):
it = from_range(6, num_shards=3)
shard1 = it.get_shard(shard_index=0, batch_ms=25, num_async=2)
shard2 = it.get_shard(shard_index=1, batch_ms=15, num_async=3)
shard3 = it.get_shard(shard_index=2, batch_ms=5, num_async=4)
assert list(shard1) == [0, 1]
assert list(shard2) == [2, 3]
assert list(shard3) == [4, 5]
# Tested on 5/13/20
# Run on 2019 Macbook Pro with 8 cores, 16 threads
# 14.52 sec
# 14.64 sec
# 0.935 sec
# 0.515 sec
"""
def test_gather_async_optimized_benchmark(ray_start_regular_shared):
import numpy as np
import tensorflow as tf
train, _ = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
num_bytes = images.nbytes / 1e6
items = list(images)
it = from_items(items, num_shards=4)
it = it.for_each(lambda img: img/255)
#local_it = it.gather_async(batch_ms=0, num_async=1)
#local_it = it.gather_async(batch_ms=0, num_async=3)
#local_it = it.gather_async(batch_ms=10, num_async=1)
#local_it = it.gather_async(batch_ms=10, num_async=3)
# dummy iterations
for i in range(20):
record = next(local_it)
start_time = time.time()
#print(start_time)
count = 0
for record in local_it:
count += 1
assert count == len(items) - 20
end_time = time.time() - start_time
print(end_time)
"""
def test_batch_across_shards(ray_start_regular_shared):
it = from_iterators([[0, 1], [2, 3]])
it = it.batch_across_shards()
+118 -26
View File
@@ -2,6 +2,7 @@ from contextlib import contextmanager
import collections
import random
import threading
import time
from typing import TypeVar, Generic, Iterable, List, Callable, Any
import ray
@@ -64,21 +65,28 @@ def from_range(n: int, num_shards: int = 2,
def from_iterators(generators: List[Iterable[T]],
repeat: bool = False,
name=None) -> "ParallelIterator[T]":
"""Create a parallel iterator from a set of iterators.
"""Create a parallel iterator from a list of iterables.
An iterable can be a conatiner (list, str, tuple, set, etc.),
a generator, or a custom class that implements __iter__ or __getitem__.
An actor will be created for each iterator.
An actor will be created for each iterable.
Examples:
>>> # Create using a list of generators.
>>> from_iterators([range(100), range(100)])
>>> # Equivalent to the above.
>>> from_iterators([lambda: range(100), lambda: range(100)])
>>> # Certain generators are not serializable.
>>> from_iterators([(x for x in range(100))])
... TypeError: can't pickle generator objects
>>> # So use lambda functions instead.
>>> # Lambda functions are serializable.
>>> from_iterators([lambda: (x for x in range(100))])
Args:
generators (list): A list of Python generator objects or lambda
functions that produced a generator when called. We allow lambda
functions since the generator itself might not be serializable,
generators (list): A list of Python iterables or lambda
functions that produce an iterable when called. We allow lambda
functions since certain generators might not be serializable,
but a lambda that returns it can be.
repeat (bool): Whether to cycle over the iterators forever.
name (str): Optional name to give the iterator.
@@ -188,9 +196,9 @@ class ParallelIterator(Generic[T]):
def for_each(self, fn: Callable[[T], U], max_concurrency=1,
resources=None) -> "ParallelIterator[U]":
"""Remotely apply fn to each item in this iterator, at most `max_concurrency`
at a time.
at a time per shard.
If `max_concurrency` == 1 then `fn` will be executed serially by the
If `max_concurrency` == 1 then `fn` will be executed serially by each
shards
`max_concurrency` should be used to achieve a high degree of
@@ -200,7 +208,7 @@ class ParallelIterator(Generic[T]):
necessarily finish first)
A performance note: When executing concurrently, this function
maintains its own internal buffer. If `async_queue_depth` is `n` and
maintains its own internal buffer. If `num_async` is `n` and
max_concur is `k` then the total number of buffered objects could be up
to `n + k - 1`
@@ -311,7 +319,8 @@ class ParallelIterator(Generic[T]):
shuffle_buffer_size,
str(seed) if seed is not None else "None"))
def repartition(self, num_partitions: int) -> "ParallelIterator[T]":
def repartition(self, num_partitions: int,
batch_ms: int = 0) -> "ParallelIterator[T]":
"""Returns a new ParallelIterator instance with num_partitions shards.
The new iterator contains the same data in this instance except with
@@ -321,6 +330,9 @@ class ParallelIterator(Generic[T]):
Args:
num_partitions (int): The number of shards to use for the new
ParallelIterator
batch_ms (int): Batches items for batch_ms milliseconds
on each shard before retrieving it.
Increasing batch_ms increases latency but improves throughput.
Returns:
A ParallelIterator with num_partitions number of shards and the
@@ -347,8 +359,10 @@ class ParallelIterator(Generic[T]):
def base_iterator(num_partitions, partition_index, timeout=None):
futures = {}
for a in all_actors:
futures[a.par_iter_slice.remote(
step=num_partitions, start=partition_index)] = a
futures[a.par_iter_slice_batch.remote(
step=num_partitions,
start=partition_index,
batch_ms=batch_ms)] = a
while futures:
pending = list(futures)
if timeout is None:
@@ -364,10 +378,13 @@ class ParallelIterator(Generic[T]):
for obj_id in ready:
actor = futures.pop(obj_id)
try:
yield ray.get(obj_id)
futures[actor.par_iter_slice.remote(
batch = ray.get(obj_id)
futures[actor.par_iter_slice_batch.remote(
step=num_partitions,
start=partition_index)] = actor
start=partition_index,
batch_ms=batch_ms)] = actor
for item in batch:
yield item
except StopIteration:
pass
# Always yield after each round of wait with timeout.
@@ -447,13 +464,17 @@ class ParallelIterator(Generic[T]):
name = "{}.batch_across_shards()".format(self)
return LocalIterator(base_iterator, SharedMetrics(), name=name)
def gather_async(self, num_async=1) -> "LocalIterator[T]":
def gather_async(self, batch_ms=0, num_async=1) -> "LocalIterator[T]":
"""Returns a local iterable for asynchronous iteration.
New items will be fetched from the shards asynchronously as soon as
the previous one is computed. Items arrive in non-deterministic order.
Arguments:
batch_ms (int): Batches items for batch_ms milliseconds
on each shard before retrieving it.
Increasing batch_ms increases latency but improves throughput.
If this value is 0, then items are returned immediately.
num_async (int): The max number of async requests in flight
per actor. Increasing this improves the amount of pipeline
parallelism in the iterator.
@@ -470,6 +491,8 @@ class ParallelIterator(Generic[T]):
if num_async < 1:
raise ValueError("queue depth must be positive")
if batch_ms < 0:
raise ValueError("batch time must be positive")
# Forward reference to the returned iterator.
local_iter = None
@@ -482,7 +505,7 @@ class ParallelIterator(Generic[T]):
futures = {}
for _ in range(num_async):
for a in all_actors:
futures[a.par_iter_next.remote()] = a
futures[a.par_iter_next_batch.remote(batch_ms)] = a
while futures:
pending = list(futures)
if timeout is None:
@@ -499,8 +522,11 @@ class ParallelIterator(Generic[T]):
actor = futures.pop(obj_id)
try:
local_iter.shared_metrics.get().current_actor = actor
yield ray.get(obj_id)
futures[actor.par_iter_next.remote()] = actor
batch = ray.get(obj_id)
futures[actor.par_iter_next_batch.remote(
batch_ms)] = actor
for item in batch:
yield item
except StopIteration:
pass
# Always yield after each round of wait with timeout.
@@ -522,7 +548,7 @@ class ParallelIterator(Generic[T]):
def union(self, other: "ParallelIterator[T]") -> "ParallelIterator[T]":
"""Return an iterator that is the union of this and the other."""
if not isinstance(other, ParallelIterator):
raise ValueError(
raise TypeError(
"other must be of type ParallelIterator, got {}".format(
type(other)))
actor_sets = []
@@ -566,12 +592,29 @@ class ParallelIterator(Generic[T]):
"""Return the list of all shards."""
return [self.get_shard(i) for i in range(self.num_shards())]
def get_shard(self, shard_index: int) -> "LocalIterator[T]":
def get_shard(self,
shard_index: int,
batch_ms: int = 0,
num_async: int = 1) -> "LocalIterator[T]":
"""Return a local iterator for the given shard.
The iterator is guaranteed to be serializable and can be passed to
remote tasks or actors.
Arguments:
shard_index (int): Index of the shard to gather.
batch_ms (int): Batches items for batch_ms milliseconds
before retrieving it.
Increasing batch_ms increases latency but improves throughput.
If this value is 0, then items are returned immediately.
num_async (int): The max number of requests in flight.
Increasing this improves the amount of pipeline
parallelism in the iterator.
"""
if num_async < 1:
raise ValueError("num async must be positive")
if batch_ms < 0:
raise ValueError("batch time must be positive")
a, t = None, None
i = shard_index
for actor_set in self.actor_sets:
@@ -586,10 +629,16 @@ class ParallelIterator(Generic[T]):
self.num_shards())
def base_iterator(timeout=None):
queue = collections.deque()
ray.get(a.par_iter_init.remote(t))
for _ in range(num_async):
queue.append(a.par_iter_next_batch.remote(batch_ms))
while True:
try:
yield ray.get(a.par_iter_next.remote(), timeout=timeout)
batch = ray.get(queue.popleft(), timeout=timeout)
queue.append(a.par_iter_next_batch.remote(batch_ms))
for item in batch:
yield item
# Always yield after each round of gets with timeout.
if timeout is not None:
yield _NextValueNotReady()
@@ -919,7 +968,10 @@ class LocalIterator(Generic[T]):
yield _NextValueNotReady()
else:
if len(queues[i]) == 0:
fill_next(timeout)
try:
fill_next(timeout)
except StopIteration:
return
yield queues[i].popleft()
return gen
@@ -1023,7 +1075,7 @@ class ParallelIteratorWorker(object):
Subclasses must call this init function.
Args:
item_generator (obj): A Python generator objects or lambda function
item_generator (obj): A Python iterable or lambda function
that produces a generator when called. We allow lambda
functions since the generator itself might not be serializable,
but a lambda that returns it can be.
@@ -1040,7 +1092,13 @@ class ParallelIteratorWorker(object):
def cycle():
while True:
it = make_iterator()
it = iter(make_iterator())
if it is item_generator:
raise ValueError(
"Cannot iterate over {} multiple times." +
"Please pass in the base iterable or" +
"lambda: {} instead.".format(
item_generator, item_generator))
for item in it:
yield item
@@ -1066,6 +1124,23 @@ class ParallelIteratorWorker(object):
assert self.local_it is not None, "must call par_iter_init()"
return next(self.local_it)
def par_iter_next_batch(self, batch_ms: int):
"""Batches par_iter_next."""
batch = []
if batch_ms == 0:
batch.append(self.par_iter_next())
return batch
t_end = time.time() + (0.001 * batch_ms)
while time.time() < t_end:
try:
batch.append(self.par_iter_next())
except StopIteration:
if len(batch) == 0:
raise StopIteration
else:
pass
return batch
def par_iter_slice(self, step: int, start: int):
"""Iterates in increments of step starting from start."""
assert self.local_it is not None, "must call par_iter_init()"
@@ -1089,6 +1164,23 @@ class ParallelIteratorWorker(object):
return self.next_ith_buffer[start].pop(0)
def par_iter_slice_batch(self, step: int, start: int, batch_ms: int):
"""Batches par_iter_slice."""
batch = []
if batch_ms == 0:
batch.append(self.par_iter_slice(step, start))
return batch
t_end = time.time() + (0.001 * batch_ms)
while time.time() < t_end:
try:
batch.append(self.par_iter_slice(step, start))
except StopIteration:
if len(batch) == 0:
raise StopIteration
else:
pass
return batch
def _randomized_int_cast(float_value):
base = int(float_value)