LR parameter, train accuracy.

This commit is contained in:
Anton Kiselev
2019-05-19 19:03:37 +03:00
parent 1801526b5a
commit 775560263a
+15 -4
View File
@@ -23,6 +23,8 @@ if __name__ == '__main__':
help='num threads (default: 4)')
parser.add_argument('--dropout', type=float, default=0.4, metavar='D',
help='dropout rate (default: 0.4)')
parser.add_argument('--lr', type=float, default=3e-3, metavar='D',
help='learning rate (default: 3e-3)')
parser.add_argument('--bert_model_name', type=str, default='bert-base-multilingual-cased',
help='Bert model name (ex. "bert-base-multilingual-cased")')
parser.add_argument('--tensorboard', type=str, default='default_tb', metavar='TB',
@@ -67,16 +69,18 @@ if __name__ == '__main__':
model = add_adapters(model, config)
model = ClassificationModel(model, n_labels=len(train_dataset.y_labels),
dropout_prob=0.0)
dropout_prob=args.dropout)
model.eval()
model.to(device)
optimizer = Adam(model.parameters(), lr=0.001, amsgrad=True)
optimizer = Adam(model.parameters(), lr=args.lr, amsgrad=True)
print('Model have initialized')
for i in range(args.num_epochs):
model.train()
labels = []
predictions = []
for batch in tqdm(train_loader):
optimizer.zero_grad()
@@ -88,10 +92,15 @@ if __name__ == '__main__':
y = batch['y']
y = y.to(device)
loss, _ = model.forward(input_ids, input_mask, segment_ids, labels=y)
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
loss.backward()
optimizer.step()
labels.append(torch.argmax(y, dim=1))
predictions.append(torch.argmax(logits, dim=1))
labels = torch.cat(labels).long()
predictions = torch.cat(predictions).long()
train_accuracy = (labels == predictions).float().mean()
model.eval()
labels = []
@@ -112,7 +121,9 @@ if __name__ == '__main__':
predictions.append(torch.argmax(logits, dim=1))
labels = torch.cat(labels).long()
predictions = torch.cat(predictions).long()
print(f'Epoch: {i}\tVal Accuracy: {(labels == predictions).float().mean()}')
valid_accuracy = (labels == predictions).float().mean()
print(f'Epoch: {i}\tTrain Accuracy: {train_accuracy}\tVal Accuracy: {valid_accuracy}')
model.eval()
labels = []