mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
dataset args
This commit is contained in:
@@ -42,38 +42,38 @@ def train_val_dataset(dataset, val_split=0.2):
|
||||
return Subset(dataset, train_idx), Subset(dataset, val_idx)
|
||||
|
||||
|
||||
def get_one_dataset(conf, dataset_name):
|
||||
def get_one_dataset(conf, dataset_name, val_split=0.2, **kwargs):
|
||||
dataset_name = dataset_name.lower()
|
||||
|
||||
if dataset_name in QA_DATASETS:
|
||||
train = QADataset(dataset_name, conf.cache_dir, "train")
|
||||
if train.no_val:
|
||||
train, eval = train_val_dataset(train, val_split=0.2)
|
||||
train, eval = train_val_dataset(train, val_split=val_split, **kwargs)
|
||||
else:
|
||||
eval = QADataset(dataset_name, conf.cache_dir, "validation")
|
||||
elif dataset_name in SUMMARIZATION_DATASETS:
|
||||
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
|
||||
if dataset_name == "debate_sum":
|
||||
train, eval = train_val_dataset(train, val_split=0.2)
|
||||
train, eval = train_val_dataset(train, val_split=val_split, **kwargs)
|
||||
else:
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
|
||||
elif "ted_trans" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
dataset = TEDTalk(pair=language_pair, split="train")
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
elif "wmt2019" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
train = WMT2019(pair=language_pair, split="train")
|
||||
eval = WMT2019(pair=language_pair, split="validation")
|
||||
elif dataset_name == "dive_mt":
|
||||
dataset = DiveMT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
elif dataset_name == "webgpt":
|
||||
dataset = WebGPT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
elif dataset_name == "prompt_dialogue":
|
||||
dataset = PromptGeneratedDataset(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
elif dataset_name == "prosocial_dialogue":
|
||||
train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train")
|
||||
eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation")
|
||||
@@ -82,22 +82,22 @@ def get_one_dataset(conf, dataset_name):
|
||||
eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation")
|
||||
elif dataset_name == "soda":
|
||||
dataset = SODA(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.1)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
elif dataset_name == "soda_dialogue":
|
||||
dataset = SODADialogue(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.1)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
elif dataset_name == "joke":
|
||||
dataset = JokeExplaination(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
elif dataset_name == "instruct_tuning":
|
||||
dataset = InstructionTuning(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
elif dataset_name == "private_tuning":
|
||||
dataset = PrivateInstructionTuning(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
elif dataset_name == "oa_translated":
|
||||
dataset = TranslatedQA(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.01)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) # TODO make val split lower..?
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {dataset_name}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user