Merge pull request #1455 from LAION-AI/add-dataset

Add more datasets and some fixes
This commit is contained in:
sanagnos
2023-02-11 10:23:59 +01:00
committed by GitHub
11 changed files with 372 additions and 52 deletions
+42
View File
@@ -18,6 +18,7 @@
"""
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
@@ -298,3 +299,44 @@ class AnthropicRLHF(Dataset):
context, pair = self.pairs[index]
return context, [pair]
class OAPrivate(Dataset):
"""
{
"prompt": <prompt string>,
"history": [("prompt1", "answer2"), ("prompt2", "answer2")],
"pos": <pos answer string>,
"neg_replies": [list of bad answers]
}
"""
split_name_mapping = {
"train": "rm_train.jsonl",
"test": "rm_test.jsonl",
"val": "rm_val.jsonl",
}
def __init__(self, split="train", sep_token="<sep>", data_path=".cache") -> None:
super().__init__()
import json
jsonl_file = os.path.join(data_path, self.split_name_mapping[split])
self.pairs = []
with open(jsonl_file, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
prefix = sep_token.join([sep_token.join(p) for p in data["history"][-2:]])
prefix += sep_token + data["prompt"]
pair = []
for neg_text in data["neg_replies"]:
pair.append((data["pos"], neg_text))
self.pairs.append((prefix, pair))
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
context, pair = self.pairs[index]
return context, pair
+7 -1
View File
@@ -115,7 +115,7 @@ def argument_parsing(parser):
def get_datasets(dataset_list: List[AnyStr], tokenizer):
from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, WebGPT
from rank_datasets import AnthropicRLHF, GPTJSynthetic, HFSummary, OAPrivate, WebGPT
from torch.utils.data import ConcatDataset
train_datasets, evals = [], {}
@@ -141,5 +141,11 @@ def get_datasets(dataset_list: List[AnyStr], tokenizer):
eval = AnthropicRLHF("test", tokenizer.sep_token)
train_datasets.append(train)
evals["anthropic_rlhf"] = eval
elif "oa_private" == dataset_name:
train = OAPrivate(split="train", sep_token=tokenizer.sep_token)
eval = OAPrivate(split="val", sep_token=tokenizer.sep_token)
train_datasets.append(train)
evals["oa_private"] = eval
train = ConcatDataset(train_datasets)
return train, evals
@@ -1,8 +1,8 @@
"""
High level functions for model training
"""
from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT
from custom_datasets.prompt_dialogue import InstructionTuning, PrivateInstructionTuning, PromptGeneratedDataset
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT
from custom_datasets.summarization import SummarizationDataset
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
from custom_datasets.translation import WMT2019, DiveMT, TEDTalk
@@ -32,7 +32,7 @@ SUMMARIZATION_DATASETS = [
"debate_sum",
"tldr_news",
]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning", "oa_translated"]
def train_val_dataset(dataset, val_split=0.2):
@@ -92,6 +92,12 @@ def get_one_dataset(conf, dataset_name):
elif dataset_name == "instruct_tuning":
dataset = InstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "private_tuning":
dataset = PrivateInstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "oa_translated":
dataset = TranslatedQA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.01)
else:
raise ValueError(f"Unknown dataset {dataset_name}")
@@ -1,3 +1,4 @@
import random
from dataclasses import dataclass
from typing import Optional, Union
@@ -76,3 +77,108 @@ class DialogueDataCollator:
batch["targets"] = torch.roll(batch["input_ids"], -1, -1)
return batch
@dataclass
class TrainDialogueDataCollator:
"""
Expects a list of texts corresponding to a sequence of [question, answer, question, answer, ...] pairs.
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
mix_length_threshold: Optional[int] = 256
mix_probability: Optional[int] = 0.6
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
flatten_messages = []
label_masks = []
total_short_context = 0
for messages in features:
messages = list(messages)
# Add a way for the model to terminate generation
# When we predict the start of a new expected question, we want to be able to stop generation
messages.append(self.tokenizer.eos_token)
flatten_message = self.tokenizer(
"".join(messages),
truncation=True,
max_length=self.max_length,
return_offsets_mapping=True,
)
message_change_indices = np.cumsum([len(x) for x in messages[:-1]])
# for each token an integer indicating the index of the message it belongs to. Just to create the label mask.
# Label mask is true when predicting a token that is part of the answer, false otherwise.
# TEXT: Question: Hello, how are you? Answer: I am fine. Question: What is your name? Answer: My name is John. Question:
# MESSAGE_INDICES: 0 0 0 0 0 0 1 1 1 2 2 2 2 2 2 3 3 3 3 -2
# LABEL_MASK: 0 0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 1 1 0
# If no result in next, we are predicting the last termination token(s)
message_indices = list(
map(
lambda x: next((i for i, val in enumerate(message_change_indices) if val >= x), -2),
list(map(lambda x: x[1], flatten_message["offset_mapping"])),
)
)
label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1)
try:
label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True
except IndexError:
# due to truncation, we might not have the last termination token
label_mask[-1] = False
label_masks.append(label_mask)
if len(flatten_message["input_ids"]) < self.mix_length_threshold:
total_short_context += len(flatten_message["input_ids"])
flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"})
# packing
if total_short_context > 2:
_flatten_messages, _label_masks = [], []
prev_short_msg, prev_short_mask = None, None
for flatten_msg, label_mask in zip(flatten_messages, label_masks):
if len(flatten_msg["input_ids"]) < self.mix_length_threshold and random.random() > 0.6:
if prev_short_msg is not None:
for key in flatten_msg.keys():
flatten_msg[key] += prev_short_msg[key]
flatten_msg[key] = flatten_msg[key][: self.max_length]
label_mask = np.concatenate([label_mask, prev_short_mask])
_label_masks.append(label_mask[: self.max_length])
_flatten_messages.append(flatten_msg)
# reset
prev_short_msg, prev_short_mask = None, None
else:
# prime
prev_short_msg, prev_short_mask = flatten_msg, label_mask
else:
_label_masks.append(label_mask)
_flatten_messages.append(flatten_msg)
if prev_short_msg is not None:
for key in flatten_msg.keys():
flatten_msg[key] += prev_short_msg[key]
flatten_msg[key] = flatten_msg[key][: self.max_length]
label_mask = np.concatenate([label_mask, prev_short_mask])[: self.max_length]
_label_masks.append(label_mask)
_flatten_messages.append(flatten_msg)
label_masks = _label_masks
flatten_messages = _flatten_messages
batch = self.tokenizer.pad(
flatten_messages,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
dim = batch["input_ids"].shape[-1]
batch["label_masks"] = torch.stack(
[F.pad(torch.tensor(x), (0, dim - len(x)), value=False) for x in label_masks]
)
batch["targets"] = torch.roll(batch["input_ids"], -1, -1)
return batch
@@ -2,7 +2,7 @@ import json
import os
from urllib.request import urlopen
from custom_datasets.formatting import format_pair
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
from torch.utils.data import Dataset
@@ -102,3 +102,45 @@ class InstructionTuning(Dataset):
def __getitem__(self, index):
return format_pair(self.pairs[index])
class PrivateInstructionTuning(Dataset):
"""
We have seen some promising capabilities from instruction tuning
with the following mix of datasets that are derived from datasets
available online.
The files for this data are in json format as a list of tuples
where each tuple is (source,instruction_response_pair)
Not to be confused with unatural instruction
"""
name = "private_tuning"
filename = "oa_v3_fixed_plus_safety.jsonl"
def __init__(self, cache_dir) -> None:
super().__init__()
os.makedirs(cache_dir, exist_ok=True)
self.pairs = []
for file_link in [self.filename]:
basename = file_link.split("/")[-1]
instruction_tune_file = os.path.join(cache_dir, basename)
with open(instruction_tune_file, "r", encoding="utf-8") as f:
for line in f:
row = json.loads(line)
prefix = ""
for _, convo in enumerate(row["text"].split("User:")):
if "Assistant" in convo:
prompt, answer = convo.split("Assistant:", maxsplit=1)
answer = answer.replace("<|endoftext|>", "").strip()
self.pairs.append((prefix + QA_SPECIAL_TOKENS["Question"] + prompt, answer))
prefix += "".join(format_pair((prompt, answer)))
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
prompt, answer = self.pairs[index]
return "{}{}".format(prompt, QA_SPECIAL_TOKENS["Answer"]), answer
@@ -1,6 +1,7 @@
"""
Open / close book QA datasets
"""
import glob
import json
import os
import re
@@ -137,11 +138,13 @@ class QADataset(Dataset):
self.dataset = load_dataset(context["name"], **context["params"])
else:
raise ValueError("Unknown dataset : " + dataset)
self.length = len(self.dataset)
def __len__(self):
return len(self.dataset)
return self.length
def __getitem__(self, idx):
data = self.dataset[idx]
return format_pair(self.index_fn(data))
@@ -311,14 +314,75 @@ class JokeExplaination(Dataset):
for line in f:
data = json.loads(line)
joke = data["joke"]
explanation = data["explanation"]
# DO NOT change this
# its the data that had syntax error
explanation = data["explaination"]
self.pairs.append((joke, explanation))
if len(question) > 0 and len(answer) > 0:
self.pairs.append((question, answer))
self.length = len(self.pairs)
def __len__(self):
return len(self.pairs)
return self.length
def __getitem__(self, index):
return format_pair(self.pairs[index])
class TranslatedQA(Dataset):
"""
Translation OA v3 results
a list of non english translation of OA v3 instruction generated text in jsonl
format for each line:
{
"text": "User: ... Assistant: ....",
"meta": {"source": ... },
"translate": [
{ "round": 1, "human":"...", "answer": "..."},
...
{ "round": K, "human":"...", "answer": "..."},
]
}
Since OA contain some code we needed to reference the original text to skip these
"""
name = "oa_translated"
def __init__(self, cache_dir) -> None:
super().__init__()
os.makedirs(cache_dir, exist_ok=True)
path = os.path.join(cache_dir, self.name)
os.makedirs(path, exist_ok=True)
self.pairs = []
for translated_jsonl in glob.glob(os.path.join(path, "*.jsonl")):
with open(translated_jsonl, "r") as fin:
for line in fin:
data = json.loads(line)
if "Python " in data["text"]:
# translation currently doesn't ignore code
# so we will have to reference original text
# for ignoring the translation
continue
prefix = ""
for convo_round in data["translate"]:
human, answer = format_pair((convo_round["human"], convo_round["answer"]))
if convo_round["round"] > 2:
self.pairs.append(("{}{}{}".format(prefix, "<sep>", human), answer))
else:
self.pairs.append((human, answer))
prefix += "{}{}{}{}".format(
QA_SPECIAL_TOKENS["Question"],
convo_round["human"],
QA_SPECIAL_TOKENS["Answer"],
convo_round["answer"],
)
self.length = len(self.pairs)
def __len__(self):
return self.length
def __getitem__(self, index):
return self.pairs[index]
@@ -57,7 +57,7 @@ def index_summary_merge(text, summary):
class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split, max_words=512):
self.name = dataset
if summarization_config_mapping[dataset][0] in ["billsum", "tldr_news"] & split == "validation":
if (dataset in ["billsum", "tldr_news"]) and (split == "validation"):
split = "test"
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.text_column, self.summary_column = summarization_name_mapping[dataset]
@@ -62,7 +62,8 @@ TRANSLATION_PROMPT = {
"{} how do we write in Malay",
"{} give me the malay translation",
"{} , berikan saya terjemahan dalam bahasa melayu",
"{}, Jemahan di bahasa melayu" "{}, jemahkan ayat ini kepada bahasa melayu",
"{}, Jemahan di bahasa melayu",
"{}, jemahkan ayat ini kepada bahasa melayu",
],
"en": ["{}. translate to english", "{} write in english", "english translation: '{}'"],
"ru": ["помогите мне перевести это на русский : {}", "{} перевести на русский язык", "russian translation: '{}'"],
@@ -71,24 +72,40 @@ TRANSLATION_PROMPT = {
"nl": ["{}. translate to dutch", "{} write in dutch", "dutch translation: '{}'"],
"vi": ["{}. Dịch sang tiếng việt nam", "{} write in vietnamese", "vietnamese translation: '{}'"],
"ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"],
"es": ["{}. translate to spanish", "{} write in spanish", "spanish translation: '{}'"],
"hi": ["{}. translate to hindi", "{}. translate to bengali", "{} write in hindi", "bengali translation: '{}'"],
}
class TranslationPair(Dataset):
def __init__(self) -> None:
def __init__(self, mix_prob=0.2) -> None:
super().__init__()
self.pairs = []
self.length = -1
self.mix_prob = mix_prob
def __len__(self):
if self.length < 0:
self.length = len(self.pairs)
return len(self.pairs)
def __getitem__(self, index):
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
additional = random.randint(0, 10) - 5
while additional == index:
additional = random.randint(0, 10) - 5
history_text = "".join(format_pair(self.pairs[additional + index]))
question, answer = self.pairs[index]
question = history_text + question
return format_pair((question, answer))
return format_pair(self.pairs[index])
class WMT2019(TranslationPair):
def __init__(self, pair="zh-en", split="train") -> None:
super().__init__()
def __init__(self, pair="zh-en", split="train", mix_prob=0.2) -> None:
super().__init__(mix_prob=mix_prob)
dataset = load_dataset("wmt19", pair)[split]
self.pairs = []
src, tgt = pair.split("-")
@@ -106,8 +123,8 @@ class DiveMT(TranslationPair):
REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"}
def __init__(self, split="train") -> None:
super().__init__()
def __init__(self, split="train", mix_prob=0.2) -> None:
super().__init__(mix_prob=mix_prob)
dataset = load_dataset("GroNLP/divemt", "main")[split]
tgt, src = "tgt_text", "src_text"
for row in dataset:
@@ -129,8 +146,8 @@ class DiveMT(TranslationPair):
class TEDTalk(TranslationPair):
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean
def __init__(self, pair="de-ja", split="train", year="2016") -> None:
super().__init__()
def __init__(self, pair="de-ja", split="train", year="2016", mix_prob=0.2) -> None:
super().__init__(mix_prob=mix_prob)
dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split]
src, tgt = pair.split("-")
for row in dataset:
@@ -27,7 +27,7 @@ def test_collate_fn():
config = Namespace(cache_dir=".cache", model_name="Salesforce/codegen-2B-multi")
tokenizer = get_tokenizer(config)
collate_fn = DialogueDataCollator(tokenizer, max_length=512)
collate_fn = DialogueDataCollator(tokenizer, max_length=620)
qa_base = QA_DATASETS
summarize_base = SUMMARIZATION_DATASETS
others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"]
@@ -40,11 +40,14 @@ def test_collate_fn():
dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128)
for batch in dataloader:
print(batch.keys())
print(batch["targets"].shape[0])
print(tokenizer.decode(batch["input_ids"][0]))
print("-----")
print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]]))
assert batch["targets"].shape[1] <= 512
assert batch["targets"].shape[1] <= 620
dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
for batch in dataloader:
assert batch["targets"].shape[1] <= 512
assert batch["targets"].shape[1] <= 620
test_collate_fn()
+54 -27
View File
@@ -1,14 +1,19 @@
import argparse
from distutils.util import strtobool
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import bitsandbytes
import datasets
import torch
from efficiency_utils import fuse_gelu
from torch import nn
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, Trainer, TrainingArguments
from transformers.trainer_pt_utils import IterableDatasetShard
from transformers.trainer_utils import seed_worker
from transformers.training_args import OptimizerNames
from transformers.utils import is_datasets_available
from utils import PerDatasetSampler, get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
@@ -34,10 +39,11 @@ class SFTTrainer(Trainer):
sampler: torch.utils.data.sampler.Sampler = None,
loss_function: str = "CrossEntropyLoss",
poly_eps: float = 1.0,
train_collate_fn: Callable = None,
**kwargs,
):
super().__init__(model, args, **kwargs)
self.train_collate_fn = train_collate_fn
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
self.loss_fct = get_loss(loss_function, poly_eps)
self.sampler = sampler
@@ -92,30 +98,51 @@ class SFTTrainer(Trainer):
return (loss, logits, labels)
def get_train_dataloader(self):
"""Inject custom data sampling behaviour into training loop"""
if self.sampler is None:
torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
shuffle=True,
collate_fn=self.data_collator,
)
else:
dataloader = torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
sampler=self.sampler,
collate_fn=self.data_collator,
)
if torch.cuda.device_count() <= 1:
return dataloader
else:
# Not strictly necessary to use accelerate, currently just
# ensures batches are padded to be divisible by # devices
from accelerate import Accelerator
"""
Inject custom data sampling behaviour into training loop
and use custom task mixing collate function : train_collate_fn
accelerator = Accelerator()
return accelerator.prepare(dataloader)
rewrite from:
https://github.com/huggingface/transformers/blob/67d074874d285e616393c65a0e670088e1b6b74a/src/transformers/trainer.py#L846
"""
data_collator = self.train_collate_fn
train_dataset = self.train_dataset
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
if isinstance(train_dataset, torch.utils.data.IterableDataset):
# if we are using iterable dataset it means no weight sampling
# added for backward compat
if self.args.world_size > 1:
train_dataset = IterableDatasetShard(
train_dataset,
batch_size=self._train_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size,
process_index=self.args.process_index,
)
return DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
if self.sampler is None:
train_sampler = self._get_train_sampler()
else:
train_sampler = self.sampler
return DataLoader(
train_dataset,
batch_size=self._train_batch_size,
sampler=train_sampler,
collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
worker_init_fn=seed_worker,
)
def _strtobool(x):
@@ -168,8 +195,7 @@ if __name__ == "__main__":
tokenizer = get_tokenizer(training_conf)
model = get_model(training_conf, tokenizer)
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
train, evals, collate_fn, train_collate_fn = get_dataset(training_conf, tokenizer)
sampler = PerDatasetSampler.build_sampler_from_config(training_conf, train.datasets)
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF
@@ -222,6 +248,7 @@ if __name__ == "__main__":
model=model,
args=args,
sampler=sampler,
train_collate_fn=train_collate_fn,
loss_function=training_conf.loss_fn,
poly_eps=training_conf.poly_eps,
train_dataset=train,
+11 -4
View File
@@ -6,7 +6,7 @@ import evaluate
import transformers
import yaml
from custom_datasets import get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator
from custom_datasets.dialogue_collator import DialogueDataCollator, TrainDialogueDataCollator
from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS
from losses import CrossEntropyLoss, PolyLoss
from models import freeze_top_n_layers, get_specific_model
@@ -126,7 +126,14 @@ def get_tokenizer(conf) -> transformers.AutoTokenizer:
if tokenizer_config.special_tokens:
if "GPT-JT" in conf.model_name:
tokenizer_config.special_tokens.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens(tokenizer_config.special_tokens)
# SpecialTokens : latest in 4.25, 4.26
tokenizer.add_special_tokens(
{
"pad_token": tokenizer_config.special_tokens.pad_token,
"eos_token": tokenizer_config.special_tokens.eos_token,
"sep_token": tokenizer_config.special_tokens.sep_token,
}
)
additional_special_tokens = (
[]
@@ -231,8 +238,8 @@ def get_dataset(conf, tokenizer):
train = ConcatDataset(train_datasets)
collate_fn = DialogueDataCollator(tokenizer, max_length=conf.max_length)
return train, evals, collate_fn
train_collate_fn = TrainDialogueDataCollator(tokenizer, max_length=conf.max_length)
return train, evals, collate_fn, train_collate_fn
def get_loss(loss, poly_eps):