From 4bc257f4fb7054073cd15bb25f31f1708d02c64b Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 28 Jan 2021 19:28:48 +0100 Subject: [PATCH] [RLlib] Fix custom multi action distr (#13681) --- rllib/models/catalog.py | 5 ++-- rllib/tests/test_catalog.py | 52 ++++++++++++++++++++++++++++++++++--- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 6d0bfd111..66796d71f 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -199,13 +199,14 @@ class ModelCatalog: config = config or MODEL_DEFAULTS # Custom distribution given. if config.get("custom_action_dist"): - action_dist_name = config["custom_action_dist"] + custom_action_config = config.copy() + action_dist_name = custom_action_config.pop("custom_action_dist") logger.debug( "Using custom action distribution {}".format(action_dist_name)) dist_cls = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name) return ModelCatalog._get_multi_action_distribution( - dist_cls, action_space, config, framework) + dist_cls, action_space, custom_action_config, framework) # Dist_type is given directly as a class. elif type(dist_type) is type and \ diff --git a/rllib/tests/test_catalog.py b/rllib/tests/test_catalog.py index b98f7143a..bbd1ec1bb 100644 --- a/rllib/tests/test_catalog.py +++ b/rllib/tests/test_catalog.py @@ -1,13 +1,15 @@ +from functools import partial import gym -from gym.spaces import Box, Discrete +from gym.spaces import Box, Dict, Discrete import numpy as np import unittest import ray -from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS, ActionDistribution -from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -from ray.rllib.models.tf.tf_action_dist import TFActionDistribution +from ray.rllib.models import ActionDistribution, ModelCatalog, MODEL_DEFAULTS from ray.rllib.models.preprocessors import NoPreprocessor, Preprocessor +from ray.rllib.models.tf.tf_action_dist import MultiActionDistribution, \ + TFActionDistribution +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import framework_iterator @@ -60,6 +62,12 @@ class CustomActionDistribution(TFActionDistribution): return tf.zeros(self.output_shape) +class CustomMultiActionDistribution(MultiActionDistribution): + @override(MultiActionDistribution) + def entropy(self): + raise NotImplementedError + + class TestModelCatalog(unittest.TestCase): def tearDown(self): ray.shutdown() @@ -161,6 +169,42 @@ class TestModelCatalog(unittest.TestCase): with self.assertRaises(NotImplementedError): dist.entropy() + def test_custom_multi_action_distribution(self): + class Model(): + pass + + ray.init( + object_store_memory=1000 * 1024 * 1024, + ignore_reinit_error=True) # otherwise fails sometimes locally + # registration + ModelCatalog.register_custom_action_dist( + "test", CustomMultiActionDistribution) + s1 = Discrete(5) + s2 = Box(0, 1, shape=(3, ), dtype=np.float32) + spaces = dict(action_1=s1, action_2=s2) + action_space = Dict(spaces) + # test retrieving it + model_config = MODEL_DEFAULTS.copy() + model_config["custom_action_dist"] = "test" + dist_cls, param_shape = ModelCatalog.get_action_dist( + action_space, model_config) + self.assertIsInstance(dist_cls, partial) + self.assertEqual(param_shape, s1.n + 2 * s2.shape[0]) + + # test the class works as a distribution + dist_input = tf1.placeholder(tf.float32, (None, param_shape)) + model = Model() + model.model_config = model_config + dist = dist_cls(dist_input, model=model) + self.assertIsInstance(dist.sample(), dict) + self.assertIn("action_1", dist.sample()) + self.assertIn("action_2", dist.sample()) + self.assertEqual(dist.sample()["action_1"].dtype, tf.int64) + self.assertEqual(dist.sample()["action_2"].shape[1:], s2.shape) + + with self.assertRaises(NotImplementedError): + dist.entropy() + if __name__ == "__main__": import pytest