[rllib] Implement custom metrics (#3144)

This commit is contained in:
Eric Liang
2018-11-03 18:48:32 -07:00
committed by GitHub
parent 7d69c77a19
commit 369cb833fe
12 changed files with 248 additions and 23 deletions
+13 -2
View File
@@ -26,8 +26,18 @@ COMMON_CONFIG = {
# === Debugging ===
# Whether to write episode stats and videos to the agent log dir
"monitor": False,
# Set the RLlib log level for the agent process and its remote evaluators
# Set the ray.rllib.* log level for the agent process and its evaluators
"log_level": "INFO",
# Callbacks that will be run during various phases of training. These all
# take a single "info" dict as an argument. For episode callbacks, custom
# metrics can be attached to the episode by updating the episode object's
# custom metrics dict (see examples/custom_metrics_and_callbacks.py).
"callbacks": {
"on_episode_start": None, # arg: {"env": .., "episode": ...}
"on_episode_step": None, # arg: {"env": .., "episode": ...}
"on_episode_end": None, # arg: {"env": .., "episode": ...}
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
},
# === Policy ===
# Arguments to pass to model. See models/catalog.py for a full list of the
@@ -184,7 +194,8 @@ class Agent(Trainable):
policy_config=config,
worker_index=worker_index,
monitor_path=self.logdir if config["monitor"] else None,
log_level=config["log_level"])
log_level=config["log_level"],
callbacks=config["callbacks"])
@classmethod
def resource_help(cls, config):
+14 -8
View File
@@ -7,13 +7,15 @@ import random
import numpy as np
from ray.rllib.env.async_vector_env import _DUMMY_AGENT_ID
class MultiAgentEpisode(object):
"""Tracks the current state of a (possibly multi-agent) episode.
The APIs in this class should be considered experimental, but we should
avoid changing things for the sake of changing them since users may
depend on them for advanced algorithms.
depend on them for custom metrics or advanced algorithms.
Attributes:
new_batch_builder (func): Create a new MultiAgentSampleBatchBuilder.
@@ -23,6 +25,8 @@ class MultiAgentEpisode(object):
length (int): Length of this episode.
episode_id (int): Unique id identifying this trajectory.
agent_rewards (dict): Summed rewards broken down by agent.
custom_metrics (dict): Dict where the you can add custom metrics.
user_data (dict): Dict that you can use for temporary storage.
Use case 1: Model-based rollouts in multi-agent:
A custom compute_actions() function in a policy graph can inspect the
@@ -47,6 +51,8 @@ class MultiAgentEpisode(object):
self.length = 0
self.episode_id = random.randrange(2e9)
self.agent_rewards = defaultdict(float)
self.custom_metrics = {}
self.user_data = {}
self._policies = policies
self._policy_mapping_fn = policy_mapping_fn
self._agent_to_policy = {}
@@ -57,7 +63,7 @@ class MultiAgentEpisode(object):
self._agent_to_prev_action = {}
self._agent_reward_history = defaultdict(list)
def policy_for(self, agent_id):
def policy_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the policy graph for the specified agent.
If the agent is new, the policy mapping fn will be called to bind the
@@ -68,12 +74,12 @@ class MultiAgentEpisode(object):
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
return self._agent_to_policy[agent_id]
def last_observation_for(self, agent_id):
def last_observation_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last observation for the specified agent."""
return self._agent_to_last_obs.get(agent_id)
def last_action_for(self, agent_id):
def last_action_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last action for the specified agent, or zeros."""
if agent_id in self._agent_to_last_action:
@@ -83,7 +89,7 @@ class MultiAgentEpisode(object):
flat = _flatten_action(policy.action_space.sample())
return np.zeros_like(flat)
def prev_action_for(self, agent_id):
def prev_action_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the previous action for the specified agent."""
if agent_id in self._agent_to_prev_action:
@@ -92,7 +98,7 @@ class MultiAgentEpisode(object):
# We're at t=0, so return all zeros.
return np.zeros_like(self.last_action_for(agent_id))
def prev_reward_for(self, agent_id):
def prev_reward_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the previous reward for the specified agent."""
history = self._agent_reward_history[agent_id]
@@ -102,7 +108,7 @@ class MultiAgentEpisode(object):
# We're at t=0, so there is no previous reward, just return zero.
return 0.0
def rnn_state_for(self, agent_id):
def rnn_state_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last RNN state for the specified agent."""
if agent_id not in self._agent_to_rnn_state:
@@ -110,7 +116,7 @@ class MultiAgentEpisode(object):
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
return self._agent_to_rnn_state[agent_id]
def last_pi_info_for(self, agent_id):
def last_pi_info_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last info object for the specified agent."""
return self._agent_to_last_pi_info[agent_id]
+7
View File
@@ -59,9 +59,12 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
episode_rewards = []
episode_lengths = []
policy_rewards = collections.defaultdict(list)
custom_metrics = collections.defaultdict(list)
for episode in episodes:
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
for k, v in episode.custom_metrics.items():
custom_metrics[k].append(v)
for (_, policy_id), reward in episode.agent_rewards.items():
if policy_id != DEFAULT_POLICY_ID:
policy_rewards[policy_id].append(reward)
@@ -77,6 +80,9 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
for policy_id, rewards in policy_rewards.copy().items():
policy_rewards[policy_id] = np.mean(rewards)
for k, v_list in custom_metrics.items():
custom_metrics[k] = np.mean(v_list)
return dict(
episode_reward_max=max_reward,
episode_reward_min=min_reward,
@@ -84,4 +90,5 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
episode_len_mean=avg_length,
episodes_this_iter=len(new_episodes),
policy_reward_mean=dict(policy_rewards),
custom_metrics=dict(custom_metrics),
num_metric_batches_dropped=num_dropped)
@@ -71,7 +71,7 @@ class PolicyEvaluator(EvaluatorInterface):
... policy_mapping_fn=lambda agent_id:
... random.choice(["car_policy1", "car_policy2"])
... if agent_id.startswith("car_") else "traffic_light_policy")
>>> print(evaluator.sample().keys())
>>> print(evaluator.sample())
MultiAgentBatch({
"car_policy1": SampleBatch(...),
"car_policy2": SampleBatch(...),
@@ -102,7 +102,8 @@ class PolicyEvaluator(EvaluatorInterface):
policy_config=None,
worker_index=0,
monitor_path=None,
log_level=None):
log_level=None,
callbacks=None):
"""Initialize a policy evaluator.
Arguments:
@@ -162,6 +163,7 @@ class PolicyEvaluator(EvaluatorInterface):
monitor_path (str): Write out episode stats and videos to this
directory if specified.
log_level (str): Set the root log level on creation.
callbacks (dict): Dict of custom debug callbacks.
"""
if log_level:
@@ -170,6 +172,7 @@ class PolicyEvaluator(EvaluatorInterface):
env_context = EnvContext(env_config or {}, worker_index)
policy_config = policy_config or {}
self.policy_config = policy_config
self.callbacks = callbacks or {}
model_config = model_config or {}
policy_mapping_fn = (policy_mapping_fn
or (lambda agent_id: DEFAULT_POLICY_ID))
@@ -280,6 +283,7 @@ class PolicyEvaluator(EvaluatorInterface):
self.filters,
clip_rewards,
unroll_length,
self.callbacks,
horizon=episode_horizon,
pack=pack_episodes,
tf_sess=self.tf_sess)
@@ -292,6 +296,7 @@ class PolicyEvaluator(EvaluatorInterface):
self.filters,
clip_rewards,
unroll_length,
self.callbacks,
horizon=episode_horizon,
pack=pack_episodes,
tf_sess=self.tf_sess)
@@ -342,6 +347,12 @@ class PolicyEvaluator(EvaluatorInterface):
batches.extend(self.sampler.get_extra_batches())
batch = batches[0].concat_samples(batches)
if self.callbacks.get("on_sample_end"):
self.callbacks["on_sample_end"]({
"evaluator": self,
"samples": batch
})
if self.compress_observations:
if isinstance(batch, MultiAgentBatch):
for data in batch.policy_batches.values():
+34 -8
View File
@@ -20,7 +20,8 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
logger = logging.getLogger(__name__)
RolloutMetrics = namedtuple(
"RolloutMetrics", ["episode_length", "episode_reward", "agent_rewards"])
"RolloutMetrics",
["episode_length", "episode_reward", "agent_rewards", "custom_metrics"])
PolicyEvalData = namedtuple(
"PolicyEvalData",
@@ -43,6 +44,7 @@ class SyncSampler(object):
obs_filters,
clip_rewards,
unroll_length,
callbacks,
horizon=None,
pack=False,
tf_sess=None):
@@ -56,7 +58,7 @@ class SyncSampler(object):
self.rollout_provider = _env_runner(
self.async_vector_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self._obs_filters, clip_rewards, pack, tf_sess)
self._obs_filters, clip_rewards, pack, callbacks, tf_sess)
self.metrics_queue = queue.Queue()
def get_data(self):
@@ -99,6 +101,7 @@ class AsyncSampler(threading.Thread):
obs_filters,
clip_rewards,
unroll_length,
callbacks,
horizon=None,
pack=False,
tf_sess=None):
@@ -119,6 +122,7 @@ class AsyncSampler(threading.Thread):
self.daemon = True
self.pack = pack
self.tf_sess = tf_sess
self.callbacks = callbacks
def run(self):
try:
@@ -131,7 +135,8 @@ class AsyncSampler(threading.Thread):
rollout_provider = _env_runner(
self.async_vector_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self._obs_filters, self.clip_rewards, self.pack, self.tf_sess)
self._obs_filters, self.clip_rewards, self.pack, self.callbacks,
self.tf_sess)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
@@ -193,6 +198,7 @@ def _env_runner(async_vector_env,
obs_filters,
clip_rewards,
pack,
callbacks,
tf_sess=None):
"""This implements the common experience collection logic.
@@ -211,6 +217,7 @@ def _env_runner(async_vector_env,
clip_rewards (bool): Whether to clip rewards before postprocessing.
pack (bool): Whether to pack multiple episodes into each batch. This
guarantees batches will be exactly `unroll_length` in size.
callbacks (dict): User callbacks to run on episode events.
tf_sess (Session|None): Optional tensorflow session to use for batching
TF policy evaluations.
@@ -239,8 +246,14 @@ def _env_runner(async_vector_env,
return MultiAgentSampleBatchBuilder(policies, clip_rewards)
def new_episode():
return MultiAgentEpisode(policies, policy_mapping_fn,
get_batch_builder, extra_batch_callback)
episode = MultiAgentEpisode(policies, policy_mapping_fn,
get_batch_builder, extra_batch_callback)
if callbacks.get("on_episode_start"):
callbacks["on_episode_start"]({
"env": async_vector_env,
"episode": episode
})
return episode
active_episodes = defaultdict(new_episode)
@@ -270,10 +283,11 @@ def _env_runner(async_vector_env,
atari_metrics = _fetch_atari_metrics(async_vector_env)
if atari_metrics is not None:
for m in atari_metrics:
yield m
yield m._replace(custom_metrics=episode.custom_metrics)
else:
yield RolloutMetrics(episode.length, episode.total_reward,
dict(episode.agent_rewards))
dict(episode.agent_rewards),
episode.custom_metrics)
else:
all_done = False
# At least send an empty dict if not done
@@ -312,6 +326,13 @@ def _env_runner(async_vector_env,
new_obs=filtered_obs,
**episode.last_pi_info_for(agent_id))
# Invoke the step callback after the step is logged to the episode
if callbacks.get("on_episode_step"):
callbacks["on_episode_step"]({
"env": async_vector_env,
"episode": episode
})
# Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size.
if episode.batch_builder.has_pending_data():
@@ -325,6 +346,11 @@ def _env_runner(async_vector_env,
if all_done:
# Handle episode termination
batch_builder_pool.append(episode.batch_builder)
if callbacks.get("on_episode_end"):
callbacks["on_episode_end"]({
"env": async_vector_env,
"episode": episode
})
del active_episodes[env_id]
resetted_obs = async_vector_env.try_reset(env_id)
if resetted_obs is None:
@@ -429,7 +455,7 @@ def _fetch_atari_metrics(async_vector_env):
if not monitor:
return None
for eps_rew, eps_len in monitor.next_episode_results():
atari_out.append(RolloutMetrics(eps_len, eps_rew, {}))
atari_out.append(RolloutMetrics(eps_len, eps_rew, {}, {}))
return atari_out
@@ -0,0 +1,66 @@
"""Example of using RLlib's debug callbacks.
Here we use callbacks to track the average CartPole pole angle magnitude as a
custom metric.
"""
import argparse
import numpy as np
import ray
from ray import tune
def on_episode_start(info):
episode = info["episode"]
print("episode {} started".format(episode.episode_id))
episode.user_data["pole_angles"] = []
def on_episode_step(info):
episode = info["episode"]
pole_angle = abs(episode.last_observation_for()[2])
episode.user_data["pole_angles"].append(pole_angle)
def on_episode_end(info):
episode = info["episode"]
mean_pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode {} ended with length {} and pole angles {}".format(
episode.episode_id, episode.length, mean_pole_angle))
episode.custom_metrics["mean_pole_angle"] = mean_pole_angle
def on_sample_end(info):
print("returned sample batch of size {}".format(info["samples"].count))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=2000)
args = parser.parse_args()
ray.init()
trials = tune.run_experiments({
"test": {
"env": "CartPole-v0",
"run": "PG",
"stop": {
"training_iteration": args.num_iters,
},
"config": {
"callbacks": {
"on_episode_start": tune.function(on_episode_start),
"on_episode_step": tune.function(on_episode_step),
"on_episode_end": tune.function(on_episode_end),
"on_sample_end": tune.function(on_sample_end),
},
},
}
})
# verify custom metrics for integration tests
custom_metrics = trials[0].last_result["custom_metrics"]
print(custom_metrics)
assert "mean_pole_angle" in custom_metrics
assert type(custom_metrics["mean_pole_angle"]) is float
@@ -6,6 +6,7 @@ import gym
import numpy as np
import time
import unittest
from collections import Counter
import ray
from ray.rllib.agents.pg import PGAgent
@@ -150,6 +151,26 @@ class TestPolicyEvaluator(unittest.TestCase):
result2 = agent.train()
self.assertLess(result2["info"]["learner"]["cur_lr"], 0.0001)
def testCallbacks(self):
counts = Counter()
pg = PGAgent(
env="CartPole-v0", config={
"num_workers": 0,
"sample_batch_size": 50,
"callbacks": {
"on_episode_start": lambda x: counts.update({"start": 1}),
"on_episode_step": lambda x: counts.update({"step": 1}),
"on_episode_end": lambda x: counts.update({"end": 1}),
"on_sample_end": lambda x: counts.update({"sample": 1}),
},
})
pg.train()
self.assertEqual(counts["sample"], 1)
self.assertGreater(counts["start"], 0)
self.assertGreater(counts["end"], 0)
self.assertGreater(counts["step"], 50)
self.assertLess(counts["step"], 100)
def testQueryEvaluators(self):
register_env("test", lambda _: gym.make("CartPole-v0"))
pg = PGAgent(