diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 26483a190..0dbf71f15 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -215,7 +215,7 @@ class ModelCatalog: """ if isinstance(action_space, gym.spaces.Discrete): - return (tf.int64, (None, )) + return (action_space.dtype, (None, )) elif isinstance(action_space, (gym.spaces.Box, Simplex)): return (tf.float32, (None, ) + action_space.shape) elif isinstance(action_space, gym.spaces.MultiDiscrete):