This commit is contained in:
ekurtulus
2023-01-11 11:37:27 +03:00
parent bdb4762359
commit 5b77dd2e9f
5 changed files with 202 additions and 38 deletions
@@ -1,3 +1,4 @@
import numpy as np
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, Subset
@@ -5,19 +6,100 @@ from torch.utils.data import Dataset, Subset
from .prompt_dialogue import PromptGeneratedDataset
QA_SPECIAL_TOKENS = {"Question": "<question>", "Answer": "<answer>"}
SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": "TL;DR:"}
summarization_name_mapping = {
"cnn_dailymail": ("article", "highlights"),
"samsum": ("dialogue", "summary"),
"xsum": ("document", "summary"),
"multi_news": ("document", "summary"),
"scitldr": ("source", "target"),
"billsum": ("text", "summary"),
"reddit": ("content", "summary"),
}
summarization_config_mapping = {
"cnn_dailymail": ("3.0.0",),
"samsum": (),
"xsum": (),
"multi_news": (),
"scitldr": ("AIC",),
"billsum": (),
"reddit": (),
}
class SquadV2Dataset(Dataset):
def __init__(self, cache_dir, split):
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
def index_squad_v2(example):
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
def index_trivia_qa_nocontext(example):
# dummy return one randomly
return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]
def index_trivia_qa_context(example):
question = example["question"]
title = example["title"][np.random.randint(len(example["title"]))]
context = example["search_context"][np.random.randint(len(example["search_context"]))]
answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]
return title + ". " + context + " " + question, answer
def index_adversarial_qa(example):
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
class QADataset(Dataset):
def __init__(self, dataset, cache_dir, split):
if dataset == "squad_v2":
self.index_fn = index_squad_v2
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
elif dataset == "trivia_qa_nocontext":
self.index_fn = index_trivia_qa_nocontext
self.dataset = load_dataset("trivia_qa", "rc.nocontext")
elif dataset == "trivia_qa_context":
self.index_fn = index_trivia_qa_context
self.dataset = load_dataset("trivia_qa", "rc")
elif dataset == "adversarial_qa":
self.index_fn = index_adversarial_qa
self.dataset = load_dataset("adversarial_qa", "adversarialQA")
else:
raise ValueError("Unknown dataset : " + dataset)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
# return first answer form list of possible answers
return data["title"] + ". " + data["context"] + " " + data["question"], data["answers"]["text"][0]
return self.index_fn(data)
def index_summary_default(text, summary):
return text, summary
def index_summary_merge(text, summary):
return " ".join(text), " ".join(summary)
class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split):
self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.summary_column, self.text_column = summarization_name_mapping[dataset]
self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_merge
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
text, summary = data[self.text_column], data[self.summary_column]
text, summary = self.preprocess_fn(text, summary)
return "".join(
SUMMARIZATION_SPECIAL_TOKENS["Text"], text, " ", SUMMARIZATION_SPECIAL_TOKENS["Summary"], summary
)
class WebGPT(Dataset):
@@ -58,9 +140,14 @@ def train_val_dataset(dataset, val_split=0.2):
def get_one_dataset(conf, dataset_name):
dataset_name = dataset_name.lower()
if dataset_name == "squadv2":
train = SquadV2Dataset(conf.cache_dir, "train")
eval = SquadV2Dataset(conf.cache_dir, "validation")
if dataset_name in ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]:
train = QADataset(dataset_name, conf.cache_dir, "train")
eval = QADataset(dataset_name, conf.cache_dir, "validation")
elif dataset_name in ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"]:
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
elif dataset_name == "webgpt":
dataset = WebGPT()
train, eval = train_val_dataset(dataset, val_split=0.2)
+30
View File
@@ -1,3 +1,5 @@
import torch
import torch.nn.functional as F
from torch import nn
@@ -13,3 +15,31 @@ class CrossEntropyLoss(nn.CrossEntropyLoss):
input = input[mask]
target = target[mask]
return super(CrossEntropyLoss, self).forward(input, target)
class PolyLoss(nn.Module):
def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", epsilon=1.0):
super(PolyLoss, self).__init__()
self.weight = torch.tensor(weight)
self.ignore_index = ignore_index
self.reduction = reduction
self.cross_entropy = CrossEntropyLoss(weight, size_average, ignore_index, reduce, "none")
self.epsilon = epsilon
def forward(self, input, target, mask=None):
if mask is not None:
mask = mask.view(-1).bool()
input = input.view(-1, input.size(-1))
target = target.view(-1)
input = input[mask]
target = target[mask]
onehot_target = F.one_hot(target, num_classes=input.size(-1)).to(device=input.device, dtype=input.dtype)
pt = torch.sum(onehot_target * F.softmax(input, -1), -1)
CE = self.cross_entropy(input, target)
poly1 = CE + self.epsilon * (1 - pt)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
return poly1
+10 -7
View File
@@ -1,4 +1,4 @@
from transformers import AutoModelForCausalLM
import transformers
# from .gptj import get_model as get_gptj_model
@@ -25,9 +25,12 @@ def freeze_top_n_layers(model, target_layers):
return model
def get_specific_model(model_name, cache_dir, quantization):
return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
# if "gpt-j" in model_name.lower():
# return get_gptj_model(model_name, cache_dir, quantization)
# else:
# return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
def get_specific_model(model_name, cache_dir, quantization, seq2seqmodel):
# encoder-decoder support for Flan-T5 like models
# for now, we can use an argument but in the future,
# we can automate this
if seq2seqmodel:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir)
else:
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
return model
+13 -18
View File
@@ -3,21 +3,22 @@ import os
from distutils.util import strtobool
from typing import Any, Dict, List, Optional, Tuple, Union
import bitsandbytes
import torch
from torch import nn
from transformers import PreTrainedModel, Trainer, TrainingArguments
from transformers.training_args import OptimizerNames
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
def compute_metrics(eval_pred):
pred_ids = eval_pred.predictions
labels = eval_pred.label_ids
def compute_metrics(eval_pred, preprocess_fn, metrics):
preds, labels = preprocess_fn(eval_pred)
return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()}
out = {}
for metric in metrics:
out = dict(**out, **metric.compute(predictions=preds, references=labels))
return out
def preprocess_logits_for_metrics(logits, labels):
@@ -31,12 +32,13 @@ class SFTTrainer(Trainer):
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
loss_function: str = "CrossEntropyLoss",
poly_eps: float = 1.0,
**kwargs,
):
super().__init__(model, args, **kwargs)
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
self.loss_fct = get_loss(loss_function)
self.loss_fct = get_loss(loss_function, poly_eps)
def compute_loss(self, model, inputs, return_outputs=False):
labels_mask = inputs.pop("label_masks")
@@ -92,6 +94,8 @@ def argument_parsing(notebook=False, notebook_args=None):
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--deepspeed", action="store_true")
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
parser.add_argument("--poly_eps", type=float, default=1.0)
parser.add_argument("--seq2seq_model", action="store_true")
parser.set_defaults(deepspeed=False)
if notebook:
@@ -131,15 +135,7 @@ if __name__ == "__main__":
model = get_model(training_conf, tokenizer)
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None
if training_conf.quantization:
for module in model.modules():
if isinstance(module, torch.nn.Embedding):
bitsandbytes.optim.GlobalOptimManager.get_instance().register_module_override(
module, "weight", {"optim_bits": 32}
)
metrics, preprocess_fn = get_metrics(training_conf)
args = TrainingArguments(
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
@@ -147,7 +143,6 @@ if __name__ == "__main__":
warmup_steps=training_conf.warmup_steps,
learning_rate=float(training_conf.learning_rate),
deepspeed="configs/zero_config.json" if training_conf.deepspeed else None,
optim=optimizer,
fp16=True,
local_rank=training_conf.local_rank,
gradient_checkpointing=training_conf.gradient_checkpointing,
+54 -5
View File
@@ -1,17 +1,21 @@
from functools import partial
from pathlib import Path
import evaluate
import nltk
import numpy as np
import transformers
import yaml
from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator
from losses import CrossEntropyLoss
from losses import CrossEntropyLoss, PolyLoss
from models import freeze_top_n_layers, get_specific_model
from sklearn.model_selection import train_test_split
from torch.utils.data import ConcatDataset, Subset
from transformers import AutoTokenizer
def get_tokenizer(conf):
tokenizer = AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
if "galactica" in conf.model_name:
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
@@ -32,8 +36,51 @@ def get_tokenizer(conf):
return tokenizer
# placeholder for now
def preprocess_qa(eval_pred):
return (eval_pred.predictions, eval_pred.label_ids)
def postprocess_summarization(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True):
preds, labels = eval_pred
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
if ignore_pad_token_for_loss:
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels)
return decoded_preds, decoded_labels
def get_metrics(conf, tokenizer):
qa_datasets = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]
summarization_datasets = ["xsum", "cnn_dailymail", "samsum", "multi_news"]
# the reason behind using a list is that we might want to extend the list of our
# metrics in the future for more thorough evaluation
if any(dataset in qa_datasets for dataset in conf.datasets):
metrics, preprocess_fn = [evaluate.load("squad_v2")], preprocess_qa
elif any(dataset in summarization_datasets for dataset in conf.datasets):
metrics, preprocess_fn = [evaluate.load("rouge")], partial(
preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss
)
else:
raise ValueError("Unknown dataset / task")
return metrics, preprocess_fn
def get_model(conf, tokenizer):
model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization)
model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization, conf.seq2seqmodel)
if len(tokenizer) != model.get_input_embeddings().num_embeddings:
assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen."
@@ -65,9 +112,11 @@ def get_dataset(conf, tokenizer):
return train, evals, collate_fn
def get_loss(loss):
def get_loss(loss, poly_eps):
if loss == "CrossEntropyLoss":
return CrossEntropyLoss()
elif loss == "Poly":
return PolyLoss(epsilon=poly_eps)
else:
raise ValueError(f"Loss {loss} not supported")