[tune] fix hparams for tbx (#7312)

* fix

* test_hist

* remove unnecessary value check

* pbt

* queue

* skip_for_now

* Apply suggestions from code review
This commit is contained in:
Richard Liaw
2020-02-28 11:51:56 -08:00
committed by GitHub
parent ca40b0fcc6
commit fb73d51d4d
2 changed files with 17 additions and 7 deletions
+6 -1
View File
@@ -231,7 +231,12 @@ class TBXLogger(Logger):
def close(self):
if self._file_writer is not None:
if self.trial and self.trial.evaluated_params and self.last_result:
self._try_log_hparams(self.last_result)
scrubbed_result = {
k: value
for k, value in self.last_result.items()
if type(value) in VALID_SUMMARY_TYPES
}
self._try_log_hparams(scrubbed_result)
self._file_writer.close()
def _try_log_hparams(self, result):
+11 -6
View File
@@ -8,12 +8,14 @@ from ray.tune.logger import JsonLogger, CSVLogger, TBXLogger
Trial = namedtuple("MockTrial", ["evaluated_params", "trial_id"])
def result(t, rew):
return dict(
def result(t, rew, **kwargs):
results = dict(
time_total_s=t,
episode_reward_mean=rew,
mean_accuracy=rew * 2,
training_iteration=int(t))
results.update(kwargs)
return results
class LoggerSuite(unittest.TestCase):
@@ -31,22 +33,25 @@ class LoggerSuite(unittest.TestCase):
logger = CSVLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(2, 4))
logger.on_result(result(2, 4))
logger.on_result(result(2, 4, score=[1, 2, 3]))
logger.close()
def testJSON(self):
config = {"a": 2, "b": 5}
t = Trial(evaluated_params=config, trial_id="json")
logger = JsonLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(2, 4))
logger.on_result(result(2, 4))
logger.on_result(result(0, 4))
logger.on_result(result(1, 4))
logger.on_result(result(2, 4, score=[1, 2, 3]))
logger.close()
def testTBX(self):
config = {"a": 2, "b": 5}
t = Trial(evaluated_params=config, trial_id="tbx")
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(2, 4))
logger.on_result(result(2, 4))
logger.on_result(result(0, 4))
logger.on_result(result(1, 4))
logger.on_result(result(2, 4, score=[1, 2, 3]))
logger.close()