mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:53:20 +08:00
[tune] Add np.bool8 and np.int to allowed HPARAMS types (#9297)
This commit is contained in:
@@ -187,7 +187,7 @@ class TBXLogger(Logger):
|
||||
"""
|
||||
|
||||
# NoneType is not supported on the last TBX release yet.
|
||||
VALID_HPARAMS = (str, bool, int, float, list)
|
||||
VALID_HPARAMS = (str, bool, np.bool8, int, np.integer, float, list)
|
||||
|
||||
def _init(self):
|
||||
try:
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections import namedtuple
|
||||
import unittest
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.logger import JsonLogger, CSVLogger, TBXLogger
|
||||
|
||||
@@ -46,7 +47,17 @@ class LoggerSuite(unittest.TestCase):
|
||||
logger.close()
|
||||
|
||||
def testTBX(self):
|
||||
config = {"a": 2, "b": [1, 2], "c": {"c": {"D": 123}}}
|
||||
config = {
|
||||
"a": 2,
|
||||
"b": [1, 2],
|
||||
"c": {
|
||||
"c": {
|
||||
"D": 123
|
||||
}
|
||||
},
|
||||
"d": np.int64(1),
|
||||
"e": np.bool8(True)
|
||||
}
|
||||
t = Trial(evaluated_params=config, trial_id="tbx")
|
||||
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(0, 4))
|
||||
|
||||
Reference in New Issue
Block a user