diff --git a/model/supervised_finetuning/README.md b/model/supervised_finetuning/README.md index 387e91e4..1a3af828 100644 --- a/model/supervised_finetuning/README.md +++ b/model/supervised_finetuning/README.md @@ -66,6 +66,17 @@ different 500 examples from `prompt_dialogue`. This works with `torch.distributed`. +## Training only on OA internal data: + +To experiment with the Open Assistant data simply run: + +```bash +python trainer.py --configs defaults oa_dataset_only galactica-125m +``` + +Change the `data_path` in the `oa_dataset_only` from the `configs/config.yaml` +file to the correct path. + ## Model Normally you should be able to add new models in `configs/config.yml` diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 79e4751d..63a2a592 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -17,13 +17,12 @@ defaults: freeze_layer: datasets: - webgpt - # - prompt_dialogue - squad_v2 - adversarial_qa - trivia_qa_nocontext - xsum - cnn_dailymail - - prompt_dialogue + - prompt_dialogue # TODO: need to fix the url - multi_news - scitldr - soda @@ -47,6 +46,18 @@ defaults: seq2seqmodel: false poly_eps: 1.0 fuse_gelu: true + log_wandb: true + samples_mixing: false # uses collator that mixes samples in the batch to create a single sample with possible multiple tasks within + verbose: false + +oa_dataset_only: + datasets: + - oa_private: + data_path: .cache + split: sft + val_split: 0.0 + fraction: 1 + file: 2023-02-10_oasst_prod.jsonl galactica-125m: learning_rate: 5e-5 @@ -81,9 +92,12 @@ codegen: per_device_eval_batch_size: 4 debug: + model_name: EleutherAI/pythia-70m-deduped eval_steps: 20 eval_size: 20 gradient_accumulation_steps: 1 per_device_train_batch_size: 1 per_device_eval_batch_size: 1 quantization: false + log_wandb: false + verbose: true diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 0e5b9a91..def468e1 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,7 +1,12 @@ """ High level functions for model training """ -from custom_datasets.prompt_dialogue import InstructionTuning, PrivateInstructionTuning, PromptGeneratedDataset +from custom_datasets.prompt_dialogue import ( + InstructionTuning, + OAPrivate, + PrivateInstructionTuning, + PromptGeneratedDataset, +) from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT from custom_datasets.summarization import SummarizationDataset from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination @@ -32,73 +37,69 @@ SUMMARIZATION_DATASETS = [ "debate_sum", "tldr_news", ] -OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated"] +OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated", "oa_private"] def train_val_dataset(dataset, val_split=0.2): + if val_split == 0: + return dataset, None + train_idx, val_idx = train_test_split( list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True ) return Subset(dataset, train_idx), Subset(dataset, val_idx) -def get_one_dataset(conf, dataset_name): +def get_one_dataset(conf, dataset_name, val_split=0.2, data_path=None, **kwargs): + data_path = data_path or conf.cache_dir dataset_name = dataset_name.lower() if dataset_name in QA_DATASETS: - train = QADataset(dataset_name, conf.cache_dir, "train") - if train.no_val: - train, eval = train_val_dataset(train, val_split=0.2) - else: - eval = QADataset(dataset_name, conf.cache_dir, "validation") + train = QADataset(dataset_name, data_path, "train") + if not train.no_val: + eval = QADataset(dataset_name, data_path, "validation") elif dataset_name in SUMMARIZATION_DATASETS: - train = SummarizationDataset(dataset_name, conf.cache_dir, "train") - if dataset_name == "debate_sum": - train, eval = train_val_dataset(train, val_split=0.2) - else: - eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation") + train = SummarizationDataset(dataset_name, data_path, "train") + if dataset_name != "debate_sum": + eval = SummarizationDataset(dataset_name, data_path, "validation") elif "ted_trans" in dataset_name: language_pair = dataset_name.split("_")[-1] dataset = TEDTalk(pair=language_pair, split="train") - train, eval = train_val_dataset(dataset, val_split=0.2) elif "wmt2019" in dataset_name: language_pair = dataset_name.split("_")[-1] train = WMT2019(pair=language_pair, split="train") eval = WMT2019(pair=language_pair, split="validation") elif dataset_name == "dive_mt": dataset = DiveMT() - train, eval = train_val_dataset(dataset, val_split=0.2) elif dataset_name == "webgpt": dataset = WebGPT() - train, eval = train_val_dataset(dataset, val_split=0.2) elif dataset_name == "prompt_dialogue": - dataset = PromptGeneratedDataset(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.2) + dataset = PromptGeneratedDataset(data_path) elif dataset_name == "prosocial_dialogue": - train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train") - eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation") + train = ProsocialDialogue(cache_dir=data_path, split="train") + eval = ProsocialDialogue(cache_dir=data_path, split="validation") elif dataset_name == "explain_prosocial": - train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train") - eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation") + train = ProsocialDialogueExplaination(cache_dir=data_path, split="train") + eval = ProsocialDialogueExplaination(cache_dir=data_path, split="validation") elif dataset_name == "soda": - dataset = SODA(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.1) + dataset = SODA(data_path) elif dataset_name == "soda_dialogue": - dataset = SODADialogue(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.1) + dataset = SODADialogue(data_path) elif dataset_name == "joke": - dataset = JokeExplaination(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.2) + dataset = JokeExplaination(data_path) elif dataset_name == "instruct_tuning": - dataset = InstructionTuning(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.2) + dataset = InstructionTuning(data_path) elif dataset_name == "private_tuning": - dataset = PrivateInstructionTuning(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.2) + dataset = PrivateInstructionTuning(data_path) elif dataset_name == "oa_translated": - dataset = TranslatedQA(conf.cache_dir) - train, eval = train_val_dataset(dataset, val_split=0.01) + dataset = TranslatedQA(data_path) # TODO make val_split lower..? + elif dataset_name == "oa_private": + dataset = OAPrivate(data_path, **kwargs) else: raise ValueError(f"Unknown dataset {dataset_name}") + # if eval not already defined + if "dataset" in locals(): + train, eval = train_val_dataset(dataset, val_split=val_split) + return train, eval diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index 0a0b7a5a..e43d6d8e 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -7,6 +7,8 @@ import torch from torch.nn import functional as F from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase +from .formatting import QA_SPECIAL_TOKENS + @dataclass class DialogueDataCollator: @@ -28,7 +30,7 @@ class DialogueDataCollator: # Add a way for the model to terminate generation # When we predict the start of a new expected question, we want to be able to stop generation - messages.append(self.tokenizer.eos_token) + messages.append(QA_SPECIAL_TOKENS["Question"]) flatten_message = self.tokenizer( "".join(messages), @@ -101,7 +103,7 @@ class TrainDialogueDataCollator: # Add a way for the model to terminate generation # When we predict the start of a new expected question, we want to be able to stop generation - messages.append(self.tokenizer.eos_token) + messages.append(QA_SPECIAL_TOKENS["Question"]) flatten_message = self.tokenizer( "".join(messages), diff --git a/model/supervised_finetuning/custom_datasets/formatting.py b/model/supervised_finetuning/custom_datasets/formatting.py index a6c1c0d8..e69557a9 100644 --- a/model/supervised_finetuning/custom_datasets/formatting.py +++ b/model/supervised_finetuning/custom_datasets/formatting.py @@ -1,5 +1,11 @@ QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} -def format_pair(pair): - return "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pair[0], QA_SPECIAL_TOKENS["Answer"]), pair[1] +def format_pair(pairs): + assert len(pairs) % 2 == 0 + return [ + "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pairs[i], QA_SPECIAL_TOKENS["Answer"]) + if i % 2 == 0 + else pairs[i] + for i in range(0, len(pairs), 2) + ] diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 4aac2655..d7b80761 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -1,11 +1,73 @@ import json +import math import os +import random +from collections import OrderedDict +from functools import reduce from urllib.request import urlopen +import numpy as np from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair from torch.utils.data import Dataset +class OAPrivate(Dataset): + splits = OrderedDict(sft=0.25, reward_model=0.4, rl=0.35) # fractions per task + + def __init__(self, data_path, split="sft", file="2023-02-10_oasst_prod.jsonl") -> None: + super().__init__() + + total_prob = reduce(lambda prev, split: prev + split[1], self.splits.items(), 0) + assert math.isclose(total_prob, 1), "Make sure OAPrivate split ratios add to 1" + + jsonl_file = os.path.join(data_path, self.file) + + with open(jsonl_file, "r", encoding="utf-8") as f: + lines = f.readlines() + + # take a subset of the dataset based on the split + rng = np.random.default_rng(seed=0) + indices = np.arange(len(lines)).astype(int) + rng.shuffle(indices) + + cumsums = np.cumsum([[0] + list(self.splits.values())]) + split_index = list(self.splits.keys()).index(split) + + start_index, end_index = int(cumsums[split_index] * len(lines)), int(cumsums[split_index + 1] * len(lines)) + + self.data = [json.loads(lines[index].strip()) for index in indices[start_index:end_index]] + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + # Sample randomly from replies + prompt = self.data[index]["prompt"] + + pairs = [] + + while True: + assert prompt["role"] == "prompter" + prompter_text = prompt["text"] + + if len(prompt["replies"]) == 0: + break + + reply = random.choice(prompt["replies"]) + reply_text = reply["text"] + + # only add if the reply exists + pairs.append(prompter_text) + pairs.append(reply_text) + + if len(reply["replies"]) == 0: + break + + prompt = random.choice(reply["replies"]) + + return format_pair(pairs) + + class PromptGeneratedDataset(Dataset): """Generates from flan 11B User: What are the best methods for preventing a slave trade? diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index ce80830b..48847862 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -231,11 +231,10 @@ if __name__ == "__main__": eval_steps=training_conf.eval_steps, save_steps=training_conf.save_steps, eval_accumulation_steps=training_conf.eval_accumulation_steps, - report_to="wandb", + report_to="wandb" if training_conf.log_wandb else None, ) - assert len(evals) > 0 - if not training_conf.deepspeed or training_conf.local_rank == 0: + if training_conf.log_wandb and not training_conf.deepspeed or training_conf.local_rank == 0: import wandb wandb.init( diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 43377bc9..f95208d5 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -61,16 +61,20 @@ class PerDatasetSampler(Sampler): @classmethod def build_sampler_from_config(cls, training_conf, datasets): dataset_sizes = [len(x) for x in datasets] - fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes) + fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes, verbose=training_conf.verbose) dataset_size_per_epoch = [int(size * frac) for size, frac in zip(dataset_sizes, fractions)] return cls(dataset_sizes, dataset_size_per_epoch) -def get_dataset_fractions(conf, dataset_sizes): +def get_dataset_fractions(conf, dataset_sizes, verbose=False): """Calculate fraction of each dataset to use per epoch when subsampling""" + + if verbose: + print("Creating sampler for datasets:") + fractions = [] for i, data_config in enumerate(conf): - dataset_name = get_dataset_name_from_data_config(data_config) + dataset_name, _ = get_dataset_name_and_kwargs_from_data_config(data_config) if isinstance(data_config, dict): if "fraction" in data_config[dataset_name]: if data_config[dataset_name]["fraction"] <= 0: @@ -81,9 +85,12 @@ def get_dataset_fractions(conf, dataset_sizes): raise ValueError(f"Please specify a size smaller than number of examples: {dataset_sizes[i]:,.0f}") fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i]) else: - raise ValueError("Please specify either fraction or size in config.yaml. See README for instructions.") + fractions.append(1) else: fractions.append(1) + + if verbose: + print(f"Dataset: {dataset_name} fraction chosen: {fractions[-1]:.2f}") return fractions @@ -220,25 +227,37 @@ def get_model(conf, tokenizer): return model -def get_dataset_name_from_data_config(data_config): +def get_dataset_name_and_kwargs_from_data_config(data_config): if isinstance(data_config, dict): - return list(data_config.keys())[0] - return data_config + name = list(data_config.keys())[0] + kwargs = data_config[name] + # remove 'fraction' or 'size' from kwargs + kwargs.pop("fraction", None) + kwargs.pop("size", None) + return name, kwargs + else: + return data_config, {} def get_dataset(conf, tokenizer): train_datasets, evals = [], {} for data_config in conf.datasets: - dataset_name = get_dataset_name_from_data_config(data_config) - train, val = get_one_dataset(conf, dataset_name) + dataset_name, kwargs = get_dataset_name_and_kwargs_from_data_config(data_config) + train, val = get_one_dataset(conf, dataset_name, **kwargs) train_datasets.append(train) - evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val + + if val is not None: + evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val train = ConcatDataset(train_datasets) collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length) - train_collate_fn = TrainDialogueDataCollator(tokenizer, max_length=conf.max_length) + + train_collate_fn = ( + TrainDialogueDataCollator(tokenizer, max_length=conf.max_length) if conf.samples_mixing else collate_fn + ) + return train, evals, collate_fn, train_collate_fn