diff --git a/python/ray/tune/tests/tutorial.py b/python/ray/tune/tests/tutorial.py index 2aa442279..2a11f12a0 100644 --- a/python/ray/tune/tests/tutorial.py +++ b/python/ray/tune/tests/tutorial.py @@ -93,7 +93,11 @@ def train_mnist(config): batch_size=64, shuffle=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = ConvNet() + model.to(device) + optimizer = optim.SGD( model.parameters(), lr=config["lr"], momentum=config["momentum"]) for i in range(10): @@ -161,6 +165,11 @@ space = { hyperopt_search = HyperOptSearch(space, metric="mean_accuracy", mode="max") analysis = tune.run(train_mnist, num_samples=10, search_alg=hyperopt_search) + +# To enable GPUs, use this instead: +# analysis = tune.run( +# train_mnist, config=search_space, resources_per_trial={'gpu': 1}) + # __run_searchalg_end__ # __run_analysis_begin__