[tune] tf2.0 testing and supporting callables (#5738)

This commit is contained in:
Richard Liaw
2019-09-21 17:01:14 -07:00
committed by GitHub
parent c91a37f622
commit e00071721a
6 changed files with 44 additions and 26 deletions
+9 -8
View File
@@ -4,13 +4,10 @@ from __future__ import print_function
import argparse
import numpy as np
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D)
from tensorflow.keras.datasets import mnist
from ray.tune.integration.keras import TuneReporterCallback
from ray.tune.examples.utils import get_mnist_data, set_keras_threads
from ray.tune.examples.utils import get_mnist_data
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -19,7 +16,11 @@ args, _ = parser.parse_known_args()
def train_mnist(config, reporter):
set_keras_threads(config["threads"])
# https://github.com/tensorflow/tensorflow/issues/32159
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (Dense, Dropout, Flatten, Conv2D,
MaxPooling2D)
batch_size = 128
num_classes = 10
epochs = 12
@@ -40,8 +41,8 @@ def train_mnist(config, reporter):
model.add(Dense(num_classes, activation="softmax"))
model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.SGD(
loss=tf.keras.losses.categorical_crossentropy,
optimizer=tf.keras.optimizers.SGD(
lr=config["lr"], momentum=config["momentum"]),
metrics=["accuracy"])
+11 -13
View File
@@ -6,11 +6,11 @@ import copy
import logging
import os
import six
import types
from ray.tune.error import TuneError
from ray.tune.registry import register_trainable
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.sample import sample_from
logger = logging.getLogger(__name__)
@@ -145,8 +145,7 @@ class Experiment(object):
def _register_if_needed(cls, run_object):
"""Registers Trainable or Function at runtime.
Assumes already registered if run_object is a string. Does not
register lambdas because they could be part of variant generation.
Assumes already registered if run_object is a string.
Also, does not inspect interface of given run_object.
Arguments:
@@ -160,17 +159,16 @@ class Experiment(object):
if isinstance(run_object, six.string_types):
return run_object
elif isinstance(run_object, types.FunctionType):
if run_object.__name__ == "<lambda>":
logger.warning(
"Not auto-registering lambdas - resolving as variant.")
return run_object
else:
elif isinstance(run_object, sample_from):
logger.warning("Not registering trainable. Resolving as variant.")
return run_object
elif isinstance(run_object, type) or callable(run_object):
name = "DEFAULT"
if hasattr(run_object, "__name__"):
name = run_object.__name__
register_trainable(name, run_object)
return name
elif isinstance(run_object, type):
name = run_object.__name__
else:
logger.warning(
"No name detected on trainable. Using {}.".format(name))
register_trainable(name, run_object)
return name
else:
+5 -1
View File
@@ -32,7 +32,11 @@ class TuneReporterCallback(keras.callbacks.Callback):
for metric in list(logs):
if "loss" in metric and "neg_" not in metric:
logs["neg_" + metric] = -logs[metric]
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
print(logs)
if "acc" in logs:
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
else:
self.reporter(keras_info=logs, mean_accuracy=logs.get("accuracy"))
def on_epoch_end(self, batch, logs={}):
if not self.freq == "epoch":
+3 -3
View File
@@ -155,7 +155,7 @@ def tf2_compat_logger(config, logdir, trial=None):
class TF2Logger(Logger):
"""TensorBoard Logger for TF version >= 1.14.
"""TensorBoard Logger for TF version >= 2.0.0.
Automatically flattens nested dicts to show on TensorBoard:
@@ -175,7 +175,7 @@ class TF2Logger(Logger):
from tensorboard.plugins.hparams import api as hp
self._context = context
self._file_writer = tf.summary.create_file_writer(self.logdir)
with tf.device("/CPU:0"), self._context.eager_mode():
with tf.device("/CPU:0"):
with tf.summary.record_if(True), self._file_writer.as_default():
step = result.get(
TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
@@ -226,7 +226,7 @@ def to_tf_values(result, path):
class TFLogger(Logger):
"""TensorBoard Logger for TF version < 1.14.
"""TensorBoard Logger for TF version < 2.0.0.
Automatically flattens nested dicts to show on TensorBoard:
+8 -1
View File
@@ -214,11 +214,15 @@ class TrainableFunctionApiTest(unittest.TestCase):
pass
register_trainable("foo", train)
Experiment("test", train)
register_trainable("foo", B)
Experiment("test", B)
self.assertRaises(TypeError, lambda: register_trainable("foo", B()))
self.assertRaises(TuneError, lambda: Experiment("foo", B()))
self.assertRaises(TypeError, lambda: register_trainable("foo", A))
self.assertRaises(TypeError, lambda: Experiment("foo", A))
def testRegisterTrainableCallable(self):
def testTrainableCallable(self):
def dummy_fn(config, reporter, steps):
reporter(timesteps_total=steps, done=True)
@@ -232,6 +236,9 @@ class TrainableFunctionApiTest(unittest.TestCase):
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)
[trial] = tune.run(partial(dummy_fn, steps=steps)).trials
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)
def testBuiltInTrainableResources(self):
class B(Trainable):