diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst
index 37ea011a0..4f8a4c66a 100644
--- a/doc/source/rllib-env.rst
+++ b/doc/source/rllib-env.rst
@@ -107,6 +107,10 @@ RLlib will auto-vectorize Gym envs for batch evaluation if the ``num_envs_per_wo
Multi-Agent
-----------
+.. note::
+
+ Learn more about multi-agent reinforcement learning in RLlib by reading the `blog post `__.
+
A multi-agent environment is one which has multiple acting entities per step, e.g., in a traffic simulation, there may be multiple "car" and "traffic light" agents in the environment. The model for multi-agent in RLlib as follows: (1) as a user you define the number of policies available up front, and (2) a function that maps agent ids to policy ids. This is summarized by the below figure:
.. image:: multi-agent.svg
diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst
index 4b6630090..dc350d272 100644
--- a/doc/source/rllib-training.rst
+++ b/doc/source/rllib-training.rst
@@ -228,7 +228,7 @@ Ray actors provide high levels of performance, so in more complex cases they can
Callbacks and Custom Metrics
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-You can provide callback functions to be called at points during policy evaluation. These functions have access to an info dict containing state for the current `episode `__. Custom state can be stored for the `episode `__ in the ``info["episode"].user_data`` dict, and custom scalar metrics reported by saving values to the ``info["episode"].custom_metrics`` dict. These custom metrics will be averaged and reported as part of training results. The following example (full code `here `__) logs a custom metric from the environment:
+You can provide callback functions to be called at points during policy evaluation. These functions have access to an info dict containing state for the current `episode `__. Custom state can be stored for the `episode `__ in the ``info["episode"].user_data`` dict, and custom scalar metrics reported by saving values to the ``info["episode"].custom_metrics`` dict. These custom metrics will be aggregated and reported as part of training results. The following example (full code `here `__) logs a custom metric from the environment:
.. code-block:: python
@@ -245,10 +245,10 @@ You can provide callback functions to be called at points during policy evaluati
def on_episode_end(info):
episode = info["episode"]
- mean_pole_angle = np.mean(episode.user_data["pole_angles"])
+ pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode {} ended with length {} and pole angles {}".format(
- episode.episode_id, episode.length, mean_pole_angle))
- episode.custom_metrics["mean_pole_angle"] = mean_pole_angle
+ episode.episode_id, episode.length, pole_angle))
+ episode.custom_metrics["pole_angle"] = pole_angle
def on_train_result(info):
print("agent.train() result: {} -> {} episodes".format(
diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py
index c883ef250..af43b7d17 100644
--- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py
+++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py
@@ -253,6 +253,10 @@ class QLoss(object):
self.td_error = tf.nn.softmax_cross_entropy_with_logits(
labels=m, logits=q_logits_t_selected)
self.loss = tf.reduce_mean(self.td_error * importance_weights)
+ self.stats = {
+ # TODO: better Q stats for dist dqn
+ "mean_td_error": tf.reduce_mean(self.td_error),
+ }
else:
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
@@ -264,6 +268,12 @@ class QLoss(object):
q_t_selected - tf.stop_gradient(q_t_selected_target))
self.loss = tf.reduce_mean(
importance_weights * _huber_loss(self.td_error))
+ self.stats = {
+ "mean_q": tf.reduce_mean(q_t_selected),
+ "min_q": tf.reduce_min(q_t_selected),
+ "max_q": tf.reduce_max(q_t_selected),
+ "mean_td_error": tf.reduce_mean(self.td_error),
+ }
class DQNPolicyGraph(TFPolicyGraph):
@@ -430,6 +440,7 @@ class DQNPolicyGraph(TFPolicyGraph):
def extra_compute_grad_fetches(self):
return {
"td_error": self.loss.td_error,
+ "stats": self.loss.stats,
}
def postprocess_trajectory(self,
diff --git a/python/ray/rllib/env/async_vector_env.py b/python/ray/rllib/env/async_vector_env.py
index aff373802..edbf7a233 100644
--- a/python/ray/rllib/env/async_vector_env.py
+++ b/python/ray/rllib/env/async_vector_env.py
@@ -291,6 +291,13 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
assert isinstance(rewards, dict), "Not a multi-agent reward"
assert isinstance(dones, dict), "Not a multi-agent return"
assert isinstance(infos, dict), "Not a multi-agent info"
+ if set(obs.keys()) != set(rewards.keys()):
+ raise ValueError(
+ "Key set for obs and rewards must be the same: "
+ "{} vs {}".format(obs.keys(), rewards.keys()))
+ if set(obs.keys()) != set(infos.keys()):
+ raise ValueError("Key set for obs and infos must be the same: "
+ "{} vs {}".format(obs.keys(), infos.keys()))
if dones["__all__"]:
self.dones.add(env_id)
self.env_states[env_id].observe(obs, rewards, dones, infos)
diff --git a/python/ray/rllib/env/multi_agent_env.py b/python/ray/rllib/env/multi_agent_env.py
index 42f7cee8c..2e569230a 100644
--- a/python/ray/rllib/env/multi_agent_env.py
+++ b/python/ray/rllib/env/multi_agent_env.py
@@ -56,7 +56,7 @@ class MultiAgentEnv(object):
rewards (dict): Reward values for each ready agent. If the
episode is just started, the value will be None.
dones (dict): Done values for each ready agent. The special key
- "__all__" is used to indicate env termination.
+ "__all__" (required) is used to indicate env termination.
infos (dict): Info values for each ready agent.
"""
raise NotImplementedError
diff --git a/python/ray/rllib/evaluation/episode.py b/python/ray/rllib/evaluation/episode.py
index 24fa431f9..119777451 100644
--- a/python/ray/rllib/evaluation/episode.py
+++ b/python/ray/rllib/evaluation/episode.py
@@ -60,6 +60,7 @@ class MultiAgentEpisode(object):
self._agent_to_policy = {}
self._agent_to_rnn_state = {}
self._agent_to_last_obs = {}
+ self._agent_to_last_info = {}
self._agent_to_last_action = {}
self._agent_to_last_pi_info = {}
self._agent_to_prev_action = {}
@@ -81,6 +82,11 @@ class MultiAgentEpisode(object):
return self._agent_to_last_obs.get(agent_id)
+ def last_info_for(self, agent_id=_DUMMY_AGENT_ID):
+ """Returns the last info for the specified agent."""
+
+ return self._agent_to_last_info.get(agent_id)
+
def last_action_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last action for the specified agent, or zeros."""
@@ -137,6 +143,9 @@ class MultiAgentEpisode(object):
def _set_last_observation(self, agent_id, obs):
self._agent_to_last_obs[agent_id] = obs
+ def _set_last_info(self, agent_id, info):
+ self._agent_to_last_info[agent_id] = info
+
def _set_last_action(self, agent_id, action):
self._agent_to_last_action[agent_id] = action
diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py
index fadf2a5a2..92c357d11 100644
--- a/python/ray/rllib/evaluation/metrics.py
+++ b/python/ray/rllib/evaluation/metrics.py
@@ -80,8 +80,16 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
for policy_id, rewards in policy_rewards.copy().items():
policy_rewards[policy_id] = np.mean(rewards)
- for k, v_list in custom_metrics.items():
- custom_metrics[k] = np.mean(v_list)
+ for k, v_list in custom_metrics.copy().items():
+ custom_metrics[k + "_mean"] = np.mean(v_list)
+ filt = [v for v in v_list if not np.isnan(v)]
+ if filt:
+ custom_metrics[k + "_min"] = np.min(filt)
+ custom_metrics[k + "_max"] = np.max(filt)
+ else:
+ custom_metrics[k + "_min"] = float("nan")
+ custom_metrics[k + "_max"] = float("nan")
+ del custom_metrics[k]
return dict(
episode_reward_max=max_reward,
diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py
index 33d5ee219..360139839 100644
--- a/python/ray/rllib/evaluation/policy_evaluator.py
+++ b/python/ray/rllib/evaluation/policy_evaluator.py
@@ -8,7 +8,6 @@ import pickle
import tensorflow as tf
import ray
-from ray.rllib.models import ModelCatalog
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.env.env_context import EnvContext
@@ -19,6 +18,8 @@ from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
+from ray.rllib.models import ModelCatalog
+from ray.rllib.models.preprocessors import NoPreprocessor
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.compression import pack
from ray.rllib.utils.filter import get_filter
@@ -191,23 +192,21 @@ class PolicyEvaluator(EvaluatorInterface):
self.sample_batch_size = batch_steps * num_envs
self.batch_mode = batch_mode
self.compress_observations = compress_observations
+ self.preprocessing_enabled = True
self.env = env_creator(env_context)
if isinstance(self.env, MultiAgentEnv) or \
isinstance(self.env, AsyncVectorEnv):
- if model_config.get("custom_preprocessor"):
- raise ValueError(
- "Custom preprocessors are not supported for env types "
- "MultiAgentEnv and AsyncVectorEnv. Please preprocess "
- "observations in your env directly.")
-
def wrap(env):
return env # we can't auto-wrap these env types
elif is_atari(self.env) and \
not model_config.get("custom_preprocessor") and \
preprocessor_pref == "deepmind":
+ # Deepmind wrappers already handle all preprocessing
+ self.preprocessing_enabled = False
+
if clip_rewards is None:
clip_rewards = True
@@ -222,8 +221,6 @@ class PolicyEvaluator(EvaluatorInterface):
else:
def wrap(env):
- env = ModelCatalog.get_preprocessor_as_wrapper(
- env, model_config)
if monitor_path:
env = _monitor(env, monitor_path)
return env
@@ -246,11 +243,11 @@ class PolicyEvaluator(EvaluatorInterface):
config=tf.ConfigProto(
gpu_options=tf.GPUOptions(allow_growth=True)))
with self.tf_sess.as_default():
- self.policy_map = self._build_policy_map(
- policy_dict, policy_config)
+ self.policy_map, self.preprocessors = \
+ self._build_policy_map(policy_dict, policy_config)
else:
- self.policy_map = self._build_policy_map(policy_dict,
- policy_config)
+ self.policy_map, self.preprocessors = self._build_policy_map(
+ policy_dict, policy_config)
self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent:
@@ -286,6 +283,7 @@ class PolicyEvaluator(EvaluatorInterface):
self.async_env,
self.policy_map,
policy_mapping_fn,
+ self.preprocessors,
self.filters,
clip_rewards,
unroll_length,
@@ -300,6 +298,7 @@ class PolicyEvaluator(EvaluatorInterface):
self.async_env,
self.policy_map,
policy_mapping_fn,
+ self.preprocessors,
self.filters,
clip_rewards,
unroll_length,
@@ -314,24 +313,26 @@ class PolicyEvaluator(EvaluatorInterface):
def _build_policy_map(self, policy_dict, policy_config):
policy_map = {}
+ preprocessors = {}
for name, (cls, obs_space, act_space,
conf) in sorted(policy_dict.items()):
merged_conf = merge_dicts(policy_config, conf)
+ if self.preprocessing_enabled:
+ preprocessor = ModelCatalog.get_preprocessor_for_space(
+ obs_space, merged_conf.get("model"))
+ preprocessors[name] = preprocessor
+ obs_space = preprocessor.observation_space
+ else:
+ preprocessors[name] = NoPreprocessor(obs_space)
+ if isinstance(obs_space, gym.spaces.Dict) or \
+ isinstance(obs_space, gym.spaces.Tuple):
+ raise ValueError(
+ "Found raw Tuple|Dict space as input to policy graph. "
+ "Please preprocess these observations with a "
+ "Tuple|DictFlatteningPreprocessor.")
with tf.variable_scope(name):
- if isinstance(obs_space, gym.spaces.Dict):
- raise ValueError(
- "Found raw Dict space as input to policy graph. "
- "Please preprocess your environment observations "
- "with DictFlatteningPreprocessor and set the "
- "obs space to `preprocessor.observation_space`.")
- elif isinstance(obs_space, gym.spaces.Tuple):
- raise ValueError(
- "Found raw Tuple space as input to policy graph. "
- "Please preprocess your environment observations "
- "with TupleFlatteningPreprocessor and set the "
- "obs space to `preprocessor.observation_space`.")
policy_map[name] = cls(obs_space, act_space, merged_conf)
- return policy_map
+ return policy_map, preprocessors
def sample(self):
"""Evaluate the current policies and return a batch of experiences.
@@ -554,6 +555,11 @@ def _validate_and_canonicalize(policy_graph, env):
elif not issubclass(policy_graph, PolicyGraph):
raise ValueError("policy_graph must be a rllib.PolicyGraph class")
else:
+ if (isinstance(env, MultiAgentEnv)
+ and not hasattr(env, "observation_space")):
+ raise ValueError(
+ "MultiAgentEnv must have observation_space defined if run "
+ "in a single-agent configuration.")
return {
DEFAULT_POLICY_ID: (policy_graph, env.observation_space,
env.action_space, {})
diff --git a/python/ray/rllib/evaluation/sample_batch.py b/python/ray/rllib/evaluation/sample_batch.py
index caec1bf43..83f2f5ca1 100644
--- a/python/ray/rllib/evaluation/sample_batch.py
+++ b/python/ray/rllib/evaluation/sample_batch.py
@@ -79,6 +79,11 @@ class MultiAgentSampleBatchBuilder(object):
self.agent_to_policy = {}
self.count = 0 # increment this manually
+ def total(self):
+ """Returns summed number of steps across all agent buffers."""
+
+ return sum(p.count for p in self.policy_builders.values())
+
def has_pending_data(self):
"""Returns whether there is pending unprocessed data."""
diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py
index 0bda18bc0..cc9e32295 100644
--- a/python/ray/rllib/evaluation/sampler.py
+++ b/python/ray/rllib/evaluation/sampler.py
@@ -19,6 +19,7 @@ from ray.rllib.models.action_dist import TupleActions
from ray.rllib.utils.tf_run_builder import TFRunBuilder
logger = logging.getLogger(__name__)
+_large_batch_warned = False
RolloutMetrics = namedtuple(
"RolloutMetrics",
@@ -42,6 +43,7 @@ class SyncSampler(object):
env,
policies,
policy_mapping_fn,
+ preprocessors,
obs_filters,
clip_rewards,
unroll_length,
@@ -55,13 +57,14 @@ class SyncSampler(object):
self.horizon = horizon
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
- self._obs_filters = obs_filters
+ self.preprocessors = preprocessors
+ self.obs_filters = obs_filters
self.extra_batches = queue.Queue()
self.rollout_provider = _env_runner(
self.async_vector_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
- self._obs_filters, clip_rewards, clip_actions, pack, callbacks,
- tf_sess)
+ self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
+ pack, callbacks, tf_sess)
self.metrics_queue = queue.Queue()
def get_data(self):
@@ -101,6 +104,7 @@ class AsyncSampler(threading.Thread):
env,
policies,
policy_mapping_fn,
+ preprocessors,
obs_filters,
clip_rewards,
unroll_length,
@@ -121,7 +125,8 @@ class AsyncSampler(threading.Thread):
self.horizon = horizon
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
- self._obs_filters = obs_filters
+ self.preprocessors = preprocessors
+ self.obs_filters = obs_filters
self.clip_rewards = clip_rewards
self.daemon = True
self.pack = pack
@@ -140,8 +145,8 @@ class AsyncSampler(threading.Thread):
rollout_provider = _env_runner(
self.async_vector_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
- self._obs_filters, self.clip_rewards, self.clip_actions, self.pack,
- self.callbacks, self.tf_sess)
+ self.preprocessors, self.obs_filters, self.clip_rewards,
+ self.clip_actions, self.pack, self.callbacks, self.tf_sess)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
@@ -200,6 +205,7 @@ def _env_runner(async_vector_env,
policy_mapping_fn,
unroll_length,
horizon,
+ preprocessors,
obs_filters,
clip_rewards,
clip_actions,
@@ -218,6 +224,8 @@ def _env_runner(async_vector_env,
unroll_length (int): Number of episode steps before `SampleBatch` is
yielded. Set to infinity to yield complete episodes.
horizon (int): Horizon of the episode.
+ preprocessors (dict): Map of policy id to preprocessor for the
+ observations prior to filtering.
obs_filters (dict): Map of policy id to filter used to process
observations for the policy.
clip_rewards (bool): Whether to clip rewards before postprocessing.
@@ -273,7 +281,7 @@ def _env_runner(async_vector_env,
active_envs, to_eval, outputs = _process_observations(
async_vector_env, policies, batch_builder_pool, active_episodes,
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
- obs_filters, unroll_length, pack, callbacks)
+ preprocessors, obs_filters, unroll_length, pack, callbacks)
for o in outputs:
yield o
@@ -293,8 +301,8 @@ def _env_runner(async_vector_env,
def _process_observations(async_vector_env, policies, batch_builder_pool,
active_episodes, unfiltered_obs, rewards, dones,
- infos, off_policy_actions, horizon, obs_filters,
- unroll_length, pack, callbacks):
+ infos, off_policy_actions, horizon, preprocessors,
+ obs_filters, unroll_length, pack, callbacks):
"""Record new data from the environment and prepare for policy evaluation.
Returns:
@@ -316,6 +324,21 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
episode.batch_builder.count += 1
episode._add_agent_rewards(rewards[env_id])
+ global _large_batch_warned
+ if (not _large_batch_warned and
+ episode.batch_builder.total() > max(1000, unroll_length * 10)):
+ _large_batch_warned = True
+ logger.warn(
+ "More than {} observations for {} env steps ".format(
+ episode.batch_builder.total(),
+ episode.batch_builder.count) + "are buffered in "
+ "the sampler. If this is not intentional, check that the "
+ "the `horizon` config is set correctly, or consider setting "
+ "`batch_mode` to 'truncate_episodes'. Note that in "
+ "multi-agent environments, `sample_batch_size` sets the "
+ "batch size based on environment steps, not the steps of "
+ "individual agents.")
+
# Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon:
all_done = True
@@ -336,7 +359,9 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
# For each agent in the environment
for agent_id, raw_obs in agent_obs.items():
policy_id = episode.policy_for(agent_id)
- filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs)
+ prep_obs = _get_or_raise(preprocessors,
+ policy_id).transform(raw_obs)
+ filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
agent_done = bool(all_done or dones[env_id].get(agent_id))
if not agent_done:
to_eval[policy_id].append(
@@ -347,6 +372,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
last_observation = episode.last_observation_for(agent_id)
episode._set_last_observation(agent_id, filtered_obs)
+ episode._set_last_info(agent_id, infos[env_id][agent_id])
# Record transition info if applicable
if last_observation is not None and \
@@ -406,8 +432,10 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
for agent_id, raw_obs in resetted_obs.items():
policy_id = episode.policy_for(agent_id)
policy = _get_or_raise(policies, policy_id)
+ prep_obs = _get_or_raise(preprocessors,
+ policy_id).transform(raw_obs)
filtered_obs = _get_or_raise(obs_filters,
- policy_id)(raw_obs)
+ policy_id)(prep_obs)
episode._set_last_observation(agent_id, filtered_obs)
to_eval[policy_id].append(
PolicyEvalData(
diff --git a/python/ray/rllib/examples/custom_metrics_and_callbacks.py b/python/ray/rllib/examples/custom_metrics_and_callbacks.py
index c92ae8783..af1d25f16 100644
--- a/python/ray/rllib/examples/custom_metrics_and_callbacks.py
+++ b/python/ray/rllib/examples/custom_metrics_and_callbacks.py
@@ -25,10 +25,10 @@ def on_episode_step(info):
def on_episode_end(info):
episode = info["episode"]
- mean_pole_angle = np.mean(episode.user_data["pole_angles"])
+ pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode {} ended with length {} and pole angles {}".format(
- episode.episode_id, episode.length, mean_pole_angle))
- episode.custom_metrics["mean_pole_angle"] = mean_pole_angle
+ episode.episode_id, episode.length, pole_angle))
+ episode.custom_metrics["pole_angle"] = pole_angle
def on_sample_end(info):
@@ -70,6 +70,8 @@ if __name__ == "__main__":
# verify custom metrics for integration tests
custom_metrics = trials[0].last_result["custom_metrics"]
print(custom_metrics)
- assert "mean_pole_angle" in custom_metrics
- assert type(custom_metrics["mean_pole_angle"]) is float
+ assert "pole_angle_mean" in custom_metrics
+ assert "pole_angle_min" in custom_metrics
+ assert "pole_angle_max" in custom_metrics
+ assert type(custom_metrics["pole_angle_mean"]) is float
assert "callback_ok" in trials[0].last_result
diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py
index f9e8af282..822af4a37 100644
--- a/python/ray/rllib/models/catalog.py
+++ b/python/ray/rllib/models/catalog.py
@@ -271,15 +271,26 @@ class ModelCatalog(object):
@staticmethod
def get_preprocessor(env, options=None):
- """Returns a suitable processor for the given environment.
+ """Returns a suitable preprocessor for the given env.
+
+ This is a wrapper for get_preprocessor_for_space().
+ """
+
+ return ModelCatalog.get_preprocessor_for_space(env.observation_space,
+ options)
+
+ @staticmethod
+ def get_preprocessor_for_space(observation_space, options=None):
+ """Returns a suitable preprocessor for the given observation space.
Args:
- env (gym.Env|VectorEnv|ExternalEnv): The environment to wrap.
+ observation_space (Space): The input observation space.
options (dict): Options to pass to the preprocessor.
Returns:
- preprocessor (Preprocessor): Preprocessor for the env observations.
+ preprocessor (Preprocessor): Preprocessor for the observations.
"""
+
options = options or MODEL_DEFAULTS
for k in options.keys():
if k not in MODEL_DEFAULTS:
@@ -290,13 +301,13 @@ class ModelCatalog(object):
preprocessor = options["custom_preprocessor"]
logger.info("Using custom preprocessor {}".format(preprocessor))
prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
- env.observation_space, options)
+ observation_space, options)
else:
- cls = get_preprocessor(env.observation_space)
- prep = cls(env.observation_space, options)
+ cls = get_preprocessor(observation_space)
+ prep = cls(observation_space, options)
logger.debug("Created preprocessor {}: {} -> {}".format(
- prep, env.observation_space, prep.shape))
+ prep, observation_space, prep.shape))
return prep
@staticmethod
diff --git a/python/ray/rllib/models/preprocessors.py b/python/ray/rllib/models/preprocessors.py
index a4af708b7..66835e864 100644
--- a/python/ray/rllib/models/preprocessors.py
+++ b/python/ray/rllib/models/preprocessors.py
@@ -109,6 +109,9 @@ class OneHotPreprocessor(Preprocessor):
def transform(self, observation):
arr = np.zeros(self._obs_space.n)
+ if not self._obs_space.contains(observation):
+ raise ValueError("Observation outside expected value range",
+ self._obs_space, observation)
arr[observation] = 1
return arr
diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py
index 3cd5a16ad..1f2167408 100644
--- a/python/ray/rllib/optimizers/async_replay_optimizer.py
+++ b/python/ray/rllib/optimizers/async_replay_optimizer.py
@@ -135,6 +135,7 @@ class LearnerThread(threading.Thread):
self.daemon = True
self.weights_updated = False
self.stopped = False
+ self.stats = {}
def run(self):
while not self.stopped:
@@ -151,6 +152,8 @@ class LearnerThread(threading.Thread):
prio_dict[pid] = (
replay.policy_batches[pid]["batch_indexes"],
info["td_error"])
+ if "stats" in info:
+ self.stats[pid] = info["stats"]
# send `replay` back also so that it gets released by the original
# thread: https://github.com/ray-project/ray/issues/2610
self.outqueue.put((ra, replay, prio_dict, replay.count))
@@ -331,4 +334,6 @@ class AsyncReplayOptimizer(PolicyOptimizer):
}
if self.debug:
stats.update(debug_stats)
+ if self.learner.stats:
+ stats["learner"] = self.learner.stats
return dict(PolicyOptimizer.stats(self), **stats)
diff --git a/python/ray/rllib/optimizers/sync_replay_optimizer.py b/python/ray/rllib/optimizers/sync_replay_optimizer.py
index 73df00601..b15ab3390 100644
--- a/python/ray/rllib/optimizers/sync_replay_optimizer.py
+++ b/python/ray/rllib/optimizers/sync_replay_optimizer.py
@@ -53,6 +53,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
self.replay_timer = TimerStat()
self.grad_timer = TimerStat()
self.throughput = RunningStat()
+ self.learner_stats = {}
# Set up replay buffer
if prioritized_replay:
@@ -111,6 +112,8 @@ class SyncReplayOptimizer(PolicyOptimizer):
with self.grad_timer:
info_dict = self.local_evaluator.compute_apply(samples)
for policy_id, info in info_dict.items():
+ if "stats" in info:
+ self.learner_stats[policy_id] = info["stats"]
replay_buffer = self.replay_buffers[policy_id]
if isinstance(replay_buffer, PrioritizedReplayBuffer):
td_error = info["td_error"]
@@ -160,4 +163,5 @@ class SyncReplayOptimizer(PolicyOptimizer):
"opt_peak_throughput": round(self.grad_timer.mean_throughput,
3),
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
+ "learner": self.learner_stats,
})
diff --git a/python/ray/rllib/test/test_env_with_subprocess.py b/python/ray/rllib/test/test_env_with_subprocess.py
index 70ccb46cc..fc940cdea 100644
--- a/python/ray/rllib/test/test_env_with_subprocess.py
+++ b/python/ray/rllib/test/test_env_with_subprocess.py
@@ -38,10 +38,10 @@ class EnvWithSubprocess(gym.Env):
atexit.register(lambda: self.subproc.kill())
def reset(self):
- return [0]
+ return 0
def step(self, action):
- return [0], 0, True, {}
+ return 0, 0, True, {}
def leaked_processes():
diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py
index 5712390c0..1fdfa5d74 100644
--- a/python/ray/rllib/test/test_multi_agent_env.py
+++ b/python/ray/rllib/test/test_multi_agent_env.py
@@ -22,6 +22,12 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.tune.registry import register_env
+def one_hot(i, n):
+ out = [0.0] * n
+ out[i] = 1.0
+ return out
+
+
class BasicMultiAgent(MultiAgentEnv):
"""Env of N independent agents, each of which exits after 25 steps."""
@@ -64,7 +70,7 @@ class RoundRobinMultiAgent(MultiAgentEnv):
self.last_info = {}
self.i = 0
self.num = num
- self.observation_space = gym.spaces.Discrete(2)
+ self.observation_space = gym.spaces.Discrete(10)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
@@ -290,7 +296,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def testMultiAgentSampleRoundRobin(self):
act_space = gym.spaces.Discrete(2)
- obs_space = gym.spaces.Discrete(2)
+ obs_space = gym.spaces.Discrete(10)
ev = PolicyEvaluator(
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
policy_graph={
@@ -303,10 +309,20 @@ class TestMultiAgentEnv(unittest.TestCase):
# since we round robin introduce agents into the env, some of the env
# steps don't count as proper transitions
self.assertEqual(batch.policy_batches["p0"].count, 42)
- self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10],
- [0, 1, 2, 3, 4] * 2)
- self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10],
- [1, 2, 3, 4, 5] * 2)
+ self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10], [
+ one_hot(0, 10),
+ one_hot(1, 10),
+ one_hot(2, 10),
+ one_hot(3, 10),
+ one_hot(4, 10),
+ ] * 2)
+ self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10], [
+ one_hot(1, 10),
+ one_hot(2, 10),
+ one_hot(3, 10),
+ one_hot(4, 10),
+ one_hot(5, 10),
+ ] * 2)
self.assertEqual(batch.policy_batches["p0"]["rewards"].tolist()[:10],
[100, 100, 100, 100, 0] * 2)
self.assertEqual(batch.policy_batches["p0"]["dones"].tolist()[:10],
diff --git a/python/ray/rllib/test/test_nested_spaces.py b/python/ray/rllib/test/test_nested_spaces.py
index 490e6af15..95744b7e2 100644
--- a/python/ray/rllib/test/test_nested_spaces.py
+++ b/python/ray/rllib/test/test_nested_spaces.py
@@ -13,6 +13,8 @@ import unittest
import ray
from ray.rllib.agents.pg import PGAgent
+from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
+from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.models import ModelCatalog
@@ -88,6 +90,34 @@ class NestedTupleEnv(gym.Env):
return TUPLE_SAMPLES[self.steps], 1, self.steps >= 5, {}
+class NestedMultiAgentEnv(MultiAgentEnv):
+ def __init__(self):
+ self.steps = 0
+
+ def reset(self):
+ return {
+ "dict_agent": DICT_SAMPLES[0],
+ "tuple_agent": TUPLE_SAMPLES[0],
+ }
+
+ def step(self, actions):
+ self.steps += 1
+ obs = {
+ "dict_agent": DICT_SAMPLES[self.steps],
+ "tuple_agent": TUPLE_SAMPLES[self.steps],
+ }
+ rew = {
+ "dict_agent": 0,
+ "tuple_agent": 0,
+ }
+ dones = {"__all__": self.steps >= 5}
+ infos = {
+ "dict_agent": {},
+ "tuple_agent": {},
+ }
+ return obs, rew, dones, infos
+
+
class InvalidModel(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
return "not", "valid"
@@ -107,7 +137,8 @@ class DictSpyModel(Model):
# redis to communicate back to our suite
ray.experimental.internal_kv._internal_kv_put(
"d_spy_in_{}".format(DictSpyModel.capture_index),
- pickle.dumps((pos, front_cam, task)))
+ pickle.dumps((pos, front_cam, task)),
+ overwrite=True)
DictSpyModel.capture_index += 1
return 0
@@ -135,7 +166,8 @@ class TupleSpyModel(Model):
# redis to communicate back to our suite
ray.experimental.internal_kv._internal_kv_put(
"t_spy_in_{}".format(TupleSpyModel.capture_index),
- pickle.dumps((pos, cam, task)))
+ pickle.dumps((pos, cam, task)),
+ overwrite=True)
TupleSpyModel.capture_index += 1
return 0
@@ -242,10 +274,8 @@ class NestedSpacesTest(unittest.TestCase):
self.doTestNestedDict(lambda _: SimpleServing(NestedDictEnv()))
def testNestedDictAsync(self):
- self.assertRaisesRegexp(
- ValueError, "Found raw Dict space.*",
- lambda: self.doTestNestedDict(
- lambda _: AsyncVectorEnv.wrap_async(NestedDictEnv())))
+ self.doTestNestedDict(
+ lambda _: AsyncVectorEnv.wrap_async(NestedDictEnv()))
def testNestedTupleGym(self):
self.doTestNestedTuple(lambda _: NestedTupleEnv())
@@ -258,10 +288,57 @@ class NestedSpacesTest(unittest.TestCase):
self.doTestNestedTuple(lambda _: SimpleServing(NestedTupleEnv()))
def testNestedTupleAsync(self):
- self.assertRaisesRegexp(
- ValueError, "Found raw Tuple space.*",
- lambda: self.doTestNestedTuple(
- lambda _: AsyncVectorEnv.wrap_async(NestedTupleEnv())))
+ self.doTestNestedTuple(
+ lambda _: AsyncVectorEnv.wrap_async(NestedTupleEnv()))
+
+ def testMultiAgentComplexSpaces(self):
+ ModelCatalog.register_custom_model("dict_spy", DictSpyModel)
+ ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel)
+ register_env("nested_ma", lambda _: NestedMultiAgentEnv())
+ act_space = spaces.Discrete(2)
+ pg = PGAgent(
+ env="nested_ma",
+ config={
+ "num_workers": 0,
+ "sample_batch_size": 5,
+ "multiagent": {
+ "policy_graphs": {
+ "tuple_policy": (
+ PGPolicyGraph, TUPLE_SPACE, act_space,
+ {"model": {"custom_model": "tuple_spy"}}),
+ "dict_policy": (
+ PGPolicyGraph, DICT_SPACE, act_space,
+ {"model": {"custom_model": "dict_spy"}}),
+ },
+ "policy_mapping_fn": lambda a: {
+ "tuple_agent": "tuple_policy",
+ "dict_agent": "dict_policy"}[a],
+ },
+ })
+ pg.train()
+
+ for i in range(4):
+ seen = pickle.loads(
+ ray.experimental.internal_kv._internal_kv_get(
+ "d_spy_in_{}".format(i)))
+ pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
+ cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
+ task_i = one_hot(
+ DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
+ self.assertEqual(seen[0][0].tolist(), pos_i)
+ self.assertEqual(seen[1][0].tolist(), cam_i)
+ self.assertEqual(seen[2][0].tolist(), task_i)
+
+ for i in range(4):
+ seen = pickle.loads(
+ ray.experimental.internal_kv._internal_kv_get(
+ "t_spy_in_{}".format(i)))
+ pos_i = TUPLE_SAMPLES[i][0].tolist()
+ cam_i = TUPLE_SAMPLES[i][1][0].tolist()
+ task_i = one_hot(TUPLE_SAMPLES[i][2], 5)
+ self.assertEqual(seen[0][0].tolist(), pos_i)
+ self.assertEqual(seen[1][0].tolist(), cam_i)
+ self.assertEqual(seen[2][0].tolist(), task_i)
if __name__ == "__main__":
diff --git a/python/ray/rllib/utils/filter.py b/python/ray/rllib/utils/filter.py
index fbdb39ae1..9a1f37dbd 100644
--- a/python/ray/rllib/utils/filter.py
+++ b/python/ray/rllib/utils/filter.py
@@ -2,9 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import logging
import numpy as np
import threading
+logger = logging.getLogger(__name__)
+
class Filter(object):
"""Processes input, possibly statefully."""
@@ -39,7 +42,10 @@ class NoFilter(Filter):
pass
def __call__(self, x, update=True):
- return np.asarray(x)
+ try:
+ return np.asarray(x)
+ except Exception:
+ raise ValueError("Failed to convert to array", x)
def apply_changes(self, other, *args, **kwargs):
pass