mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 12:37:14 +08:00
[rllib] Add test for multi-agent support and fix IMPALA multi-agent (#3289)
IMPALA support for multiagent was broken since IMPALA has a requirement that batch sizes be of a certain length. However multi-agent envs can create variable-length batches. Fix this by adding zero-padding as needed (similar to the RNN case).
This commit is contained in:
@@ -219,7 +219,8 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
prev_action_input=prev_actions,
|
||||
prev_reward_input=prev_rewards,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=self.config["model"]["max_seq_len"])
|
||||
max_seq_len=self.config["model"]["max_seq_len"],
|
||||
batch_divisibility_req=self.config["sample_batch_size"])
|
||||
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
|
||||
@@ -106,6 +106,10 @@ class PPOAgent(Agent):
|
||||
and not self.config["use_gae"]):
|
||||
raise ValueError(
|
||||
"Episode truncation is not supported without a value function")
|
||||
if (self.config["multiagent"]["policy_graphs"]
|
||||
and not self.config["simple_optimizer"]):
|
||||
logger.warn("forcing simple_optimizer=True in multi-agent mode")
|
||||
self.config["simple_optimizer"] = True
|
||||
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
|
||||
@@ -55,6 +55,8 @@ class MultiAgentEpisode(object):
|
||||
self.user_data = {}
|
||||
self._policies = policies
|
||||
self._policy_mapping_fn = policy_mapping_fn
|
||||
self._next_agent_index = 0
|
||||
self._agent_to_index = {}
|
||||
self._agent_to_policy = {}
|
||||
self._agent_to_rnn_state = {}
|
||||
self._agent_to_last_obs = {}
|
||||
@@ -141,6 +143,12 @@ class MultiAgentEpisode(object):
|
||||
def _set_last_pi_info(self, agent_id, pi_info):
|
||||
self._agent_to_last_pi_info[agent_id] = pi_info
|
||||
|
||||
def _agent_index(self, agent_id):
|
||||
if agent_id not in self._agent_to_index:
|
||||
self._agent_to_index[agent_id] = self._next_agent_index
|
||||
self._next_agent_index += 1
|
||||
return self._agent_to_index[agent_id]
|
||||
|
||||
|
||||
def _flatten_action(action):
|
||||
# Concatenate tuple actions
|
||||
|
||||
@@ -316,6 +316,7 @@ def _env_runner(async_vector_env,
|
||||
policy_id,
|
||||
t=episode.length - 1,
|
||||
eps_id=episode.episode_id,
|
||||
agent_index=episode._agent_index(agent_id),
|
||||
obs=last_observation,
|
||||
actions=episode.last_action_for(agent_id),
|
||||
rewards=rewards[env_id][agent_id],
|
||||
|
||||
@@ -52,7 +52,8 @@ class TFPolicyGraph(PolicyGraph):
|
||||
prev_action_input=None,
|
||||
prev_reward_input=None,
|
||||
seq_lens=None,
|
||||
max_seq_len=20):
|
||||
max_seq_len=20,
|
||||
batch_divisibility_req=1):
|
||||
"""Initialize the policy graph.
|
||||
|
||||
Arguments:
|
||||
@@ -78,6 +79,9 @@ class TFPolicyGraph(PolicyGraph):
|
||||
[NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See
|
||||
models/lstm.py for more information.
|
||||
max_seq_len (int): max sequence length for LSTM training.
|
||||
batch_divisibility_req (int): pad all agent experiences batches to
|
||||
multiples of this value. This only has an effect if not using
|
||||
a LSTM model.
|
||||
"""
|
||||
|
||||
self.observation_space = observation_space
|
||||
@@ -97,6 +101,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
self._loss_input_dict["state_in_{}".format(i)] = ph
|
||||
self._seq_lens = seq_lens
|
||||
self._max_seq_len = max_seq_len
|
||||
self._batch_divisibility_req = batch_divisibility_req
|
||||
self._optimizer = self.optimizer()
|
||||
self._grads_and_vars = [(g, v)
|
||||
for (g, v) in self.gradients(self._optimizer)
|
||||
@@ -162,21 +167,37 @@ class TFPolicyGraph(PolicyGraph):
|
||||
|
||||
def _get_loss_inputs_dict(self, batch):
|
||||
feed_dict = {}
|
||||
if self._batch_divisibility_req > 1:
|
||||
meets_divisibility_reqs = (
|
||||
len(batch["obs"]) % self._batch_divisibility_req == 0
|
||||
and max(batch["agent_index"]) == 0) # not multiagent
|
||||
else:
|
||||
meets_divisibility_reqs = True
|
||||
|
||||
# Simple case
|
||||
if not self._state_inputs:
|
||||
# Simple case: not RNN nor do we need to pad
|
||||
if not self._state_inputs and meets_divisibility_reqs:
|
||||
for k, ph in self._loss_inputs:
|
||||
feed_dict[ph] = batch[k]
|
||||
return feed_dict
|
||||
|
||||
# RNN case
|
||||
if self._state_inputs:
|
||||
max_seq_len = self._max_seq_len
|
||||
dynamic_max = True
|
||||
else:
|
||||
max_seq_len = self._batch_divisibility_req
|
||||
dynamic_max = False
|
||||
|
||||
# RNN or multi-agent case
|
||||
feature_keys = [k for k, v in self._loss_inputs]
|
||||
state_keys = [
|
||||
"state_in_{}".format(i) for i in range(len(self._state_inputs))
|
||||
]
|
||||
feature_sequences, initial_states, seq_lens = chop_into_sequences(
|
||||
batch["eps_id"], [batch[k] for k in feature_keys],
|
||||
[batch[k] for k in state_keys], self._max_seq_len)
|
||||
batch["eps_id"],
|
||||
batch["agent_index"], [batch[k] for k in feature_keys],
|
||||
[batch[k] for k in state_keys],
|
||||
max_seq_len,
|
||||
dynamic_max=dynamic_max)
|
||||
for k, v in zip(feature_keys, feature_sequences):
|
||||
feed_dict[self._loss_input_dict[k]] = v
|
||||
for k, v in zip(state_keys, initial_states):
|
||||
|
||||
@@ -206,6 +206,8 @@ class ModelCatalog(object):
|
||||
logger.debug("Created model {}: ({} of {}, {}, {}) -> {}, {}".format(
|
||||
model, input_dict, obs_space, state_in, seq_lens, model.outputs,
|
||||
model.state_out))
|
||||
|
||||
model._validate_output_shape()
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -52,15 +52,24 @@ def add_time_dimension(padded_inputs, seq_lens):
|
||||
return tf.reshape(padded_inputs, new_shape)
|
||||
|
||||
|
||||
def chop_into_sequences(episode_ids, feature_columns, state_columns,
|
||||
max_seq_len):
|
||||
def chop_into_sequences(episode_ids,
|
||||
agent_indices,
|
||||
feature_columns,
|
||||
state_columns,
|
||||
max_seq_len,
|
||||
dynamic_max=True):
|
||||
"""Truncate and pad experiences into fixed-length sequences.
|
||||
|
||||
Arguments:
|
||||
episode_ids (list): List of episode ids for each step.
|
||||
agent_indices (list): List of agent ids for each step. Note that this
|
||||
has to be combined with episode_ids for uniqueness.
|
||||
feature_columns (list): List of arrays containing features.
|
||||
state_columns (list): List of arrays containing LSTM state values.
|
||||
max_seq_len (int): Max length of sequences before truncation.
|
||||
dynamic_max (bool): Whether to dynamically shrink the max seq len.
|
||||
For example, if max len is 20 and the actual max seq len in the
|
||||
data is 7, it will be shrunk to 7.
|
||||
|
||||
Returns:
|
||||
f_pad (list): Padded feature columns. These will be of shape
|
||||
@@ -88,19 +97,21 @@ def chop_into_sequences(episode_ids, feature_columns, state_columns,
|
||||
prev_id = None
|
||||
seq_lens = []
|
||||
seq_len = 0
|
||||
for eps_id in episode_ids:
|
||||
if (prev_id is not None and eps_id != prev_id) or \
|
||||
unique_ids = np.add(episode_ids, agent_indices)
|
||||
for uid in unique_ids:
|
||||
if (prev_id is not None and uid != prev_id) or \
|
||||
seq_len >= max_seq_len:
|
||||
seq_lens.append(seq_len)
|
||||
seq_len = 0
|
||||
seq_len += 1
|
||||
prev_id = eps_id
|
||||
prev_id = uid
|
||||
if seq_len:
|
||||
seq_lens.append(seq_len)
|
||||
assert sum(seq_lens) == len(episode_ids)
|
||||
assert sum(seq_lens) == len(unique_ids)
|
||||
|
||||
# Dynamically shrink max len as needed to optimize memory usage
|
||||
max_seq_len = max(seq_lens)
|
||||
if dynamic_max:
|
||||
max_seq_len = max(seq_lens)
|
||||
|
||||
feature_sequences = []
|
||||
for f in feature_columns:
|
||||
@@ -113,7 +124,7 @@ def chop_into_sequences(episode_ids, feature_columns, state_columns,
|
||||
f_pad[seq_base + seq_offset] = f[i]
|
||||
i += 1
|
||||
seq_base += max_seq_len
|
||||
assert i == len(episode_ids), f
|
||||
assert i == len(unique_ids), f
|
||||
feature_sequences.append(f_pad)
|
||||
|
||||
initial_states = []
|
||||
|
||||
@@ -62,6 +62,7 @@ class Model(object):
|
||||
self.seq_lens = tf.placeholder(
|
||||
dtype=tf.int32, shape=[None], name="seq_lens")
|
||||
|
||||
self._num_outputs = num_outputs
|
||||
if options.get("free_log_std"):
|
||||
assert num_outputs % 2 == 0
|
||||
num_outputs = num_outputs // 2
|
||||
@@ -73,18 +74,6 @@ class Model(object):
|
||||
self.outputs, self.last_layer = self._build_layers(
|
||||
input_dict["obs"], num_outputs, options)
|
||||
|
||||
# Validate the output shape
|
||||
try:
|
||||
out = tf.convert_to_tensor(self.outputs)
|
||||
shape = out.shape.as_list()
|
||||
except Exception:
|
||||
raise ValueError("Output is not a tensor: {}".format(self.outputs))
|
||||
else:
|
||||
if len(shape) != 2 or shape[1] != num_outputs:
|
||||
raise ValueError(
|
||||
"Expected output shape of [None, {}], got {}".format(
|
||||
num_outputs, shape))
|
||||
|
||||
if options.get("free_log_std", False):
|
||||
log_std = tf.get_variable(
|
||||
name="log_std",
|
||||
@@ -93,6 +82,19 @@ class Model(object):
|
||||
self.outputs = tf.concat(
|
||||
[self.outputs, 0.0 * self.outputs + log_std], 1)
|
||||
|
||||
def _validate_output_shape(self):
|
||||
"""Checks that the model has the correct number of outputs."""
|
||||
try:
|
||||
out = tf.convert_to_tensor(self.outputs)
|
||||
shape = out.shape.as_list()
|
||||
except Exception:
|
||||
raise ValueError("Output is not a tensor: {}".format(self.outputs))
|
||||
else:
|
||||
if len(shape) != 2 or shape[1] != self._num_outputs:
|
||||
raise ValueError(
|
||||
"Expected output shape of [None, {}], got {}".format(
|
||||
self._num_outputs, shape))
|
||||
|
||||
def _build_layers(self, inputs, num_outputs, options):
|
||||
"""Builds and returns the output and last layer of the network.
|
||||
|
||||
|
||||
@@ -10,10 +10,12 @@ from ray.rllib.models.lstm import chop_into_sequences
|
||||
class LSTMUtilsTest(unittest.TestCase):
|
||||
def testBasic(self):
|
||||
eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
|
||||
agent_ids = [1, 1, 1, 1, 1, 1, 1, 1]
|
||||
f = [[101, 102, 103, 201, 202, 203, 204, 205],
|
||||
[[101], [102], [103], [201], [202], [203], [204], [205]]]
|
||||
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, f, s, 4)
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, agent_ids, f, s,
|
||||
4)
|
||||
self.assertEqual([f.tolist() for f in f_pad], [
|
||||
[101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0],
|
||||
[[101], [102], [103], [0], [201], [202], [203], [204], [205], [0],
|
||||
@@ -22,11 +24,25 @@ class LSTMUtilsTest(unittest.TestCase):
|
||||
self.assertEqual([s.tolist() for s in s_init], [[209, 109, 105]])
|
||||
self.assertEqual(seq_lens.tolist(), [3, 4, 1])
|
||||
|
||||
def testMultiAgent(self):
|
||||
eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
|
||||
agent_ids = [1, 1, 2, 1, 1, 2, 2, 3]
|
||||
f = [[101, 102, 103, 201, 202, 203, 204, 205],
|
||||
[[101], [102], [103], [201], [202], [203], [204], [205]]]
|
||||
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(
|
||||
eps_ids, agent_ids, f, s, 4, dynamic_max=False)
|
||||
self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1])
|
||||
self.assertEqual(len(f_pad[0]), 20)
|
||||
self.assertEqual(len(s_init[0]), 5)
|
||||
|
||||
def testDynamicMaxLen(self):
|
||||
eps_ids = [5, 2, 2]
|
||||
agent_ids = [2, 2, 2]
|
||||
f = [[1, 1, 1]]
|
||||
s = [[1, 1, 1]]
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, f, s, 4)
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, agent_ids, f, s,
|
||||
4)
|
||||
self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]])
|
||||
self.assertEqual([s.tolist() for s in s_init], [[1, 1]])
|
||||
self.assertEqual(seq_lens.tolist(), [1, 2])
|
||||
|
||||
@@ -100,25 +100,32 @@ class RoundRobinMultiAgent(MultiAgentEnv):
|
||||
return obs, rew, done, info
|
||||
|
||||
|
||||
class MultiCartpole(MultiAgentEnv):
|
||||
def __init__(self, num):
|
||||
self.agents = [gym.make("CartPole-v0") for _ in range(num)]
|
||||
self.dones = set()
|
||||
self.observation_space = self.agents[0].observation_space
|
||||
self.action_space = self.agents[0].action_space
|
||||
def make_multiagent(env_name):
|
||||
class MultiEnv(MultiAgentEnv):
|
||||
def __init__(self, num):
|
||||
self.agents = [gym.make(env_name) for _ in range(num)]
|
||||
self.dones = set()
|
||||
self.observation_space = self.agents[0].observation_space
|
||||
self.action_space = self.agents[0].action_space
|
||||
|
||||
def reset(self):
|
||||
self.dones = set()
|
||||
return {i: a.reset() for i, a in enumerate(self.agents)}
|
||||
def reset(self):
|
||||
self.dones = set()
|
||||
return {i: a.reset() for i, a in enumerate(self.agents)}
|
||||
|
||||
def step(self, action_dict):
|
||||
obs, rew, done, info = {}, {}, {}, {}
|
||||
for i, action in action_dict.items():
|
||||
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
|
||||
if done[i]:
|
||||
self.dones.add(i)
|
||||
done["__all__"] = len(self.dones) == len(self.agents)
|
||||
return obs, rew, done, info
|
||||
def step(self, action_dict):
|
||||
obs, rew, done, info = {}, {}, {}, {}
|
||||
for i, action in action_dict.items():
|
||||
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
|
||||
if done[i]:
|
||||
self.dones.add(i)
|
||||
done["__all__"] = len(self.dones) == len(self.agents)
|
||||
return obs, rew, done, info
|
||||
|
||||
return MultiEnv
|
||||
|
||||
|
||||
MultiCartpole = make_multiagent("CartPole-v0")
|
||||
MultiMountainCar = make_multiagent("MountainCarContinuous-v0")
|
||||
|
||||
|
||||
class TestMultiAgentEnv(unittest.TestCase):
|
||||
|
||||
@@ -93,6 +93,11 @@ class InvalidModel(Model):
|
||||
return "not", "valid"
|
||||
|
||||
|
||||
class InvalidModel2(Model):
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
return tf.constant(0), tf.constant(0)
|
||||
|
||||
|
||||
class DictSpyModel(Model):
|
||||
capture_index = 0
|
||||
|
||||
@@ -158,6 +163,17 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
},
|
||||
}))
|
||||
|
||||
def testInvalidModel2(self):
|
||||
ModelCatalog.register_custom_model("invalid2", InvalidModel2)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Expected output.*",
|
||||
lambda: PGAgent(
|
||||
env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_model": "invalid2",
|
||||
},
|
||||
}))
|
||||
|
||||
def doTestNestedDict(self, make_env):
|
||||
ModelCatalog.register_custom_model("composite", DictSpyModel)
|
||||
register_env("nested", make_env)
|
||||
|
||||
@@ -9,6 +9,7 @@ import sys
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.test.test_multi_agent_env import MultiCartpole, MultiMountainCar
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
@@ -88,9 +89,27 @@ def check_support(alg, config, stats, check_bounds=False):
|
||||
stats[alg, a_name, o_name] = stat
|
||||
|
||||
|
||||
def check_support_multiagent(alg, config):
|
||||
register_env("multi_mountaincar", lambda _: MultiMountainCar(2))
|
||||
register_env("multi_cartpole", lambda _: MultiCartpole(2))
|
||||
if alg == "DDPG":
|
||||
a = get_agent_class(alg)(config=config, env="multi_mountaincar")
|
||||
else:
|
||||
a = get_agent_class(alg)(config=config, env="multi_cartpole")
|
||||
try:
|
||||
a.train()
|
||||
finally:
|
||||
a.stop()
|
||||
|
||||
|
||||
class ModelSupportedSpaces(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def testAll(self):
|
||||
ray.init()
|
||||
stats = {}
|
||||
check_support("IMPALA", {"num_gpus": 0}, stats)
|
||||
check_support("DDPG", {"timesteps_per_iteration": 1}, stats)
|
||||
@@ -137,6 +156,27 @@ class ModelSupportedSpaces(unittest.TestCase):
|
||||
stat)
|
||||
self.assertEqual(num_unexpected_errors, 0)
|
||||
|
||||
def testMultiAgent(self):
|
||||
check_support_multiagent("IMPALA", {"num_gpus": 0})
|
||||
check_support_multiagent("DQN", {"timesteps_per_iteration": 1})
|
||||
check_support_multiagent("A3C", {
|
||||
"num_workers": 1,
|
||||
"optimizer": {
|
||||
"grads_per_step": 1
|
||||
}
|
||||
})
|
||||
check_support_multiagent(
|
||||
"PPO", {
|
||||
"num_workers": 1,
|
||||
"num_sgd_iter": 1,
|
||||
"train_batch_size": 10,
|
||||
"sample_batch_size": 10,
|
||||
"sgd_minibatch_size": 1,
|
||||
"simple_optimizer": True,
|
||||
})
|
||||
check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}})
|
||||
check_support_multiagent("DDPG", {"timesteps_per_iteration": 1})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "--smoke":
|
||||
|
||||
Reference in New Issue
Block a user