[rllib] Support training intensity for dqn / apex (#8396)

This commit is contained in:
Eric Liang
2020-05-20 11:22:30 -07:00
committed by GitHub
parent f56b3be916
commit aa7a58e92f
8 changed files with 134 additions and 25 deletions
+50 -16
View File
@@ -941,13 +941,22 @@ class LocalIterator(Generic[T]):
return iterators
def union(self, *others: "LocalIterator[T]",
deterministic: bool = False) -> "LocalIterator[T]":
def union(self,
*others: "LocalIterator[T]",
deterministic: bool = False,
round_robin_weights: List[float] = None) -> "LocalIterator[T]":
"""Return an iterator that is the union of this and the others.
If deterministic=True, we alternate between reading from one iterator
and the others. Otherwise we return items from iterators as they
become ready.
Args:
deterministic (bool): If deterministic=True, we alternate between
reading from one iterator and the others. Otherwise we return
items from iterators as they become ready.
round_robin_weights (list): List of weights to use for round robin
mode. For example, [2, 1] will cause the iterator to pull twice
as many items from the first iterator as the second.
[2, 1, "*"] will cause as many items to be pulled as possible
from the third iterator without blocking. This overrides the
deterministic flag.
"""
for it in others:
@@ -956,32 +965,49 @@ class LocalIterator(Generic[T]):
"other must be of type LocalIterator, got {}".format(
type(it)))
timeout = None if deterministic else 0
active = []
parent_iters = [self] + list(others)
shared_metrics = SharedMetrics(
parents=[p.shared_metrics for p in parent_iters])
for it in parent_iters:
timeout = None if deterministic else 0
if round_robin_weights:
if len(round_robin_weights) != len(parent_iters):
raise ValueError(
"Length of round robin weights must equal number of "
"iterators total.")
timeouts = [0 if w == "*" else None for w in round_robin_weights]
else:
timeouts = [timeout] * len(parent_iters)
round_robin_weights = [1] * len(parent_iters)
for i, it in enumerate(parent_iters):
active.append(
LocalIterator(
it.base_iterator,
shared_metrics,
it.local_transforms,
timeout=timeout))
timeout=timeouts[i]))
active = list(zip(round_robin_weights, active))
def build_union(timeout=None):
while True:
for it in list(active):
for weight, it in list(active):
if weight == "*":
max_pull = 100 # TOOD(ekl) how to best bound this?
else:
max_pull = _randomized_int_cast(weight)
try:
item = next(it)
if isinstance(item, _NextValueNotReady):
if timeout is not None:
for _ in range(max_pull):
item = next(it)
if isinstance(item, _NextValueNotReady):
if timeout is not None:
yield item
break
else:
yield item
else:
yield item
except StopIteration:
active.remove(it)
active.remove((weight, it))
if not active:
break
@@ -1071,6 +1097,14 @@ class ParallelIteratorWorker(object):
return self.next_ith_buffer[start].pop(0)
def _randomized_int_cast(float_value):
base = int(float_value)
remainder = float_value - base
if random.random() < remainder:
base += 1
return base
class _NextValueNotReady(Exception):
"""Indicates that a local iterator has no value currently available.