mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:38:16 +08:00
[tune] fix hparams for tbx (#7312)
* fix * test_hist * remove unnecessary value check * pbt * queue * skip_for_now * Apply suggestions from code review
This commit is contained in:
@@ -231,7 +231,12 @@ class TBXLogger(Logger):
|
||||
def close(self):
|
||||
if self._file_writer is not None:
|
||||
if self.trial and self.trial.evaluated_params and self.last_result:
|
||||
self._try_log_hparams(self.last_result)
|
||||
scrubbed_result = {
|
||||
k: value
|
||||
for k, value in self.last_result.items()
|
||||
if type(value) in VALID_SUMMARY_TYPES
|
||||
}
|
||||
self._try_log_hparams(scrubbed_result)
|
||||
self._file_writer.close()
|
||||
|
||||
def _try_log_hparams(self, result):
|
||||
|
||||
@@ -8,12 +8,14 @@ from ray.tune.logger import JsonLogger, CSVLogger, TBXLogger
|
||||
Trial = namedtuple("MockTrial", ["evaluated_params", "trial_id"])
|
||||
|
||||
|
||||
def result(t, rew):
|
||||
return dict(
|
||||
def result(t, rew, **kwargs):
|
||||
results = dict(
|
||||
time_total_s=t,
|
||||
episode_reward_mean=rew,
|
||||
mean_accuracy=rew * 2,
|
||||
training_iteration=int(t))
|
||||
results.update(kwargs)
|
||||
return results
|
||||
|
||||
|
||||
class LoggerSuite(unittest.TestCase):
|
||||
@@ -31,22 +33,25 @@ class LoggerSuite(unittest.TestCase):
|
||||
logger = CSVLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(2, 4, score=[1, 2, 3]))
|
||||
logger.close()
|
||||
|
||||
def testJSON(self):
|
||||
config = {"a": 2, "b": 5}
|
||||
t = Trial(evaluated_params=config, trial_id="json")
|
||||
logger = JsonLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(0, 4))
|
||||
logger.on_result(result(1, 4))
|
||||
logger.on_result(result(2, 4, score=[1, 2, 3]))
|
||||
logger.close()
|
||||
|
||||
def testTBX(self):
|
||||
config = {"a": 2, "b": 5}
|
||||
t = Trial(evaluated_params=config, trial_id="tbx")
|
||||
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(0, 4))
|
||||
logger.on_result(result(1, 4))
|
||||
logger.on_result(result(2, 4, score=[1, 2, 3]))
|
||||
logger.close()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user