os private dataset

This commit is contained in:
Sotirios Anagnostidis
2023-02-11 13:20:42 +01:00
parent 5b1427d811
commit 6a68139b91
3 changed files with 26 additions and 8 deletions
@@ -1,7 +1,12 @@
"""
High level functions for model training
"""
from custom_datasets.prompt_dialogue import InstructionTuning, PrivateInstructionTuning, PromptGeneratedDataset
from custom_datasets.prompt_dialogue import (
InstructionTuning,
OAPrivate,
PrivateInstructionTuning,
PromptGeneratedDataset,
)
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT
from custom_datasets.summarization import SummarizationDataset
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
@@ -36,6 +41,9 @@ OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_
def train_val_dataset(dataset, val_split=0.2):
if val_split == 0:
return dataset, None
train_idx, val_idx = train_test_split(
list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True
)
@@ -85,11 +93,13 @@ def get_one_dataset(conf, dataset_name, val_split=0.2, data_path=None, **kwargs)
dataset = PrivateInstructionTuning(data_path)
elif dataset_name == "oa_translated":
dataset = TranslatedQA(data_path) # TODO make val_split lower..?
elif dataset_name == "oa_private":
dataset = OAPrivate(data_path, **kwargs)
else:
raise ValueError(f"Unknown dataset {dataset_name}")
# if eval not already defined
if "dataset" in locals():
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
train, eval = train_val_dataset(dataset, val_split=val_split)
return train, eval
@@ -1,5 +1,11 @@
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
def format_pair(pair):
return "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pair[0], QA_SPECIAL_TOKENS["Answer"]), pair[1]
def format_pair(pairs):
assert len(pairs) % 2 == 0
return [
"{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pairs[i], QA_SPECIAL_TOKENS["Answer"])
if i % 2 == 0
else pairs[i]
for i in range(0, len(pairs), 2)
]
@@ -12,10 +12,9 @@ from torch.utils.data import Dataset
class OAPrivate(Dataset):
file = "2023-02-10_oasst_prod.jsonl"
splits = OrderedDict(sft=0.25, reward_model=0.4, rl=0.35) # fractions per task
def __init__(self, split="sft", data_path=".cache") -> None:
def __init__(self, data_path, split="sft", file="2023-02-10_oasst_prod.jsonl") -> None:
super().__init__()
total_prob = reduce(lambda prev, split: prev + split[1], self.splits.items(), 0)
@@ -56,14 +55,17 @@ class OAPrivate(Dataset):
reply = random.choice(prompt["replies"])
reply_text = reply["text"]
pairs.append([prompter_text, reply_text])
# only add if the reply exists
pairs.append(prompter_text)
pairs.append(reply_text)
if len(reply["replies"]) == 0:
break
prompt = random.choice(reply["replies"])
return pairs
return format_pair(pairs)
class PromptGeneratedDataset(Dataset):