This commit is contained in:
Anton Kiselev
2019-05-19 18:44:38 +03:00
parent 80e8959ceb
commit 7544e0b6b6
+4 -4
View File
@@ -109,8 +109,8 @@ if __name__ == '__main__':
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
labels.append(torch.argmax(y, dim=1))
predictions.append(torch.argmax(logits, dim=1))
labels = torch.LongTensor(labels)
predictions = torch.LongTensor(predictions)
labels = torch.cat(labels).long()
predictions = torch.cat(predictions).long()
print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).mean()}')
model.eval()
@@ -130,6 +130,6 @@ if __name__ == '__main__':
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
labels.append(torch.argmax(y, dim=1))
predictions.append(torch.argmax(logits, dim=1))
labels = torch.LongTensor(labels)
predictions = torch.LongTensor(predictions)
labels = torch.cat(labels).long()
predictions = torch.cat(predictions).long()
print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}')