From cff08e19ff1606ef6e718624703e8e0da19b223d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 26 Mar 2019 00:27:59 -0700 Subject: [PATCH] [rllib] Print out intermediate data shapes on the first iteration (#4426) --- doc/source/rllib-training.rst | 4 +- python/ray/rllib/agents/agent.py | 7 +- python/ray/rllib/agents/ars/ars.py | 7 +- python/ray/rllib/agents/dqn/dqn.py | 3 +- python/ray/rllib/agents/es/es.py | 7 +- python/ray/rllib/env/base_env.py | 2 +- .../ray/rllib/evaluation/policy_evaluator.py | 40 ++++++- python/ray/rllib/evaluation/sample_batch.py | 2 +- .../rllib/evaluation/sample_batch_builder.py | 10 ++ python/ray/rllib/evaluation/sampler.py | 28 ++++- .../ray/rllib/evaluation/tf_policy_graph.py | 10 ++ .../optimizers/async_samples_optimizer.py | 7 +- python/ray/rllib/optimizers/multi_gpu_impl.py | 12 +- python/ray/rllib/rollout.py | 5 +- .../ray/rllib/tests/test_policy_evaluator.py | 15 +-- python/ray/rllib/utils/debug.py | 111 ++++++++++++++++++ 16 files changed, 236 insertions(+), 34 deletions(-) create mode 100644 python/ray/rllib/utils/debug.py diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index a95ffc3c5..c192b674f 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -184,7 +184,7 @@ Accessing Policy State ~~~~~~~~~~~~~~~~~~~~~~ It is common to need to access an agent's internal state, e.g., to set or get internal weights. In RLlib an agent's state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``agent.optimizer.foreach_evaluator()`` or ``agent.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list. -You can also access just the "master" copy of the agent state through ``agent.get_policy()`` or ``agent.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``agent.get_policy().get_weights()``. This is also equivalent to ``agent.local_evaluator.policy_map["default"].get_weights()``: +You can also access just the "master" copy of the agent state through ``agent.get_policy()`` or ``agent.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``agent.get_policy().get_weights()``. This is also equivalent to ``agent.local_evaluator.policy_map["default_policy"].get_weights()``: .. code-block:: python @@ -192,7 +192,7 @@ You can also access just the "master" copy of the agent state through ``agent.ge agent.get_policy().get_weights() # Same as above - agent.local_evaluator.policy_map["default"].get_weights() + agent.local_evaluator.policy_map["default_policy"].get_weights() # Get list of weights of each evaluator, including remote replicas agent.optimizer.foreach_evaluator(lambda ev: ev.get_policy().get_weights()) diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 596e4df7a..31fc09bb9 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -41,7 +41,10 @@ COMMON_CONFIG = { # === Debugging === # Whether to write episode stats and videos to the agent log dir "monitor": False, - # Set the ray.rllib.* log level for the agent process and its evaluators + # Set the ray.rllib.* log level for the agent process and its evaluators. + # Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also + # periodically print out summaries of relevant internal dataflow (this is + # also printed out once at startup at the INFO level). "log_level": "INFO", # Callbacks that will be run during various phases of training. These all # take a single "info" dict as an argument. For episode callbacks, custom @@ -407,7 +410,7 @@ class Agent(Trainable): prev_action=None, prev_reward=None, info=None, - policy_id="default"): + policy_id=DEFAULT_POLICY_ID): """Computes an action for the specified policy. Note that you can also access the policy object through diff --git a/python/ray/rllib/agents/ars/ars.py b/python/ray/rllib/agents/ars/ars.py index aafcee7f4..16416e46a 100644 --- a/python/ray/rllib/agents/ars/ars.py +++ b/python/ray/rllib/agents/ars/ars.py @@ -17,6 +17,7 @@ from ray.rllib.agents import Agent, with_common_config from ray.rllib.agents.ars import optimizers from ray.rllib.agents.ars import policies from ray.rllib.agents.ars import utils +from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override from ray.rllib.utils import FilterManager @@ -87,7 +88,7 @@ class Worker(object): @property def filters(self): - return {"default": self.policy.get_filter()} + return {DEFAULT_POLICY_ID: self.policy.get_filter()} def sync_filters(self, new_filters): for k in self.filters: @@ -271,7 +272,7 @@ class ARSAgent(Agent): # Now sync the filters FilterManager.synchronize({ - "default": self.policy.get_filter() + DEFAULT_POLICY_ID: self.policy.get_filter() }, self.workers) info = { @@ -335,5 +336,5 @@ class ARSAgent(Agent): self.policy.set_weights(state["weights"]) self.policy.set_filter(state["filter"]) FilterManager.synchronize({ - "default": self.policy.get_filter() + DEFAULT_POLICY_ID: self.policy.get_filter() }, self.workers) diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index 56c064b2d..5b1bb3597 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -10,6 +10,7 @@ from ray.rllib import optimizers from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph from ray.rllib.evaluation.metrics import collect_metrics +from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule @@ -195,7 +196,7 @@ class DQNAgent(Agent): policies = info["policy"] episode = info["episode"] episode.custom_metrics["policy_distance"] = policies[ - "default"].pi_distance + DEFAULT_POLICY_ID].pi_distance if end_callback: end_callback(info) diff --git a/python/ray/rllib/agents/es/es.py b/python/ray/rllib/agents/es/es.py index 4aa4a86aa..8ca9689e6 100644 --- a/python/ray/rllib/agents/es/es.py +++ b/python/ray/rllib/agents/es/es.py @@ -16,6 +16,7 @@ from ray.rllib.agents import Agent, with_common_config from ray.rllib.agents.es import optimizers from ray.rllib.agents.es import policies from ray.rllib.agents.es import utils +from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override from ray.rllib.utils import FilterManager @@ -91,7 +92,7 @@ class Worker(object): @property def filters(self): - return {"default": self.policy.get_filter()} + return {DEFAULT_POLICY_ID: self.policy.get_filter()} def sync_filters(self, new_filters): for k in self.filters: @@ -268,7 +269,7 @@ class ESAgent(Agent): # Now sync the filters FilterManager.synchronize({ - "default": self.policy.get_filter() + DEFAULT_POLICY_ID: self.policy.get_filter() }, self.workers) info = { @@ -332,5 +333,5 @@ class ESAgent(Agent): self.policy.set_weights(state["weights"]) self.policy.set_filter(state["filter"]) FilterManager.synchronize({ - "default": self.policy.get_filter() + DEFAULT_POLICY_ID: self.policy.get_filter() }, self.workers) diff --git a/python/ray/rllib/env/base_env.py b/python/ray/rllib/env/base_env.py index 7dd1921f1..05196a342 100644 --- a/python/ray/rllib/env/base_env.py +++ b/python/ray/rllib/env/base_env.py @@ -186,7 +186,7 @@ class BaseEnv(object): # Fixed agent identifier when there is only the single agent in the env -_DUMMY_AGENT_ID = "single_agent" +_DUMMY_AGENT_ID = "singleton_agent" def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID): diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index c7389b3e4..60e75ed46 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -28,6 +28,8 @@ from ray.rllib.models.preprocessors import NoPreprocessor from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.compression import pack +from ray.rllib.utils.debug import disable_log_once_globally, log_once, \ + summarize, enable_periodic_logging from ray.rllib.utils.filter import get_filter from ray.rllib.utils.tf_run_builder import TFRunBuilder @@ -209,10 +211,16 @@ class PolicyEvaluator(EvaluatorInterface): if log_level: logging.getLogger("ray.rllib").setLevel(log_level) + if worker_index > 1: + disable_log_once_globally() # only need 1 evaluator to log + elif log_level == "DEBUG": + enable_periodic_logging() + env_context = EnvContext(env_config or {}, worker_index) policy_config = policy_config or {} self.policy_config = policy_config self.callbacks = callbacks or {} + self.worker_index = worker_index model_config = model_config or {} policy_mapping_fn = (policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID)) @@ -304,6 +312,8 @@ class PolicyEvaluator(EvaluatorInterface): policy.observation_space.shape) for (policy_id, policy) in self.policy_map.items() } + if self.worker_index == 0: + logger.info("Built filter map: {}".format(self.filters)) # Always use vector env for consistency even if num_envs = 1 self.async_env = BaseEnv.to_base_env( @@ -390,6 +400,10 @@ class PolicyEvaluator(EvaluatorInterface): SampleBatch|MultiAgentBatch from evaluating the current policies. """ + if log_once("sample_start"): + logger.info("Generating sample batch of size {}".format( + self.sample_batch_size)) + batches = [self.input_reader.next()] steps_so_far = batches[0].count @@ -423,6 +437,10 @@ class PolicyEvaluator(EvaluatorInterface): for estimator in self.reward_estimators: estimator.process(sub_batch) + if log_once("sample_end"): + logger.info("Completed sample batch:\n\n{}\n".format( + summarize(batch))) + if self.compress_observations: if isinstance(batch, MultiAgentBatch): for data in batch.policy_batches.values(): @@ -457,6 +475,9 @@ class PolicyEvaluator(EvaluatorInterface): @override(EvaluatorInterface) def compute_gradients(self, samples): + if log_once("compute_gradients"): + logger.info("Compute gradients on:\n\n{}\n".format( + summarize(samples))) if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} if self.tf_sess is not None: @@ -479,10 +500,15 @@ class PolicyEvaluator(EvaluatorInterface): grad_out, info_out = ( self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples)) info_out["batch_count"] = samples.count + if log_once("grad_out"): + logger.info("Compute grad info:\n\n{}\n".format( + summarize(info_out))) return grad_out, info_out @override(EvaluatorInterface) def apply_gradients(self, grads): + if log_once("apply_gradients"): + logger.info("Apply gradients:\n\n{}\n".format(summarize(grads))) if isinstance(grads, dict): if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "apply_gradients") @@ -502,6 +528,10 @@ class PolicyEvaluator(EvaluatorInterface): @override(EvaluatorInterface) def learn_on_batch(self, samples): + if log_once("learn_on_batch"): + logger.info( + "Training on concatenated sample batches:\n\n{}\n".format( + summarize(samples))) if isinstance(samples, MultiAgentBatch): info_out = {} if self.tf_sess is not None: @@ -519,11 +549,12 @@ class PolicyEvaluator(EvaluatorInterface): continue info_out[pid], _ = ( self.policy_map[pid].learn_on_batch(batch)) - return info_out else: - grad_fetch, apply_fetch = ( + info_out, _ = ( self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples)) - return grad_fetch + if log_once("learn_out"): + logger.info("Training output:\n\n{}\n".format(summarize(info_out))) + return info_out @DeveloperAPI def get_metrics(self): @@ -659,6 +690,9 @@ class PolicyEvaluator(EvaluatorInterface): "Tuple|DictFlatteningPreprocessor.") with tf.variable_scope(name): policy_map[name] = cls(obs_space, act_space, merged_conf) + if self.worker_index == 0: + logger.info("Built policy map: {}".format(policy_map)) + logger.info("Built preprocessor map: {}".format(preprocessors)) return policy_map, preprocessors def __del__(self): diff --git a/python/ray/rllib/evaluation/sample_batch.py b/python/ray/rllib/evaluation/sample_batch.py index 4e4de2358..4142f720a 100644 --- a/python/ray/rllib/evaluation/sample_batch.py +++ b/python/ray/rllib/evaluation/sample_batch.py @@ -10,7 +10,7 @@ from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.memory import concat_aligned # Defaults policy id for single agent environments -DEFAULT_POLICY_ID = "default" +DEFAULT_POLICY_ID = "default_policy" @PublicAPI diff --git a/python/ray/rllib/evaluation/sample_batch_builder.py b/python/ray/rllib/evaluation/sample_batch_builder.py index 211e7075b..2387f5cd0 100644 --- a/python/ray/rllib/evaluation/sample_batch_builder.py +++ b/python/ray/rllib/evaluation/sample_batch_builder.py @@ -3,10 +3,14 @@ from __future__ import division from __future__ import print_function import collections +import logging import numpy as np from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI +from ray.rllib.utils.debug import log_once, summarize + +logger = logging.getLogger(__name__) def to_float_array(v): @@ -145,10 +149,16 @@ class MultiAgentSampleBatchBuilder(object): post_batches[agent_id] = policy.postprocess_trajectory( pre_batch, other_batches, episode) + if log_once("after_post"): + logger.info( + "Trajectory fragment after postprocess_trajectory():\n\n{}\n". + format(summarize(post_batches))) + # Append into policy batches and reset for agent_id, post_batch in sorted(post_batches.items()): self.policy_builders[self.agent_to_policy[agent_id]].add_batch( post_batch) + self.agent_builders.clear() self.agent_to_policy.clear() diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index cd6935f8a..381ef7212 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -18,10 +18,10 @@ from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv from ray.rllib.models.action_dist import TupleActions from ray.rllib.offline import InputReader from ray.rllib.utils.annotations import override +from ray.rllib.utils.debug import log_once, summarize from ray.rllib.utils.tf_run_builder import TFRunBuilder logger = logging.getLogger(__name__) -_large_batch_warned = False RolloutMetrics = namedtuple( "RolloutMetrics", @@ -303,6 +303,11 @@ def _env_runner(base_env, unfiltered_obs, rewards, dones, infos, off_policy_actions = \ base_env.poll() + if log_once("env_returns"): + logger.info("Raw obs from env: {}".format( + summarize(unfiltered_obs))) + logger.info("Info return from env: {}".format(summarize(infos))) + # Process observations and prepare for policy evaluation active_envs, to_eval, outputs = _process_observations( base_env, policies, batch_builder_pool, active_episodes, @@ -350,10 +355,8 @@ def _process_observations(base_env, policies, batch_builder_pool, episode.batch_builder.count += 1 episode._add_agent_rewards(rewards[env_id]) - global _large_batch_warned - if (not _large_batch_warned and - episode.batch_builder.total() > max(1000, unroll_length * 10)): - _large_batch_warned = True + if (episode.batch_builder.total() > max(1000, unroll_length * 10) + and log_once("large_batch_warning")): logger.warning( "More than {} observations for {} env steps ".format( episode.batch_builder.total(), @@ -387,7 +390,13 @@ def _process_observations(base_env, policies, batch_builder_pool, policy_id = episode.policy_for(agent_id) prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) + if log_once("prep_obs"): + logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) + filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) + if log_once("filtered_obs"): + logger.info("Filtered obs: {}".format(summarize(filtered_obs))) + agent_done = bool(all_done or dones[env_id].get(agent_id)) if not agent_done: to_eval[policy_id].append( @@ -491,6 +500,11 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): pending_fetches = {} else: builder = None + + if log_once("compute_actions_input"): + logger.info("Example compute_actions() input:\n\n{}\n".format( + summarize(to_eval))) + for policy_id, eval_data in to_eval.items(): rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) policy = _get_or_raise(policies, policy_id) @@ -514,6 +528,10 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): for k, v in pending_fetches.items(): eval_results[k] = builder.get(v) + if log_once("compute_actions_result"): + logger.info("Example compute_actions() result:\n\n{}\n".format( + summarize(eval_results))) + return eval_results diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index 609c9f57b..8febf7738 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -13,6 +13,7 @@ import ray.experimental.tf_utils from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.models.lstm import chop_into_sequences from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.debug import log_once, summarize from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule from ray.rllib.utils.tf_run_builder import TFRunBuilder @@ -466,6 +467,15 @@ class TFPolicyGraph(PolicyGraph): for k, v in zip(state_keys, initial_states): feed_dict[self._loss_input_dict[k]] = v feed_dict[self._seq_lens] = seq_lens + + if log_once("rnn_feed_dict"): + logger.info("Padded input for RNN:\n\n{}\n".format( + summarize({ + "features": feature_sequences, + "initial_states": initial_states, + "seq_lens": seq_lens, + "max_seq_len": max_seq_len, + }))) return feed_dict diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index de171c0ca..22f33545b 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -15,6 +15,7 @@ import threading from six.moves import queue import ray +from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.utils.actors import TaskPool @@ -321,9 +322,9 @@ class TFMultiGPULearner(LearnerThread): assert self.train_batch_size % len(self.devices) == 0 assert self.train_batch_size >= len(self.devices), "batch too small" - if set(self.local_evaluator.policy_map.keys()) != {"default"}: + if set(self.local_evaluator.policy_map.keys()) != {DEFAULT_POLICY_ID}: raise NotImplementedError("Multi-gpu mode for multi-agent") - self.policy = self.local_evaluator.policy_map["default"] + self.policy = self.local_evaluator.policy_map[DEFAULT_POLICY_ID] # per-GPU graph copies created below must share vars with the policy # reuse is set to AUTO_REUSE because Adam nodes are created after @@ -331,7 +332,7 @@ class TFMultiGPULearner(LearnerThread): self.par_opt = [] with self.local_evaluator.tf_sess.graph.as_default(): with self.local_evaluator.tf_sess.as_default(): - with tf.variable_scope("default", reuse=tf.AUTO_REUSE): + with tf.variable_scope(DEFAULT_POLICY_ID, reuse=tf.AUTO_REUSE): if self.policy._state_inputs: rnn_inputs = self.policy._state_inputs + [ self.policy._seq_lens diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index 337ca11aa..5a418f79d 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -4,9 +4,10 @@ from __future__ import print_function from collections import namedtuple import logging - import tensorflow as tf +from ray.rllib.utils.debug import log_once, summarize + # Variable scope in which created variables will be placed under TOWER_SCOPE_NAME = "tower" @@ -134,6 +135,15 @@ class LocalSyncParallelOptimizer(object): The number of tuples loaded per device. """ + if log_once("load_data"): + logger.info( + "Training on concatenated sample batches:\n\n{}\n".format( + summarize({ + "placeholders": self.loss_inputs, + "inputs": inputs, + "state_inputs": state_inputs + }))) + feed_dict = {} assert len(self.loss_inputs) == len(inputs + state_inputs), \ (self.loss_inputs, inputs, state_inputs) diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index 0bd364583..65f5bce88 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -12,6 +12,7 @@ import pickle import gym import ray from ray.rllib.agents.registry import get_agent_class +from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID from ray.rllib.evaluation.sampler import clip_action from ray.tune.util import merge_dicts @@ -116,7 +117,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): else: env = gym.make(env_name) multiagent = False - use_lstm = {'default': False} + use_lstm = {DEFAULT_POLICY_ID: False} if out is not None: rollouts = [] @@ -148,7 +149,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): action_dict[agent_id] = a_action action = action_dict else: - if use_lstm["default"]: + if use_lstm[DEFAULT_POLICY_ID]: action, state_init, _ = agent.compute_action( state, state=state_init) else: diff --git a/python/ray/rllib/tests/test_policy_evaluator.py b/python/ray/rllib/tests/test_policy_evaluator.py index d5466865b..827b737ef 100644 --- a/python/ray/rllib/tests/test_policy_evaluator.py +++ b/python/ray/rllib/tests/test_policy_evaluator.py @@ -16,6 +16,7 @@ from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID from ray.rllib.env.vector_env import VectorEnv from ray.tune.registry import register_env @@ -367,7 +368,7 @@ class TestPolicyEvaluator(unittest.TestCase): time.sleep(2) ev.sample() filters = ev.get_filters(flush_after=True) - obs_f = filters["default"] + obs_f = filters[DEFAULT_POLICY_ID] self.assertNotEqual(obs_f.rs.n, 0) self.assertNotEqual(obs_f.buffer.n, 0) @@ -381,8 +382,8 @@ class TestPolicyEvaluator(unittest.TestCase): filters = ev.get_filters(flush_after=False) time.sleep(2) filters2 = ev.get_filters(flush_after=False) - obs_f = filters["default"] - obs_f2 = filters2["default"] + obs_f = filters[DEFAULT_POLICY_ID] + obs_f2 = filters2[DEFAULT_POLICY_ID] self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n) self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n) @@ -396,15 +397,15 @@ class TestPolicyEvaluator(unittest.TestCase): # Current State filters = ev.get_filters(flush_after=False) - obs_f = filters["default"] + obs_f = filters[DEFAULT_POLICY_ID] self.assertLessEqual(obs_f.buffer.n, 20) new_obsf = obs_f.copy() new_obsf.rs._n = 100 - ev.sync_filters({"default": new_obsf}) + ev.sync_filters({DEFAULT_POLICY_ID: new_obsf}) filters = ev.get_filters(flush_after=False) - obs_f = filters["default"] + obs_f = filters[DEFAULT_POLICY_ID] self.assertGreaterEqual(obs_f.rs.n, 100) self.assertLessEqual(obs_f.buffer.n, 20) @@ -412,7 +413,7 @@ class TestPolicyEvaluator(unittest.TestCase): time.sleep(2) ev.sample() filters = ev.get_filters(flush_after=True) - obs_f = filters["default"] + obs_f = filters[DEFAULT_POLICY_ID] self.assertNotEqual(obs_f.rs.n, 0) self.assertNotEqual(obs_f.buffer.n, 0) return obs_f diff --git a/python/ray/rllib/utils/debug.py b/python/ray/rllib/utils/debug.py new file mode 100644 index 000000000..63638d292 --- /dev/null +++ b/python/ray/rllib/utils/debug.py @@ -0,0 +1,111 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import pprint +import time + +from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch + +_logged = set() +_disabled = False +_periodic_log = True +_last_logged = 0.0 +_printer = pprint.PrettyPrinter(indent=2, width=60) + + +def log_once(key): + """Returns True if this is the "first" call for a given key. + + Various logging settings can adjust the definition of "first". + + Example: + >>> if log_once("some_key"): + ... logger.info("Some verbose logging statement") + """ + + global _last_logged + + if _disabled: + return False + elif key not in _logged: + _logged.add(key) + _last_logged = time.time() + return True + elif _periodic_log and time.time() - _last_logged > 60.0: + _logged.clear() + _last_logged = time.time() + return False + else: + return False + + +def disable_log_once_globally(): + """Make log_once() return False in this process.""" + + global _disabled + _disabled = True + + +def enable_periodic_logging(): + """Make log_once() periodically return True in this process.""" + + global _periodic_log + _periodic_log = True + + +def summarize(obj): + """Return a pretty-formatted string for an object. + + This has special handling for pretty-formatting of commonly used data types + in RLlib, such as SampleBatch, numpy arrays, etc. + """ + + return _printer.pformat(_summarize(obj)) + + +def _summarize(obj): + if isinstance(obj, dict): + return {k: _summarize(v) for k, v in obj.items()} + elif hasattr(obj, "_asdict"): + return { + "type": obj.__class__.__name__, + "data": _summarize(obj._asdict()), + } + elif isinstance(obj, list): + return [_summarize(x) for x in obj] + elif isinstance(obj, tuple): + return tuple(_summarize(x) for x in obj) + elif isinstance(obj, np.ndarray): + if obj.dtype == np.object: + return _StringValue("np.ndarray({}, dtype={}, head={})".format( + obj.shape, obj.dtype, _summarize(obj[0]))) + else: + return _StringValue( + "np.ndarray({}, dtype={}, min={}, max={}, mean={})".format( + obj.shape, obj.dtype, round(float(np.min(obj)), 3), + round(float(np.max(obj)), 3), round( + float(np.mean(obj)), 3))) + elif isinstance(obj, MultiAgentBatch): + return { + "type": "MultiAgentBatch", + "policy_batches": _summarize(obj.policy_batches), + "count": obj.count, + } + elif isinstance(obj, SampleBatch): + return { + "type": "SampleBatch", + "data": {k: _summarize(v) + for k, v in obj.items()}, + } + else: + return obj + + +class _StringValue(object): + def __init__(self, value): + self.value = value + + def __repr__(self): + return self.value