mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'main' of github.com:LAION-AI/Open-Assistant
This commit is contained in:
@@ -66,6 +66,17 @@ different 500 examples from `prompt_dialogue`.
|
||||
|
||||
This works with `torch.distributed`.
|
||||
|
||||
## Training only on OA internal data:
|
||||
|
||||
To experiment with the Open Assistant data simply run:
|
||||
|
||||
```bash
|
||||
python trainer.py --configs defaults oa_dataset_only galactica-125m
|
||||
```
|
||||
|
||||
Change the `data_path` in the `oa_dataset_only` from the `configs/config.yaml`
|
||||
file to the correct path.
|
||||
|
||||
## Model
|
||||
|
||||
Normally you should be able to add new models in `configs/config.yml`
|
||||
|
||||
@@ -17,13 +17,12 @@ defaults:
|
||||
freeze_layer:
|
||||
datasets:
|
||||
- webgpt
|
||||
# - prompt_dialogue
|
||||
- squad_v2
|
||||
- adversarial_qa
|
||||
- trivia_qa_nocontext
|
||||
- xsum
|
||||
- cnn_dailymail
|
||||
- prompt_dialogue
|
||||
- prompt_dialogue # TODO: need to fix the url
|
||||
- multi_news
|
||||
- scitldr
|
||||
- soda
|
||||
@@ -47,6 +46,18 @@ defaults:
|
||||
seq2seqmodel: false
|
||||
poly_eps: 1.0
|
||||
fuse_gelu: true
|
||||
log_wandb: true
|
||||
samples_mixing: false # uses collator that mixes samples in the batch to create a single sample with possible multiple tasks within
|
||||
verbose: false
|
||||
|
||||
oa_dataset_only:
|
||||
datasets:
|
||||
- oa_private:
|
||||
data_path: .cache
|
||||
split: sft
|
||||
val_split: 0.0
|
||||
fraction: 1
|
||||
file: 2023-02-10_oasst_prod.jsonl
|
||||
|
||||
galactica-125m:
|
||||
learning_rate: 5e-5
|
||||
@@ -81,9 +92,12 @@ codegen:
|
||||
per_device_eval_batch_size: 4
|
||||
|
||||
debug:
|
||||
model_name: EleutherAI/pythia-70m-deduped
|
||||
eval_steps: 20
|
||||
eval_size: 20
|
||||
gradient_accumulation_steps: 1
|
||||
per_device_train_batch_size: 1
|
||||
per_device_eval_batch_size: 1
|
||||
quantization: false
|
||||
log_wandb: false
|
||||
verbose: true
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
"""
|
||||
High level functions for model training
|
||||
"""
|
||||
from custom_datasets.prompt_dialogue import InstructionTuning, PrivateInstructionTuning, PromptGeneratedDataset
|
||||
from custom_datasets.prompt_dialogue import (
|
||||
InstructionTuning,
|
||||
OAPrivate,
|
||||
PrivateInstructionTuning,
|
||||
PromptGeneratedDataset,
|
||||
)
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT
|
||||
from custom_datasets.summarization import SummarizationDataset
|
||||
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
|
||||
@@ -32,73 +37,69 @@ SUMMARIZATION_DATASETS = [
|
||||
"debate_sum",
|
||||
"tldr_news",
|
||||
]
|
||||
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated"]
|
||||
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated", "oa_private"]
|
||||
|
||||
|
||||
def train_val_dataset(dataset, val_split=0.2):
|
||||
if val_split == 0:
|
||||
return dataset, None
|
||||
|
||||
train_idx, val_idx = train_test_split(
|
||||
list(range(len(dataset))), test_size=val_split, random_state=666, shuffle=True
|
||||
)
|
||||
return Subset(dataset, train_idx), Subset(dataset, val_idx)
|
||||
|
||||
|
||||
def get_one_dataset(conf, dataset_name):
|
||||
def get_one_dataset(conf, dataset_name, val_split=0.2, data_path=None, **kwargs):
|
||||
data_path = data_path or conf.cache_dir
|
||||
dataset_name = dataset_name.lower()
|
||||
|
||||
if dataset_name in QA_DATASETS:
|
||||
train = QADataset(dataset_name, conf.cache_dir, "train")
|
||||
if train.no_val:
|
||||
train, eval = train_val_dataset(train, val_split=0.2)
|
||||
else:
|
||||
eval = QADataset(dataset_name, conf.cache_dir, "validation")
|
||||
train = QADataset(dataset_name, data_path, "train")
|
||||
if not train.no_val:
|
||||
eval = QADataset(dataset_name, data_path, "validation")
|
||||
elif dataset_name in SUMMARIZATION_DATASETS:
|
||||
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
|
||||
if dataset_name == "debate_sum":
|
||||
train, eval = train_val_dataset(train, val_split=0.2)
|
||||
else:
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")
|
||||
train = SummarizationDataset(dataset_name, data_path, "train")
|
||||
if dataset_name != "debate_sum":
|
||||
eval = SummarizationDataset(dataset_name, data_path, "validation")
|
||||
elif "ted_trans" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
dataset = TEDTalk(pair=language_pair, split="train")
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif "wmt2019" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
train = WMT2019(pair=language_pair, split="train")
|
||||
eval = WMT2019(pair=language_pair, split="validation")
|
||||
elif dataset_name == "dive_mt":
|
||||
dataset = DiveMT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "webgpt":
|
||||
dataset = WebGPT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "prompt_dialogue":
|
||||
dataset = PromptGeneratedDataset(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
dataset = PromptGeneratedDataset(data_path)
|
||||
elif dataset_name == "prosocial_dialogue":
|
||||
train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train")
|
||||
eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation")
|
||||
train = ProsocialDialogue(cache_dir=data_path, split="train")
|
||||
eval = ProsocialDialogue(cache_dir=data_path, split="validation")
|
||||
elif dataset_name == "explain_prosocial":
|
||||
train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train")
|
||||
eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation")
|
||||
train = ProsocialDialogueExplaination(cache_dir=data_path, split="train")
|
||||
eval = ProsocialDialogueExplaination(cache_dir=data_path, split="validation")
|
||||
elif dataset_name == "soda":
|
||||
dataset = SODA(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.1)
|
||||
dataset = SODA(data_path)
|
||||
elif dataset_name == "soda_dialogue":
|
||||
dataset = SODADialogue(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.1)
|
||||
dataset = SODADialogue(data_path)
|
||||
elif dataset_name == "joke":
|
||||
dataset = JokeExplaination(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
dataset = JokeExplaination(data_path)
|
||||
elif dataset_name == "instruct_tuning":
|
||||
dataset = InstructionTuning(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
dataset = InstructionTuning(data_path)
|
||||
elif dataset_name == "private_tuning":
|
||||
dataset = PrivateInstructionTuning(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
dataset = PrivateInstructionTuning(data_path)
|
||||
elif dataset_name == "oa_translated":
|
||||
dataset = TranslatedQA(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.01)
|
||||
dataset = TranslatedQA(data_path) # TODO make val_split lower..?
|
||||
elif dataset_name == "oa_private":
|
||||
dataset = OAPrivate(data_path, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {dataset_name}")
|
||||
|
||||
# if eval not already defined
|
||||
if "dataset" in locals():
|
||||
train, eval = train_val_dataset(dataset, val_split=val_split)
|
||||
|
||||
return train, eval
|
||||
|
||||
@@ -7,6 +7,8 @@ import torch
|
||||
from torch.nn import functional as F
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
from .formatting import QA_SPECIAL_TOKENS
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialogueDataCollator:
|
||||
@@ -28,7 +30,7 @@ class DialogueDataCollator:
|
||||
|
||||
# Add a way for the model to terminate generation
|
||||
# When we predict the start of a new expected question, we want to be able to stop generation
|
||||
messages.append(self.tokenizer.eos_token)
|
||||
messages.append(QA_SPECIAL_TOKENS["Question"])
|
||||
|
||||
flatten_message = self.tokenizer(
|
||||
"".join(messages),
|
||||
@@ -101,7 +103,7 @@ class TrainDialogueDataCollator:
|
||||
|
||||
# Add a way for the model to terminate generation
|
||||
# When we predict the start of a new expected question, we want to be able to stop generation
|
||||
messages.append(self.tokenizer.eos_token)
|
||||
messages.append(QA_SPECIAL_TOKENS["Question"])
|
||||
|
||||
flatten_message = self.tokenizer(
|
||||
"".join(messages),
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
|
||||
|
||||
|
||||
def format_pair(pair):
|
||||
return "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pair[0], QA_SPECIAL_TOKENS["Answer"]), pair[1]
|
||||
def format_pair(pairs):
|
||||
assert len(pairs) % 2 == 0
|
||||
return [
|
||||
"{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pairs[i], QA_SPECIAL_TOKENS["Answer"])
|
||||
if i % 2 == 0
|
||||
else pairs[i]
|
||||
for i in range(0, len(pairs), 2)
|
||||
]
|
||||
|
||||
@@ -1,11 +1,73 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from functools import reduce
|
||||
from urllib.request import urlopen
|
||||
|
||||
import numpy as np
|
||||
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class OAPrivate(Dataset):
|
||||
splits = OrderedDict(sft=0.25, reward_model=0.4, rl=0.35) # fractions per task
|
||||
|
||||
def __init__(self, data_path, split="sft", file="2023-02-10_oasst_prod.jsonl") -> None:
|
||||
super().__init__()
|
||||
|
||||
total_prob = reduce(lambda prev, split: prev + split[1], self.splits.items(), 0)
|
||||
assert math.isclose(total_prob, 1), "Make sure OAPrivate split ratios add to 1"
|
||||
|
||||
jsonl_file = os.path.join(data_path, self.file)
|
||||
|
||||
with open(jsonl_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# take a subset of the dataset based on the split
|
||||
rng = np.random.default_rng(seed=0)
|
||||
indices = np.arange(len(lines)).astype(int)
|
||||
rng.shuffle(indices)
|
||||
|
||||
cumsums = np.cumsum([[0] + list(self.splits.values())])
|
||||
split_index = list(self.splits.keys()).index(split)
|
||||
|
||||
start_index, end_index = int(cumsums[split_index] * len(lines)), int(cumsums[split_index + 1] * len(lines))
|
||||
|
||||
self.data = [json.loads(lines[index].strip()) for index in indices[start_index:end_index]]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Sample randomly from replies
|
||||
prompt = self.data[index]["prompt"]
|
||||
|
||||
pairs = []
|
||||
|
||||
while True:
|
||||
assert prompt["role"] == "prompter"
|
||||
prompter_text = prompt["text"]
|
||||
|
||||
if len(prompt["replies"]) == 0:
|
||||
break
|
||||
|
||||
reply = random.choice(prompt["replies"])
|
||||
reply_text = reply["text"]
|
||||
|
||||
# only add if the reply exists
|
||||
pairs.append(prompter_text)
|
||||
pairs.append(reply_text)
|
||||
|
||||
if len(reply["replies"]) == 0:
|
||||
break
|
||||
|
||||
prompt = random.choice(reply["replies"])
|
||||
|
||||
return format_pair(pairs)
|
||||
|
||||
|
||||
class PromptGeneratedDataset(Dataset):
|
||||
"""Generates from flan 11B
|
||||
User: What are the best methods for preventing a slave trade?
|
||||
|
||||
@@ -231,11 +231,10 @@ if __name__ == "__main__":
|
||||
eval_steps=training_conf.eval_steps,
|
||||
save_steps=training_conf.save_steps,
|
||||
eval_accumulation_steps=training_conf.eval_accumulation_steps,
|
||||
report_to="wandb",
|
||||
report_to="wandb" if training_conf.log_wandb else None,
|
||||
)
|
||||
|
||||
assert len(evals) > 0
|
||||
if not training_conf.deepspeed or training_conf.local_rank == 0:
|
||||
if training_conf.log_wandb and not training_conf.deepspeed or training_conf.local_rank == 0:
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
|
||||
@@ -61,16 +61,20 @@ class PerDatasetSampler(Sampler):
|
||||
@classmethod
|
||||
def build_sampler_from_config(cls, training_conf, datasets):
|
||||
dataset_sizes = [len(x) for x in datasets]
|
||||
fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes)
|
||||
fractions = get_dataset_fractions(training_conf.datasets, dataset_sizes, verbose=training_conf.verbose)
|
||||
dataset_size_per_epoch = [int(size * frac) for size, frac in zip(dataset_sizes, fractions)]
|
||||
return cls(dataset_sizes, dataset_size_per_epoch)
|
||||
|
||||
|
||||
def get_dataset_fractions(conf, dataset_sizes):
|
||||
def get_dataset_fractions(conf, dataset_sizes, verbose=False):
|
||||
"""Calculate fraction of each dataset to use per epoch when subsampling"""
|
||||
|
||||
if verbose:
|
||||
print("Creating sampler for datasets:")
|
||||
|
||||
fractions = []
|
||||
for i, data_config in enumerate(conf):
|
||||
dataset_name = get_dataset_name_from_data_config(data_config)
|
||||
dataset_name, _ = get_dataset_name_and_kwargs_from_data_config(data_config)
|
||||
if isinstance(data_config, dict):
|
||||
if "fraction" in data_config[dataset_name]:
|
||||
if data_config[dataset_name]["fraction"] <= 0:
|
||||
@@ -81,9 +85,12 @@ def get_dataset_fractions(conf, dataset_sizes):
|
||||
raise ValueError(f"Please specify a size smaller than number of examples: {dataset_sizes[i]:,.0f}")
|
||||
fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i])
|
||||
else:
|
||||
raise ValueError("Please specify either fraction or size in config.yaml. See README for instructions.")
|
||||
fractions.append(1)
|
||||
else:
|
||||
fractions.append(1)
|
||||
|
||||
if verbose:
|
||||
print(f"Dataset: {dataset_name} fraction chosen: {fractions[-1]:.2f}")
|
||||
return fractions
|
||||
|
||||
|
||||
@@ -220,25 +227,37 @@ def get_model(conf, tokenizer):
|
||||
return model
|
||||
|
||||
|
||||
def get_dataset_name_from_data_config(data_config):
|
||||
def get_dataset_name_and_kwargs_from_data_config(data_config):
|
||||
if isinstance(data_config, dict):
|
||||
return list(data_config.keys())[0]
|
||||
return data_config
|
||||
name = list(data_config.keys())[0]
|
||||
kwargs = data_config[name]
|
||||
# remove 'fraction' or 'size' from kwargs
|
||||
kwargs.pop("fraction", None)
|
||||
kwargs.pop("size", None)
|
||||
return name, kwargs
|
||||
else:
|
||||
return data_config, {}
|
||||
|
||||
|
||||
def get_dataset(conf, tokenizer):
|
||||
train_datasets, evals = [], {}
|
||||
|
||||
for data_config in conf.datasets:
|
||||
dataset_name = get_dataset_name_from_data_config(data_config)
|
||||
train, val = get_one_dataset(conf, dataset_name)
|
||||
dataset_name, kwargs = get_dataset_name_and_kwargs_from_data_config(data_config)
|
||||
train, val = get_one_dataset(conf, dataset_name, **kwargs)
|
||||
train_datasets.append(train)
|
||||
evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val
|
||||
|
||||
if val is not None:
|
||||
evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val
|
||||
|
||||
train = ConcatDataset(train_datasets)
|
||||
|
||||
collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length)
|
||||
train_collate_fn = TrainDialogueDataCollator(tokenizer, max_length=conf.max_length)
|
||||
|
||||
train_collate_fn = (
|
||||
TrainDialogueDataCollator(tokenizer, max_length=conf.max_length) if conf.samples_mixing else collate_fn
|
||||
)
|
||||
|
||||
return train, evals, collate_fn, train_collate_fn
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user