[tune] Default to TensorboardX and include in requirements. (#6836)

This commit is contained in:
Richard Liaw
2020-01-19 01:49:33 -08:00
committed by GitHub
parent a229bdf272
commit 341ddd0a09
4 changed files with 32 additions and 23 deletions
+18 -10
View File
@@ -315,7 +315,8 @@ class CSVLogger(Logger):
class TBXLogger(Logger):
"""TensorBoardX Logger.
Automatically flattens nested dicts to show on TensorBoard:
Note that hparams will be written only after a trial has terminated.
This logger automatically flattens nested dicts to show on TensorBoard:
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
"""
@@ -324,7 +325,7 @@ class TBXLogger(Logger):
try:
from tensorboardX import SummaryWriter
except ImportError:
logger.error("pip install tensorboardX to see TensorBoard files.")
logger.error("pip install 'ray[tune]' to see TensorBoard files.")
raise
self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
self.last_result = None
@@ -359,17 +360,24 @@ class TBXLogger(Logger):
def close(self):
if self._file_writer is not None:
if self.trial and self.trial.evaluated_params and self.last_result:
from tensorboardX.summary import hparams
experiment_tag, session_start_tag, session_end_tag = hparams(
hparam_dict=self.trial.evaluated_params,
metric_dict=self.last_result)
self._file_writer.file_writer.add_summary(experiment_tag)
self._file_writer.file_writer.add_summary(session_start_tag)
self._file_writer.file_writer.add_summary(session_end_tag)
self._try_log_hparams(self.last_result)
self._file_writer.close()
def _try_log_hparams(self, result):
# TBX currently errors if the hparams value is None.
scrubbed_params = {
k: v
for k, v in self.trial.evaluated_params.items() if v is not None
}
from tensorboardX.summary import hparams
experiment_tag, session_start_tag, session_end_tag = hparams(
hparam_dict=scrubbed_params, metric_dict=result)
self._file_writer.file_writer.add_summary(experiment_tag)
self._file_writer.file_writer.add_summary(session_start_tag)
self._file_writer.file_writer.add_summary(session_end_tag)
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, tf2_compat_logger)
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TBXLogger)
class UnifiedLogger(Logger):
+9 -5
View File
@@ -72,16 +72,20 @@ if "RAY_USE_NEW_GCS" in os.environ and os.environ["RAY_USE_NEW_GCS"] == "on":
]
extras = {
"rllib": [
"pyyaml", "gym[atari]", "opencv-python-headless", "lz4", "scipy",
"tabulate"
],
"debug": ["psutil", "setproctitle", "py-spy >= 0.2.0"],
"dashboard": ["aiohttp", "google", "grpcio", "psutil", "setproctitle"],
"serve": ["uvicorn", "pygments", "werkzeug", "flask", "pandas", "blist"],
"tune": ["tabulate"],
"tune": ["tabulate", "tensorboardX"],
}
extras["rllib"] = extras["tune"] + [
"pyyaml",
"gym[atari]",
"opencv-python-headless",
"lz4",
"scipy",
]
extras["all"] = list(set(chain.from_iterable(extras.values())))