mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:33:16 +08:00
[tune] Remove unused TF loggers (#7090)
This commit is contained in:
@@ -3,7 +3,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import yaml
|
||||
import distutils.version
|
||||
import numbers
|
||||
|
||||
import numpy as np
|
||||
@@ -132,155 +131,6 @@ class JsonLogger(Logger):
|
||||
cloudpickle.dump(self.config, f)
|
||||
|
||||
|
||||
def tf2_compat_logger(config, logdir, trial=None):
|
||||
"""Chooses TensorBoard logger depending on imported TF version."""
|
||||
global tf
|
||||
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
||||
logger.warning("Not importing TensorFlow for test purposes")
|
||||
tf = None
|
||||
raise RuntimeError("Not importing TensorFlow for test purposes")
|
||||
else:
|
||||
import tensorflow as tf
|
||||
use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >=
|
||||
distutils.version.LooseVersion("1.15.0"))
|
||||
if use_tf2_api:
|
||||
# This is temporarily for RLlib because it disables v2 behavior...
|
||||
from tensorflow.python import tf2
|
||||
if not tf2.enabled():
|
||||
tf = tf.compat.v1
|
||||
return TFLogger(config, logdir, trial)
|
||||
tf = tf.compat.v2 # setting this for TF2.0
|
||||
return TF2Logger(config, logdir, trial)
|
||||
else:
|
||||
return TFLogger(config, logdir, trial)
|
||||
|
||||
|
||||
class TF2Logger(Logger):
|
||||
"""TensorBoard Logger for TF version >= 2.0.0.
|
||||
|
||||
Automatically flattens nested dicts to show on TensorBoard:
|
||||
|
||||
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
|
||||
|
||||
If you need to do more advanced logging, it is recommended
|
||||
to use a Summary Writer in the Trainable yourself.
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
global tf
|
||||
if tf is None:
|
||||
import tensorflow as tf
|
||||
tf = tf.compat.v2 # setting this for TF2.0
|
||||
self._file_writer = None
|
||||
self._hp_logged = False
|
||||
|
||||
def on_result(self, result):
|
||||
if self._file_writer is None:
|
||||
from tensorflow.python.eager import context
|
||||
from tensorboard.plugins.hparams import api as hp
|
||||
self._context = context
|
||||
self._file_writer = tf.summary.create_file_writer(self.logdir)
|
||||
with tf.device("/CPU:0"):
|
||||
with tf.summary.record_if(True), self._file_writer.as_default():
|
||||
step = result.get(
|
||||
TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
|
||||
|
||||
tmp = result.copy()
|
||||
if not self._hp_logged:
|
||||
if self.trial and self.trial.evaluated_params:
|
||||
try:
|
||||
hp.hparams(
|
||||
self.trial.evaluated_params,
|
||||
trial_id=self.trial.trial_id)
|
||||
except Exception as exc:
|
||||
logger.error("HParams failed with %s", exc)
|
||||
self._hp_logged = True
|
||||
|
||||
for k in [
|
||||
"config", "pid", "timestamp", TIME_TOTAL_S,
|
||||
TRAINING_ITERATION
|
||||
]:
|
||||
if k in tmp:
|
||||
del tmp[k] # not useful to log these
|
||||
|
||||
flat_result = flatten_dict(tmp, delimiter="/")
|
||||
path = ["ray", "tune"]
|
||||
for attr, value in flat_result.items():
|
||||
if type(value) in VALID_SUMMARY_TYPES:
|
||||
tf.summary.scalar(
|
||||
"/".join(path + [attr]), value, step=step)
|
||||
self._file_writer.flush()
|
||||
|
||||
def flush(self):
|
||||
if self._file_writer is not None:
|
||||
self._file_writer.flush()
|
||||
|
||||
def close(self):
|
||||
if self._file_writer is not None:
|
||||
self._file_writer.close()
|
||||
|
||||
|
||||
def to_tf_values(result, path):
|
||||
from tensorboardX.summary import make_histogram
|
||||
flat_result = flatten_dict(result, delimiter="/")
|
||||
values = []
|
||||
for attr, value in flat_result.items():
|
||||
if type(value) in VALID_SUMMARY_TYPES:
|
||||
values.append(
|
||||
tf.Summary.Value(
|
||||
tag="/".join(path + [attr]), simple_value=value))
|
||||
elif type(value) is list and len(value) > 0:
|
||||
values.append(
|
||||
tf.Summary.Value(
|
||||
tag="/".join(path + [attr]),
|
||||
histo=make_histogram(values=np.array(value), bins=10)))
|
||||
return values
|
||||
|
||||
|
||||
class TFLogger(Logger):
|
||||
"""TensorBoard Logger for TF version < 2.0.0.
|
||||
|
||||
Automatically flattens nested dicts to show on TensorBoard:
|
||||
|
||||
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
|
||||
|
||||
If you need to do more advanced logging, it is recommended
|
||||
to use a Summary Writer in the Trainable yourself.
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
global tf
|
||||
if tf is None:
|
||||
import tensorflow as tf
|
||||
tf = tf.compat.v1 # setting this for regular TF logger
|
||||
logger.debug("Initializing TFLogger instead of TF2Logger.")
|
||||
self._file_writer = tf.summary.FileWriter(self.logdir)
|
||||
|
||||
def on_result(self, result):
|
||||
tmp = result.copy()
|
||||
for k in [
|
||||
"config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION
|
||||
]:
|
||||
if k in tmp:
|
||||
del tmp[k] # not useful to tf log these
|
||||
values = to_tf_values(tmp, ["ray", "tune"])
|
||||
train_stats = tf.Summary(value=values)
|
||||
t = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
|
||||
self._file_writer.add_summary(train_stats, t)
|
||||
iteration_value = to_tf_values({
|
||||
TRAINING_ITERATION: result[TRAINING_ITERATION]
|
||||
}, ["ray", "tune"])
|
||||
iteration_stats = tf.Summary(value=iteration_value)
|
||||
self._file_writer.add_summary(iteration_stats, t)
|
||||
self._file_writer.flush()
|
||||
|
||||
def flush(self):
|
||||
self._file_writer.flush()
|
||||
|
||||
def close(self):
|
||||
self._file_writer.close()
|
||||
|
||||
|
||||
class CSVLogger(Logger):
|
||||
"""Logs results to progress.csv under the trial directory.
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
from ray.tune.logger import tf2_compat_logger, JsonLogger, CSVLogger, TBXLogger
|
||||
from ray.tune.logger import JsonLogger, CSVLogger, TBXLogger
|
||||
|
||||
Trial = namedtuple("MockTrial", ["evaluated_params", "trial_id"])
|
||||
|
||||
@@ -25,15 +25,6 @@ class LoggerSuite(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir, ignore_errors=True)
|
||||
|
||||
def testTensorBoardLogger(self):
|
||||
config = {"a": 2, "b": 5}
|
||||
t = Trial(evaluated_params=config, trial_id=5342)
|
||||
logger = tf2_compat_logger(
|
||||
config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(2, 4))
|
||||
logger.close()
|
||||
|
||||
def testCSV(self):
|
||||
config = {"a": 2, "b": 5}
|
||||
t = Trial(evaluated_params=config, trial_id="csv")
|
||||
|
||||
Reference in New Issue
Block a user