mirror of
https://github.com/wassname/bert_adapter.git
synced 2026-06-27 16:29:32 +08:00
Tweaks in configs
This commit is contained in:
+5
-5
@@ -12,13 +12,14 @@ from modules import add_adapters, AdapterConfig, ClassificationModel
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# TODO parameters
|
||||
parser = argparse.ArgumentParser(description='bert_adapter')
|
||||
parser.add_argument('--num_epochs', type=int, default=5, metavar='NI',
|
||||
help='num epochs (default: 5)')
|
||||
parser.add_argument('--batch-size', type=int, default=10, metavar='S')
|
||||
parser.add_argument('--n_workers', type=int, default=2, metavar='S')
|
||||
parser.add_argument('--num-threads', type=int, default=2,
|
||||
parser.add_argument('--n_workers', type=int, default=4, metavar='S')
|
||||
parser.add_argument('--adapter_size', type=int,
|
||||
default=8, metavar='S')
|
||||
parser.add_argument('--num-threads', type=int, default=4,
|
||||
help='num threads (default: 4)')
|
||||
parser.add_argument('--dropout', type=float, default=0.4, metavar='D',
|
||||
help='dropout rate (default: 0.4)')
|
||||
@@ -54,7 +55,7 @@ if __name__ == '__main__':
|
||||
model = BertModel.from_pretrained(args.bert_model_name)
|
||||
|
||||
config = AdapterConfig(
|
||||
hidden_size=768, adapter_size=5,
|
||||
hidden_size=768, adapter_size=args.adapter_size,
|
||||
adapter_act='relu', adapter_initializer_range=0.1
|
||||
)
|
||||
model = add_adapters(model, config)
|
||||
@@ -99,7 +100,6 @@ if __name__ == '__main__':
|
||||
y = batch['y']
|
||||
y = y.to(device)
|
||||
|
||||
|
||||
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
|
||||
labels.append(torch.argmax(y, dim=1))
|
||||
predictions.append(torch.argmax(logits, dim=1))
|
||||
|
||||
Reference in New Issue
Block a user