diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index d7544c997..29775b6cd 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -186,6 +186,9 @@ class TBXLogger(Logger): {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2} """ + # NoneType is not supported on the last TBX release yet. + VALID_HPARAMS = (str, bool, int, float, list) + def _init(self): try: from tensorboardX import SummaryWriter @@ -253,14 +256,31 @@ class TBXLogger(Logger): flat_params = flatten_dict(self.trial.evaluated_params) scrubbed_params = { k: v - for k, v in flat_params.items() if v is not None + for k, v in flat_params.items() + if isinstance(v, self.VALID_HPARAMS) } + + removed = { + k: v + for k, v in flat_params.items() + if not isinstance(v, self.VALID_HPARAMS) + } + if removed: + logger.info( + "Removed the following hyperparameter values when " + "logging to tensorboard: %s", str(removed)) + from tensorboardX.summary import hparams - experiment_tag, session_start_tag, session_end_tag = hparams( - hparam_dict=scrubbed_params, metric_dict=result) - self._file_writer.file_writer.add_summary(experiment_tag) - self._file_writer.file_writer.add_summary(session_start_tag) - self._file_writer.file_writer.add_summary(session_end_tag) + try: + experiment_tag, session_start_tag, session_end_tag = hparams( + hparam_dict=scrubbed_params, metric_dict=result) + self._file_writer.file_writer.add_summary(experiment_tag) + self._file_writer.file_writer.add_summary(session_start_tag) + self._file_writer.file_writer.add_summary(session_end_tag) + except Exception: + logger.exception("TensorboardX failed to log hparams. " + "This may be due to an unsupported type " + "in the hyperparameter values.") DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TBXLogger) diff --git a/python/ray/tune/tests/test_logger.py b/python/ray/tune/tests/test_logger.py index 9b3c8d946..9a52ec61c 100644 --- a/python/ray/tune/tests/test_logger.py +++ b/python/ray/tune/tests/test_logger.py @@ -46,7 +46,7 @@ class LoggerSuite(unittest.TestCase): logger.close() def testTBX(self): - config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}} + config = {"a": 2, "b": [1, 2], "c": {"c": {"D": 123}}} t = Trial(evaluated_params=config, trial_id="tbx") logger = TBXLogger(config=config, logdir=self.test_dir, trial=t) logger.on_result(result(0, 4)) @@ -54,6 +54,25 @@ class LoggerSuite(unittest.TestCase): logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1})) logger.close() + def testBadTBX(self): + config = {"b": (1, 2, 3)} + t = Trial(evaluated_params=config, trial_id="tbx") + logger = TBXLogger(config=config, logdir=self.test_dir, trial=t) + logger.on_result(result(0, 4)) + logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1})) + with self.assertLogs("ray.tune.logger", level="INFO") as cm: + logger.close() + assert "INFO" in cm.output[0] + + config = {"None": None} + t = Trial(evaluated_params=config, trial_id="tbx") + logger = TBXLogger(config=config, logdir=self.test_dir, trial=t) + logger.on_result(result(0, 4)) + logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1})) + with self.assertLogs("ray.tune.logger", level="INFO") as cm: + logger.close() + assert "INFO" in cm.output[0] + if __name__ == "__main__": import pytest