Config changes

This commit is contained in:
Anton Kiselev
2019-05-19 18:53:06 +03:00
parent c368a4a90c
commit 1801526b5a
+6 -5
View File
@@ -18,7 +18,7 @@ if __name__ == '__main__':
parser.add_argument('--batch-size', type=int, default=10, metavar='S')
parser.add_argument('--n_workers', type=int, default=4, metavar='S')
parser.add_argument('--adapter_size', type=int,
default=8, metavar='S')
default=64, metavar='S')
parser.add_argument('--num-threads', type=int, default=4,
help='num threads (default: 4)')
parser.add_argument('--dropout', type=float, default=0.4, metavar='D',
@@ -47,8 +47,8 @@ if __name__ == '__main__':
train_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
val_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
train_dataset.df = train_dataset.df.iloc[:-1000]
val_dataset.df = val_dataset.df.iloc[-1000:]
train_dataset.df = train_dataset.df.iloc[:int(len(train_dataset) * 0.9)]
val_dataset.df = val_dataset.df.iloc[int(len(train_dataset) * 0.9):]
test_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.test_file)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
@@ -62,11 +62,12 @@ if __name__ == '__main__':
config = AdapterConfig(
hidden_size=768, adapter_size=args.adapter_size,
adapter_act='relu', adapter_initializer_range=0.1
adapter_act='relu', adapter_initializer_range=1e-2
)
model = add_adapters(model, config)
model = ClassificationModel(model, n_labels=len(train_dataset.y_labels), dropout_prob=0.3)
model = ClassificationModel(model, n_labels=len(train_dataset.y_labels),
dropout_prob=0.0)
model.eval()
model.to(device)