From 22e3ab1a890876691ab382e811b6a486f2fa3eeb Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Fri, 20 Jan 2023 07:23:02 +0000 Subject: [PATCH] [fix] linter fix --- model/supervised_finetuning/README.md | 5 +- .../custom_datasets/README.md | 19 ++-- .../custom_datasets/__init__.py | 2 +- .../custom_datasets/toxic_conversation.py | 37 ++++---- .../custom_datasets/translation.py | 86 ++++++++----------- 5 files changed, 70 insertions(+), 79 deletions(-) diff --git a/model/supervised_finetuning/README.md b/model/supervised_finetuning/README.md index 9f200847..d5b10e01 100644 --- a/model/supervised_finetuning/README.md +++ b/model/supervised_finetuning/README.md @@ -60,7 +60,10 @@ python trainer.py --configs defaults your-model-name --deepspeed ## Dataset choices -To specify which translation pair for [WMT](https://huggingface.co/datasets/wmt19) and [TED Talk](https://huggingface.co/datasets/ted_talks_iwslt) translation simply add the supported language pair at the postfix +To specify which translation pair for +[WMT](https://huggingface.co/datasets/wmt19) and +[TED Talk](https://huggingface.co/datasets/ted_talks_iwslt) translation simply +add the supported language pair at the postfix ``` datasets: diff --git a/model/supervised_finetuning/custom_datasets/README.md b/model/supervised_finetuning/custom_datasets/README.md index 9c825932..56a28574 100644 --- a/model/supervised_finetuning/custom_datasets/README.md +++ b/model/supervised_finetuning/custom_datasets/README.md @@ -4,23 +4,24 @@ currently dataset can be divided into 3 classes - language knowledge - - summarization + - summarization - - translation + - translation - dialogue : don't let user know you are a robot - STEM : knowledge about the world - - coding - - - world knowledge <= ideally we want to handle this via prefix context + - coding + + - world knowledge <= ideally we want to handle this via prefix context Issues and TODO: -* as dataset are growing, how can we update this section less +- as dataset are growing, how can we update this section less -* ideally we can update the config yaml and new dataset will be download from hub - - * one possible idea is we upload the trasform format of these dataset to the OA hub +- ideally we can update the config yaml and new dataset will be download from + hub + - one possible idea is we upload the trasform format of these dataset to the + OA hub diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index cef3a409..2e1e4b30 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -2,7 +2,7 @@ High level functions for model training """ from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset -from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT, SODADialogue +from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT from custom_datasets.summarization import SummarizationDataset from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination from custom_datasets.translation import WMT2019, DiveMT, TEDTalk diff --git a/model/supervised_finetuning/custom_datasets/toxic_conversation.py b/model/supervised_finetuning/custom_datasets/toxic_conversation.py index 6ef29163..815ac722 100644 --- a/model/supervised_finetuning/custom_datasets/toxic_conversation.py +++ b/model/supervised_finetuning/custom_datasets/toxic_conversation.py @@ -1,11 +1,13 @@ -''' +""" SFT dataset to reject toxic questions -''' +""" import random + from datasets import load_dataset from torch.utils.data import Dataset + class ProsocialDialogueExplaination(Dataset): name = "prosocial_explain" TEMPLATE = [ @@ -15,26 +17,31 @@ class ProsocialDialogueExplaination(Dataset): ("'{0}' Why is this sentence {1}", "This sentence may need {1} because {0}"), ("'{0}' Explain to me why we need to {1} of this sentence", "We need to be {1} of this sentence because {0}"), ] - def __init__(self, split='train', cache_dir='.cache') -> None: + + def __init__(self, split="train", cache_dir=".cache") -> None: super().__init__() dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] self.pairs = [] for row in dataset: - for safety_annotation, safe_answer in zip(row['safety_annotations'], row['safety_annotation_reasons']): + for safety_annotation, safe_answer in zip(row["safety_annotations"], row["safety_annotation_reasons"]): (prompt_template, answer_template) = random.choice(self.TEMPLATE) - self.pairs.append(( - prompt_template.format(row['context'],safety_annotation), - answer_template.format( safe_answer, safety_annotation) - )) + self.pairs.append( + ( + prompt_template.format(row["context"], safety_annotation), + answer_template.format(safe_answer, safety_annotation), + ) + ) + def __len__(self): return len(self.pairs) def __getitem__(self, idx): return self.pairs[idx] + class ProsocialDialogue(Dataset): name = "prosocial_dialogue" - ''' + """ ProsocialDialog, we set up a human-AI collaborative data creation framework, where GPT-3 generates the potentially unsafe utterances, and crowdworkers provide prosocial responses to them. This approach allows us to circumvent @@ -43,20 +50,16 @@ class ProsocialDialogue(Dataset): between humans (2) asking humans to write unethical, toxic, or problematic utterances could result in psychological harms (Roberts, 2017; Steiger et al., 2021). - ''' + """ PREFIX = "You are now a prosocial chatbot, be caution and casual when reply" - - def __init__(self, split='train', cache_dir='.cache') -> None: + def __init__(self, split="train", cache_dir=".cache") -> None: super().__init__() dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] self.pairs = [] for row in dataset: - for answer in row['rots']: - self.pairs.append(( - self.PREFIX+row['context'], - answer - )) + for answer in row["rots"]: + self.pairs.append((self.PREFIX + row["context"], answer)) def __len__(self): return len(self.pairs) diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py index 79fff0d1..694d31ce 100644 --- a/model/supervised_finetuning/custom_datasets/translation.py +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -1,18 +1,19 @@ -''' +""" List of translation dataset GroNLP/divemt fill in the blanks : https://huggingface.co/datasets/m_lama -''' +""" import random + from datasets import load_dataset from torch.utils.data import Dataset # postfix prompt TRANSLATION_PROMPT = { - "zh": [ # simplified or any chinese which was not mentioned + "zh": [ # simplified or any chinese which was not mentioned "Translate to chinese simplified: {}", "{}, translate to chinese", "{} give me the chinese translation", @@ -20,7 +21,7 @@ TRANSLATION_PROMPT = { "{} 这句中文翻译怎麽写?", "我需要这句话的中文翻译: {}", ], - "zh-tw": [ # WMT code + "zh-tw": [ # WMT code "{}. Translate to chinese traditional", "{}, translate to chinese", "{}. get chinese translation", @@ -47,7 +48,8 @@ TRANSLATION_PROMPT = { "{}. translate to french", "{} write in french", "{} french translation", - "{} ,donnez moi la traduction française"], + "{} ,donnez moi la traduction française", + ], "ko": [ "{}. translate to Korean", "how do we write in korean: {}", @@ -59,8 +61,7 @@ TRANSLATION_PROMPT = { "{} how do we write in Malay", "{} give me the malay translation", "{} , berikan saya terjemahan dalam bahasa melayu", - "{}, Jemahan di bahasa melayu" - "{}, jemahkan ayat ini kepada bahasa melayu" + "{}, Jemahan di bahasa melayu" "{}, jemahkan ayat ini kepada bahasa melayu", ], "en": ["{}. translate to english", "{} write in english", "english translation: '{}'"], "ru": ["помогите мне перевести это на русский : {}", "{} перевести на русский язык", "russian translation: '{}'"], @@ -70,6 +71,8 @@ TRANSLATION_PROMPT = { "vi": ["{}. Dịch sang tiếng việt nam", "{} write in vietnamese", "vietnamese translation: '{}'"], "ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"], } + + class TranslationPair(Dataset): def __init__(self) -> None: super().__init__() @@ -80,79 +83,60 @@ class TranslationPair(Dataset): def __getitem__(self, index): return self.pairs[index] - + class WMT2019(TranslationPair): - - def __init__(self, pair='zh-en', split='train') -> None: + def __init__(self, pair="zh-en", split="train") -> None: super().__init__() - dataset = load_dataset('wmt19', pair)[split] + dataset = load_dataset("wmt19", pair)[split] self.pairs = [] - src, tgt = pair.split('-') + src, tgt = pair.split("-") for row in dataset: - row = row['translation'] + row = row["translation"] if random.random() > 0.5: - source = random.choice( - TRANSLATION_PROMPT[tgt] - ).format(row[src]) + source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src]) self.pairs.append((source, row[tgt])) - else:# translating in reverse direction - source = random.choice( - TRANSLATION_PROMPT[src] - ).format(row[tgt]) + else: # translating in reverse direction + source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt]) self.pairs.append((source, row[src])) + class DiveMT(TranslationPair): - REMAP = { - 'tur': 'tr', - 'ita': 'it', - 'ukr': 'uk', - 'nld': 'nl', - 'vie': 'vi', - 'ara': 'ar' - } + REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"} - def __init__(self, split='train') -> None: + def __init__(self, split="train") -> None: super().__init__() - dataset = load_dataset('GroNLP/divemt', 'main')[split] - tgt, src = 'tgt_text', 'src_text' + dataset = load_dataset("GroNLP/divemt", "main")[split] + tgt, src = "tgt_text", "src_text" for row in dataset: # ISO 639-2 - lang_code_2 = row['subject_id'].split('_')[0] + lang_code_2 = row["subject_id"].split("_")[0] lang_code = self.REMAP[lang_code_2] if lang_code not in TRANSLATION_PROMPT: continue if random.random() > 0.5: - source = random.choice( - TRANSLATION_PROMPT[lang_code] - ).format(row[src]) + source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[src]) self.pairs.append((source, row[tgt])) - else:# translating in reverse direction - lang_code = 'en' - source = random.choice( - TRANSLATION_PROMPT[lang_code] - ).format(row[tgt]) + else: # translating in reverse direction + lang_code = "en" + source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[tgt]) self.pairs.append((source, row[src])) class TEDTalk(TranslationPair): # NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean - def __init__(self, pair='de-ja', split='train', year='2016') -> None: + def __init__(self, pair="de-ja", split="train", year="2016") -> None: super().__init__() - dataset = load_dataset('ted_talks_iwslt', language_pair=pair.split('-'), year=year)[split] - src, tgt = pair.split('-') + dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split] + src, tgt = pair.split("-") for row in dataset: - row = row['translation'] + row = row["translation"] if random.random() > 0.5: - source = random.choice( - TRANSLATION_PROMPT[tgt] - ).format(row[src]) + source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src]) self.pairs.append((source, row[tgt])) - else:# translating in reverse direction - source = random.choice( - TRANSLATION_PROMPT[src] - ).format(row[tgt]) + else: # translating in reverse direction + source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt]) self.pairs.append((source, row[src]))