mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 17:56:52 +08:00
[tune] Default to TensorboardX and include in requirements. (#6836)
This commit is contained in:
+18
-10
@@ -315,7 +315,8 @@ class CSVLogger(Logger):
|
||||
class TBXLogger(Logger):
|
||||
"""TensorBoardX Logger.
|
||||
|
||||
Automatically flattens nested dicts to show on TensorBoard:
|
||||
Note that hparams will be written only after a trial has terminated.
|
||||
This logger automatically flattens nested dicts to show on TensorBoard:
|
||||
|
||||
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
|
||||
"""
|
||||
@@ -324,7 +325,7 @@ class TBXLogger(Logger):
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
except ImportError:
|
||||
logger.error("pip install tensorboardX to see TensorBoard files.")
|
||||
logger.error("pip install 'ray[tune]' to see TensorBoard files.")
|
||||
raise
|
||||
self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
|
||||
self.last_result = None
|
||||
@@ -359,17 +360,24 @@ class TBXLogger(Logger):
|
||||
def close(self):
|
||||
if self._file_writer is not None:
|
||||
if self.trial and self.trial.evaluated_params and self.last_result:
|
||||
from tensorboardX.summary import hparams
|
||||
experiment_tag, session_start_tag, session_end_tag = hparams(
|
||||
hparam_dict=self.trial.evaluated_params,
|
||||
metric_dict=self.last_result)
|
||||
self._file_writer.file_writer.add_summary(experiment_tag)
|
||||
self._file_writer.file_writer.add_summary(session_start_tag)
|
||||
self._file_writer.file_writer.add_summary(session_end_tag)
|
||||
self._try_log_hparams(self.last_result)
|
||||
self._file_writer.close()
|
||||
|
||||
def _try_log_hparams(self, result):
|
||||
# TBX currently errors if the hparams value is None.
|
||||
scrubbed_params = {
|
||||
k: v
|
||||
for k, v in self.trial.evaluated_params.items() if v is not None
|
||||
}
|
||||
from tensorboardX.summary import hparams
|
||||
experiment_tag, session_start_tag, session_end_tag = hparams(
|
||||
hparam_dict=scrubbed_params, metric_dict=result)
|
||||
self._file_writer.file_writer.add_summary(experiment_tag)
|
||||
self._file_writer.file_writer.add_summary(session_start_tag)
|
||||
self._file_writer.file_writer.add_summary(session_end_tag)
|
||||
|
||||
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, tf2_compat_logger)
|
||||
|
||||
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TBXLogger)
|
||||
|
||||
|
||||
class UnifiedLogger(Logger):
|
||||
|
||||
+9
-5
@@ -72,16 +72,20 @@ if "RAY_USE_NEW_GCS" in os.environ and os.environ["RAY_USE_NEW_GCS"] == "on":
|
||||
]
|
||||
|
||||
extras = {
|
||||
"rllib": [
|
||||
"pyyaml", "gym[atari]", "opencv-python-headless", "lz4", "scipy",
|
||||
"tabulate"
|
||||
],
|
||||
"debug": ["psutil", "setproctitle", "py-spy >= 0.2.0"],
|
||||
"dashboard": ["aiohttp", "google", "grpcio", "psutil", "setproctitle"],
|
||||
"serve": ["uvicorn", "pygments", "werkzeug", "flask", "pandas", "blist"],
|
||||
"tune": ["tabulate"],
|
||||
"tune": ["tabulate", "tensorboardX"],
|
||||
}
|
||||
|
||||
extras["rllib"] = extras["tune"] + [
|
||||
"pyyaml",
|
||||
"gym[atari]",
|
||||
"opencv-python-headless",
|
||||
"lz4",
|
||||
"scipy",
|
||||
]
|
||||
|
||||
extras["all"] = list(set(chain.from_iterable(extras.values())))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user