mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-05 17:30:48 +08:00
better
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Dataset, Subset
|
||||
@@ -5,19 +6,100 @@ from torch.utils.data import Dataset, Subset
|
||||
from .prompt_dialogue import PromptGeneratedDataset
|
||||
|
||||
QA_SPECIAL_TOKENS = {"Question": "<question>", "Answer": "<answer>"}
|
||||
SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": "TL;DR:"}
|
||||
|
||||
summarization_name_mapping = {
|
||||
"cnn_dailymail": ("article", "highlights"),
|
||||
"samsum": ("dialogue", "summary"),
|
||||
"xsum": ("document", "summary"),
|
||||
"multi_news": ("document", "summary"),
|
||||
"scitldr": ("source", "target"),
|
||||
"billsum": ("text", "summary"),
|
||||
"reddit": ("content", "summary"),
|
||||
}
|
||||
summarization_config_mapping = {
|
||||
"cnn_dailymail": ("3.0.0",),
|
||||
"samsum": (),
|
||||
"xsum": (),
|
||||
"multi_news": (),
|
||||
"scitldr": ("AIC",),
|
||||
"billsum": (),
|
||||
"reddit": (),
|
||||
}
|
||||
|
||||
|
||||
class SquadV2Dataset(Dataset):
|
||||
def __init__(self, cache_dir, split):
|
||||
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
|
||||
def index_squad_v2(example):
|
||||
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
|
||||
|
||||
|
||||
def index_trivia_qa_nocontext(example):
|
||||
# dummy return one randomly
|
||||
return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]
|
||||
|
||||
|
||||
def index_trivia_qa_context(example):
|
||||
question = example["question"]
|
||||
title = example["title"][np.random.randint(len(example["title"]))]
|
||||
context = example["search_context"][np.random.randint(len(example["search_context"]))]
|
||||
answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]
|
||||
|
||||
return title + ". " + context + " " + question, answer
|
||||
|
||||
|
||||
def index_adversarial_qa(example):
|
||||
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]
|
||||
|
||||
|
||||
class QADataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
if dataset == "squad_v2":
|
||||
self.index_fn = index_squad_v2
|
||||
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
|
||||
elif dataset == "trivia_qa_nocontext":
|
||||
self.index_fn = index_trivia_qa_nocontext
|
||||
self.dataset = load_dataset("trivia_qa", "rc.nocontext")
|
||||
elif dataset == "trivia_qa_context":
|
||||
self.index_fn = index_trivia_qa_context
|
||||
self.dataset = load_dataset("trivia_qa", "rc")
|
||||
elif dataset == "adversarial_qa":
|
||||
self.index_fn = index_adversarial_qa
|
||||
self.dataset = load_dataset("adversarial_qa", "adversarialQA")
|
||||
else:
|
||||
raise ValueError("Unknown dataset : " + dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.dataset[idx]
|
||||
# return first answer form list of possible answers
|
||||
return data["title"] + ". " + data["context"] + " " + data["question"], data["answers"]["text"][0]
|
||||
return self.index_fn(data)
|
||||
|
||||
|
||||
def index_summary_default(text, summary):
|
||||
return text, summary
|
||||
|
||||
|
||||
def index_summary_merge(text, summary):
|
||||
return " ".join(text), " ".join(summary)
|
||||
|
||||
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.summary_column, self.text_column = summarization_name_mapping[dataset]
|
||||
self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_merge
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.dataset[idx]
|
||||
text, summary = data[self.text_column], data[self.summary_column]
|
||||
text, summary = self.preprocess_fn(text, summary)
|
||||
|
||||
return "".join(
|
||||
SUMMARIZATION_SPECIAL_TOKENS["Text"], text, " ", SUMMARIZATION_SPECIAL_TOKENS["Summary"], summary
|
||||
)
|
||||
|
||||
|
||||
class WebGPT(Dataset):
|
||||
@@ -58,9 +140,14 @@ def train_val_dataset(dataset, val_split=0.2):
|
||||
def get_one_dataset(conf, dataset_name):
|
||||
dataset_name = dataset_name.lower()
|
||||
|
||||
if dataset_name == "squadv2":
|
||||
train = SquadV2Dataset(conf.cache_dir, "train")
|
||||
eval = SquadV2Dataset(conf.cache_dir, "validation")
|
||||
if dataset_name in ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]:
|
||||
train = QADataset(dataset_name, conf.cache_dir, "train")
|
||||
eval = QADataset(dataset_name, conf.cache_dir, "validation")
|
||||
|
||||
elif dataset_name in ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"]:
|
||||
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
|
||||
|
||||
elif dataset_name == "webgpt":
|
||||
dataset = WebGPT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
@@ -13,3 +15,31 @@ class CrossEntropyLoss(nn.CrossEntropyLoss):
|
||||
input = input[mask]
|
||||
target = target[mask]
|
||||
return super(CrossEntropyLoss, self).forward(input, target)
|
||||
|
||||
|
||||
class PolyLoss(nn.Module):
|
||||
def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", epsilon=1.0):
|
||||
super(PolyLoss, self).__init__()
|
||||
self.weight = torch.tensor(weight)
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.cross_entropy = CrossEntropyLoss(weight, size_average, ignore_index, reduce, "none")
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, input, target, mask=None):
|
||||
if mask is not None:
|
||||
mask = mask.view(-1).bool()
|
||||
input = input.view(-1, input.size(-1))
|
||||
target = target.view(-1)
|
||||
input = input[mask]
|
||||
target = target[mask]
|
||||
|
||||
onehot_target = F.one_hot(target, num_classes=input.size(-1)).to(device=input.device, dtype=input.dtype)
|
||||
pt = torch.sum(onehot_target * F.softmax(input, -1), -1)
|
||||
CE = self.cross_entropy(input, target)
|
||||
poly1 = CE + self.epsilon * (1 - pt)
|
||||
if self.reduction == "mean":
|
||||
poly1 = poly1.mean()
|
||||
elif self.reduction == "sum":
|
||||
poly1 = poly1.sum()
|
||||
return poly1
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from transformers import AutoModelForCausalLM
|
||||
import transformers
|
||||
|
||||
# from .gptj import get_model as get_gptj_model
|
||||
|
||||
@@ -25,9 +25,12 @@ def freeze_top_n_layers(model, target_layers):
|
||||
return model
|
||||
|
||||
|
||||
def get_specific_model(model_name, cache_dir, quantization):
|
||||
return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
# if "gpt-j" in model_name.lower():
|
||||
# return get_gptj_model(model_name, cache_dir, quantization)
|
||||
# else:
|
||||
# return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
def get_specific_model(model_name, cache_dir, quantization, seq2seqmodel):
|
||||
# encoder-decoder support for Flan-T5 like models
|
||||
# for now, we can use an argument but in the future,
|
||||
# we can automate this
|
||||
if seq2seqmodel:
|
||||
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
else:
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
return model
|
||||
|
||||
@@ -3,21 +3,22 @@ import os
|
||||
from distutils.util import strtobool
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import bitsandbytes
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel, Trainer, TrainingArguments
|
||||
from transformers.training_args import OptimizerNames
|
||||
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
|
||||
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
pred_ids = eval_pred.predictions
|
||||
labels = eval_pred.label_ids
|
||||
def compute_metrics(eval_pred, preprocess_fn, metrics):
|
||||
preds, labels = preprocess_fn(eval_pred)
|
||||
|
||||
return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()}
|
||||
out = {}
|
||||
for metric in metrics:
|
||||
out = dict(**out, **metric.compute(predictions=preds, references=labels))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def preprocess_logits_for_metrics(logits, labels):
|
||||
@@ -31,12 +32,13 @@ class SFTTrainer(Trainer):
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
loss_function: str = "CrossEntropyLoss",
|
||||
poly_eps: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model, args, **kwargs)
|
||||
|
||||
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
|
||||
self.loss_fct = get_loss(loss_function)
|
||||
self.loss_fct = get_loss(loss_function, poly_eps)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels_mask = inputs.pop("label_masks")
|
||||
@@ -92,6 +94,8 @@ def argument_parsing(notebook=False, notebook_args=None):
|
||||
parser.add_argument("--local_rank", type=int, default=-1)
|
||||
parser.add_argument("--deepspeed", action="store_true")
|
||||
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
|
||||
parser.add_argument("--poly_eps", type=float, default=1.0)
|
||||
parser.add_argument("--seq2seq_model", action="store_true")
|
||||
parser.set_defaults(deepspeed=False)
|
||||
|
||||
if notebook:
|
||||
@@ -131,15 +135,7 @@ if __name__ == "__main__":
|
||||
model = get_model(training_conf, tokenizer)
|
||||
|
||||
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
|
||||
|
||||
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None
|
||||
|
||||
if training_conf.quantization:
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Embedding):
|
||||
bitsandbytes.optim.GlobalOptimManager.get_instance().register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
metrics, preprocess_fn = get_metrics(training_conf)
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
|
||||
@@ -147,7 +143,6 @@ if __name__ == "__main__":
|
||||
warmup_steps=training_conf.warmup_steps,
|
||||
learning_rate=float(training_conf.learning_rate),
|
||||
deepspeed="configs/zero_config.json" if training_conf.deepspeed else None,
|
||||
optim=optimizer,
|
||||
fp16=True,
|
||||
local_rank=training_conf.local_rank,
|
||||
gradient_checkpointing=training_conf.gradient_checkpointing,
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import evaluate
|
||||
import nltk
|
||||
import numpy as np
|
||||
import transformers
|
||||
import yaml
|
||||
from custom_datasets import QA_SPECIAL_TOKENS, get_one_dataset
|
||||
from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
from losses import CrossEntropyLoss
|
||||
from losses import CrossEntropyLoss, PolyLoss
|
||||
from models import freeze_top_n_layers, get_specific_model
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import ConcatDataset, Subset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def get_tokenizer(conf):
|
||||
tokenizer = AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
|
||||
if "galactica" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
|
||||
@@ -32,8 +36,51 @@ def get_tokenizer(conf):
|
||||
return tokenizer
|
||||
|
||||
|
||||
# placeholder for now
|
||||
def preprocess_qa(eval_pred):
|
||||
return (eval_pred.predictions, eval_pred.label_ids)
|
||||
|
||||
|
||||
def postprocess_summarization(preds, labels):
|
||||
preds = [pred.strip() for pred in preds]
|
||||
labels = [label.strip() for label in labels]
|
||||
|
||||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
||||
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
||||
|
||||
return preds, labels
|
||||
|
||||
|
||||
def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True):
|
||||
preds, labels = eval_pred
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
if ignore_pad_token_for_loss:
|
||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels)
|
||||
return decoded_preds, decoded_labels
|
||||
|
||||
|
||||
def get_metrics(conf, tokenizer):
|
||||
qa_datasets = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]
|
||||
summarization_datasets = ["xsum", "cnn_dailymail", "samsum", "multi_news"]
|
||||
|
||||
# the reason behind using a list is that we might want to extend the list of our
|
||||
# metrics in the future for more thorough evaluation
|
||||
if any(dataset in qa_datasets for dataset in conf.datasets):
|
||||
metrics, preprocess_fn = [evaluate.load("squad_v2")], preprocess_qa
|
||||
elif any(dataset in summarization_datasets for dataset in conf.datasets):
|
||||
metrics, preprocess_fn = [evaluate.load("rouge")], partial(
|
||||
preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown dataset / task")
|
||||
return metrics, preprocess_fn
|
||||
|
||||
|
||||
def get_model(conf, tokenizer):
|
||||
model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization)
|
||||
model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization, conf.seq2seqmodel)
|
||||
|
||||
if len(tokenizer) != model.get_input_embeddings().num_embeddings:
|
||||
assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen."
|
||||
@@ -65,9 +112,11 @@ def get_dataset(conf, tokenizer):
|
||||
return train, evals, collate_fn
|
||||
|
||||
|
||||
def get_loss(loss):
|
||||
def get_loss(loss, poly_eps):
|
||||
if loss == "CrossEntropyLoss":
|
||||
return CrossEntropyLoss()
|
||||
elif loss == "Poly":
|
||||
return PolyLoss(epsilon=poly_eps)
|
||||
else:
|
||||
raise ValueError(f"Loss {loss} not supported")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user