From ac97943be15adcfabf73a6f81f902cac0bb08616 Mon Sep 17 00:00:00 2001 From: Sotirios Anagnostidis Date: Sat, 11 Feb 2023 11:49:58 +0100 Subject: [PATCH] refactor datasets and oa private data selection --- .../supervised_finetuning/configs/config.yaml | 6 ++ .../custom_datasets/__init__.py | 57 ++++++++---------- .../custom_datasets/prompt_dialogue.py | 60 +++++++++++++++++++ 3 files changed, 90 insertions(+), 33 deletions(-) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 191a7391..92a1a793 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -50,6 +50,12 @@ defaults: log_wandb: true samples_mixing: false # uses collator that mixes samples in the batch to create a single sample with possible multiple tasks within +oa_dataset_only: + datasets: + - oa_pricate: + data_path: .cache + val_split: 0.0 + galactica-125m: learning_rate: 5e-5 model_name: facebook/galactica-125m diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 440cb62b..73fc6d89 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -32,7 +32,7 @@ SUMMARIZATION_DATASETS = [ "debate_sum", "tldr_news", ] -OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated"] +OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated", "oa_private"] def train_val_dataset(dataset, val_split=0.2): @@ -42,63 +42,54 @@ def train_val_dataset(dataset, val_split=0.2): return Subset(dataset, train_idx), Subset(dataset, val_idx) -def get_one_dataset(conf, dataset_name, val_split=0.2, **kwargs): +def get_one_dataset(conf, dataset_name, val_split=0.2, data_path=None, **kwargs): + data_path = data_path or conf.cache_dir dataset_name = dataset_name.lower() if dataset_name in QA_DATASETS: - train = QADataset(dataset_name, conf.cache_dir, "train") - if train.no_val: - train, eval = train_val_dataset(train, val_split=val_split, **kwargs) - else: - eval = QADataset(dataset_name, conf.cache_dir, "validation") + train = QADataset(dataset_name, data_path, "train") + if not train.no_val: + eval = QADataset(dataset_name, data_path, "validation") elif dataset_name in SUMMARIZATION_DATASETS: - train = SummarizationDataset(dataset_name, conf.cache_dir, "train") - if dataset_name == "debate_sum": - train, eval = train_val_dataset(train, val_split=val_split, **kwargs) - else: - eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation") + train = SummarizationDataset(dataset_name, data_path, "train") + if dataset_name != "debate_sum": + eval = SummarizationDataset(dataset_name, data_path, "validation") elif "ted_trans" in dataset_name: language_pair = dataset_name.split("_")[-1] dataset = TEDTalk(pair=language_pair, split="train") - train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) elif "wmt2019" in dataset_name: language_pair = dataset_name.split("_")[-1] train = WMT2019(pair=language_pair, split="train") eval = WMT2019(pair=language_pair, split="validation") elif dataset_name == "dive_mt": dataset = DiveMT() - train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) elif dataset_name == "webgpt": dataset = WebGPT() - train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) elif dataset_name == "prompt_dialogue": - dataset = PromptGeneratedDataset(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) + dataset = PromptGeneratedDataset(data_path) elif dataset_name == "prosocial_dialogue": - train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train") - eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation") + train = ProsocialDialogue(cache_dir=data_path, split="train") + eval = ProsocialDialogue(cache_dir=data_path, split="validation") elif dataset_name == "explain_prosocial": - train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train") - eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation") + train = ProsocialDialogueExplaination(cache_dir=data_path, split="train") + eval = ProsocialDialogueExplaination(cache_dir=data_path, split="validation") elif dataset_name == "soda": - dataset = SODA(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) + dataset = SODA(data_path) elif dataset_name == "soda_dialogue": - dataset = SODADialogue(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) + dataset = SODADialogue(data_path) elif dataset_name == "joke": - dataset = JokeExplaination(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) + dataset = JokeExplaination(data_path) elif dataset_name == "instruct_tuning": - dataset = InstructionTuning(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) + dataset = InstructionTuning(data_path) elif dataset_name == "private_tuning": - dataset = PrivateInstructionTuning(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) + dataset = PrivateInstructionTuning(data_path) elif dataset_name == "oa_translated": - dataset = TranslatedQA(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) # TODO make val split lower..? + dataset = TranslatedQA(data_path) # TODO make val_split lower..? else: raise ValueError(f"Unknown dataset {dataset_name}") + # if eval not already defined + if "dataset" in locals(): + train, eval = train_val_dataset(dataset, val_split=val_split, **kwargs) + return train, eval diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 4aac2655..1550de6c 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -1,11 +1,71 @@ import json +import math import os +import random +from collections import OrderedDict +from functools import reduce from urllib.request import urlopen +import numpy as np from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair from torch.utils.data import Dataset +class OAPrivate(Dataset): + file = "2023-02-10_oasst_prod.jsonl" + splits = OrderedDict(sft=0.25, reward_model=0.4, rl=0.35) # fractions per task + + def __init__(self, split="sft", data_path=".cache") -> None: + super().__init__() + + total_prob = reduce(lambda prev, split: prev + split[1], self.splits.items(), 0) + assert math.isclose(total_prob, 1), "Make sure OAPrivate split ratios add to 1" + + jsonl_file = os.path.join(data_path, self.file) + + with open(jsonl_file, "r", encoding="utf-8") as f: + lines = f.readlines() + + # take a subset of the dataset based on the split + rng = np.random.default_rng(seed=0) + indices = np.arange(len(lines)).astype(int) + rng.shuffle(indices) + + cumsums = np.cumsum([[0] + list(self.splits.values())]) + split_index = list(self.splits.keys()).index(split) + + start_index, end_index = int(cumsums[split_index] * len(lines)), int(cumsums[split_index + 1] * len(lines)) + + self.data = [json.loads(lines[index].strip()) for index in indices[start_index:end_index]] + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + # Sample randomly from replies + prompt = self.data[index]["prompt"] + + pairs = [] + + while True: + assert prompt["role"] == "prompter" + prompter_text = prompt["text"] + + if len(prompt["replies"]) == 0: + break + + reply = random.choice(prompt["replies"]) + reply_text = reply["text"] + pairs.append([prompter_text, reply_text]) + + if len(reply["replies"]) == 0: + break + + prompt = random.choice(reply["replies"]) + + return pairs + + class PromptGeneratedDataset(Dataset): """Generates from flan 11B User: What are the best methods for preventing a slave trade?