diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 616aa828..bd35f168 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -23,8 +23,10 @@ defaults: eval_size: log_dir: "base" quantization: false + seq2seqmodel: false + poly_eps: 1.0 -galactica-125: +galactica-125m: learning_rate: 5e-5 model_name: facebook/galactica-125m weight_decay: 0.01 @@ -58,7 +60,7 @@ codegen: debug: eval_steps: 20 - eval_size: 100 + eval_size: 20 gradient_accumulation_steps: 1 per_device_train_batch_size: 1 per_device_eval_batch_size: 1 diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 5706bfa7..c0cd424b 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,3 +1,4 @@ +import numpy as np from datasets import load_dataset from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, Subset @@ -5,19 +6,103 @@ from torch.utils.data import Dataset, Subset from .prompt_dialogue import PromptGeneratedDataset QA_SPECIAL_TOKENS = {"Question": "", "Answer": ""} +SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": "TL;DR:"} + +summarization_name_mapping = { + "cnn_dailymail": ("article", "highlights"), + "samsum": ("dialogue", "summary"), + "xsum": ("document", "summary"), + "multi_news": ("document", "summary"), + "scitldr": ("source", "target"), + "billsum": ("text", "summary"), + "reddit": ("content", "summary"), +} +summarization_config_mapping = { + "cnn_dailymail": ("3.0.0",), + "samsum": (), + "xsum": (), + "multi_news": (), + "scitldr": ("AIC",), + "billsum": (), + "reddit": (), +} + +QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"] +SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"] -class SquadV2Dataset(Dataset): - def __init__(self, cache_dir, split): - self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split) +def index_squad_v2(example): + return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] + + +def index_trivia_qa_nocontext(example): + # dummy return one randomly + return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] + + +def index_trivia_qa_context(example): + question = example["question"] + title = example["title"][np.random.randint(len(example["title"]))] + context = example["search_context"][np.random.randint(len(example["search_context"]))] + answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] + + return title + ". " + context + " " + question, answer + + +def index_adversarial_qa(example): + return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] + + +class QADataset(Dataset): + def __init__(self, dataset, cache_dir, split): + if dataset == "squad_v2": + self.index_fn = index_squad_v2 + self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split) + elif dataset == "trivia_qa_nocontext": + self.index_fn = index_trivia_qa_nocontext + self.dataset = load_dataset("trivia_qa", "rc.nocontext") + elif dataset == "trivia_qa_context": + self.index_fn = index_trivia_qa_context + self.dataset = load_dataset("trivia_qa", "rc") + elif dataset == "adversarial_qa": + self.index_fn = index_adversarial_qa + self.dataset = load_dataset("adversarial_qa", "adversarialQA") + else: + raise ValueError("Unknown dataset : " + dataset) def __len__(self): return len(self.dataset) def __getitem__(self, idx): data = self.dataset[idx] - # return first answer form list of possible answers - return data["title"] + ". " + data["context"] + " " + data["question"], data["answers"]["text"][0] + return self.index_fn(data) + + +def index_summary_default(text, summary): + return text, summary + + +def index_summary_merge(text, summary): + return " ".join(text), " ".join(summary) + + +class SummarizationDataset(Dataset): + def __init__(self, dataset, cache_dir, split): + self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) + self.summary_column, self.text_column = summarization_name_mapping[dataset] + self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_merge + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + data = self.dataset[idx] + text, summary = data[self.text_column], data[self.summary_column] + text, summary = self.preprocess_fn(text, summary) + + return "".join( + SUMMARIZATION_SPECIAL_TOKENS["Text"], text, " ", SUMMARIZATION_SPECIAL_TOKENS["Summary"], summary + ) class WebGPT(Dataset): @@ -58,9 +143,14 @@ def train_val_dataset(dataset, val_split=0.2): def get_one_dataset(conf, dataset_name): dataset_name = dataset_name.lower() - if dataset_name == "squadv2": - train = SquadV2Dataset(conf.cache_dir, "train") - eval = SquadV2Dataset(conf.cache_dir, "validation") + if dataset_name in ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]: + train = QADataset(dataset_name, conf.cache_dir, "train") + eval = QADataset(dataset_name, conf.cache_dir, "validation") + + elif dataset_name in ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"]: + train = SummarizationDataset(dataset_name, conf.cache_dir, "train") + eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation") + elif dataset_name == "webgpt": dataset = WebGPT() train, eval = train_val_dataset(dataset, val_split=0.2) diff --git a/model/supervised_finetuning/losses.py b/model/supervised_finetuning/losses.py index 0cc639cf..88902ed7 100644 --- a/model/supervised_finetuning/losses.py +++ b/model/supervised_finetuning/losses.py @@ -1,3 +1,5 @@ +import torch +import torch.nn.functional as F from torch import nn @@ -13,3 +15,31 @@ class CrossEntropyLoss(nn.CrossEntropyLoss): input = input[mask] target = target[mask] return super(CrossEntropyLoss, self).forward(input, target) + + +class PolyLoss(nn.Module): + def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", epsilon=1.0): + super(PolyLoss, self).__init__() + self.weight = torch.tensor(weight) + self.ignore_index = ignore_index + self.reduction = reduction + self.cross_entropy = CrossEntropyLoss(weight, size_average, ignore_index, reduce, "none") + self.epsilon = epsilon + + def forward(self, input, target, mask=None): + if mask is not None: + mask = mask.view(-1).bool() + input = input.view(-1, input.size(-1)) + target = target.view(-1) + input = input[mask] + target = target[mask] + + onehot_target = F.one_hot(target, num_classes=input.size(-1)).to(device=input.device, dtype=input.dtype) + pt = torch.sum(onehot_target * F.softmax(input, -1), -1) + CE = self.cross_entropy(input, target) + poly1 = CE + self.epsilon * (1 - pt) + if self.reduction == "mean": + poly1 = poly1.mean() + elif self.reduction == "sum": + poly1 = poly1.sum() + return poly1 diff --git a/model/supervised_finetuning/models/__init__.py b/model/supervised_finetuning/models/__init__.py index 21ddab9d..510a2738 100644 --- a/model/supervised_finetuning/models/__init__.py +++ b/model/supervised_finetuning/models/__init__.py @@ -1,4 +1,4 @@ -from transformers import AutoModelForCausalLM +import transformers # from .gptj import get_model as get_gptj_model @@ -25,9 +25,12 @@ def freeze_top_n_layers(model, target_layers): return model -def get_specific_model(model_name, cache_dir, quantization): - return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) - # if "gpt-j" in model_name.lower(): - # return get_gptj_model(model_name, cache_dir, quantization) - # else: - # return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) +def get_specific_model(model_name, cache_dir, quantization, seq2seqmodel): + # encoder-decoder support for Flan-T5 like models + # for now, we can use an argument but in the future, + # we can automate this + if seq2seqmodel: + model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir) + else: + model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) + return model diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index c47a1218..0e6eeb51 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -2,7 +2,9 @@ accelerate==0.15.0 bitsandbytes==0.36.0.post2 datasets==2.8.0 deepspeed==0.7.7 +evaluate==0.4.0 mpi4py==3.1.4 +nltk==3.8.1 numpy==1.23.0 PyYAML==6.0 scikit_learn==1.2.0 diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 450854f1..0acb10dd 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -1,6 +1,6 @@ import argparse -import os from distutils.util import strtobool +from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import bitsandbytes @@ -8,16 +8,16 @@ import torch from torch import nn from transformers import PreTrainedModel, Trainer, TrainingArguments from transformers.training_args import OptimizerNames -from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls - -os.environ["WANDB_PROJECT"] = "supervised-finetuning" +from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls -def compute_metrics(eval_pred): - pred_ids = eval_pred.predictions - labels = eval_pred.label_ids +def compute_metrics(eval_pred, preprocess_fns, metrics): + out = {} + for metric, preprocess_fn in zip(metrics, preprocess_fns): + preds, labels = preprocess_fn(eval_pred) + out = dict(**out, **metric.compute(predictions=preds, references=labels)) - return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()} + return out def preprocess_logits_for_metrics(logits, labels): @@ -31,18 +31,22 @@ class SFTTrainer(Trainer): model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, loss_function: str = "CrossEntropyLoss", + poly_eps: float = 1.0, **kwargs, ): super().__init__(model, args, **kwargs) # By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct - self.loss_fct = get_loss(loss_function) + self.loss_fct = get_loss(loss_function, poly_eps) def compute_loss(self, model, inputs, return_outputs=False): labels_mask = inputs.pop("label_masks") targets = inputs.pop("targets") - outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None)) + outputs = model( + input_ids=inputs["input_ids"], + attention_mask=inputs.get("attention_mask", None), + ) loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask) @@ -54,7 +58,10 @@ class SFTTrainer(Trainer): labels_mask = inputs.pop("label_masks") targets = inputs.pop("targets") - outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None)) + outputs = model( + input_ids=inputs["input_ids"], + attention_mask=inputs.get("attention_mask", None), + ) logits = outputs.get("logits") @@ -92,6 +99,7 @@ def argument_parsing(notebook=False, notebook_args=None): parser.add_argument("--local_rank", type=int, default=-1) parser.add_argument("--deepspeed", action="store_true") parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false") + parser.add_argument("--wandb-entity", type=str, default="open-assistant") parser.set_defaults(deepspeed=False) if notebook: @@ -111,8 +119,10 @@ def argument_parsing(notebook=False, notebook_args=None): else: conf.update(configs[name]) + conf["wandb_entity"] = args.wandb_entity conf["local_rank"] = args.local_rank conf["deepspeed"] = args.deepspeed + # Override config from command-line parser = argparse.ArgumentParser() for key, value in conf.items(): @@ -131,8 +141,9 @@ if __name__ == "__main__": model = get_model(training_conf, tokenizer) train, evals, collate_fn = get_dataset(training_conf, tokenizer) + metrics, preprocess_fns = get_metrics(training_conf, tokenizer) - optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None + optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF if training_conf.quantization: for module in model.modules(): @@ -166,15 +177,26 @@ if __name__ == "__main__": ) assert len(evals) > 0 + + if not training_conf.deepspeed or training_conf.local_rank == 0: + import wandb + + wandb.init( + project="supervised-finetuning", + entity=training_conf.wandb_entity, + name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned", + ) + trainer = SFTTrainer( model, args, loss_function=training_conf.loss_fn, + poly_eps=training_conf.poly_eps, train_dataset=train, eval_dataset=evals, data_collator=collate_fn, tokenizer=tokenizer, - compute_metrics=compute_metrics, + compute_metrics=partial(compute_metrics, metrics=metrics, preprocess_fns=preprocess_fns), preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) trainer.train() diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index d6abcff2..f598dde1 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -1,17 +1,21 @@ +from functools import partial from pathlib import Path +import evaluate +import nltk +import numpy as np +import transformers import yaml -from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset +from custom_datasets import QA_DATASETS, QA_SPECIAL_TOKENS, SUMMARIZATION_DATASETS, get_one_dataset from custom_datasets.dialogue_collator import DialogueDataCollator -from losses import CrossEntropyLoss +from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model from sklearn.model_selection import train_test_split from torch.utils.data import ConcatDataset, Subset -from transformers import AutoTokenizer def get_tokenizer(conf): - tokenizer = AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir) + tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir) if "galactica" in conf.model_name: tokenizer.add_special_tokens({"pad_token": "", "eos_token": ""}) @@ -32,8 +36,63 @@ def get_tokenizer(conf): return tokenizer +def default_preprocess(eval_pred, ignote_negative_labels=True): + preds, labels = eval_pred.predictions, eval_pred.label_ids + + if not ignote_negative_labels: + return preds, labels + + mask = labels > 0 + return preds[mask], labels[mask] + + +# placeholder for now +def preprocess_qa(eval_pred): + return (eval_pred.predictions, eval_pred.label_ids) + + +def postprocess_summarization(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + + +def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True): + preds, labels = eval_pred + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + if ignore_pad_token_for_loss: + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels) + return decoded_preds, decoded_labels + + +def get_metrics(conf, tokenizer): + # the reason behind using a list is that we might want to extend the list of our + # metrics in the future for more thorough evaluation + metrics, preprocess_fns = [evaluate.load("accuracy")], [default_preprocess] + + if any(dataset in QA_DATASETS for dataset in conf.datasets): + raise ValueError("TODO") + metrics.append(evaluate.load("squad_v2")) + preprocess_fns.append(preprocess_qa) + if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets): + raise ValueError("TODO") + metrics.append(evaluate.load("rouge")) + preprocess_fns.append( + partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss) + ) + + return metrics, preprocess_fns + + def get_model(conf, tokenizer): - model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization) + model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization, conf.seq2seqmodel) if len(tokenizer) != model.get_input_embeddings().num_embeddings: assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen." @@ -65,9 +124,11 @@ def get_dataset(conf, tokenizer): return train, evals, collate_fn -def get_loss(loss): +def get_loss(loss, poly_eps): if loss == "CrossEntropyLoss": return CrossEntropyLoss() + elif loss == "Poly": + return PolyLoss(epsilon=poly_eps) else: raise ValueError(f"Loss {loss} not supported")