[tune] Make tf distributed testing smaller (#12173)

This commit is contained in:
Richard Liaw
2020-11-23 12:15:10 -08:00
committed by GitHub
parent c99c376d66
commit 40428c9b05
2 changed files with 19 additions and 9 deletions
+1
View File
@@ -390,6 +390,7 @@ py_test(
srcs = ["examples/tf_distributed_keras_example.py"],
deps = [":tune_lib"],
tags = ["exclusive", "example", "tensorflow"],
args = ["--smoke-test"]
)
py_test(
@@ -13,12 +13,15 @@ from ray.tune.integration.tensorflow import (DistributedTrainableCreator,
get_num_workers)
def mnist_dataset(batch_size):
def mnist_dataset(batch_size, mini=False):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the range [0, 255].
# You need to convert them to float32 with values in the range [0, 1]
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
if mini:
x_train = x_train[:512]
y_train = y_train[:512]
train_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
return train_dataset
@@ -30,7 +33,7 @@ def build_and_compile_cnn_model(config):
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(config.get("hidden", 128), activation="relu"),
tf.keras.layers.Dense(config.get("hidden", 16), activation="relu"),
tf.keras.layers.Dense(10)
])
model.compile(
@@ -47,13 +50,15 @@ def train_mnist(config, checkpoint_dir=None):
per_worker_batch_size = 64
num_workers = get_num_workers()
global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)
multi_worker_dataset = mnist_dataset(
global_batch_size, mini=config.get("use_mini"))
steps_per_epoch = 5 if config.get("use_mini") else 70
with strategy.scope():
multi_worker_model = build_and_compile_cnn_model(config)
multi_worker_model.fit(
multi_worker_dataset,
epochs=2,
steps_per_epoch=70,
steps_per_epoch=steps_per_epoch,
callbacks=[
TuneReportCheckpointCallback({
"mean_accuracy": "accuracy"
@@ -91,6 +96,11 @@ if __name__ == "__main__":
action="store_true",
default=False,
help="enables multi-node tuning")
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="enables small scale testing")
args = parser.parse_args()
if args.cluster:
options = dict(address="auto")
@@ -119,10 +129,9 @@ if __name__ == "__main__":
},
num_samples=1,
config={
"lr": tune.sample_from(lambda spec: np.random.uniform(0.001, 0.1)),
"momentum": tune.sample_from(
lambda spec: np.random.uniform(0.1, 0.9)),
"hidden": tune.sample_from(
lambda spec: np.random.randint(32, 512)),
"use_mini": args.smoke_test,
"lr": tune.uniform(0.001, 0.1),
"momentum": tune.uniform(0.1, 0.9),
"hidden": tune.randint(32, 512),
})
print("Best hyperparameters found were: ", analysis.best_config)