mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-01 16:50:12 +08:00
Merge pull request #638 from LAION-AI/ekurtulus/main
Changes on #619. Datasets is getting a bit dirty. I will do a refactoring this week.
This commit is contained in:
@@ -23,8 +23,10 @@ defaults:
|
||||
eval_size:
|
||||
log_dir: "base"
|
||||
quantization: false
|
||||
seq2seqmodel: false
|
||||
poly_eps: 1.0
|
||||
|
||||
galactica-125:
|
||||
galactica-125m:
|
||||
learning_rate: 5e-5
|
||||
model_name: facebook/galactica-125m
|
||||
weight_decay: 0.01
|
||||
@@ -58,7 +60,7 @@ codegen:
|
||||
|
||||
debug:
|
||||
eval_steps: 20
|
||||
eval_size: 100
|
||||
eval_size: 20
|
||||
gradient_accumulation_steps: 1
|
||||
per_device_train_batch_size: 1
|
||||
per_device_eval_batch_size: 1
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Dataset, Subset
|
||||
@@ -5,19 +6,103 @@ from torch.utils.data import Dataset, Subset
|
||||
from .prompt_dialogue import PromptGeneratedDataset
|
||||
|
||||
QA_SPECIAL_TOKENS = {"Question": "<question>", "Answer": "<answer>"}
|
||||
SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": "TL;DR:"}
|
||||
|
||||
summarization_name_mapping = {
|
||||
"cnn_dailymail": ("article", "highlights"),
|
||||
"samsum": ("dialogue", "summary"),
|
||||
"xsum": ("document", "summary"),
|
||||
"multi_news": ("document", "summary"),
|
||||
"scitldr": ("source", "target"),
|
||||
"billsum": ("text", "summary"),
|
||||
"reddit": ("content", "summary"),
|
||||
}
|
||||
summarization_config_mapping = {
|
||||
"cnn_dailymail": ("3.0.0",),
|
||||
"samsum": (),
|
||||
"xsum": (),
|
||||
"multi_news": (),
|
||||
"scitldr": ("AIC",),
|
||||
"billsum": (),
|
||||
"reddit": (),
|
||||
}
|
||||
|
||||
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]
|
||||
SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"]
|
||||
|
||||
|
||||
class SquadV2Dataset(Dataset):
|
||||
def __init__(self, cache_dir, split):
|
||||
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
|
||||
def index_squad_v2(example):
|
||||
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
|
||||
|
||||
|
||||
def index_trivia_qa_nocontext(example):
|
||||
# dummy return one randomly
|
||||
return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]
|
||||
|
||||
|
||||
def index_trivia_qa_context(example):
|
||||
question = example["question"]
|
||||
title = example["title"][np.random.randint(len(example["title"]))]
|
||||
context = example["search_context"][np.random.randint(len(example["search_context"]))]
|
||||
answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]
|
||||
|
||||
return title + ". " + context + " " + question, answer
|
||||
|
||||
|
||||
def index_adversarial_qa(example):
|
||||
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
|
||||
|
||||
|
||||
class QADataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
if dataset == "squad_v2":
|
||||
self.index_fn = index_squad_v2
|
||||
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
|
||||
elif dataset == "trivia_qa_nocontext":
|
||||
self.index_fn = index_trivia_qa_nocontext
|
||||
self.dataset = load_dataset("trivia_qa", "rc.nocontext")
|
||||
elif dataset == "trivia_qa_context":
|
||||
self.index_fn = index_trivia_qa_context
|
||||
self.dataset = load_dataset("trivia_qa", "rc")
|
||||
elif dataset == "adversarial_qa":
|
||||
self.index_fn = index_adversarial_qa
|
||||
self.dataset = load_dataset("adversarial_qa", "adversarialQA")
|
||||
else:
|
||||
raise ValueError("Unknown dataset : " + dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.dataset[idx]
|
||||
# return first answer form list of possible answers
|
||||
return data["title"] + ". " + data["context"] + " " + data["question"], data["answers"]["text"][0]
|
||||
return self.index_fn(data)
|
||||
|
||||
|
||||
def index_summary_default(text, summary):
|
||||
return text, summary
|
||||
|
||||
|
||||
def index_summary_merge(text, summary):
|
||||
return " ".join(text), " ".join(summary)
|
||||
|
||||
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.summary_column, self.text_column = summarization_name_mapping[dataset]
|
||||
self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_merge
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.dataset[idx]
|
||||
text, summary = data[self.text_column], data[self.summary_column]
|
||||
text, summary = self.preprocess_fn(text, summary)
|
||||
|
||||
return "".join(
|
||||
SUMMARIZATION_SPECIAL_TOKENS["Text"], text, " ", SUMMARIZATION_SPECIAL_TOKENS["Summary"], summary
|
||||
)
|
||||
|
||||
|
||||
class WebGPT(Dataset):
|
||||
@@ -58,9 +143,14 @@ def train_val_dataset(dataset, val_split=0.2):
|
||||
def get_one_dataset(conf, dataset_name):
|
||||
dataset_name = dataset_name.lower()
|
||||
|
||||
if dataset_name == "squadv2":
|
||||
train = SquadV2Dataset(conf.cache_dir, "train")
|
||||
eval = SquadV2Dataset(conf.cache_dir, "validation")
|
||||
if dataset_name in ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]:
|
||||
train = QADataset(dataset_name, conf.cache_dir, "train")
|
||||
eval = QADataset(dataset_name, conf.cache_dir, "validation")
|
||||
|
||||
elif dataset_name in ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"]:
|
||||
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
|
||||
|
||||
elif dataset_name == "webgpt":
|
||||
dataset = WebGPT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
@@ -13,3 +15,31 @@ class CrossEntropyLoss(nn.CrossEntropyLoss):
|
||||
input = input[mask]
|
||||
target = target[mask]
|
||||
return super(CrossEntropyLoss, self).forward(input, target)
|
||||
|
||||
|
||||
class PolyLoss(nn.Module):
|
||||
def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", epsilon=1.0):
|
||||
super(PolyLoss, self).__init__()
|
||||
self.weight = torch.tensor(weight)
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.cross_entropy = CrossEntropyLoss(weight, size_average, ignore_index, reduce, "none")
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, input, target, mask=None):
|
||||
if mask is not None:
|
||||
mask = mask.view(-1).bool()
|
||||
input = input.view(-1, input.size(-1))
|
||||
target = target.view(-1)
|
||||
input = input[mask]
|
||||
target = target[mask]
|
||||
|
||||
onehot_target = F.one_hot(target, num_classes=input.size(-1)).to(device=input.device, dtype=input.dtype)
|
||||
pt = torch.sum(onehot_target * F.softmax(input, -1), -1)
|
||||
CE = self.cross_entropy(input, target)
|
||||
poly1 = CE + self.epsilon * (1 - pt)
|
||||
if self.reduction == "mean":
|
||||
poly1 = poly1.mean()
|
||||
elif self.reduction == "sum":
|
||||
poly1 = poly1.sum()
|
||||
return poly1
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from transformers import AutoModelForCausalLM
|
||||
import transformers
|
||||
|
||||
# from .gptj import get_model as get_gptj_model
|
||||
|
||||
@@ -25,9 +25,12 @@ def freeze_top_n_layers(model, target_layers):
|
||||
return model
|
||||
|
||||
|
||||
def get_specific_model(model_name, cache_dir, quantization):
|
||||
return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
# if "gpt-j" in model_name.lower():
|
||||
# return get_gptj_model(model_name, cache_dir, quantization)
|
||||
# else:
|
||||
# return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
def get_specific_model(model_name, cache_dir, quantization, seq2seqmodel):
|
||||
# encoder-decoder support for Flan-T5 like models
|
||||
# for now, we can use an argument but in the future,
|
||||
# we can automate this
|
||||
if seq2seqmodel:
|
||||
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
else:
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
return model
|
||||
|
||||
@@ -2,7 +2,9 @@ accelerate==0.15.0
|
||||
bitsandbytes==0.36.0.post2
|
||||
datasets==2.8.0
|
||||
deepspeed==0.7.7
|
||||
evaluate==0.4.0
|
||||
mpi4py==3.1.4
|
||||
nltk==3.8.1
|
||||
numpy==1.23.0
|
||||
PyYAML==6.0
|
||||
scikit_learn==1.2.0
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import bitsandbytes
|
||||
@@ -8,16 +8,16 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel, Trainer, TrainingArguments
|
||||
from transformers.training_args import OptimizerNames
|
||||
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
|
||||
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
pred_ids = eval_pred.predictions
|
||||
labels = eval_pred.label_ids
|
||||
def compute_metrics(eval_pred, preprocess_fns, metrics):
|
||||
out = {}
|
||||
for metric, preprocess_fn in zip(metrics, preprocess_fns):
|
||||
preds, labels = preprocess_fn(eval_pred)
|
||||
out = dict(**out, **metric.compute(predictions=preds, references=labels))
|
||||
|
||||
return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()}
|
||||
return out
|
||||
|
||||
|
||||
def preprocess_logits_for_metrics(logits, labels):
|
||||
@@ -31,18 +31,22 @@ class SFTTrainer(Trainer):
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
loss_function: str = "CrossEntropyLoss",
|
||||
poly_eps: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model, args, **kwargs)
|
||||
|
||||
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
|
||||
self.loss_fct = get_loss(loss_function)
|
||||
self.loss_fct = get_loss(loss_function, poly_eps)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels_mask = inputs.pop("label_masks")
|
||||
targets = inputs.pop("targets")
|
||||
|
||||
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None))
|
||||
outputs = model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs.get("attention_mask", None),
|
||||
)
|
||||
|
||||
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
|
||||
|
||||
@@ -54,7 +58,10 @@ class SFTTrainer(Trainer):
|
||||
labels_mask = inputs.pop("label_masks")
|
||||
targets = inputs.pop("targets")
|
||||
|
||||
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None))
|
||||
outputs = model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs.get("attention_mask", None),
|
||||
)
|
||||
|
||||
logits = outputs.get("logits")
|
||||
|
||||
@@ -92,6 +99,7 @@ def argument_parsing(notebook=False, notebook_args=None):
|
||||
parser.add_argument("--local_rank", type=int, default=-1)
|
||||
parser.add_argument("--deepspeed", action="store_true")
|
||||
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
|
||||
parser.add_argument("--wandb-entity", type=str, default="open-assistant")
|
||||
parser.set_defaults(deepspeed=False)
|
||||
|
||||
if notebook:
|
||||
@@ -111,8 +119,10 @@ def argument_parsing(notebook=False, notebook_args=None):
|
||||
else:
|
||||
conf.update(configs[name])
|
||||
|
||||
conf["wandb_entity"] = args.wandb_entity
|
||||
conf["local_rank"] = args.local_rank
|
||||
conf["deepspeed"] = args.deepspeed
|
||||
|
||||
# Override config from command-line
|
||||
parser = argparse.ArgumentParser()
|
||||
for key, value in conf.items():
|
||||
@@ -131,8 +141,9 @@ if __name__ == "__main__":
|
||||
model = get_model(training_conf, tokenizer)
|
||||
|
||||
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
|
||||
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
|
||||
|
||||
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None
|
||||
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF
|
||||
|
||||
if training_conf.quantization:
|
||||
for module in model.modules():
|
||||
@@ -166,15 +177,26 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
assert len(evals) > 0
|
||||
|
||||
if not training_conf.deepspeed or training_conf.local_rank == 0:
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
project="supervised-finetuning",
|
||||
entity=training_conf.wandb_entity,
|
||||
name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args,
|
||||
loss_function=training_conf.loss_fn,
|
||||
poly_eps=training_conf.poly_eps,
|
||||
train_dataset=train,
|
||||
eval_dataset=evals,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
compute_metrics=partial(compute_metrics, metrics=metrics, preprocess_fns=preprocess_fns),
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import evaluate
|
||||
import nltk
|
||||
import numpy as np
|
||||
import transformers
|
||||
import yaml
|
||||
from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset
|
||||
from custom_datasets import QA_DATASETS, QA_SPECIAL_TOKENS, SUMMARIZATION_DATASETS, get_one_dataset
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
from losses import CrossEntropyLoss
|
||||
from losses import CrossEntropyLoss, PolyLoss
|
||||
from models import freeze_top_n_layers, get_specific_model
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import ConcatDataset, Subset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def get_tokenizer(conf):
|
||||
tokenizer = AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
|
||||
if "galactica" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
|
||||
@@ -32,8 +36,63 @@ def get_tokenizer(conf):
|
||||
return tokenizer
|
||||
|
||||
|
||||
def default_preprocess(eval_pred, ignote_negative_labels=True):
|
||||
preds, labels = eval_pred.predictions, eval_pred.label_ids
|
||||
|
||||
if not ignote_negative_labels:
|
||||
return preds, labels
|
||||
|
||||
mask = labels > 0
|
||||
return preds[mask], labels[mask]
|
||||
|
||||
|
||||
# placeholder for now
|
||||
def preprocess_qa(eval_pred):
|
||||
return (eval_pred.predictions, eval_pred.label_ids)
|
||||
|
||||
|
||||
def postprocess_summarization(preds, labels):
|
||||
preds = [pred.strip() for pred in preds]
|
||||
labels = [label.strip() for label in labels]
|
||||
|
||||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
||||
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
||||
|
||||
return preds, labels
|
||||
|
||||
|
||||
def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True):
|
||||
preds, labels = eval_pred
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
if ignore_pad_token_for_loss:
|
||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels)
|
||||
return decoded_preds, decoded_labels
|
||||
|
||||
|
||||
def get_metrics(conf, tokenizer):
|
||||
# the reason behind using a list is that we might want to extend the list of our
|
||||
# metrics in the future for more thorough evaluation
|
||||
metrics, preprocess_fns = [evaluate.load("accuracy")], [default_preprocess]
|
||||
|
||||
if any(dataset in QA_DATASETS for dataset in conf.datasets):
|
||||
raise ValueError("TODO")
|
||||
metrics.append(evaluate.load("squad_v2"))
|
||||
preprocess_fns.append(preprocess_qa)
|
||||
if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets):
|
||||
raise ValueError("TODO")
|
||||
metrics.append(evaluate.load("rouge"))
|
||||
preprocess_fns.append(
|
||||
partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss)
|
||||
)
|
||||
|
||||
return metrics, preprocess_fns
|
||||
|
||||
|
||||
def get_model(conf, tokenizer):
|
||||
model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization)
|
||||
model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization, conf.seq2seqmodel)
|
||||
|
||||
if len(tokenizer) != model.get_input_embeddings().num_embeddings:
|
||||
assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen."
|
||||
@@ -65,9 +124,11 @@ def get_dataset(conf, tokenizer):
|
||||
return train, evals, collate_fn
|
||||
|
||||
|
||||
def get_loss(loss):
|
||||
def get_loss(loss, poly_eps):
|
||||
if loss == "CrossEntropyLoss":
|
||||
return CrossEntropyLoss()
|
||||
elif loss == "Poly":
|
||||
return PolyLoss(epsilon=poly_eps)
|
||||
else:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user