mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-01 16:50:12 +08:00
Merge pull request #885 from LAION-AI/sft-formatting
Fix prefix formatting
This commit is contained in:
@@ -43,7 +43,7 @@ def get_one_dataset(conf, dataset_name):
|
||||
if dataset_name == "debate_sum":
|
||||
train, eval = train_val_dataset(train, val_split=0.2)
|
||||
else:
|
||||
val_name = "validation" if dataset_name not in ["billsum"] else "test"
|
||||
val_name = "validation" if dataset_name not in ["billsum", "tldr_news"] else "test"
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
|
||||
elif "ted_trans" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS
|
||||
from torch.nn import functional as F
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
|
||||
@@ -23,15 +22,8 @@ class DialogueDataCollator:
|
||||
flatten_messages = []
|
||||
label_masks = []
|
||||
|
||||
for feature_one in features:
|
||||
assert len(feature_one) % 2 == 0, "Number of messages must be even"
|
||||
# TODO: we should push this to dataset __getitem__
|
||||
messages = [
|
||||
(QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "")
|
||||
+ x
|
||||
+ (QA_SPECIAL_TOKENS["Answer"] if i % 2 == 0 else "")
|
||||
for i, x in enumerate(feature_one)
|
||||
]
|
||||
for messages in features:
|
||||
messages = list(messages)
|
||||
|
||||
# Add a way for the model to terminate generation
|
||||
# When we predict the start of a new expected question, we want to be able to stop generation
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
|
||||
|
||||
|
||||
def format_pair(pair):
|
||||
return "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pair[0], QA_SPECIAL_TOKENS["Answer"]), pair[1]
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
from urllib.request import urlopen
|
||||
|
||||
from custom_datasets.formatting import format_pair
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
@@ -49,8 +50,7 @@ class PromptGeneratedDataset(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
class InstructionTuning(Dataset):
|
||||
@@ -101,5 +101,4 @@ class InstructionTuning(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
@@ -7,14 +7,13 @@ import re
|
||||
from urllib.request import urlopen
|
||||
|
||||
import numpy as np
|
||||
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# @agoryuno contributed this
|
||||
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
|
||||
|
||||
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
|
||||
|
||||
|
||||
def index_squad_v2(example):
|
||||
if len(example["answers"]["text"]):
|
||||
@@ -78,7 +77,7 @@ class QADataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.dataset[idx]
|
||||
return self.index_fn(data)
|
||||
return format_pair(self.index_fn(data))
|
||||
|
||||
|
||||
class WebGPT(Dataset):
|
||||
@@ -111,7 +110,7 @@ class WebGPT(Dataset):
|
||||
def __getitem__(self, index):
|
||||
question = self.index2question[index]
|
||||
answer = self.questions[question]
|
||||
return [question, answer]
|
||||
return format_pair((question, answer))
|
||||
|
||||
|
||||
class SODA(Dataset):
|
||||
@@ -121,14 +120,14 @@ class SODA(Dataset):
|
||||
def process_soda_convo(self, data):
|
||||
pairs = []
|
||||
play_as = data["speakers"][1]
|
||||
prefix = "{}{}. {}{}".format(
|
||||
QA_SPECIAL_TOKENS["StartPrefix"],
|
||||
data["narrative"],
|
||||
"your name {}".format(play_as),
|
||||
QA_SPECIAL_TOKENS["EndPrefix"],
|
||||
)
|
||||
question, answer = "", ""
|
||||
prefix, postfix = "", ""
|
||||
dialogue_bg = "{}{} {}{}".format(
|
||||
QA_SPECIAL_TOKENS["StartPrefix"],
|
||||
data["narrative"],
|
||||
"your are {}".format(play_as),
|
||||
QA_SPECIAL_TOKENS["EndPrefix"],
|
||||
)
|
||||
previous_chat = []
|
||||
|
||||
for idx, convo in enumerate(data["dialogue"]):
|
||||
@@ -138,14 +137,20 @@ class SODA(Dataset):
|
||||
else:
|
||||
answer = convo
|
||||
postfix = data["speakers"][idx]
|
||||
|
||||
if len(question) and len(answer) and prefix != postfix and postfix == play_as:
|
||||
history = "<sep>".join(
|
||||
["{}{}{}".format(p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) for p in previous_chat]
|
||||
[
|
||||
"{}{}{}{}".format(QA_SPECIAL_TOKENS["Question"], p[0], QA_SPECIAL_TOKENS["Answer"], p[1])
|
||||
for p in previous_chat
|
||||
]
|
||||
)
|
||||
if len(history):
|
||||
history += "<sep>"
|
||||
pairs.append((prefix + history + question, answer))
|
||||
prompt = QA_SPECIAL_TOKENS["Question"] + question + QA_SPECIAL_TOKENS["Answer"]
|
||||
pairs.append((dialogue_bg + history + prompt, answer))
|
||||
previous_chat.append((question, answer))
|
||||
|
||||
return pairs
|
||||
|
||||
def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None:
|
||||
@@ -166,8 +171,8 @@ class SODA(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
# special token added during preprocess
|
||||
return self.pairs[index]
|
||||
|
||||
|
||||
class SODADialogue(Dataset):
|
||||
@@ -218,7 +223,7 @@ class SODADialogue(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.pairs[index]
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
class JokeExplaination(Dataset):
|
||||
@@ -253,8 +258,7 @@ class JokeExplaination(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
# https://huggingface.co/datasets/aquamuse
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
import random
|
||||
|
||||
from custom_datasets.formatting import format_pair
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
@@ -54,11 +55,12 @@ def index_summary_merge(text, summary):
|
||||
|
||||
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
def __init__(self, dataset, cache_dir, split, max_words=512):
|
||||
self.name = dataset
|
||||
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.text_column, self.summary_column = summarization_name_mapping[dataset]
|
||||
self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default
|
||||
self.max_words = max_words
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
@@ -72,4 +74,5 @@ class SummarizationDataset(Dataset):
|
||||
else:
|
||||
prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"])
|
||||
|
||||
return ("".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[:256]), prompt]), summary)
|
||||
context = "".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[: self.max_words]), prompt])
|
||||
return format_pair((context, summary))
|
||||
|
||||
@@ -4,12 +4,13 @@
|
||||
"""
|
||||
import random
|
||||
|
||||
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class ProsocialDialogueExplaination(Dataset):
|
||||
name = "prosocial_explain"
|
||||
name = "explain_prosocial"
|
||||
TEMPLATE = [
|
||||
# 0 : reply or sentence of interest, 1 : reason of caution
|
||||
("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"),
|
||||
@@ -36,7 +37,7 @@ class ProsocialDialogueExplaination(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.pairs[idx]
|
||||
return format_pair(self.pairs[idx])
|
||||
|
||||
|
||||
class ProsocialDialogue(Dataset):
|
||||
@@ -58,8 +59,9 @@ class ProsocialDialogue(Dataset):
|
||||
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
|
||||
self.pairs = []
|
||||
for row in dataset:
|
||||
prompt = QA_SPECIAL_TOKENS["Question"] + row["context"] + QA_SPECIAL_TOKENS["Answer"]
|
||||
for answer in row["rots"]:
|
||||
self.pairs.append((self.PREFIX + row["context"], answer))
|
||||
self.pairs.append((self.PREFIX + prompt, answer))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
"""
|
||||
import random
|
||||
|
||||
from custom_datasets.formatting import format_pair
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
@@ -82,7 +83,7 @@ class TranslationPair(Dataset):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.pairs[index]
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
class WMT2019(TranslationPair):
|
||||
@@ -99,6 +100,8 @@ class WMT2019(TranslationPair):
|
||||
else: # translating in reverse direction
|
||||
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
|
||||
self.pairs.append((source, row[src]))
|
||||
if len(self.pairs) > 100000:
|
||||
break
|
||||
|
||||
|
||||
class DiveMT(TranslationPair):
|
||||
|
||||
@@ -7,8 +7,8 @@ from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
def test_all_datasets():
|
||||
qa_base = QA_DATASETS
|
||||
summarize_base = SUMMARIZATION_DATASETS
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"]
|
||||
translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"]
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning", "explain_prosocial", "prosocial_dialogue"]
|
||||
translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "ted_trans_de-ja", "ted_trans_nl-en"]
|
||||
|
||||
config = Namespace(cache_dir=".cache")
|
||||
for dataset_name in translation + others + summarize_base + qa_base:
|
||||
@@ -31,7 +31,6 @@ def test_collate_fn():
|
||||
qa_base = QA_DATASETS
|
||||
summarize_base = SUMMARIZATION_DATASETS
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"]
|
||||
|
||||
trains, evals = [], []
|
||||
for dataset_name in others + qa_base + summarize_base:
|
||||
print(dataset_name)
|
||||
@@ -41,10 +40,10 @@ def test_collate_fn():
|
||||
|
||||
dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128)
|
||||
for batch in dataloader:
|
||||
# print(batch.keys())
|
||||
# print(tokenizer.decode(batch['input_ids'][0]))
|
||||
# print('-----')
|
||||
# print(tokenizer.decode(batch['targets'][0][batch['label_masks'][0]]))
|
||||
print(batch.keys())
|
||||
print(tokenizer.decode(batch["input_ids"][0]))
|
||||
print("-----")
|
||||
print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]]))
|
||||
assert batch["targets"].shape[1] <= 512
|
||||
dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
|
||||
for batch in dataloader:
|
||||
|
||||
@@ -25,6 +25,10 @@ def get_tokenizer(conf):
|
||||
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token, "sep_token": "<|extratoken_100|>"})
|
||||
elif "codegen" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"})
|
||||
elif "pythia" in conf.model_name:
|
||||
tokenizer.add_special_tokens(
|
||||
{"pad_token": "<|padding|>", "sep_token": "<|endoftext|>", "eos_token": "<|endoftext|>"}
|
||||
)
|
||||
|
||||
additional_special_tokens = (
|
||||
[]
|
||||
|
||||
Reference in New Issue
Block a user