mirror of
https://github.com/wassname/bert_adapter.git
synced 2026-06-27 17:47:35 +08:00
148 lines
5.8 KiB
Python
148 lines
5.8 KiB
Python
import argparse
|
|
|
|
import torch
|
|
from pytorch_pretrained_bert import BertTokenizer, BertModel
|
|
from torch.optim import Adam
|
|
from torch.utils.data import DataLoader
|
|
from tensorboardX import SummaryWriter
|
|
from tqdm import tqdm
|
|
|
|
from data import TokenizedDataFrameDataset
|
|
from modules import add_adapters, AdapterConfig, ClassificationModel
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser(description='bert_adapter')
|
|
parser.add_argument('--num_epochs', type=int, default=5, metavar='NI',
|
|
help='num epochs (default: 5)')
|
|
parser.add_argument('--batch-size', type=int, default=10, metavar='S')
|
|
parser.add_argument('--n_workers', type=int, default=4, metavar='S')
|
|
parser.add_argument('--adapter_size', type=int,
|
|
default=64, metavar='S')
|
|
parser.add_argument('--num-threads', type=int, default=4,
|
|
help='num threads (default: 4)')
|
|
parser.add_argument('--dropout', type=float, default=0.4, metavar='D',
|
|
help='dropout rate (default: 0.4)')
|
|
parser.add_argument('--lr', type=float, default=3e-3, metavar='D',
|
|
help='learning rate (default: 3e-3)')
|
|
parser.add_argument('--bert_model_name', type=str, default='bert-base-multilingual-cased',
|
|
help='Bert model name (ex. "bert-base-multilingual-cased")')
|
|
parser.add_argument('--tensorboard', type=str, default='default_tb', metavar='TB',
|
|
help='Name for tensorboard model')
|
|
parser.add_argument('--train_file', type=str,
|
|
default='./data/rusentiment/rusentiment_random_posts.csv',
|
|
metavar='TB',
|
|
help='Path to RuSentiment train')
|
|
parser.add_argument('--test_file', type=str,
|
|
default='./data/rusentiment/rusentiment_test.csv',
|
|
metavar='TB',
|
|
help='Path to RuSentiment test')
|
|
args = parser.parse_args()
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
writer = SummaryWriter(args.tensorboard)
|
|
|
|
torch.set_num_threads(args.num_threads)
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model_name)
|
|
|
|
train_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
|
|
val_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
|
|
train_dataset.df = train_dataset.df.iloc[:int(len(train_dataset) * 0.9)]
|
|
val_dataset.df = val_dataset.df.iloc[int(len(train_dataset) * 0.9):]
|
|
test_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.test_file)
|
|
|
|
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
|
|
num_workers=args.n_workers)
|
|
val_loader = DataLoader(val_dataset, batch_size=args.batch_size,
|
|
num_workers=args.n_workers)
|
|
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers)
|
|
|
|
# Load pre-trained model (weights)
|
|
model = BertModel.from_pretrained(args.bert_model_name)
|
|
|
|
config = AdapterConfig(
|
|
hidden_size=768, adapter_size=args.adapter_size,
|
|
adapter_act='relu', adapter_initializer_range=1e-2
|
|
)
|
|
model = add_adapters(model, config)
|
|
|
|
model = ClassificationModel(model, n_labels=len(train_dataset.y_labels),
|
|
dropout_prob=args.dropout)
|
|
|
|
model.eval()
|
|
model.to(device)
|
|
|
|
optimizer = Adam(model.parameters(), lr=args.lr, amsgrad=True)
|
|
|
|
print('Model have initialized')
|
|
for i in range(args.num_epochs):
|
|
model.train()
|
|
labels = []
|
|
predictions = []
|
|
for batch in tqdm(train_loader):
|
|
optimizer.zero_grad()
|
|
|
|
input_ids, input_mask, segment_ids = batch['x']
|
|
input_ids = input_ids.to(device)
|
|
input_mask = input_mask.to(device)
|
|
segment_ids = segment_ids.to(device)
|
|
|
|
y = batch['y']
|
|
y = y.to(device)
|
|
|
|
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
labels.append(torch.argmax(y, dim=1))
|
|
predictions.append(torch.argmax(logits, dim=1))
|
|
labels = torch.cat(labels).long()
|
|
predictions = torch.cat(predictions).long()
|
|
train_accuracy = (labels == predictions).float().mean()
|
|
|
|
model.eval()
|
|
labels = []
|
|
predictions = []
|
|
for batch in tqdm(val_loader):
|
|
optimizer.zero_grad()
|
|
|
|
input_ids, input_mask, segment_ids = batch['x']
|
|
input_ids = input_ids.to(device)
|
|
input_mask = input_mask.to(device)
|
|
segment_ids = segment_ids.to(device)
|
|
|
|
y = batch['y']
|
|
y = y.to(device)
|
|
|
|
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
|
|
labels.append(torch.argmax(y, dim=1))
|
|
predictions.append(torch.argmax(logits, dim=1))
|
|
labels = torch.cat(labels).long()
|
|
predictions = torch.cat(predictions).long()
|
|
|
|
valid_accuracy = (labels == predictions).float().mean()
|
|
print(f'Epoch: {i}\tTrain Accuracy: {train_accuracy}\tVal Accuracy: {valid_accuracy}')
|
|
|
|
model.eval()
|
|
labels = []
|
|
predictions = []
|
|
for batch in tqdm(test_loader):
|
|
optimizer.zero_grad()
|
|
|
|
input_ids, input_mask, segment_ids = batch['x']
|
|
input_ids = input_ids.to(device)
|
|
input_mask = input_mask.to(device)
|
|
segment_ids = segment_ids.to(device)
|
|
|
|
y = batch['y']
|
|
y = y.to(device)
|
|
|
|
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
|
|
labels.append(torch.argmax(y, dim=1))
|
|
predictions.append(torch.argmax(logits, dim=1))
|
|
labels = torch.cat(labels).long()
|
|
predictions = torch.cat(predictions).long()
|
|
print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).float().mean()}')
|