diff --git a/training.py b/training.py index 60f63d0..2308bcf 100644 --- a/training.py +++ b/training.py @@ -109,8 +109,8 @@ if __name__ == '__main__': loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y) labels.append(torch.argmax(y, dim=1)) predictions.append(torch.argmax(logits, dim=1)) - labels = torch.LongTensor(labels) - predictions = torch.LongTensor(predictions) + labels = torch.cat(labels).long() + predictions = torch.cat(predictions).long() print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).mean()}') model.eval() @@ -130,6 +130,6 @@ if __name__ == '__main__': loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y) labels.append(torch.argmax(y, dim=1)) predictions.append(torch.argmax(logits, dim=1)) - labels = torch.LongTensor(labels) - predictions = torch.LongTensor(predictions) + labels = torch.cat(labels).long() + predictions = torch.cat(predictions).long() print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}')