[hotfix] fix Travis action dist test (#5428)

This commit is contained in:
Eric Liang
2019-08-10 17:59:54 -07:00
committed by GitHub
parent 983f3c83d8
commit cc86271cf8
+10 -4
View File
@@ -41,7 +41,7 @@ class CustomActionDistribution(TFActionDistribution):
return action_space.shape
def _build_sample_op(self):
custom_options = self.model_config["custom_options"]
custom_options = self.model.model_config["custom_options"]
if "output_dim" in custom_options:
output_shape = tf.concat(
[tf.shape(self.inputs)[:1], custom_options["output_dim"]],
@@ -115,6 +115,9 @@ class ModelCatalogTest(unittest.TestCase):
self.assertEqual(str(type(p1)), str(CustomModel))
def testCustomActionDistribution(self):
class Model():
pass
ray.init()
# registration
ModelCatalog.register_custom_action_dist("test",
@@ -131,7 +134,9 @@ class ModelCatalogTest(unittest.TestCase):
# test the class works as a distribution
dist_input = tf.placeholder(tf.float32, (None, ) + param_shape)
dist = dist_cls(dist_input, model_config=model_config)
model = Model()
model.model_config = model_config
dist = dist_cls(dist_input, model=model)
self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:])
self.assertIsInstance(dist.sample(), tf.Tensor)
with self.assertRaises(NotImplementedError):
@@ -143,7 +148,8 @@ class ModelCatalogTest(unittest.TestCase):
action_space, model_config)
self.assertEqual(param_shape, (3, ))
dist_input = tf.placeholder(tf.float32, (None, ) + param_shape)
dist = dist_cls(dist_input, model_config=model_config)
model.model_config = model_config
dist = dist_cls(dist_input, model=model)
self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:])
self.assertIsInstance(dist.sample(), tf.Tensor)
with self.assertRaises(NotImplementedError):
@@ -151,4 +157,4 @@ class ModelCatalogTest(unittest.TestCase):
if __name__ == "__main__":
unittest.main(verbosity=2)
unittest.main(verbosity=1)