Initial implementation.

This commit is contained in:
Anton Kiselev
2019-05-18 14:28:39 +03:00
parent dc0c66246b
commit fc96b0b092
5 changed files with 264 additions and 0 deletions
+70
View File
@@ -0,0 +1,70 @@
from typing import NamedTuple
import torch
import pandas as pd
from torch.utils.data import Dataset
from pytorch_pretrained_bert import BertTokenizer
class BertInput(NamedTuple):
input_ids: torch.LongTensor
input_mask: torch.LongTensor
segment_ids: torch.LongTensor
class TokenizedDataFrameDataset(Dataset):
def __init__(self,
tokenizer: BertTokenizer,
file_path: str,
x_label: str = 'text',
y_label: str = 'label',
max_seq_len: int = 20):
"""
:param data_path: path to data
"""
self.tokenizer = tokenizer
self.x_label = x_label
self.y_label = y_label
self.max_seq_len = max_seq_len
self.df = pd.read_csv(file_path)
self.df[y_label] = self.df[y_label].astype('category')
self.y_labels = self.df[y_label].cat.categories
self.df[y_label] = self.df[y_label].cat.codes
def preprocess_text(self, text: str) -> BertInput:
tokens = self.tokenizer.tokenize(text)
tokens = ["[CLS]"] + tokens + ["[SEP]"]
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
segment_ids = [0] * len(tokens)
input_mask = [1] * len(input_ids)
padding = [0] * (self.max_seq_len - len(input_ids))
input_ids += padding
input_mask += padding
segment_ids += padding
assert len(input_ids) == self.max_seq_len
assert len(input_mask) == self.max_seq_len
assert len(segment_ids) == self.max_seq_len
return BertInput(*[torch.LongTensor(_) for _ in [input_ids, input_mask, segment_ids]])
def preprocess_label(self, label: int):
result = torch.zeros(len(self.y_labels))
result[label] = 1
return result
def __getitem__(self, index) -> dict:
sample = self.df.iloc[index]
x = sample[self.x_label]
y = sample[self.y_label]
return {
'x': self.preprocess_text(x),
'y': self.preprocess_label(y)
}
def __len__(self):
return len(self.df)
View File
+98
View File
@@ -0,0 +1,98 @@
import logging
from typing import NamedTuple, Callable, Union
import torch.nn as nn
from pytorch_pretrained_bert.modeling import ACT2FN, BertLayerNorm, BertModel, BertSelfOutput
logging.basicConfig(level=logging.INFO)
class AdapterConfig(NamedTuple):
hidden_size: int
adapter_size: int
adapter_act: Union[str, Callable]
adapter_initializer_range: float
class Adapter(nn.Module):
def __init__(self, config: AdapterConfig):
super(Adapter, self).__init__()
self.down_project = nn.Linear(config.hidden_size, config.adapter_size)
nn.init.normal_(self.down_project.weight, std=config.adapter_initializer_range)
nn.init.zeros_(self.down_project.bias)
if isinstance(config.adapter_act, str):
self.activation = ACT2FN[config.adapter_act]
else:
self.activation = config.adapter_act
self.up_project = nn.Linear(config.adapter_size, config.hidden_size)
nn.init.normal_(self.up_project.weight, std=config.adapter_initializer_range)
nn.init.zeros_(self.up_project.bias)
def forward(self, hidden_states):
down_projected = self.down_project(hidden_states)
activated = self.activation(down_projected)
up_projected = self.up_project(activated)
return hidden_states + up_projected
class BertAdaptedSelfOutput(nn.Module):
def __init__(self,
self_output: BertSelfOutput,
config: AdapterConfig):
super(BertAdaptedSelfOutput, self).__init__()
self.self_output = self_output
self.adapter = Adapter(config)
def forward(self, hidden_states, input_tensor):
hidden_states = self.self_output.dense(hidden_states)
hidden_states = self.self_output.dropout(hidden_states)
hidden_states = self.adapter(hidden_states)
hidden_states = self.self_output.LayerNorm(hidden_states + input_tensor)
return hidden_states
def adapt_bert_self_output(config: AdapterConfig):
return lambda self_output: BertAdaptedSelfOutput(self_output, config=config)
def add_adapters(bert_model: BertModel,
config: AdapterConfig) -> BertModel:
bert_encoder = bert_model.encoder
for i in range(len(bert_model.encoder.layer)):
bert_encoder.layer[i].attention.output = adapt_bert_self_output(config)(
bert_encoder.layer[i].attention.output)
# Freeze all parameters
for param in bert_model.parameters():
param.requires_grad = False
# Unfreeze trainable parts — layer norms and adapters
for name, sub_module in bert_model.named_modules():
if isinstance(sub_module, (Adapter, BertLayerNorm)):
for param_name, param in sub_module.named_parameters():
param.requires_grad = True
return bert_model
class ClassificationModel(nn.Module):
def __init__(self, bert: BertModel, n_labels: int, dropout_prob: float):
super(ClassificationModel, self).__init__()
self.n_labels = n_labels
self.bert = bert
self.dropout = nn.Dropout(dropout_prob)
self.classifier = nn.Linear(bert.pooler.dense.out_features, n_labels)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if labels is not None:
loss_function = nn.CrossEntropyLoss()
loss = loss_function(logits.view(-1, self.n_labels), labels.view(-1))
return loss, logits
else:
return logits
+1
View File
@@ -0,0 +1 @@
tensorboardX
+95
View File
@@ -0,0 +1,95 @@
import argparse
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel
from torch.optim import Adam
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from tqdm import tqdm
from data import TokenizedDataFrameDataset
from modules import add_adapters, AdapterConfig, ClassificationModel
if __name__ == '__main__':
# TODO parameters
parser = argparse.ArgumentParser(description='bert_adapter')
parser.add_argument('--num_epochs', type=int, default=5, metavar='NI',
help='num epochs (default: 5)')
parser.add_argument('--batch-size', type=int, default=50, metavar='S')
parser.add_argument('--n_workers', type=int, default=4, metavar='S')
parser.add_argument('--num-threads', type=int, default=4, metavar='BS',
help='num threads (default: 4)')
parser.add_argument('--dropout', type=float, default=0.4, metavar='D',
help='dropout rate (default: 0.4)')
parser.add_argument('--tensorboard', type=str, default='default_tb', metavar='TB',
help='Name for tensorboard model')
parser.add_argument('--train_file', type=str,
default='./data/rusentiment/rusentiment_random_posts.csv',
metavar='TB',
help='Path to RuSentiment train')
parser.add_argument('--test_file', type=str,
default='./data/rusentiment/rusentiment_test.csv',
metavar='TB',
help='Path to RuSentiment test')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
writer = SummaryWriter(args.tensorboard)
torch.set_num_threads(args.num_threads)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.train_file)
test_dataset = TokenizedDataFrameDataset(tokenizer, file_path=args.test_file)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.n_workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.n_workers)
# Load pre-trained model (weights)
model = BertModel.from_pretrained('bert-base-uncased')
config = AdapterConfig(
hidden_size=768, adapter_size=5,
adapter_act='relu', adapter_initializer_range=0.1
)
model = add_adapters(model, config)
model = ClassificationModel(model, n_labels=len(train_dataset.y_labels), dropout_prob=0.3)
model.eval()
model.to(device)
optimizer = Adam(model.learnable_parameters(), lr=0.001, amsgrad=True)
print('Model have initialized')
for i in range(args.num_epochs):
model.train()
for batch in tqdm(train_loader):
optimizer.zero_grad()
input_ids, input_mask, segment_ids = batch['x']
y = batch['y']
loss, _ = model.forward(input_ids, input_mask, segment_ids, labels=y)
loss.backward()
optimizer.step()
model.eval()
labels = []
predictions = []
for batch in test_loader:
optimizer.zero_grad()
input_ids, input_mask, segment_ids = batch['x']
y = batch['y']
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
labels.append(torch.argmax(y, dim=1))
predictions.append(torch.argmax(logits, dim=1))
labels = torch.LongTensor(labels)
predictions = torch.LongTensor(predictions)
print(f'Epoch: {i}\tTest Accuracy: {(labels == predictions).mean()}')