mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 00:55:31 +08:00
[rllib] tuple space shouldn't assume elements are all the same size (#2637)
* fix * lint
This commit is contained in:
@@ -6,8 +6,6 @@ import tensorflow as tf
|
||||
import numpy as np
|
||||
import distutils.version
|
||||
|
||||
from ray.rllib.utils.reshaper import Reshaper
|
||||
|
||||
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
distutils.version.LooseVersion("1.5.0"))
|
||||
|
||||
@@ -182,10 +180,10 @@ class MultiActionDistribution(ActionDistribution):
|
||||
inputs (Tensor list): A list of tensors from which to compute samples.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs, action_space, child_distributions):
|
||||
# you actually have to instantiate the child distributions
|
||||
self.reshaper = Reshaper(action_space.spaces)
|
||||
split_inputs = self.reshaper.split_tensor(inputs)
|
||||
def __init__(self, inputs, action_space, child_distributions, input_lens):
|
||||
self.input_lens = input_lens
|
||||
inputs = tf.reshape(inputs, [-1, sum(input_lens)])
|
||||
split_inputs = tf.split(inputs, self.input_lens, axis=1)
|
||||
child_list = []
|
||||
for i, distribution in enumerate(child_distributions):
|
||||
child_list.append(distribution(split_inputs[i]))
|
||||
@@ -193,7 +191,7 @@ class MultiActionDistribution(ActionDistribution):
|
||||
|
||||
def logp(self, x):
|
||||
"""The log-likelihood of the action distribution."""
|
||||
split_list = self.reshaper.split_tensor(x)
|
||||
split_list = tf.split(x, len(self.input_lens), axis=1)
|
||||
for i, distribution in enumerate(self.child_distributions):
|
||||
# Remove extra categorical dimension
|
||||
if isinstance(distribution, Categorical):
|
||||
|
||||
@@ -87,16 +87,17 @@ class ModelCatalog(object):
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
return Categorical, action_space.n
|
||||
elif isinstance(action_space, gym.spaces.Tuple):
|
||||
size = 0
|
||||
child_dist = []
|
||||
input_lens = []
|
||||
for action in action_space.spaces:
|
||||
dist, action_size = ModelCatalog.get_action_dist(action)
|
||||
child_dist.append(dist)
|
||||
size += action_size
|
||||
input_lens.append(action_size)
|
||||
return partial(
|
||||
MultiActionDistribution,
|
||||
child_distributions=child_dist,
|
||||
action_space=action_space), size
|
||||
action_space=action_space,
|
||||
input_lens=input_lens), sum(input_lens)
|
||||
|
||||
raise NotImplementedError("Unsupported args: {} {}".format(
|
||||
action_space, dist_type))
|
||||
|
||||
@@ -6,7 +6,7 @@ import tensorflow as tf
|
||||
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.action_dist import Reshaper
|
||||
from ray.rllib.utils.reshaper import Reshaper
|
||||
|
||||
|
||||
class MultiAgentFullyConnectedNetwork(Model):
|
||||
|
||||
Reference in New Issue
Block a user