dataset args

This commit is contained in:
Sotirios Anagnostidis
2023-02-11 11:02:48 +01:00
parent 714677b5d8
commit 23ee2f24d9
@@ -42,38 +42,38 @@ def train_val_dataset(dataset, val_split=0.2):
return Subset(dataset, train_idx), Subset(dataset, val_idx)
def get_one_dataset(conf, dataset_name):
def get_one_dataset(conf, dataset_name, val_split=0.2, **kwargs):
dataset_name = dataset_name.lower()
if dataset_name in QA_DATASETS:
train = QADataset(dataset_name, conf.cache_dir, "train")
if train.no_val:
train, eval = train_val_dataset(train, val_split=0.2)
train, eval = train_val_dataset(train, val_split=val_split, **kwargs)
else:
eval = QADataset(dataset_name, conf.cache_dir, "validation")
elif dataset_name in SUMMARIZATION_DATASETS:
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
if dataset_name == "debate_sum":
train, eval = train_val_dataset(train, val_split=0.2)
train, eval = train_val_dataset(train, val_split=val_split, **kwargs)
else:
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
elif "ted_trans" in dataset_name:
language_pair = dataset_name.split("_")[-1]
dataset = TEDTalk(pair=language_pair, split="train")
train, eval = train_val_dataset(dataset, val_split=0.2)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
elif "wmt2019" in dataset_name:
language_pair = dataset_name.split("_")[-1]
train = WMT2019(pair=language_pair, split="train")
eval = WMT2019(pair=language_pair, split="validation")
elif dataset_name == "dive_mt":
dataset = DiveMT()
train, eval = train_val_dataset(dataset, val_split=0.2)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
elif dataset_name == "webgpt":
dataset = WebGPT()
train, eval = train_val_dataset(dataset, val_split=0.2)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
elif dataset_name == "prompt_dialogue":
dataset = PromptGeneratedDataset(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
elif dataset_name == "prosocial_dialogue":
train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train")
eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation")
@@ -82,22 +82,22 @@ def get_one_dataset(conf, dataset_name):
eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation")
elif dataset_name == "soda":
dataset = SODA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.1)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
elif dataset_name == "soda_dialogue":
dataset = SODADialogue(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.1)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
elif dataset_name == "joke":
dataset = JokeExplaination(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
elif dataset_name == "instruct_tuning":
dataset = InstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
elif dataset_name == "private_tuning":
dataset = PrivateInstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
elif dataset_name == "oa_translated":
dataset = TranslatedQA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.01)
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) # TODO make val split lower..?
else:
raise ValueError(f"Unknown dataset {dataset_name}")