mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 12:36:50 +08:00
[RLlib] Fix custom multi action distr (#13681)
This commit is contained in:
@@ -199,13 +199,14 @@ class ModelCatalog:
|
||||
config = config or MODEL_DEFAULTS
|
||||
# Custom distribution given.
|
||||
if config.get("custom_action_dist"):
|
||||
action_dist_name = config["custom_action_dist"]
|
||||
custom_action_config = config.copy()
|
||||
action_dist_name = custom_action_config.pop("custom_action_dist")
|
||||
logger.debug(
|
||||
"Using custom action distribution {}".format(action_dist_name))
|
||||
dist_cls = _global_registry.get(RLLIB_ACTION_DIST,
|
||||
action_dist_name)
|
||||
return ModelCatalog._get_multi_action_distribution(
|
||||
dist_cls, action_space, config, framework)
|
||||
dist_cls, action_space, custom_action_config, framework)
|
||||
|
||||
# Dist_type is given directly as a class.
|
||||
elif type(dist_type) is type and \
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from functools import partial
|
||||
import gym
|
||||
from gym.spaces import Box, Discrete
|
||||
from gym.spaces import Box, Dict, Discrete
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS, ActionDistribution
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||
from ray.rllib.models import ActionDistribution, ModelCatalog, MODEL_DEFAULTS
|
||||
from ray.rllib.models.preprocessors import NoPreprocessor, Preprocessor
|
||||
from ray.rllib.models.tf.tf_action_dist import MultiActionDistribution, \
|
||||
TFActionDistribution
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
@@ -60,6 +62,12 @@ class CustomActionDistribution(TFActionDistribution):
|
||||
return tf.zeros(self.output_shape)
|
||||
|
||||
|
||||
class CustomMultiActionDistribution(MultiActionDistribution):
|
||||
@override(MultiActionDistribution)
|
||||
def entropy(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TestModelCatalog(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
@@ -161,6 +169,42 @@ class TestModelCatalog(unittest.TestCase):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
dist.entropy()
|
||||
|
||||
def test_custom_multi_action_distribution(self):
|
||||
class Model():
|
||||
pass
|
||||
|
||||
ray.init(
|
||||
object_store_memory=1000 * 1024 * 1024,
|
||||
ignore_reinit_error=True) # otherwise fails sometimes locally
|
||||
# registration
|
||||
ModelCatalog.register_custom_action_dist(
|
||||
"test", CustomMultiActionDistribution)
|
||||
s1 = Discrete(5)
|
||||
s2 = Box(0, 1, shape=(3, ), dtype=np.float32)
|
||||
spaces = dict(action_1=s1, action_2=s2)
|
||||
action_space = Dict(spaces)
|
||||
# test retrieving it
|
||||
model_config = MODEL_DEFAULTS.copy()
|
||||
model_config["custom_action_dist"] = "test"
|
||||
dist_cls, param_shape = ModelCatalog.get_action_dist(
|
||||
action_space, model_config)
|
||||
self.assertIsInstance(dist_cls, partial)
|
||||
self.assertEqual(param_shape, s1.n + 2 * s2.shape[0])
|
||||
|
||||
# test the class works as a distribution
|
||||
dist_input = tf1.placeholder(tf.float32, (None, param_shape))
|
||||
model = Model()
|
||||
model.model_config = model_config
|
||||
dist = dist_cls(dist_input, model=model)
|
||||
self.assertIsInstance(dist.sample(), dict)
|
||||
self.assertIn("action_1", dist.sample())
|
||||
self.assertIn("action_2", dist.sample())
|
||||
self.assertEqual(dist.sample()["action_1"].dtype, tf.int64)
|
||||
self.assertEqual(dist.sample()["action_2"].shape[1:], s2.shape)
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
dist.entropy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
Reference in New Issue
Block a user