mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:23:10 +08:00
[tune] Fix tutorial training on GPU (#12914)
This commit is contained in:
@@ -93,7 +93,11 @@ def train_mnist(config):
|
||||
batch_size=64,
|
||||
shuffle=True)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = ConvNet()
|
||||
model.to(device)
|
||||
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr=config["lr"], momentum=config["momentum"])
|
||||
for i in range(10):
|
||||
@@ -161,6 +165,11 @@ space = {
|
||||
hyperopt_search = HyperOptSearch(space, metric="mean_accuracy", mode="max")
|
||||
|
||||
analysis = tune.run(train_mnist, num_samples=10, search_alg=hyperopt_search)
|
||||
|
||||
# To enable GPUs, use this instead:
|
||||
# analysis = tune.run(
|
||||
# train_mnist, config=search_space, resources_per_trial={'gpu': 1})
|
||||
|
||||
# __run_searchalg_end__
|
||||
|
||||
# __run_analysis_begin__
|
||||
|
||||
Reference in New Issue
Block a user