diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 776773552..ce91742c3 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -147,7 +147,7 @@ class ModelCatalog(object): elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): if torch: raise NotImplementedError - return MultiCategorical, sum(action_space.nvec) + return MultiCategorical, int(sum(action_space.nvec)) raise NotImplementedError("Unsupported args: {} {}".format( action_space, dist_type))