Removed the implicit sync barrier at the end of each training iteration (#5217)

*  removed sync barrier at the end of each training iteration

*  formatted

*  modify the comment according to current semantics

*  lint check

* Update trainer.py
This commit is contained in:
Jones Wong
2019-07-19 13:59:52 +08:00
committed by Eric Liang
parent 28e5c5555d
commit da7676c925
3 changed files with 27 additions and 24 deletions
+2 -1
View File
@@ -174,7 +174,8 @@ COMMON_CONFIG = {
},
# Whether to LZ4 compress individual observations
"compress_observations": False,
# Drop metric batches from unresponsive workers after this many seconds
# Wait for metric batches for at most this many seconds. Those that
# have not returned in time will be collected in the next iteration.
"collect_metrics_timeout": 180,
# Smooth metrics over this many episodes.
"metrics_smoothing_episodes": 100,
+21 -21
View File
@@ -40,58 +40,59 @@ def get_learner_stats(grad_info):
@DeveloperAPI
def collect_metrics(local_worker=None, remote_workers=[], timeout_seconds=180):
def collect_metrics(local_worker=None,
remote_workers=[],
to_be_collected=[],
timeout_seconds=180):
"""Gathers episode metrics from RolloutWorker instances."""
episodes, num_dropped = collect_episodes(
local_worker, remote_workers, timeout_seconds=timeout_seconds)
metrics = summarize_episodes(episodes, episodes, num_dropped)
episodes, to_be_collected = collect_episodes(
local_worker,
remote_workers,
to_be_collected,
timeout_seconds=timeout_seconds)
metrics = summarize_episodes(episodes, episodes)
return metrics
@DeveloperAPI
def collect_episodes(local_worker=None, remote_workers=[],
def collect_episodes(local_worker=None,
remote_workers=[],
to_be_collected=[],
timeout_seconds=180):
"""Gathers new episodes metrics tuples from the given evaluators."""
if remote_workers:
pending = [
a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_workers
]
collected, _ = ray.wait(
] + to_be_collected
collected, to_be_collected = ray.wait(
pending, num_returns=len(pending), timeout=timeout_seconds * 1.0)
num_metric_batches_dropped = len(pending) - len(collected)
if pending and len(collected) == 0:
raise ValueError(
"Timed out waiting for metrics from workers. You can "
"configure this timeout with `collect_metrics_timeout`.")
logger.warning(
"WARNING: collected no metrics in {} seconds".format(
timeout_seconds))
metric_lists = ray_get_and_free(collected)
else:
metric_lists = []
num_metric_batches_dropped = 0
if local_worker:
metric_lists.append(local_worker.get_metrics())
episodes = []
for metrics in metric_lists:
episodes.extend(metrics)
return episodes, num_metric_batches_dropped
return episodes, to_be_collected
@DeveloperAPI
def summarize_episodes(episodes, new_episodes, num_dropped):
def summarize_episodes(episodes, new_episodes):
"""Summarizes a set of episode metrics tuples.
Arguments:
episodes: smoothed set of episodes including historical ones
new_episodes: just the new episodes in this iteration
num_dropped: number of workers haven't returned their metrics
"""
if num_dropped > 0:
logger.warning("WARNING: {} workers have NOT returned metrics".format(
num_dropped))
episodes, estimates = _partition(episodes)
new_episodes, _ = _partition(new_episodes)
@@ -155,8 +156,7 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
policy_reward_mean=dict(policy_rewards),
custom_metrics=dict(custom_metrics),
sampler_perf=dict(perf_stats),
off_policy_estimator=dict(estimators),
num_metric_batches_dropped=num_dropped)
off_policy_estimator=dict(estimators))
def _partition(episodes):
@@ -36,6 +36,7 @@ class PolicyOptimizer(object):
"""
self.workers = workers
self.episode_history = []
self.to_be_collected = []
# Counters that should be updated by sub-classes
self.num_steps_trained = 0
@@ -100,9 +101,10 @@ class PolicyOptimizer(object):
res (dict): A training result dict from worker metrics with
`info` replaced with stats from self.
"""
episodes, num_dropped = collect_episodes(
episodes, self.to_be_collected = collect_episodes(
self.workers.local_worker(),
selected_workers or self.workers.remote_workers(),
self.to_be_collected,
timeout_seconds=timeout_seconds)
orig_episodes = list(episodes)
missing = min_history - len(episodes)
@@ -111,7 +113,7 @@ class PolicyOptimizer(object):
assert len(episodes) <= min_history
self.episode_history.extend(orig_episodes)
self.episode_history = self.episode_history[-min_history:]
res = summarize_episodes(episodes, orig_episodes, num_dropped)
res = summarize_episodes(episodes, orig_episodes)
res.update(info=self.stats())
return res