mirror of
https://github.com/wassname/bert_adapter.git
synced 2026-06-27 16:44:45 +08:00
Changes in configuration.
This commit is contained in:
+3
-1
@@ -1 +1,3 @@
|
|||||||
tensorboardX
|
tensorboardX
|
||||||
|
pytorch_pretrained_bert
|
||||||
|
tqdm
|
||||||
+7
-5
@@ -16,12 +16,14 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser(description='bert_adapter')
|
parser = argparse.ArgumentParser(description='bert_adapter')
|
||||||
parser.add_argument('--num_epochs', type=int, default=5, metavar='NI',
|
parser.add_argument('--num_epochs', type=int, default=5, metavar='NI',
|
||||||
help='num epochs (default: 5)')
|
help='num epochs (default: 5)')
|
||||||
parser.add_argument('--batch-size', type=int, default=50, metavar='S')
|
parser.add_argument('--batch-size', type=int, default=10, metavar='S')
|
||||||
parser.add_argument('--n_workers', type=int, default=4, metavar='S')
|
parser.add_argument('--n_workers', type=int, default=2, metavar='S')
|
||||||
parser.add_argument('--num-threads', type=int, default=4, metavar='BS',
|
parser.add_argument('--num-threads', type=int, default=2,
|
||||||
help='num threads (default: 4)')
|
help='num threads (default: 4)')
|
||||||
parser.add_argument('--dropout', type=float, default=0.4, metavar='D',
|
parser.add_argument('--dropout', type=float, default=0.4, metavar='D',
|
||||||
help='dropout rate (default: 0.4)')
|
help='dropout rate (default: 0.4)')
|
||||||
|
parser.add_argument('--bert_model_name', type=str, default='bert-base-multilingual-cased',
|
||||||
|
help='Bert model name (ex. "bert-base-multilingual-cased")')
|
||||||
parser.add_argument('--tensorboard', type=str, default='default_tb', metavar='TB',
|
parser.add_argument('--tensorboard', type=str, default='default_tb', metavar='TB',
|
||||||
help='Name for tensorboard model')
|
help='Name for tensorboard model')
|
||||||
parser.add_argument('--train_file', type=str,
|
parser.add_argument('--train_file', type=str,
|
||||||
@@ -40,7 +42,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
torch.set_num_threads(args.num_threads)
|
torch.set_num_threads(args.num_threads)
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model_name)
|
||||||
|
|
||||||
train_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
|
train_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
|
||||||
test_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.test_file)
|
test_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.test_file)
|
||||||
@@ -49,7 +51,7 @@ if __name__ == '__main__':
|
|||||||
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers)
|
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers)
|
||||||
|
|
||||||
# Load pre-trained model (weights)
|
# Load pre-trained model (weights)
|
||||||
model = BertModel.from_pretrained('bert-base-uncased')
|
model = BertModel.from_pretrained(args.bert_model_name)
|
||||||
|
|
||||||
config = AdapterConfig(
|
config = AdapterConfig(
|
||||||
hidden_size=768, adapter_size=5,
|
hidden_size=768, adapter_size=5,
|
||||||
|
|||||||
Reference in New Issue
Block a user