[rllib] First pass at pipeline implementation of DQN (#7433)

* wip iters

* add test

* speed up

* update docs

* document it

* support serial sampling

* add test

* spacing

* annotate it

* update

* rename to pipeline

* comment

* iter2 wip

* update

* update

* context test

* update

* fix

* fix

* a3c pipeline

* doc

* update

* move timer

* comment

* add piepline test

* fix

* clean up

* document

* iter s

* wip dqn

* wip

* wip

* metrics

* metrics rename

* metrics ctx

* wip

* constants

* add todo

* suppport .union

* wip

* support union

* remove prints

* add todo

* remove auto timer

* fix up

* fix pipeline test

* typing

* fix breakage

* remove bad assert

* wip

* fix multiagent example

* fixapply

* update a3c

* remove a2c pl

* 0 workers

* wip

* wip

* share metrics

* wip

* wip

* doc

* fix weight sync and global var updates

* mode

* fix

* fix

* doc

* fix
This commit is contained in:
Eric Liang
2020-03-07 14:47:58 -08:00
committed by GitHub
parent beb9b02dbd
commit a644060daa
8 changed files with 258 additions and 74 deletions
+4 -5
View File
@@ -47,14 +47,14 @@ class TuneReporterBase(ProgressReporter):
"""Abstract base class for the default Tune reporters."""
# Truncated representations of column names (to accommodate small screens).
DEFAULT_COLUMNS = {
EPISODE_REWARD_MEAN: "reward",
DEFAULT_COLUMNS = collections.OrderedDict({
MEAN_ACCURACY: "acc",
MEAN_LOSS: "loss",
TRAINING_ITERATION: "iter",
TIME_TOTAL_S: "total time (s)",
TIMESTEPS_TOTAL: "ts",
TRAINING_ITERATION: "iter",
}
EPISODE_REWARD_MEAN: "reward",
})
def __init__(self,
metric_columns=None,
@@ -301,7 +301,6 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
k for k in keys if any(
t.last_result.get(k) is not None for t in trials)
]
keys = sorted(keys)
# Build trial rows.
params = sorted(set().union(*[t.evaluated_params for t in trials]))
trial_table = [_get_trial_info(trial, params, keys) for trial in trials]
+20 -25
View File
@@ -776,36 +776,35 @@ class LocalIterator(Generic[T]):
if i >= n:
break
def union(self, other: "LocalIterator[T]",
def union(self, *others: "LocalIterator[T]",
deterministic: bool = False) -> "LocalIterator[T]":
"""Return an iterator that is the union of this and the other.
"""Return an iterator that is the union of this and the others.
If deterministic=True, we alternate between reading from one iterator
and the other. Otherwise we return items from iterators as they
and the others. Otherwise we return items from iterators as they
become ready.
"""
if not isinstance(other, LocalIterator):
raise ValueError(
"other must be of type LocalIterator, got {}".format(
type(other)))
for it in others:
if not isinstance(it, LocalIterator):
raise ValueError(
"other must be of type LocalIterator, got {}".format(
type(it)))
if deterministic:
timeout = None
else:
timeout = 0
it1 = LocalIterator(
self.base_iterator,
self.metrics,
self.local_transforms,
timeout=timeout)
it2 = LocalIterator(
other.base_iterator,
other.metrics,
other.local_transforms,
timeout=timeout)
active = [it1, it2]
active = []
shared_metrics = MetricsContext()
for it in [self] + list(others):
active.append(
LocalIterator(
it.base_iterator,
shared_metrics,
it.local_transforms,
timeout=timeout))
def build_union(timeout=None):
while True:
@@ -826,15 +825,11 @@ class LocalIterator(Generic[T]):
if not active:
break
# TODO(ekl) is this the best way to represent union() of metrics?
new_ctx = MetricsContext()
new_ctx.parent_metrics.append(self.metrics)
new_ctx.parent_metrics.append(other.metrics)
return LocalIterator(
build_union,
new_ctx, [],
name="LocalUnion[{}, {}]".format(self, other))
shared_metrics, [],
name="LocalUnion[{}, {}]".format(self, ", ".join(map(str,
others))))
class ParallelIteratorWorker(object):