diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 73fc6d89..def468e1 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,7 +1,12 @@ """ High level functions for model training """ -from custom_datasets.prompt_dialogue import InstructionTuning, PrivateInstructionTuning, PromptGeneratedDataset +from custom_datasets.prompt_dialogue import ( + InstructionTuning, + OAPrivate, + PrivateInstructionTuning, + PromptGeneratedDataset, +) from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT from custom_datasets.summarization import SummarizationDataset from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination @@ -36,6 +41,9 @@ OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_ def train_val_dataset(dataset, val_split=0.2): + if val_split == 0: + return dataset, None + train_idx, val_idx = train_test_split( list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True ) @@ -85,11 +93,13 @@ def get_one_dataset(conf, dataset_name, val_split=0.2, data_path=None, **kwargs) dataset = PrivateInstructionTuning(data_path) elif dataset_name == "oa_translated": dataset = TranslatedQA(data_path) # TODO make val_split lower..? + elif dataset_name == "oa_private": + dataset = OAPrivate(data_path, **kwargs) else: raise ValueError(f"Unknown dataset {dataset_name}") # if eval not already defined if "dataset" in locals(): - train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) + train, eval = train_val_dataset(dataset, val_split=val_split) return train, eval diff --git a/model/supervised_finetuning/custom_datasets/formatting.py b/model/supervised_finetuning/custom_datasets/formatting.py index a6c1c0d8..e69557a9 100644 --- a/model/supervised_finetuning/custom_datasets/formatting.py +++ b/model/supervised_finetuning/custom_datasets/formatting.py @@ -1,5 +1,11 @@ QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} -def format_pair(pair): - return "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pair[0], QA_SPECIAL_TOKENS["Answer"]), pair[1] +def format_pair(pairs): + assert len(pairs) % 2 == 0 + return [ + "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pairs[i], QA_SPECIAL_TOKENS["Answer"]) + if i % 2 == 0 + else pairs[i] + for i in range(0, len(pairs), 2) + ] diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 1550de6c..d7b80761 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -12,10 +12,9 @@ from torch.utils.data import Dataset class OAPrivate(Dataset): - file = "2023-02-10_oasst_prod.jsonl" splits = OrderedDict(sft=0.25, reward_model=0.4, rl=0.35) # fractions per task - def __init__(self, split="sft", data_path=".cache") -> None: + def __init__(self, data_path, split="sft", file="2023-02-10_oasst_prod.jsonl") -> None: super().__init__() total_prob = reduce(lambda prev, split: prev + split[1], self.splits.items(), 0) @@ -56,14 +55,17 @@ class OAPrivate(Dataset): reply = random.choice(prompt["replies"]) reply_text = reply["text"] - pairs.append([prompter_text, reply_text]) + + # only add if the reply exists + pairs.append(prompter_text) + pairs.append(reply_text) if len(reply["replies"]) == 0: break prompt = random.choice(reply["replies"]) - return pairs + return format_pair(pairs) class PromptGeneratedDataset(Dataset):