Change get_action_shape so that it uses the dtype of the Discrete object, rather than overwriting it with tf.int64. (#8424)

This commit is contained in:
internetcoffeephone
2020-09-22 02:08:31 +02:00
committed by GitHub
parent 6247740b94
commit 840fb5543b
+1 -1
View File
@@ -215,7 +215,7 @@ class ModelCatalog:
"""
if isinstance(action_space, gym.spaces.Discrete):
return (tf.int64, (None, ))
return (action_space.dtype, (None, ))
elif isinstance(action_space, (gym.spaces.Box, Simplex)):
return (tf.float32, (None, ) + action_space.shape)
elif isinstance(action_space, gym.spaces.MultiDiscrete):