Tweaks in configs

This commit is contained in:
Anton Kiselev
2019-05-19 18:14:24 +03:00
parent 96d37e23e5
commit 79f3878488
+5 -5
View File
@@ -12,13 +12,14 @@ from modules import add_adapters, AdapterConfig, ClassificationModel
if __name__ == '__main__':
# TODO parameters
parser = argparse.ArgumentParser(description='bert_adapter')
parser.add_argument('--num_epochs', type=int, default=5, metavar='NI',
help='num epochs (default: 5)')
parser.add_argument('--batch-size', type=int, default=10, metavar='S')
parser.add_argument('--n_workers', type=int, default=2, metavar='S')
parser.add_argument('--num-threads', type=int, default=2,
parser.add_argument('--n_workers', type=int, default=4, metavar='S')
parser.add_argument('--adapter_size', type=int,
default=8, metavar='S')
parser.add_argument('--num-threads', type=int, default=4,
help='num threads (default: 4)')
parser.add_argument('--dropout', type=float, default=0.4, metavar='D',
help='dropout rate (default: 0.4)')
@@ -54,7 +55,7 @@ if __name__ == '__main__':
model = BertModel.from_pretrained(args.bert_model_name)
config = AdapterConfig(
hidden_size=768, adapter_size=5,
hidden_size=768, adapter_size=args.adapter_size,
adapter_act='relu', adapter_initializer_range=0.1
)
model = add_adapters(model, config)
@@ -99,7 +100,6 @@ if __name__ == '__main__':
y = batch['y']
y = y.to(device)
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
labels.append(torch.argmax(y, dim=1))
predictions.append(torch.argmax(logits, dim=1))