diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index bdec2a8bf..71b58f121 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -154,11 +154,13 @@ def tf2_compat_logger(config, logdir): class TF2Logger(Logger): def _init(self): - from tensorflow.python.eager import context - self._context = context - self._file_writer = tf.summary.create_file_writer(self.logdir) + self._file_writer = None def on_result(self, result): + if self._file_writer is None: + from tensorflow.python.eager import context + self._context = context + self._file_writer = tf.summary.create_file_writer(self.logdir) with tf.device("/CPU:0"), self._context.eager_mode(): with tf.summary.record_if(True), self._file_writer.as_default(): step = result.get( @@ -181,10 +183,12 @@ class TF2Logger(Logger): self._file_writer.flush() def flush(self): - self._file_writer.flush() + if self._file_writer is not None: + self._file_writer.flush() def close(self): - self._file_writer.close() + if self._file_writer is not None: + self._file_writer.close() def to_tf_values(result, path):