[tune] Fix PyTorch example after PyTorch v1 (#3500)

* [tune]

* fix

* lint

* fix
This commit is contained in:
Richard Liaw
2018-12-10 12:00:53 -08:00
committed by Philipp Moritz
parent 962f18756b
commit 1f4a01cff6
2 changed files with 4 additions and 5 deletions
+2 -3
View File
@@ -137,13 +137,12 @@ def train_mnist(args, config, reporter):
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(
output, target,
size_average=False).data[0] # sum up batch loss
output, target, size_average=False).item() # sum up batch loss
pred = output.data.max(
1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
test_loss = test_loss.item() / len(test_loader.dataset)
test_loss = test_loss / len(test_loader.dataset)
accuracy = correct.item() / len(test_loader.dataset)
reporter(mean_loss=test_loss, mean_accuracy=accuracy)
@@ -145,13 +145,13 @@ class TrainMNIST(Trainable):
output = self.model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, size_average=False).data[0]
test_loss += F.nll_loss(output, target, size_average=False).item()
# get the index of the max log-probability
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
test_loss = test_loss.item() / len(self.test_loader.dataset)
test_loss = test_loss / len(self.test_loader.dataset)
accuracy = correct.item() / len(self.test_loader.dataset)
return {"mean_loss": test_loss, "mean_accuracy": accuracy}