[rllib] tuple space shouldn't assume elements are all the same size (#2637)

* fix

* lint
This commit is contained in:
Eric Liang
2018-08-11 10:57:40 -07:00
committed by GitHub
parent 230b9ab33b
commit 9559873d13
3 changed files with 10 additions and 11 deletions
+5 -7
View File
@@ -6,8 +6,6 @@ import tensorflow as tf
import numpy as np
import distutils.version
from ray.rllib.utils.reshaper import Reshaper
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.5.0"))
@@ -182,10 +180,10 @@ class MultiActionDistribution(ActionDistribution):
inputs (Tensor list): A list of tensors from which to compute samples.
"""
def __init__(self, inputs, action_space, child_distributions):
# you actually have to instantiate the child distributions
self.reshaper = Reshaper(action_space.spaces)
split_inputs = self.reshaper.split_tensor(inputs)
def __init__(self, inputs, action_space, child_distributions, input_lens):
self.input_lens = input_lens
inputs = tf.reshape(inputs, [-1, sum(input_lens)])
split_inputs = tf.split(inputs, self.input_lens, axis=1)
child_list = []
for i, distribution in enumerate(child_distributions):
child_list.append(distribution(split_inputs[i]))
@@ -193,7 +191,7 @@ class MultiActionDistribution(ActionDistribution):
def logp(self, x):
"""The log-likelihood of the action distribution."""
split_list = self.reshaper.split_tensor(x)
split_list = tf.split(x, len(self.input_lens), axis=1)
for i, distribution in enumerate(self.child_distributions):
# Remove extra categorical dimension
if isinstance(distribution, Categorical):
+4 -3
View File
@@ -87,16 +87,17 @@ class ModelCatalog(object):
elif isinstance(action_space, gym.spaces.Discrete):
return Categorical, action_space.n
elif isinstance(action_space, gym.spaces.Tuple):
size = 0
child_dist = []
input_lens = []
for action in action_space.spaces:
dist, action_size = ModelCatalog.get_action_dist(action)
child_dist.append(dist)
size += action_size
input_lens.append(action_size)
return partial(
MultiActionDistribution,
child_distributions=child_dist,
action_space=action_space), size
action_space=action_space,
input_lens=input_lens), sum(input_lens)
raise NotImplementedError("Unsupported args: {} {}".format(
action_space, dist_type))
+1 -1
View File
@@ -6,7 +6,7 @@ import tensorflow as tf
from ray.rllib.models.model import Model
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.action_dist import Reshaper
from ray.rllib.utils.reshaper import Reshaper
class MultiAgentFullyConnectedNetwork(Model):