mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge pull request #989 from LAION-AI/sft-dataset-update
Add new SFT datasets
This commit is contained in:
@@ -9,7 +9,19 @@ from custom_datasets.translation import WMT2019, DiveMT, TEDTalk
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset
|
||||
|
||||
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"]
|
||||
QA_DATASETS = [
|
||||
"squad_v2",
|
||||
"adversarial_qa",
|
||||
"trivia_qa_context",
|
||||
"trivia_qa_nocontext",
|
||||
"gsm8k",
|
||||
"wikihow",
|
||||
"essay_instruction",
|
||||
"math_qa",
|
||||
"reddit_eli5",
|
||||
"reddit_askh",
|
||||
"reddit_asks",
|
||||
]
|
||||
SUMMARIZATION_DATASETS = [
|
||||
"xsum",
|
||||
"cnn_dailymail",
|
||||
@@ -35,16 +47,16 @@ def get_one_dataset(conf, dataset_name):
|
||||
|
||||
if dataset_name in QA_DATASETS:
|
||||
train = QADataset(dataset_name, conf.cache_dir, "train")
|
||||
val_name = "validation" if dataset_name not in ["gsm8k"] else "test"
|
||||
eval = QADataset(dataset_name, conf.cache_dir, val_name)
|
||||
|
||||
if train.no_val:
|
||||
train, eval = train_val_dataset(train, val_split=0.2)
|
||||
else:
|
||||
eval = QADataset(dataset_name, conf.cache_dir, "validation")
|
||||
elif dataset_name in SUMMARIZATION_DATASETS:
|
||||
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
|
||||
if dataset_name == "debate_sum":
|
||||
train, eval = train_val_dataset(train, val_split=0.2)
|
||||
else:
|
||||
val_name = "validation" if dataset_name not in ["billsum", "tldr_news"] else "test"
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
|
||||
elif "ted_trans" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
dataset = TEDTalk(pair=language_pair, split="train")
|
||||
|
||||
@@ -49,26 +49,92 @@ def index_gsm8k(example):
|
||||
return example["question"], example["answer"]
|
||||
|
||||
|
||||
def index_wikihow(example):
|
||||
return example["title"] + ", explain step by step", example["result"]
|
||||
|
||||
|
||||
def index_essay_instruction(example):
|
||||
return example["instructions"], example["titles"].strip() + "\n" + example["essays"]
|
||||
|
||||
|
||||
def index_math_qa(example):
|
||||
"""
|
||||
we are not including choices, so no need to output the "answer : <a,b,c,d>" part
|
||||
> if girls is 10 and boys is 20 , then 10 / 20 . so ratio of girls to boys is = 10 / 20 = 1 / 2 answer : a
|
||||
"""
|
||||
return example["Problem"], example["Rationale"].split("answer : ", maxsplit=1)[0]
|
||||
|
||||
|
||||
def index_eli5(example):
|
||||
return example["title"], example["answers"]["text"][0]
|
||||
|
||||
|
||||
class QADataset(Dataset):
|
||||
"""
|
||||
How to define a new QA dataset:
|
||||
|
||||
Criteria : the qa dataset doesn't need fancy transform needed between fields rows or list
|
||||
|
||||
1. Write the transform function, which maps each row into a pair of (question, answer) tuple
|
||||
|
||||
2. Update DATASET_FORMAT_MAPPING with your dataset name and required parameter
|
||||
|
||||
- index_fn : your transform function
|
||||
|
||||
- name: the dataset name, this will be used when the name is different than huggingface load_dataset name
|
||||
|
||||
- params: if your dataset require a predefined name, create a dictionary with the parameter name-value dictionary
|
||||
|
||||
Feel free to create issues on GH for any suggestion how we can simplify this thing
|
||||
"""
|
||||
|
||||
DATASET_FORMAT_MAPPING = {
|
||||
"squad_v2": {"index_fn": index_squad_v2},
|
||||
"trivia_qa_nocontext": {
|
||||
"index_fn": index_trivia_qa_nocontext,
|
||||
"name": "trivia_qa",
|
||||
"params": {"name": "rc.nocontext"},
|
||||
},
|
||||
"trivia_qa_context": {"index_fn": index_trivia_qa_context, "name": "trivia_qa", "params": {"name": "rc"}},
|
||||
"adversarial_qa": {
|
||||
"index_fn": index_adversarial_qa,
|
||||
"params": {"name": "adversarialQA"},
|
||||
},
|
||||
"gsm8k": {"index_fn": index_gsm8k, "params": {"name": "main"}, "validation": "test"},
|
||||
"wikihow": {"name": "b-mc2/wikihow_lists", "index_fn": index_wikihow, "no_val": True},
|
||||
"essay_instruction": {
|
||||
"name": "ChristophSchuhmann/essays-with-instructions",
|
||||
"index_fn": index_essay_instruction,
|
||||
"no_val": True,
|
||||
},
|
||||
"math_qa": {
|
||||
"index_fn": index_math_qa,
|
||||
},
|
||||
"reddit_eli5": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_eli5"},
|
||||
"reddit_askh": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_askh"},
|
||||
"reddit_asks": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_asks"},
|
||||
}
|
||||
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
if dataset == "squad_v2":
|
||||
self.index_fn = index_squad_v2
|
||||
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
|
||||
elif dataset == "trivia_qa_nocontext":
|
||||
self.index_fn = index_trivia_qa_nocontext
|
||||
self.dataset = load_dataset("trivia_qa", "rc.nocontext", split=split, cache_dir=cache_dir)
|
||||
elif dataset == "trivia_qa_context":
|
||||
self.index_fn = index_trivia_qa_context
|
||||
self.dataset = load_dataset("trivia_qa", "rc", split=split, cache_dir=cache_dir)
|
||||
elif dataset == "adversarial_qa":
|
||||
self.index_fn = index_adversarial_qa
|
||||
self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir)
|
||||
elif dataset == "gsm8k":
|
||||
self.index_fn = index_gsm8k
|
||||
self.dataset = load_dataset("gsm8k", "main", split=split, cache_dir=cache_dir)
|
||||
elif dataset == "adversarial_qa":
|
||||
self.index_fn = index_adversarial_qa
|
||||
self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir)
|
||||
self.no_val = False
|
||||
if dataset in self.DATASET_FORMAT_MAPPING:
|
||||
context = self.DATASET_FORMAT_MAPPING[dataset]
|
||||
if split == "validation" and "validation" in context:
|
||||
split = context["validation"]
|
||||
if "name" not in context:
|
||||
context["name"] = dataset
|
||||
if "split_postfix" in context:
|
||||
# append a postfix to split name, used in eli5 : test_eli5, test_asks, test_askh
|
||||
split += context["split_postfix"]
|
||||
if "params" not in context:
|
||||
context["params"] = {"cache_dir": cache_dir, "split": split}
|
||||
else:
|
||||
context["params"]["cache_dir"] = cache_dir
|
||||
context["params"]["split"] = split
|
||||
if "no_val" in context:
|
||||
self.no_val = True
|
||||
self.index_fn = context["index_fn"]
|
||||
self.dataset = load_dataset(context["name"], **context["params"])
|
||||
else:
|
||||
raise ValueError("Unknown dataset : " + dataset)
|
||||
|
||||
@@ -259,6 +325,3 @@ class JokeExplaination(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
# https://huggingface.co/datasets/aquamuse
|
||||
|
||||
@@ -57,6 +57,8 @@ def index_summary_merge(text, summary):
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split, max_words=512):
|
||||
self.name = dataset
|
||||
if summarization_config_mapping[dataset][0] in ["billsum", "tldr_news"] & split == "validation":
|
||||
split = "test"
|
||||
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.text_column, self.summary_column = summarization_name_mapping[dataset]
|
||||
self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default
|
||||
|
||||
Reference in New Issue
Block a user