This commit is contained in:
Anton Kiselev
2019-05-19 18:46:22 +03:00
parent 7544e0b6b6
commit c368a4a90c
+2 -2
View File
@@ -111,7 +111,7 @@ if __name__ == '__main__':
predictions.append(torch.argmax(logits, dim=1))
labels = torch.cat(labels).long()
predictions = torch.cat(predictions).long()
print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).mean()}')
print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).float().mean()}')
model.eval()
labels = []
@@ -132,4 +132,4 @@ if __name__ == '__main__':
predictions.append(torch.argmax(logits, dim=1))
labels = torch.cat(labels).long()
predictions = torch.cat(predictions).long()
print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}')
print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).float().mean()}')