diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 558ec502..ee061f04 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -9,7 +9,19 @@ from custom_datasets.translation import WMT2019, DiveMT, TEDTalk from sklearn.model_selection import train_test_split from torch.utils.data import Subset -QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"] +QA_DATASETS = [ + "squad_v2", + "adversarial_qa", + "trivia_qa_context", + "trivia_qa_nocontext", + "gsm8k", + "wikihow", + "essay_instruction", + "math_qa", + "reddit_eli5", + "reddit_askh", + "reddit_asks", +] SUMMARIZATION_DATASETS = [ "xsum", "cnn_dailymail", @@ -35,16 +47,16 @@ def get_one_dataset(conf, dataset_name): if dataset_name in QA_DATASETS: train = QADataset(dataset_name, conf.cache_dir, "train") - val_name = "validation" if dataset_name not in ["gsm8k"] else "test" - eval = QADataset(dataset_name, conf.cache_dir, val_name) - + if train.no_val: + train, eval = train_val_dataset(train, val_split=0.2) + else: + eval = QADataset(dataset_name, conf.cache_dir, "validation") elif dataset_name in SUMMARIZATION_DATASETS: train = SummarizationDataset(dataset_name, conf.cache_dir, "train") if dataset_name == "debate_sum": train, eval = train_val_dataset(train, val_split=0.2) else: - val_name = "validation" if dataset_name not in ["billsum", "tldr_news"] else "test" - eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name) + eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation") elif "ted_trans" in dataset_name: language_pair = dataset_name.split("_")[-1] dataset = TEDTalk(pair=language_pair, split="train") diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index 47b1c247..2c5c7ee2 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -49,26 +49,92 @@ def index_gsm8k(example): return example["question"], example["answer"] +def index_wikihow(example): + return example["title"] + ", explain step by step", example["result"] + + +def index_essay_instruction(example): + return example["instructions"], example["titles"].strip() + "\n" + example["essays"] + + +def index_math_qa(example): + """ + we are not including choices, so no need to output the "answer : " part + > if girls is 10 and boys is 20 , then 10 / 20 . so ratio of girls to boys is = 10 / 20 = 1 / 2 answer : a + """ + return example["Problem"], example["Rationale"].split("answer : ", maxsplit=1)[0] + + +def index_eli5(example): + return example["title"], example["answers"]["text"][0] + + class QADataset(Dataset): + """ + How to define a new QA dataset: + + Criteria : the qa dataset doesn't need fancy transform needed between fields rows or list + + 1. Write the transform function, which maps each row into a pair of (question, answer) tuple + + 2. Update DATASET_FORMAT_MAPPING with your dataset name and required parameter + + - index_fn : your transform function + + - name: the dataset name, this will be used when the name is different than huggingface load_dataset name + + - params: if your dataset require a predefined name, create a dictionary with the parameter name-value dictionary + + Feel free to create issues on GH for any suggestion how we can simplify this thing + """ + + DATASET_FORMAT_MAPPING = { + "squad_v2": {"index_fn": index_squad_v2}, + "trivia_qa_nocontext": { + "index_fn": index_trivia_qa_nocontext, + "name": "trivia_qa", + "params": {"name": "rc.nocontext"}, + }, + "trivia_qa_context": {"index_fn": index_trivia_qa_context, "name": "trivia_qa", "params": {"name": "rc"}}, + "adversarial_qa": { + "index_fn": index_adversarial_qa, + "params": {"name": "adversarialQA"}, + }, + "gsm8k": {"index_fn": index_gsm8k, "params": {"name": "main"}, "validation": "test"}, + "wikihow": {"name": "b-mc2/wikihow_lists", "index_fn": index_wikihow, "no_val": True}, + "essay_instruction": { + "name": "ChristophSchuhmann/essays-with-instructions", + "index_fn": index_essay_instruction, + "no_val": True, + }, + "math_qa": { + "index_fn": index_math_qa, + }, + "reddit_eli5": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_eli5"}, + "reddit_askh": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_askh"}, + "reddit_asks": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_asks"}, + } + def __init__(self, dataset, cache_dir, split): - if dataset == "squad_v2": - self.index_fn = index_squad_v2 - self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split) - elif dataset == "trivia_qa_nocontext": - self.index_fn = index_trivia_qa_nocontext - self.dataset = load_dataset("trivia_qa", "rc.nocontext", split=split, cache_dir=cache_dir) - elif dataset == "trivia_qa_context": - self.index_fn = index_trivia_qa_context - self.dataset = load_dataset("trivia_qa", "rc", split=split, cache_dir=cache_dir) - elif dataset == "adversarial_qa": - self.index_fn = index_adversarial_qa - self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir) - elif dataset == "gsm8k": - self.index_fn = index_gsm8k - self.dataset = load_dataset("gsm8k", "main", split=split, cache_dir=cache_dir) - elif dataset == "adversarial_qa": - self.index_fn = index_adversarial_qa - self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir) + self.no_val = False + if dataset in self.DATASET_FORMAT_MAPPING: + context = self.DATASET_FORMAT_MAPPING[dataset] + if split == "validation" and "validation" in context: + split = context["validation"] + if "name" not in context: + context["name"] = dataset + if "split_postfix" in context: + # append a postfix to split name, used in eli5 : test_eli5, test_asks, test_askh + split += context["split_postfix"] + if "params" not in context: + context["params"] = {"cache_dir": cache_dir, "split": split} + else: + context["params"]["cache_dir"] = cache_dir + context["params"]["split"] = split + if "no_val" in context: + self.no_val = True + self.index_fn = context["index_fn"] + self.dataset = load_dataset(context["name"], **context["params"]) else: raise ValueError("Unknown dataset : " + dataset) @@ -259,6 +325,3 @@ class JokeExplaination(Dataset): def __getitem__(self, index): return format_pair(self.pairs[index]) - - -# https://huggingface.co/datasets/aquamuse diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py index 85d21a27..834fa16c 100644 --- a/model/supervised_finetuning/custom_datasets/summarization.py +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -57,6 +57,8 @@ def index_summary_merge(text, summary): class SummarizationDataset(Dataset): def __init__(self, dataset, cache_dir, split, max_words=512): self.name = dataset + if summarization_config_mapping[dataset][0] in ["billsum", "tldr_news"] & split == "validation": + split = "test" self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) self.text_column, self.summary_column = summarization_name_mapping[dataset] self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default