Validation

This commit is contained in:
Anton Kiselev
2019-05-19 18:19:09 +03:00
parent 79f3878488
commit 80e8959ceb
+30 -3
View File
@@ -46,9 +46,15 @@ if __name__ == '__main__':
tokenizer = BertTokenizer.from_pretrained(args.bert_model_name)
train_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
val_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
train_dataset.df = train_dataset.df.iloc[:-1000]
val_dataset.df = val_dataset.df.iloc[-1000:]
test_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.test_file)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.n_workers)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
num_workers=args.n_workers)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size,
num_workers=args.n_workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers)
# Load pre-trained model (weights)
@@ -89,7 +95,7 @@ if __name__ == '__main__':
model.eval()
labels = []
predictions = []
for batch in tqdm(test_loader):
for batch in tqdm(val_loader):
optimizer.zero_grad()
input_ids, input_mask, segment_ids = batch['x']
@@ -105,4 +111,25 @@ if __name__ == '__main__':
predictions.append(torch.argmax(logits, dim=1))
labels = torch.LongTensor(labels)
predictions = torch.LongTensor(predictions)
print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}')
print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).mean()}')
model.eval()
labels = []
predictions = []
for batch in tqdm(test_loader):
optimizer.zero_grad()
input_ids, input_mask, segment_ids = batch['x']
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
y = batch['y']
y = y.to(device)
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
labels.append(torch.argmax(y, dim=1))
predictions.append(torch.argmax(logits, dim=1))
labels = torch.LongTensor(labels)
predictions = torch.LongTensor(predictions)
print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}')