mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 17:58:35 +08:00
Removed the implicit sync barrier at the end of each training iteration (#5217)
* removed sync barrier at the end of each training iteration * formatted * modify the comment according to current semantics * lint check * Update trainer.py
This commit is contained in:
@@ -174,7 +174,8 @@ COMMON_CONFIG = {
|
||||
},
|
||||
# Whether to LZ4 compress individual observations
|
||||
"compress_observations": False,
|
||||
# Drop metric batches from unresponsive workers after this many seconds
|
||||
# Wait for metric batches for at most this many seconds. Those that
|
||||
# have not returned in time will be collected in the next iteration.
|
||||
"collect_metrics_timeout": 180,
|
||||
# Smooth metrics over this many episodes.
|
||||
"metrics_smoothing_episodes": 100,
|
||||
|
||||
@@ -40,58 +40,59 @@ def get_learner_stats(grad_info):
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_metrics(local_worker=None, remote_workers=[], timeout_seconds=180):
|
||||
def collect_metrics(local_worker=None,
|
||||
remote_workers=[],
|
||||
to_be_collected=[],
|
||||
timeout_seconds=180):
|
||||
"""Gathers episode metrics from RolloutWorker instances."""
|
||||
|
||||
episodes, num_dropped = collect_episodes(
|
||||
local_worker, remote_workers, timeout_seconds=timeout_seconds)
|
||||
metrics = summarize_episodes(episodes, episodes, num_dropped)
|
||||
episodes, to_be_collected = collect_episodes(
|
||||
local_worker,
|
||||
remote_workers,
|
||||
to_be_collected,
|
||||
timeout_seconds=timeout_seconds)
|
||||
metrics = summarize_episodes(episodes, episodes)
|
||||
return metrics
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_episodes(local_worker=None, remote_workers=[],
|
||||
def collect_episodes(local_worker=None,
|
||||
remote_workers=[],
|
||||
to_be_collected=[],
|
||||
timeout_seconds=180):
|
||||
"""Gathers new episodes metrics tuples from the given evaluators."""
|
||||
|
||||
if remote_workers:
|
||||
pending = [
|
||||
a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_workers
|
||||
]
|
||||
collected, _ = ray.wait(
|
||||
] + to_be_collected
|
||||
collected, to_be_collected = ray.wait(
|
||||
pending, num_returns=len(pending), timeout=timeout_seconds * 1.0)
|
||||
num_metric_batches_dropped = len(pending) - len(collected)
|
||||
if pending and len(collected) == 0:
|
||||
raise ValueError(
|
||||
"Timed out waiting for metrics from workers. You can "
|
||||
"configure this timeout with `collect_metrics_timeout`.")
|
||||
logger.warning(
|
||||
"WARNING: collected no metrics in {} seconds".format(
|
||||
timeout_seconds))
|
||||
metric_lists = ray_get_and_free(collected)
|
||||
else:
|
||||
metric_lists = []
|
||||
num_metric_batches_dropped = 0
|
||||
|
||||
if local_worker:
|
||||
metric_lists.append(local_worker.get_metrics())
|
||||
episodes = []
|
||||
for metrics in metric_lists:
|
||||
episodes.extend(metrics)
|
||||
return episodes, num_metric_batches_dropped
|
||||
return episodes, to_be_collected
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
def summarize_episodes(episodes, new_episodes):
|
||||
"""Summarizes a set of episode metrics tuples.
|
||||
|
||||
Arguments:
|
||||
episodes: smoothed set of episodes including historical ones
|
||||
new_episodes: just the new episodes in this iteration
|
||||
num_dropped: number of workers haven't returned their metrics
|
||||
"""
|
||||
|
||||
if num_dropped > 0:
|
||||
logger.warning("WARNING: {} workers have NOT returned metrics".format(
|
||||
num_dropped))
|
||||
|
||||
episodes, estimates = _partition(episodes)
|
||||
new_episodes, _ = _partition(new_episodes)
|
||||
|
||||
@@ -155,8 +156,7 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
policy_reward_mean=dict(policy_rewards),
|
||||
custom_metrics=dict(custom_metrics),
|
||||
sampler_perf=dict(perf_stats),
|
||||
off_policy_estimator=dict(estimators),
|
||||
num_metric_batches_dropped=num_dropped)
|
||||
off_policy_estimator=dict(estimators))
|
||||
|
||||
|
||||
def _partition(episodes):
|
||||
|
||||
@@ -36,6 +36,7 @@ class PolicyOptimizer(object):
|
||||
"""
|
||||
self.workers = workers
|
||||
self.episode_history = []
|
||||
self.to_be_collected = []
|
||||
|
||||
# Counters that should be updated by sub-classes
|
||||
self.num_steps_trained = 0
|
||||
@@ -100,9 +101,10 @@ class PolicyOptimizer(object):
|
||||
res (dict): A training result dict from worker metrics with
|
||||
`info` replaced with stats from self.
|
||||
"""
|
||||
episodes, num_dropped = collect_episodes(
|
||||
episodes, self.to_be_collected = collect_episodes(
|
||||
self.workers.local_worker(),
|
||||
selected_workers or self.workers.remote_workers(),
|
||||
self.to_be_collected,
|
||||
timeout_seconds=timeout_seconds)
|
||||
orig_episodes = list(episodes)
|
||||
missing = min_history - len(episodes)
|
||||
@@ -111,7 +113,7 @@ class PolicyOptimizer(object):
|
||||
assert len(episodes) <= min_history
|
||||
self.episode_history.extend(orig_episodes)
|
||||
self.episode_history = self.episode_history[-min_history:]
|
||||
res = summarize_episodes(episodes, orig_episodes, num_dropped)
|
||||
res = summarize_episodes(episodes, orig_episodes)
|
||||
res.update(info=self.stats())
|
||||
return res
|
||||
|
||||
|
||||
Reference in New Issue
Block a user