[feature] added translation, rallio instruct tuning dataset, prosocial for safety, new summary dataset

This commit is contained in:
theblackcat102
2023-01-20 03:02:07 +00:00
parent da6a3b687e
commit 74cb9aaa5a
11 changed files with 399 additions and 18 deletions
@@ -2,7 +2,7 @@ model_name: microsoft/deberta-v3-base
learning_rate: 1e-5
scheduler: cosine
gradient_checkpointing: false
gradient_accumulation_steps: 32
gradient_accumulation_steps: 16
per_device_train_batch_size: 2
warmup_steps: 600
eval_steps: 200
@@ -29,6 +29,14 @@ defaults:
- soda
- joke
- gsm8k
- dive_mt
- wmt2019_zh-en
- wmt2019_ru-en
- wmt2019_de-en
- ted_trans_nl-en
- ted_trans_de-ja
- instruct_tuning
- wmt2019_de-en
- samsum
cache_dir: .cache
loss_fn: CrossEntropyLoss
@@ -0,0 +1,26 @@
# Dataset collections overview:
currently dataset can be divided into 3 classes
- language knowledge
- summarization
- translation
- dialogue : don't let user know you are a robot
- STEM : knowledge about the world
- coding
- world knowledge <= ideally we want to handle this via prefix context
Issues and TODO:
* as dataset are growing, how can we update this section less
* ideally we can update the config yaml and new dataset will be download from hub
* one possible idea is we upload the trasform format of these dataset to the OA hub
@@ -1,11 +1,26 @@
from custom_datasets.prompt_dialogue import PromptGeneratedDataset
"""
High level functions for model training
"""
from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT
from custom_datasets.summarization import SummarizationDataset
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
from custom_datasets.translation import WMT2019, DiveMT, TEDTalk
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"]
SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum"]
SUMMARIZATION_DATASETS = [
"xsum",
"cnn_dailymail",
"samsum",
"multi_news",
"scitldr",
"billsum",
"debate_sum",
"tldr_news",
]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"]
def train_val_dataset(dataset, val_split=0.2):
@@ -25,20 +40,43 @@ def get_one_dataset(conf, dataset_name):
elif dataset_name in SUMMARIZATION_DATASETS:
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
val_name = "validation" if dataset_name not in ["billsum"] else "test"
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
if dataset_name == "debate_sum":
train, eval = train_val_dataset(train, val_split=0.2)
else:
val_name = "validation" if dataset_name not in ["billsum"] else "test"
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
elif "ted_trans" in dataset_name:
language_pair = dataset_name.split("_")[-1]
dataset = TEDTalk(pair=language_pair, split="train")
train, eval = train_val_dataset(dataset, val_split=0.2)
elif "wmt2019" in dataset_name:
language_pair = dataset_name.split("_")[-1]
train = WMT2019(pair=language_pair, split="train")
eval = WMT2019(pair=language_pair, split="validation")
elif dataset_name == "dive_mt":
dataset = DiveMT()
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "webgpt":
dataset = WebGPT()
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "prompt_dialogue":
dataset = PromptGeneratedDataset(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "prosocial_dialogue":
train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train")
eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation")
elif dataset_name == "explain_prosocial":
train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train")
eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation")
elif dataset_name == "soda":
dataset = SODA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.1)
elif dataset_name == "joke":
dataset = JokeExplaination(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "instruct_tuning":
dataset = InstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
else:
raise ValueError(f"Unknown dataset {dataset_name}")
@@ -25,6 +25,7 @@ class DialogueDataCollator:
for feature_one in features:
assert len(feature_one) % 2 == 0, "Number of messages must be even"
# TODO: we should push this to dataset __getitem__
messages = [
(QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "")
+ x
@@ -1,3 +1,4 @@
import json
import os
from urllib.request import urlopen
@@ -14,6 +15,7 @@ class PromptGeneratedDataset(Dataset):
we are ignoring results with multiple lines for now
"""
name = "prompt_dialogue"
url = "https://github.com/Rallio67/language-model-agents/raw/main/chat_dialogue_v2_c.txt"
def __init__(self, cache_dir) -> None:
@@ -49,3 +51,55 @@ class PromptGeneratedDataset(Dataset):
def __getitem__(self, index):
question, answer = self.pairs[index]
return question, answer
class InstructionTuning(Dataset):
"""
We have seen some promising capabilities from instruction tuning
with the following mix of datasets that are derived from datasets
available online.
The files for this data are in json format as a list of tuples
where each tuple is (source,instruction_response_pair)
- instruction_tuning_dataset_alpha_part1.json
- instruction_tuning_dataset_alpha_part2.json
Not to be confused with unatural instruction
"""
name = "instruction_dataset"
url_part_2 = (
"https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part2.json"
)
url_part_1 = (
"https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part1.json"
)
def __init__(self, cache_dir) -> None:
super().__init__()
os.makedirs(cache_dir, exist_ok=True)
self.pairs = []
for file_link in [self.url_part_1, self.url_part_2]:
basename = file_link.split("/")[-1]
instruction_tune_file = os.path.join(cache_dir, basename)
if not os.path.exists(instruction_tune_file):
with urlopen(file_link) as file:
content = file.read().decode()
with open(instruction_tune_file, "w", encoding="utf-8") as fout:
fout.write(content)
with open(instruction_tune_file, "r", encoding="utf-8") as f:
datasets = json.load(f)
for row in datasets:
_, response_pair = row
question, answer = response_pair.split("\n\n", maxsplit=1)
answer = answer.replace("<|endoftext|>", "").strip()
self.pairs.append((question, answer))
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
question, answer = self.pairs[index]
return question, answer
@@ -1,11 +1,18 @@
"""
Open / close book QA datasets
"""
import json
import os
import re
from urllib.request import urlopen
import numpy as np
from datasets import load_dataset
from torch.utils.data import Dataset
# @agoryuno contributed this
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
@@ -75,6 +82,9 @@ class QADataset(Dataset):
class WebGPT(Dataset):
name = "webgpt"
def __init__(self) -> None:
super().__init__()
@@ -89,7 +99,9 @@ class WebGPT(Dataset):
self.index2question[len(self.index2question)] = question
# only keep the best answer
questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"]
questions[question] = re_reference_remove.sub(
"", row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"]
)
self.questions = questions
@@ -103,6 +115,9 @@ class WebGPT(Dataset):
class SODA(Dataset):
name = "soda"
def process_soda_convo(self, data):
pairs = []
play_as = data["speakers"][1]
@@ -149,8 +164,8 @@ class SODA(Dataset):
class JokeExplaination(Dataset):
""" """
name = "joke"
url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl"
def __init__(self, cache_dir) -> None:
@@ -182,3 +197,6 @@ class JokeExplaination(Dataset):
def __getitem__(self, index):
question, answer = self.pairs[index]
return question, answer
# https://huggingface.co/datasets/aquamuse
@@ -1,3 +1,6 @@
"""
Summarize different spectrum of documents
"""
import random
from datasets import load_dataset
@@ -12,13 +15,21 @@ SUMMARY_SPECIAL_PROMPT = {
}
summarization_config_mapping = {
"cnn_dailymail": ("3.0.0",),
"samsum": (),
"xsum": (),
"multi_news": (),
"scitldr": ("AIC",),
"billsum": (),
"reddit": (),
"cnn_dailymail": (
"cnn_dailymail",
"3.0.0",
),
"samsum": ("samsum",),
"xsum": ("xsum",),
"multi_news": ("multi_news",),
"scitldr": (
"scitldr",
"AIC",
),
"billsum": ("billsum",),
"reddit": ("reddit",),
"tldr_news": ("JulesBelveze/tldr_news",), # need to fix : JulesBelveze/tldr_news
"debate_sum": ("Hellisotherpeople/DebateSum",), # Hellisotherpeople/DebateSum
}
summarization_name_mapping = {
@@ -29,6 +40,8 @@ summarization_name_mapping = {
"scitldr": ("source", "target"),
"billsum": ("text", "summary"),
"reddit": ("content", "summary"),
"tldr_news": ("content", "headline"),
"debate_sum": ("Full-Document", "Extract"),
}
@@ -43,7 +56,7 @@ def index_summary_merge(text, summary):
class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split):
self.name = dataset
self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.text_column, self.summary_column = summarization_name_mapping[dataset]
self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default
@@ -0,0 +1,65 @@
'''
SFT dataset to reject toxic questions
'''
import random
from datasets import load_dataset
from torch.utils.data import Dataset
class ProsocialDialogueExplaination(Dataset):
name = "prosocial_explain"
TEMPLATE = [
# 0 : reply or sentence of interest, 1 : reason of caution
("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"),
("Explain to me why this sentence is {1}: {0}", "This sentence is {1} because {0}"),
("'{0}' Why is this sentence {1}", "This sentence may need {1} because {0}"),
("'{0}' Explain to me why we need to {1} of this sentence", "We need to be {1} of this sentence because {0}"),
]
def __init__(self, split='train', cache_dir='.cache') -> None:
super().__init__()
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
self.pairs = []
for row in dataset:
for safety_annotation, safe_answer in zip(row['safety_annotations'], row['safety_annotation_reasons']):
(prompt_template, answer_template) = random.choice(self.TEMPLATE)
self.pairs.append((
prompt_template.format(row['context'],safety_annotation),
answer_template.format( safe_answer, safety_annotation)
))
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
return self.pairs[idx]
class ProsocialDialogue(Dataset):
name = "prosocial_dialogue"
'''
ProsocialDialog, we set up a human-AI collaborative data creation framework,
where GPT-3 generates the potentially unsafe utterances, and crowdworkers
provide prosocial responses to them. This approach allows us to circumvent
two substantial challenges:
(1) there are no available large-scale corpora of multiturn prosocial conversations
between humans
(2) asking humans to write unethical, toxic, or problematic utterances could result
in psychological harms (Roberts, 2017; Steiger et al., 2021).
'''
PREFIX = "<prefix>You are now a prosocial chatbot, be caution and casual when reply</prefix>"
def __init__(self, split='train', cache_dir='.cache') -> None:
super().__init__()
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
self.pairs = []
for row in dataset:
for answer in row['rots']:
self.pairs.append((
self.PREFIX+row['context'],
answer
))
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
return self.pairs[idx]
@@ -0,0 +1,157 @@
'''
List of translation dataset
GroNLP/divemt
fill in the blanks : https://huggingface.co/datasets/m_lama
'''
import random
from datasets import load_dataset
from torch.utils.data import Dataset
# postfix prompt
TRANSLATION_PROMPT = {
"zh": [ # simplified or any chinese which was not mentioned
"Translate to chinese simplified: {}",
"{}, translate to chinese",
"{} give me the chinese translation",
"翻译成中文: {}",
"{} 这句中文翻译怎麽写?",
"我需要这句话的中文翻译: {}",
],
"zh-tw": [ # WMT code
"{}. Translate to chinese traditional",
"{}, translate to chinese",
"{}. get chinese translation",
"中文翻譯: {}",
"幫我翻譯成中文: '{}'",
"{} 這句中文翻譯怎麼寫?",
],
"ja": [
"{}: help me translate to japanese",
"Need japanese translation: {}",
"{}: にほんごやくをよこす",
"{}: にほんごやくをおくれ",
"{}: にほんごやくを じょす",
"give me the japanese translation, {}",
],
"de": [
"{}: translate to german",
"give me the german translation {}",
"I want german translation {}",
"{}, ins Deutsche übersetzen",
"{}, Übersetzen ins Deutsche",
],
"fr": [
"{}. translate to french",
"{} write in french",
"{} french translation",
"{} ,donnez moi la traduction française"],
"ko": [
"{}. translate to Korean",
"how do we write in korean: {}",
"give me the korean translation: {}",
"{}, 한국어 번역을 해주세요",
],
"ms": [
"{} translate to malay",
"{} how do we write in Malay",
"{} give me the malay translation",
"{} , berikan saya terjemahan dalam bahasa melayu",
"{}, Jemahan di bahasa melayu"
"{}, jemahkan ayat ini kepada bahasa melayu"
],
"en": ["{}. translate to english", "{} write in english", "english translation: '{}'"],
"tr": ["{}. translate to turkish", "{} write in turkish", "turkish translation: '{}'"],
"it": ["{}. translate to italian", "{} write in italian", "italian translation: '{}'"],
"nl": ["{}. translate to dutch", "{} write in dutch", "dutch translation: '{}'"],
"vi": ["{}. translate to vietnamese", "{} write in vietnamese", "vietnamese translation: '{}'"],
"ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"],
}
class TranslationPair(Dataset):
def __init__(self) -> None:
super().__init__()
self.pairs = []
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
return self.pairs[index]
class WMT2019(TranslationPair):
def __init__(self, pair='zh-en', split='train') -> None:
super().__init__()
dataset = load_dataset('wmt19', pair)[split]
self.pairs = []
src, tgt = pair.split('-')
for row in dataset:
row = row['translation']
if random.random() > 0.5:
source = random.choice(
TRANSLATION_PROMPT[tgt]
).format(row[src])
self.pairs.append((source, row[tgt]))
else:# translating in reverse direction
source = random.choice(
TRANSLATION_PROMPT[src]
).format(row[tgt])
self.pairs.append((source, row[src]))
class DiveMT(TranslationPair):
REMAP = {
'tur': 'tr',
'ita': 'it',
'ukr': 'uk',
'nld': 'nl',
'vie': 'vi',
'ara': 'ar'
}
def __init__(self, split='train') -> None:
super().__init__()
dataset = load_dataset('GroNLP/divemt', 'main')[split]
tgt, src = 'tgt_text', 'src_text'
for row in dataset:
# ISO 639-2
lang_code_2 = row['subject_id'].split('_')[0]
lang_code = self.REMAP[lang_code_2]
if lang_code not in TRANSLATION_PROMPT:
continue
if random.random() > 0.5:
source = random.choice(
TRANSLATION_PROMPT[lang_code]
).format(row[src])
self.pairs.append((source, row[tgt]))
else:# translating in reverse direction
lang_code = 'en'
source = random.choice(
TRANSLATION_PROMPT[lang_code]
).format(row[tgt])
self.pairs.append((source, row[src]))
class TEDTalk(TranslationPair):
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean
def __init__(self, pair='de-ja', split='train', year='2016') -> None:
super().__init__()
dataset = load_dataset('ted_talks_iwslt', language_pair=pair.split('-'), year=year)[split]
src, tgt = pair.split('-')
for row in dataset:
row = row['translation']
if random.random() > 0.5:
source = random.choice(
TRANSLATION_PROMPT[tgt]
).format(row[src])
self.pairs.append((source, row[tgt]))
else:# translating in reverse direction
source = random.choice(
TRANSLATION_PROMPT[src]
).format(row[tgt])
self.pairs.append((source, row[src]))
@@ -7,10 +7,11 @@ from custom_datasets.dialogue_collator import DialogueDataCollator
def test_all_datasets():
qa_base = QA_DATASETS
summarize_base = SUMMARIZATION_DATASETS
others = ["prompt_dialogue", "webgpt", "soda", "joke"]
others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"]
translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"]
config = Namespace(cache_dir=".cache")
for dataset_name in others + qa_base + summarize_base:
for dataset_name in translation:
print(dataset_name)
train, eval = get_one_dataset(config, dataset_name)
# sanity check
@@ -51,4 +52,4 @@ def test_collate_fn():
if __name__ == "__main__":
test_collate_fn()
test_all_datasets()