mirror of
https://github.com/wassname/bert_adapter.git
synced 2026-06-27 17:00:00 +08:00
LR parameter, train accuracy.
This commit is contained in:
+15
-4
@@ -23,6 +23,8 @@ if __name__ == '__main__':
|
||||
help='num threads (default: 4)')
|
||||
parser.add_argument('--dropout', type=float, default=0.4, metavar='D',
|
||||
help='dropout rate (default: 0.4)')
|
||||
parser.add_argument('--lr', type=float, default=3e-3, metavar='D',
|
||||
help='learning rate (default: 3e-3)')
|
||||
parser.add_argument('--bert_model_name', type=str, default='bert-base-multilingual-cased',
|
||||
help='Bert model name (ex. "bert-base-multilingual-cased")')
|
||||
parser.add_argument('--tensorboard', type=str, default='default_tb', metavar='TB',
|
||||
@@ -67,16 +69,18 @@ if __name__ == '__main__':
|
||||
model = add_adapters(model, config)
|
||||
|
||||
model = ClassificationModel(model, n_labels=len(train_dataset.y_labels),
|
||||
dropout_prob=0.0)
|
||||
dropout_prob=args.dropout)
|
||||
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=0.001, amsgrad=True)
|
||||
optimizer = Adam(model.parameters(), lr=args.lr, amsgrad=True)
|
||||
|
||||
print('Model have initialized')
|
||||
for i in range(args.num_epochs):
|
||||
model.train()
|
||||
labels = []
|
||||
predictions = []
|
||||
for batch in tqdm(train_loader):
|
||||
optimizer.zero_grad()
|
||||
|
||||
@@ -88,10 +92,15 @@ if __name__ == '__main__':
|
||||
y = batch['y']
|
||||
y = y.to(device)
|
||||
|
||||
loss, _ = model.forward(input_ids, input_mask, segment_ids, labels=y)
|
||||
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
labels.append(torch.argmax(y, dim=1))
|
||||
predictions.append(torch.argmax(logits, dim=1))
|
||||
labels = torch.cat(labels).long()
|
||||
predictions = torch.cat(predictions).long()
|
||||
train_accuracy = (labels == predictions).float().mean()
|
||||
|
||||
model.eval()
|
||||
labels = []
|
||||
@@ -112,7 +121,9 @@ if __name__ == '__main__':
|
||||
predictions.append(torch.argmax(logits, dim=1))
|
||||
labels = torch.cat(labels).long()
|
||||
predictions = torch.cat(predictions).long()
|
||||
print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).float().mean()}')
|
||||
|
||||
valid_accuracy = (labels == predictions).float().mean()
|
||||
print(f'Epoch: {i}\tTrain Accuracy: {train_accuracy}\tVal Accuracy: {valid_accuracy}')
|
||||
|
||||
model.eval()
|
||||
labels = []
|
||||
|
||||
Reference in New Issue
Block a user