From 23ee2f24d9091b1e30b4bdb5867d3e989d1123e7 Mon Sep 17 00:00:00 2001 From: Sotirios Anagnostidis Date: Sat, 11 Feb 2023 11:02:48 +0100 Subject: [PATCH] dataset args --- .../custom_datasets/__init__.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 0e5b9a91..440cb62b 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -42,38 +42,38 @@ def train_val_dataset(dataset, val_split=0.2): return Subset(dataset, train_idx), Subset(dataset, val_idx) -def get_one_dataset(conf, dataset_name): +def get_one_dataset(conf, dataset_name, val_split=0.2, **kwargs): dataset_name = dataset_name.lower() if dataset_name in QA_DATASETS: train = QADataset(dataset_name, conf.cache_dir, "train") if train.no_val: - train, eval = train_val_dataset(train, val_split=0.2) + train, eval = train_val_dataset(train, val_split=val_split, **kwargs) else: eval = QADataset(dataset_name, conf.cache_dir, "validation") elif dataset_name in SUMMARIZATION_DATASETS: train = SummarizationDataset(dataset_name, conf.cache_dir, "train") if dataset_name == "debate_sum": - train, eval = train_val_dataset(train, val_split=0.2) + train, eval = train_val_dataset(train, val_split=val_split, **kwargs) else: eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation") elif "ted_trans" in dataset_name: language_pair = dataset_name.split("_")[-1] dataset = TEDTalk(pair=language_pair, split="train") - train, eval = train_val_dataset(dataset, val_split=0.2) + train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) elif "wmt2019" in dataset_name: language_pair = dataset_name.split("_")[-1] train = WMT2019(pair=language_pair, split="train") eval = WMT2019(pair=language_pair, split="validation") elif dataset_name == "dive_mt": dataset = DiveMT() - train, eval = train_val_dataset(dataset, val_split=0.2) + train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) elif dataset_name == "webgpt": dataset = WebGPT() - train, eval = train_val_dataset(dataset, val_split=0.2) + train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) elif dataset_name == "prompt_dialogue": dataset = PromptGeneratedDataset(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.2) + train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) elif dataset_name == "prosocial_dialogue": train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train") eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation") @@ -82,22 +82,22 @@ def get_one_dataset(conf, dataset_name): eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation") elif dataset_name == "soda": dataset = SODA(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.1) + train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) elif dataset_name == "soda_dialogue": dataset = SODADialogue(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.1) + train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) elif dataset_name == "joke": dataset = JokeExplaination(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.2) + train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) elif dataset_name == "instruct_tuning": dataset = InstructionTuning(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.2) + train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) elif dataset_name == "private_tuning": dataset = PrivateInstructionTuning(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.2) + train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) elif dataset_name == "oa_translated": dataset = TranslatedQA(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.01) + train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) # TODO make val split lower..? else: raise ValueError(f"Unknown dataset {dataset_name}")