diff --git a/training.py b/training.py index 184fe54..60f63d0 100644 --- a/training.py +++ b/training.py @@ -46,9 +46,15 @@ if __name__ == '__main__': tokenizer = BertTokenizer.from_pretrained(args.bert_model_name) train_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file) + val_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file) + train_dataset.df = train_dataset.df.iloc[:-1000] + val_dataset.df = val_dataset.df.iloc[-1000:] test_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.test_file) - train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.n_workers) + train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, + num_workers=args.n_workers) + val_loader = DataLoader(val_dataset, batch_size=args.batch_size, + num_workers=args.n_workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers) # Load pre-trained model (weights) @@ -89,7 +95,7 @@ if __name__ == '__main__': model.eval() labels = [] predictions = [] - for batch in tqdm(test_loader): + for batch in tqdm(val_loader): optimizer.zero_grad() input_ids, input_mask, segment_ids = batch['x'] @@ -105,4 +111,25 @@ if __name__ == '__main__': predictions.append(torch.argmax(logits, dim=1)) labels = torch.LongTensor(labels) predictions = torch.LongTensor(predictions) - print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}') + print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).mean()}') + + model.eval() + labels = [] + predictions = [] + for batch in tqdm(test_loader): + optimizer.zero_grad() + + input_ids, input_mask, segment_ids = batch['x'] + input_ids = input_ids.to(device) + input_mask = input_mask.to(device) + segment_ids = segment_ids.to(device) + + y = batch['y'] + y = y.to(device) + + loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y) + labels.append(torch.argmax(y, dim=1)) + predictions.append(torch.argmax(logits, dim=1)) + labels = torch.LongTensor(labels) + predictions = torch.LongTensor(predictions) + print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}')