mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 21:55:40 +08:00
Solve hang caused by ray.get in collect_metrics (#3096)
This commit is contained in:
@@ -82,7 +82,8 @@ class A3CAgent(Agent):
|
||||
start = time.time()
|
||||
while time.time() - start < self.config["min_iter_time_s"]:
|
||||
self.optimizer.step()
|
||||
result = self.optimizer.collect_metrics()
|
||||
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
prev_steps)
|
||||
result = self.optimizer.collect_metrics(
|
||||
self.config["collect_metrics_timeout"])
|
||||
result.update(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps)
|
||||
return result
|
||||
|
||||
@@ -87,6 +87,8 @@ COMMON_CONFIG = {
|
||||
"compress_observations": False,
|
||||
# Allocate a fraction of a GPU instead of one (e.g., 0.3 GPUs)
|
||||
"gpu_fraction": 1,
|
||||
# Drop metric batches from unresponsive workers after this timeout (seconds)
|
||||
"collect_metrics_timeout": 180,
|
||||
|
||||
# === Multiagent ===
|
||||
"multiagent": {
|
||||
|
||||
@@ -236,10 +236,13 @@ class DQNAgent(Agent):
|
||||
# Only collect metrics from the third of workers with lowest eps
|
||||
result = collect_metrics(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators[-len(self.remote_evaluators) // 3:])
|
||||
self.remote_evaluators[-len(self.remote_evaluators) // 3:],
|
||||
timeout_seconds=self.config["collect_metrics_timeout"])
|
||||
else:
|
||||
result = collect_metrics(self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
result = collect_metrics(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
timeout_seconds=self.config["collect_metrics_timeout"])
|
||||
|
||||
result.update(
|
||||
timesteps_this_iter=self.global_timestep - start_timestep,
|
||||
|
||||
@@ -109,7 +109,8 @@ class ImpalaAgent(Agent):
|
||||
self.optimizer.step()
|
||||
while time.time() - start < self.config["min_iter_time_s"]:
|
||||
self.optimizer.step()
|
||||
result = self.optimizer.collect_metrics()
|
||||
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
prev_steps)
|
||||
result = self.optimizer.collect_metrics(
|
||||
self.config["collect_metrics_timeout"])
|
||||
result.update(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps)
|
||||
return result
|
||||
|
||||
@@ -49,7 +49,8 @@ class PGAgent(Agent):
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
self.optimizer.step()
|
||||
result = self.optimizer.collect_metrics()
|
||||
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
prev_steps)
|
||||
result = self.optimizer.collect_metrics(
|
||||
self.config["collect_metrics_timeout"])
|
||||
result.update(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps)
|
||||
return result
|
||||
|
||||
@@ -138,7 +138,8 @@ class PPOAgent(Agent):
|
||||
# multi-agent
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
lambda pi, pi_id: pi.update_kl(fetches[pi_id]["kl"]))
|
||||
res = self.optimizer.collect_metrics()
|
||||
res = self.optimizer.collect_metrics(
|
||||
self.config["collect_metrics_timeout"])
|
||||
res.update(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
|
||||
info=dict(fetches, **res.get("info", {})))
|
||||
|
||||
@@ -2,42 +2,60 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import collections
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def collect_metrics(local_evaluator, remote_evaluators=[]):
|
||||
|
||||
def collect_metrics(local_evaluator, remote_evaluators=[],
|
||||
timeout_seconds=180):
|
||||
"""Gathers episode metrics from PolicyEvaluator instances."""
|
||||
|
||||
episodes = collect_episodes(local_evaluator, remote_evaluators)
|
||||
return summarize_episodes(episodes, episodes)
|
||||
episodes, num_dropped = collect_episodes(
|
||||
local_evaluator, remote_evaluators, timeout_seconds=timeout_seconds)
|
||||
metrics = summarize_episodes(episodes, episodes, num_dropped)
|
||||
return metrics
|
||||
|
||||
|
||||
def collect_episodes(local_evaluator, remote_evaluators=[]):
|
||||
def collect_episodes(local_evaluator,
|
||||
remote_evaluators=[],
|
||||
timeout_seconds=180):
|
||||
"""Gathers new episodes metrics tuples from the given evaluators."""
|
||||
|
||||
metric_lists = ray.get([
|
||||
pending = [
|
||||
a.apply.remote(lambda ev: ev.sampler.get_metrics())
|
||||
for a in remote_evaluators
|
||||
])
|
||||
]
|
||||
collected, _ = ray.wait(
|
||||
pending, num_returns=len(pending), timeout=timeout_seconds * 1000)
|
||||
num_metric_batches_dropped = len(pending) - len(collected)
|
||||
|
||||
metric_lists = ray.get(collected)
|
||||
metric_lists.append(local_evaluator.sampler.get_metrics())
|
||||
episodes = []
|
||||
for metrics in metric_lists:
|
||||
episodes.extend(metrics)
|
||||
return episodes
|
||||
return episodes, num_metric_batches_dropped
|
||||
|
||||
|
||||
def summarize_episodes(episodes, new_episodes):
|
||||
def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
"""Summarizes a set of episode metrics tuples.
|
||||
|
||||
Arguments:
|
||||
episodes: smoothed set of episodes including historical ones
|
||||
new_episodes: just the new episodes in this iteration
|
||||
num_dropped: number of workers haven't returned their metrics
|
||||
"""
|
||||
|
||||
if num_dropped > 0:
|
||||
logger.warn("WARNING: {} workers have NOT returned metrics".format(
|
||||
num_dropped))
|
||||
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
policy_rewards = collections.defaultdict(list)
|
||||
@@ -65,4 +83,5 @@ def summarize_episodes(episodes, new_episodes):
|
||||
episode_reward_mean=avg_reward,
|
||||
episode_len_mean=avg_length,
|
||||
episodes_this_iter=len(new_episodes),
|
||||
policy_reward_mean=dict(policy_rewards))
|
||||
policy_reward_mean=dict(policy_rewards),
|
||||
num_metric_batches_dropped=num_dropped)
|
||||
|
||||
@@ -83,7 +83,7 @@ class PolicyOptimizer(object):
|
||||
"num_steps_sampled": self.num_steps_sampled,
|
||||
}
|
||||
|
||||
def collect_metrics(self, min_history=100):
|
||||
def collect_metrics(self, timeout_seconds, min_history=100):
|
||||
"""Returns evaluator and optimizer stats.
|
||||
|
||||
Arguments:
|
||||
@@ -93,8 +93,10 @@ class PolicyOptimizer(object):
|
||||
res (dict): A training result dict from evaluator metrics with
|
||||
`info` replaced with stats from self.
|
||||
"""
|
||||
episodes = collect_episodes(self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
episodes, num_dropped = collect_episodes(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
timeout_seconds=timeout_seconds)
|
||||
orig_episodes = list(episodes)
|
||||
missing = min_history - len(episodes)
|
||||
if missing > 0:
|
||||
@@ -102,7 +104,7 @@ class PolicyOptimizer(object):
|
||||
assert len(episodes) <= min_history
|
||||
self.episode_history.extend(orig_episodes)
|
||||
self.episode_history = self.episode_history[-min_history:]
|
||||
res = summarize_episodes(episodes, orig_episodes)
|
||||
res = summarize_episodes(episodes, orig_episodes, num_dropped)
|
||||
res.update(info=self.stats())
|
||||
return res
|
||||
|
||||
|
||||
Reference in New Issue
Block a user