mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 12:41:43 +08:00
[tune] TF2.0 TensorBoard support (#5547)
* Fix tensorboard log issue with tensorflow2.0 * tf2 support
This commit is contained in:
+60
-23
@@ -23,7 +23,7 @@ from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S,
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
tf = None
|
||||
use_tf150_api = True
|
||||
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32]
|
||||
|
||||
|
||||
class Logger(object):
|
||||
@@ -135,34 +135,71 @@ class JsonLogger(Logger):
|
||||
cloudpickle.dump(self.config, f)
|
||||
|
||||
|
||||
def to_tf_values(result, path):
|
||||
if use_tf150_api:
|
||||
type_list = [int, float, np.float32, np.float64, np.int32]
|
||||
def tf2_compat_logger(config, logdir):
|
||||
global tf
|
||||
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
||||
logger.warning("Not importing TensorFlow for test purposes")
|
||||
tf = None
|
||||
raise RuntimeError("Not importing TensorFlow for test purposes")
|
||||
else:
|
||||
type_list = [int, float]
|
||||
import tensorflow as tf
|
||||
use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >=
|
||||
distutils.version.LooseVersion("1.14.0"))
|
||||
if use_tf2_api:
|
||||
tf = tf.compat.v2 # setting this for 1.14
|
||||
return TF2Logger(config, logdir)
|
||||
else:
|
||||
return TFLogger(config, logdir)
|
||||
|
||||
|
||||
class TF2Logger(Logger):
|
||||
def _init(self):
|
||||
self._file_writer = tf.summary.create_file_writer(self.logdir)
|
||||
|
||||
def on_result(self, result):
|
||||
with tf.device("/CPU:0"):
|
||||
with self._file_writer.as_default():
|
||||
step = result.get(
|
||||
TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
|
||||
|
||||
tmp = result.copy()
|
||||
for k in [
|
||||
"config", "pid", "timestamp", TIME_TOTAL_S,
|
||||
TRAINING_ITERATION
|
||||
]:
|
||||
if k in tmp:
|
||||
del tmp[k] # not useful to log these
|
||||
|
||||
flat_result = flatten_dict(tmp, delimiter="/")
|
||||
path = ["ray", "tune"]
|
||||
for attr, value in flat_result.items():
|
||||
if type(value) in VALID_SUMMARY_TYPES:
|
||||
tf.summary.scalar(
|
||||
"/".join(path + [attr]), value, step=step)
|
||||
self._file_writer.flush()
|
||||
|
||||
def flush(self):
|
||||
self._file_writer.flush()
|
||||
|
||||
def close(self):
|
||||
self._file_writer.close()
|
||||
|
||||
|
||||
def to_tf_values(result, path):
|
||||
flat_result = flatten_dict(result, delimiter="/")
|
||||
values = [
|
||||
tf.Summary.Value(tag="/".join(path + [attr]), simple_value=value)
|
||||
for attr, value in flat_result.items() if type(value) in type_list
|
||||
for attr, value in flat_result.items()
|
||||
if type(value) in VALID_SUMMARY_TYPES
|
||||
]
|
||||
return values
|
||||
|
||||
|
||||
class TFLogger(Logger):
|
||||
def _init(self):
|
||||
try:
|
||||
global tf, use_tf150_api
|
||||
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
||||
logger.warning("Not importing TensorFlow for test purposes")
|
||||
tf = None
|
||||
else:
|
||||
import tensorflow
|
||||
tf = tensorflow
|
||||
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
distutils.version.LooseVersion("1.5.0"))
|
||||
except ImportError:
|
||||
logger.warning("Couldn't import TensorFlow - "
|
||||
"disabling TensorBoard logging.")
|
||||
logger.info(
|
||||
"Initializing TFLogger instead of TF2Logger. We recommend "
|
||||
"migrating to TF2.0. This class will be removed in the future.")
|
||||
self._file_writer = tf.summary.FileWriter(self.logdir)
|
||||
|
||||
def on_result(self, result):
|
||||
@@ -220,7 +257,7 @@ class CSVLogger(Logger):
|
||||
self._file.close()
|
||||
|
||||
|
||||
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TFLogger)
|
||||
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, tf2_compat_logger)
|
||||
|
||||
|
||||
class UnifiedLogger(Logger):
|
||||
@@ -250,9 +287,9 @@ class UnifiedLogger(Logger):
|
||||
for cls in self._logger_cls_list:
|
||||
try:
|
||||
self._loggers.append(cls(self.config, self.logdir))
|
||||
except Exception:
|
||||
logger.warning("Could not instantiate {} - skipping.".format(
|
||||
str(cls)))
|
||||
except Exception as exc:
|
||||
logger.warning("Could not instantiate {}: {}.".format(
|
||||
cls.__name__, str(exc)))
|
||||
self._log_syncer = get_log_syncer(
|
||||
self.logdir,
|
||||
remote_dir=self.logdir,
|
||||
|
||||
Reference in New Issue
Block a user