mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 19:58:40 +08:00
[rllib] Report sampler performance metrics (#4427)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
RLlib: Scalable Reinforcement Learning
|
||||
======================================
|
||||
|
||||
RLlib is an open-source library for reinforcement learning that offers both a unified API for a variety of applications and high scalability via distributed eager execution.
|
||||
RLlib is an open-source library for reinforcement learning that offers both high scalability and a unified API for a variety of applications.
|
||||
|
||||
For an overview of RLlib, see the [documentation](http://ray.readthedocs.io/en/latest/rllib.html).
|
||||
|
||||
|
||||
@@ -72,11 +72,14 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
episode_lengths = []
|
||||
policy_rewards = collections.defaultdict(list)
|
||||
custom_metrics = collections.defaultdict(list)
|
||||
perf_stats = collections.defaultdict(list)
|
||||
for episode in episodes:
|
||||
episode_lengths.append(episode.episode_length)
|
||||
episode_rewards.append(episode.episode_reward)
|
||||
for k, v in episode.custom_metrics.items():
|
||||
custom_metrics[k].append(v)
|
||||
for k, v in episode.perf_stats.items():
|
||||
perf_stats[k].append(v)
|
||||
for (_, policy_id), reward in episode.agent_rewards.items():
|
||||
if policy_id != DEFAULT_POLICY_ID:
|
||||
policy_rewards[policy_id].append(reward)
|
||||
@@ -103,6 +106,9 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
custom_metrics[k + "_max"] = float("nan")
|
||||
del custom_metrics[k]
|
||||
|
||||
for k, v_list in perf_stats.copy().items():
|
||||
perf_stats[k] = np.mean(v_list)
|
||||
|
||||
estimators = collections.defaultdict(lambda: collections.defaultdict(list))
|
||||
for e in estimates:
|
||||
acc = estimators[e.estimator_name]
|
||||
@@ -121,6 +127,7 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
episodes_this_iter=len(new_episodes),
|
||||
policy_reward_mean=dict(policy_rewards),
|
||||
custom_metrics=dict(custom_metrics),
|
||||
sampler_perf=dict(perf_stats),
|
||||
off_policy_estimator=dict(estimators),
|
||||
num_metric_batches_dropped=num_dropped)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import logging
|
||||
import numpy as np
|
||||
import six.moves.queue as queue
|
||||
import threading
|
||||
import time
|
||||
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
|
||||
from ray.rllib.evaluation.sample_batch_builder import \
|
||||
@@ -23,9 +24,10 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RolloutMetrics = namedtuple(
|
||||
"RolloutMetrics",
|
||||
["episode_length", "episode_reward", "agent_rewards", "custom_metrics"])
|
||||
RolloutMetrics = namedtuple("RolloutMetrics", [
|
||||
"episode_length", "episode_reward", "agent_rewards", "custom_metrics",
|
||||
"perf_stats"
|
||||
])
|
||||
|
||||
PolicyEvalData = namedtuple("PolicyEvalData", [
|
||||
"env_id", "agent_id", "obs", "info", "rnn_state", "prev_action",
|
||||
@@ -33,6 +35,23 @@ PolicyEvalData = namedtuple("PolicyEvalData", [
|
||||
])
|
||||
|
||||
|
||||
class PerfStats(object):
|
||||
"""Sampler perf stats that will be included in rollout metrics."""
|
||||
|
||||
def __init__(self):
|
||||
self.iters = 0
|
||||
self.env_wait_time = 0.0
|
||||
self.processing_time = 0.0
|
||||
self.inference_time = 0.0
|
||||
|
||||
def get(self):
|
||||
return {
|
||||
"mean_env_wait_ms": self.env_wait_time * 1000 / self.iters,
|
||||
"mean_processing_ms": self.processing_time * 1000 / self.iters,
|
||||
"mean_inference_ms": self.inference_time * 1000 / self.iters
|
||||
}
|
||||
|
||||
|
||||
class SamplerInput(InputReader):
|
||||
"""Reads input experiences from an existing sampler."""
|
||||
|
||||
@@ -68,11 +87,12 @@ class SyncSampler(SamplerInput):
|
||||
self.preprocessors = preprocessors
|
||||
self.obs_filters = obs_filters
|
||||
self.extra_batches = queue.Queue()
|
||||
self.perf_stats = PerfStats()
|
||||
self.rollout_provider = _env_runner(
|
||||
self.base_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
||||
pack, callbacks, tf_sess)
|
||||
pack, callbacks, tf_sess, self.perf_stats)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
def get_data(self):
|
||||
@@ -87,7 +107,8 @@ class SyncSampler(SamplerInput):
|
||||
completed = []
|
||||
while True:
|
||||
try:
|
||||
completed.append(self.metrics_queue.get_nowait())
|
||||
completed.append(self.metrics_queue.get_nowait()._replace(
|
||||
perf_stats=self.perf_stats.get()))
|
||||
except queue.Empty:
|
||||
break
|
||||
return completed
|
||||
@@ -138,6 +159,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
self.callbacks = callbacks
|
||||
self.clip_actions = clip_actions
|
||||
self.blackhole_outputs = blackhole_outputs
|
||||
self.perf_stats = PerfStats()
|
||||
self.shutdown = False
|
||||
|
||||
def run(self):
|
||||
@@ -159,7 +181,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
self.base_env, extra_batches_putter, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, self.clip_rewards,
|
||||
self.clip_actions, self.pack, self.callbacks, self.tf_sess)
|
||||
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
|
||||
self.perf_stats)
|
||||
while not self.shutdown:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
@@ -185,7 +208,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
completed = []
|
||||
while True:
|
||||
try:
|
||||
completed.append(self.metrics_queue.get_nowait())
|
||||
completed.append(self.metrics_queue.get_nowait()._replace(
|
||||
perf_stats=self.perf_stats.get()))
|
||||
except queue.Empty:
|
||||
break
|
||||
return completed
|
||||
@@ -225,19 +249,10 @@ def clip_action(action, space):
|
||||
return action
|
||||
|
||||
|
||||
def _env_runner(base_env,
|
||||
extra_batch_callback,
|
||||
policies,
|
||||
policy_mapping_fn,
|
||||
unroll_length,
|
||||
horizon,
|
||||
preprocessors,
|
||||
obs_filters,
|
||||
clip_rewards,
|
||||
clip_actions,
|
||||
pack,
|
||||
callbacks,
|
||||
tf_sess=None):
|
||||
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
unroll_length, horizon, preprocessors, obs_filters,
|
||||
clip_rewards, clip_actions, pack, callbacks, tf_sess,
|
||||
perf_stats):
|
||||
"""This implements the common experience collection logic.
|
||||
|
||||
Args:
|
||||
@@ -261,6 +276,7 @@ def _env_runner(base_env,
|
||||
callbacks (dict): User callbacks to run on episode events.
|
||||
tf_sess (Session|None): Optional tensorflow session to use for batching
|
||||
TF policy evaluations.
|
||||
perf_stats (PerfStats): Record perf stats into this object.
|
||||
|
||||
Yields:
|
||||
rollout (SampleBatch): Object containing state, action, reward,
|
||||
@@ -299,9 +315,12 @@ def _env_runner(base_env,
|
||||
active_episodes = defaultdict(new_episode)
|
||||
|
||||
while True:
|
||||
perf_stats.iters += 1
|
||||
t0 = time.time()
|
||||
# Get observations from all ready agents
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
|
||||
base_env.poll()
|
||||
perf_stats.env_wait_time += time.time() - t0
|
||||
|
||||
if log_once("env_returns"):
|
||||
logger.info("Raw obs from env: {}".format(
|
||||
@@ -309,25 +328,33 @@ def _env_runner(base_env,
|
||||
logger.info("Info return from env: {}".format(summarize(infos)))
|
||||
|
||||
# Process observations and prepare for policy evaluation
|
||||
t1 = time.time()
|
||||
active_envs, to_eval, outputs = _process_observations(
|
||||
base_env, policies, batch_builder_pool, active_episodes,
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
|
||||
preprocessors, obs_filters, unroll_length, pack, callbacks)
|
||||
perf_stats.processing_time += time.time() - t1
|
||||
for o in outputs:
|
||||
yield o
|
||||
|
||||
# Do batched policy eval
|
||||
t2 = time.time()
|
||||
eval_results = _do_policy_eval(tf_sess, to_eval, policies,
|
||||
active_episodes)
|
||||
perf_stats.inference_time += time.time() - t2
|
||||
|
||||
# Process results and update episode state
|
||||
t3 = time.time()
|
||||
actions_to_send = _process_policy_eval_results(
|
||||
to_eval, eval_results, active_episodes, active_envs,
|
||||
off_policy_actions, policies, clip_actions)
|
||||
perf_stats.processing_time += time.time() - t3
|
||||
|
||||
# Return computed actions to ready envs. We also send to envs that have
|
||||
# taken off-policy actions; those envs are free to ignore the action.
|
||||
t4 = time.time()
|
||||
base_env.send_actions(actions_to_send)
|
||||
perf_stats.env_wait_time += time.time() - t4
|
||||
|
||||
|
||||
def _process_observations(base_env, policies, batch_builder_pool,
|
||||
@@ -380,7 +407,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||
outputs.append(
|
||||
RolloutMetrics(episode.length, episode.total_reward,
|
||||
dict(episode.agent_rewards),
|
||||
episode.custom_metrics))
|
||||
episode.custom_metrics, {}))
|
||||
else:
|
||||
all_done = False
|
||||
active_envs.add(env_id)
|
||||
@@ -602,7 +629,7 @@ def _fetch_atari_metrics(base_env):
|
||||
if not monitor:
|
||||
return None
|
||||
for eps_rew, eps_len in monitor.next_episode_results():
|
||||
atari_out.append(RolloutMetrics(eps_len, eps_rew, {}, {}))
|
||||
atari_out.append(RolloutMetrics(eps_len, eps_rew, {}, {}, {}))
|
||||
return atari_out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user