mirror of
https://github.com/wassname/bert_adapter.git
synced 2026-06-27 19:30:10 +08:00
Config changes
This commit is contained in:
+6
-5
@@ -18,7 +18,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--batch-size', type=int, default=10, metavar='S')
|
||||
parser.add_argument('--n_workers', type=int, default=4, metavar='S')
|
||||
parser.add_argument('--adapter_size', type=int,
|
||||
default=8, metavar='S')
|
||||
default=64, metavar='S')
|
||||
parser.add_argument('--num-threads', type=int, default=4,
|
||||
help='num threads (default: 4)')
|
||||
parser.add_argument('--dropout', type=float, default=0.4, metavar='D',
|
||||
@@ -47,8 +47,8 @@ if __name__ == '__main__':
|
||||
|
||||
train_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
|
||||
val_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
|
||||
train_dataset.df = train_dataset.df.iloc[:-1000]
|
||||
val_dataset.df = val_dataset.df.iloc[-1000:]
|
||||
train_dataset.df = train_dataset.df.iloc[:int(len(train_dataset) * 0.9)]
|
||||
val_dataset.df = val_dataset.df.iloc[int(len(train_dataset) * 0.9):]
|
||||
test_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.test_file)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
|
||||
@@ -62,11 +62,12 @@ if __name__ == '__main__':
|
||||
|
||||
config = AdapterConfig(
|
||||
hidden_size=768, adapter_size=args.adapter_size,
|
||||
adapter_act='relu', adapter_initializer_range=0.1
|
||||
adapter_act='relu', adapter_initializer_range=1e-2
|
||||
)
|
||||
model = add_adapters(model, config)
|
||||
|
||||
model = ClassificationModel(model, n_labels=len(train_dataset.y_labels), dropout_prob=0.3)
|
||||
model = ClassificationModel(model, n_labels=len(train_dataset.y_labels),
|
||||
dropout_prob=0.0)
|
||||
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
Reference in New Issue
Block a user