mirror of
https://github.com/wassname/bert_adapter.git
synced 2026-06-27 18:03:56 +08:00
Validation
This commit is contained in:
+30
-3
@@ -46,9 +46,15 @@ if __name__ == '__main__':
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model_name)
|
||||
|
||||
train_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
|
||||
val_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
|
||||
train_dataset.df = train_dataset.df.iloc[:-1000]
|
||||
val_dataset.df = val_dataset.df.iloc[-1000:]
|
||||
test_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.test_file)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.n_workers)
|
||||
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
|
||||
num_workers=args.n_workers)
|
||||
val_loader = DataLoader(val_dataset, batch_size=args.batch_size,
|
||||
num_workers=args.n_workers)
|
||||
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers)
|
||||
|
||||
# Load pre-trained model (weights)
|
||||
@@ -89,7 +95,7 @@ if __name__ == '__main__':
|
||||
model.eval()
|
||||
labels = []
|
||||
predictions = []
|
||||
for batch in tqdm(test_loader):
|
||||
for batch in tqdm(val_loader):
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_ids, input_mask, segment_ids = batch['x']
|
||||
@@ -105,4 +111,25 @@ if __name__ == '__main__':
|
||||
predictions.append(torch.argmax(logits, dim=1))
|
||||
labels = torch.LongTensor(labels)
|
||||
predictions = torch.LongTensor(predictions)
|
||||
print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}')
|
||||
print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).mean()}')
|
||||
|
||||
model.eval()
|
||||
labels = []
|
||||
predictions = []
|
||||
for batch in tqdm(test_loader):
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_ids, input_mask, segment_ids = batch['x']
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
|
||||
y = batch['y']
|
||||
y = y.to(device)
|
||||
|
||||
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
|
||||
labels.append(torch.argmax(y, dim=1))
|
||||
predictions.append(torch.argmax(logits, dim=1))
|
||||
labels = torch.LongTensor(labels)
|
||||
predictions = torch.LongTensor(predictions)
|
||||
print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}')
|
||||
|
||||
Reference in New Issue
Block a user