From c368a4a90c487ee019879b05de3d4071b14b196a Mon Sep 17 00:00:00 2001 From: Anton Kiselev Date: Sun, 19 May 2019 18:46:22 +0300 Subject: [PATCH] fix --- training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/training.py b/training.py index 2308bcf..69b5fe5 100644 --- a/training.py +++ b/training.py @@ -111,7 +111,7 @@ if __name__ == '__main__': predictions.append(torch.argmax(logits, dim=1)) labels = torch.cat(labels).long() predictions = torch.cat(predictions).long() - print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).mean()}') + print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).float().mean()}') model.eval() labels = [] @@ -132,4 +132,4 @@ if __name__ == '__main__': predictions.append(torch.argmax(logits, dim=1)) labels = torch.cat(labels).long() predictions = torch.cat(predictions).long() - print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}') + print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).float().mean()}')