diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 48641225..0e5b9a91 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -32,7 +32,7 @@ SUMMARIZATION_DATASETS = [ "debate_sum", "tldr_news", ] -OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning"] +OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated"] def train_val_dataset(dataset, val_split=0.2): @@ -95,7 +95,7 @@ def get_one_dataset(conf, dataset_name): elif dataset_name == "private_tuning": dataset = PrivateInstructionTuning(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.2) - elif dataset_name == "translate_qa": + elif dataset_name == "oa_translated": dataset = TranslatedQA(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.01) else: diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index 7876a920..145aee72 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -349,8 +349,21 @@ class TranslatedQA(Dataset): if "Python " in data["text"]: continue # incorrect, TODO: fix later + prefix = "" for convo_round in data["translate"]: - self.pairs.append((convo_round["human"], convo_round["answer"])) + human, answer = format_pair((convo_round["human"], convo_round["answer"])) + if convo_round["round"] > 2: + # TODO: remove this later + self.pairs.append(("{}{}{}".format(prefix, "", human), answer)) + else: + self.pairs.append((human, answer)) + + prefix += "{}{}{}{}".format( + QA_SPECIAL_TOKENS["Question"], + convo_round["human"], + QA_SPECIAL_TOKENS["Answer"], + convo_round["answer"], + ) self.length = len(self.pairs) @@ -358,4 +371,4 @@ class TranslatedQA(Dataset): return self.length def __getitem__(self, index): - return format_pair(self.pairs[index]) + return self.pairs[index]