From 1f4a01cff68ad20f3cfc083b6a10d197e5a090b4 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Mon, 10 Dec 2018 12:00:53 -0800 Subject: [PATCH] [tune] Fix PyTorch example after PyTorch v1 (#3500) * [tune] * fix * lint * fix --- python/ray/tune/examples/mnist_pytorch.py | 5 ++--- python/ray/tune/examples/mnist_pytorch_trainable.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/ray/tune/examples/mnist_pytorch.py b/python/ray/tune/examples/mnist_pytorch.py index bec73f3d5..d9e336a76 100644 --- a/python/ray/tune/examples/mnist_pytorch.py +++ b/python/ray/tune/examples/mnist_pytorch.py @@ -137,13 +137,12 @@ def train_mnist(args, config, reporter): data, target = Variable(data, volatile=True), Variable(target) output = model(data) test_loss += F.nll_loss( - output, target, - size_average=False).data[0] # sum up batch loss + output, target, size_average=False).item() # sum up batch loss pred = output.data.max( 1, keepdim=True)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() - test_loss = test_loss.item() / len(test_loader.dataset) + test_loss = test_loss / len(test_loader.dataset) accuracy = correct.item() / len(test_loader.dataset) reporter(mean_loss=test_loss, mean_accuracy=accuracy) diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py index 6005cd79c..75da205f1 100644 --- a/python/ray/tune/examples/mnist_pytorch_trainable.py +++ b/python/ray/tune/examples/mnist_pytorch_trainable.py @@ -145,13 +145,13 @@ class TrainMNIST(Trainable): output = self.model(data) # sum up batch loss - test_loss += F.nll_loss(output, target, size_average=False).data[0] + test_loss += F.nll_loss(output, target, size_average=False).item() # get the index of the max log-probability pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() - test_loss = test_loss.item() / len(self.test_loader.dataset) + test_loss = test_loss / len(self.test_loader.dataset) accuracy = correct.item() / len(self.test_loader.dataset) return {"mean_loss": test_loss, "mean_accuracy": accuracy}