mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:53:20 +08:00
[rllib] Support training intensity for dqn / apex (#8396)
This commit is contained in:
+50
-16
@@ -941,13 +941,22 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return iterators
|
||||
|
||||
def union(self, *others: "LocalIterator[T]",
|
||||
deterministic: bool = False) -> "LocalIterator[T]":
|
||||
def union(self,
|
||||
*others: "LocalIterator[T]",
|
||||
deterministic: bool = False,
|
||||
round_robin_weights: List[float] = None) -> "LocalIterator[T]":
|
||||
"""Return an iterator that is the union of this and the others.
|
||||
|
||||
If deterministic=True, we alternate between reading from one iterator
|
||||
and the others. Otherwise we return items from iterators as they
|
||||
become ready.
|
||||
Args:
|
||||
deterministic (bool): If deterministic=True, we alternate between
|
||||
reading from one iterator and the others. Otherwise we return
|
||||
items from iterators as they become ready.
|
||||
round_robin_weights (list): List of weights to use for round robin
|
||||
mode. For example, [2, 1] will cause the iterator to pull twice
|
||||
as many items from the first iterator as the second.
|
||||
[2, 1, "*"] will cause as many items to be pulled as possible
|
||||
from the third iterator without blocking. This overrides the
|
||||
deterministic flag.
|
||||
"""
|
||||
|
||||
for it in others:
|
||||
@@ -956,32 +965,49 @@ class LocalIterator(Generic[T]):
|
||||
"other must be of type LocalIterator, got {}".format(
|
||||
type(it)))
|
||||
|
||||
timeout = None if deterministic else 0
|
||||
|
||||
active = []
|
||||
parent_iters = [self] + list(others)
|
||||
shared_metrics = SharedMetrics(
|
||||
parents=[p.shared_metrics for p in parent_iters])
|
||||
for it in parent_iters:
|
||||
|
||||
timeout = None if deterministic else 0
|
||||
if round_robin_weights:
|
||||
if len(round_robin_weights) != len(parent_iters):
|
||||
raise ValueError(
|
||||
"Length of round robin weights must equal number of "
|
||||
"iterators total.")
|
||||
timeouts = [0 if w == "*" else None for w in round_robin_weights]
|
||||
else:
|
||||
timeouts = [timeout] * len(parent_iters)
|
||||
round_robin_weights = [1] * len(parent_iters)
|
||||
|
||||
for i, it in enumerate(parent_iters):
|
||||
active.append(
|
||||
LocalIterator(
|
||||
it.base_iterator,
|
||||
shared_metrics,
|
||||
it.local_transforms,
|
||||
timeout=timeout))
|
||||
timeout=timeouts[i]))
|
||||
active = list(zip(round_robin_weights, active))
|
||||
|
||||
def build_union(timeout=None):
|
||||
while True:
|
||||
for it in list(active):
|
||||
for weight, it in list(active):
|
||||
if weight == "*":
|
||||
max_pull = 100 # TOOD(ekl) how to best bound this?
|
||||
else:
|
||||
max_pull = _randomized_int_cast(weight)
|
||||
try:
|
||||
item = next(it)
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
if timeout is not None:
|
||||
for _ in range(max_pull):
|
||||
item = next(it)
|
||||
if isinstance(item, _NextValueNotReady):
|
||||
if timeout is not None:
|
||||
yield item
|
||||
break
|
||||
else:
|
||||
yield item
|
||||
else:
|
||||
yield item
|
||||
except StopIteration:
|
||||
active.remove(it)
|
||||
active.remove((weight, it))
|
||||
if not active:
|
||||
break
|
||||
|
||||
@@ -1071,6 +1097,14 @@ class ParallelIteratorWorker(object):
|
||||
return self.next_ith_buffer[start].pop(0)
|
||||
|
||||
|
||||
def _randomized_int_cast(float_value):
|
||||
base = int(float_value)
|
||||
remainder = float_value - base
|
||||
if random.random() < remainder:
|
||||
base += 1
|
||||
return base
|
||||
|
||||
|
||||
class _NextValueNotReady(Exception):
|
||||
"""Indicates that a local iterator has no value currently available.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user