mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 17:18:45 +08:00
[rllib] Print out intermediate data shapes on the first iteration (#4426)
This commit is contained in:
@@ -41,7 +41,10 @@ COMMON_CONFIG = {
|
||||
# === Debugging ===
|
||||
# Whether to write episode stats and videos to the agent log dir
|
||||
"monitor": False,
|
||||
# Set the ray.rllib.* log level for the agent process and its evaluators
|
||||
# Set the ray.rllib.* log level for the agent process and its evaluators.
|
||||
# Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also
|
||||
# periodically print out summaries of relevant internal dataflow (this is
|
||||
# also printed out once at startup at the INFO level).
|
||||
"log_level": "INFO",
|
||||
# Callbacks that will be run during various phases of training. These all
|
||||
# take a single "info" dict as an argument. For episode callbacks, custom
|
||||
@@ -407,7 +410,7 @@ class Agent(Trainable):
|
||||
prev_action=None,
|
||||
prev_reward=None,
|
||||
info=None,
|
||||
policy_id="default"):
|
||||
policy_id=DEFAULT_POLICY_ID):
|
||||
"""Computes an action for the specified policy.
|
||||
|
||||
Note that you can also access the policy object through
|
||||
|
||||
@@ -17,6 +17,7 @@ from ray.rllib.agents import Agent, with_common_config
|
||||
from ray.rllib.agents.ars import optimizers
|
||||
from ray.rllib.agents.ars import policies
|
||||
from ray.rllib.agents.ars import utils
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import FilterManager
|
||||
|
||||
@@ -87,7 +88,7 @@ class Worker(object):
|
||||
|
||||
@property
|
||||
def filters(self):
|
||||
return {"default": self.policy.get_filter()}
|
||||
return {DEFAULT_POLICY_ID: self.policy.get_filter()}
|
||||
|
||||
def sync_filters(self, new_filters):
|
||||
for k in self.filters:
|
||||
@@ -271,7 +272,7 @@ class ARSAgent(Agent):
|
||||
|
||||
# Now sync the filters
|
||||
FilterManager.synchronize({
|
||||
"default": self.policy.get_filter()
|
||||
DEFAULT_POLICY_ID: self.policy.get_filter()
|
||||
}, self.workers)
|
||||
|
||||
info = {
|
||||
@@ -335,5 +336,5 @@ class ARSAgent(Agent):
|
||||
self.policy.set_weights(state["weights"])
|
||||
self.policy.set_filter(state["filter"])
|
||||
FilterManager.synchronize({
|
||||
"default": self.policy.get_filter()
|
||||
DEFAULT_POLICY_ID: self.policy.get_filter()
|
||||
}, self.workers)
|
||||
|
||||
@@ -10,6 +10,7 @@ from ray.rllib import optimizers
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
|
||||
|
||||
@@ -195,7 +196,7 @@ class DQNAgent(Agent):
|
||||
policies = info["policy"]
|
||||
episode = info["episode"]
|
||||
episode.custom_metrics["policy_distance"] = policies[
|
||||
"default"].pi_distance
|
||||
DEFAULT_POLICY_ID].pi_distance
|
||||
if end_callback:
|
||||
end_callback(info)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from ray.rllib.agents import Agent, with_common_config
|
||||
from ray.rllib.agents.es import optimizers
|
||||
from ray.rllib.agents.es import policies
|
||||
from ray.rllib.agents.es import utils
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import FilterManager
|
||||
|
||||
@@ -91,7 +92,7 @@ class Worker(object):
|
||||
|
||||
@property
|
||||
def filters(self):
|
||||
return {"default": self.policy.get_filter()}
|
||||
return {DEFAULT_POLICY_ID: self.policy.get_filter()}
|
||||
|
||||
def sync_filters(self, new_filters):
|
||||
for k in self.filters:
|
||||
@@ -268,7 +269,7 @@ class ESAgent(Agent):
|
||||
|
||||
# Now sync the filters
|
||||
FilterManager.synchronize({
|
||||
"default": self.policy.get_filter()
|
||||
DEFAULT_POLICY_ID: self.policy.get_filter()
|
||||
}, self.workers)
|
||||
|
||||
info = {
|
||||
@@ -332,5 +333,5 @@ class ESAgent(Agent):
|
||||
self.policy.set_weights(state["weights"])
|
||||
self.policy.set_filter(state["filter"])
|
||||
FilterManager.synchronize({
|
||||
"default": self.policy.get_filter()
|
||||
DEFAULT_POLICY_ID: self.policy.get_filter()
|
||||
}, self.workers)
|
||||
|
||||
Vendored
+1
-1
@@ -186,7 +186,7 @@ class BaseEnv(object):
|
||||
|
||||
|
||||
# Fixed agent identifier when there is only the single agent in the env
|
||||
_DUMMY_AGENT_ID = "single_agent"
|
||||
_DUMMY_AGENT_ID = "singleton_agent"
|
||||
|
||||
|
||||
def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
|
||||
|
||||
@@ -28,6 +28,8 @@ from ray.rllib.models.preprocessors import NoPreprocessor
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.compression import pack
|
||||
from ray.rllib.utils.debug import disable_log_once_globally, log_once, \
|
||||
summarize, enable_periodic_logging
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
@@ -209,10 +211,16 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
if log_level:
|
||||
logging.getLogger("ray.rllib").setLevel(log_level)
|
||||
|
||||
if worker_index > 1:
|
||||
disable_log_once_globally() # only need 1 evaluator to log
|
||||
elif log_level == "DEBUG":
|
||||
enable_periodic_logging()
|
||||
|
||||
env_context = EnvContext(env_config or {}, worker_index)
|
||||
policy_config = policy_config or {}
|
||||
self.policy_config = policy_config
|
||||
self.callbacks = callbacks or {}
|
||||
self.worker_index = worker_index
|
||||
model_config = model_config or {}
|
||||
policy_mapping_fn = (policy_mapping_fn
|
||||
or (lambda agent_id: DEFAULT_POLICY_ID))
|
||||
@@ -304,6 +312,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
policy.observation_space.shape)
|
||||
for (policy_id, policy) in self.policy_map.items()
|
||||
}
|
||||
if self.worker_index == 0:
|
||||
logger.info("Built filter map: {}".format(self.filters))
|
||||
|
||||
# Always use vector env for consistency even if num_envs = 1
|
||||
self.async_env = BaseEnv.to_base_env(
|
||||
@@ -390,6 +400,10 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
SampleBatch|MultiAgentBatch from evaluating the current policies.
|
||||
"""
|
||||
|
||||
if log_once("sample_start"):
|
||||
logger.info("Generating sample batch of size {}".format(
|
||||
self.sample_batch_size))
|
||||
|
||||
batches = [self.input_reader.next()]
|
||||
steps_so_far = batches[0].count
|
||||
|
||||
@@ -423,6 +437,10 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
for estimator in self.reward_estimators:
|
||||
estimator.process(sub_batch)
|
||||
|
||||
if log_once("sample_end"):
|
||||
logger.info("Completed sample batch:\n\n{}\n".format(
|
||||
summarize(batch)))
|
||||
|
||||
if self.compress_observations:
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
for data in batch.policy_batches.values():
|
||||
@@ -457,6 +475,9 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def compute_gradients(self, samples):
|
||||
if log_once("compute_gradients"):
|
||||
logger.info("Compute gradients on:\n\n{}\n".format(
|
||||
summarize(samples)))
|
||||
if isinstance(samples, MultiAgentBatch):
|
||||
grad_out, info_out = {}, {}
|
||||
if self.tf_sess is not None:
|
||||
@@ -479,10 +500,15 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
grad_out, info_out = (
|
||||
self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
|
||||
info_out["batch_count"] = samples.count
|
||||
if log_once("grad_out"):
|
||||
logger.info("Compute grad info:\n\n{}\n".format(
|
||||
summarize(info_out)))
|
||||
return grad_out, info_out
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def apply_gradients(self, grads):
|
||||
if log_once("apply_gradients"):
|
||||
logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
|
||||
if isinstance(grads, dict):
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "apply_gradients")
|
||||
@@ -502,6 +528,10 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def learn_on_batch(self, samples):
|
||||
if log_once("learn_on_batch"):
|
||||
logger.info(
|
||||
"Training on concatenated sample batches:\n\n{}\n".format(
|
||||
summarize(samples)))
|
||||
if isinstance(samples, MultiAgentBatch):
|
||||
info_out = {}
|
||||
if self.tf_sess is not None:
|
||||
@@ -519,11 +549,12 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
continue
|
||||
info_out[pid], _ = (
|
||||
self.policy_map[pid].learn_on_batch(batch))
|
||||
return info_out
|
||||
else:
|
||||
grad_fetch, apply_fetch = (
|
||||
info_out, _ = (
|
||||
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
|
||||
return grad_fetch
|
||||
if log_once("learn_out"):
|
||||
logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
|
||||
return info_out
|
||||
|
||||
@DeveloperAPI
|
||||
def get_metrics(self):
|
||||
@@ -659,6 +690,9 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
"Tuple|DictFlatteningPreprocessor.")
|
||||
with tf.variable_scope(name):
|
||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
if self.worker_index == 0:
|
||||
logger.info("Built policy map: {}".format(policy_map))
|
||||
logger.info("Built preprocessor map: {}".format(preprocessors))
|
||||
return policy_map, preprocessors
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@@ -10,7 +10,7 @@ from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.memory import concat_aligned
|
||||
|
||||
# Defaults policy id for single agent environments
|
||||
DEFAULT_POLICY_ID = "default"
|
||||
DEFAULT_POLICY_ID = "default_policy"
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
||||
@@ -3,10 +3,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.debug import log_once, summarize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_float_array(v):
|
||||
@@ -145,10 +149,16 @@ class MultiAgentSampleBatchBuilder(object):
|
||||
post_batches[agent_id] = policy.postprocess_trajectory(
|
||||
pre_batch, other_batches, episode)
|
||||
|
||||
if log_once("after_post"):
|
||||
logger.info(
|
||||
"Trajectory fragment after postprocess_trajectory():\n\n{}\n".
|
||||
format(summarize(post_batches)))
|
||||
|
||||
# Append into policy batches and reset
|
||||
for agent_id, post_batch in sorted(post_batches.items()):
|
||||
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
|
||||
post_batch)
|
||||
|
||||
self.agent_builders.clear()
|
||||
self.agent_to_policy.clear()
|
||||
|
||||
|
||||
@@ -18,10 +18,10 @@ from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
|
||||
from ray.rllib.models.action_dist import TupleActions
|
||||
from ray.rllib.offline import InputReader
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import log_once, summarize
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_large_batch_warned = False
|
||||
|
||||
RolloutMetrics = namedtuple(
|
||||
"RolloutMetrics",
|
||||
@@ -303,6 +303,11 @@ def _env_runner(base_env,
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
|
||||
base_env.poll()
|
||||
|
||||
if log_once("env_returns"):
|
||||
logger.info("Raw obs from env: {}".format(
|
||||
summarize(unfiltered_obs)))
|
||||
logger.info("Info return from env: {}".format(summarize(infos)))
|
||||
|
||||
# Process observations and prepare for policy evaluation
|
||||
active_envs, to_eval, outputs = _process_observations(
|
||||
base_env, policies, batch_builder_pool, active_episodes,
|
||||
@@ -350,10 +355,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||
episode.batch_builder.count += 1
|
||||
episode._add_agent_rewards(rewards[env_id])
|
||||
|
||||
global _large_batch_warned
|
||||
if (not _large_batch_warned and
|
||||
episode.batch_builder.total() > max(1000, unroll_length * 10)):
|
||||
_large_batch_warned = True
|
||||
if (episode.batch_builder.total() > max(1000, unroll_length * 10)
|
||||
and log_once("large_batch_warning")):
|
||||
logger.warning(
|
||||
"More than {} observations for {} env steps ".format(
|
||||
episode.batch_builder.total(),
|
||||
@@ -387,7 +390,13 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
prep_obs = _get_or_raise(preprocessors,
|
||||
policy_id).transform(raw_obs)
|
||||
if log_once("prep_obs"):
|
||||
logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
|
||||
|
||||
filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
|
||||
if log_once("filtered_obs"):
|
||||
logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
|
||||
|
||||
agent_done = bool(all_done or dones[env_id].get(agent_id))
|
||||
if not agent_done:
|
||||
to_eval[policy_id].append(
|
||||
@@ -491,6 +500,11 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
pending_fetches = {}
|
||||
else:
|
||||
builder = None
|
||||
|
||||
if log_once("compute_actions_input"):
|
||||
logger.info("Example compute_actions() input:\n\n{}\n".format(
|
||||
summarize(to_eval)))
|
||||
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
@@ -514,6 +528,10 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
for k, v in pending_fetches.items():
|
||||
eval_results[k] = builder.get(v)
|
||||
|
||||
if log_once("compute_actions_result"):
|
||||
logger.info("Example compute_actions() result:\n\n{}\n".format(
|
||||
summarize(eval_results)))
|
||||
|
||||
return eval_results
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.debug import log_once, summarize
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
@@ -466,6 +467,15 @@ class TFPolicyGraph(PolicyGraph):
|
||||
for k, v in zip(state_keys, initial_states):
|
||||
feed_dict[self._loss_input_dict[k]] = v
|
||||
feed_dict[self._seq_lens] = seq_lens
|
||||
|
||||
if log_once("rnn_feed_dict"):
|
||||
logger.info("Padded input for RNN:\n\n{}\n".format(
|
||||
summarize({
|
||||
"features": feature_sequences,
|
||||
"initial_states": initial_states,
|
||||
"seq_lens": seq_lens,
|
||||
"max_seq_len": max_seq_len,
|
||||
})))
|
||||
return feed_dict
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import threading
|
||||
from six.moves import queue
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.actors import TaskPool
|
||||
@@ -321,9 +322,9 @@ class TFMultiGPULearner(LearnerThread):
|
||||
assert self.train_batch_size % len(self.devices) == 0
|
||||
assert self.train_batch_size >= len(self.devices), "batch too small"
|
||||
|
||||
if set(self.local_evaluator.policy_map.keys()) != {"default"}:
|
||||
if set(self.local_evaluator.policy_map.keys()) != {DEFAULT_POLICY_ID}:
|
||||
raise NotImplementedError("Multi-gpu mode for multi-agent")
|
||||
self.policy = self.local_evaluator.policy_map["default"]
|
||||
self.policy = self.local_evaluator.policy_map[DEFAULT_POLICY_ID]
|
||||
|
||||
# per-GPU graph copies created below must share vars with the policy
|
||||
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
||||
@@ -331,7 +332,7 @@ class TFMultiGPULearner(LearnerThread):
|
||||
self.par_opt = []
|
||||
with self.local_evaluator.tf_sess.graph.as_default():
|
||||
with self.local_evaluator.tf_sess.as_default():
|
||||
with tf.variable_scope("default", reuse=tf.AUTO_REUSE):
|
||||
with tf.variable_scope(DEFAULT_POLICY_ID, reuse=tf.AUTO_REUSE):
|
||||
if self.policy._state_inputs:
|
||||
rnn_inputs = self.policy._state_inputs + [
|
||||
self.policy._seq_lens
|
||||
|
||||
@@ -4,9 +4,10 @@ from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from ray.rllib.utils.debug import log_once, summarize
|
||||
|
||||
# Variable scope in which created variables will be placed under
|
||||
TOWER_SCOPE_NAME = "tower"
|
||||
|
||||
@@ -134,6 +135,15 @@ class LocalSyncParallelOptimizer(object):
|
||||
The number of tuples loaded per device.
|
||||
"""
|
||||
|
||||
if log_once("load_data"):
|
||||
logger.info(
|
||||
"Training on concatenated sample batches:\n\n{}\n".format(
|
||||
summarize({
|
||||
"placeholders": self.loss_inputs,
|
||||
"inputs": inputs,
|
||||
"state_inputs": state_inputs
|
||||
})))
|
||||
|
||||
feed_dict = {}
|
||||
assert len(self.loss_inputs) == len(inputs + state_inputs), \
|
||||
(self.loss_inputs, inputs, state_inputs)
|
||||
|
||||
@@ -12,6 +12,7 @@ import pickle
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.sampler import clip_action
|
||||
from ray.tune.util import merge_dicts
|
||||
|
||||
@@ -116,7 +117,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
||||
else:
|
||||
env = gym.make(env_name)
|
||||
multiagent = False
|
||||
use_lstm = {'default': False}
|
||||
use_lstm = {DEFAULT_POLICY_ID: False}
|
||||
|
||||
if out is not None:
|
||||
rollouts = []
|
||||
@@ -148,7 +149,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
||||
action_dict[agent_id] = a_action
|
||||
action = action_dict
|
||||
else:
|
||||
if use_lstm["default"]:
|
||||
if use_lstm[DEFAULT_POLICY_ID]:
|
||||
action, state_init, _ = agent.compute_action(
|
||||
state, state=state_init)
|
||||
else:
|
||||
|
||||
@@ -16,6 +16,7 @@ from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
@@ -367,7 +368,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
time.sleep(2)
|
||||
ev.sample()
|
||||
filters = ev.get_filters(flush_after=True)
|
||||
obs_f = filters["default"]
|
||||
obs_f = filters[DEFAULT_POLICY_ID]
|
||||
self.assertNotEqual(obs_f.rs.n, 0)
|
||||
self.assertNotEqual(obs_f.buffer.n, 0)
|
||||
|
||||
@@ -381,8 +382,8 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
filters = ev.get_filters(flush_after=False)
|
||||
time.sleep(2)
|
||||
filters2 = ev.get_filters(flush_after=False)
|
||||
obs_f = filters["default"]
|
||||
obs_f2 = filters2["default"]
|
||||
obs_f = filters[DEFAULT_POLICY_ID]
|
||||
obs_f2 = filters2[DEFAULT_POLICY_ID]
|
||||
self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n)
|
||||
self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
|
||||
|
||||
@@ -396,15 +397,15 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
|
||||
# Current State
|
||||
filters = ev.get_filters(flush_after=False)
|
||||
obs_f = filters["default"]
|
||||
obs_f = filters[DEFAULT_POLICY_ID]
|
||||
|
||||
self.assertLessEqual(obs_f.buffer.n, 20)
|
||||
|
||||
new_obsf = obs_f.copy()
|
||||
new_obsf.rs._n = 100
|
||||
ev.sync_filters({"default": new_obsf})
|
||||
ev.sync_filters({DEFAULT_POLICY_ID: new_obsf})
|
||||
filters = ev.get_filters(flush_after=False)
|
||||
obs_f = filters["default"]
|
||||
obs_f = filters[DEFAULT_POLICY_ID]
|
||||
self.assertGreaterEqual(obs_f.rs.n, 100)
|
||||
self.assertLessEqual(obs_f.buffer.n, 20)
|
||||
|
||||
@@ -412,7 +413,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
time.sleep(2)
|
||||
ev.sample()
|
||||
filters = ev.get_filters(flush_after=True)
|
||||
obs_f = filters["default"]
|
||||
obs_f = filters[DEFAULT_POLICY_ID]
|
||||
self.assertNotEqual(obs_f.rs.n, 0)
|
||||
self.assertNotEqual(obs_f.buffer.n, 0)
|
||||
return obs_f
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pprint
|
||||
import time
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
|
||||
|
||||
_logged = set()
|
||||
_disabled = False
|
||||
_periodic_log = True
|
||||
_last_logged = 0.0
|
||||
_printer = pprint.PrettyPrinter(indent=2, width=60)
|
||||
|
||||
|
||||
def log_once(key):
|
||||
"""Returns True if this is the "first" call for a given key.
|
||||
|
||||
Various logging settings can adjust the definition of "first".
|
||||
|
||||
Example:
|
||||
>>> if log_once("some_key"):
|
||||
... logger.info("Some verbose logging statement")
|
||||
"""
|
||||
|
||||
global _last_logged
|
||||
|
||||
if _disabled:
|
||||
return False
|
||||
elif key not in _logged:
|
||||
_logged.add(key)
|
||||
_last_logged = time.time()
|
||||
return True
|
||||
elif _periodic_log and time.time() - _last_logged > 60.0:
|
||||
_logged.clear()
|
||||
_last_logged = time.time()
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def disable_log_once_globally():
|
||||
"""Make log_once() return False in this process."""
|
||||
|
||||
global _disabled
|
||||
_disabled = True
|
||||
|
||||
|
||||
def enable_periodic_logging():
|
||||
"""Make log_once() periodically return True in this process."""
|
||||
|
||||
global _periodic_log
|
||||
_periodic_log = True
|
||||
|
||||
|
||||
def summarize(obj):
|
||||
"""Return a pretty-formatted string for an object.
|
||||
|
||||
This has special handling for pretty-formatting of commonly used data types
|
||||
in RLlib, such as SampleBatch, numpy arrays, etc.
|
||||
"""
|
||||
|
||||
return _printer.pformat(_summarize(obj))
|
||||
|
||||
|
||||
def _summarize(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {k: _summarize(v) for k, v in obj.items()}
|
||||
elif hasattr(obj, "_asdict"):
|
||||
return {
|
||||
"type": obj.__class__.__name__,
|
||||
"data": _summarize(obj._asdict()),
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [_summarize(x) for x in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(_summarize(x) for x in obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
if obj.dtype == np.object:
|
||||
return _StringValue("np.ndarray({}, dtype={}, head={})".format(
|
||||
obj.shape, obj.dtype, _summarize(obj[0])))
|
||||
else:
|
||||
return _StringValue(
|
||||
"np.ndarray({}, dtype={}, min={}, max={}, mean={})".format(
|
||||
obj.shape, obj.dtype, round(float(np.min(obj)), 3),
|
||||
round(float(np.max(obj)), 3), round(
|
||||
float(np.mean(obj)), 3)))
|
||||
elif isinstance(obj, MultiAgentBatch):
|
||||
return {
|
||||
"type": "MultiAgentBatch",
|
||||
"policy_batches": _summarize(obj.policy_batches),
|
||||
"count": obj.count,
|
||||
}
|
||||
elif isinstance(obj, SampleBatch):
|
||||
return {
|
||||
"type": "SampleBatch",
|
||||
"data": {k: _summarize(v)
|
||||
for k, v in obj.items()},
|
||||
}
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
class _StringValue(object):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return self.value
|
||||
Reference in New Issue
Block a user