Merge branch 'main' of github.com:LAION-AI/Open-Assistant

This commit is contained in:
Yannic Kilcher
2023-02-11 14:12:35 +01:00
8 changed files with 168 additions and 54 deletions
+11
View File
@@ -66,6 +66,17 @@ different 500 examples from `prompt_dialogue`.
This works with `torch.distributed`.
## Training only on OA internal data:
To experiment with the Open Assistant data simply run:
```bash
python trainer.py --configs defaults oa_dataset_only galactica-125m
```
Change the `data_path` in the `oa_dataset_only` from the `configs/config.yaml`
file to the correct path.
## Model
Normally you should be able to add new models in `configs/config.yml`
@@ -17,13 +17,12 @@ defaults:
freeze_layer:
datasets:
- webgpt
# - prompt_dialogue
- squad_v2
- adversarial_qa
- trivia_qa_nocontext
- xsum
- cnn_dailymail
- prompt_dialogue
- prompt_dialogue # TODO: need to fix the url
- multi_news
- scitldr
- soda
@@ -47,6 +46,18 @@ defaults:
seq2seqmodel: false
poly_eps: 1.0
fuse_gelu: true
log_wandb: true
samples_mixing: false # uses collator that mixes samples in the batch to create a single sample with possible multiple tasks within
verbose: false
oa_dataset_only:
datasets:
- oa_private:
data_path: .cache
split: sft
val_split: 0.0
fraction: 1
file: 2023-02-10_oasst_prod.jsonl
galactica-125m:
learning_rate: 5e-5
@@ -81,9 +92,12 @@ codegen:
per_device_eval_batch_size: 4
debug:
model_name: EleutherAI/pythia-70m-deduped
eval_steps: 20
eval_size: 20
gradient_accumulation_steps: 1
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
quantization: false
log_wandb: false
verbose: true
@@ -1,7 +1,12 @@
"""
High level functions for model training
"""
from custom_datasets.prompt_dialogue import InstructionTuning, PrivateInstructionTuning, PromptGeneratedDataset
from custom_datasets.prompt_dialogue import (
InstructionTuning,
OAPrivate,
PrivateInstructionTuning,
PromptGeneratedDataset,
)
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT
from custom_datasets.summarization import SummarizationDataset
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
@@ -32,73 +37,69 @@ SUMMARIZATION_DATASETS = [
"debate_sum",
"tldr_news",
]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated"]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated", "oa_private"]
def train_val_dataset(dataset, val_split=0.2):
if val_split == 0:
return dataset, None
train_idx, val_idx = train_test_split(
list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True
)
return Subset(dataset, train_idx), Subset(dataset, val_idx)
def get_one_dataset(conf, dataset_name):
def get_one_dataset(conf, dataset_name, val_split=0.2, data_path=None, **kwargs):
data_path = data_path or conf.cache_dir
dataset_name = dataset_name.lower()
if dataset_name in QA_DATASETS:
train = QADataset(dataset_name, conf.cache_dir, "train")
if train.no_val:
train, eval = train_val_dataset(train, val_split=0.2)
else:
eval = QADataset(dataset_name, conf.cache_dir, "validation")
train = QADataset(dataset_name, data_path, "train")
if not train.no_val:
eval = QADataset(dataset_name, data_path, "validation")
elif dataset_name in SUMMARIZATION_DATASETS:
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
if dataset_name == "debate_sum":
train, eval = train_val_dataset(train, val_split=0.2)
else:
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
train = SummarizationDataset(dataset_name, data_path, "train")
if dataset_name != "debate_sum":
eval = SummarizationDataset(dataset_name, data_path, "validation")
elif "ted_trans" in dataset_name:
language_pair = dataset_name.split("_")[-1]
dataset = TEDTalk(pair=language_pair, split="train")
train, eval = train_val_dataset(dataset, val_split=0.2)
elif "wmt2019" in dataset_name:
language_pair = dataset_name.split("_")[-1]
train = WMT2019(pair=language_pair, split="train")
eval = WMT2019(pair=language_pair, split="validation")
elif dataset_name == "dive_mt":
dataset = DiveMT()
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "webgpt":
dataset = WebGPT()
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "prompt_dialogue":
dataset = PromptGeneratedDataset(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
dataset = PromptGeneratedDataset(data_path)
elif dataset_name == "prosocial_dialogue":
train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train")
eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation")
train = ProsocialDialogue(cache_dir=data_path, split="train")
eval = ProsocialDialogue(cache_dir=data_path, split="validation")
elif dataset_name == "explain_prosocial":
train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train")
eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation")
train = ProsocialDialogueExplaination(cache_dir=data_path, split="train")
eval = ProsocialDialogueExplaination(cache_dir=data_path, split="validation")
elif dataset_name == "soda":
dataset = SODA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.1)
dataset = SODA(data_path)
elif dataset_name == "soda_dialogue":
dataset = SODADialogue(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.1)
dataset = SODADialogue(data_path)
elif dataset_name == "joke":
dataset = JokeExplaination(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
dataset = JokeExplaination(data_path)
elif dataset_name == "instruct_tuning":
dataset = InstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
dataset = InstructionTuning(data_path)
elif dataset_name == "private_tuning":
dataset = PrivateInstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
dataset = PrivateInstructionTuning(data_path)
elif dataset_name == "oa_translated":
dataset = TranslatedQA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.01)
dataset = TranslatedQA(data_path) # TODO make val_split lower..?
elif dataset_name == "oa_private":
dataset = OAPrivate(data_path, **kwargs)
else:
raise ValueError(f"Unknown dataset {dataset_name}")
# if eval not already defined
if "dataset" in locals():
train, eval = train_val_dataset(dataset, val_split=val_split)
return train, eval
@@ -7,6 +7,8 @@ import torch
from torch.nn import functional as F
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
from .formatting import QA_SPECIAL_TOKENS
@dataclass
class DialogueDataCollator:
@@ -28,7 +30,7 @@ class DialogueDataCollator:
# Add a way for the model to terminate generation
# When we predict the start of a new expected question, we want to be able to stop generation
messages.append(self.tokenizer.eos_token)
messages.append(QA_SPECIAL_TOKENS["Question"])
flatten_message = self.tokenizer(
"".join(messages),
@@ -101,7 +103,7 @@ class TrainDialogueDataCollator:
# Add a way for the model to terminate generation
# When we predict the start of a new expected question, we want to be able to stop generation
messages.append(self.tokenizer.eos_token)
messages.append(QA_SPECIAL_TOKENS["Question"])
flatten_message = self.tokenizer(
"".join(messages),
@@ -1,5 +1,11 @@
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
def format_pair(pair):
return "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pair[0], QA_SPECIAL_TOKENS["Answer"]), pair[1]
def format_pair(pairs):
assert len(pairs) % 2 == 0
return [
"{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pairs[i], QA_SPECIAL_TOKENS["Answer"])
if i % 2 == 0
else pairs[i]
for i in range(0, len(pairs), 2)
]
@@ -1,11 +1,73 @@
import json
import math
import os
import random
from collections import OrderedDict
from functools import reduce
from urllib.request import urlopen
import numpy as np
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
from torch.utils.data import Dataset
class OAPrivate(Dataset):
splits = OrderedDict(sft=0.25, reward_model=0.4, rl=0.35) # fractions per task
def __init__(self, data_path, split="sft", file="2023-02-10_oasst_prod.jsonl") -> None:
super().__init__()
total_prob = reduce(lambda prev, split: prev + split[1], self.splits.items(), 0)
assert math.isclose(total_prob, 1), "Make sure OAPrivate split ratios add to 1"
jsonl_file = os.path.join(data_path, self.file)
with open(jsonl_file, "r", encoding="utf-8") as f:
lines = f.readlines()
# take a subset of the dataset based on the split
rng = np.random.default_rng(seed=0)
indices = np.arange(len(lines)).astype(int)
rng.shuffle(indices)
cumsums = np.cumsum([[0] + list(self.splits.values())])
split_index = list(self.splits.keys()).index(split)
start_index, end_index = int(cumsums[split_index] * len(lines)), int(cumsums[split_index + 1] * len(lines))
self.data = [json.loads(lines[index].strip()) for index in indices[start_index:end_index]]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# Sample randomly from replies
prompt = self.data[index]["prompt"]
pairs = []
while True:
assert prompt["role"] == "prompter"
prompter_text = prompt["text"]
if len(prompt["replies"]) == 0:
break
reply = random.choice(prompt["replies"])
reply_text = reply["text"]
# only add if the reply exists
pairs.append(prompter_text)
pairs.append(reply_text)
if len(reply["replies"]) == 0:
break
prompt = random.choice(reply["replies"])
return format_pair(pairs)
class PromptGeneratedDataset(Dataset):
"""Generates from flan 11B
User: What are the best methods for preventing a slave trade?
+2 -3
View File
@@ -231,11 +231,10 @@ if __name__ == "__main__":
eval_steps=training_conf.eval_steps,
save_steps=training_conf.save_steps,
eval_accumulation_steps=training_conf.eval_accumulation_steps,
report_to="wandb",
report_to="wandb" if training_conf.log_wandb else None,
)
assert len(evals) > 0
if not training_conf.deepspeed or training_conf.local_rank == 0:
if training_conf.log_wandb and not training_conf.deepspeed or training_conf.local_rank == 0:
import wandb
wandb.init(
+30 -11
View File
@@ -61,16 +61,20 @@ class PerDatasetSampler(Sampler):
@classmethod
def build_sampler_from_config(cls, training_conf, datasets):
dataset_sizes = [len(x) for x in datasets]
fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes)
fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes, verbose=training_conf.verbose)
dataset_size_per_epoch = [int(size * frac) for size, frac in zip(dataset_sizes, fractions)]
return cls(dataset_sizes, dataset_size_per_epoch)
def get_dataset_fractions(conf, dataset_sizes):
def get_dataset_fractions(conf, dataset_sizes, verbose=False):
"""Calculate fraction of each dataset to use per epoch when subsampling"""
if verbose:
print("Creating sampler for datasets:")
fractions = []
for i, data_config in enumerate(conf):
dataset_name = get_dataset_name_from_data_config(data_config)
dataset_name, _ = get_dataset_name_and_kwargs_from_data_config(data_config)
if isinstance(data_config, dict):
if "fraction" in data_config[dataset_name]:
if data_config[dataset_name]["fraction"] <= 0:
@@ -81,9 +85,12 @@ def get_dataset_fractions(conf, dataset_sizes):
raise ValueError(f"Please specify a size smaller than number of examples: {dataset_sizes[i]:,.0f}")
fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i])
else:
raise ValueError("Please specify either fraction or size in config.yaml. See README for instructions.")
fractions.append(1)
else:
fractions.append(1)
if verbose:
print(f"Dataset: {dataset_name} fraction chosen: {fractions[-1]:.2f}")
return fractions
@@ -220,25 +227,37 @@ def get_model(conf, tokenizer):
return model
def get_dataset_name_from_data_config(data_config):
def get_dataset_name_and_kwargs_from_data_config(data_config):
if isinstance(data_config, dict):
return list(data_config.keys())[0]
return data_config
name = list(data_config.keys())[0]
kwargs = data_config[name]
# remove 'fraction' or 'size' from kwargs
kwargs.pop("fraction", None)
kwargs.pop("size", None)
return name, kwargs
else:
return data_config, {}
def get_dataset(conf, tokenizer):
train_datasets, evals = [], {}
for data_config in conf.datasets:
dataset_name = get_dataset_name_from_data_config(data_config)
train, val = get_one_dataset(conf, dataset_name)
dataset_name, kwargs = get_dataset_name_and_kwargs_from_data_config(data_config)
train, val = get_one_dataset(conf, dataset_name, **kwargs)
train_datasets.append(train)
evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val
if val is not None:
evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val
train = ConcatDataset(train_datasets)
collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length)
train_collate_fn = TrainDialogueDataCollator(tokenizer, max_length=conf.max_length)
train_collate_fn = (
TrainDialogueDataCollator(tokenizer, max_length=conf.max_length) if conf.samples_mixing else collate_fn
)
return train, evals, collate_fn, train_collate_fn