mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 09:57:46 +08:00
[tune] tf2.0 testing and supporting callables (#5738)
This commit is contained in:
@@ -4,13 +4,10 @@ from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import keras
|
||||
from keras.datasets import mnist
|
||||
from keras.models import Sequential
|
||||
from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D)
|
||||
from tensorflow.keras.datasets import mnist
|
||||
|
||||
from ray.tune.integration.keras import TuneReporterCallback
|
||||
from ray.tune.examples.utils import get_mnist_data, set_keras_threads
|
||||
from ray.tune.examples.utils import get_mnist_data
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -19,7 +16,11 @@ args, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
def train_mnist(config, reporter):
|
||||
set_keras_threads(config["threads"])
|
||||
# https://github.com/tensorflow/tensorflow/issues/32159
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import Sequential
|
||||
from tensorflow.keras.layers import (Dense, Dropout, Flatten, Conv2D,
|
||||
MaxPooling2D)
|
||||
batch_size = 128
|
||||
num_classes = 10
|
||||
epochs = 12
|
||||
@@ -40,8 +41,8 @@ def train_mnist(config, reporter):
|
||||
model.add(Dense(num_classes, activation="softmax"))
|
||||
|
||||
model.compile(
|
||||
loss=keras.losses.categorical_crossentropy,
|
||||
optimizer=keras.optimizers.SGD(
|
||||
loss=tf.keras.losses.categorical_crossentropy,
|
||||
optimizer=tf.keras.optimizers.SGD(
|
||||
lr=config["lr"], momentum=config["momentum"]),
|
||||
metrics=["accuracy"])
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ import copy
|
||||
import logging
|
||||
import os
|
||||
import six
|
||||
import types
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.registry import register_trainable
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.sample import sample_from
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -145,8 +145,7 @@ class Experiment(object):
|
||||
def _register_if_needed(cls, run_object):
|
||||
"""Registers Trainable or Function at runtime.
|
||||
|
||||
Assumes already registered if run_object is a string. Does not
|
||||
register lambdas because they could be part of variant generation.
|
||||
Assumes already registered if run_object is a string.
|
||||
Also, does not inspect interface of given run_object.
|
||||
|
||||
Arguments:
|
||||
@@ -160,17 +159,16 @@ class Experiment(object):
|
||||
|
||||
if isinstance(run_object, six.string_types):
|
||||
return run_object
|
||||
elif isinstance(run_object, types.FunctionType):
|
||||
if run_object.__name__ == "<lambda>":
|
||||
logger.warning(
|
||||
"Not auto-registering lambdas - resolving as variant.")
|
||||
return run_object
|
||||
else:
|
||||
elif isinstance(run_object, sample_from):
|
||||
logger.warning("Not registering trainable. Resolving as variant.")
|
||||
return run_object
|
||||
elif isinstance(run_object, type) or callable(run_object):
|
||||
name = "DEFAULT"
|
||||
if hasattr(run_object, "__name__"):
|
||||
name = run_object.__name__
|
||||
register_trainable(name, run_object)
|
||||
return name
|
||||
elif isinstance(run_object, type):
|
||||
name = run_object.__name__
|
||||
else:
|
||||
logger.warning(
|
||||
"No name detected on trainable. Using {}.".format(name))
|
||||
register_trainable(name, run_object)
|
||||
return name
|
||||
else:
|
||||
|
||||
@@ -32,7 +32,11 @@ class TuneReporterCallback(keras.callbacks.Callback):
|
||||
for metric in list(logs):
|
||||
if "loss" in metric and "neg_" not in metric:
|
||||
logs["neg_" + metric] = -logs[metric]
|
||||
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
|
||||
print(logs)
|
||||
if "acc" in logs:
|
||||
self.reporter(keras_info=logs, mean_accuracy=logs["acc"])
|
||||
else:
|
||||
self.reporter(keras_info=logs, mean_accuracy=logs.get("accuracy"))
|
||||
|
||||
def on_epoch_end(self, batch, logs={}):
|
||||
if not self.freq == "epoch":
|
||||
|
||||
@@ -155,7 +155,7 @@ def tf2_compat_logger(config, logdir, trial=None):
|
||||
|
||||
|
||||
class TF2Logger(Logger):
|
||||
"""TensorBoard Logger for TF version >= 1.14.
|
||||
"""TensorBoard Logger for TF version >= 2.0.0.
|
||||
|
||||
Automatically flattens nested dicts to show on TensorBoard:
|
||||
|
||||
@@ -175,7 +175,7 @@ class TF2Logger(Logger):
|
||||
from tensorboard.plugins.hparams import api as hp
|
||||
self._context = context
|
||||
self._file_writer = tf.summary.create_file_writer(self.logdir)
|
||||
with tf.device("/CPU:0"), self._context.eager_mode():
|
||||
with tf.device("/CPU:0"):
|
||||
with tf.summary.record_if(True), self._file_writer.as_default():
|
||||
step = result.get(
|
||||
TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
|
||||
@@ -226,7 +226,7 @@ def to_tf_values(result, path):
|
||||
|
||||
|
||||
class TFLogger(Logger):
|
||||
"""TensorBoard Logger for TF version < 1.14.
|
||||
"""TensorBoard Logger for TF version < 2.0.0.
|
||||
|
||||
Automatically flattens nested dicts to show on TensorBoard:
|
||||
|
||||
|
||||
@@ -214,11 +214,15 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
pass
|
||||
|
||||
register_trainable("foo", train)
|
||||
Experiment("test", train)
|
||||
register_trainable("foo", B)
|
||||
Experiment("test", B)
|
||||
self.assertRaises(TypeError, lambda: register_trainable("foo", B()))
|
||||
self.assertRaises(TuneError, lambda: Experiment("foo", B()))
|
||||
self.assertRaises(TypeError, lambda: register_trainable("foo", A))
|
||||
self.assertRaises(TypeError, lambda: Experiment("foo", A))
|
||||
|
||||
def testRegisterTrainableCallable(self):
|
||||
def testTrainableCallable(self):
|
||||
def dummy_fn(config, reporter, steps):
|
||||
reporter(timesteps_total=steps, done=True)
|
||||
|
||||
@@ -232,6 +236,9 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)
|
||||
[trial] = tune.run(partial(dummy_fn, steps=steps)).trials
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)
|
||||
|
||||
def testBuiltInTrainableResources(self):
|
||||
class B(Trainable):
|
||||
|
||||
Reference in New Issue
Block a user