From 426f8a8d15449c5e78efc4604e2b0fd4a8577fb8 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 18 Dec 2020 10:31:40 +0100 Subject: [PATCH] [tune] Fix tutorial training on GPU (#12914) --- python/ray/tune/tests/tutorial.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/ray/tune/tests/tutorial.py b/python/ray/tune/tests/tutorial.py index 2aa442279..2a11f12a0 100644 --- a/python/ray/tune/tests/tutorial.py +++ b/python/ray/tune/tests/tutorial.py @@ -93,7 +93,11 @@ def train_mnist(config): batch_size=64, shuffle=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = ConvNet() + model.to(device) + optimizer = optim.SGD( model.parameters(), lr=config["lr"], momentum=config["momentum"]) for i in range(10): @@ -161,6 +165,11 @@ space = { hyperopt_search = HyperOptSearch(space, metric="mean_accuracy", mode="max") analysis = tune.run(train_mnist, num_samples=10, search_alg=hyperopt_search) + +# To enable GPUs, use this instead: +# analysis = tune.run( +# train_mnist, config=search_space, resources_per_trial={'gpu': 1}) + # __run_searchalg_end__ # __run_analysis_begin__