Solve hang caused by ray.get in collect_metrics (#3096)

This commit is contained in:
Jones Wong
2018-10-29 02:52:18 +08:00
committed by Eric Liang
parent af0c1174cd
commit d6bf890648
8 changed files with 56 additions and 26 deletions
+4 -3
View File
@@ -82,7 +82,8 @@ class A3CAgent(Agent):
start = time.time()
while time.time() - start < self.config["min_iter_time_s"]:
self.optimizer.step()
result = self.optimizer.collect_metrics()
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps)
result = self.optimizer.collect_metrics(
self.config["collect_metrics_timeout"])
result.update(
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps)
return result
+2
View File
@@ -87,6 +87,8 @@ COMMON_CONFIG = {
"compress_observations": False,
# Allocate a fraction of a GPU instead of one (e.g., 0.3 GPUs)
"gpu_fraction": 1,
# Drop metric batches from unresponsive workers after this timeout (seconds)
"collect_metrics_timeout": 180,
# === Multiagent ===
"multiagent": {
+6 -3
View File
@@ -236,10 +236,13 @@ class DQNAgent(Agent):
# Only collect metrics from the third of workers with lowest eps
result = collect_metrics(
self.local_evaluator,
self.remote_evaluators[-len(self.remote_evaluators) // 3:])
self.remote_evaluators[-len(self.remote_evaluators) // 3:],
timeout_seconds=self.config["collect_metrics_timeout"])
else:
result = collect_metrics(self.local_evaluator,
self.remote_evaluators)
result = collect_metrics(
self.local_evaluator,
self.remote_evaluators,
timeout_seconds=self.config["collect_metrics_timeout"])
result.update(
timesteps_this_iter=self.global_timestep - start_timestep,
+4 -3
View File
@@ -109,7 +109,8 @@ class ImpalaAgent(Agent):
self.optimizer.step()
while time.time() - start < self.config["min_iter_time_s"]:
self.optimizer.step()
result = self.optimizer.collect_metrics()
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps)
result = self.optimizer.collect_metrics(
self.config["collect_metrics_timeout"])
result.update(
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps)
return result
+4 -3
View File
@@ -49,7 +49,8 @@ class PGAgent(Agent):
def _train(self):
prev_steps = self.optimizer.num_steps_sampled
self.optimizer.step()
result = self.optimizer.collect_metrics()
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps)
result = self.optimizer.collect_metrics(
self.config["collect_metrics_timeout"])
result.update(
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps)
return result
+2 -1
View File
@@ -138,7 +138,8 @@ class PPOAgent(Agent):
# multi-agent
self.local_evaluator.foreach_trainable_policy(
lambda pi, pi_id: pi.update_kl(fetches[pi_id]["kl"]))
res = self.optimizer.collect_metrics()
res = self.optimizer.collect_metrics(
self.config["collect_metrics_timeout"])
res.update(
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
info=dict(fetches, **res.get("info", {})))
+28 -9
View File
@@ -2,42 +2,60 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import collections
import ray
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
logger = logging.getLogger(__name__)
def collect_metrics(local_evaluator, remote_evaluators=[]):
def collect_metrics(local_evaluator, remote_evaluators=[],
timeout_seconds=180):
"""Gathers episode metrics from PolicyEvaluator instances."""
episodes = collect_episodes(local_evaluator, remote_evaluators)
return summarize_episodes(episodes, episodes)
episodes, num_dropped = collect_episodes(
local_evaluator, remote_evaluators, timeout_seconds=timeout_seconds)
metrics = summarize_episodes(episodes, episodes, num_dropped)
return metrics
def collect_episodes(local_evaluator, remote_evaluators=[]):
def collect_episodes(local_evaluator,
remote_evaluators=[],
timeout_seconds=180):
"""Gathers new episodes metrics tuples from the given evaluators."""
metric_lists = ray.get([
pending = [
a.apply.remote(lambda ev: ev.sampler.get_metrics())
for a in remote_evaluators
])
]
collected, _ = ray.wait(
pending, num_returns=len(pending), timeout=timeout_seconds * 1000)
num_metric_batches_dropped = len(pending) - len(collected)
metric_lists = ray.get(collected)
metric_lists.append(local_evaluator.sampler.get_metrics())
episodes = []
for metrics in metric_lists:
episodes.extend(metrics)
return episodes
return episodes, num_metric_batches_dropped
def summarize_episodes(episodes, new_episodes):
def summarize_episodes(episodes, new_episodes, num_dropped):
"""Summarizes a set of episode metrics tuples.
Arguments:
episodes: smoothed set of episodes including historical ones
new_episodes: just the new episodes in this iteration
num_dropped: number of workers haven't returned their metrics
"""
if num_dropped > 0:
logger.warn("WARNING: {} workers have NOT returned metrics".format(
num_dropped))
episode_rewards = []
episode_lengths = []
policy_rewards = collections.defaultdict(list)
@@ -65,4 +83,5 @@ def summarize_episodes(episodes, new_episodes):
episode_reward_mean=avg_reward,
episode_len_mean=avg_length,
episodes_this_iter=len(new_episodes),
policy_reward_mean=dict(policy_rewards))
policy_reward_mean=dict(policy_rewards),
num_metric_batches_dropped=num_dropped)
@@ -83,7 +83,7 @@ class PolicyOptimizer(object):
"num_steps_sampled": self.num_steps_sampled,
}
def collect_metrics(self, min_history=100):
def collect_metrics(self, timeout_seconds, min_history=100):
"""Returns evaluator and optimizer stats.
Arguments:
@@ -93,8 +93,10 @@ class PolicyOptimizer(object):
res (dict): A training result dict from evaluator metrics with
`info` replaced with stats from self.
"""
episodes = collect_episodes(self.local_evaluator,
self.remote_evaluators)
episodes, num_dropped = collect_episodes(
self.local_evaluator,
self.remote_evaluators,
timeout_seconds=timeout_seconds)
orig_episodes = list(episodes)
missing = min_history - len(episodes)
if missing > 0:
@@ -102,7 +104,7 @@ class PolicyOptimizer(object):
assert len(episodes) <= min_history
self.episode_history.extend(orig_episodes)
self.episode_history = self.episode_history[-min_history:]
res = summarize_episodes(episodes, orig_episodes)
res = summarize_episodes(episodes, orig_episodes, num_dropped)
res.update(info=self.stats())
return res