[tune] Test TF2.0, TF1.14, TF1.12 Tensorboard support (#5931)

This commit is contained in:
Richard Liaw
2019-10-18 13:50:42 -07:00
committed by GitHub
parent 697f765efc
commit 48ba484640
7 changed files with 83 additions and 13 deletions
@@ -14,12 +14,16 @@ from tensorflow.keras.layers import LSTM
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.utils import get_file
from tensorflow.keras.preprocessing.sequence import pad_sequences
from ray.tune import Trainable
from filelock import FileLock
import os
import argparse
import tarfile
import numpy as np
import re
from ray.tune import Trainable
def tokenize(sent):
"""Return the tokens of a sentence including punctuation.
@@ -211,7 +215,8 @@ class MemNNModel(Trainable):
return model
def _setup(self, config):
self.train_stories, self.test_stories = read_data()
with FileLock(os.path.expanduser("~/.tune.lock")):
self.train_stories, self.test_stories = read_data()
model = self.build_model()
rmsprop = RMSprop(
lr=self.config.get("lr", 1e-3), rho=self.config.get("rho", 0.9))
+2 -2
View File
@@ -146,7 +146,7 @@ def tf2_compat_logger(config, logdir, trial=None):
else:
import tensorflow as tf
use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >=
distutils.version.LooseVersion("2.0.0"))
distutils.version.LooseVersion("1.15.0"))
if use_tf2_api:
tf = tf.compat.v2 # setting this for TF2.0
return TF2Logger(config, logdir, trial)
@@ -238,7 +238,7 @@ class TFLogger(Logger):
def _init(self):
logger.debug("Initializing TFLogger instead of TF2Logger.")
self._file_writer = tf.compat.v1.summary.FileWriter(self.logdir)
self._file_writer = tf.summary.FileWriter(self.logdir)
def on_result(self, result):
tmp = result.copy()
+59
View File
@@ -0,0 +1,59 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import unittest
import tempfile
import shutil
from ray.tune.logger import tf2_compat_logger, JsonLogger, CSVLogger
Trial = namedtuple("MockTrial", ["evaluated_params", "trial_id"])
def result(t, rew):
return dict(
time_total_s=t,
episode_reward_mean=rew,
mean_accuracy=rew * 2,
training_iteration=int(t))
class LoggerSuite(unittest.TestCase):
"""Test built-in loggers."""
def setUp(self):
self.test_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.test_dir, ignore_errors=True)
def testTensorBoardLogger(self):
config = {"a": 2, "b": 5}
t = Trial(evaluated_params=config, trial_id=5342)
logger = tf2_compat_logger(
config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(2, 4))
logger.on_result(result(2, 4))
logger.close()
def testCSV(self):
config = {"a": 2, "b": 5}
t = Trial(evaluated_params=config, trial_id="csv")
logger = CSVLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(2, 4))
logger.on_result(result(2, 4))
logger.close()
def testJSON(self):
config = {"a": 2, "b": 5}
t = Trial(evaluated_params=config, trial_id="json")
logger = JsonLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(2, 4))
logger.on_result(result(2, 4))
logger.close()
if __name__ == "__main__":
unittest.main(verbosity=2)