mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 13:49:45 +08:00
[Parallel Iterators] Batching + Pipelining optimizations (#7931)
* batching + get_shard pipelining * duplicate fix * formatting * adding performance benchmark * minor changes * turn batching off by default
This commit is contained in:
@@ -329,12 +329,58 @@ def test_gather_async(ray_start_regular_shared):
|
||||
assert sorted(it) == [0, 1, 2, 3]
|
||||
|
||||
|
||||
def test_gather_async_queue(ray_start_regular_shared):
|
||||
def test_gather_async_optimized(ray_start_regular_shared):
|
||||
it = from_range(100)
|
||||
it = it.gather_async(num_async=4)
|
||||
it = it.gather_async(batch_ms=100, num_async=4)
|
||||
assert sorted(it) == list(range(100))
|
||||
|
||||
|
||||
def test_get_shard_optimized(ray_start_regular_shared):
|
||||
it = from_range(6, num_shards=3)
|
||||
shard1 = it.get_shard(shard_index=0, batch_ms=25, num_async=2)
|
||||
shard2 = it.get_shard(shard_index=1, batch_ms=15, num_async=3)
|
||||
shard3 = it.get_shard(shard_index=2, batch_ms=5, num_async=4)
|
||||
assert list(shard1) == [0, 1]
|
||||
assert list(shard2) == [2, 3]
|
||||
assert list(shard3) == [4, 5]
|
||||
|
||||
|
||||
# Tested on 5/13/20
|
||||
# Run on 2019 Macbook Pro with 8 cores, 16 threads
|
||||
# 14.52 sec
|
||||
# 14.64 sec
|
||||
# 0.935 sec
|
||||
# 0.515 sec
|
||||
"""
|
||||
def test_gather_async_optimized_benchmark(ray_start_regular_shared):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
train, _ = tf.keras.datasets.fashion_mnist.load_data()
|
||||
images, labels = train
|
||||
num_bytes = images.nbytes / 1e6
|
||||
items = list(images)
|
||||
it = from_items(items, num_shards=4)
|
||||
it = it.for_each(lambda img: img/255)
|
||||
#local_it = it.gather_async(batch_ms=0, num_async=1)
|
||||
#local_it = it.gather_async(batch_ms=0, num_async=3)
|
||||
#local_it = it.gather_async(batch_ms=10, num_async=1)
|
||||
#local_it = it.gather_async(batch_ms=10, num_async=3)
|
||||
|
||||
# dummy iterations
|
||||
for i in range(20):
|
||||
record = next(local_it)
|
||||
|
||||
start_time = time.time()
|
||||
#print(start_time)
|
||||
count = 0
|
||||
for record in local_it:
|
||||
count += 1
|
||||
assert count == len(items) - 20
|
||||
end_time = time.time() - start_time
|
||||
print(end_time)
|
||||
"""
|
||||
|
||||
|
||||
def test_batch_across_shards(ray_start_regular_shared):
|
||||
it = from_iterators([[0, 1], [2, 3]])
|
||||
it = it.batch_across_shards()
|
||||
|
||||
+118
-26
@@ -2,6 +2,7 @@ from contextlib import contextmanager
|
||||
import collections
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from typing import TypeVar, Generic, Iterable, List, Callable, Any
|
||||
|
||||
import ray
|
||||
@@ -64,21 +65,28 @@ def from_range(n: int, num_shards: int = 2,
|
||||
def from_iterators(generators: List[Iterable[T]],
|
||||
repeat: bool = False,
|
||||
name=None) -> "ParallelIterator[T]":
|
||||
"""Create a parallel iterator from a set of iterators.
|
||||
"""Create a parallel iterator from a list of iterables.
|
||||
An iterable can be a conatiner (list, str, tuple, set, etc.),
|
||||
a generator, or a custom class that implements __iter__ or __getitem__.
|
||||
|
||||
An actor will be created for each iterator.
|
||||
An actor will be created for each iterable.
|
||||
|
||||
Examples:
|
||||
>>> # Create using a list of generators.
|
||||
>>> from_iterators([range(100), range(100)])
|
||||
|
||||
>>> # Equivalent to the above.
|
||||
>>> from_iterators([lambda: range(100), lambda: range(100)])
|
||||
>>> # Certain generators are not serializable.
|
||||
>>> from_iterators([(x for x in range(100))])
|
||||
... TypeError: can't pickle generator objects
|
||||
|
||||
>>> # So use lambda functions instead.
|
||||
>>> # Lambda functions are serializable.
|
||||
>>> from_iterators([lambda: (x for x in range(100))])
|
||||
|
||||
Args:
|
||||
generators (list): A list of Python generator objects or lambda
|
||||
functions that produced a generator when called. We allow lambda
|
||||
functions since the generator itself might not be serializable,
|
||||
generators (list): A list of Python iterables or lambda
|
||||
functions that produce an iterable when called. We allow lambda
|
||||
functions since certain generators might not be serializable,
|
||||
but a lambda that returns it can be.
|
||||
repeat (bool): Whether to cycle over the iterators forever.
|
||||
name (str): Optional name to give the iterator.
|
||||
@@ -188,9 +196,9 @@ class ParallelIterator(Generic[T]):
|
||||
def for_each(self, fn: Callable[[T], U], max_concurrency=1,
|
||||
resources=None) -> "ParallelIterator[U]":
|
||||
"""Remotely apply fn to each item in this iterator, at most `max_concurrency`
|
||||
at a time.
|
||||
at a time per shard.
|
||||
|
||||
If `max_concurrency` == 1 then `fn` will be executed serially by the
|
||||
If `max_concurrency` == 1 then `fn` will be executed serially by each
|
||||
shards
|
||||
|
||||
`max_concurrency` should be used to achieve a high degree of
|
||||
@@ -200,7 +208,7 @@ class ParallelIterator(Generic[T]):
|
||||
necessarily finish first)
|
||||
|
||||
A performance note: When executing concurrently, this function
|
||||
maintains its own internal buffer. If `async_queue_depth` is `n` and
|
||||
maintains its own internal buffer. If `num_async` is `n` and
|
||||
max_concur is `k` then the total number of buffered objects could be up
|
||||
to `n + k - 1`
|
||||
|
||||
@@ -311,7 +319,8 @@ class ParallelIterator(Generic[T]):
|
||||
shuffle_buffer_size,
|
||||
str(seed) if seed is not None else "None"))
|
||||
|
||||
def repartition(self, num_partitions: int) -> "ParallelIterator[T]":
|
||||
def repartition(self, num_partitions: int,
|
||||
batch_ms: int = 0) -> "ParallelIterator[T]":
|
||||
"""Returns a new ParallelIterator instance with num_partitions shards.
|
||||
|
||||
The new iterator contains the same data in this instance except with
|
||||
@@ -321,6 +330,9 @@ class ParallelIterator(Generic[T]):
|
||||
Args:
|
||||
num_partitions (int): The number of shards to use for the new
|
||||
ParallelIterator
|
||||
batch_ms (int): Batches items for batch_ms milliseconds
|
||||
on each shard before retrieving it.
|
||||
Increasing batch_ms increases latency but improves throughput.
|
||||
|
||||
Returns:
|
||||
A ParallelIterator with num_partitions number of shards and the
|
||||
@@ -347,8 +359,10 @@ class ParallelIterator(Generic[T]):
|
||||
def base_iterator(num_partitions, partition_index, timeout=None):
|
||||
futures = {}
|
||||
for a in all_actors:
|
||||
futures[a.par_iter_slice.remote(
|
||||
step=num_partitions, start=partition_index)] = a
|
||||
futures[a.par_iter_slice_batch.remote(
|
||||
step=num_partitions,
|
||||
start=partition_index,
|
||||
batch_ms=batch_ms)] = a
|
||||
while futures:
|
||||
pending = list(futures)
|
||||
if timeout is None:
|
||||
@@ -364,10 +378,13 @@ class ParallelIterator(Generic[T]):
|
||||
for obj_id in ready:
|
||||
actor = futures.pop(obj_id)
|
||||
try:
|
||||
yield ray.get(obj_id)
|
||||
futures[actor.par_iter_slice.remote(
|
||||
batch = ray.get(obj_id)
|
||||
futures[actor.par_iter_slice_batch.remote(
|
||||
step=num_partitions,
|
||||
start=partition_index)] = actor
|
||||
start=partition_index,
|
||||
batch_ms=batch_ms)] = actor
|
||||
for item in batch:
|
||||
yield item
|
||||
except StopIteration:
|
||||
pass
|
||||
# Always yield after each round of wait with timeout.
|
||||
@@ -447,13 +464,17 @@ class ParallelIterator(Generic[T]):
|
||||
name = "{}.batch_across_shards()".format(self)
|
||||
return LocalIterator(base_iterator, SharedMetrics(), name=name)
|
||||
|
||||
def gather_async(self, num_async=1) -> "LocalIterator[T]":
|
||||
def gather_async(self, batch_ms=0, num_async=1) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for asynchronous iteration.
|
||||
|
||||
New items will be fetched from the shards asynchronously as soon as
|
||||
the previous one is computed. Items arrive in non-deterministic order.
|
||||
|
||||
Arguments:
|
||||
batch_ms (int): Batches items for batch_ms milliseconds
|
||||
on each shard before retrieving it.
|
||||
Increasing batch_ms increases latency but improves throughput.
|
||||
If this value is 0, then items are returned immediately.
|
||||
num_async (int): The max number of async requests in flight
|
||||
per actor. Increasing this improves the amount of pipeline
|
||||
parallelism in the iterator.
|
||||
@@ -470,6 +491,8 @@ class ParallelIterator(Generic[T]):
|
||||
|
||||
if num_async < 1:
|
||||
raise ValueError("queue depth must be positive")
|
||||
if batch_ms < 0:
|
||||
raise ValueError("batch time must be positive")
|
||||
|
||||
# Forward reference to the returned iterator.
|
||||
local_iter = None
|
||||
@@ -482,7 +505,7 @@ class ParallelIterator(Generic[T]):
|
||||
futures = {}
|
||||
for _ in range(num_async):
|
||||
for a in all_actors:
|
||||
futures[a.par_iter_next.remote()] = a
|
||||
futures[a.par_iter_next_batch.remote(batch_ms)] = a
|
||||
while futures:
|
||||
pending = list(futures)
|
||||
if timeout is None:
|
||||
@@ -499,8 +522,11 @@ class ParallelIterator(Generic[T]):
|
||||
actor = futures.pop(obj_id)
|
||||
try:
|
||||
local_iter.shared_metrics.get().current_actor = actor
|
||||
yield ray.get(obj_id)
|
||||
futures[actor.par_iter_next.remote()] = actor
|
||||
batch = ray.get(obj_id)
|
||||
futures[actor.par_iter_next_batch.remote(
|
||||
batch_ms)] = actor
|
||||
for item in batch:
|
||||
yield item
|
||||
except StopIteration:
|
||||
pass
|
||||
# Always yield after each round of wait with timeout.
|
||||
@@ -522,7 +548,7 @@ class ParallelIterator(Generic[T]):
|
||||
def union(self, other: "ParallelIterator[T]") -> "ParallelIterator[T]":
|
||||
"""Return an iterator that is the union of this and the other."""
|
||||
if not isinstance(other, ParallelIterator):
|
||||
raise ValueError(
|
||||
raise TypeError(
|
||||
"other must be of type ParallelIterator, got {}".format(
|
||||
type(other)))
|
||||
actor_sets = []
|
||||
@@ -566,12 +592,29 @@ class ParallelIterator(Generic[T]):
|
||||
"""Return the list of all shards."""
|
||||
return [self.get_shard(i) for i in range(self.num_shards())]
|
||||
|
||||
def get_shard(self, shard_index: int) -> "LocalIterator[T]":
|
||||
def get_shard(self,
|
||||
shard_index: int,
|
||||
batch_ms: int = 0,
|
||||
num_async: int = 1) -> "LocalIterator[T]":
|
||||
"""Return a local iterator for the given shard.
|
||||
|
||||
The iterator is guaranteed to be serializable and can be passed to
|
||||
remote tasks or actors.
|
||||
|
||||
Arguments:
|
||||
shard_index (int): Index of the shard to gather.
|
||||
batch_ms (int): Batches items for batch_ms milliseconds
|
||||
before retrieving it.
|
||||
Increasing batch_ms increases latency but improves throughput.
|
||||
If this value is 0, then items are returned immediately.
|
||||
num_async (int): The max number of requests in flight.
|
||||
Increasing this improves the amount of pipeline
|
||||
parallelism in the iterator.
|
||||
"""
|
||||
if num_async < 1:
|
||||
raise ValueError("num async must be positive")
|
||||
if batch_ms < 0:
|
||||
raise ValueError("batch time must be positive")
|
||||
a, t = None, None
|
||||
i = shard_index
|
||||
for actor_set in self.actor_sets:
|
||||
@@ -586,10 +629,16 @@ class ParallelIterator(Generic[T]):
|
||||
self.num_shards())
|
||||
|
||||
def base_iterator(timeout=None):
|
||||
queue = collections.deque()
|
||||
ray.get(a.par_iter_init.remote(t))
|
||||
for _ in range(num_async):
|
||||
queue.append(a.par_iter_next_batch.remote(batch_ms))
|
||||
while True:
|
||||
try:
|
||||
yield ray.get(a.par_iter_next.remote(), timeout=timeout)
|
||||
batch = ray.get(queue.popleft(), timeout=timeout)
|
||||
queue.append(a.par_iter_next_batch.remote(batch_ms))
|
||||
for item in batch:
|
||||
yield item
|
||||
# Always yield after each round of gets with timeout.
|
||||
if timeout is not None:
|
||||
yield _NextValueNotReady()
|
||||
@@ -919,7 +968,10 @@ class LocalIterator(Generic[T]):
|
||||
yield _NextValueNotReady()
|
||||
else:
|
||||
if len(queues[i]) == 0:
|
||||
fill_next(timeout)
|
||||
try:
|
||||
fill_next(timeout)
|
||||
except StopIteration:
|
||||
return
|
||||
yield queues[i].popleft()
|
||||
|
||||
return gen
|
||||
@@ -1023,7 +1075,7 @@ class ParallelIteratorWorker(object):
|
||||
Subclasses must call this init function.
|
||||
|
||||
Args:
|
||||
item_generator (obj): A Python generator objects or lambda function
|
||||
item_generator (obj): A Python iterable or lambda function
|
||||
that produces a generator when called. We allow lambda
|
||||
functions since the generator itself might not be serializable,
|
||||
but a lambda that returns it can be.
|
||||
@@ -1040,7 +1092,13 @@ class ParallelIteratorWorker(object):
|
||||
|
||||
def cycle():
|
||||
while True:
|
||||
it = make_iterator()
|
||||
it = iter(make_iterator())
|
||||
if it is item_generator:
|
||||
raise ValueError(
|
||||
"Cannot iterate over {} multiple times." +
|
||||
"Please pass in the base iterable or" +
|
||||
"lambda: {} instead.".format(
|
||||
item_generator, item_generator))
|
||||
for item in it:
|
||||
yield item
|
||||
|
||||
@@ -1066,6 +1124,23 @@ class ParallelIteratorWorker(object):
|
||||
assert self.local_it is not None, "must call par_iter_init()"
|
||||
return next(self.local_it)
|
||||
|
||||
def par_iter_next_batch(self, batch_ms: int):
|
||||
"""Batches par_iter_next."""
|
||||
batch = []
|
||||
if batch_ms == 0:
|
||||
batch.append(self.par_iter_next())
|
||||
return batch
|
||||
t_end = time.time() + (0.001 * batch_ms)
|
||||
while time.time() < t_end:
|
||||
try:
|
||||
batch.append(self.par_iter_next())
|
||||
except StopIteration:
|
||||
if len(batch) == 0:
|
||||
raise StopIteration
|
||||
else:
|
||||
pass
|
||||
return batch
|
||||
|
||||
def par_iter_slice(self, step: int, start: int):
|
||||
"""Iterates in increments of step starting from start."""
|
||||
assert self.local_it is not None, "must call par_iter_init()"
|
||||
@@ -1089,6 +1164,23 @@ class ParallelIteratorWorker(object):
|
||||
|
||||
return self.next_ith_buffer[start].pop(0)
|
||||
|
||||
def par_iter_slice_batch(self, step: int, start: int, batch_ms: int):
|
||||
"""Batches par_iter_slice."""
|
||||
batch = []
|
||||
if batch_ms == 0:
|
||||
batch.append(self.par_iter_slice(step, start))
|
||||
return batch
|
||||
t_end = time.time() + (0.001 * batch_ms)
|
||||
while time.time() < t_end:
|
||||
try:
|
||||
batch.append(self.par_iter_slice(step, start))
|
||||
except StopIteration:
|
||||
if len(batch) == 0:
|
||||
raise StopIteration
|
||||
else:
|
||||
pass
|
||||
return batch
|
||||
|
||||
|
||||
def _randomized_int_cast(float_value):
|
||||
base = int(float_value)
|
||||
|
||||
Reference in New Issue
Block a user