mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
refactor datasets and oa private data selection
This commit is contained in:
@@ -50,6 +50,12 @@ defaults:
|
||||
log_wandb: true
|
||||
samples_mixing: false # uses collator that mixes samples in the batch to create a single sample with possible multiple tasks within
|
||||
|
||||
oa_dataset_only:
|
||||
datasets:
|
||||
- oa_pricate:
|
||||
data_path: .cache
|
||||
val_split: 0.0
|
||||
|
||||
galactica-125m:
|
||||
learning_rate: 5e-5
|
||||
model_name: facebook/galactica-125m
|
||||
|
||||
@@ -32,7 +32,7 @@ SUMMARIZATION_DATASETS = [
|
||||
"debate_sum",
|
||||
"tldr_news",
|
||||
]
|
||||
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated"]
|
||||
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated", "oa_private"]
|
||||
|
||||
|
||||
def train_val_dataset(dataset, val_split=0.2):
|
||||
@@ -42,63 +42,54 @@ def train_val_dataset(dataset, val_split=0.2):
|
||||
return Subset(dataset, train_idx), Subset(dataset, val_idx)
|
||||
|
||||
|
||||
def get_one_dataset(conf, dataset_name, val_split=0.2, **kwargs):
|
||||
def get_one_dataset(conf, dataset_name, val_split=0.2, data_path=None, **kwargs):
|
||||
data_path = data_path or conf.cache_dir
|
||||
dataset_name = dataset_name.lower()
|
||||
|
||||
if dataset_name in QA_DATASETS:
|
||||
train = QADataset(dataset_name, conf.cache_dir, "train")
|
||||
if train.no_val:
|
||||
train, eval = train_val_dataset(train, val_split=val_split, **kwargs)
|
||||
else:
|
||||
eval = QADataset(dataset_name, conf.cache_dir, "validation")
|
||||
train = QADataset(dataset_name, data_path, "train")
|
||||
if not train.no_val:
|
||||
eval = QADataset(dataset_name, data_path, "validation")
|
||||
elif dataset_name in SUMMARIZATION_DATASETS:
|
||||
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
|
||||
if dataset_name == "debate_sum":
|
||||
train, eval = train_val_dataset(train, val_split=val_split, **kwargs)
|
||||
else:
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
|
||||
train = SummarizationDataset(dataset_name, data_path, "train")
|
||||
if dataset_name != "debate_sum":
|
||||
eval = SummarizationDataset(dataset_name, data_path, "validation")
|
||||
elif "ted_trans" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
dataset = TEDTalk(pair=language_pair, split="train")
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
elif "wmt2019" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
train = WMT2019(pair=language_pair, split="train")
|
||||
eval = WMT2019(pair=language_pair, split="validation")
|
||||
elif dataset_name == "dive_mt":
|
||||
dataset = DiveMT()
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
elif dataset_name == "webgpt":
|
||||
dataset = WebGPT()
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
elif dataset_name == "prompt_dialogue":
|
||||
dataset = PromptGeneratedDataset(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
dataset = PromptGeneratedDataset(data_path)
|
||||
elif dataset_name == "prosocial_dialogue":
|
||||
train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train")
|
||||
eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation")
|
||||
train = ProsocialDialogue(cache_dir=data_path, split="train")
|
||||
eval = ProsocialDialogue(cache_dir=data_path, split="validation")
|
||||
elif dataset_name == "explain_prosocial":
|
||||
train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train")
|
||||
eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation")
|
||||
train = ProsocialDialogueExplaination(cache_dir=data_path, split="train")
|
||||
eval = ProsocialDialogueExplaination(cache_dir=data_path, split="validation")
|
||||
elif dataset_name == "soda":
|
||||
dataset = SODA(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
dataset = SODA(data_path)
|
||||
elif dataset_name == "soda_dialogue":
|
||||
dataset = SODADialogue(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
dataset = SODADialogue(data_path)
|
||||
elif dataset_name == "joke":
|
||||
dataset = JokeExplaination(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
dataset = JokeExplaination(data_path)
|
||||
elif dataset_name == "instruct_tuning":
|
||||
dataset = InstructionTuning(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
dataset = InstructionTuning(data_path)
|
||||
elif dataset_name == "private_tuning":
|
||||
dataset = PrivateInstructionTuning(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
dataset = PrivateInstructionTuning(data_path)
|
||||
elif dataset_name == "oa_translated":
|
||||
dataset = TranslatedQA(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) # TODO make val split lower..?
|
||||
dataset = TranslatedQA(data_path) # TODO make val_split lower..?
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {dataset_name}")
|
||||
|
||||
# if eval not already defined
|
||||
if "dataset" in locals():
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs)
|
||||
|
||||
return train, eval
|
||||
|
||||
@@ -1,11 +1,71 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from functools import reduce
|
||||
from urllib.request import urlopen
|
||||
|
||||
import numpy as np
|
||||
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class OAPrivate(Dataset):
|
||||
file = "2023-02-10_oasst_prod.jsonl"
|
||||
splits = OrderedDict(sft=0.25, reward_model=0.4, rl=0.35) # fractions per task
|
||||
|
||||
def __init__(self, split="sft", data_path=".cache") -> None:
|
||||
super().__init__()
|
||||
|
||||
total_prob = reduce(lambda prev, split: prev + split[1], self.splits.items(), 0)
|
||||
assert math.isclose(total_prob, 1), "Make sure OAPrivate split ratios add to 1"
|
||||
|
||||
jsonl_file = os.path.join(data_path, self.file)
|
||||
|
||||
with open(jsonl_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# take a subset of the dataset based on the split
|
||||
rng = np.random.default_rng(seed=0)
|
||||
indices = np.arange(len(lines)).astype(int)
|
||||
rng.shuffle(indices)
|
||||
|
||||
cumsums = np.cumsum([[0] + list(self.splits.values())])
|
||||
split_index = list(self.splits.keys()).index(split)
|
||||
|
||||
start_index, end_index = int(cumsums[split_index] * len(lines)), int(cumsums[split_index + 1] * len(lines))
|
||||
|
||||
self.data = [json.loads(lines[index].strip()) for index in indices[start_index:end_index]]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Sample randomly from replies
|
||||
prompt = self.data[index]["prompt"]
|
||||
|
||||
pairs = []
|
||||
|
||||
while True:
|
||||
assert prompt["role"] == "prompter"
|
||||
prompter_text = prompt["text"]
|
||||
|
||||
if len(prompt["replies"]) == 0:
|
||||
break
|
||||
|
||||
reply = random.choice(prompt["replies"])
|
||||
reply_text = reply["text"]
|
||||
pairs.append([prompter_text, reply_text])
|
||||
|
||||
if len(reply["replies"]) == 0:
|
||||
break
|
||||
|
||||
prompt = random.choice(reply["replies"])
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
class PromptGeneratedDataset(Dataset):
|
||||
"""Generates from flan 11B
|
||||
User: What are the best methods for preventing a slave trade?
|
||||
|
||||
Reference in New Issue
Block a user