mirror of
https://github.com/wassname/bert_adapter.git
synced 2026-06-27 19:30:10 +08:00
Initial implementation.
This commit is contained in:
@@ -0,0 +1,70 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
import pandas as pd
|
||||
from torch.utils.data import Dataset
|
||||
from pytorch_pretrained_bert import BertTokenizer
|
||||
|
||||
|
||||
class BertInput(NamedTuple):
|
||||
input_ids: torch.LongTensor
|
||||
input_mask: torch.LongTensor
|
||||
segment_ids: torch.LongTensor
|
||||
|
||||
|
||||
class TokenizedDataFrameDataset(Dataset):
|
||||
def __init__(self,
|
||||
tokenizer: BertTokenizer,
|
||||
file_path: str,
|
||||
x_label: str = 'text',
|
||||
y_label: str = 'label',
|
||||
max_seq_len: int = 20):
|
||||
"""
|
||||
:param data_path: path to data
|
||||
"""
|
||||
self.tokenizer = tokenizer
|
||||
self.x_label = x_label
|
||||
self.y_label = y_label
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.df = pd.read_csv(file_path)
|
||||
self.df[y_label] = self.df[y_label].astype('category')
|
||||
|
||||
self.y_labels = self.df[y_label].cat.categories
|
||||
self.df[y_label] = self.df[y_label].cat.codes
|
||||
|
||||
def preprocess_text(self, text: str) -> BertInput:
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
segment_ids = [0] * len(tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
|
||||
padding = [0] * (self.max_seq_len - len(input_ids))
|
||||
input_ids += padding
|
||||
input_mask += padding
|
||||
segment_ids += padding
|
||||
|
||||
assert len(input_ids) == self.max_seq_len
|
||||
assert len(input_mask) == self.max_seq_len
|
||||
assert len(segment_ids) == self.max_seq_len
|
||||
|
||||
return BertInput(*[torch.LongTensor(_) for _ in [input_ids, input_mask, segment_ids]])
|
||||
|
||||
def preprocess_label(self, label: int):
|
||||
result = torch.zeros(len(self.y_labels))
|
||||
result[label] = 1
|
||||
return result
|
||||
|
||||
def __getitem__(self, index) -> dict:
|
||||
sample = self.df.iloc[index]
|
||||
x = sample[self.x_label]
|
||||
y = sample[self.y_label]
|
||||
return {
|
||||
'x': self.preprocess_text(x),
|
||||
'y': self.preprocess_label(y)
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
+98
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
from typing import NamedTuple, Callable, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from pytorch_pretrained_bert.modeling import ACT2FN, BertLayerNorm, BertModel, BertSelfOutput
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class AdapterConfig(NamedTuple):
|
||||
hidden_size: int
|
||||
adapter_size: int
|
||||
adapter_act: Union[str, Callable]
|
||||
adapter_initializer_range: float
|
||||
|
||||
|
||||
class Adapter(nn.Module):
|
||||
def __init__(self, config: AdapterConfig):
|
||||
super(Adapter, self).__init__()
|
||||
self.down_project = nn.Linear(config.hidden_size, config.adapter_size)
|
||||
nn.init.normal_(self.down_project.weight, std=config.adapter_initializer_range)
|
||||
nn.init.zeros_(self.down_project.bias)
|
||||
|
||||
if isinstance(config.adapter_act, str):
|
||||
self.activation = ACT2FN[config.adapter_act]
|
||||
else:
|
||||
self.activation = config.adapter_act
|
||||
|
||||
self.up_project = nn.Linear(config.adapter_size, config.hidden_size)
|
||||
nn.init.normal_(self.up_project.weight, std=config.adapter_initializer_range)
|
||||
nn.init.zeros_(self.up_project.bias)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
down_projected = self.down_project(hidden_states)
|
||||
activated = self.activation(down_projected)
|
||||
up_projected = self.up_project(activated)
|
||||
return hidden_states + up_projected
|
||||
|
||||
|
||||
class BertAdaptedSelfOutput(nn.Module):
|
||||
def __init__(self,
|
||||
self_output: BertSelfOutput,
|
||||
config: AdapterConfig):
|
||||
super(BertAdaptedSelfOutput, self).__init__()
|
||||
self.self_output = self_output
|
||||
self.adapter = Adapter(config)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.self_output.dense(hidden_states)
|
||||
hidden_states = self.self_output.dropout(hidden_states)
|
||||
hidden_states = self.adapter(hidden_states)
|
||||
hidden_states = self.self_output.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def adapt_bert_self_output(config: AdapterConfig):
|
||||
return lambda self_output: BertAdaptedSelfOutput(self_output, config=config)
|
||||
|
||||
|
||||
def add_adapters(bert_model: BertModel,
|
||||
config: AdapterConfig) -> BertModel:
|
||||
bert_encoder = bert_model.encoder
|
||||
for i in range(len(bert_model.encoder.layer)):
|
||||
bert_encoder.layer[i].attention.output = adapt_bert_self_output(config)(
|
||||
bert_encoder.layer[i].attention.output)
|
||||
|
||||
# Freeze all parameters
|
||||
for param in bert_model.parameters():
|
||||
param.requires_grad = False
|
||||
# Unfreeze trainable parts — layer norms and adapters
|
||||
for name, sub_module in bert_model.named_modules():
|
||||
if isinstance(sub_module, (Adapter, BertLayerNorm)):
|
||||
for param_name, param in sub_module.named_parameters():
|
||||
param.requires_grad = True
|
||||
return bert_model
|
||||
|
||||
|
||||
class ClassificationModel(nn.Module):
|
||||
def __init__(self, bert: BertModel, n_labels: int, dropout_prob: float):
|
||||
super(ClassificationModel, self).__init__()
|
||||
self.n_labels = n_labels
|
||||
self.bert = bert
|
||||
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
self.classifier = nn.Linear(bert.pooler.dense.out_features, n_labels)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
||||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
output_all_encoded_layers=False)
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
if labels is not None:
|
||||
loss_function = nn.CrossEntropyLoss()
|
||||
loss = loss_function(logits.view(-1, self.n_labels), labels.view(-1))
|
||||
return loss, logits
|
||||
else:
|
||||
return logits
|
||||
@@ -0,0 +1 @@
|
||||
tensorboardX
|
||||
+95
@@ -0,0 +1,95 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from pytorch_pretrained_bert import BertTokenizer, BertModel
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from tensorboardX import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from data import TokenizedDataFrameDataset
|
||||
from modules import add_adapters, AdapterConfig, ClassificationModel
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# TODO parameters
|
||||
parser = argparse.ArgumentParser(description='bert_adapter')
|
||||
parser.add_argument('--num_epochs', type=int, default=5, metavar='NI',
|
||||
help='num epochs (default: 5)')
|
||||
parser.add_argument('--batch-size', type=int, default=50, metavar='S')
|
||||
parser.add_argument('--n_workers', type=int, default=4, metavar='S')
|
||||
parser.add_argument('--num-threads', type=int, default=4, metavar='BS',
|
||||
help='num threads (default: 4)')
|
||||
parser.add_argument('--dropout', type=float, default=0.4, metavar='D',
|
||||
help='dropout rate (default: 0.4)')
|
||||
parser.add_argument('--tensorboard', type=str, default='default_tb', metavar='TB',
|
||||
help='Name for tensorboard model')
|
||||
parser.add_argument('--train_file', type=str,
|
||||
default='./data/rusentiment/rusentiment_random_posts.csv',
|
||||
metavar='TB',
|
||||
help='Path to RuSentiment train')
|
||||
parser.add_argument('--test_file', type=str,
|
||||
default='./data/rusentiment/rusentiment_test.csv',
|
||||
metavar='TB',
|
||||
help='Path to RuSentiment test')
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
writer = SummaryWriter(args.tensorboard)
|
||||
|
||||
torch.set_num_threads(args.num_threads)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
|
||||
train_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
|
||||
test_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.test_file)
|
||||
|
||||
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.n_workers)
|
||||
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers)
|
||||
|
||||
# Load pre-trained model (weights)
|
||||
model = BertModel.from_pretrained('bert-base-uncased')
|
||||
|
||||
config = AdapterConfig(
|
||||
hidden_size=768, adapter_size=5,
|
||||
adapter_act='relu', adapter_initializer_range=0.1
|
||||
)
|
||||
model = add_adapters(model, config)
|
||||
|
||||
model = ClassificationModel(model, n_labels=len(train_dataset.y_labels), dropout_prob=0.3)
|
||||
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
optimizer = Adam(model.learnable_parameters(), lr=0.001, amsgrad=True)
|
||||
|
||||
print('Model have initialized')
|
||||
for i in range(args.num_epochs):
|
||||
model.train()
|
||||
for batch in tqdm(train_loader):
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_ids, input_mask, segment_ids = batch['x']
|
||||
y = batch['y']
|
||||
|
||||
loss, _ = model.forward(input_ids, input_mask, segment_ids, labels=y)
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
model.eval()
|
||||
labels = []
|
||||
predictions = []
|
||||
for batch in test_loader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_ids, input_mask, segment_ids = batch['x']
|
||||
y = batch['y']
|
||||
|
||||
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
|
||||
labels.append(torch.argmax(y, dim=1))
|
||||
predictions.append(torch.argmax(logits, dim=1))
|
||||
labels = torch.LongTensor(labels)
|
||||
predictions = torch.LongTensor(predictions)
|
||||
print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}')
|
||||
Reference in New Issue
Block a user