Merge pull request #989 from LAION-AI/sft-dataset-update

Add new SFT datasets
This commit is contained in:
sanagnos
2023-01-29 13:34:44 +01:00
committed by GitHub
3 changed files with 104 additions and 27 deletions
@@ -9,7 +9,19 @@ from custom_datasets.translation import WMT2019, DiveMT, TEDTalk
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"]
QA_DATASETS = [
"squad_v2",
"adversarial_qa",
"trivia_qa_context",
"trivia_qa_nocontext",
"gsm8k",
"wikihow",
"essay_instruction",
"math_qa",
"reddit_eli5",
"reddit_askh",
"reddit_asks",
]
SUMMARIZATION_DATASETS = [
"xsum",
"cnn_dailymail",
@@ -35,16 +47,16 @@ def get_one_dataset(conf, dataset_name):
if dataset_name in QA_DATASETS:
train = QADataset(dataset_name, conf.cache_dir, "train")
val_name = "validation" if dataset_name not in ["gsm8k"] else "test"
eval = QADataset(dataset_name, conf.cache_dir, val_name)
if train.no_val:
train, eval = train_val_dataset(train, val_split=0.2)
else:
eval = QADataset(dataset_name, conf.cache_dir, "validation")
elif dataset_name in SUMMARIZATION_DATASETS:
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
if dataset_name == "debate_sum":
train, eval = train_val_dataset(train, val_split=0.2)
else:
val_name = "validation" if dataset_name not in ["billsum", "tldr_news"] else "test"
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
elif "ted_trans" in dataset_name:
language_pair = dataset_name.split("_")[-1]
dataset = TEDTalk(pair=language_pair, split="train")
@@ -49,26 +49,92 @@ def index_gsm8k(example):
return example["question"], example["answer"]
def index_wikihow(example):
return example["title"] + ", explain step by step", example["result"]
def index_essay_instruction(example):
return example["instructions"], example["titles"].strip() + "\n" + example["essays"]
def index_math_qa(example):
"""
we are not including choices, so no need to output the "answer : <a,b,c,d>" part
> if girls is 10 and boys is 20 , then 10 / 20 . so ratio of girls to boys is = 10 / 20 = 1 / 2 answer : a
"""
return example["Problem"], example["Rationale"].split("answer : ", maxsplit=1)[0]
def index_eli5(example):
return example["title"], example["answers"]["text"][0]
class QADataset(Dataset):
"""
How to define a new QA dataset:
Criteria : the qa dataset doesn't need fancy transform needed between fields rows or list
1. Write the transform function, which maps each row into a pair of (question, answer) tuple
2. Update DATASET_FORMAT_MAPPING with your dataset name and required parameter
- index_fn : your transform function
- name: the dataset name, this will be used when the name is different than huggingface load_dataset name
- params: if your dataset require a predefined name, create a dictionary with the parameter name-value dictionary
Feel free to create issues on GH for any suggestion how we can simplify this thing
"""
DATASET_FORMAT_MAPPING = {
"squad_v2": {"index_fn": index_squad_v2},
"trivia_qa_nocontext": {
"index_fn": index_trivia_qa_nocontext,
"name": "trivia_qa",
"params": {"name": "rc.nocontext"},
},
"trivia_qa_context": {"index_fn": index_trivia_qa_context, "name": "trivia_qa", "params": {"name": "rc"}},
"adversarial_qa": {
"index_fn": index_adversarial_qa,
"params": {"name": "adversarialQA"},
},
"gsm8k": {"index_fn": index_gsm8k, "params": {"name": "main"}, "validation": "test"},
"wikihow": {"name": "b-mc2/wikihow_lists", "index_fn": index_wikihow, "no_val": True},
"essay_instruction": {
"name": "ChristophSchuhmann/essays-with-instructions",
"index_fn": index_essay_instruction,
"no_val": True,
},
"math_qa": {
"index_fn": index_math_qa,
},
"reddit_eli5": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_eli5"},
"reddit_askh": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_askh"},
"reddit_asks": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_asks"},
}
def __init__(self, dataset, cache_dir, split):
if dataset == "squad_v2":
self.index_fn = index_squad_v2
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
elif dataset == "trivia_qa_nocontext":
self.index_fn = index_trivia_qa_nocontext
self.dataset = load_dataset("trivia_qa", "rc.nocontext", split=split, cache_dir=cache_dir)
elif dataset == "trivia_qa_context":
self.index_fn = index_trivia_qa_context
self.dataset = load_dataset("trivia_qa", "rc", split=split, cache_dir=cache_dir)
elif dataset == "adversarial_qa":
self.index_fn = index_adversarial_qa
self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir)
elif dataset == "gsm8k":
self.index_fn = index_gsm8k
self.dataset = load_dataset("gsm8k", "main", split=split, cache_dir=cache_dir)
elif dataset == "adversarial_qa":
self.index_fn = index_adversarial_qa
self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir)
self.no_val = False
if dataset in self.DATASET_FORMAT_MAPPING:
context = self.DATASET_FORMAT_MAPPING[dataset]
if split == "validation" and "validation" in context:
split = context["validation"]
if "name" not in context:
context["name"] = dataset
if "split_postfix" in context:
# append a postfix to split name, used in eli5 : test_eli5, test_asks, test_askh
split += context["split_postfix"]
if "params" not in context:
context["params"] = {"cache_dir": cache_dir, "split": split}
else:
context["params"]["cache_dir"] = cache_dir
context["params"]["split"] = split
if "no_val" in context:
self.no_val = True
self.index_fn = context["index_fn"]
self.dataset = load_dataset(context["name"], **context["params"])
else:
raise ValueError("Unknown dataset : " + dataset)
@@ -259,6 +325,3 @@ class JokeExplaination(Dataset):
def __getitem__(self, index):
return format_pair(self.pairs[index])
# https://huggingface.co/datasets/aquamuse
@@ -57,6 +57,8 @@ def index_summary_merge(text, summary):
class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split, max_words=512):
self.name = dataset
if summarization_config_mapping[dataset][0] in ["billsum", "tldr_news"] & split == "validation":
split = "test"
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.text_column, self.summary_column = summarization_name_mapping[dataset]
self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default