diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 044448d47..d2fae3723 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -187,7 +187,7 @@ class TBXLogger(Logger): """ # NoneType is not supported on the last TBX release yet. - VALID_HPARAMS = (str, bool, int, float, list) + VALID_HPARAMS = (str, bool, np.bool8, int, np.integer, float, list) def _init(self): try: diff --git a/python/ray/tune/tests/test_logger.py b/python/ray/tune/tests/test_logger.py index 9a52ec61c..17215b36c 100644 --- a/python/ray/tune/tests/test_logger.py +++ b/python/ray/tune/tests/test_logger.py @@ -2,6 +2,7 @@ from collections import namedtuple import unittest import tempfile import shutil +import numpy as np from ray.tune.logger import JsonLogger, CSVLogger, TBXLogger @@ -46,7 +47,17 @@ class LoggerSuite(unittest.TestCase): logger.close() def testTBX(self): - config = {"a": 2, "b": [1, 2], "c": {"c": {"D": 123}}} + config = { + "a": 2, + "b": [1, 2], + "c": { + "c": { + "D": 123 + } + }, + "d": np.int64(1), + "e": np.bool8(True) + } t = Trial(evaluated_params=config, trial_id="tbx") logger = TBXLogger(config=config, logdir=self.test_dir, trial=t) logger.on_result(result(0, 4))