""" High level functions for model training """ from custom_datasets.prompt_dialogue import ( InstructionTuning, OAPrivate, PrivateInstructionTuning, PromptGeneratedDataset, ) from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT from custom_datasets.summarization import SummarizationDataset from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination from custom_datasets.translation import WMT2019, DiveMT, TEDTalk from sklearn.model_selection import train_test_split from torch.utils.data import Subset QA_DATASETS = [ "squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k", "wikihow", "essay_instruction", "math_qa", "reddit_eli5", "reddit_askh", "reddit_asks", ] SUMMARIZATION_DATASETS = [ "xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "debate_sum", "tldr_news", ] OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated", "oa_private"] def train_val_dataset(dataset, val_split=0.2): if val_split == 0: return dataset, None train_idx, val_idx = train_test_split( list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True ) return Subset(dataset, train_idx), Subset(dataset, val_idx) def get_one_dataset(conf, dataset_name, val_split=0.2, data_path=None, **kwargs): data_path = data_path or conf.cache_dir dataset_name = dataset_name.lower() if dataset_name in QA_DATASETS: train = QADataset(dataset_name, data_path, "train") if not train.no_val: eval = QADataset(dataset_name, data_path, "validation") elif dataset_name in SUMMARIZATION_DATASETS: train = SummarizationDataset(dataset_name, data_path, "train") if dataset_name != "debate_sum": eval = SummarizationDataset(dataset_name, data_path, "validation") elif "ted_trans" in dataset_name: language_pair = dataset_name.split("_")[-1] dataset = TEDTalk(pair=language_pair, split="train") elif "wmt2019" in dataset_name: language_pair = dataset_name.split("_")[-1] train = WMT2019(pair=language_pair, split="train") eval = WMT2019(pair=language_pair, split="validation") elif dataset_name == "dive_mt": dataset = DiveMT() elif dataset_name == "webgpt": dataset = WebGPT() elif dataset_name == "prompt_dialogue": dataset = PromptGeneratedDataset(data_path) elif dataset_name == "prosocial_dialogue": train = ProsocialDialogue(cache_dir=data_path, split="train") eval = ProsocialDialogue(cache_dir=data_path, split="validation") elif dataset_name == "explain_prosocial": train = ProsocialDialogueExplaination(cache_dir=data_path, split="train") eval = ProsocialDialogueExplaination(cache_dir=data_path, split="validation") elif dataset_name == "soda": dataset = SODA(data_path) elif dataset_name == "soda_dialogue": dataset = SODADialogue(data_path) elif dataset_name == "joke": dataset = JokeExplaination(data_path) elif dataset_name == "instruct_tuning": dataset = InstructionTuning(data_path) elif dataset_name == "private_tuning": dataset = PrivateInstructionTuning(data_path) elif dataset_name == "oa_translated": dataset = TranslatedQA(data_path) # TODO make val_split lower..? elif dataset_name == "oa_private": dataset = OAPrivate(data_path, **kwargs) else: raise ValueError(f"Unknown dataset {dataset_name}") # if eval not already defined if "dataset" in locals(): train, eval = train_val_dataset(dataset, val_split=val_split) return train, eval