Merge pull request #885 from LAION-AI/sft-formatting

Fix prefix formatting
This commit is contained in:
sanagnos
2023-01-23 09:51:45 +01:00
committed by GitHub
10 changed files with 56 additions and 45 deletions
@@ -43,7 +43,7 @@ def get_one_dataset(conf, dataset_name):
if dataset_name == "debate_sum":
train, eval = train_val_dataset(train, val_split=0.2)
else:
val_name = "validation" if dataset_name not in ["billsum"] else "test"
val_name = "validation" if dataset_name not in ["billsum", "tldr_news"] else "test"
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
elif "ted_trans" in dataset_name:
language_pair = dataset_name.split("_")[-1]
@@ -3,7 +3,6 @@ from typing import Optional, Union
import numpy as np
import torch
from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS
from torch.nn import functional as F
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
@@ -23,15 +22,8 @@ class DialogueDataCollator:
flatten_messages = []
label_masks = []
for feature_one in features:
assert len(feature_one) % 2 == 0, "Number of messages must be even"
# TODO: we should push this to dataset __getitem__
messages = [
(QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "")
+ x
+ (QA_SPECIAL_TOKENS["Answer"] if i % 2 == 0 else "")
for i, x in enumerate(feature_one)
]
for messages in features:
messages = list(messages)
# Add a way for the model to terminate generation
# When we predict the start of a new expected question, we want to be able to stop generation
@@ -0,0 +1,5 @@
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
def format_pair(pair):
return "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pair[0], QA_SPECIAL_TOKENS["Answer"]), pair[1]
@@ -2,6 +2,7 @@ import json
import os
from urllib.request import urlopen
from custom_datasets.formatting import format_pair
from torch.utils.data import Dataset
@@ -49,8 +50,7 @@ class PromptGeneratedDataset(Dataset):
return len(self.pairs)
def __getitem__(self, index):
question, answer = self.pairs[index]
return question, answer
return format_pair(self.pairs[index])
class InstructionTuning(Dataset):
@@ -101,5 +101,4 @@ class InstructionTuning(Dataset):
return len(self.pairs)
def __getitem__(self, index):
question, answer = self.pairs[index]
return question, answer
return format_pair(self.pairs[index])
@@ -7,14 +7,13 @@ import re
from urllib.request import urlopen
import numpy as np
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
from datasets import load_dataset
from torch.utils.data import Dataset
# @agoryuno contributed this
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
def index_squad_v2(example):
if len(example["answers"]["text"]):
@@ -78,7 +77,7 @@ class QADataset(Dataset):
def __getitem__(self, idx):
data = self.dataset[idx]
return self.index_fn(data)
return format_pair(self.index_fn(data))
class WebGPT(Dataset):
@@ -111,7 +110,7 @@ class WebGPT(Dataset):
def __getitem__(self, index):
question = self.index2question[index]
answer = self.questions[question]
return [question, answer]
return format_pair((question, answer))
class SODA(Dataset):
@@ -121,14 +120,14 @@ class SODA(Dataset):
def process_soda_convo(self, data):
pairs = []
play_as = data["speakers"][1]
prefix = "{}{}. {}{}".format(
QA_SPECIAL_TOKENS["StartPrefix"],
data["narrative"],
"your name {}".format(play_as),
QA_SPECIAL_TOKENS["EndPrefix"],
)
question, answer = "", ""
prefix, postfix = "", ""
dialogue_bg = "{}{} {}{}".format(
QA_SPECIAL_TOKENS["StartPrefix"],
data["narrative"],
"your are {}".format(play_as),
QA_SPECIAL_TOKENS["EndPrefix"],
)
previous_chat = []
for idx, convo in enumerate(data["dialogue"]):
@@ -138,14 +137,20 @@ class SODA(Dataset):
else:
answer = convo
postfix = data["speakers"][idx]
if len(question) and len(answer) and prefix != postfix and postfix == play_as:
history = "<sep>".join(
["{}{}{}".format(p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) for p in previous_chat]
[
"{}{}{}{}".format(QA_SPECIAL_TOKENS["Question"], p[0], QA_SPECIAL_TOKENS["Answer"], p[1])
for p in previous_chat
]
)
if len(history):
history += "<sep>"
pairs.append((prefix + history + question, answer))
prompt = QA_SPECIAL_TOKENS["Question"] + question + QA_SPECIAL_TOKENS["Answer"]
pairs.append((dialogue_bg + history + prompt, answer))
previous_chat.append((question, answer))
return pairs
def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None:
@@ -166,8 +171,8 @@ class SODA(Dataset):
return len(self.pairs)
def __getitem__(self, index):
question, answer = self.pairs[index]
return question, answer
# special token added during preprocess
return self.pairs[index]
class SODADialogue(Dataset):
@@ -218,7 +223,7 @@ class SODADialogue(Dataset):
return len(self.pairs)
def __getitem__(self, index):
return self.pairs[index]
return format_pair(self.pairs[index])
class JokeExplaination(Dataset):
@@ -253,8 +258,7 @@ class JokeExplaination(Dataset):
return len(self.pairs)
def __getitem__(self, index):
question, answer = self.pairs[index]
return question, answer
return format_pair(self.pairs[index])
# https://huggingface.co/datasets/aquamuse
@@ -3,6 +3,7 @@
"""
import random
from custom_datasets.formatting import format_pair
from datasets import load_dataset
from torch.utils.data import Dataset
@@ -54,11 +55,12 @@ def index_summary_merge(text, summary):
class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split):
def __init__(self, dataset, cache_dir, split, max_words=512):
self.name = dataset
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.text_column, self.summary_column = summarization_name_mapping[dataset]
self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default
self.max_words = max_words
def __len__(self):
return len(self.dataset)
@@ -72,4 +74,5 @@ class SummarizationDataset(Dataset):
else:
prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"])
return ("".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[:256]), prompt]), summary)
context = "".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[: self.max_words]), prompt])
return format_pair((context, summary))
@@ -4,12 +4,13 @@
"""
import random
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
from datasets import load_dataset
from torch.utils.data import Dataset
class ProsocialDialogueExplaination(Dataset):
name = "prosocial_explain"
name = "explain_prosocial"
TEMPLATE = [
# 0 : reply or sentence of interest, 1 : reason of caution
("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"),
@@ -36,7 +37,7 @@ class ProsocialDialogueExplaination(Dataset):
return len(self.pairs)
def __getitem__(self, idx):
return self.pairs[idx]
return format_pair(self.pairs[idx])
class ProsocialDialogue(Dataset):
@@ -58,8 +59,9 @@ class ProsocialDialogue(Dataset):
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
self.pairs = []
for row in dataset:
prompt = QA_SPECIAL_TOKENS["Question"] + row["context"] + QA_SPECIAL_TOKENS["Answer"]
for answer in row["rots"]:
self.pairs.append((self.PREFIX + row["context"], answer))
self.pairs.append((self.PREFIX + prompt, answer))
def __len__(self):
return len(self.pairs)
@@ -8,6 +8,7 @@
"""
import random
from custom_datasets.formatting import format_pair
from datasets import load_dataset
from torch.utils.data import Dataset
@@ -82,7 +83,7 @@ class TranslationPair(Dataset):
return len(self.pairs)
def __getitem__(self, index):
return self.pairs[index]
return format_pair(self.pairs[index])
class WMT2019(TranslationPair):
@@ -99,6 +100,8 @@ class WMT2019(TranslationPair):
else: # translating in reverse direction
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
self.pairs.append((source, row[src]))
if len(self.pairs) > 100000:
break
class DiveMT(TranslationPair):
@@ -7,8 +7,8 @@ from custom_datasets.dialogue_collator import DialogueDataCollator
def test_all_datasets():
qa_base = QA_DATASETS
summarize_base = SUMMARIZATION_DATASETS
others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"]
translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"]
others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning", "explain_prosocial", "prosocial_dialogue"]
translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "ted_trans_de-ja", "ted_trans_nl-en"]
config = Namespace(cache_dir=".cache")
for dataset_name in translation + others + summarize_base + qa_base:
@@ -31,7 +31,6 @@ def test_collate_fn():
qa_base = QA_DATASETS
summarize_base = SUMMARIZATION_DATASETS
others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"]
trains, evals = [], []
for dataset_name in others + qa_base + summarize_base:
print(dataset_name)
@@ -41,10 +40,10 @@ def test_collate_fn():
dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128)
for batch in dataloader:
# print(batch.keys())
# print(tokenizer.decode(batch['input_ids'][0]))
# print('-----')
# print(tokenizer.decode(batch['targets'][0][batch['label_masks'][0]]))
print(batch.keys())
print(tokenizer.decode(batch["input_ids"][0]))
print("-----")
print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]]))
assert batch["targets"].shape[1] <= 512
dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
for batch in dataloader:
+4
View File
@@ -25,6 +25,10 @@ def get_tokenizer(conf):
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token, "sep_token": "<|extratoken_100|>"})
elif "codegen" in conf.model_name:
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"})
elif "pythia" in conf.model_name:
tokenizer.add_special_tokens(
{"pad_token": "<|padding|>", "sep_token": "<|endoftext|>", "eos_token": "<|endoftext|>"}
)
additional_special_tokens = (
[]