[rllib] fixes from dogfooding multi-agent (#3456)

auto wrap multi-agent dict and tuple spaces by keeping a policy -> preprocessor in the sampler
add some Q-learning debug stats
report min, max of custom metrics
better errors
This commit is contained in:
Eric Liang
2018-12-05 23:31:45 -08:00
committed by GitHub
parent 7a79b7f62c
commit d864f299d7
19 changed files with 277 additions and 75 deletions
@@ -253,6 +253,10 @@ class QLoss(object):
self.td_error = tf.nn.softmax_cross_entropy_with_logits(
labels=m, logits=q_logits_t_selected)
self.loss = tf.reduce_mean(self.td_error * importance_weights)
self.stats = {
# TODO: better Q stats for dist dqn
"mean_td_error": tf.reduce_mean(self.td_error),
}
else:
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
@@ -264,6 +268,12 @@ class QLoss(object):
q_t_selected - tf.stop_gradient(q_t_selected_target))
self.loss = tf.reduce_mean(
importance_weights * _huber_loss(self.td_error))
self.stats = {
"mean_q": tf.reduce_mean(q_t_selected),
"min_q": tf.reduce_min(q_t_selected),
"max_q": tf.reduce_max(q_t_selected),
"mean_td_error": tf.reduce_mean(self.td_error),
}
class DQNPolicyGraph(TFPolicyGraph):
@@ -430,6 +440,7 @@ class DQNPolicyGraph(TFPolicyGraph):
def extra_compute_grad_fetches(self):
return {
"td_error": self.loss.td_error,
"stats": self.loss.stats,
}
def postprocess_trajectory(self,
+7
View File
@@ -291,6 +291,13 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
assert isinstance(rewards, dict), "Not a multi-agent reward"
assert isinstance(dones, dict), "Not a multi-agent return"
assert isinstance(infos, dict), "Not a multi-agent info"
if set(obs.keys()) != set(rewards.keys()):
raise ValueError(
"Key set for obs and rewards must be the same: "
"{} vs {}".format(obs.keys(), rewards.keys()))
if set(obs.keys()) != set(infos.keys()):
raise ValueError("Key set for obs and infos must be the same: "
"{} vs {}".format(obs.keys(), infos.keys()))
if dones["__all__"]:
self.dones.add(env_id)
self.env_states[env_id].observe(obs, rewards, dones, infos)
+1 -1
View File
@@ -56,7 +56,7 @@ class MultiAgentEnv(object):
rewards (dict): Reward values for each ready agent. If the
episode is just started, the value will be None.
dones (dict): Done values for each ready agent. The special key
"__all__" is used to indicate env termination.
"__all__" (required) is used to indicate env termination.
infos (dict): Info values for each ready agent.
"""
raise NotImplementedError
+9
View File
@@ -60,6 +60,7 @@ class MultiAgentEpisode(object):
self._agent_to_policy = {}
self._agent_to_rnn_state = {}
self._agent_to_last_obs = {}
self._agent_to_last_info = {}
self._agent_to_last_action = {}
self._agent_to_last_pi_info = {}
self._agent_to_prev_action = {}
@@ -81,6 +82,11 @@ class MultiAgentEpisode(object):
return self._agent_to_last_obs.get(agent_id)
def last_info_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last info for the specified agent."""
return self._agent_to_last_info.get(agent_id)
def last_action_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last action for the specified agent, or zeros."""
@@ -137,6 +143,9 @@ class MultiAgentEpisode(object):
def _set_last_observation(self, agent_id, obs):
self._agent_to_last_obs[agent_id] = obs
def _set_last_info(self, agent_id, info):
self._agent_to_last_info[agent_id] = info
def _set_last_action(self, agent_id, action):
self._agent_to_last_action[agent_id] = action
+10 -2
View File
@@ -80,8 +80,16 @@ def summarize_episodes(episodes, new_episodes, num_dropped):
for policy_id, rewards in policy_rewards.copy().items():
policy_rewards[policy_id] = np.mean(rewards)
for k, v_list in custom_metrics.items():
custom_metrics[k] = np.mean(v_list)
for k, v_list in custom_metrics.copy().items():
custom_metrics[k + "_mean"] = np.mean(v_list)
filt = [v for v in v_list if not np.isnan(v)]
if filt:
custom_metrics[k + "_min"] = np.min(filt)
custom_metrics[k + "_max"] = np.max(filt)
else:
custom_metrics[k + "_min"] = float("nan")
custom_metrics[k + "_max"] = float("nan")
del custom_metrics[k]
return dict(
episode_reward_max=max_reward,
+32 -26
View File
@@ -8,7 +8,6 @@ import pickle
import tensorflow as tf
import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.env.env_context import EnvContext
@@ -19,6 +18,8 @@ from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import NoPreprocessor
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.compression import pack
from ray.rllib.utils.filter import get_filter
@@ -191,23 +192,21 @@ class PolicyEvaluator(EvaluatorInterface):
self.sample_batch_size = batch_steps * num_envs
self.batch_mode = batch_mode
self.compress_observations = compress_observations
self.preprocessing_enabled = True
self.env = env_creator(env_context)
if isinstance(self.env, MultiAgentEnv) or \
isinstance(self.env, AsyncVectorEnv):
if model_config.get("custom_preprocessor"):
raise ValueError(
"Custom preprocessors are not supported for env types "
"MultiAgentEnv and AsyncVectorEnv. Please preprocess "
"observations in your env directly.")
def wrap(env):
return env # we can't auto-wrap these env types
elif is_atari(self.env) and \
not model_config.get("custom_preprocessor") and \
preprocessor_pref == "deepmind":
# Deepmind wrappers already handle all preprocessing
self.preprocessing_enabled = False
if clip_rewards is None:
clip_rewards = True
@@ -222,8 +221,6 @@ class PolicyEvaluator(EvaluatorInterface):
else:
def wrap(env):
env = ModelCatalog.get_preprocessor_as_wrapper(
env, model_config)
if monitor_path:
env = _monitor(env, monitor_path)
return env
@@ -246,11 +243,11 @@ class PolicyEvaluator(EvaluatorInterface):
config=tf.ConfigProto(
gpu_options=tf.GPUOptions(allow_growth=True)))
with self.tf_sess.as_default():
self.policy_map = self._build_policy_map(
policy_dict, policy_config)
self.policy_map, self.preprocessors = \
self._build_policy_map(policy_dict, policy_config)
else:
self.policy_map = self._build_policy_map(policy_dict,
policy_config)
self.policy_map, self.preprocessors = self._build_policy_map(
policy_dict, policy_config)
self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent:
@@ -286,6 +283,7 @@ class PolicyEvaluator(EvaluatorInterface):
self.async_env,
self.policy_map,
policy_mapping_fn,
self.preprocessors,
self.filters,
clip_rewards,
unroll_length,
@@ -300,6 +298,7 @@ class PolicyEvaluator(EvaluatorInterface):
self.async_env,
self.policy_map,
policy_mapping_fn,
self.preprocessors,
self.filters,
clip_rewards,
unroll_length,
@@ -314,24 +313,26 @@ class PolicyEvaluator(EvaluatorInterface):
def _build_policy_map(self, policy_dict, policy_config):
policy_map = {}
preprocessors = {}
for name, (cls, obs_space, act_space,
conf) in sorted(policy_dict.items()):
merged_conf = merge_dicts(policy_config, conf)
if self.preprocessing_enabled:
preprocessor = ModelCatalog.get_preprocessor_for_space(
obs_space, merged_conf.get("model"))
preprocessors[name] = preprocessor
obs_space = preprocessor.observation_space
else:
preprocessors[name] = NoPreprocessor(obs_space)
if isinstance(obs_space, gym.spaces.Dict) or \
isinstance(obs_space, gym.spaces.Tuple):
raise ValueError(
"Found raw Tuple|Dict space as input to policy graph. "
"Please preprocess these observations with a "
"Tuple|DictFlatteningPreprocessor.")
with tf.variable_scope(name):
if isinstance(obs_space, gym.spaces.Dict):
raise ValueError(
"Found raw Dict space as input to policy graph. "
"Please preprocess your environment observations "
"with DictFlatteningPreprocessor and set the "
"obs space to `preprocessor.observation_space`.")
elif isinstance(obs_space, gym.spaces.Tuple):
raise ValueError(
"Found raw Tuple space as input to policy graph. "
"Please preprocess your environment observations "
"with TupleFlatteningPreprocessor and set the "
"obs space to `preprocessor.observation_space`.")
policy_map[name] = cls(obs_space, act_space, merged_conf)
return policy_map
return policy_map, preprocessors
def sample(self):
"""Evaluate the current policies and return a batch of experiences.
@@ -554,6 +555,11 @@ def _validate_and_canonicalize(policy_graph, env):
elif not issubclass(policy_graph, PolicyGraph):
raise ValueError("policy_graph must be a rllib.PolicyGraph class")
else:
if (isinstance(env, MultiAgentEnv)
and not hasattr(env, "observation_space")):
raise ValueError(
"MultiAgentEnv must have observation_space defined if run "
"in a single-agent configuration.")
return {
DEFAULT_POLICY_ID: (policy_graph, env.observation_space,
env.action_space, {})
@@ -79,6 +79,11 @@ class MultiAgentSampleBatchBuilder(object):
self.agent_to_policy = {}
self.count = 0 # increment this manually
def total(self):
"""Returns summed number of steps across all agent buffers."""
return sum(p.count for p in self.policy_builders.values())
def has_pending_data(self):
"""Returns whether there is pending unprocessed data."""
+39 -11
View File
@@ -19,6 +19,7 @@ from ray.rllib.models.action_dist import TupleActions
from ray.rllib.utils.tf_run_builder import TFRunBuilder
logger = logging.getLogger(__name__)
_large_batch_warned = False
RolloutMetrics = namedtuple(
"RolloutMetrics",
@@ -42,6 +43,7 @@ class SyncSampler(object):
env,
policies,
policy_mapping_fn,
preprocessors,
obs_filters,
clip_rewards,
unroll_length,
@@ -55,13 +57,14 @@ class SyncSampler(object):
self.horizon = horizon
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
self._obs_filters = obs_filters
self.preprocessors = preprocessors
self.obs_filters = obs_filters
self.extra_batches = queue.Queue()
self.rollout_provider = _env_runner(
self.async_vector_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self._obs_filters, clip_rewards, clip_actions, pack, callbacks,
tf_sess)
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
pack, callbacks, tf_sess)
self.metrics_queue = queue.Queue()
def get_data(self):
@@ -101,6 +104,7 @@ class AsyncSampler(threading.Thread):
env,
policies,
policy_mapping_fn,
preprocessors,
obs_filters,
clip_rewards,
unroll_length,
@@ -121,7 +125,8 @@ class AsyncSampler(threading.Thread):
self.horizon = horizon
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
self._obs_filters = obs_filters
self.preprocessors = preprocessors
self.obs_filters = obs_filters
self.clip_rewards = clip_rewards
self.daemon = True
self.pack = pack
@@ -140,8 +145,8 @@ class AsyncSampler(threading.Thread):
rollout_provider = _env_runner(
self.async_vector_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self._obs_filters, self.clip_rewards, self.clip_actions, self.pack,
self.callbacks, self.tf_sess)
self.preprocessors, self.obs_filters, self.clip_rewards,
self.clip_actions, self.pack, self.callbacks, self.tf_sess)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
@@ -200,6 +205,7 @@ def _env_runner(async_vector_env,
policy_mapping_fn,
unroll_length,
horizon,
preprocessors,
obs_filters,
clip_rewards,
clip_actions,
@@ -218,6 +224,8 @@ def _env_runner(async_vector_env,
unroll_length (int): Number of episode steps before `SampleBatch` is
yielded. Set to infinity to yield complete episodes.
horizon (int): Horizon of the episode.
preprocessors (dict): Map of policy id to preprocessor for the
observations prior to filtering.
obs_filters (dict): Map of policy id to filter used to process
observations for the policy.
clip_rewards (bool): Whether to clip rewards before postprocessing.
@@ -273,7 +281,7 @@ def _env_runner(async_vector_env,
active_envs, to_eval, outputs = _process_observations(
async_vector_env, policies, batch_builder_pool, active_episodes,
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
obs_filters, unroll_length, pack, callbacks)
preprocessors, obs_filters, unroll_length, pack, callbacks)
for o in outputs:
yield o
@@ -293,8 +301,8 @@ def _env_runner(async_vector_env,
def _process_observations(async_vector_env, policies, batch_builder_pool,
active_episodes, unfiltered_obs, rewards, dones,
infos, off_policy_actions, horizon, obs_filters,
unroll_length, pack, callbacks):
infos, off_policy_actions, horizon, preprocessors,
obs_filters, unroll_length, pack, callbacks):
"""Record new data from the environment and prepare for policy evaluation.
Returns:
@@ -316,6 +324,21 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
episode.batch_builder.count += 1
episode._add_agent_rewards(rewards[env_id])
global _large_batch_warned
if (not _large_batch_warned and
episode.batch_builder.total() > max(1000, unroll_length * 10)):
_large_batch_warned = True
logger.warn(
"More than {} observations for {} env steps ".format(
episode.batch_builder.total(),
episode.batch_builder.count) + "are buffered in "
"the sampler. If this is not intentional, check that the "
"the `horizon` config is set correctly, or consider setting "
"`batch_mode` to 'truncate_episodes'. Note that in "
"multi-agent environments, `sample_batch_size` sets the "
"batch size based on environment steps, not the steps of "
"individual agents.")
# Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon:
all_done = True
@@ -336,7 +359,9 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
# For each agent in the environment
for agent_id, raw_obs in agent_obs.items():
policy_id = episode.policy_for(agent_id)
filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs)
prep_obs = _get_or_raise(preprocessors,
policy_id).transform(raw_obs)
filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
agent_done = bool(all_done or dones[env_id].get(agent_id))
if not agent_done:
to_eval[policy_id].append(
@@ -347,6 +372,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
last_observation = episode.last_observation_for(agent_id)
episode._set_last_observation(agent_id, filtered_obs)
episode._set_last_info(agent_id, infos[env_id][agent_id])
# Record transition info if applicable
if last_observation is not None and \
@@ -406,8 +432,10 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
for agent_id, raw_obs in resetted_obs.items():
policy_id = episode.policy_for(agent_id)
policy = _get_or_raise(policies, policy_id)
prep_obs = _get_or_raise(preprocessors,
policy_id).transform(raw_obs)
filtered_obs = _get_or_raise(obs_filters,
policy_id)(raw_obs)
policy_id)(prep_obs)
episode._set_last_observation(agent_id, filtered_obs)
to_eval[policy_id].append(
PolicyEvalData(
@@ -25,10 +25,10 @@ def on_episode_step(info):
def on_episode_end(info):
episode = info["episode"]
mean_pole_angle = np.mean(episode.user_data["pole_angles"])
pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode {} ended with length {} and pole angles {}".format(
episode.episode_id, episode.length, mean_pole_angle))
episode.custom_metrics["mean_pole_angle"] = mean_pole_angle
episode.episode_id, episode.length, pole_angle))
episode.custom_metrics["pole_angle"] = pole_angle
def on_sample_end(info):
@@ -70,6 +70,8 @@ if __name__ == "__main__":
# verify custom metrics for integration tests
custom_metrics = trials[0].last_result["custom_metrics"]
print(custom_metrics)
assert "mean_pole_angle" in custom_metrics
assert type(custom_metrics["mean_pole_angle"]) is float
assert "pole_angle_mean" in custom_metrics
assert "pole_angle_min" in custom_metrics
assert "pole_angle_max" in custom_metrics
assert type(custom_metrics["pole_angle_mean"]) is float
assert "callback_ok" in trials[0].last_result
+18 -7
View File
@@ -271,15 +271,26 @@ class ModelCatalog(object):
@staticmethod
def get_preprocessor(env, options=None):
"""Returns a suitable processor for the given environment.
"""Returns a suitable preprocessor for the given env.
This is a wrapper for get_preprocessor_for_space().
"""
return ModelCatalog.get_preprocessor_for_space(env.observation_space,
options)
@staticmethod
def get_preprocessor_for_space(observation_space, options=None):
"""Returns a suitable preprocessor for the given observation space.
Args:
env (gym.Env|VectorEnv|ExternalEnv): The environment to wrap.
observation_space (Space): The input observation space.
options (dict): Options to pass to the preprocessor.
Returns:
preprocessor (Preprocessor): Preprocessor for the env observations.
preprocessor (Preprocessor): Preprocessor for the observations.
"""
options = options or MODEL_DEFAULTS
for k in options.keys():
if k not in MODEL_DEFAULTS:
@@ -290,13 +301,13 @@ class ModelCatalog(object):
preprocessor = options["custom_preprocessor"]
logger.info("Using custom preprocessor {}".format(preprocessor))
prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
env.observation_space, options)
observation_space, options)
else:
cls = get_preprocessor(env.observation_space)
prep = cls(env.observation_space, options)
cls = get_preprocessor(observation_space)
prep = cls(observation_space, options)
logger.debug("Created preprocessor {}: {} -> {}".format(
prep, env.observation_space, prep.shape))
prep, observation_space, prep.shape))
return prep
@staticmethod
+3
View File
@@ -109,6 +109,9 @@ class OneHotPreprocessor(Preprocessor):
def transform(self, observation):
arr = np.zeros(self._obs_space.n)
if not self._obs_space.contains(observation):
raise ValueError("Observation outside expected value range",
self._obs_space, observation)
arr[observation] = 1
return arr
@@ -135,6 +135,7 @@ class LearnerThread(threading.Thread):
self.daemon = True
self.weights_updated = False
self.stopped = False
self.stats = {}
def run(self):
while not self.stopped:
@@ -151,6 +152,8 @@ class LearnerThread(threading.Thread):
prio_dict[pid] = (
replay.policy_batches[pid]["batch_indexes"],
info["td_error"])
if "stats" in info:
self.stats[pid] = info["stats"]
# send `replay` back also so that it gets released by the original
# thread: https://github.com/ray-project/ray/issues/2610
self.outqueue.put((ra, replay, prio_dict, replay.count))
@@ -331,4 +334,6 @@ class AsyncReplayOptimizer(PolicyOptimizer):
}
if self.debug:
stats.update(debug_stats)
if self.learner.stats:
stats["learner"] = self.learner.stats
return dict(PolicyOptimizer.stats(self), **stats)
@@ -53,6 +53,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
self.replay_timer = TimerStat()
self.grad_timer = TimerStat()
self.throughput = RunningStat()
self.learner_stats = {}
# Set up replay buffer
if prioritized_replay:
@@ -111,6 +112,8 @@ class SyncReplayOptimizer(PolicyOptimizer):
with self.grad_timer:
info_dict = self.local_evaluator.compute_apply(samples)
for policy_id, info in info_dict.items():
if "stats" in info:
self.learner_stats[policy_id] = info["stats"]
replay_buffer = self.replay_buffers[policy_id]
if isinstance(replay_buffer, PrioritizedReplayBuffer):
td_error = info["td_error"]
@@ -160,4 +163,5 @@ class SyncReplayOptimizer(PolicyOptimizer):
"opt_peak_throughput": round(self.grad_timer.mean_throughput,
3),
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
"learner": self.learner_stats,
})
@@ -38,10 +38,10 @@ class EnvWithSubprocess(gym.Env):
atexit.register(lambda: self.subproc.kill())
def reset(self):
return [0]
return 0
def step(self, action):
return [0], 0, True, {}
return 0, 0, True, {}
def leaked_processes():
+22 -6
View File
@@ -22,6 +22,12 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.tune.registry import register_env
def one_hot(i, n):
out = [0.0] * n
out[i] = 1.0
return out
class BasicMultiAgent(MultiAgentEnv):
"""Env of N independent agents, each of which exits after 25 steps."""
@@ -64,7 +70,7 @@ class RoundRobinMultiAgent(MultiAgentEnv):
self.last_info = {}
self.i = 0
self.num = num
self.observation_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Discrete(10)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
@@ -290,7 +296,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def testMultiAgentSampleRoundRobin(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(10)
ev = PolicyEvaluator(
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
policy_graph={
@@ -303,10 +309,20 @@ class TestMultiAgentEnv(unittest.TestCase):
# since we round robin introduce agents into the env, some of the env
# steps don't count as proper transitions
self.assertEqual(batch.policy_batches["p0"].count, 42)
self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10],
[0, 1, 2, 3, 4] * 2)
self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10],
[1, 2, 3, 4, 5] * 2)
self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10], [
one_hot(0, 10),
one_hot(1, 10),
one_hot(2, 10),
one_hot(3, 10),
one_hot(4, 10),
] * 2)
self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10], [
one_hot(1, 10),
one_hot(2, 10),
one_hot(3, 10),
one_hot(4, 10),
one_hot(5, 10),
] * 2)
self.assertEqual(batch.policy_batches["p0"]["rewards"].tolist()[:10],
[100, 100, 100, 100, 0] * 2)
self.assertEqual(batch.policy_batches["p0"]["dones"].tolist()[:10],
+87 -10
View File
@@ -13,6 +13,8 @@ import unittest
import ray
from ray.rllib.agents.pg import PGAgent
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.models import ModelCatalog
@@ -88,6 +90,34 @@ class NestedTupleEnv(gym.Env):
return TUPLE_SAMPLES[self.steps], 1, self.steps >= 5, {}
class NestedMultiAgentEnv(MultiAgentEnv):
def __init__(self):
self.steps = 0
def reset(self):
return {
"dict_agent": DICT_SAMPLES[0],
"tuple_agent": TUPLE_SAMPLES[0],
}
def step(self, actions):
self.steps += 1
obs = {
"dict_agent": DICT_SAMPLES[self.steps],
"tuple_agent": TUPLE_SAMPLES[self.steps],
}
rew = {
"dict_agent": 0,
"tuple_agent": 0,
}
dones = {"__all__": self.steps >= 5}
infos = {
"dict_agent": {},
"tuple_agent": {},
}
return obs, rew, dones, infos
class InvalidModel(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
return "not", "valid"
@@ -107,7 +137,8 @@ class DictSpyModel(Model):
# redis to communicate back to our suite
ray.experimental.internal_kv._internal_kv_put(
"d_spy_in_{}".format(DictSpyModel.capture_index),
pickle.dumps((pos, front_cam, task)))
pickle.dumps((pos, front_cam, task)),
overwrite=True)
DictSpyModel.capture_index += 1
return 0
@@ -135,7 +166,8 @@ class TupleSpyModel(Model):
# redis to communicate back to our suite
ray.experimental.internal_kv._internal_kv_put(
"t_spy_in_{}".format(TupleSpyModel.capture_index),
pickle.dumps((pos, cam, task)))
pickle.dumps((pos, cam, task)),
overwrite=True)
TupleSpyModel.capture_index += 1
return 0
@@ -242,10 +274,8 @@ class NestedSpacesTest(unittest.TestCase):
self.doTestNestedDict(lambda _: SimpleServing(NestedDictEnv()))
def testNestedDictAsync(self):
self.assertRaisesRegexp(
ValueError, "Found raw Dict space.*",
lambda: self.doTestNestedDict(
lambda _: AsyncVectorEnv.wrap_async(NestedDictEnv())))
self.doTestNestedDict(
lambda _: AsyncVectorEnv.wrap_async(NestedDictEnv()))
def testNestedTupleGym(self):
self.doTestNestedTuple(lambda _: NestedTupleEnv())
@@ -258,10 +288,57 @@ class NestedSpacesTest(unittest.TestCase):
self.doTestNestedTuple(lambda _: SimpleServing(NestedTupleEnv()))
def testNestedTupleAsync(self):
self.assertRaisesRegexp(
ValueError, "Found raw Tuple space.*",
lambda: self.doTestNestedTuple(
lambda _: AsyncVectorEnv.wrap_async(NestedTupleEnv())))
self.doTestNestedTuple(
lambda _: AsyncVectorEnv.wrap_async(NestedTupleEnv()))
def testMultiAgentComplexSpaces(self):
ModelCatalog.register_custom_model("dict_spy", DictSpyModel)
ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel)
register_env("nested_ma", lambda _: NestedMultiAgentEnv())
act_space = spaces.Discrete(2)
pg = PGAgent(
env="nested_ma",
config={
"num_workers": 0,
"sample_batch_size": 5,
"multiagent": {
"policy_graphs": {
"tuple_policy": (
PGPolicyGraph, TUPLE_SPACE, act_space,
{"model": {"custom_model": "tuple_spy"}}),
"dict_policy": (
PGPolicyGraph, DICT_SPACE, act_space,
{"model": {"custom_model": "dict_spy"}}),
},
"policy_mapping_fn": lambda a: {
"tuple_agent": "tuple_policy",
"dict_agent": "dict_policy"}[a],
},
})
pg.train()
for i in range(4):
seen = pickle.loads(
ray.experimental.internal_kv._internal_kv_get(
"d_spy_in_{}".format(i)))
pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
task_i = one_hot(
DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
self.assertEqual(seen[0][0].tolist(), pos_i)
self.assertEqual(seen[1][0].tolist(), cam_i)
self.assertEqual(seen[2][0].tolist(), task_i)
for i in range(4):
seen = pickle.loads(
ray.experimental.internal_kv._internal_kv_get(
"t_spy_in_{}".format(i)))
pos_i = TUPLE_SAMPLES[i][0].tolist()
cam_i = TUPLE_SAMPLES[i][1][0].tolist()
task_i = one_hot(TUPLE_SAMPLES[i][2], 5)
self.assertEqual(seen[0][0].tolist(), pos_i)
self.assertEqual(seen[1][0].tolist(), cam_i)
self.assertEqual(seen[2][0].tolist(), task_i)
if __name__ == "__main__":
+7 -1
View File
@@ -2,9 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import threading
logger = logging.getLogger(__name__)
class Filter(object):
"""Processes input, possibly statefully."""
@@ -39,7 +42,10 @@ class NoFilter(Filter):
pass
def __call__(self, x, update=True):
return np.asarray(x)
try:
return np.asarray(x)
except Exception:
raise ValueError("Failed to convert to array", x)
def apply_changes(self, other, *args, **kwargs):
pass