diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index b104230bf..c3a816709 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -6,8 +6,6 @@ import tensorflow as tf import numpy as np import distutils.version -from ray.rllib.utils.reshaper import Reshaper - use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion("1.5.0")) @@ -182,10 +180,10 @@ class MultiActionDistribution(ActionDistribution): inputs (Tensor list): A list of tensors from which to compute samples. """ - def __init__(self, inputs, action_space, child_distributions): - # you actually have to instantiate the child distributions - self.reshaper = Reshaper(action_space.spaces) - split_inputs = self.reshaper.split_tensor(inputs) + def __init__(self, inputs, action_space, child_distributions, input_lens): + self.input_lens = input_lens + inputs = tf.reshape(inputs, [-1, sum(input_lens)]) + split_inputs = tf.split(inputs, self.input_lens, axis=1) child_list = [] for i, distribution in enumerate(child_distributions): child_list.append(distribution(split_inputs[i])) @@ -193,7 +191,7 @@ class MultiActionDistribution(ActionDistribution): def logp(self, x): """The log-likelihood of the action distribution.""" - split_list = self.reshaper.split_tensor(x) + split_list = tf.split(x, len(self.input_lens), axis=1) for i, distribution in enumerate(self.child_distributions): # Remove extra categorical dimension if isinstance(distribution, Categorical): diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 71e181d10..b98061fdd 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -87,16 +87,17 @@ class ModelCatalog(object): elif isinstance(action_space, gym.spaces.Discrete): return Categorical, action_space.n elif isinstance(action_space, gym.spaces.Tuple): - size = 0 child_dist = [] + input_lens = [] for action in action_space.spaces: dist, action_size = ModelCatalog.get_action_dist(action) child_dist.append(dist) - size += action_size + input_lens.append(action_size) return partial( MultiActionDistribution, child_distributions=child_dist, - action_space=action_space), size + action_space=action_space, + input_lens=input_lens), sum(input_lens) raise NotImplementedError("Unsupported args: {} {}".format( action_space, dist_type)) diff --git a/python/ray/rllib/models/multiagentfcnet.py b/python/ray/rllib/models/multiagentfcnet.py index d000e95df..dad7f2983 100644 --- a/python/ray/rllib/models/multiagentfcnet.py +++ b/python/ray/rllib/models/multiagentfcnet.py @@ -6,7 +6,7 @@ import tensorflow as tf from ray.rllib.models.model import Model from ray.rllib.models.fcnet import FullyConnectedNetwork -from ray.rllib.models.action_dist import Reshaper +from ray.rllib.utils.reshaper import Reshaper class MultiAgentFullyConnectedNetwork(Model):