mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 11:51:09 +08:00
[tune] Fix PyTorch example after PyTorch v1 (#3500)
* [tune] * fix * lint * fix
This commit is contained in:
committed by
Philipp Moritz
parent
962f18756b
commit
1f4a01cff6
@@ -137,13 +137,12 @@ def train_mnist(args, config, reporter):
|
||||
data, target = Variable(data, volatile=True), Variable(target)
|
||||
output = model(data)
|
||||
test_loss += F.nll_loss(
|
||||
output, target,
|
||||
size_average=False).data[0] # sum up batch loss
|
||||
output, target, size_average=False).item() # sum up batch loss
|
||||
pred = output.data.max(
|
||||
1, keepdim=True)[1] # get the index of the max log-probability
|
||||
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
|
||||
|
||||
test_loss = test_loss.item() / len(test_loader.dataset)
|
||||
test_loss = test_loss / len(test_loader.dataset)
|
||||
accuracy = correct.item() / len(test_loader.dataset)
|
||||
reporter(mean_loss=test_loss, mean_accuracy=accuracy)
|
||||
|
||||
|
||||
@@ -145,13 +145,13 @@ class TrainMNIST(Trainable):
|
||||
output = self.model(data)
|
||||
|
||||
# sum up batch loss
|
||||
test_loss += F.nll_loss(output, target, size_average=False).data[0]
|
||||
test_loss += F.nll_loss(output, target, size_average=False).item()
|
||||
|
||||
# get the index of the max log-probability
|
||||
pred = output.data.max(1, keepdim=True)[1]
|
||||
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
|
||||
|
||||
test_loss = test_loss.item() / len(self.test_loader.dataset)
|
||||
test_loss = test_loss / len(self.test_loader.dataset)
|
||||
accuracy = correct.item() / len(self.test_loader.dataset)
|
||||
return {"mean_loss": test_loss, "mean_accuracy": accuracy}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user