From 840fb5543b6ffca119cfb080b6cfe62060893f54 Mon Sep 17 00:00:00 2001 From: internetcoffeephone Date: Tue, 22 Sep 2020 02:08:31 +0200 Subject: [PATCH] Change get_action_shape so that it uses the dtype of the Discrete object, rather than overwriting it with tf.int64. (#8424) --- rllib/models/catalog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 26483a190..0dbf71f15 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -215,7 +215,7 @@ class ModelCatalog: """ if isinstance(action_space, gym.spaces.Discrete): - return (tf.int64, (None, )) + return (action_space.dtype, (None, )) elif isinstance(action_space, (gym.spaces.Box, Simplex)): return (tf.float32, (None, ) + action_space.shape) elif isinstance(action_space, gym.spaces.MultiDiscrete):