mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 14:48:25 +08:00
[rllib] First pass at pipeline implementation of DQN (#7433)
* wip iters * add test * speed up * update docs * document it * support serial sampling * add test * spacing * annotate it * update * rename to pipeline * comment * iter2 wip * update * update * context test * update * fix * fix * a3c pipeline * doc * update * move timer * comment * add piepline test * fix * clean up * document * iter s * wip dqn * wip * wip * metrics * metrics rename * metrics ctx * wip * constants * add todo * suppport .union * wip * support union * remove prints * add todo * remove auto timer * fix up * fix pipeline test * typing * fix breakage * remove bad assert * wip * fix multiagent example * fixapply * update a3c * remove a2c pl * 0 workers * wip * wip * share metrics * wip * wip * doc * fix weight sync and global var updates * mode * fix * fix * doc * fix
This commit is contained in:
+20
-25
@@ -776,36 +776,35 @@ class LocalIterator(Generic[T]):
|
||||
if i >= n:
|
||||
break
|
||||
|
||||
def union(self, other: "LocalIterator[T]",
|
||||
def union(self, *others: "LocalIterator[T]",
|
||||
deterministic: bool = False) -> "LocalIterator[T]":
|
||||
"""Return an iterator that is the union of this and the other.
|
||||
"""Return an iterator that is the union of this and the others.
|
||||
|
||||
If deterministic=True, we alternate between reading from one iterator
|
||||
and the other. Otherwise we return items from iterators as they
|
||||
and the others. Otherwise we return items from iterators as they
|
||||
become ready.
|
||||
"""
|
||||
|
||||
if not isinstance(other, LocalIterator):
|
||||
raise ValueError(
|
||||
"other must be of type LocalIterator, got {}".format(
|
||||
type(other)))
|
||||
for it in others:
|
||||
if not isinstance(it, LocalIterator):
|
||||
raise ValueError(
|
||||
"other must be of type LocalIterator, got {}".format(
|
||||
type(it)))
|
||||
|
||||
if deterministic:
|
||||
timeout = None
|
||||
else:
|
||||
timeout = 0
|
||||
|
||||
it1 = LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.local_transforms,
|
||||
timeout=timeout)
|
||||
it2 = LocalIterator(
|
||||
other.base_iterator,
|
||||
other.metrics,
|
||||
other.local_transforms,
|
||||
timeout=timeout)
|
||||
active = [it1, it2]
|
||||
active = []
|
||||
shared_metrics = MetricsContext()
|
||||
for it in [self] + list(others):
|
||||
active.append(
|
||||
LocalIterator(
|
||||
it.base_iterator,
|
||||
shared_metrics,
|
||||
it.local_transforms,
|
||||
timeout=timeout))
|
||||
|
||||
def build_union(timeout=None):
|
||||
while True:
|
||||
@@ -826,15 +825,11 @@ class LocalIterator(Generic[T]):
|
||||
if not active:
|
||||
break
|
||||
|
||||
# TODO(ekl) is this the best way to represent union() of metrics?
|
||||
new_ctx = MetricsContext()
|
||||
new_ctx.parent_metrics.append(self.metrics)
|
||||
new_ctx.parent_metrics.append(other.metrics)
|
||||
|
||||
return LocalIterator(
|
||||
build_union,
|
||||
new_ctx, [],
|
||||
name="LocalUnion[{}, {}]".format(self, other))
|
||||
shared_metrics, [],
|
||||
name="LocalUnion[{}, {}]".format(self, ", ".join(map(str,
|
||||
others))))
|
||||
|
||||
|
||||
class ParallelIteratorWorker(object):
|
||||
|
||||
Reference in New Issue
Block a user