mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:12:00 +08:00
[tune] Improve mnist_pytorch.py example (#3894)
## What do these changes do? * Improved --no-cuda handling * Removed deprecated Variable usage ## Related issue number Fixes #3873 <!-- Are there any issues opened that will be resolved by merging this change? -->
This commit is contained in:
@@ -8,7 +8,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
from torch.autograd import Variable
|
||||
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
@@ -120,7 +119,6 @@ def train_mnist(args, config, reporter):
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
data, target = Variable(data), Variable(target)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
@@ -131,16 +129,17 @@ def train_mnist(args, config, reporter):
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
for data, target in test_loader:
|
||||
if args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
data, target = Variable(data, volatile=True), Variable(target)
|
||||
output = model(data)
|
||||
test_loss += F.nll_loss(
|
||||
output, target, size_average=False).item() # sum up batch loss
|
||||
pred = output.data.max(
|
||||
1, keepdim=True)[1] # get the index of the max log-probability
|
||||
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
if args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
output = model(data)
|
||||
# sum up batch loss
|
||||
test_loss += F.nll_loss(output, target, reduction='sum').item()
|
||||
# get the index of the max log-probability
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(
|
||||
target.data.view_as(pred)).long().cpu().sum()
|
||||
|
||||
test_loss = test_loss / len(test_loader.dataset)
|
||||
accuracy = correct.item() / len(test_loader.dataset)
|
||||
@@ -176,7 +175,8 @@ if __name__ == '__main__':
|
||||
"training_iteration": 1 if args.smoke_test else 20
|
||||
},
|
||||
"resources_per_trial": {
|
||||
"cpu": 3
|
||||
"cpu": 3,
|
||||
"gpu": int(not args.no_cuda)
|
||||
},
|
||||
"run": "train_mnist",
|
||||
"num_samples": 1 if args.smoke_test else 10,
|
||||
|
||||
@@ -9,7 +9,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
from torch.autograd import Variable
|
||||
|
||||
from ray.tune import Trainable
|
||||
|
||||
@@ -127,7 +126,6 @@ class TrainMNIST(Trainable):
|
||||
for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||
if self.args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
data, target = Variable(data), Variable(target)
|
||||
self.optimizer.zero_grad()
|
||||
output = self.model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
@@ -138,18 +136,17 @@ class TrainMNIST(Trainable):
|
||||
self.model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
for data, target in self.test_loader:
|
||||
if self.args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
data, target = Variable(data, volatile=True), Variable(target)
|
||||
output = self.model(data)
|
||||
|
||||
# sum up batch loss
|
||||
test_loss += F.nll_loss(output, target, size_average=False).item()
|
||||
|
||||
# get the index of the max log-probability
|
||||
pred = output.data.max(1, keepdim=True)[1]
|
||||
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
|
||||
with torch.no_grad():
|
||||
for data, target in self.test_loader:
|
||||
if self.args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
output = self.model(data)
|
||||
# sum up batch loss
|
||||
test_loss += F.nll_loss(output, target, reduction='sum').item()
|
||||
# get the index of the max log-probability
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(
|
||||
target.data.view_as(pred)).long().cpu().sum()
|
||||
|
||||
test_loss = test_loss / len(self.test_loader.dataset)
|
||||
accuracy = correct.item() / len(self.test_loader.dataset)
|
||||
@@ -188,7 +185,8 @@ if __name__ == '__main__':
|
||||
"training_iteration": 1 if args.smoke_test else 20,
|
||||
},
|
||||
"resources_per_trial": {
|
||||
"cpu": 3
|
||||
"cpu": 3,
|
||||
"gpu": int(not args.no_cuda)
|
||||
},
|
||||
"run": TrainMNIST,
|
||||
"num_samples": 1 if args.smoke_test else 20,
|
||||
|
||||
Reference in New Issue
Block a user