[tune] Fix tutorial training on GPU (#12914)

This commit is contained in:
Kai Fricke
2020-12-18 10:31:40 +01:00
committed by GitHub
parent a442cd17e0
commit 426f8a8d15
+9
View File
@@ -93,7 +93,11 @@ def train_mnist(config):
batch_size=64,
shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvNet()
model.to(device)
optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"])
for i in range(10):
@@ -161,6 +165,11 @@ space = {
hyperopt_search = HyperOptSearch(space, metric="mean_accuracy", mode="max")
analysis = tune.run(train_mnist, num_samples=10, search_alg=hyperopt_search)
# To enable GPUs, use this instead:
# analysis = tune.run(
# train_mnist, config=search_space, resources_per_trial={'gpu': 1})
# __run_searchalg_end__
# __run_analysis_begin__