mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 09:05:47 +08:00
[tune] tbx logger (#6133)
* tbx * add_hparams * fix_hparams * ok * ok * fix * ok * fix
This commit is contained in:
@@ -316,6 +316,63 @@ class CSVLogger(Logger):
|
||||
self._file.close()
|
||||
|
||||
|
||||
class TBXLogger(Logger):
|
||||
"""TensorBoardX Logger.
|
||||
|
||||
Automatically flattens nested dicts to show on TensorBoard:
|
||||
|
||||
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
except ImportError:
|
||||
logger.error("pip install tensorboardX to see TensorBoard files.")
|
||||
raise
|
||||
self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
|
||||
self.last_result = None
|
||||
|
||||
def on_result(self, result):
|
||||
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
|
||||
|
||||
tmp = result.copy()
|
||||
for k in [
|
||||
"config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION
|
||||
]:
|
||||
if k in tmp:
|
||||
del tmp[k] # not useful to log these
|
||||
|
||||
flat_result = flatten_dict(tmp, delimiter="/")
|
||||
path = ["ray", "tune"]
|
||||
valid_result = {
|
||||
"/".join(path + [attr]): value
|
||||
for attr, value in flat_result.items()
|
||||
if type(value) in VALID_SUMMARY_TYPES
|
||||
}
|
||||
|
||||
for attr, value in valid_result.items():
|
||||
self._file_writer.add_scalar(attr, value, global_step=step)
|
||||
self.last_result = valid_result
|
||||
self._file_writer.flush()
|
||||
|
||||
def flush(self):
|
||||
if self._file_writer is not None:
|
||||
self._file_writer.flush()
|
||||
|
||||
def close(self):
|
||||
if self._file_writer is not None:
|
||||
if self.trial and self.trial.evaluated_params and self.last_result:
|
||||
from tensorboardX.summary import hparams
|
||||
experiment_tag, session_start_tag, session_end_tag = hparams(
|
||||
hparam_dict=self.trial.evaluated_params,
|
||||
metric_dict=self.last_result)
|
||||
self._file_writer.file_writer.add_summary(experiment_tag)
|
||||
self._file_writer.file_writer.add_summary(session_start_tag)
|
||||
self._file_writer.file_writer.add_summary(session_end_tag)
|
||||
self._file_writer.close()
|
||||
|
||||
|
||||
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, tf2_compat_logger)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import unittest
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
from ray.tune.logger import tf2_compat_logger, JsonLogger, CSVLogger
|
||||
from ray.tune.logger import tf2_compat_logger, JsonLogger, CSVLogger, TBXLogger
|
||||
|
||||
Trial = namedtuple("MockTrial", ["evaluated_params", "trial_id"])
|
||||
|
||||
@@ -54,6 +54,14 @@ class LoggerSuite(unittest.TestCase):
|
||||
logger.on_result(result(2, 4))
|
||||
logger.close()
|
||||
|
||||
def testTBX(self):
|
||||
config = {"a": 2, "b": 5}
|
||||
t = Trial(evaluated_params=config, trial_id="tbx")
|
||||
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(2, 4))
|
||||
logger.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user