mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 16:48:51 +08:00
[tune] Make tf distributed testing smaller (#12173)
This commit is contained in:
@@ -390,6 +390,7 @@ py_test(
|
||||
srcs = ["examples/tf_distributed_keras_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example", "tensorflow"],
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
||||
@@ -13,12 +13,15 @@ from ray.tune.integration.tensorflow import (DistributedTrainableCreator,
|
||||
get_num_workers)
|
||||
|
||||
|
||||
def mnist_dataset(batch_size):
|
||||
def mnist_dataset(batch_size, mini=False):
|
||||
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
|
||||
# The `x` arrays are in uint8 and have values in the range [0, 255].
|
||||
# You need to convert them to float32 with values in the range [0, 1]
|
||||
x_train = x_train / np.float32(255)
|
||||
y_train = y_train.astype(np.int64)
|
||||
if mini:
|
||||
x_train = x_train[:512]
|
||||
y_train = y_train[:512]
|
||||
train_dataset = tf.data.Dataset.from_tensor_slices(
|
||||
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
|
||||
return train_dataset
|
||||
@@ -30,7 +33,7 @@ def build_and_compile_cnn_model(config):
|
||||
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
|
||||
tf.keras.layers.Conv2D(32, 3, activation="relu"),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(config.get("hidden", 128), activation="relu"),
|
||||
tf.keras.layers.Dense(config.get("hidden", 16), activation="relu"),
|
||||
tf.keras.layers.Dense(10)
|
||||
])
|
||||
model.compile(
|
||||
@@ -47,13 +50,15 @@ def train_mnist(config, checkpoint_dir=None):
|
||||
per_worker_batch_size = 64
|
||||
num_workers = get_num_workers()
|
||||
global_batch_size = per_worker_batch_size * num_workers
|
||||
multi_worker_dataset = mnist_dataset(global_batch_size)
|
||||
multi_worker_dataset = mnist_dataset(
|
||||
global_batch_size, mini=config.get("use_mini"))
|
||||
steps_per_epoch = 5 if config.get("use_mini") else 70
|
||||
with strategy.scope():
|
||||
multi_worker_model = build_and_compile_cnn_model(config)
|
||||
multi_worker_model.fit(
|
||||
multi_worker_dataset,
|
||||
epochs=2,
|
||||
steps_per_epoch=70,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
callbacks=[
|
||||
TuneReportCheckpointCallback({
|
||||
"mean_accuracy": "accuracy"
|
||||
@@ -91,6 +96,11 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="enables multi-node tuning")
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="enables small scale testing")
|
||||
args = parser.parse_args()
|
||||
if args.cluster:
|
||||
options = dict(address="auto")
|
||||
@@ -119,10 +129,9 @@ if __name__ == "__main__":
|
||||
},
|
||||
num_samples=1,
|
||||
config={
|
||||
"lr": tune.sample_from(lambda spec: np.random.uniform(0.001, 0.1)),
|
||||
"momentum": tune.sample_from(
|
||||
lambda spec: np.random.uniform(0.1, 0.9)),
|
||||
"hidden": tune.sample_from(
|
||||
lambda spec: np.random.randint(32, 512)),
|
||||
"use_mini": args.smoke_test,
|
||||
"lr": tune.uniform(0.001, 0.1),
|
||||
"momentum": tune.uniform(0.1, 0.9),
|
||||
"hidden": tune.randint(32, 512),
|
||||
})
|
||||
print("Best hyperparameters found were: ", analysis.best_config)
|
||||
|
||||
Reference in New Issue
Block a user