[tune] TF2.0 TensorBoard support (#5547)

* Fix tensorboard log issue with tensorflow2.0

* tf2 support
This commit is contained in:
idthanm
2019-08-28 01:53:27 +08:00
committed by Eric Liang
parent d20696300e
commit 52a6a1b9f7
+60 -23
View File
@@ -23,7 +23,7 @@ from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S,
logger = logging.getLogger(__name__)
tf = None
use_tf150_api = True
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32]
class Logger(object):
@@ -135,34 +135,71 @@ class JsonLogger(Logger):
cloudpickle.dump(self.config, f)
def to_tf_values(result, path):
if use_tf150_api:
type_list = [int, float, np.float32, np.float64, np.int32]
def tf2_compat_logger(config, logdir):
global tf
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
logger.warning("Not importing TensorFlow for test purposes")
tf = None
raise RuntimeError("Not importing TensorFlow for test purposes")
else:
type_list = [int, float]
import tensorflow as tf
use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >=
distutils.version.LooseVersion("1.14.0"))
if use_tf2_api:
tf = tf.compat.v2 # setting this for 1.14
return TF2Logger(config, logdir)
else:
return TFLogger(config, logdir)
class TF2Logger(Logger):
def _init(self):
self._file_writer = tf.summary.create_file_writer(self.logdir)
def on_result(self, result):
with tf.device("/CPU:0"):
with self._file_writer.as_default():
step = result.get(
TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
tmp = result.copy()
for k in [
"config", "pid", "timestamp", TIME_TOTAL_S,
TRAINING_ITERATION
]:
if k in tmp:
del tmp[k] # not useful to log these
flat_result = flatten_dict(tmp, delimiter="/")
path = ["ray", "tune"]
for attr, value in flat_result.items():
if type(value) in VALID_SUMMARY_TYPES:
tf.summary.scalar(
"/".join(path + [attr]), value, step=step)
self._file_writer.flush()
def flush(self):
self._file_writer.flush()
def close(self):
self._file_writer.close()
def to_tf_values(result, path):
flat_result = flatten_dict(result, delimiter="/")
values = [
tf.Summary.Value(tag="/".join(path + [attr]), simple_value=value)
for attr, value in flat_result.items() if type(value) in type_list
for attr, value in flat_result.items()
if type(value) in VALID_SUMMARY_TYPES
]
return values
class TFLogger(Logger):
def _init(self):
try:
global tf, use_tf150_api
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
logger.warning("Not importing TensorFlow for test purposes")
tf = None
else:
import tensorflow
tf = tensorflow
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.5.0"))
except ImportError:
logger.warning("Couldn't import TensorFlow - "
"disabling TensorBoard logging.")
logger.info(
"Initializing TFLogger instead of TF2Logger. We recommend "
"migrating to TF2.0. This class will be removed in the future.")
self._file_writer = tf.summary.FileWriter(self.logdir)
def on_result(self, result):
@@ -220,7 +257,7 @@ class CSVLogger(Logger):
self._file.close()
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TFLogger)
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, tf2_compat_logger)
class UnifiedLogger(Logger):
@@ -250,9 +287,9 @@ class UnifiedLogger(Logger):
for cls in self._logger_cls_list:
try:
self._loggers.append(cls(self.config, self.logdir))
except Exception:
logger.warning("Could not instantiate {} - skipping.".format(
str(cls)))
except Exception as exc:
logger.warning("Could not instantiate {}: {}.".format(
cls.__name__, str(exc)))
self._log_syncer = get_log_syncer(
self.logdir,
remote_dir=self.logdir,