diff --git a/python/ray/rllib/common.py b/python/ray/rllib/common.py index 23b6e6da0..30a14b913 100644 --- a/python/ray/rllib/common.py +++ b/python/ray/rllib/common.py @@ -22,11 +22,13 @@ logger.setLevel(logging.INFO) class RLLibEncoder(json.JSONEncoder): def default(self, value): - if isinstance(value, np.float32) or isinstance(value, np.float64): + if np.issubdtype(value, float): if np.isnan(value): return None else: return float(value) + elif np.issubdtype(value, int): + return int(value) class RLLibLogger(object):