mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 23:11:40 +08:00
95 lines
3.2 KiB
Python
95 lines
3.2 KiB
Python
from typing import List
|
|
import queue
|
|
|
|
from ray.util.iter import LocalIterator, _NextValueNotReady
|
|
from ray.util.iter_metrics import SharedMetrics
|
|
|
|
|
|
def Concurrently(ops: List[LocalIterator], *, mode="round_robin"):
|
|
"""Operator that runs the given parent iterators concurrently.
|
|
|
|
Arguments:
|
|
mode (str): One of {'round_robin', 'async'}.
|
|
- In 'round_robin' mode, we alternate between pulling items from
|
|
each parent iterator in order deterministically.
|
|
- In 'async' mode, we pull from each parent iterator as fast as
|
|
they are produced. This is non-deterministic.
|
|
|
|
>>> sim_op = ParallelRollouts(...).for_each(...)
|
|
>>> replay_op = LocalReplay(...).for_each(...)
|
|
>>> combined_op = Concurrently([sim_op, replay_op], mode="async")
|
|
"""
|
|
|
|
if len(ops) < 2:
|
|
raise ValueError("Should specify at least 2 ops.")
|
|
if mode == "round_robin":
|
|
deterministic = True
|
|
elif mode == "async":
|
|
deterministic = False
|
|
else:
|
|
raise ValueError("Unknown mode {}".format(mode))
|
|
return ops[0].union(*ops[1:], deterministic=deterministic)
|
|
|
|
|
|
class Enqueue:
|
|
"""Enqueue data items into a queue.Queue instance.
|
|
|
|
The enqueue is non-blocking, so Enqueue operations can executed with
|
|
Dequeue via the Concurrently() operator.
|
|
|
|
Examples:
|
|
>>> queue = queue.Queue(100)
|
|
>>> write_op = ParallelRollouts(...).for_each(Enqueue(queue))
|
|
>>> read_op = Dequeue(queue)
|
|
>>> combined_op = Concurrently([write_op, read_op], mode="async")
|
|
>>> next(combined_op)
|
|
SampleBatch(...)
|
|
"""
|
|
|
|
def __init__(self, output_queue: queue.Queue):
|
|
if not isinstance(output_queue, queue.Queue):
|
|
raise ValueError("Expected queue.Queue, got {}".format(
|
|
type(output_queue)))
|
|
self.queue = output_queue
|
|
|
|
def __call__(self, x):
|
|
try:
|
|
self.queue.put_nowait(x)
|
|
except queue.Full:
|
|
return _NextValueNotReady()
|
|
|
|
|
|
def Dequeue(input_queue: queue.Queue, check=lambda: True):
|
|
"""Dequeue data items from a queue.Queue instance.
|
|
|
|
The dequeue is non-blocking, so Dequeue operations can executed with
|
|
Enqueue via the Concurrently() operator.
|
|
|
|
Arguments:
|
|
input_queue (Queue): queue to pull items from.
|
|
check (fn): liveness check. When this function returns false,
|
|
Dequeue() will raise an error to halt execution.
|
|
|
|
Examples:
|
|
>>> queue = queue.Queue(100)
|
|
>>> write_op = ParallelRollouts(...).for_each(Enqueue(queue))
|
|
>>> read_op = Dequeue(queue)
|
|
>>> combined_op = Concurrently([write_op, read_op], mode="async")
|
|
>>> next(combined_op)
|
|
SampleBatch(...)
|
|
"""
|
|
if not isinstance(input_queue, queue.Queue):
|
|
raise ValueError("Expected queue.Queue, got {}".format(
|
|
type(input_queue)))
|
|
|
|
def base_iterator(timeout=None):
|
|
while check():
|
|
try:
|
|
item = input_queue.get_nowait()
|
|
yield item
|
|
except queue.Empty:
|
|
yield _NextValueNotReady()
|
|
raise RuntimeError("Error raised reading from queue")
|
|
|
|
return LocalIterator(base_iterator, SharedMetrics())
|