mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 15:22:56 +08:00
[rllib] First pass at pipeline implementation of DQN (#7433)
* wip iters * add test * speed up * update docs * document it * support serial sampling * add test * spacing * annotate it * update * rename to pipeline * comment * iter2 wip * update * update * context test * update * fix * fix * a3c pipeline * doc * update * move timer * comment * add piepline test * fix * clean up * document * iter s * wip dqn * wip * wip * metrics * metrics rename * metrics ctx * wip * constants * add todo * suppport .union * wip * support union * remove prints * add todo * remove auto timer * fix up * fix pipeline test * typing * fix breakage * remove bad assert * wip * fix multiagent example * fixapply * update a3c * remove a2c pl * 0 workers * wip * wip * share metrics * wip * wip * doc * fix weight sync and global var updates * mode * fix * fix * doc * fix
This commit is contained in:
@@ -47,14 +47,14 @@ class TuneReporterBase(ProgressReporter):
|
||||
"""Abstract base class for the default Tune reporters."""
|
||||
|
||||
# Truncated representations of column names (to accommodate small screens).
|
||||
DEFAULT_COLUMNS = {
|
||||
EPISODE_REWARD_MEAN: "reward",
|
||||
DEFAULT_COLUMNS = collections.OrderedDict({
|
||||
MEAN_ACCURACY: "acc",
|
||||
MEAN_LOSS: "loss",
|
||||
TRAINING_ITERATION: "iter",
|
||||
TIME_TOTAL_S: "total time (s)",
|
||||
TIMESTEPS_TOTAL: "ts",
|
||||
TRAINING_ITERATION: "iter",
|
||||
}
|
||||
EPISODE_REWARD_MEAN: "reward",
|
||||
})
|
||||
|
||||
def __init__(self,
|
||||
metric_columns=None,
|
||||
@@ -301,7 +301,6 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
|
||||
k for k in keys if any(
|
||||
t.last_result.get(k) is not None for t in trials)
|
||||
]
|
||||
keys = sorted(keys)
|
||||
# Build trial rows.
|
||||
params = sorted(set().union(*[t.evaluated_params for t in trials]))
|
||||
trial_table = [_get_trial_info(trial, params, keys) for trial in trials]
|
||||
|
||||
+20
-25
@@ -776,36 +776,35 @@ class LocalIterator(Generic[T]):
|
||||
if i >= n:
|
||||
break
|
||||
|
||||
def union(self, other: "LocalIterator[T]",
|
||||
def union(self, *others: "LocalIterator[T]",
|
||||
deterministic: bool = False) -> "LocalIterator[T]":
|
||||
"""Return an iterator that is the union of this and the other.
|
||||
"""Return an iterator that is the union of this and the others.
|
||||
|
||||
If deterministic=True, we alternate between reading from one iterator
|
||||
and the other. Otherwise we return items from iterators as they
|
||||
and the others. Otherwise we return items from iterators as they
|
||||
become ready.
|
||||
"""
|
||||
|
||||
if not isinstance(other, LocalIterator):
|
||||
raise ValueError(
|
||||
"other must be of type LocalIterator, got {}".format(
|
||||
type(other)))
|
||||
for it in others:
|
||||
if not isinstance(it, LocalIterator):
|
||||
raise ValueError(
|
||||
"other must be of type LocalIterator, got {}".format(
|
||||
type(it)))
|
||||
|
||||
if deterministic:
|
||||
timeout = None
|
||||
else:
|
||||
timeout = 0
|
||||
|
||||
it1 = LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.local_transforms,
|
||||
timeout=timeout)
|
||||
it2 = LocalIterator(
|
||||
other.base_iterator,
|
||||
other.metrics,
|
||||
other.local_transforms,
|
||||
timeout=timeout)
|
||||
active = [it1, it2]
|
||||
active = []
|
||||
shared_metrics = MetricsContext()
|
||||
for it in [self] + list(others):
|
||||
active.append(
|
||||
LocalIterator(
|
||||
it.base_iterator,
|
||||
shared_metrics,
|
||||
it.local_transforms,
|
||||
timeout=timeout))
|
||||
|
||||
def build_union(timeout=None):
|
||||
while True:
|
||||
@@ -826,15 +825,11 @@ class LocalIterator(Generic[T]):
|
||||
if not active:
|
||||
break
|
||||
|
||||
# TODO(ekl) is this the best way to represent union() of metrics?
|
||||
new_ctx = MetricsContext()
|
||||
new_ctx.parent_metrics.append(self.metrics)
|
||||
new_ctx.parent_metrics.append(other.metrics)
|
||||
|
||||
return LocalIterator(
|
||||
build_union,
|
||||
new_ctx, [],
|
||||
name="LocalUnion[{}, {}]".format(self, other))
|
||||
shared_metrics, [],
|
||||
name="LocalUnion[{}, {}]".format(self, ", ".join(map(str,
|
||||
others))))
|
||||
|
||||
|
||||
class ParallelIteratorWorker(object):
|
||||
|
||||
Reference in New Issue
Block a user