diff --git a/model/reward/instructor/configs/deberta-v3-base.yml b/model/reward/instructor/configs/deberta-v3-base.yml index 7023709c..134cfdaa 100644 --- a/model/reward/instructor/configs/deberta-v3-base.yml +++ b/model/reward/instructor/configs/deberta-v3-base.yml @@ -2,7 +2,7 @@ model_name: microsoft/deberta-v3-base learning_rate: 1e-5 scheduler: cosine gradient_checkpointing: false -gradient_accumulation_steps: 32 +gradient_accumulation_steps: 16 per_device_train_batch_size: 2 warmup_steps: 600 eval_steps: 200 diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 2eaa6686..42a0ae2c 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -29,6 +29,14 @@ defaults: - soda - joke - gsm8k + - dive_mt + - wmt2019_zh-en + - wmt2019_ru-en + - wmt2019_de-en + - ted_trans_nl-en + - ted_trans_de-ja + - instruct_tuning + - wmt2019_de-en - samsum cache_dir: .cache loss_fn: CrossEntropyLoss diff --git a/model/supervised_finetuning/custom_datasets/README.md b/model/supervised_finetuning/custom_datasets/README.md new file mode 100644 index 00000000..9c825932 --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/README.md @@ -0,0 +1,26 @@ +# Dataset collections overview: + +currently dataset can be divided into 3 classes + +- language knowledge + + - summarization + + - translation + +- dialogue : don't let user know you are a robot + +- STEM : knowledge about the world + + - coding + + - world knowledge <= ideally we want to handle this via prefix context + +Issues and TODO: + +* as dataset are growing, how can we update this section less + +* ideally we can update the config yaml and new dataset will be download from hub + + * one possible idea is we upload the trasform format of these dataset to the OA hub + diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index e293af3d..cb844777 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,11 +1,26 @@ -from custom_datasets.prompt_dialogue import PromptGeneratedDataset +""" + High level functions for model training +""" +from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT from custom_datasets.summarization import SummarizationDataset +from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination +from custom_datasets.translation import WMT2019, DiveMT, TEDTalk from sklearn.model_selection import train_test_split from torch.utils.data import Subset QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"] -SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum"] +SUMMARIZATION_DATASETS = [ + "xsum", + "cnn_dailymail", + "samsum", + "multi_news", + "scitldr", + "billsum", + "debate_sum", + "tldr_news", +] +OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"] def train_val_dataset(dataset, val_split=0.2): @@ -25,20 +40,43 @@ def get_one_dataset(conf, dataset_name): elif dataset_name in SUMMARIZATION_DATASETS: train = SummarizationDataset(dataset_name, conf.cache_dir, "train") - val_name = "validation" if dataset_name not in ["billsum"] else "test" - eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name) + if dataset_name == "debate_sum": + train, eval = train_val_dataset(train, val_split=0.2) + else: + val_name = "validation" if dataset_name not in ["billsum"] else "test" + eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name) + elif "ted_trans" in dataset_name: + language_pair = dataset_name.split("_")[-1] + dataset = TEDTalk(pair=language_pair, split="train") + train, eval = train_val_dataset(dataset, val_split=0.2) + elif "wmt2019" in dataset_name: + language_pair = dataset_name.split("_")[-1] + train = WMT2019(pair=language_pair, split="train") + eval = WMT2019(pair=language_pair, split="validation") + elif dataset_name == "dive_mt": + dataset = DiveMT() + train, eval = train_val_dataset(dataset, val_split=0.2) elif dataset_name == "webgpt": dataset = WebGPT() train, eval = train_val_dataset(dataset, val_split=0.2) elif dataset_name == "prompt_dialogue": dataset = PromptGeneratedDataset(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.2) + elif dataset_name == "prosocial_dialogue": + train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train") + eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation") + elif dataset_name == "explain_prosocial": + train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train") + eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation") elif dataset_name == "soda": dataset = SODA(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.1) elif dataset_name == "joke": dataset = JokeExplaination(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.2) + elif dataset_name == "instruct_tuning": + dataset = InstructionTuning(conf.cache_dir) + train, eval = train_val_dataset(dataset, val_split=0.2) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index 2efe160f..719fa0d6 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -25,6 +25,7 @@ class DialogueDataCollator: for feature_one in features: assert len(feature_one) % 2 == 0, "Number of messages must be even" + # TODO: we should push this to dataset __getitem__ messages = [ (QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "") + x diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 372ea27f..4a1d83a3 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -1,3 +1,4 @@ +import json import os from urllib.request import urlopen @@ -14,6 +15,7 @@ class PromptGeneratedDataset(Dataset): we are ignoring results with multiple lines for now """ + name = "prompt_dialogue" url = "https://github.com/Rallio67/language-model-agents/raw/main/chat_dialogue_v2_c.txt" def __init__(self, cache_dir) -> None: @@ -49,3 +51,55 @@ class PromptGeneratedDataset(Dataset): def __getitem__(self, index): question, answer = self.pairs[index] return question, answer + + +class InstructionTuning(Dataset): + """ + We have seen some promising capabilities from instruction tuning + with the following mix of datasets that are derived from datasets + available online. + The files for this data are in json format as a list of tuples + where each tuple is (source,instruction_response_pair) + + - instruction_tuning_dataset_alpha_part1.json + - instruction_tuning_dataset_alpha_part2.json + + Not to be confused with unatural instruction + """ + + name = "instruction_dataset" + url_part_2 = ( + "https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part2.json" + ) + url_part_1 = ( + "https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part1.json" + ) + + def __init__(self, cache_dir) -> None: + super().__init__() + os.makedirs(cache_dir, exist_ok=True) + + self.pairs = [] + for file_link in [self.url_part_1, self.url_part_2]: + basename = file_link.split("/")[-1] + instruction_tune_file = os.path.join(cache_dir, basename) + if not os.path.exists(instruction_tune_file): + with urlopen(file_link) as file: + content = file.read().decode() + with open(instruction_tune_file, "w", encoding="utf-8") as fout: + fout.write(content) + + with open(instruction_tune_file, "r", encoding="utf-8") as f: + datasets = json.load(f) + for row in datasets: + _, response_pair = row + question, answer = response_pair.split("\n\n", maxsplit=1) + answer = answer.replace("<|endoftext|>", "").strip() + self.pairs.append((question, answer)) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + question, answer = self.pairs[index] + return question, answer diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index eed9c644..789b8f58 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -1,11 +1,18 @@ +""" + Open / close book QA datasets +""" import json import os +import re from urllib.request import urlopen import numpy as np from datasets import load_dataset from torch.utils.data import Dataset +# @agoryuno contributed this +re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]") + QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} @@ -75,6 +82,9 @@ class QADataset(Dataset): class WebGPT(Dataset): + + name = "webgpt" + def __init__(self) -> None: super().__init__() @@ -89,7 +99,9 @@ class WebGPT(Dataset): self.index2question[len(self.index2question)] = question # only keep the best answer - questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"] + questions[question] = re_reference_remove.sub( + "", row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"] + ) self.questions = questions @@ -103,6 +115,9 @@ class WebGPT(Dataset): class SODA(Dataset): + + name = "soda" + def process_soda_convo(self, data): pairs = [] play_as = data["speakers"][1] @@ -149,8 +164,8 @@ class SODA(Dataset): class JokeExplaination(Dataset): - """ """ + name = "joke" url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl" def __init__(self, cache_dir) -> None: @@ -182,3 +197,6 @@ class JokeExplaination(Dataset): def __getitem__(self, index): question, answer = self.pairs[index] return question, answer + + +# https://huggingface.co/datasets/aquamuse diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py index 69e4b51d..2a097fe7 100644 --- a/model/supervised_finetuning/custom_datasets/summarization.py +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -1,3 +1,6 @@ +""" + Summarize different spectrum of documents +""" import random from datasets import load_dataset @@ -12,13 +15,21 @@ SUMMARY_SPECIAL_PROMPT = { } summarization_config_mapping = { - "cnn_dailymail": ("3.0.0",), - "samsum": (), - "xsum": (), - "multi_news": (), - "scitldr": ("AIC",), - "billsum": (), - "reddit": (), + "cnn_dailymail": ( + "cnn_dailymail", + "3.0.0", + ), + "samsum": ("samsum",), + "xsum": ("xsum",), + "multi_news": ("multi_news",), + "scitldr": ( + "scitldr", + "AIC", + ), + "billsum": ("billsum",), + "reddit": ("reddit",), + "tldr_news": ("JulesBelveze/tldr_news",), # need to fix : JulesBelveze/tldr_news + "debate_sum": ("Hellisotherpeople/DebateSum",), # Hellisotherpeople/DebateSum } summarization_name_mapping = { @@ -29,6 +40,8 @@ summarization_name_mapping = { "scitldr": ("source", "target"), "billsum": ("text", "summary"), "reddit": ("content", "summary"), + "tldr_news": ("content", "headline"), + "debate_sum": ("Full-Document", "Extract"), } @@ -43,7 +56,7 @@ def index_summary_merge(text, summary): class SummarizationDataset(Dataset): def __init__(self, dataset, cache_dir, split): self.name = dataset - self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) + self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) self.text_column, self.summary_column = summarization_name_mapping[dataset] self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default diff --git a/model/supervised_finetuning/custom_datasets/toxic_conversation.py b/model/supervised_finetuning/custom_datasets/toxic_conversation.py new file mode 100644 index 00000000..6ef29163 --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/toxic_conversation.py @@ -0,0 +1,65 @@ +''' + SFT dataset to reject toxic questions + +''' +import random +from datasets import load_dataset +from torch.utils.data import Dataset + +class ProsocialDialogueExplaination(Dataset): + name = "prosocial_explain" + TEMPLATE = [ + # 0 : reply or sentence of interest, 1 : reason of caution + ("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"), + ("Explain to me why this sentence is {1}: {0}", "This sentence is {1} because {0}"), + ("'{0}' Why is this sentence {1}", "This sentence may need {1} because {0}"), + ("'{0}' Explain to me why we need to {1} of this sentence", "We need to be {1} of this sentence because {0}"), + ] + def __init__(self, split='train', cache_dir='.cache') -> None: + super().__init__() + dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] + self.pairs = [] + for row in dataset: + for safety_annotation, safe_answer in zip(row['safety_annotations'], row['safety_annotation_reasons']): + (prompt_template, answer_template) = random.choice(self.TEMPLATE) + self.pairs.append(( + prompt_template.format(row['context'],safety_annotation), + answer_template.format( safe_answer, safety_annotation) + )) + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + return self.pairs[idx] + +class ProsocialDialogue(Dataset): + name = "prosocial_dialogue" + ''' + ProsocialDialog, we set up a human-AI collaborative data creation framework, + where GPT-3 generates the potentially unsafe utterances, and crowdworkers + provide prosocial responses to them. This approach allows us to circumvent + two substantial challenges: + (1) there are no available large-scale corpora of multiturn prosocial conversations + between humans + (2) asking humans to write unethical, toxic, or problematic utterances could result + in psychological harms (Roberts, 2017; Steiger et al., 2021). + ''' + PREFIX = "You are now a prosocial chatbot, be caution and casual when reply" + + + def __init__(self, split='train', cache_dir='.cache') -> None: + super().__init__() + dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] + self.pairs = [] + for row in dataset: + for answer in row['rots']: + self.pairs.append(( + self.PREFIX+row['context'], + answer + )) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + return self.pairs[idx] diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py new file mode 100644 index 00000000..a6d46e9e --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -0,0 +1,157 @@ +''' + List of translation dataset + + GroNLP/divemt + + fill in the blanks : https://huggingface.co/datasets/m_lama + +''' +import random +from datasets import load_dataset +from torch.utils.data import Dataset + +# postfix prompt +TRANSLATION_PROMPT = { + "zh": [ # simplified or any chinese which was not mentioned + "Translate to chinese simplified: {}", + "{}, translate to chinese", + "{} give me the chinese translation", + "翻译成中文: {}", + "{} 这句中文翻译怎麽写?", + "我需要这句话的中文翻译: {}", + ], + "zh-tw": [ # WMT code + "{}. Translate to chinese traditional", + "{}, translate to chinese", + "{}. get chinese translation", + "中文翻譯: {}", + "幫我翻譯成中文: '{}'", + "{} 這句中文翻譯怎麼寫?", + ], + "ja": [ + "{}: help me translate to japanese", + "Need japanese translation: {}", + "{}: にほんごやくをよこす", + "{}: にほんごやくをおくれ", + "{}: にほんごやくを じょす", + "give me the japanese translation, {}", + ], + "de": [ + "{}: translate to german", + "give me the german translation {}", + "I want german translation {}", + "{}, ins Deutsche übersetzen", + "{}, Übersetzen ins Deutsche", + ], + "fr": [ + "{}. translate to french", + "{} write in french", + "{} french translation", + "{} ,donnez moi la traduction française"], + "ko": [ + "{}. translate to Korean", + "how do we write in korean: {}", + "give me the korean translation: {}", + "{}, 한국어 번역을 해주세요", + ], + "ms": [ + "{} translate to malay", + "{} how do we write in Malay", + "{} give me the malay translation", + "{} , berikan saya terjemahan dalam bahasa melayu", + "{}, Jemahan di bahasa melayu" + "{}, jemahkan ayat ini kepada bahasa melayu" + ], + "en": ["{}. translate to english", "{} write in english", "english translation: '{}'"], + "tr": ["{}. translate to turkish", "{} write in turkish", "turkish translation: '{}'"], + "it": ["{}. translate to italian", "{} write in italian", "italian translation: '{}'"], + "nl": ["{}. translate to dutch", "{} write in dutch", "dutch translation: '{}'"], + "vi": ["{}. translate to vietnamese", "{} write in vietnamese", "vietnamese translation: '{}'"], + "ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"], +} +class TranslationPair(Dataset): + def __init__(self) -> None: + super().__init__() + self.pairs = [] + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + return self.pairs[index] + + +class WMT2019(TranslationPair): + + def __init__(self, pair='zh-en', split='train') -> None: + super().__init__() + dataset = load_dataset('wmt19', pair)[split] + self.pairs = [] + src, tgt = pair.split('-') + for row in dataset: + row = row['translation'] + if random.random() > 0.5: + source = random.choice( + TRANSLATION_PROMPT[tgt] + ).format(row[src]) + self.pairs.append((source, row[tgt])) + else:# translating in reverse direction + source = random.choice( + TRANSLATION_PROMPT[src] + ).format(row[tgt]) + self.pairs.append((source, row[src])) + +class DiveMT(TranslationPair): + + REMAP = { + 'tur': 'tr', + 'ita': 'it', + 'ukr': 'uk', + 'nld': 'nl', + 'vie': 'vi', + 'ara': 'ar' + } + + def __init__(self, split='train') -> None: + super().__init__() + dataset = load_dataset('GroNLP/divemt', 'main')[split] + tgt, src = 'tgt_text', 'src_text' + for row in dataset: + # ISO 639-2 + lang_code_2 = row['subject_id'].split('_')[0] + lang_code = self.REMAP[lang_code_2] + if lang_code not in TRANSLATION_PROMPT: + continue + + if random.random() > 0.5: + source = random.choice( + TRANSLATION_PROMPT[lang_code] + ).format(row[src]) + self.pairs.append((source, row[tgt])) + else:# translating in reverse direction + lang_code = 'en' + source = random.choice( + TRANSLATION_PROMPT[lang_code] + ).format(row[tgt]) + self.pairs.append((source, row[src])) + + +class TEDTalk(TranslationPair): + # NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean + + def __init__(self, pair='de-ja', split='train', year='2016') -> None: + super().__init__() + dataset = load_dataset('ted_talks_iwslt', language_pair=pair.split('-'), year=year)[split] + src, tgt = pair.split('-') + for row in dataset: + row = row['translation'] + if random.random() > 0.5: + source = random.choice( + TRANSLATION_PROMPT[tgt] + ).format(row[src]) + self.pairs.append((source, row[tgt])) + else:# translating in reverse direction + source = random.choice( + TRANSLATION_PROMPT[src] + ).format(row[tgt]) + self.pairs.append((source, row[src])) diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py index c9363303..2ac43613 100644 --- a/model/supervised_finetuning/tests/test_datasets.py +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -7,10 +7,11 @@ from custom_datasets.dialogue_collator import DialogueDataCollator def test_all_datasets(): qa_base = QA_DATASETS summarize_base = SUMMARIZATION_DATASETS - others = ["prompt_dialogue", "webgpt", "soda", "joke"] + others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"] + translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"] config = Namespace(cache_dir=".cache") - for dataset_name in others + qa_base + summarize_base: + for dataset_name in translation: print(dataset_name) train, eval = get_one_dataset(config, dataset_name) # sanity check @@ -51,4 +52,4 @@ def test_collate_fn(): if __name__ == "__main__": - test_collate_fn() + test_all_datasets()