mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 11:10:25 +08:00
[tune] Test TF2.0, TF1.14, TF1.12 Tensorboard support (#5931)
This commit is contained in:
@@ -14,12 +14,16 @@ from tensorflow.keras.layers import LSTM
|
||||
from tensorflow.keras.optimizers import RMSprop
|
||||
from tensorflow.keras.utils import get_file
|
||||
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||||
from ray.tune import Trainable
|
||||
|
||||
from filelock import FileLock
|
||||
import os
|
||||
import argparse
|
||||
import tarfile
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
from ray.tune import Trainable
|
||||
|
||||
|
||||
def tokenize(sent):
|
||||
"""Return the tokens of a sentence including punctuation.
|
||||
@@ -211,7 +215,8 @@ class MemNNModel(Trainable):
|
||||
return model
|
||||
|
||||
def _setup(self, config):
|
||||
self.train_stories, self.test_stories = read_data()
|
||||
with FileLock(os.path.expanduser("~/.tune.lock")):
|
||||
self.train_stories, self.test_stories = read_data()
|
||||
model = self.build_model()
|
||||
rmsprop = RMSprop(
|
||||
lr=self.config.get("lr", 1e-3), rho=self.config.get("rho", 0.9))
|
||||
|
||||
@@ -146,7 +146,7 @@ def tf2_compat_logger(config, logdir, trial=None):
|
||||
else:
|
||||
import tensorflow as tf
|
||||
use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >=
|
||||
distutils.version.LooseVersion("2.0.0"))
|
||||
distutils.version.LooseVersion("1.15.0"))
|
||||
if use_tf2_api:
|
||||
tf = tf.compat.v2 # setting this for TF2.0
|
||||
return TF2Logger(config, logdir, trial)
|
||||
@@ -238,7 +238,7 @@ class TFLogger(Logger):
|
||||
|
||||
def _init(self):
|
||||
logger.debug("Initializing TFLogger instead of TF2Logger.")
|
||||
self._file_writer = tf.compat.v1.summary.FileWriter(self.logdir)
|
||||
self._file_writer = tf.summary.FileWriter(self.logdir)
|
||||
|
||||
def on_result(self, result):
|
||||
tmp = result.copy()
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import unittest
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
from ray.tune.logger import tf2_compat_logger, JsonLogger, CSVLogger
|
||||
|
||||
Trial = namedtuple("MockTrial", ["evaluated_params", "trial_id"])
|
||||
|
||||
|
||||
def result(t, rew):
|
||||
return dict(
|
||||
time_total_s=t,
|
||||
episode_reward_mean=rew,
|
||||
mean_accuracy=rew * 2,
|
||||
training_iteration=int(t))
|
||||
|
||||
|
||||
class LoggerSuite(unittest.TestCase):
|
||||
"""Test built-in loggers."""
|
||||
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir, ignore_errors=True)
|
||||
|
||||
def testTensorBoardLogger(self):
|
||||
config = {"a": 2, "b": 5}
|
||||
t = Trial(evaluated_params=config, trial_id=5342)
|
||||
logger = tf2_compat_logger(
|
||||
config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(2, 4))
|
||||
logger.close()
|
||||
|
||||
def testCSV(self):
|
||||
config = {"a": 2, "b": 5}
|
||||
t = Trial(evaluated_params=config, trial_id="csv")
|
||||
logger = CSVLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(2, 4))
|
||||
logger.close()
|
||||
|
||||
def testJSON(self):
|
||||
config = {"a": 2, "b": 5}
|
||||
t = Trial(evaluated_params=config, trial_id="json")
|
||||
logger = JsonLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(2, 4))
|
||||
logger.on_result(result(2, 4))
|
||||
logger.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user