From cc86271cf8e01e5f97e52a32c33b0e07de61be58 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 10 Aug 2019 17:59:54 -0700 Subject: [PATCH] [hotfix] fix Travis action dist test (#5428) --- rllib/tests/test_catalog.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/rllib/tests/test_catalog.py b/rllib/tests/test_catalog.py index c72f5f37d..99fcd4c8e 100644 --- a/rllib/tests/test_catalog.py +++ b/rllib/tests/test_catalog.py @@ -41,7 +41,7 @@ class CustomActionDistribution(TFActionDistribution): return action_space.shape def _build_sample_op(self): - custom_options = self.model_config["custom_options"] + custom_options = self.model.model_config["custom_options"] if "output_dim" in custom_options: output_shape = tf.concat( [tf.shape(self.inputs)[:1], custom_options["output_dim"]], @@ -115,6 +115,9 @@ class ModelCatalogTest(unittest.TestCase): self.assertEqual(str(type(p1)), str(CustomModel)) def testCustomActionDistribution(self): + class Model(): + pass + ray.init() # registration ModelCatalog.register_custom_action_dist("test", @@ -131,7 +134,9 @@ class ModelCatalogTest(unittest.TestCase): # test the class works as a distribution dist_input = tf.placeholder(tf.float32, (None, ) + param_shape) - dist = dist_cls(dist_input, model_config=model_config) + model = Model() + model.model_config = model_config + dist = dist_cls(dist_input, model=model) self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:]) self.assertIsInstance(dist.sample(), tf.Tensor) with self.assertRaises(NotImplementedError): @@ -143,7 +148,8 @@ class ModelCatalogTest(unittest.TestCase): action_space, model_config) self.assertEqual(param_shape, (3, )) dist_input = tf.placeholder(tf.float32, (None, ) + param_shape) - dist = dist_cls(dist_input, model_config=model_config) + model.model_config = model_config + dist = dist_cls(dist_input, model=model) self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:]) self.assertIsInstance(dist.sample(), tf.Tensor) with self.assertRaises(NotImplementedError): @@ -151,4 +157,4 @@ class ModelCatalogTest(unittest.TestCase): if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=1)