mirror of
https://github.com/wassname/bert_adapter.git
synced 2026-06-27 16:29:32 +08:00
fix
This commit is contained in:
+4
-4
@@ -109,8 +109,8 @@ if __name__ == '__main__':
|
||||
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
|
||||
labels.append(torch.argmax(y, dim=1))
|
||||
predictions.append(torch.argmax(logits, dim=1))
|
||||
labels = torch.LongTensor(labels)
|
||||
predictions = torch.LongTensor(predictions)
|
||||
labels = torch.cat(labels).long()
|
||||
predictions = torch.cat(predictions).long()
|
||||
print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).mean()}')
|
||||
|
||||
model.eval()
|
||||
@@ -130,6 +130,6 @@ if __name__ == '__main__':
|
||||
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
|
||||
labels.append(torch.argmax(y, dim=1))
|
||||
predictions.append(torch.argmax(logits, dim=1))
|
||||
labels = torch.LongTensor(labels)
|
||||
predictions = torch.LongTensor(predictions)
|
||||
labels = torch.cat(labels).long()
|
||||
predictions = torch.cat(predictions).long()
|
||||
print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}')
|
||||
|
||||
Reference in New Issue
Block a user