From 40428c9b05dca1c2a0eb38b5f9159bae5bf98e3b Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Mon, 23 Nov 2020 12:15:10 -0800 Subject: [PATCH] [tune] Make tf distributed testing smaller (#12173) --- python/ray/tune/BUILD | 1 + .../examples/tf_distributed_keras_example.py | 27 ++++++++++++------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index 7f2f563f3..f10df3ec9 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -390,6 +390,7 @@ py_test( srcs = ["examples/tf_distributed_keras_example.py"], deps = [":tune_lib"], tags = ["exclusive", "example", "tensorflow"], + args = ["--smoke-test"] ) py_test( diff --git a/python/ray/tune/examples/tf_distributed_keras_example.py b/python/ray/tune/examples/tf_distributed_keras_example.py index 99e0fe217..dbbec5987 100644 --- a/python/ray/tune/examples/tf_distributed_keras_example.py +++ b/python/ray/tune/examples/tf_distributed_keras_example.py @@ -13,12 +13,15 @@ from ray.tune.integration.tensorflow import (DistributedTrainableCreator, get_num_workers) -def mnist_dataset(batch_size): +def mnist_dataset(batch_size, mini=False): (x_train, y_train), _ = tf.keras.datasets.mnist.load_data() # The `x` arrays are in uint8 and have values in the range [0, 255]. # You need to convert them to float32 with values in the range [0, 1] x_train = x_train / np.float32(255) y_train = y_train.astype(np.int64) + if mini: + x_train = x_train[:512] + y_train = y_train[:512] train_dataset = tf.data.Dataset.from_tensor_slices( (x_train, y_train)).shuffle(60000).repeat().batch(batch_size) return train_dataset @@ -30,7 +33,7 @@ def build_and_compile_cnn_model(config): tf.keras.layers.Reshape(target_shape=(28, 28, 1)), tf.keras.layers.Conv2D(32, 3, activation="relu"), tf.keras.layers.Flatten(), - tf.keras.layers.Dense(config.get("hidden", 128), activation="relu"), + tf.keras.layers.Dense(config.get("hidden", 16), activation="relu"), tf.keras.layers.Dense(10) ]) model.compile( @@ -47,13 +50,15 @@ def train_mnist(config, checkpoint_dir=None): per_worker_batch_size = 64 num_workers = get_num_workers() global_batch_size = per_worker_batch_size * num_workers - multi_worker_dataset = mnist_dataset(global_batch_size) + multi_worker_dataset = mnist_dataset( + global_batch_size, mini=config.get("use_mini")) + steps_per_epoch = 5 if config.get("use_mini") else 70 with strategy.scope(): multi_worker_model = build_and_compile_cnn_model(config) multi_worker_model.fit( multi_worker_dataset, epochs=2, - steps_per_epoch=70, + steps_per_epoch=steps_per_epoch, callbacks=[ TuneReportCheckpointCallback({ "mean_accuracy": "accuracy" @@ -91,6 +96,11 @@ if __name__ == "__main__": action="store_true", default=False, help="enables multi-node tuning") + parser.add_argument( + "--smoke-test", + action="store_true", + default=False, + help="enables small scale testing") args = parser.parse_args() if args.cluster: options = dict(address="auto") @@ -119,10 +129,9 @@ if __name__ == "__main__": }, num_samples=1, config={ - "lr": tune.sample_from(lambda spec: np.random.uniform(0.001, 0.1)), - "momentum": tune.sample_from( - lambda spec: np.random.uniform(0.1, 0.9)), - "hidden": tune.sample_from( - lambda spec: np.random.randint(32, 512)), + "use_mini": args.smoke_test, + "lr": tune.uniform(0.001, 0.1), + "momentum": tune.uniform(0.1, 0.9), + "hidden": tune.randint(32, 512), }) print("Best hyperparameters found were: ", analysis.best_config)