mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:54:27 +08:00
[rllib] Implement custom metrics (#3144)
This commit is contained in:
@@ -26,8 +26,18 @@ COMMON_CONFIG = {
|
||||
# === Debugging ===
|
||||
# Whether to write episode stats and videos to the agent log dir
|
||||
"monitor": False,
|
||||
# Set the RLlib log level for the agent process and its remote evaluators
|
||||
# Set the ray.rllib.* log level for the agent process and its evaluators
|
||||
"log_level": "INFO",
|
||||
# Callbacks that will be run during various phases of training. These all
|
||||
# take a single "info" dict as an argument. For episode callbacks, custom
|
||||
# metrics can be attached to the episode by updating the episode object's
|
||||
# custom metrics dict (see examples/custom_metrics_and_callbacks.py).
|
||||
"callbacks": {
|
||||
"on_episode_start": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_step": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_end": None, # arg: {"env": .., "episode": ...}
|
||||
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
|
||||
},
|
||||
|
||||
# === Policy ===
|
||||
# Arguments to pass to model. See models/catalog.py for a full list of the
|
||||
@@ -184,7 +194,8 @@ class Agent(Trainable):
|
||||
policy_config=config,
|
||||
worker_index=worker_index,
|
||||
monitor_path=self.logdir if config["monitor"] else None,
|
||||
log_level=config["log_level"])
|
||||
log_level=config["log_level"],
|
||||
callbacks=config["callbacks"])
|
||||
|
||||
@classmethod
|
||||
def resource_help(cls, config):
|
||||
|
||||
@@ -7,13 +7,15 @@ import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.env.async_vector_env import _DUMMY_AGENT_ID
|
||||
|
||||
|
||||
class MultiAgentEpisode(object):
|
||||
"""Tracks the current state of a (possibly multi-agent) episode.
|
||||
|
||||
The APIs in this class should be considered experimental, but we should
|
||||
avoid changing things for the sake of changing them since users may
|
||||
depend on them for advanced algorithms.
|
||||
depend on them for custom metrics or advanced algorithms.
|
||||
|
||||
Attributes:
|
||||
new_batch_builder (func): Create a new MultiAgentSampleBatchBuilder.
|
||||
@@ -23,6 +25,8 @@ class MultiAgentEpisode(object):
|
||||
length (int): Length of this episode.
|
||||
episode_id (int): Unique id identifying this trajectory.
|
||||
agent_rewards (dict): Summed rewards broken down by agent.
|
||||
custom_metrics (dict): Dict where the you can add custom metrics.
|
||||
user_data (dict): Dict that you can use for temporary storage.
|
||||
|
||||
Use case 1: Model-based rollouts in multi-agent:
|
||||
A custom compute_actions() function in a policy graph can inspect the
|
||||
@@ -47,6 +51,8 @@ class MultiAgentEpisode(object):
|
||||
self.length = 0
|
||||
self.episode_id = random.randrange(2e9)
|
||||
self.agent_rewards = defaultdict(float)
|
||||
self.custom_metrics = {}
|
||||
self.user_data = {}
|
||||
self._policies = policies
|
||||
self._policy_mapping_fn = policy_mapping_fn
|
||||
self._agent_to_policy = {}
|
||||
@@ -57,7 +63,7 @@ class MultiAgentEpisode(object):
|
||||
self._agent_to_prev_action = {}
|
||||
self._agent_reward_history = defaultdict(list)
|
||||
|
||||
def policy_for(self, agent_id):
|
||||
def policy_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the policy graph for the specified agent.
|
||||
|
||||
If the agent is new, the policy mapping fn will be called to bind the
|
||||
@@ -68,12 +74,12 @@ class MultiAgentEpisode(object):
|
||||
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
|
||||
return self._agent_to_policy[agent_id]
|
||||
|
||||
def last_observation_for(self, agent_id):
|
||||
def last_observation_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last observation for the specified agent."""
|
||||
|
||||
return self._agent_to_last_obs.get(agent_id)
|
||||
|
||||
def last_action_for(self, agent_id):
|
||||
def last_action_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last action for the specified agent, or zeros."""
|
||||
|
||||
if agent_id in self._agent_to_last_action:
|
||||
@@ -83,7 +89,7 @@ class MultiAgentEpisode(object):
|
||||
flat = _flatten_action(policy.action_space.sample())
|
||||
return np.zeros_like(flat)
|
||||
|
||||
def prev_action_for(self, agent_id):
|
||||
def prev_action_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the previous action for the specified agent."""
|
||||
|
||||
if agent_id in self._agent_to_prev_action:
|
||||
@@ -92,7 +98,7 @@ class MultiAgentEpisode(object):
|
||||
# We're at t=0, so return all zeros.
|
||||
return np.zeros_like(self.last_action_for(agent_id))
|
||||
|
||||
def prev_reward_for(self, agent_id):
|
||||
def prev_reward_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the previous reward for the specified agent."""
|
||||
|
||||
history = self._agent_reward_history[agent_id]
|
||||
@@ -102,7 +108,7 @@ class MultiAgentEpisode(object):
|
||||
# We're at t=0, so there is no previous reward, just return zero.
|
||||
return 0.0
|
||||
|
||||
def rnn_state_for(self, agent_id):
|
||||
def rnn_state_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last RNN state for the specified agent."""
|
||||
|
||||
if agent_id not in self._agent_to_rnn_state:
|
||||
@@ -110,7 +116,7 @@ class MultiAgentEpisode(object):
|
||||
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
|
||||
return self._agent_to_rnn_state[agent_id]
|
||||
|
||||
def last_pi_info_for(self, agent_id):
|
||||
def last_pi_info_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last info object for the specified agent."""
|
||||
|
||||
return self._agent_to_last_pi_info[agent_id]
|
||||
|
||||
@@ -59,9 +59,12 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
policy_rewards = collections.defaultdict(list)
|
||||
custom_metrics = collections.defaultdict(list)
|
||||
for episode in episodes:
|
||||
episode_lengths.append(episode.episode_length)
|
||||
episode_rewards.append(episode.episode_reward)
|
||||
for k, v in episode.custom_metrics.items():
|
||||
custom_metrics[k].append(v)
|
||||
for (_, policy_id), reward in episode.agent_rewards.items():
|
||||
if policy_id != DEFAULT_POLICY_ID:
|
||||
policy_rewards[policy_id].append(reward)
|
||||
@@ -77,6 +80,9 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
for policy_id, rewards in policy_rewards.copy().items():
|
||||
policy_rewards[policy_id] = np.mean(rewards)
|
||||
|
||||
for k, v_list in custom_metrics.items():
|
||||
custom_metrics[k] = np.mean(v_list)
|
||||
|
||||
return dict(
|
||||
episode_reward_max=max_reward,
|
||||
episode_reward_min=min_reward,
|
||||
@@ -84,4 +90,5 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
episode_len_mean=avg_length,
|
||||
episodes_this_iter=len(new_episodes),
|
||||
policy_reward_mean=dict(policy_rewards),
|
||||
custom_metrics=dict(custom_metrics),
|
||||
num_metric_batches_dropped=num_dropped)
|
||||
|
||||
@@ -71,7 +71,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
... policy_mapping_fn=lambda agent_id:
|
||||
... random.choice(["car_policy1", "car_policy2"])
|
||||
... if agent_id.startswith("car_") else "traffic_light_policy")
|
||||
>>> print(evaluator.sample().keys())
|
||||
>>> print(evaluator.sample())
|
||||
MultiAgentBatch({
|
||||
"car_policy1": SampleBatch(...),
|
||||
"car_policy2": SampleBatch(...),
|
||||
@@ -102,7 +102,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
policy_config=None,
|
||||
worker_index=0,
|
||||
monitor_path=None,
|
||||
log_level=None):
|
||||
log_level=None,
|
||||
callbacks=None):
|
||||
"""Initialize a policy evaluator.
|
||||
|
||||
Arguments:
|
||||
@@ -162,6 +163,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
monitor_path (str): Write out episode stats and videos to this
|
||||
directory if specified.
|
||||
log_level (str): Set the root log level on creation.
|
||||
callbacks (dict): Dict of custom debug callbacks.
|
||||
"""
|
||||
|
||||
if log_level:
|
||||
@@ -170,6 +172,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
env_context = EnvContext(env_config or {}, worker_index)
|
||||
policy_config = policy_config or {}
|
||||
self.policy_config = policy_config
|
||||
self.callbacks = callbacks or {}
|
||||
model_config = model_config or {}
|
||||
policy_mapping_fn = (policy_mapping_fn
|
||||
or (lambda agent_id: DEFAULT_POLICY_ID))
|
||||
@@ -280,6 +283,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess)
|
||||
@@ -292,6 +296,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess)
|
||||
@@ -342,6 +347,12 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
batches.extend(self.sampler.get_extra_batches())
|
||||
batch = batches[0].concat_samples(batches)
|
||||
|
||||
if self.callbacks.get("on_sample_end"):
|
||||
self.callbacks["on_sample_end"]({
|
||||
"evaluator": self,
|
||||
"samples": batch
|
||||
})
|
||||
|
||||
if self.compress_observations:
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
for data in batch.policy_batches.values():
|
||||
|
||||
@@ -20,7 +20,8 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RolloutMetrics = namedtuple(
|
||||
"RolloutMetrics", ["episode_length", "episode_reward", "agent_rewards"])
|
||||
"RolloutMetrics",
|
||||
["episode_length", "episode_reward", "agent_rewards", "custom_metrics"])
|
||||
|
||||
PolicyEvalData = namedtuple(
|
||||
"PolicyEvalData",
|
||||
@@ -43,6 +44,7 @@ class SyncSampler(object):
|
||||
obs_filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
callbacks,
|
||||
horizon=None,
|
||||
pack=False,
|
||||
tf_sess=None):
|
||||
@@ -56,7 +58,7 @@ class SyncSampler(object):
|
||||
self.rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self._obs_filters, clip_rewards, pack, tf_sess)
|
||||
self._obs_filters, clip_rewards, pack, callbacks, tf_sess)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
def get_data(self):
|
||||
@@ -99,6 +101,7 @@ class AsyncSampler(threading.Thread):
|
||||
obs_filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
callbacks,
|
||||
horizon=None,
|
||||
pack=False,
|
||||
tf_sess=None):
|
||||
@@ -119,6 +122,7 @@ class AsyncSampler(threading.Thread):
|
||||
self.daemon = True
|
||||
self.pack = pack
|
||||
self.tf_sess = tf_sess
|
||||
self.callbacks = callbacks
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
@@ -131,7 +135,8 @@ class AsyncSampler(threading.Thread):
|
||||
rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self._obs_filters, self.clip_rewards, self.pack, self.tf_sess)
|
||||
self._obs_filters, self.clip_rewards, self.pack, self.callbacks,
|
||||
self.tf_sess)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
@@ -193,6 +198,7 @@ def _env_runner(async_vector_env,
|
||||
obs_filters,
|
||||
clip_rewards,
|
||||
pack,
|
||||
callbacks,
|
||||
tf_sess=None):
|
||||
"""This implements the common experience collection logic.
|
||||
|
||||
@@ -211,6 +217,7 @@ def _env_runner(async_vector_env,
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
pack (bool): Whether to pack multiple episodes into each batch. This
|
||||
guarantees batches will be exactly `unroll_length` in size.
|
||||
callbacks (dict): User callbacks to run on episode events.
|
||||
tf_sess (Session|None): Optional tensorflow session to use for batching
|
||||
TF policy evaluations.
|
||||
|
||||
@@ -239,8 +246,14 @@ def _env_runner(async_vector_env,
|
||||
return MultiAgentSampleBatchBuilder(policies, clip_rewards)
|
||||
|
||||
def new_episode():
|
||||
return MultiAgentEpisode(policies, policy_mapping_fn,
|
||||
get_batch_builder, extra_batch_callback)
|
||||
episode = MultiAgentEpisode(policies, policy_mapping_fn,
|
||||
get_batch_builder, extra_batch_callback)
|
||||
if callbacks.get("on_episode_start"):
|
||||
callbacks["on_episode_start"]({
|
||||
"env": async_vector_env,
|
||||
"episode": episode
|
||||
})
|
||||
return episode
|
||||
|
||||
active_episodes = defaultdict(new_episode)
|
||||
|
||||
@@ -270,10 +283,11 @@ def _env_runner(async_vector_env,
|
||||
atari_metrics = _fetch_atari_metrics(async_vector_env)
|
||||
if atari_metrics is not None:
|
||||
for m in atari_metrics:
|
||||
yield m
|
||||
yield m._replace(custom_metrics=episode.custom_metrics)
|
||||
else:
|
||||
yield RolloutMetrics(episode.length, episode.total_reward,
|
||||
dict(episode.agent_rewards))
|
||||
dict(episode.agent_rewards),
|
||||
episode.custom_metrics)
|
||||
else:
|
||||
all_done = False
|
||||
# At least send an empty dict if not done
|
||||
@@ -312,6 +326,13 @@ def _env_runner(async_vector_env,
|
||||
new_obs=filtered_obs,
|
||||
**episode.last_pi_info_for(agent_id))
|
||||
|
||||
# Invoke the step callback after the step is logged to the episode
|
||||
if callbacks.get("on_episode_step"):
|
||||
callbacks["on_episode_step"]({
|
||||
"env": async_vector_env,
|
||||
"episode": episode
|
||||
})
|
||||
|
||||
# Cut the batch if we're not packing multiple episodes into one,
|
||||
# or if we've exceeded the requested batch size.
|
||||
if episode.batch_builder.has_pending_data():
|
||||
@@ -325,6 +346,11 @@ def _env_runner(async_vector_env,
|
||||
if all_done:
|
||||
# Handle episode termination
|
||||
batch_builder_pool.append(episode.batch_builder)
|
||||
if callbacks.get("on_episode_end"):
|
||||
callbacks["on_episode_end"]({
|
||||
"env": async_vector_env,
|
||||
"episode": episode
|
||||
})
|
||||
del active_episodes[env_id]
|
||||
resetted_obs = async_vector_env.try_reset(env_id)
|
||||
if resetted_obs is None:
|
||||
@@ -429,7 +455,7 @@ def _fetch_atari_metrics(async_vector_env):
|
||||
if not monitor:
|
||||
return None
|
||||
for eps_rew, eps_len in monitor.next_episode_results():
|
||||
atari_out.append(RolloutMetrics(eps_len, eps_rew, {}))
|
||||
atari_out.append(RolloutMetrics(eps_len, eps_rew, {}, {}))
|
||||
return atari_out
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Example of using RLlib's debug callbacks.
|
||||
|
||||
Here we use callbacks to track the average CartPole pole angle magnitude as a
|
||||
custom metric.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
|
||||
def on_episode_start(info):
|
||||
episode = info["episode"]
|
||||
print("episode {} started".format(episode.episode_id))
|
||||
episode.user_data["pole_angles"] = []
|
||||
|
||||
|
||||
def on_episode_step(info):
|
||||
episode = info["episode"]
|
||||
pole_angle = abs(episode.last_observation_for()[2])
|
||||
episode.user_data["pole_angles"].append(pole_angle)
|
||||
|
||||
|
||||
def on_episode_end(info):
|
||||
episode = info["episode"]
|
||||
mean_pole_angle = np.mean(episode.user_data["pole_angles"])
|
||||
print("episode {} ended with length {} and pole angles {}".format(
|
||||
episode.episode_id, episode.length, mean_pole_angle))
|
||||
episode.custom_metrics["mean_pole_angle"] = mean_pole_angle
|
||||
|
||||
|
||||
def on_sample_end(info):
|
||||
print("returned sample batch of size {}".format(info["samples"].count))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-iters", type=int, default=2000)
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init()
|
||||
trials = tune.run_experiments({
|
||||
"test": {
|
||||
"env": "CartPole-v0",
|
||||
"run": "PG",
|
||||
"stop": {
|
||||
"training_iteration": args.num_iters,
|
||||
},
|
||||
"config": {
|
||||
"callbacks": {
|
||||
"on_episode_start": tune.function(on_episode_start),
|
||||
"on_episode_step": tune.function(on_episode_step),
|
||||
"on_episode_end": tune.function(on_episode_end),
|
||||
"on_sample_end": tune.function(on_sample_end),
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
# verify custom metrics for integration tests
|
||||
custom_metrics = trials[0].last_result["custom_metrics"]
|
||||
print(custom_metrics)
|
||||
assert "mean_pole_angle" in custom_metrics
|
||||
assert type(custom_metrics["mean_pole_angle"]) is float
|
||||
@@ -6,6 +6,7 @@ import gym
|
||||
import numpy as np
|
||||
import time
|
||||
import unittest
|
||||
from collections import Counter
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
@@ -150,6 +151,26 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
result2 = agent.train()
|
||||
self.assertLess(result2["info"]["learner"]["cur_lr"], 0.0001)
|
||||
|
||||
def testCallbacks(self):
|
||||
counts = Counter()
|
||||
pg = PGAgent(
|
||||
env="CartPole-v0", config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 50,
|
||||
"callbacks": {
|
||||
"on_episode_start": lambda x: counts.update({"start": 1}),
|
||||
"on_episode_step": lambda x: counts.update({"step": 1}),
|
||||
"on_episode_end": lambda x: counts.update({"end": 1}),
|
||||
"on_sample_end": lambda x: counts.update({"sample": 1}),
|
||||
},
|
||||
})
|
||||
pg.train()
|
||||
self.assertEqual(counts["sample"], 1)
|
||||
self.assertGreater(counts["start"], 0)
|
||||
self.assertGreater(counts["end"], 0)
|
||||
self.assertGreater(counts["step"], 50)
|
||||
self.assertLess(counts["step"], 100)
|
||||
|
||||
def testQueryEvaluators(self):
|
||||
register_env("test", lambda _: gym.make("CartPole-v0"))
|
||||
pg = PGAgent(
|
||||
|
||||
Reference in New Issue
Block a user