diff --git a/training.py b/training.py index 2308bcf..69b5fe5 100644 --- a/training.py +++ b/training.py @@ -111,7 +111,7 @@ if __name__ == '__main__': predictions.append(torch.argmax(logits, dim=1)) labels = torch.cat(labels).long() predictions = torch.cat(predictions).long() - print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).mean()}') + print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).float().mean()}') model.eval() labels = [] @@ -132,4 +132,4 @@ if __name__ == '__main__': predictions.append(torch.argmax(logits, dim=1)) labels = torch.cat(labels).long() predictions = torch.cat(predictions).long() - print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}') + print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).float().mean()}')