[rllib] Report sampler performance metrics (#4427)

This commit is contained in:
Eric Liang
2019-03-27 13:24:23 -07:00
committed by GitHub
parent 12db684f72
commit 2871609296
4 changed files with 58 additions and 24 deletions
+1 -1
View File
@@ -1,7 +1,7 @@
RLlib: Scalable Reinforcement Learning
======================================
RLlib is an open-source library for reinforcement learning that offers both a unified API for a variety of applications and high scalability via distributed eager execution.
RLlib is an open-source library for reinforcement learning that offers both high scalability and a unified API for a variety of applications.
For an overview of RLlib, see the [documentation](http://ray.readthedocs.io/en/latest/rllib.html).
+7
View File
@@ -72,11 +72,14 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
episode_lengths = []
policy_rewards = collections.defaultdict(list)
custom_metrics = collections.defaultdict(list)
perf_stats = collections.defaultdict(list)
for episode in episodes:
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
for k, v in episode.custom_metrics.items():
custom_metrics[k].append(v)
for k, v in episode.perf_stats.items():
perf_stats[k].append(v)
for (_, policy_id), reward in episode.agent_rewards.items():
if policy_id != DEFAULT_POLICY_ID:
policy_rewards[policy_id].append(reward)
@@ -103,6 +106,9 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
custom_metrics[k + "_max"] = float("nan")
del custom_metrics[k]
for k, v_list in perf_stats.copy().items():
perf_stats[k] = np.mean(v_list)
estimators = collections.defaultdict(lambda: collections.defaultdict(list))
for e in estimates:
acc = estimators[e.estimator_name]
@@ -121,6 +127,7 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
episodes_this_iter=len(new_episodes),
policy_reward_mean=dict(policy_rewards),
custom_metrics=dict(custom_metrics),
sampler_perf=dict(perf_stats),
off_policy_estimator=dict(estimators),
num_metric_batches_dropped=num_dropped)
+49 -22
View File
@@ -8,6 +8,7 @@ import logging
import numpy as np
import six.moves.queue as queue
import threading
import time
from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
from ray.rllib.evaluation.sample_batch_builder import \
@@ -23,9 +24,10 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
logger = logging.getLogger(__name__)
RolloutMetrics = namedtuple(
"RolloutMetrics",
["episode_length", "episode_reward", "agent_rewards", "custom_metrics"])
RolloutMetrics = namedtuple("RolloutMetrics", [
"episode_length", "episode_reward", "agent_rewards", "custom_metrics",
"perf_stats"
])
PolicyEvalData = namedtuple("PolicyEvalData", [
"env_id", "agent_id", "obs", "info", "rnn_state", "prev_action",
@@ -33,6 +35,23 @@ PolicyEvalData = namedtuple("PolicyEvalData", [
])
class PerfStats(object):
"""Sampler perf stats that will be included in rollout metrics."""
def __init__(self):
self.iters = 0
self.env_wait_time = 0.0
self.processing_time = 0.0
self.inference_time = 0.0
def get(self):
return {
"mean_env_wait_ms": self.env_wait_time * 1000 / self.iters,
"mean_processing_ms": self.processing_time * 1000 / self.iters,
"mean_inference_ms": self.inference_time * 1000 / self.iters
}
class SamplerInput(InputReader):
"""Reads input experiences from an existing sampler."""
@@ -68,11 +87,12 @@ class SyncSampler(SamplerInput):
self.preprocessors = preprocessors
self.obs_filters = obs_filters
self.extra_batches = queue.Queue()
self.perf_stats = PerfStats()
self.rollout_provider = _env_runner(
self.base_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
pack, callbacks, tf_sess)
pack, callbacks, tf_sess, self.perf_stats)
self.metrics_queue = queue.Queue()
def get_data(self):
@@ -87,7 +107,8 @@ class SyncSampler(SamplerInput):
completed = []
while True:
try:
completed.append(self.metrics_queue.get_nowait())
completed.append(self.metrics_queue.get_nowait()._replace(
perf_stats=self.perf_stats.get()))
except queue.Empty:
break
return completed
@@ -138,6 +159,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.callbacks = callbacks
self.clip_actions = clip_actions
self.blackhole_outputs = blackhole_outputs
self.perf_stats = PerfStats()
self.shutdown = False
def run(self):
@@ -159,7 +181,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.base_env, extra_batches_putter, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self.preprocessors, self.obs_filters, self.clip_rewards,
self.clip_actions, self.pack, self.callbacks, self.tf_sess)
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
self.perf_stats)
while not self.shutdown:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
@@ -185,7 +208,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
completed = []
while True:
try:
completed.append(self.metrics_queue.get_nowait())
completed.append(self.metrics_queue.get_nowait()._replace(
perf_stats=self.perf_stats.get()))
except queue.Empty:
break
return completed
@@ -225,19 +249,10 @@ def clip_action(action, space):
return action
def _env_runner(base_env,
extra_batch_callback,
policies,
policy_mapping_fn,
unroll_length,
horizon,
preprocessors,
obs_filters,
clip_rewards,
clip_actions,
pack,
callbacks,
tf_sess=None):
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
unroll_length, horizon, preprocessors, obs_filters,
clip_rewards, clip_actions, pack, callbacks, tf_sess,
perf_stats):
"""This implements the common experience collection logic.
Args:
@@ -261,6 +276,7 @@ def _env_runner(base_env,
callbacks (dict): User callbacks to run on episode events.
tf_sess (Session|None): Optional tensorflow session to use for batching
TF policy evaluations.
perf_stats (PerfStats): Record perf stats into this object.
Yields:
rollout (SampleBatch): Object containing state, action, reward,
@@ -299,9 +315,12 @@ def _env_runner(base_env,
active_episodes = defaultdict(new_episode)
while True:
perf_stats.iters += 1
t0 = time.time()
# Get observations from all ready agents
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
base_env.poll()
perf_stats.env_wait_time += time.time() - t0
if log_once("env_returns"):
logger.info("Raw obs from env: {}".format(
@@ -309,25 +328,33 @@ def _env_runner(base_env,
logger.info("Info return from env: {}".format(summarize(infos)))
# Process observations and prepare for policy evaluation
t1 = time.time()
active_envs, to_eval, outputs = _process_observations(
base_env, policies, batch_builder_pool, active_episodes,
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
preprocessors, obs_filters, unroll_length, pack, callbacks)
perf_stats.processing_time += time.time() - t1
for o in outputs:
yield o
# Do batched policy eval
t2 = time.time()
eval_results = _do_policy_eval(tf_sess, to_eval, policies,
active_episodes)
perf_stats.inference_time += time.time() - t2
# Process results and update episode state
t3 = time.time()
actions_to_send = _process_policy_eval_results(
to_eval, eval_results, active_episodes, active_envs,
off_policy_actions, policies, clip_actions)
perf_stats.processing_time += time.time() - t3
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
t4 = time.time()
base_env.send_actions(actions_to_send)
perf_stats.env_wait_time += time.time() - t4
def _process_observations(base_env, policies, batch_builder_pool,
@@ -380,7 +407,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
outputs.append(
RolloutMetrics(episode.length, episode.total_reward,
dict(episode.agent_rewards),
episode.custom_metrics))
episode.custom_metrics, {}))
else:
all_done = False
active_envs.add(env_id)
@@ -602,7 +629,7 @@ def _fetch_atari_metrics(base_env):
if not monitor:
return None
for eps_rew, eps_len in monitor.next_episode_results():
atari_out.append(RolloutMetrics(eps_len, eps_rew, {}, {}))
atari_out.append(RolloutMetrics(eps_len, eps_rew, {}, {}, {}))
return atari_out