[rllib] Fix support for mixed discrete and continuous action spaces, add to regression test (#2655)

* fix

* lint

* fix
This commit is contained in:
Eric Liang
2018-08-15 10:19:41 -07:00
committed by GitHub
parent 98fed67b45
commit 53f9755594
4 changed files with 22 additions and 5 deletions
+7 -1
View File
@@ -404,7 +404,13 @@ class _MultiAgentEpisode(object):
action = self._agent_to_last_action[agent_id]
# Concatenate tuple actions
if isinstance(action, list):
action = np.concatenate(action, axis=0).flatten()
expanded = []
for a in action:
if len(a.shape) == 1:
expanded.append(np.expand_dims(a, 1))
else:
expanded.append(a)
action = np.concatenate(expanded, axis=1).flatten()
return action
def last_pi_info_for(self, agent_id):
+9 -3
View File
@@ -182,7 +182,6 @@ class MultiActionDistribution(ActionDistribution):
def __init__(self, inputs, action_space, child_distributions, input_lens):
self.input_lens = input_lens
inputs = tf.reshape(inputs, [-1, sum(input_lens)])
split_inputs = tf.split(inputs, self.input_lens, axis=1)
child_list = []
for i, distribution in enumerate(child_distributions):
@@ -191,11 +190,18 @@ class MultiActionDistribution(ActionDistribution):
def logp(self, x):
"""The log-likelihood of the action distribution."""
split_list = tf.split(x, len(self.input_lens), axis=1)
split_indices = []
for dist in self.child_distributions:
if isinstance(dist, Categorical):
split_indices.append(1)
else:
split_indices.append(tf.shape(dist.sample())[1])
split_list = tf.split(x, split_indices, axis=1)
for i, distribution in enumerate(self.child_distributions):
# Remove extra categorical dimension
if isinstance(distribution, Categorical):
split_list[i] = tf.squeeze(split_list[i], axis=-1)
split_list[i] = tf.cast(
tf.squeeze(split_list[i], axis=-1), tf.int32)
log_list = np.asarray([
distribution.logp(split_x) for distribution, split_x in zip(
self.child_distributions, split_list)
@@ -23,6 +23,10 @@ ACTION_SPACES_TO_TEST = {
Box(0.0, 1.0, (5, ), dtype=np.float32),
Box(0.0, 1.0, (5, ), dtype=np.float32)
],
"mixed_tuple": Tuple(
[Discrete(2),
Discrete(3),
Box(0.0, 1.0, (5, ), dtype=np.float32)]),
}
OBSERVATION_SPACES_TO_TEST = {