From fb73d51d4dcb43ef347d6b23d116a7dc0b2a16e9 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Fri, 28 Feb 2020 11:51:56 -0800 Subject: [PATCH] [tune] fix hparams for tbx (#7312) * fix * test_hist * remove unnecessary value check * pbt * queue * skip_for_now * Apply suggestions from code review --- python/ray/tune/logger.py | 7 ++++++- python/ray/tune/tests/test_logger.py | 17 +++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 857981e42..4ee45d38a 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -231,7 +231,12 @@ class TBXLogger(Logger): def close(self): if self._file_writer is not None: if self.trial and self.trial.evaluated_params and self.last_result: - self._try_log_hparams(self.last_result) + scrubbed_result = { + k: value + for k, value in self.last_result.items() + if type(value) in VALID_SUMMARY_TYPES + } + self._try_log_hparams(scrubbed_result) self._file_writer.close() def _try_log_hparams(self, result): diff --git a/python/ray/tune/tests/test_logger.py b/python/ray/tune/tests/test_logger.py index 28b2b6990..636821f23 100644 --- a/python/ray/tune/tests/test_logger.py +++ b/python/ray/tune/tests/test_logger.py @@ -8,12 +8,14 @@ from ray.tune.logger import JsonLogger, CSVLogger, TBXLogger Trial = namedtuple("MockTrial", ["evaluated_params", "trial_id"]) -def result(t, rew): - return dict( +def result(t, rew, **kwargs): + results = dict( time_total_s=t, episode_reward_mean=rew, mean_accuracy=rew * 2, training_iteration=int(t)) + results.update(kwargs) + return results class LoggerSuite(unittest.TestCase): @@ -31,22 +33,25 @@ class LoggerSuite(unittest.TestCase): logger = CSVLogger(config=config, logdir=self.test_dir, trial=t) logger.on_result(result(2, 4)) logger.on_result(result(2, 4)) + logger.on_result(result(2, 4, score=[1, 2, 3])) logger.close() def testJSON(self): config = {"a": 2, "b": 5} t = Trial(evaluated_params=config, trial_id="json") logger = JsonLogger(config=config, logdir=self.test_dir, trial=t) - logger.on_result(result(2, 4)) - logger.on_result(result(2, 4)) + logger.on_result(result(0, 4)) + logger.on_result(result(1, 4)) + logger.on_result(result(2, 4, score=[1, 2, 3])) logger.close() def testTBX(self): config = {"a": 2, "b": 5} t = Trial(evaluated_params=config, trial_id="tbx") logger = TBXLogger(config=config, logdir=self.test_dir, trial=t) - logger.on_result(result(2, 4)) - logger.on_result(result(2, 4)) + logger.on_result(result(0, 4)) + logger.on_result(result(1, 4)) + logger.on_result(result(2, 4, score=[1, 2, 3])) logger.close()