mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 02:49:33 +08:00
[hotfix] fix Travis action dist test (#5428)
This commit is contained in:
@@ -41,7 +41,7 @@ class CustomActionDistribution(TFActionDistribution):
|
||||
return action_space.shape
|
||||
|
||||
def _build_sample_op(self):
|
||||
custom_options = self.model_config["custom_options"]
|
||||
custom_options = self.model.model_config["custom_options"]
|
||||
if "output_dim" in custom_options:
|
||||
output_shape = tf.concat(
|
||||
[tf.shape(self.inputs)[:1], custom_options["output_dim"]],
|
||||
@@ -115,6 +115,9 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
self.assertEqual(str(type(p1)), str(CustomModel))
|
||||
|
||||
def testCustomActionDistribution(self):
|
||||
class Model():
|
||||
pass
|
||||
|
||||
ray.init()
|
||||
# registration
|
||||
ModelCatalog.register_custom_action_dist("test",
|
||||
@@ -131,7 +134,9 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
|
||||
# test the class works as a distribution
|
||||
dist_input = tf.placeholder(tf.float32, (None, ) + param_shape)
|
||||
dist = dist_cls(dist_input, model_config=model_config)
|
||||
model = Model()
|
||||
model.model_config = model_config
|
||||
dist = dist_cls(dist_input, model=model)
|
||||
self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:])
|
||||
self.assertIsInstance(dist.sample(), tf.Tensor)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
@@ -143,7 +148,8 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
action_space, model_config)
|
||||
self.assertEqual(param_shape, (3, ))
|
||||
dist_input = tf.placeholder(tf.float32, (None, ) + param_shape)
|
||||
dist = dist_cls(dist_input, model_config=model_config)
|
||||
model.model_config = model_config
|
||||
dist = dist_cls(dist_input, model=model)
|
||||
self.assertEqual(dist.sample().shape[1:], dist_input.shape[1:])
|
||||
self.assertIsInstance(dist.sample(), tf.Tensor)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
@@ -151,4 +157,4 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
unittest.main(verbosity=1)
|
||||
|
||||
Reference in New Issue
Block a user