From f87235f232c5aaf728c9bb7476c7b6fa450cdc62 Mon Sep 17 00:00:00 2001 From: Andrew Tan Date: Thu, 2 May 2019 10:16:48 -0700 Subject: [PATCH] [tune] Example for Tune blog post (#4673) --- doc/source/tune-examples.rst | 2 +- python/ray/tune/examples/tune_mnist_keras.py | 200 ++++--------------- python/ray/tune/examples/utils.py | 61 ++++++ 3 files changed, 104 insertions(+), 159 deletions(-) create mode 100644 python/ray/tune/examples/utils.py diff --git a/doc/source/tune-examples.rst b/doc/source/tune-examples.rst index f11cd1b86..a15f5ec4e 100644 --- a/doc/source/tune-examples.rst +++ b/doc/source/tune-examples.rst @@ -30,7 +30,7 @@ Keras Examples -------------- - `tune_mnist_keras `__: - Converts the Keras MNIST example to use Tune with the function-based API and a Keras callback. Also shows how to easily convert something relying on argparse to use Tune. + Converts the Keras MNIST example to use Tune with the function-based API and a Keras callback. PyTorch Examples diff --git a/python/ray/tune/examples/tune_mnist_keras.py b/python/ray/tune/examples/tune_mnist_keras.py index 485dd818d..676f650cc 100644 --- a/python/ray/tune/examples/tune_mnist_keras.py +++ b/python/ray/tune/examples/tune_mnist_keras.py @@ -2,93 +2,47 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np import argparse +import numpy as np import keras from keras.datasets import mnist from keras.models import Sequential -from keras.layers import Dense, Dropout, Flatten -from keras.layers import Conv2D, MaxPooling2D -from keras import backend as K +from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D) -import ray -from ray import tune -from ray.tune.schedulers import AsyncHyperBandScheduler +from ray.tune.examples.utils import (TuneKerasCallback, get_mnist_data, + set_keras_threads) + +parser = argparse.ArgumentParser() +parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") +args, _ = parser.parse_known_args() -class TuneCallback(keras.callbacks.Callback): - def __init__(self, reporter, logs={}): - self.reporter = reporter - self.iteration = 0 - - def on_train_end(self, epoch, logs={}): - self.reporter( - timesteps_total=self.iteration, done=1, mean_accuracy=logs["acc"]) - - def on_batch_end(self, batch, logs={}): - self.iteration += 1 - self.reporter( - timesteps_total=self.iteration, mean_accuracy=logs["acc"]) - - -def train_mnist(args, cfg, reporter): - # We set threads here to avoid contention, as Keras - # is heavily parallelized across multiple cores. - K.set_session( - K.tf.Session( - config=K.tf.ConfigProto( - intra_op_parallelism_threads=args.threads, - inter_op_parallelism_threads=args.threads))) - vars(args).update(cfg) +def train_mnist(config, reporter): + set_keras_threads(config["threads"]) batch_size = 128 num_classes = 10 epochs = 12 - # input image dimensions - img_rows, img_cols = 28, 28 - - # the data, split between train and test sets - (x_train, y_train), (x_test, y_test) = mnist.load_data() - - if K.image_data_format() == "channels_first": - x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) - x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) - input_shape = (1, img_rows, img_cols) - else: - x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) - x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) - input_shape = (img_rows, img_cols, 1) - - x_train = x_train.astype("float32") - x_test = x_test.astype("float32") - x_train /= 255 - x_test /= 255 - print("x_train shape:", x_train.shape) - print(x_train.shape[0], "train samples") - print(x_test.shape[0], "test samples") - - # convert class vectors to binary class matrices - y_train = keras.utils.to_categorical(y_train, num_classes) - y_test = keras.utils.to_categorical(y_test, num_classes) + x_train, y_train, x_test, y_test, input_shape = get_mnist_data() model = Sequential() model.add( Conv2D( - 32, - kernel_size=(args.kernel1, args.kernel1), - activation="relu", + 32, kernel_size=(3, 3), activation="relu", input_shape=input_shape)) - model.add(Conv2D(64, (args.kernel2, args.kernel2), activation="relu")) - model.add(MaxPooling2D(pool_size=(args.poolsize, args.poolsize))) - model.add(Dropout(args.dropout1)) + model.add(Conv2D(64, (3, 3), activation="relu")) + model.add(MaxPooling2D(pool_size=(2, 2))) + model.add(Dropout(0.5)) model.add(Flatten()) - model.add(Dense(args.hidden, activation="relu")) - model.add(Dropout(args.dropout2)) + model.add(Dense(config["hidden"], activation="relu")) + model.add(Dropout(0.5)) model.add(Dense(num_classes, activation="softmax")) model.compile( loss=keras.losses.categorical_crossentropy, - optimizer=keras.optimizers.SGD(lr=args.lr, momentum=args.momentum), + optimizer=keras.optimizers.SGD( + lr=config["lr"], momentum=config["momentum"]), metrics=["accuracy"]) model.fit( @@ -98,77 +52,14 @@ def train_mnist(args, cfg, reporter): epochs=epochs, verbose=0, validation_data=(x_test, y_test), - callbacks=[TuneCallback(reporter)]) - - -def create_parser(): - parser = argparse.ArgumentParser(description="Keras MNIST Example") - parser.add_argument( - "--smoke-test", action="store_true", help="Finish quickly for testing") - parser.add_argument( - "--use-gpu", action="store_true", help="Use GPU in training.") - parser.add_argument( - "--jobs", - type=int, - default=1, - help="number of jobs to run concurrently (default: 1)") - parser.add_argument( - "--threads", - type=int, - default=2, - help="threads used in operations (default: 2)") - parser.add_argument( - "--steps", - type=float, - default=0.01, - metavar="LR", - help="learning rate (default: 0.01)") - parser.add_argument( - "--lr", - type=float, - default=0.01, - metavar="LR", - help="learning rate (default: 0.01)") - parser.add_argument( - "--momentum", - type=float, - default=0.5, - metavar="M", - help="SGD momentum (default: 0.5)") - parser.add_argument( - "--kernel1", - type=int, - default=3, - help="Size of first kernel (default: 3)") - parser.add_argument( - "--kernel2", - type=int, - default=3, - help="Size of second kernel (default: 3)") - parser.add_argument( - "--poolsize", type=int, default=2, help="Size of Pooling (default: 2)") - parser.add_argument( - "--dropout1", - type=float, - default=0.25, - help="Size of first kernel (default: 0.25)") - parser.add_argument( - "--hidden", - type=int, - default=128, - help="Size of Hidden Layer (default: 128)") - parser.add_argument( - "--dropout2", - type=float, - default=0.5, - help="Size of first kernel (default: 0.5)") - return parser + callbacks=[TuneKerasCallback(reporter)]) if __name__ == "__main__": - parser = create_parser() - args = parser.parse_args() - mnist.load_data() # we do this because it's not threadsafe + import ray + from ray import tune + from ray.tune.schedulers import AsyncHyperBandScheduler + mnist.load_data() # we do this on the driver because it's not threadsafe ray.init() sched = AsyncHyperBandScheduler( @@ -177,31 +68,24 @@ if __name__ == "__main__": max_t=400, grace_period=20) - tune.register_trainable( - "TRAIN_FN", - lambda config, reporter: train_mnist(args, config, reporter)) tune.run( - "TRAIN_FN", + train_mnist, name="exp", scheduler=sched, - **{ - "stop": { - "mean_accuracy": 0.99, - "timesteps_total": 10 if args.smoke_test else 300 - }, - "num_samples": 1 if args.smoke_test else 10, - "resources_per_trial": { - "cpu": args.threads, - "gpu": 0.5 if args.use_gpu else 0 - }, - "config": { - "lr": tune.sample_from( - lambda spec: np.random.uniform(0.001, 0.1)), - "momentum": tune.sample_from( - lambda spec: np.random.uniform(0.1, 0.9)), - "hidden": tune.sample_from( - lambda spec: np.random.randint(32, 512)), - "dropout1": tune.sample_from( - lambda spec: np.random.uniform(0.2, 0.8)), - } + stop={ + "mean_accuracy": 0.99, + "training_iteration": 5 if args.smoke_test else 300 + }, + num_samples=10, + resources_per_trial={ + "cpu": 2, + "gpu": 0 + }, + config={ + "threads": 2, + "lr": tune.sample_from(lambda spec: np.random.uniform(0.001, 0.1)), + "momentum": tune.sample_from( + lambda spec: np.random.uniform(0.1, 0.9)), + "hidden": tune.sample_from( + lambda spec: np.random.randint(32, 512)), }) diff --git a/python/ray/tune/examples/utils.py b/python/ray/tune/examples/utils.py new file mode 100644 index 000000000..3c73bce2b --- /dev/null +++ b/python/ray/tune/examples/utils.py @@ -0,0 +1,61 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import keras +from keras.datasets import mnist +from keras import backend as K + + +class TuneKerasCallback(keras.callbacks.Callback): + def __init__(self, reporter, logs={}): + self.reporter = reporter + self.iteration = 0 + super(TuneKerasCallback, self).__init__() + + def on_train_end(self, epoch, logs={}): + self.reporter( + timesteps_total=self.iteration, done=1, mean_accuracy=logs["acc"]) + + def on_batch_end(self, batch, logs={}): + self.iteration += 1 + self.reporter( + timesteps_total=self.iteration, mean_accuracy=logs["acc"]) + + +def get_mnist_data(): + img_rows, img_cols = 28, 28 + num_classes = 10 + + # the data, split between train and test sets + (x_train, y_train), (x_test, y_test) = mnist.load_data() + + if K.image_data_format() == "channels_first": + x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) + x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) + input_shape = (1, img_rows, img_cols) + else: + x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) + x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) + input_shape = (img_rows, img_cols, 1) + + x_train = x_train.astype("float32") + x_test = x_test.astype("float32") + x_train /= 255 + x_test /= 255 + + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + y_test = keras.utils.to_categorical(y_test, num_classes) + + return x_train, y_train, x_test, y_test, input_shape + + +def set_keras_threads(threads): + # We set threads here to avoid contention, as Keras + # is heavily parallelized across multiple cores. + K.set_session( + K.tf.Session( + config=K.tf.ConfigProto( + intra_op_parallelism_threads=threads, + inter_op_parallelism_threads=threads)))