[rllib] Print out intermediate data shapes on the first iteration (#4426)

This commit is contained in:
Eric Liang
2019-03-26 00:27:59 -07:00
committed by GitHub
parent 8ee240f40e
commit cff08e19ff
16 changed files with 236 additions and 34 deletions
+5 -2
View File
@@ -41,7 +41,10 @@ COMMON_CONFIG = {
# === Debugging ===
# Whether to write episode stats and videos to the agent log dir
"monitor": False,
# Set the ray.rllib.* log level for the agent process and its evaluators
# Set the ray.rllib.* log level for the agent process and its evaluators.
# Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also
# periodically print out summaries of relevant internal dataflow (this is
# also printed out once at startup at the INFO level).
"log_level": "INFO",
# Callbacks that will be run during various phases of training. These all
# take a single "info" dict as an argument. For episode callbacks, custom
@@ -407,7 +410,7 @@ class Agent(Trainable):
prev_action=None,
prev_reward=None,
info=None,
policy_id="default"):
policy_id=DEFAULT_POLICY_ID):
"""Computes an action for the specified policy.
Note that you can also access the policy object through
+4 -3
View File
@@ -17,6 +17,7 @@ from ray.rllib.agents import Agent, with_common_config
from ray.rllib.agents.ars import optimizers
from ray.rllib.agents.ars import policies
from ray.rllib.agents.ars import utils
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils import FilterManager
@@ -87,7 +88,7 @@ class Worker(object):
@property
def filters(self):
return {"default": self.policy.get_filter()}
return {DEFAULT_POLICY_ID: self.policy.get_filter()}
def sync_filters(self, new_filters):
for k in self.filters:
@@ -271,7 +272,7 @@ class ARSAgent(Agent):
# Now sync the filters
FilterManager.synchronize({
"default": self.policy.get_filter()
DEFAULT_POLICY_ID: self.policy.get_filter()
}, self.workers)
info = {
@@ -335,5 +336,5 @@ class ARSAgent(Agent):
self.policy.set_weights(state["weights"])
self.policy.set_filter(state["filter"])
FilterManager.synchronize({
"default": self.policy.get_filter()
DEFAULT_POLICY_ID: self.policy.get_filter()
}, self.workers)
+2 -1
View File
@@ -10,6 +10,7 @@ from ray.rllib import optimizers
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
@@ -195,7 +196,7 @@ class DQNAgent(Agent):
policies = info["policy"]
episode = info["episode"]
episode.custom_metrics["policy_distance"] = policies[
"default"].pi_distance
DEFAULT_POLICY_ID].pi_distance
if end_callback:
end_callback(info)
+4 -3
View File
@@ -16,6 +16,7 @@ from ray.rllib.agents import Agent, with_common_config
from ray.rllib.agents.es import optimizers
from ray.rllib.agents.es import policies
from ray.rllib.agents.es import utils
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils import FilterManager
@@ -91,7 +92,7 @@ class Worker(object):
@property
def filters(self):
return {"default": self.policy.get_filter()}
return {DEFAULT_POLICY_ID: self.policy.get_filter()}
def sync_filters(self, new_filters):
for k in self.filters:
@@ -268,7 +269,7 @@ class ESAgent(Agent):
# Now sync the filters
FilterManager.synchronize({
"default": self.policy.get_filter()
DEFAULT_POLICY_ID: self.policy.get_filter()
}, self.workers)
info = {
@@ -332,5 +333,5 @@ class ESAgent(Agent):
self.policy.set_weights(state["weights"])
self.policy.set_filter(state["filter"])
FilterManager.synchronize({
"default": self.policy.get_filter()
DEFAULT_POLICY_ID: self.policy.get_filter()
}, self.workers)
+1 -1
View File
@@ -186,7 +186,7 @@ class BaseEnv(object):
# Fixed agent identifier when there is only the single agent in the env
_DUMMY_AGENT_ID = "single_agent"
_DUMMY_AGENT_ID = "singleton_agent"
def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
@@ -28,6 +28,8 @@ from ray.rllib.models.preprocessors import NoPreprocessor
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.compression import pack
from ray.rllib.utils.debug import disable_log_once_globally, log_once, \
summarize, enable_periodic_logging
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.tf_run_builder import TFRunBuilder
@@ -209,10 +211,16 @@ class PolicyEvaluator(EvaluatorInterface):
if log_level:
logging.getLogger("ray.rllib").setLevel(log_level)
if worker_index > 1:
disable_log_once_globally() # only need 1 evaluator to log
elif log_level == "DEBUG":
enable_periodic_logging()
env_context = EnvContext(env_config or {}, worker_index)
policy_config = policy_config or {}
self.policy_config = policy_config
self.callbacks = callbacks or {}
self.worker_index = worker_index
model_config = model_config or {}
policy_mapping_fn = (policy_mapping_fn
or (lambda agent_id: DEFAULT_POLICY_ID))
@@ -304,6 +312,8 @@ class PolicyEvaluator(EvaluatorInterface):
policy.observation_space.shape)
for (policy_id, policy) in self.policy_map.items()
}
if self.worker_index == 0:
logger.info("Built filter map: {}".format(self.filters))
# Always use vector env for consistency even if num_envs = 1
self.async_env = BaseEnv.to_base_env(
@@ -390,6 +400,10 @@ class PolicyEvaluator(EvaluatorInterface):
SampleBatch|MultiAgentBatch from evaluating the current policies.
"""
if log_once("sample_start"):
logger.info("Generating sample batch of size {}".format(
self.sample_batch_size))
batches = [self.input_reader.next()]
steps_so_far = batches[0].count
@@ -423,6 +437,10 @@ class PolicyEvaluator(EvaluatorInterface):
for estimator in self.reward_estimators:
estimator.process(sub_batch)
if log_once("sample_end"):
logger.info("Completed sample batch:\n\n{}\n".format(
summarize(batch)))
if self.compress_observations:
if isinstance(batch, MultiAgentBatch):
for data in batch.policy_batches.values():
@@ -457,6 +475,9 @@ class PolicyEvaluator(EvaluatorInterface):
@override(EvaluatorInterface)
def compute_gradients(self, samples):
if log_once("compute_gradients"):
logger.info("Compute gradients on:\n\n{}\n".format(
summarize(samples)))
if isinstance(samples, MultiAgentBatch):
grad_out, info_out = {}, {}
if self.tf_sess is not None:
@@ -479,10 +500,15 @@ class PolicyEvaluator(EvaluatorInterface):
grad_out, info_out = (
self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
info_out["batch_count"] = samples.count
if log_once("grad_out"):
logger.info("Compute grad info:\n\n{}\n".format(
summarize(info_out)))
return grad_out, info_out
@override(EvaluatorInterface)
def apply_gradients(self, grads):
if log_once("apply_gradients"):
logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
if isinstance(grads, dict):
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "apply_gradients")
@@ -502,6 +528,10 @@ class PolicyEvaluator(EvaluatorInterface):
@override(EvaluatorInterface)
def learn_on_batch(self, samples):
if log_once("learn_on_batch"):
logger.info(
"Training on concatenated sample batches:\n\n{}\n".format(
summarize(samples)))
if isinstance(samples, MultiAgentBatch):
info_out = {}
if self.tf_sess is not None:
@@ -519,11 +549,12 @@ class PolicyEvaluator(EvaluatorInterface):
continue
info_out[pid], _ = (
self.policy_map[pid].learn_on_batch(batch))
return info_out
else:
grad_fetch, apply_fetch = (
info_out, _ = (
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
return grad_fetch
if log_once("learn_out"):
logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
return info_out
@DeveloperAPI
def get_metrics(self):
@@ -659,6 +690,9 @@ class PolicyEvaluator(EvaluatorInterface):
"Tuple|DictFlatteningPreprocessor.")
with tf.variable_scope(name):
policy_map[name] = cls(obs_space, act_space, merged_conf)
if self.worker_index == 0:
logger.info("Built policy map: {}".format(policy_map))
logger.info("Built preprocessor map: {}".format(preprocessors))
return policy_map, preprocessors
def __del__(self):
+1 -1
View File
@@ -10,7 +10,7 @@ from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.memory import concat_aligned
# Defaults policy id for single agent environments
DEFAULT_POLICY_ID = "default"
DEFAULT_POLICY_ID = "default_policy"
@PublicAPI
@@ -3,10 +3,14 @@ from __future__ import division
from __future__ import print_function
import collections
import logging
import numpy as np
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils.debug import log_once, summarize
logger = logging.getLogger(__name__)
def to_float_array(v):
@@ -145,10 +149,16 @@ class MultiAgentSampleBatchBuilder(object):
post_batches[agent_id] = policy.postprocess_trajectory(
pre_batch, other_batches, episode)
if log_once("after_post"):
logger.info(
"Trajectory fragment after postprocess_trajectory():\n\n{}\n".
format(summarize(post_batches)))
# Append into policy batches and reset
for agent_id, post_batch in sorted(post_batches.items()):
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
post_batch)
self.agent_builders.clear()
self.agent_to_policy.clear()
+23 -5
View File
@@ -18,10 +18,10 @@ from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
from ray.rllib.models.action_dist import TupleActions
from ray.rllib.offline import InputReader
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import log_once, summarize
from ray.rllib.utils.tf_run_builder import TFRunBuilder
logger = logging.getLogger(__name__)
_large_batch_warned = False
RolloutMetrics = namedtuple(
"RolloutMetrics",
@@ -303,6 +303,11 @@ def _env_runner(base_env,
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
base_env.poll()
if log_once("env_returns"):
logger.info("Raw obs from env: {}".format(
summarize(unfiltered_obs)))
logger.info("Info return from env: {}".format(summarize(infos)))
# Process observations and prepare for policy evaluation
active_envs, to_eval, outputs = _process_observations(
base_env, policies, batch_builder_pool, active_episodes,
@@ -350,10 +355,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
episode.batch_builder.count += 1
episode._add_agent_rewards(rewards[env_id])
global _large_batch_warned
if (not _large_batch_warned and
episode.batch_builder.total() > max(1000, unroll_length * 10)):
_large_batch_warned = True
if (episode.batch_builder.total() > max(1000, unroll_length * 10)
and log_once("large_batch_warning")):
logger.warning(
"More than {} observations for {} env steps ".format(
episode.batch_builder.total(),
@@ -387,7 +390,13 @@ def _process_observations(base_env, policies, batch_builder_pool,
policy_id = episode.policy_for(agent_id)
prep_obs = _get_or_raise(preprocessors,
policy_id).transform(raw_obs)
if log_once("prep_obs"):
logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
if log_once("filtered_obs"):
logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
agent_done = bool(all_done or dones[env_id].get(agent_id))
if not agent_done:
to_eval[policy_id].append(
@@ -491,6 +500,11 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
pending_fetches = {}
else:
builder = None
if log_once("compute_actions_input"):
logger.info("Example compute_actions() input:\n\n{}\n".format(
summarize(to_eval)))
for policy_id, eval_data in to_eval.items():
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
policy = _get_or_raise(policies, policy_id)
@@ -514,6 +528,10 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
for k, v in pending_fetches.items():
eval_results[k] = builder.get(v)
if log_once("compute_actions_result"):
logger.info("Example compute_actions() result:\n\n{}\n".format(
summarize(eval_results)))
return eval_results
@@ -13,6 +13,7 @@ import ray.experimental.tf_utils
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import log_once, summarize
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.tf_run_builder import TFRunBuilder
@@ -466,6 +467,15 @@ class TFPolicyGraph(PolicyGraph):
for k, v in zip(state_keys, initial_states):
feed_dict[self._loss_input_dict[k]] = v
feed_dict[self._seq_lens] = seq_lens
if log_once("rnn_feed_dict"):
logger.info("Padded input for RNN:\n\n{}\n".format(
summarize({
"features": feature_sequences,
"initial_states": initial_states,
"seq_lens": seq_lens,
"max_seq_len": max_seq_len,
})))
return feed_dict
@@ -15,6 +15,7 @@ import threading
from six.moves import queue
import ray
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.actors import TaskPool
@@ -321,9 +322,9 @@ class TFMultiGPULearner(LearnerThread):
assert self.train_batch_size % len(self.devices) == 0
assert self.train_batch_size >= len(self.devices), "batch too small"
if set(self.local_evaluator.policy_map.keys()) != {"default"}:
if set(self.local_evaluator.policy_map.keys()) != {DEFAULT_POLICY_ID}:
raise NotImplementedError("Multi-gpu mode for multi-agent")
self.policy = self.local_evaluator.policy_map["default"]
self.policy = self.local_evaluator.policy_map[DEFAULT_POLICY_ID]
# per-GPU graph copies created below must share vars with the policy
# reuse is set to AUTO_REUSE because Adam nodes are created after
@@ -331,7 +332,7 @@ class TFMultiGPULearner(LearnerThread):
self.par_opt = []
with self.local_evaluator.tf_sess.graph.as_default():
with self.local_evaluator.tf_sess.as_default():
with tf.variable_scope("default", reuse=tf.AUTO_REUSE):
with tf.variable_scope(DEFAULT_POLICY_ID, reuse=tf.AUTO_REUSE):
if self.policy._state_inputs:
rnn_inputs = self.policy._state_inputs + [
self.policy._seq_lens
+11 -1
View File
@@ -4,9 +4,10 @@ from __future__ import print_function
from collections import namedtuple
import logging
import tensorflow as tf
from ray.rllib.utils.debug import log_once, summarize
# Variable scope in which created variables will be placed under
TOWER_SCOPE_NAME = "tower"
@@ -134,6 +135,15 @@ class LocalSyncParallelOptimizer(object):
The number of tuples loaded per device.
"""
if log_once("load_data"):
logger.info(
"Training on concatenated sample batches:\n\n{}\n".format(
summarize({
"placeholders": self.loss_inputs,
"inputs": inputs,
"state_inputs": state_inputs
})))
feed_dict = {}
assert len(self.loss_inputs) == len(inputs + state_inputs), \
(self.loss_inputs, inputs, state_inputs)
+3 -2
View File
@@ -12,6 +12,7 @@ import pickle
import gym
import ray
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.evaluation.sampler import clip_action
from ray.tune.util import merge_dicts
@@ -116,7 +117,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
else:
env = gym.make(env_name)
multiagent = False
use_lstm = {'default': False}
use_lstm = {DEFAULT_POLICY_ID: False}
if out is not None:
rollouts = []
@@ -148,7 +149,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
action_dict[agent_id] = a_action
action = action_dict
else:
if use_lstm["default"]:
if use_lstm[DEFAULT_POLICY_ID]:
action, state_init, _ = agent.compute_action(
state, state=state_init)
else:
@@ -16,6 +16,7 @@ from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.env.vector_env import VectorEnv
from ray.tune.registry import register_env
@@ -367,7 +368,7 @@ class TestPolicyEvaluator(unittest.TestCase):
time.sleep(2)
ev.sample()
filters = ev.get_filters(flush_after=True)
obs_f = filters["default"]
obs_f = filters[DEFAULT_POLICY_ID]
self.assertNotEqual(obs_f.rs.n, 0)
self.assertNotEqual(obs_f.buffer.n, 0)
@@ -381,8 +382,8 @@ class TestPolicyEvaluator(unittest.TestCase):
filters = ev.get_filters(flush_after=False)
time.sleep(2)
filters2 = ev.get_filters(flush_after=False)
obs_f = filters["default"]
obs_f2 = filters2["default"]
obs_f = filters[DEFAULT_POLICY_ID]
obs_f2 = filters2[DEFAULT_POLICY_ID]
self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n)
self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
@@ -396,15 +397,15 @@ class TestPolicyEvaluator(unittest.TestCase):
# Current State
filters = ev.get_filters(flush_after=False)
obs_f = filters["default"]
obs_f = filters[DEFAULT_POLICY_ID]
self.assertLessEqual(obs_f.buffer.n, 20)
new_obsf = obs_f.copy()
new_obsf.rs._n = 100
ev.sync_filters({"default": new_obsf})
ev.sync_filters({DEFAULT_POLICY_ID: new_obsf})
filters = ev.get_filters(flush_after=False)
obs_f = filters["default"]
obs_f = filters[DEFAULT_POLICY_ID]
self.assertGreaterEqual(obs_f.rs.n, 100)
self.assertLessEqual(obs_f.buffer.n, 20)
@@ -412,7 +413,7 @@ class TestPolicyEvaluator(unittest.TestCase):
time.sleep(2)
ev.sample()
filters = ev.get_filters(flush_after=True)
obs_f = filters["default"]
obs_f = filters[DEFAULT_POLICY_ID]
self.assertNotEqual(obs_f.rs.n, 0)
self.assertNotEqual(obs_f.buffer.n, 0)
return obs_f
+111
View File
@@ -0,0 +1,111 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import pprint
import time
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
_logged = set()
_disabled = False
_periodic_log = True
_last_logged = 0.0
_printer = pprint.PrettyPrinter(indent=2, width=60)
def log_once(key):
"""Returns True if this is the "first" call for a given key.
Various logging settings can adjust the definition of "first".
Example:
>>> if log_once("some_key"):
... logger.info("Some verbose logging statement")
"""
global _last_logged
if _disabled:
return False
elif key not in _logged:
_logged.add(key)
_last_logged = time.time()
return True
elif _periodic_log and time.time() - _last_logged > 60.0:
_logged.clear()
_last_logged = time.time()
return False
else:
return False
def disable_log_once_globally():
"""Make log_once() return False in this process."""
global _disabled
_disabled = True
def enable_periodic_logging():
"""Make log_once() periodically return True in this process."""
global _periodic_log
_periodic_log = True
def summarize(obj):
"""Return a pretty-formatted string for an object.
This has special handling for pretty-formatting of commonly used data types
in RLlib, such as SampleBatch, numpy arrays, etc.
"""
return _printer.pformat(_summarize(obj))
def _summarize(obj):
if isinstance(obj, dict):
return {k: _summarize(v) for k, v in obj.items()}
elif hasattr(obj, "_asdict"):
return {
"type": obj.__class__.__name__,
"data": _summarize(obj._asdict()),
}
elif isinstance(obj, list):
return [_summarize(x) for x in obj]
elif isinstance(obj, tuple):
return tuple(_summarize(x) for x in obj)
elif isinstance(obj, np.ndarray):
if obj.dtype == np.object:
return _StringValue("np.ndarray({}, dtype={}, head={})".format(
obj.shape, obj.dtype, _summarize(obj[0])))
else:
return _StringValue(
"np.ndarray({}, dtype={}, min={}, max={}, mean={})".format(
obj.shape, obj.dtype, round(float(np.min(obj)), 3),
round(float(np.max(obj)), 3), round(
float(np.mean(obj)), 3)))
elif isinstance(obj, MultiAgentBatch):
return {
"type": "MultiAgentBatch",
"policy_batches": _summarize(obj.policy_batches),
"count": obj.count,
}
elif isinstance(obj, SampleBatch):
return {
"type": "SampleBatch",
"data": {k: _summarize(v)
for k, v in obj.items()},
}
else:
return obj
class _StringValue(object):
def __init__(self, value):
self.value = value
def __repr__(self):
return self.value