mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
os private dataset
This commit is contained in:
@@ -1,7 +1,12 @@
|
||||
"""
|
||||
High level functions for model training
|
||||
"""
|
||||
from custom_datasets.prompt_dialogue import InstructionTuning, PrivateInstructionTuning, PromptGeneratedDataset
|
||||
from custom_datasets.prompt_dialogue import (
|
||||
InstructionTuning,
|
||||
OAPrivate,
|
||||
PrivateInstructionTuning,
|
||||
PromptGeneratedDataset,
|
||||
)
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT
|
||||
from custom_datasets.summarization import SummarizationDataset
|
||||
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
|
||||
@@ -36,6 +41,9 @@ OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_
|
||||
|
||||
|
||||
def train_val_dataset(dataset, val_split=0.2):
|
||||
if val_split == 0:
|
||||
return dataset, None
|
||||
|
||||
train_idx, val_idx = train_test_split(
|
||||
list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True
|
||||
)
|
||||
@@ -85,11 +93,13 @@ def get_one_dataset(conf, dataset_name, val_split=0.2, data_path=None, **kwargs)
|
||||
dataset = PrivateInstructionTuning(data_path)
|
||||
elif dataset_name == "oa_translated":
|
||||
dataset = TranslatedQA(data_path) # TODO make val_split lower..?
|
||||
elif dataset_name == "oa_private":
|
||||
dataset = OAPrivate(data_path, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {dataset_name}")
|
||||
|
||||
# if eval not already defined
|
||||
if "dataset" in locals():
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split)
|
||||
|
||||
return train, eval
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
|
||||
|
||||
|
||||
def format_pair(pair):
|
||||
return "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pair[0], QA_SPECIAL_TOKENS["Answer"]), pair[1]
|
||||
def format_pair(pairs):
|
||||
assert len(pairs) % 2 == 0
|
||||
return [
|
||||
"{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pairs[i], QA_SPECIAL_TOKENS["Answer"])
|
||||
if i % 2 == 0
|
||||
else pairs[i]
|
||||
for i in range(0, len(pairs), 2)
|
||||
]
|
||||
|
||||
@@ -12,10 +12,9 @@ from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class OAPrivate(Dataset):
|
||||
file = "2023-02-10_oasst_prod.jsonl"
|
||||
splits = OrderedDict(sft=0.25, reward_model=0.4, rl=0.35) # fractions per task
|
||||
|
||||
def __init__(self, split="sft", data_path=".cache") -> None:
|
||||
def __init__(self, data_path, split="sft", file="2023-02-10_oasst_prod.jsonl") -> None:
|
||||
super().__init__()
|
||||
|
||||
total_prob = reduce(lambda prev, split: prev + split[1], self.splits.items(), 0)
|
||||
@@ -56,14 +55,17 @@ class OAPrivate(Dataset):
|
||||
|
||||
reply = random.choice(prompt["replies"])
|
||||
reply_text = reply["text"]
|
||||
pairs.append([prompter_text, reply_text])
|
||||
|
||||
# only add if the reply exists
|
||||
pairs.append(prompter_text)
|
||||
pairs.append(reply_text)
|
||||
|
||||
if len(reply["replies"]) == 0:
|
||||
break
|
||||
|
||||
prompt = random.choice(reply["replies"])
|
||||
|
||||
return pairs
|
||||
return format_pair(pairs)
|
||||
|
||||
|
||||
class PromptGeneratedDataset(Dataset):
|
||||
|
||||
Reference in New Issue
Block a user