Merge pull request #638 from LAION-AI/ekurtulus/main

Changes on #619. Datasets is getting a bit dirty. I will do a refactoring this week.
This commit is contained in:
theblackcat102
2023-01-12 08:23:37 +08:00
committed by GitHub
7 changed files with 246 additions and 36 deletions
@@ -23,8 +23,10 @@ defaults:
eval_size:
log_dir: "base"
quantization: false
seq2seqmodel: false
poly_eps: 1.0
galactica-125:
galactica-125m:
learning_rate: 5e-5
model_name: facebook/galactica-125m
weight_decay: 0.01
@@ -58,7 +60,7 @@ codegen:
debug:
eval_steps: 20
eval_size: 100
eval_size: 20
gradient_accumulation_steps: 1
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
@@ -1,3 +1,4 @@
import numpy as np
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, Subset
@@ -5,19 +6,103 @@ from torch.utils.data import Dataset, Subset
from .prompt_dialogue import PromptGeneratedDataset
QA_SPECIAL_TOKENS = {"Question": "<question>", "Answer": "<answer>"}
SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": "TL;DR:"}
summarization_name_mapping = {
"cnn_dailymail": ("article", "highlights"),
"samsum": ("dialogue", "summary"),
"xsum": ("document", "summary"),
"multi_news": ("document", "summary"),
"scitldr": ("source", "target"),
"billsum": ("text", "summary"),
"reddit": ("content", "summary"),
}
summarization_config_mapping = {
"cnn_dailymail": ("3.0.0",),
"samsum": (),
"xsum": (),
"multi_news": (),
"scitldr": ("AIC",),
"billsum": (),
"reddit": (),
}
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]
SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"]
class SquadV2Dataset(Dataset):
def __init__(self, cache_dir, split):
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
def index_squad_v2(example):
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
def index_trivia_qa_nocontext(example):
# dummy return one randomly
return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]
def index_trivia_qa_context(example):
question = example["question"]
title = example["title"][np.random.randint(len(example["title"]))]
context = example["search_context"][np.random.randint(len(example["search_context"]))]
answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]
return title + ". " + context + " " + question, answer
def index_adversarial_qa(example):
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
class QADataset(Dataset):
def __init__(self, dataset, cache_dir, split):
if dataset == "squad_v2":
self.index_fn = index_squad_v2
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
elif dataset == "trivia_qa_nocontext":
self.index_fn = index_trivia_qa_nocontext
self.dataset = load_dataset("trivia_qa", "rc.nocontext")
elif dataset == "trivia_qa_context":
self.index_fn = index_trivia_qa_context
self.dataset = load_dataset("trivia_qa", "rc")
elif dataset == "adversarial_qa":
self.index_fn = index_adversarial_qa
self.dataset = load_dataset("adversarial_qa", "adversarialQA")
else:
raise ValueError("Unknown dataset : " + dataset)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
# return first answer form list of possible answers
return data["title"] + ". " + data["context"] + " " + data["question"], data["answers"]["text"][0]
return self.index_fn(data)
def index_summary_default(text, summary):
return text, summary
def index_summary_merge(text, summary):
return " ".join(text), " ".join(summary)
class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split):
self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.summary_column, self.text_column = summarization_name_mapping[dataset]
self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_merge
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
text, summary = data[self.text_column], data[self.summary_column]
text, summary = self.preprocess_fn(text, summary)
return "".join(
SUMMARIZATION_SPECIAL_TOKENS["Text"], text, " ", SUMMARIZATION_SPECIAL_TOKENS["Summary"], summary
)
class WebGPT(Dataset):
@@ -58,9 +143,14 @@ def train_val_dataset(dataset, val_split=0.2):
def get_one_dataset(conf, dataset_name):
dataset_name = dataset_name.lower()
if dataset_name == "squadv2":
train = SquadV2Dataset(conf.cache_dir, "train")
eval = SquadV2Dataset(conf.cache_dir, "validation")
if dataset_name in ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]:
train = QADataset(dataset_name, conf.cache_dir, "train")
eval = QADataset(dataset_name, conf.cache_dir, "validation")
elif dataset_name in ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"]:
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
elif dataset_name == "webgpt":
dataset = WebGPT()
train, eval = train_val_dataset(dataset, val_split=0.2)
+30
View File
@@ -1,3 +1,5 @@
import torch
import torch.nn.functional as F
from torch import nn
@@ -13,3 +15,31 @@ class CrossEntropyLoss(nn.CrossEntropyLoss):
input = input[mask]
target = target[mask]
return super(CrossEntropyLoss, self).forward(input, target)
class PolyLoss(nn.Module):
def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", epsilon=1.0):
super(PolyLoss, self).__init__()
self.weight = torch.tensor(weight)
self.ignore_index = ignore_index
self.reduction = reduction
self.cross_entropy = CrossEntropyLoss(weight, size_average, ignore_index, reduce, "none")
self.epsilon = epsilon
def forward(self, input, target, mask=None):
if mask is not None:
mask = mask.view(-1).bool()
input = input.view(-1, input.size(-1))
target = target.view(-1)
input = input[mask]
target = target[mask]
onehot_target = F.one_hot(target, num_classes=input.size(-1)).to(device=input.device, dtype=input.dtype)
pt = torch.sum(onehot_target * F.softmax(input, -1), -1)
CE = self.cross_entropy(input, target)
poly1 = CE + self.epsilon * (1 - pt)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
return poly1
+10 -7
View File
@@ -1,4 +1,4 @@
from transformers import AutoModelForCausalLM
import transformers
# from .gptj import get_model as get_gptj_model
@@ -25,9 +25,12 @@ def freeze_top_n_layers(model, target_layers):
return model
def get_specific_model(model_name, cache_dir, quantization):
return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
# if "gpt-j" in model_name.lower():
# return get_gptj_model(model_name, cache_dir, quantization)
# else:
# return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
def get_specific_model(model_name, cache_dir, quantization, seq2seqmodel):
# encoder-decoder support for Flan-T5 like models
# for now, we can use an argument but in the future,
# we can automate this
if seq2seqmodel:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir)
else:
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
return model
@@ -2,7 +2,9 @@ accelerate==0.15.0
bitsandbytes==0.36.0.post2
datasets==2.8.0
deepspeed==0.7.7
evaluate==0.4.0
mpi4py==3.1.4
nltk==3.8.1
numpy==1.23.0
PyYAML==6.0
scikit_learn==1.2.0
+35 -13
View File
@@ -1,6 +1,6 @@
import argparse
import os
from distutils.util import strtobool
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import bitsandbytes
@@ -8,16 +8,16 @@ import torch
from torch import nn
from transformers import PreTrainedModel, Trainer, TrainingArguments
from transformers.training_args import OptimizerNames
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
def compute_metrics(eval_pred):
pred_ids = eval_pred.predictions
labels = eval_pred.label_ids
def compute_metrics(eval_pred, preprocess_fns, metrics):
out = {}
for metric, preprocess_fn in zip(metrics, preprocess_fns):
preds, labels = preprocess_fn(eval_pred)
out = dict(**out, **metric.compute(predictions=preds, references=labels))
return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()}
return out
def preprocess_logits_for_metrics(logits, labels):
@@ -31,18 +31,22 @@ class SFTTrainer(Trainer):
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
loss_function: str = "CrossEntropyLoss",
poly_eps: float = 1.0,
**kwargs,
):
super().__init__(model, args, **kwargs)
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
self.loss_fct = get_loss(loss_function)
self.loss_fct = get_loss(loss_function, poly_eps)
def compute_loss(self, model, inputs, return_outputs=False):
labels_mask = inputs.pop("label_masks")
targets = inputs.pop("targets")
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None))
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
@@ -54,7 +58,10 @@ class SFTTrainer(Trainer):
labels_mask = inputs.pop("label_masks")
targets = inputs.pop("targets")
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None))
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)
logits = outputs.get("logits")
@@ -92,6 +99,7 @@ def argument_parsing(notebook=False, notebook_args=None):
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--deepspeed", action="store_true")
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
parser.add_argument("--wandb-entity", type=str, default="open-assistant")
parser.set_defaults(deepspeed=False)
if notebook:
@@ -111,8 +119,10 @@ def argument_parsing(notebook=False, notebook_args=None):
else:
conf.update(configs[name])
conf["wandb_entity"] = args.wandb_entity
conf["local_rank"] = args.local_rank
conf["deepspeed"] = args.deepspeed
# Override config from command-line
parser = argparse.ArgumentParser()
for key, value in conf.items():
@@ -131,8 +141,9 @@ if __name__ == "__main__":
model = get_model(training_conf, tokenizer)
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF
if training_conf.quantization:
for module in model.modules():
@@ -166,15 +177,26 @@ if __name__ == "__main__":
)
assert len(evals) > 0
if not training_conf.deepspeed or training_conf.local_rank == 0:
import wandb
wandb.init(
project="supervised-finetuning",
entity=training_conf.wandb_entity,
name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
)
trainer = SFTTrainer(
model,
args,
loss_function=training_conf.loss_fn,
poly_eps=training_conf.poly_eps,
train_dataset=train,
eval_dataset=evals,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
compute_metrics=partial(compute_metrics, metrics=metrics, preprocess_fns=preprocess_fns),
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.train()
+67 -6
View File
@@ -1,17 +1,21 @@
from functools import partial
from pathlib import Path
import evaluate
import nltk
import numpy as np
import transformers
import yaml
from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset
from custom_datasets import QA_DATASETS, QA_SPECIAL_TOKENS, SUMMARIZATION_DATASETS, get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator
from losses import CrossEntropyLoss
from losses import CrossEntropyLoss, PolyLoss
from models import freeze_top_n_layers, get_specific_model
from sklearn.model_selection import train_test_split
from torch.utils.data import ConcatDataset, Subset
from transformers import AutoTokenizer
def get_tokenizer(conf):
tokenizer = AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
if "galactica" in conf.model_name:
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
@@ -32,8 +36,63 @@ def get_tokenizer(conf):
return tokenizer
def default_preprocess(eval_pred, ignote_negative_labels=True):
preds, labels = eval_pred.predictions, eval_pred.label_ids
if not ignote_negative_labels:
return preds, labels
mask = labels > 0
return preds[mask], labels[mask]
# placeholder for now
def preprocess_qa(eval_pred):
return (eval_pred.predictions, eval_pred.label_ids)
def postprocess_summarization(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True):
preds, labels = eval_pred
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
if ignore_pad_token_for_loss:
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels)
return decoded_preds, decoded_labels
def get_metrics(conf, tokenizer):
# the reason behind using a list is that we might want to extend the list of our
# metrics in the future for more thorough evaluation
metrics, preprocess_fns = [evaluate.load("accuracy")], [default_preprocess]
if any(dataset in QA_DATASETS for dataset in conf.datasets):
raise ValueError("TODO")
metrics.append(evaluate.load("squad_v2"))
preprocess_fns.append(preprocess_qa)
if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets):
raise ValueError("TODO")
metrics.append(evaluate.load("rouge"))
preprocess_fns.append(
partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss)
)
return metrics, preprocess_fns
def get_model(conf, tokenizer):
model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization)
model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization, conf.seq2seqmodel)
if len(tokenizer) != model.get_input_embeddings().num_embeddings:
assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen."
@@ -65,9 +124,11 @@ def get_dataset(conf, tokenizer):
return train, evals, collate_fn
def get_loss(loss):
def get_loss(loss, poly_eps):
if loss == "CrossEntropyLoss":
return CrossEntropyLoss()
elif loss == "Poly":
return PolyLoss(epsilon=poly_eps)
else:
raise ValueError(f"Loss {loss} not supported")