mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 05:16:11 +08:00
[rllib] fixes from dogfooding multi-agent (#3456)
auto wrap multi-agent dict and tuple spaces by keeping a policy -> preprocessor in the sampler add some Q-learning debug stats report min, max of custom metrics better errors
This commit is contained in:
@@ -253,6 +253,10 @@ class QLoss(object):
|
||||
self.td_error = tf.nn.softmax_cross_entropy_with_logits(
|
||||
labels=m, logits=q_logits_t_selected)
|
||||
self.loss = tf.reduce_mean(self.td_error * importance_weights)
|
||||
self.stats = {
|
||||
# TODO: better Q stats for dist dqn
|
||||
"mean_td_error": tf.reduce_mean(self.td_error),
|
||||
}
|
||||
else:
|
||||
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
|
||||
|
||||
@@ -264,6 +268,12 @@ class QLoss(object):
|
||||
q_t_selected - tf.stop_gradient(q_t_selected_target))
|
||||
self.loss = tf.reduce_mean(
|
||||
importance_weights * _huber_loss(self.td_error))
|
||||
self.stats = {
|
||||
"mean_q": tf.reduce_mean(q_t_selected),
|
||||
"min_q": tf.reduce_min(q_t_selected),
|
||||
"max_q": tf.reduce_max(q_t_selected),
|
||||
"mean_td_error": tf.reduce_mean(self.td_error),
|
||||
}
|
||||
|
||||
|
||||
class DQNPolicyGraph(TFPolicyGraph):
|
||||
@@ -430,6 +440,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
def extra_compute_grad_fetches(self):
|
||||
return {
|
||||
"td_error": self.loss.td_error,
|
||||
"stats": self.loss.stats,
|
||||
}
|
||||
|
||||
def postprocess_trajectory(self,
|
||||
|
||||
+7
@@ -291,6 +291,13 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
|
||||
assert isinstance(rewards, dict), "Not a multi-agent reward"
|
||||
assert isinstance(dones, dict), "Not a multi-agent return"
|
||||
assert isinstance(infos, dict), "Not a multi-agent info"
|
||||
if set(obs.keys()) != set(rewards.keys()):
|
||||
raise ValueError(
|
||||
"Key set for obs and rewards must be the same: "
|
||||
"{} vs {}".format(obs.keys(), rewards.keys()))
|
||||
if set(obs.keys()) != set(infos.keys()):
|
||||
raise ValueError("Key set for obs and infos must be the same: "
|
||||
"{} vs {}".format(obs.keys(), infos.keys()))
|
||||
if dones["__all__"]:
|
||||
self.dones.add(env_id)
|
||||
self.env_states[env_id].observe(obs, rewards, dones, infos)
|
||||
|
||||
+1
-1
@@ -56,7 +56,7 @@ class MultiAgentEnv(object):
|
||||
rewards (dict): Reward values for each ready agent. If the
|
||||
episode is just started, the value will be None.
|
||||
dones (dict): Done values for each ready agent. The special key
|
||||
"__all__" is used to indicate env termination.
|
||||
"__all__" (required) is used to indicate env termination.
|
||||
infos (dict): Info values for each ready agent.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -60,6 +60,7 @@ class MultiAgentEpisode(object):
|
||||
self._agent_to_policy = {}
|
||||
self._agent_to_rnn_state = {}
|
||||
self._agent_to_last_obs = {}
|
||||
self._agent_to_last_info = {}
|
||||
self._agent_to_last_action = {}
|
||||
self._agent_to_last_pi_info = {}
|
||||
self._agent_to_prev_action = {}
|
||||
@@ -81,6 +82,11 @@ class MultiAgentEpisode(object):
|
||||
|
||||
return self._agent_to_last_obs.get(agent_id)
|
||||
|
||||
def last_info_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last info for the specified agent."""
|
||||
|
||||
return self._agent_to_last_info.get(agent_id)
|
||||
|
||||
def last_action_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last action for the specified agent, or zeros."""
|
||||
|
||||
@@ -137,6 +143,9 @@ class MultiAgentEpisode(object):
|
||||
def _set_last_observation(self, agent_id, obs):
|
||||
self._agent_to_last_obs[agent_id] = obs
|
||||
|
||||
def _set_last_info(self, agent_id, info):
|
||||
self._agent_to_last_info[agent_id] = info
|
||||
|
||||
def _set_last_action(self, agent_id, action):
|
||||
self._agent_to_last_action[agent_id] = action
|
||||
|
||||
|
||||
@@ -80,8 +80,16 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
for policy_id, rewards in policy_rewards.copy().items():
|
||||
policy_rewards[policy_id] = np.mean(rewards)
|
||||
|
||||
for k, v_list in custom_metrics.items():
|
||||
custom_metrics[k] = np.mean(v_list)
|
||||
for k, v_list in custom_metrics.copy().items():
|
||||
custom_metrics[k + "_mean"] = np.mean(v_list)
|
||||
filt = [v for v in v_list if not np.isnan(v)]
|
||||
if filt:
|
||||
custom_metrics[k + "_min"] = np.min(filt)
|
||||
custom_metrics[k + "_max"] = np.max(filt)
|
||||
else:
|
||||
custom_metrics[k + "_min"] = float("nan")
|
||||
custom_metrics[k + "_max"] = float("nan")
|
||||
del custom_metrics[k]
|
||||
|
||||
return dict(
|
||||
episode_reward_max=max_reward,
|
||||
|
||||
@@ -8,7 +8,6 @@ import pickle
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
@@ -19,6 +18,8 @@ from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
|
||||
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import NoPreprocessor
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.compression import pack
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
@@ -191,23 +192,21 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.sample_batch_size = batch_steps * num_envs
|
||||
self.batch_mode = batch_mode
|
||||
self.compress_observations = compress_observations
|
||||
self.preprocessing_enabled = True
|
||||
|
||||
self.env = env_creator(env_context)
|
||||
if isinstance(self.env, MultiAgentEnv) or \
|
||||
isinstance(self.env, AsyncVectorEnv):
|
||||
|
||||
if model_config.get("custom_preprocessor"):
|
||||
raise ValueError(
|
||||
"Custom preprocessors are not supported for env types "
|
||||
"MultiAgentEnv and AsyncVectorEnv. Please preprocess "
|
||||
"observations in your env directly.")
|
||||
|
||||
def wrap(env):
|
||||
return env # we can't auto-wrap these env types
|
||||
elif is_atari(self.env) and \
|
||||
not model_config.get("custom_preprocessor") and \
|
||||
preprocessor_pref == "deepmind":
|
||||
|
||||
# Deepmind wrappers already handle all preprocessing
|
||||
self.preprocessing_enabled = False
|
||||
|
||||
if clip_rewards is None:
|
||||
clip_rewards = True
|
||||
|
||||
@@ -222,8 +221,6 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
else:
|
||||
|
||||
def wrap(env):
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(
|
||||
env, model_config)
|
||||
if monitor_path:
|
||||
env = _monitor(env, monitor_path)
|
||||
return env
|
||||
@@ -246,11 +243,11 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
config=tf.ConfigProto(
|
||||
gpu_options=tf.GPUOptions(allow_growth=True)))
|
||||
with self.tf_sess.as_default():
|
||||
self.policy_map = self._build_policy_map(
|
||||
policy_dict, policy_config)
|
||||
self.policy_map, self.preprocessors = \
|
||||
self._build_policy_map(policy_dict, policy_config)
|
||||
else:
|
||||
self.policy_map = self._build_policy_map(policy_dict,
|
||||
policy_config)
|
||||
self.policy_map, self.preprocessors = self._build_policy_map(
|
||||
policy_dict, policy_config)
|
||||
|
||||
self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
|
||||
if self.multiagent:
|
||||
@@ -286,6 +283,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.async_env,
|
||||
self.policy_map,
|
||||
policy_mapping_fn,
|
||||
self.preprocessors,
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
@@ -300,6 +298,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.async_env,
|
||||
self.policy_map,
|
||||
policy_mapping_fn,
|
||||
self.preprocessors,
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
@@ -314,24 +313,26 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
|
||||
def _build_policy_map(self, policy_dict, policy_config):
|
||||
policy_map = {}
|
||||
preprocessors = {}
|
||||
for name, (cls, obs_space, act_space,
|
||||
conf) in sorted(policy_dict.items()):
|
||||
merged_conf = merge_dicts(policy_config, conf)
|
||||
if self.preprocessing_enabled:
|
||||
preprocessor = ModelCatalog.get_preprocessor_for_space(
|
||||
obs_space, merged_conf.get("model"))
|
||||
preprocessors[name] = preprocessor
|
||||
obs_space = preprocessor.observation_space
|
||||
else:
|
||||
preprocessors[name] = NoPreprocessor(obs_space)
|
||||
if isinstance(obs_space, gym.spaces.Dict) or \
|
||||
isinstance(obs_space, gym.spaces.Tuple):
|
||||
raise ValueError(
|
||||
"Found raw Tuple|Dict space as input to policy graph. "
|
||||
"Please preprocess these observations with a "
|
||||
"Tuple|DictFlatteningPreprocessor.")
|
||||
with tf.variable_scope(name):
|
||||
if isinstance(obs_space, gym.spaces.Dict):
|
||||
raise ValueError(
|
||||
"Found raw Dict space as input to policy graph. "
|
||||
"Please preprocess your environment observations "
|
||||
"with DictFlatteningPreprocessor and set the "
|
||||
"obs space to `preprocessor.observation_space`.")
|
||||
elif isinstance(obs_space, gym.spaces.Tuple):
|
||||
raise ValueError(
|
||||
"Found raw Tuple space as input to policy graph. "
|
||||
"Please preprocess your environment observations "
|
||||
"with TupleFlatteningPreprocessor and set the "
|
||||
"obs space to `preprocessor.observation_space`.")
|
||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
return policy_map
|
||||
return policy_map, preprocessors
|
||||
|
||||
def sample(self):
|
||||
"""Evaluate the current policies and return a batch of experiences.
|
||||
@@ -554,6 +555,11 @@ def _validate_and_canonicalize(policy_graph, env):
|
||||
elif not issubclass(policy_graph, PolicyGraph):
|
||||
raise ValueError("policy_graph must be a rllib.PolicyGraph class")
|
||||
else:
|
||||
if (isinstance(env, MultiAgentEnv)
|
||||
and not hasattr(env, "observation_space")):
|
||||
raise ValueError(
|
||||
"MultiAgentEnv must have observation_space defined if run "
|
||||
"in a single-agent configuration.")
|
||||
return {
|
||||
DEFAULT_POLICY_ID: (policy_graph, env.observation_space,
|
||||
env.action_space, {})
|
||||
|
||||
@@ -79,6 +79,11 @@ class MultiAgentSampleBatchBuilder(object):
|
||||
self.agent_to_policy = {}
|
||||
self.count = 0 # increment this manually
|
||||
|
||||
def total(self):
|
||||
"""Returns summed number of steps across all agent buffers."""
|
||||
|
||||
return sum(p.count for p in self.policy_builders.values())
|
||||
|
||||
def has_pending_data(self):
|
||||
"""Returns whether there is pending unprocessed data."""
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from ray.rllib.models.action_dist import TupleActions
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_large_batch_warned = False
|
||||
|
||||
RolloutMetrics = namedtuple(
|
||||
"RolloutMetrics",
|
||||
@@ -42,6 +43,7 @@ class SyncSampler(object):
|
||||
env,
|
||||
policies,
|
||||
policy_mapping_fn,
|
||||
preprocessors,
|
||||
obs_filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
@@ -55,13 +57,14 @@ class SyncSampler(object):
|
||||
self.horizon = horizon
|
||||
self.policies = policies
|
||||
self.policy_mapping_fn = policy_mapping_fn
|
||||
self._obs_filters = obs_filters
|
||||
self.preprocessors = preprocessors
|
||||
self.obs_filters = obs_filters
|
||||
self.extra_batches = queue.Queue()
|
||||
self.rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self._obs_filters, clip_rewards, clip_actions, pack, callbacks,
|
||||
tf_sess)
|
||||
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
||||
pack, callbacks, tf_sess)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
def get_data(self):
|
||||
@@ -101,6 +104,7 @@ class AsyncSampler(threading.Thread):
|
||||
env,
|
||||
policies,
|
||||
policy_mapping_fn,
|
||||
preprocessors,
|
||||
obs_filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
@@ -121,7 +125,8 @@ class AsyncSampler(threading.Thread):
|
||||
self.horizon = horizon
|
||||
self.policies = policies
|
||||
self.policy_mapping_fn = policy_mapping_fn
|
||||
self._obs_filters = obs_filters
|
||||
self.preprocessors = preprocessors
|
||||
self.obs_filters = obs_filters
|
||||
self.clip_rewards = clip_rewards
|
||||
self.daemon = True
|
||||
self.pack = pack
|
||||
@@ -140,8 +145,8 @@ class AsyncSampler(threading.Thread):
|
||||
rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self._obs_filters, self.clip_rewards, self.clip_actions, self.pack,
|
||||
self.callbacks, self.tf_sess)
|
||||
self.preprocessors, self.obs_filters, self.clip_rewards,
|
||||
self.clip_actions, self.pack, self.callbacks, self.tf_sess)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
@@ -200,6 +205,7 @@ def _env_runner(async_vector_env,
|
||||
policy_mapping_fn,
|
||||
unroll_length,
|
||||
horizon,
|
||||
preprocessors,
|
||||
obs_filters,
|
||||
clip_rewards,
|
||||
clip_actions,
|
||||
@@ -218,6 +224,8 @@ def _env_runner(async_vector_env,
|
||||
unroll_length (int): Number of episode steps before `SampleBatch` is
|
||||
yielded. Set to infinity to yield complete episodes.
|
||||
horizon (int): Horizon of the episode.
|
||||
preprocessors (dict): Map of policy id to preprocessor for the
|
||||
observations prior to filtering.
|
||||
obs_filters (dict): Map of policy id to filter used to process
|
||||
observations for the policy.
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
@@ -273,7 +281,7 @@ def _env_runner(async_vector_env,
|
||||
active_envs, to_eval, outputs = _process_observations(
|
||||
async_vector_env, policies, batch_builder_pool, active_episodes,
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
|
||||
obs_filters, unroll_length, pack, callbacks)
|
||||
preprocessors, obs_filters, unroll_length, pack, callbacks)
|
||||
for o in outputs:
|
||||
yield o
|
||||
|
||||
@@ -293,8 +301,8 @@ def _env_runner(async_vector_env,
|
||||
|
||||
def _process_observations(async_vector_env, policies, batch_builder_pool,
|
||||
active_episodes, unfiltered_obs, rewards, dones,
|
||||
infos, off_policy_actions, horizon, obs_filters,
|
||||
unroll_length, pack, callbacks):
|
||||
infos, off_policy_actions, horizon, preprocessors,
|
||||
obs_filters, unroll_length, pack, callbacks):
|
||||
"""Record new data from the environment and prepare for policy evaluation.
|
||||
|
||||
Returns:
|
||||
@@ -316,6 +324,21 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
|
||||
episode.batch_builder.count += 1
|
||||
episode._add_agent_rewards(rewards[env_id])
|
||||
|
||||
global _large_batch_warned
|
||||
if (not _large_batch_warned and
|
||||
episode.batch_builder.total() > max(1000, unroll_length * 10)):
|
||||
_large_batch_warned = True
|
||||
logger.warn(
|
||||
"More than {} observations for {} env steps ".format(
|
||||
episode.batch_builder.total(),
|
||||
episode.batch_builder.count) + "are buffered in "
|
||||
"the sampler. If this is not intentional, check that the "
|
||||
"the `horizon` config is set correctly, or consider setting "
|
||||
"`batch_mode` to 'truncate_episodes'. Note that in "
|
||||
"multi-agent environments, `sample_batch_size` sets the "
|
||||
"batch size based on environment steps, not the steps of "
|
||||
"individual agents.")
|
||||
|
||||
# Check episode termination conditions
|
||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||
all_done = True
|
||||
@@ -336,7 +359,9 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
|
||||
# For each agent in the environment
|
||||
for agent_id, raw_obs in agent_obs.items():
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs)
|
||||
prep_obs = _get_or_raise(preprocessors,
|
||||
policy_id).transform(raw_obs)
|
||||
filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
|
||||
agent_done = bool(all_done or dones[env_id].get(agent_id))
|
||||
if not agent_done:
|
||||
to_eval[policy_id].append(
|
||||
@@ -347,6 +372,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
|
||||
|
||||
last_observation = episode.last_observation_for(agent_id)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
episode._set_last_info(agent_id, infos[env_id][agent_id])
|
||||
|
||||
# Record transition info if applicable
|
||||
if last_observation is not None and \
|
||||
@@ -406,8 +432,10 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
|
||||
for agent_id, raw_obs in resetted_obs.items():
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
prep_obs = _get_or_raise(preprocessors,
|
||||
policy_id).transform(raw_obs)
|
||||
filtered_obs = _get_or_raise(obs_filters,
|
||||
policy_id)(raw_obs)
|
||||
policy_id)(prep_obs)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(
|
||||
|
||||
@@ -25,10 +25,10 @@ def on_episode_step(info):
|
||||
|
||||
def on_episode_end(info):
|
||||
episode = info["episode"]
|
||||
mean_pole_angle = np.mean(episode.user_data["pole_angles"])
|
||||
pole_angle = np.mean(episode.user_data["pole_angles"])
|
||||
print("episode {} ended with length {} and pole angles {}".format(
|
||||
episode.episode_id, episode.length, mean_pole_angle))
|
||||
episode.custom_metrics["mean_pole_angle"] = mean_pole_angle
|
||||
episode.episode_id, episode.length, pole_angle))
|
||||
episode.custom_metrics["pole_angle"] = pole_angle
|
||||
|
||||
|
||||
def on_sample_end(info):
|
||||
@@ -70,6 +70,8 @@ if __name__ == "__main__":
|
||||
# verify custom metrics for integration tests
|
||||
custom_metrics = trials[0].last_result["custom_metrics"]
|
||||
print(custom_metrics)
|
||||
assert "mean_pole_angle" in custom_metrics
|
||||
assert type(custom_metrics["mean_pole_angle"]) is float
|
||||
assert "pole_angle_mean" in custom_metrics
|
||||
assert "pole_angle_min" in custom_metrics
|
||||
assert "pole_angle_max" in custom_metrics
|
||||
assert type(custom_metrics["pole_angle_mean"]) is float
|
||||
assert "callback_ok" in trials[0].last_result
|
||||
|
||||
@@ -271,15 +271,26 @@ class ModelCatalog(object):
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessor(env, options=None):
|
||||
"""Returns a suitable processor for the given environment.
|
||||
"""Returns a suitable preprocessor for the given env.
|
||||
|
||||
This is a wrapper for get_preprocessor_for_space().
|
||||
"""
|
||||
|
||||
return ModelCatalog.get_preprocessor_for_space(env.observation_space,
|
||||
options)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessor_for_space(observation_space, options=None):
|
||||
"""Returns a suitable preprocessor for the given observation space.
|
||||
|
||||
Args:
|
||||
env (gym.Env|VectorEnv|ExternalEnv): The environment to wrap.
|
||||
observation_space (Space): The input observation space.
|
||||
options (dict): Options to pass to the preprocessor.
|
||||
|
||||
Returns:
|
||||
preprocessor (Preprocessor): Preprocessor for the env observations.
|
||||
preprocessor (Preprocessor): Preprocessor for the observations.
|
||||
"""
|
||||
|
||||
options = options or MODEL_DEFAULTS
|
||||
for k in options.keys():
|
||||
if k not in MODEL_DEFAULTS:
|
||||
@@ -290,13 +301,13 @@ class ModelCatalog(object):
|
||||
preprocessor = options["custom_preprocessor"]
|
||||
logger.info("Using custom preprocessor {}".format(preprocessor))
|
||||
prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
|
||||
env.observation_space, options)
|
||||
observation_space, options)
|
||||
else:
|
||||
cls = get_preprocessor(env.observation_space)
|
||||
prep = cls(env.observation_space, options)
|
||||
cls = get_preprocessor(observation_space)
|
||||
prep = cls(observation_space, options)
|
||||
|
||||
logger.debug("Created preprocessor {}: {} -> {}".format(
|
||||
prep, env.observation_space, prep.shape))
|
||||
prep, observation_space, prep.shape))
|
||||
return prep
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -109,6 +109,9 @@ class OneHotPreprocessor(Preprocessor):
|
||||
|
||||
def transform(self, observation):
|
||||
arr = np.zeros(self._obs_space.n)
|
||||
if not self._obs_space.contains(observation):
|
||||
raise ValueError("Observation outside expected value range",
|
||||
self._obs_space, observation)
|
||||
arr[observation] = 1
|
||||
return arr
|
||||
|
||||
|
||||
@@ -135,6 +135,7 @@ class LearnerThread(threading.Thread):
|
||||
self.daemon = True
|
||||
self.weights_updated = False
|
||||
self.stopped = False
|
||||
self.stats = {}
|
||||
|
||||
def run(self):
|
||||
while not self.stopped:
|
||||
@@ -151,6 +152,8 @@ class LearnerThread(threading.Thread):
|
||||
prio_dict[pid] = (
|
||||
replay.policy_batches[pid]["batch_indexes"],
|
||||
info["td_error"])
|
||||
if "stats" in info:
|
||||
self.stats[pid] = info["stats"]
|
||||
# send `replay` back also so that it gets released by the original
|
||||
# thread: https://github.com/ray-project/ray/issues/2610
|
||||
self.outqueue.put((ra, replay, prio_dict, replay.count))
|
||||
@@ -331,4 +334,6 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
}
|
||||
if self.debug:
|
||||
stats.update(debug_stats)
|
||||
if self.learner.stats:
|
||||
stats["learner"] = self.learner.stats
|
||||
return dict(PolicyOptimizer.stats(self), **stats)
|
||||
|
||||
@@ -53,6 +53,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
self.replay_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.throughput = RunningStat()
|
||||
self.learner_stats = {}
|
||||
|
||||
# Set up replay buffer
|
||||
if prioritized_replay:
|
||||
@@ -111,6 +112,8 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
with self.grad_timer:
|
||||
info_dict = self.local_evaluator.compute_apply(samples)
|
||||
for policy_id, info in info_dict.items():
|
||||
if "stats" in info:
|
||||
self.learner_stats[policy_id] = info["stats"]
|
||||
replay_buffer = self.replay_buffers[policy_id]
|
||||
if isinstance(replay_buffer, PrioritizedReplayBuffer):
|
||||
td_error = info["td_error"]
|
||||
@@ -160,4 +163,5 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
"opt_peak_throughput": round(self.grad_timer.mean_throughput,
|
||||
3),
|
||||
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
|
||||
"learner": self.learner_stats,
|
||||
})
|
||||
|
||||
@@ -38,10 +38,10 @@ class EnvWithSubprocess(gym.Env):
|
||||
atexit.register(lambda: self.subproc.kill())
|
||||
|
||||
def reset(self):
|
||||
return [0]
|
||||
return 0
|
||||
|
||||
def step(self, action):
|
||||
return [0], 0, True, {}
|
||||
return 0, 0, True, {}
|
||||
|
||||
|
||||
def leaked_processes():
|
||||
|
||||
@@ -22,6 +22,12 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
def one_hot(i, n):
|
||||
out = [0.0] * n
|
||||
out[i] = 1.0
|
||||
return out
|
||||
|
||||
|
||||
class BasicMultiAgent(MultiAgentEnv):
|
||||
"""Env of N independent agents, each of which exits after 25 steps."""
|
||||
|
||||
@@ -64,7 +70,7 @@ class RoundRobinMultiAgent(MultiAgentEnv):
|
||||
self.last_info = {}
|
||||
self.i = 0
|
||||
self.num = num
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
self.observation_space = gym.spaces.Discrete(10)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
@@ -290,7 +296,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
|
||||
def testMultiAgentSampleRoundRobin(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(10)
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
|
||||
policy_graph={
|
||||
@@ -303,10 +309,20 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
# since we round robin introduce agents into the env, some of the env
|
||||
# steps don't count as proper transitions
|
||||
self.assertEqual(batch.policy_batches["p0"].count, 42)
|
||||
self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10],
|
||||
[0, 1, 2, 3, 4] * 2)
|
||||
self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10],
|
||||
[1, 2, 3, 4, 5] * 2)
|
||||
self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10], [
|
||||
one_hot(0, 10),
|
||||
one_hot(1, 10),
|
||||
one_hot(2, 10),
|
||||
one_hot(3, 10),
|
||||
one_hot(4, 10),
|
||||
] * 2)
|
||||
self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10], [
|
||||
one_hot(1, 10),
|
||||
one_hot(2, 10),
|
||||
one_hot(3, 10),
|
||||
one_hot(4, 10),
|
||||
one_hot(5, 10),
|
||||
] * 2)
|
||||
self.assertEqual(batch.policy_batches["p0"]["rewards"].tolist()[:10],
|
||||
[100, 100, 100, 100, 0] * 2)
|
||||
self.assertEqual(batch.policy_batches["p0"]["dones"].tolist()[:10],
|
||||
|
||||
@@ -13,6 +13,8 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.models import ModelCatalog
|
||||
@@ -88,6 +90,34 @@ class NestedTupleEnv(gym.Env):
|
||||
return TUPLE_SAMPLES[self.steps], 1, self.steps >= 5, {}
|
||||
|
||||
|
||||
class NestedMultiAgentEnv(MultiAgentEnv):
|
||||
def __init__(self):
|
||||
self.steps = 0
|
||||
|
||||
def reset(self):
|
||||
return {
|
||||
"dict_agent": DICT_SAMPLES[0],
|
||||
"tuple_agent": TUPLE_SAMPLES[0],
|
||||
}
|
||||
|
||||
def step(self, actions):
|
||||
self.steps += 1
|
||||
obs = {
|
||||
"dict_agent": DICT_SAMPLES[self.steps],
|
||||
"tuple_agent": TUPLE_SAMPLES[self.steps],
|
||||
}
|
||||
rew = {
|
||||
"dict_agent": 0,
|
||||
"tuple_agent": 0,
|
||||
}
|
||||
dones = {"__all__": self.steps >= 5}
|
||||
infos = {
|
||||
"dict_agent": {},
|
||||
"tuple_agent": {},
|
||||
}
|
||||
return obs, rew, dones, infos
|
||||
|
||||
|
||||
class InvalidModel(Model):
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
return "not", "valid"
|
||||
@@ -107,7 +137,8 @@ class DictSpyModel(Model):
|
||||
# redis to communicate back to our suite
|
||||
ray.experimental.internal_kv._internal_kv_put(
|
||||
"d_spy_in_{}".format(DictSpyModel.capture_index),
|
||||
pickle.dumps((pos, front_cam, task)))
|
||||
pickle.dumps((pos, front_cam, task)),
|
||||
overwrite=True)
|
||||
DictSpyModel.capture_index += 1
|
||||
return 0
|
||||
|
||||
@@ -135,7 +166,8 @@ class TupleSpyModel(Model):
|
||||
# redis to communicate back to our suite
|
||||
ray.experimental.internal_kv._internal_kv_put(
|
||||
"t_spy_in_{}".format(TupleSpyModel.capture_index),
|
||||
pickle.dumps((pos, cam, task)))
|
||||
pickle.dumps((pos, cam, task)),
|
||||
overwrite=True)
|
||||
TupleSpyModel.capture_index += 1
|
||||
return 0
|
||||
|
||||
@@ -242,10 +274,8 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
self.doTestNestedDict(lambda _: SimpleServing(NestedDictEnv()))
|
||||
|
||||
def testNestedDictAsync(self):
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Found raw Dict space.*",
|
||||
lambda: self.doTestNestedDict(
|
||||
lambda _: AsyncVectorEnv.wrap_async(NestedDictEnv())))
|
||||
self.doTestNestedDict(
|
||||
lambda _: AsyncVectorEnv.wrap_async(NestedDictEnv()))
|
||||
|
||||
def testNestedTupleGym(self):
|
||||
self.doTestNestedTuple(lambda _: NestedTupleEnv())
|
||||
@@ -258,10 +288,57 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
self.doTestNestedTuple(lambda _: SimpleServing(NestedTupleEnv()))
|
||||
|
||||
def testNestedTupleAsync(self):
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Found raw Tuple space.*",
|
||||
lambda: self.doTestNestedTuple(
|
||||
lambda _: AsyncVectorEnv.wrap_async(NestedTupleEnv())))
|
||||
self.doTestNestedTuple(
|
||||
lambda _: AsyncVectorEnv.wrap_async(NestedTupleEnv()))
|
||||
|
||||
def testMultiAgentComplexSpaces(self):
|
||||
ModelCatalog.register_custom_model("dict_spy", DictSpyModel)
|
||||
ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel)
|
||||
register_env("nested_ma", lambda _: NestedMultiAgentEnv())
|
||||
act_space = spaces.Discrete(2)
|
||||
pg = PGAgent(
|
||||
env="nested_ma",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 5,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"tuple_policy": (
|
||||
PGPolicyGraph, TUPLE_SPACE, act_space,
|
||||
{"model": {"custom_model": "tuple_spy"}}),
|
||||
"dict_policy": (
|
||||
PGPolicyGraph, DICT_SPACE, act_space,
|
||||
{"model": {"custom_model": "dict_spy"}}),
|
||||
},
|
||||
"policy_mapping_fn": lambda a: {
|
||||
"tuple_agent": "tuple_policy",
|
||||
"dict_agent": "dict_policy"}[a],
|
||||
},
|
||||
})
|
||||
pg.train()
|
||||
|
||||
for i in range(4):
|
||||
seen = pickle.loads(
|
||||
ray.experimental.internal_kv._internal_kv_get(
|
||||
"d_spy_in_{}".format(i)))
|
||||
pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
|
||||
cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
|
||||
task_i = one_hot(
|
||||
DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
|
||||
self.assertEqual(seen[0][0].tolist(), pos_i)
|
||||
self.assertEqual(seen[1][0].tolist(), cam_i)
|
||||
self.assertEqual(seen[2][0].tolist(), task_i)
|
||||
|
||||
for i in range(4):
|
||||
seen = pickle.loads(
|
||||
ray.experimental.internal_kv._internal_kv_get(
|
||||
"t_spy_in_{}".format(i)))
|
||||
pos_i = TUPLE_SAMPLES[i][0].tolist()
|
||||
cam_i = TUPLE_SAMPLES[i][1][0].tolist()
|
||||
task_i = one_hot(TUPLE_SAMPLES[i][2], 5)
|
||||
self.assertEqual(seen[0][0].tolist(), pos_i)
|
||||
self.assertEqual(seen[1][0].tolist(), cam_i)
|
||||
self.assertEqual(seen[2][0].tolist(), task_i)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -2,9 +2,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Filter(object):
|
||||
"""Processes input, possibly statefully."""
|
||||
@@ -39,7 +42,10 @@ class NoFilter(Filter):
|
||||
pass
|
||||
|
||||
def __call__(self, x, update=True):
|
||||
return np.asarray(x)
|
||||
try:
|
||||
return np.asarray(x)
|
||||
except Exception:
|
||||
raise ValueError("Failed to convert to array", x)
|
||||
|
||||
def apply_changes(self, other, *args, **kwargs):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user