[RLlib] Fix custom multi action distr (#13681)

This commit is contained in:
Sven Mika
2021-01-28 19:28:48 +01:00
committed by GitHub
parent c583113d66
commit 4bc257f4fb
2 changed files with 51 additions and 6 deletions
+3 -2
View File
@@ -199,13 +199,14 @@ class ModelCatalog:
config = config or MODEL_DEFAULTS
# Custom distribution given.
if config.get("custom_action_dist"):
action_dist_name = config["custom_action_dist"]
custom_action_config = config.copy()
action_dist_name = custom_action_config.pop("custom_action_dist")
logger.debug(
"Using custom action distribution {}".format(action_dist_name))
dist_cls = _global_registry.get(RLLIB_ACTION_DIST,
action_dist_name)
return ModelCatalog._get_multi_action_distribution(
dist_cls, action_space, config, framework)
dist_cls, action_space, custom_action_config, framework)
# Dist_type is given directly as a class.
elif type(dist_type) is type and \
+48 -4
View File
@@ -1,13 +1,15 @@
from functools import partial
import gym
from gym.spaces import Box, Discrete
from gym.spaces import Box, Dict, Discrete
import numpy as np
import unittest
import ray
from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS, ActionDistribution
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.models import ActionDistribution, ModelCatalog, MODEL_DEFAULTS
from ray.rllib.models.preprocessors import NoPreprocessor, Preprocessor
from ray.rllib.models.tf.tf_action_dist import MultiActionDistribution, \
TFActionDistribution
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import framework_iterator
@@ -60,6 +62,12 @@ class CustomActionDistribution(TFActionDistribution):
return tf.zeros(self.output_shape)
class CustomMultiActionDistribution(MultiActionDistribution):
@override(MultiActionDistribution)
def entropy(self):
raise NotImplementedError
class TestModelCatalog(unittest.TestCase):
def tearDown(self):
ray.shutdown()
@@ -161,6 +169,42 @@ class TestModelCatalog(unittest.TestCase):
with self.assertRaises(NotImplementedError):
dist.entropy()
def test_custom_multi_action_distribution(self):
class Model():
pass
ray.init(
object_store_memory=1000 * 1024 * 1024,
ignore_reinit_error=True) # otherwise fails sometimes locally
# registration
ModelCatalog.register_custom_action_dist(
"test", CustomMultiActionDistribution)
s1 = Discrete(5)
s2 = Box(0, 1, shape=(3, ), dtype=np.float32)
spaces = dict(action_1=s1, action_2=s2)
action_space = Dict(spaces)
# test retrieving it
model_config = MODEL_DEFAULTS.copy()
model_config["custom_action_dist"] = "test"
dist_cls, param_shape = ModelCatalog.get_action_dist(
action_space, model_config)
self.assertIsInstance(dist_cls, partial)
self.assertEqual(param_shape, s1.n + 2 * s2.shape[0])
# test the class works as a distribution
dist_input = tf1.placeholder(tf.float32, (None, param_shape))
model = Model()
model.model_config = model_config
dist = dist_cls(dist_input, model=model)
self.assertIsInstance(dist.sample(), dict)
self.assertIn("action_1", dist.sample())
self.assertIn("action_2", dist.sample())
self.assertEqual(dist.sample()["action_1"].dtype, tf.int64)
self.assertEqual(dist.sample()["action_2"].shape[1:], s2.shape)
with self.assertRaises(NotImplementedError):
dist.entropy()
if __name__ == "__main__":
import pytest