From 5b77dd2e9f21e24553173bf5c7f5221ea0bbd6f8 Mon Sep 17 00:00:00 2001 From: ekurtulus Date: Wed, 11 Jan 2023 11:37:27 +0300 Subject: [PATCH] better --- .../custom_datasets/__init__.py | 103 ++++++++++++++++-- model/supervised_finetuning/losses.py | 30 +++++ .../supervised_finetuning/models/__init__.py | 17 +-- model/supervised_finetuning/trainer.py | 31 +++--- model/supervised_finetuning/utils.py | 59 +++++++++- 5 files changed, 202 insertions(+), 38 deletions(-) diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 5706bfa7..c6fd3ada 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,3 +1,4 @@ +import numpy as np from datasets import load_dataset from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, Subset @@ -5,19 +6,100 @@ from torch.utils.data import Dataset, Subset from .prompt_dialogue import PromptGeneratedDataset QA_SPECIAL_TOKENS = {"Question": "", "Answer": ""} +SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": "TL;DR:"} + +summarization_name_mapping = { + "cnn_dailymail": ("article", "highlights"), + "samsum": ("dialogue", "summary"), + "xsum": ("document", "summary"), + "multi_news": ("document", "summary"), + "scitldr": ("source", "target"), + "billsum": ("text", "summary"), + "reddit": ("content", "summary"), +} +summarization_config_mapping = { + "cnn_dailymail": ("3.0.0",), + "samsum": (), + "xsum": (), + "multi_news": (), + "scitldr": ("AIC",), + "billsum": (), + "reddit": (), +} -class SquadV2Dataset(Dataset): - def __init__(self, cache_dir, split): - self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split) +def index_squad_v2(example): + return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] + + +def index_trivia_qa_nocontext(example): + # dummy return one randomly + return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] + + +def index_trivia_qa_context(example): + question = example["question"] + title = example["title"][np.random.randint(len(example["title"]))] + context = example["search_context"][np.random.randint(len(example["search_context"]))] + answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] + + return title + ". " + context + " " + question, answer + + +def index_adversarial_qa(example): + return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] + + +class QADataset(Dataset): + def __init__(self, dataset, cache_dir, split): + if dataset == "squad_v2": + self.index_fn = index_squad_v2 + self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split) + elif dataset == "trivia_qa_nocontext": + self.index_fn = index_trivia_qa_nocontext + self.dataset = load_dataset("trivia_qa", "rc.nocontext") + elif dataset == "trivia_qa_context": + self.index_fn = index_trivia_qa_context + self.dataset = load_dataset("trivia_qa", "rc") + elif dataset == "adversarial_qa": + self.index_fn = index_adversarial_qa + self.dataset = load_dataset("adversarial_qa", "adversarialQA") + else: + raise ValueError("Unknown dataset : " + dataset) def __len__(self): return len(self.dataset) def __getitem__(self, idx): data = self.dataset[idx] - # return first answer form list of possible answers - return data["title"] + ". " + data["context"] + " " + data["question"], data["answers"]["text"][0] + return self.index_fn(data) + + +def index_summary_default(text, summary): + return text, summary + + +def index_summary_merge(text, summary): + return " ".join(text), " ".join(summary) + + +class SummarizationDataset(Dataset): + def __init__(self, dataset, cache_dir, split): + self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) + self.summary_column, self.text_column = summarization_name_mapping[dataset] + self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_merge + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + data = self.dataset[idx] + text, summary = data[self.text_column], data[self.summary_column] + text, summary = self.preprocess_fn(text, summary) + + return "".join( + SUMMARIZATION_SPECIAL_TOKENS["Text"], text, " ", SUMMARIZATION_SPECIAL_TOKENS["Summary"], summary + ) class WebGPT(Dataset): @@ -58,9 +140,14 @@ def train_val_dataset(dataset, val_split=0.2): def get_one_dataset(conf, dataset_name): dataset_name = dataset_name.lower() - if dataset_name == "squadv2": - train = SquadV2Dataset(conf.cache_dir, "train") - eval = SquadV2Dataset(conf.cache_dir, "validation") + if dataset_name in ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]: + train = QADataset(dataset_name, conf.cache_dir, "train") + eval = QADataset(dataset_name, conf.cache_dir, "validation") + + elif dataset_name in ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"]: + train = SummarizationDataset(dataset_name, conf.cache_dir, "train") + eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation") + elif dataset_name == "webgpt": dataset = WebGPT() train, eval = train_val_dataset(dataset, val_split=0.2) diff --git a/model/supervised_finetuning/losses.py b/model/supervised_finetuning/losses.py index 0cc639cf..88902ed7 100644 --- a/model/supervised_finetuning/losses.py +++ b/model/supervised_finetuning/losses.py @@ -1,3 +1,5 @@ +import torch +import torch.nn.functional as F from torch import nn @@ -13,3 +15,31 @@ class CrossEntropyLoss(nn.CrossEntropyLoss): input = input[mask] target = target[mask] return super(CrossEntropyLoss, self).forward(input, target) + + +class PolyLoss(nn.Module): + def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", epsilon=1.0): + super(PolyLoss, self).__init__() + self.weight = torch.tensor(weight) + self.ignore_index = ignore_index + self.reduction = reduction + self.cross_entropy = CrossEntropyLoss(weight, size_average, ignore_index, reduce, "none") + self.epsilon = epsilon + + def forward(self, input, target, mask=None): + if mask is not None: + mask = mask.view(-1).bool() + input = input.view(-1, input.size(-1)) + target = target.view(-1) + input = input[mask] + target = target[mask] + + onehot_target = F.one_hot(target, num_classes=input.size(-1)).to(device=input.device, dtype=input.dtype) + pt = torch.sum(onehot_target * F.softmax(input, -1), -1) + CE = self.cross_entropy(input, target) + poly1 = CE + self.epsilon * (1 - pt) + if self.reduction == "mean": + poly1 = poly1.mean() + elif self.reduction == "sum": + poly1 = poly1.sum() + return poly1 diff --git a/model/supervised_finetuning/models/__init__.py b/model/supervised_finetuning/models/__init__.py index 21ddab9d..510a2738 100644 --- a/model/supervised_finetuning/models/__init__.py +++ b/model/supervised_finetuning/models/__init__.py @@ -1,4 +1,4 @@ -from transformers import AutoModelForCausalLM +import transformers # from .gptj import get_model as get_gptj_model @@ -25,9 +25,12 @@ def freeze_top_n_layers(model, target_layers): return model -def get_specific_model(model_name, cache_dir, quantization): - return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) - # if "gpt-j" in model_name.lower(): - # return get_gptj_model(model_name, cache_dir, quantization) - # else: - # return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) +def get_specific_model(model_name, cache_dir, quantization, seq2seqmodel): + # encoder-decoder support for Flan-T5 like models + # for now, we can use an argument but in the future, + # we can automate this + if seq2seqmodel: + model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir) + else: + model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) + return model diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 450854f1..54a45408 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -3,21 +3,22 @@ import os from distutils.util import strtobool from typing import Any, Dict, List, Optional, Tuple, Union -import bitsandbytes import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments -from transformers.training_args import OptimizerNames -from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls +from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls os.environ["WANDB_PROJECT"] = "supervised-finetuning" -def compute_metrics(eval_pred): - pred_ids = eval_pred.predictions - labels = eval_pred.label_ids +def compute_metrics(eval_pred, preprocess_fn, metrics): + preds, labels = preprocess_fn(eval_pred) - return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()} + out = {} + for metric in metrics: + out = dict(**out, **metric.compute(predictions=preds, references=labels)) + + return out def preprocess_logits_for_metrics(logits, labels): @@ -31,12 +32,13 @@ class SFTTrainer(Trainer): model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, loss_function: str = "CrossEntropyLoss", + poly_eps: float = 1.0, **kwargs, ): super().__init__(model, args, **kwargs) # By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct - self.loss_fct = get_loss(loss_function) + self.loss_fct = get_loss(loss_function, poly_eps) def compute_loss(self, model, inputs, return_outputs=False): labels_mask = inputs.pop("label_masks") @@ -92,6 +94,8 @@ def argument_parsing(notebook=False, notebook_args=None): parser.add_argument("--local_rank", type=int, default=-1) parser.add_argument("--deepspeed", action="store_true") parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false") + parser.add_argument("--poly_eps", type=float, default=1.0) + parser.add_argument("--seq2seq_model", action="store_true") parser.set_defaults(deepspeed=False) if notebook: @@ -131,15 +135,7 @@ if __name__ == "__main__": model = get_model(training_conf, tokenizer) train, evals, collate_fn = get_dataset(training_conf, tokenizer) - - optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None - - if training_conf.quantization: - for module in model.modules(): - if isinstance(module, torch.nn.Embedding): - bitsandbytes.optim.GlobalOptimManager.get_instance().register_module_override( - module, "weight", {"optim_bits": 32} - ) + metrics, preprocess_fn = get_metrics(training_conf) args = TrainingArguments( output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", @@ -147,7 +143,6 @@ if __name__ == "__main__": warmup_steps=training_conf.warmup_steps, learning_rate=float(training_conf.learning_rate), deepspeed="configs/zero_config.json" if training_conf.deepspeed else None, - optim=optimizer, fp16=True, local_rank=training_conf.local_rank, gradient_checkpointing=training_conf.gradient_checkpointing, diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index d6abcff2..8dd21773 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -1,17 +1,21 @@ +from functools import partial from pathlib import Path +import evaluate +import nltk +import numpy as np +import transformers import yaml from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset from custom_datasets.dialogue_collator import DialogueDataCollator -from losses import CrossEntropyLoss +from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model from sklearn.model_selection import train_test_split from torch.utils.data import ConcatDataset, Subset -from transformers import AutoTokenizer def get_tokenizer(conf): - tokenizer = AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir) + tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir) if "galactica" in conf.model_name: tokenizer.add_special_tokens({"pad_token": "", "eos_token": ""}) @@ -32,8 +36,51 @@ def get_tokenizer(conf): return tokenizer +# placeholder for now +def preprocess_qa(eval_pred): + return (eval_pred.predictions, eval_pred.label_ids) + + +def postprocess_summarization(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + + +def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True): + preds, labels = eval_pred + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + if ignore_pad_token_for_loss: + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels) + return decoded_preds, decoded_labels + + +def get_metrics(conf, tokenizer): + qa_datasets = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"] + summarization_datasets = ["xsum", "cnn_dailymail", "samsum", "multi_news"] + + # the reason behind using a list is that we might want to extend the list of our + # metrics in the future for more thorough evaluation + if any(dataset in qa_datasets for dataset in conf.datasets): + metrics, preprocess_fn = [evaluate.load("squad_v2")], preprocess_qa + elif any(dataset in summarization_datasets for dataset in conf.datasets): + metrics, preprocess_fn = [evaluate.load("rouge")], partial( + preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss + ) + else: + raise ValueError("Unknown dataset / task") + return metrics, preprocess_fn + + def get_model(conf, tokenizer): - model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization) + model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization, conf.seq2seqmodel) if len(tokenizer) != model.get_input_embeddings().num_embeddings: assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen." @@ -65,9 +112,11 @@ def get_dataset(conf, tokenizer): return train, evals, collate_fn -def get_loss(loss): +def get_loss(loss, poly_eps): if loss == "CrossEntropyLoss": return CrossEntropyLoss() + elif loss == "Poly": + return PolyLoss(epsilon=poly_eps) else: raise ValueError(f"Loss {loss} not supported")