[tune] tbx logger (#6133)

* tbx

* add_hparams

* fix_hparams

* ok

* ok

* fix

* ok

* fix
This commit is contained in:
Richard Liaw
2019-11-15 08:45:44 -08:00
committed by GitHub
parent 8ff393a7bd
commit 62cbc043b4
5 changed files with 85 additions and 4 deletions
+57
View File
@@ -316,6 +316,63 @@ class CSVLogger(Logger):
self._file.close()
class TBXLogger(Logger):
"""TensorBoardX Logger.
Automatically flattens nested dicts to show on TensorBoard:
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
"""
def _init(self):
try:
from tensorboardX import SummaryWriter
except ImportError:
logger.error("pip install tensorboardX to see TensorBoard files.")
raise
self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
self.last_result = None
def on_result(self, result):
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
tmp = result.copy()
for k in [
"config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION
]:
if k in tmp:
del tmp[k] # not useful to log these
flat_result = flatten_dict(tmp, delimiter="/")
path = ["ray", "tune"]
valid_result = {
"/".join(path + [attr]): value
for attr, value in flat_result.items()
if type(value) in VALID_SUMMARY_TYPES
}
for attr, value in valid_result.items():
self._file_writer.add_scalar(attr, value, global_step=step)
self.last_result = valid_result
self._file_writer.flush()
def flush(self):
if self._file_writer is not None:
self._file_writer.flush()
def close(self):
if self._file_writer is not None:
if self.trial and self.trial.evaluated_params and self.last_result:
from tensorboardX.summary import hparams
experiment_tag, session_start_tag, session_end_tag = hparams(
hparam_dict=self.trial.evaluated_params,
metric_dict=self.last_result)
self._file_writer.file_writer.add_summary(experiment_tag)
self._file_writer.file_writer.add_summary(session_start_tag)
self._file_writer.file_writer.add_summary(session_end_tag)
self._file_writer.close()
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, tf2_compat_logger)
+9 -1
View File
@@ -7,7 +7,7 @@ import unittest
import tempfile
import shutil
from ray.tune.logger import tf2_compat_logger, JsonLogger, CSVLogger
from ray.tune.logger import tf2_compat_logger, JsonLogger, CSVLogger, TBXLogger
Trial = namedtuple("MockTrial", ["evaluated_params", "trial_id"])
@@ -54,6 +54,14 @@ class LoggerSuite(unittest.TestCase):
logger.on_result(result(2, 4))
logger.close()
def testTBX(self):
config = {"a": 2, "b": 5}
t = Trial(evaluated_params=config, trial_id="tbx")
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(2, 4))
logger.on_result(result(2, 4))
logger.close()
if __name__ == "__main__":
unittest.main(verbosity=2)