[fix] linter fix

This commit is contained in:
theblackcat102
2023-01-20 07:23:02 +00:00
parent c255148dc6
commit 22e3ab1a89
5 changed files with 70 additions and 79 deletions
+4 -1
View File
@@ -60,7 +60,10 @@ python trainer.py --configs defaults your-model-name --deepspeed
## Dataset choices
To specify which translation pair for [WMT](https://huggingface.co/datasets/wmt19) and [TED Talk](https://huggingface.co/datasets/ted_talks_iwslt) translation simply add the supported language pair at the postfix
To specify which translation pair for
[WMT](https://huggingface.co/datasets/wmt19) and
[TED Talk](https://huggingface.co/datasets/ted_talks_iwslt) translation simply
add the supported language pair at the postfix
```
datasets:
@@ -4,23 +4,24 @@ currently dataset can be divided into 3 classes
- language knowledge
- summarization
- summarization
- translation
- translation
- dialogue : don't let user know you are a robot
- STEM : knowledge about the world
- coding
- world knowledge <= ideally we want to handle this via prefix context
- coding
- world knowledge <= ideally we want to handle this via prefix context
Issues and TODO:
* as dataset are growing, how can we update this section less
- as dataset are growing, how can we update this section less
* ideally we can update the config yaml and new dataset will be download from hub
* one possible idea is we upload the trasform format of these dataset to the OA hub
- ideally we can update the config yaml and new dataset will be download from
hub
- one possible idea is we upload the trasform format of these dataset to the
OA hub
@@ -2,7 +2,7 @@
High level functions for model training
"""
from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT, SODADialogue
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT
from custom_datasets.summarization import SummarizationDataset
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
from custom_datasets.translation import WMT2019, DiveMT, TEDTalk
@@ -1,11 +1,13 @@
'''
"""
SFT dataset to reject toxic questions
'''
"""
import random
from datasets import load_dataset
from torch.utils.data import Dataset
class ProsocialDialogueExplaination(Dataset):
name = "prosocial_explain"
TEMPLATE = [
@@ -15,26 +17,31 @@ class ProsocialDialogueExplaination(Dataset):
("'{0}' Why is this sentence {1}", "This sentence may need {1} because {0}"),
("'{0}' Explain to me why we need to {1} of this sentence", "We need to be {1} of this sentence because {0}"),
]
def __init__(self, split='train', cache_dir='.cache') -> None:
def __init__(self, split="train", cache_dir=".cache") -> None:
super().__init__()
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
self.pairs = []
for row in dataset:
for safety_annotation, safe_answer in zip(row['safety_annotations'], row['safety_annotation_reasons']):
for safety_annotation, safe_answer in zip(row["safety_annotations"], row["safety_annotation_reasons"]):
(prompt_template, answer_template) = random.choice(self.TEMPLATE)
self.pairs.append((
prompt_template.format(row['context'],safety_annotation),
answer_template.format( safe_answer, safety_annotation)
))
self.pairs.append(
(
prompt_template.format(row["context"], safety_annotation),
answer_template.format(safe_answer, safety_annotation),
)
)
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
return self.pairs[idx]
class ProsocialDialogue(Dataset):
name = "prosocial_dialogue"
'''
"""
ProsocialDialog, we set up a human-AI collaborative data creation framework,
where GPT-3 generates the potentially unsafe utterances, and crowdworkers
provide prosocial responses to them. This approach allows us to circumvent
@@ -43,20 +50,16 @@ class ProsocialDialogue(Dataset):
between humans
(2) asking humans to write unethical, toxic, or problematic utterances could result
in psychological harms (Roberts, 2017; Steiger et al., 2021).
'''
"""
PREFIX = "<prefix>You are now a prosocial chatbot, be caution and casual when reply</prefix>"
def __init__(self, split='train', cache_dir='.cache') -> None:
def __init__(self, split="train", cache_dir=".cache") -> None:
super().__init__()
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
self.pairs = []
for row in dataset:
for answer in row['rots']:
self.pairs.append((
self.PREFIX+row['context'],
answer
))
for answer in row["rots"]:
self.pairs.append((self.PREFIX + row["context"], answer))
def __len__(self):
return len(self.pairs)
@@ -1,18 +1,19 @@
'''
"""
List of translation dataset
GroNLP/divemt
fill in the blanks : https://huggingface.co/datasets/m_lama
'''
"""
import random
from datasets import load_dataset
from torch.utils.data import Dataset
# postfix prompt
TRANSLATION_PROMPT = {
"zh": [ # simplified or any chinese which was not mentioned
"zh": [ # simplified or any chinese which was not mentioned
"Translate to chinese simplified: {}",
"{}, translate to chinese",
"{} give me the chinese translation",
@@ -20,7 +21,7 @@ TRANSLATION_PROMPT = {
"{} 这句中文翻译怎麽写?",
"我需要这句话的中文翻译: {}",
],
"zh-tw": [ # WMT code
"zh-tw": [ # WMT code
"{}. Translate to chinese traditional",
"{}, translate to chinese",
"{}. get chinese translation",
@@ -47,7 +48,8 @@ TRANSLATION_PROMPT = {
"{}. translate to french",
"{} write in french",
"{} french translation",
"{} ,donnez moi la traduction française"],
"{} ,donnez moi la traduction française",
],
"ko": [
"{}. translate to Korean",
"how do we write in korean: {}",
@@ -59,8 +61,7 @@ TRANSLATION_PROMPT = {
"{} how do we write in Malay",
"{} give me the malay translation",
"{} , berikan saya terjemahan dalam bahasa melayu",
"{}, Jemahan di bahasa melayu"
"{}, jemahkan ayat ini kepada bahasa melayu"
"{}, Jemahan di bahasa melayu" "{}, jemahkan ayat ini kepada bahasa melayu",
],
"en": ["{}. translate to english", "{} write in english", "english translation: '{}'"],
"ru": ["помогите мне перевести это на русский : {}", "{} перевести на русский язык", "russian translation: '{}'"],
@@ -70,6 +71,8 @@ TRANSLATION_PROMPT = {
"vi": ["{}. Dịch sang tiếng việt nam", "{} write in vietnamese", "vietnamese translation: '{}'"],
"ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"],
}
class TranslationPair(Dataset):
def __init__(self) -> None:
super().__init__()
@@ -80,79 +83,60 @@ class TranslationPair(Dataset):
def __getitem__(self, index):
return self.pairs[index]
class WMT2019(TranslationPair):
def __init__(self, pair='zh-en', split='train') -> None:
def __init__(self, pair="zh-en", split="train") -> None:
super().__init__()
dataset = load_dataset('wmt19', pair)[split]
dataset = load_dataset("wmt19", pair)[split]
self.pairs = []
src, tgt = pair.split('-')
src, tgt = pair.split("-")
for row in dataset:
row = row['translation']
row = row["translation"]
if random.random() > 0.5:
source = random.choice(
TRANSLATION_PROMPT[tgt]
).format(row[src])
source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src])
self.pairs.append((source, row[tgt]))
else:# translating in reverse direction
source = random.choice(
TRANSLATION_PROMPT[src]
).format(row[tgt])
else: # translating in reverse direction
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
self.pairs.append((source, row[src]))
class DiveMT(TranslationPair):
REMAP = {
'tur': 'tr',
'ita': 'it',
'ukr': 'uk',
'nld': 'nl',
'vie': 'vi',
'ara': 'ar'
}
REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"}
def __init__(self, split='train') -> None:
def __init__(self, split="train") -> None:
super().__init__()
dataset = load_dataset('GroNLP/divemt', 'main')[split]
tgt, src = 'tgt_text', 'src_text'
dataset = load_dataset("GroNLP/divemt", "main")[split]
tgt, src = "tgt_text", "src_text"
for row in dataset:
# ISO 639-2
lang_code_2 = row['subject_id'].split('_')[0]
lang_code_2 = row["subject_id"].split("_")[0]
lang_code = self.REMAP[lang_code_2]
if lang_code not in TRANSLATION_PROMPT:
continue
if random.random() > 0.5:
source = random.choice(
TRANSLATION_PROMPT[lang_code]
).format(row[src])
source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[src])
self.pairs.append((source, row[tgt]))
else:# translating in reverse direction
lang_code = 'en'
source = random.choice(
TRANSLATION_PROMPT[lang_code]
).format(row[tgt])
else: # translating in reverse direction
lang_code = "en"
source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[tgt])
self.pairs.append((source, row[src]))
class TEDTalk(TranslationPair):
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean
def __init__(self, pair='de-ja', split='train', year='2016') -> None:
def __init__(self, pair="de-ja", split="train", year="2016") -> None:
super().__init__()
dataset = load_dataset('ted_talks_iwslt', language_pair=pair.split('-'), year=year)[split]
src, tgt = pair.split('-')
dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split]
src, tgt = pair.split("-")
for row in dataset:
row = row['translation']
row = row["translation"]
if random.random() > 0.5:
source = random.choice(
TRANSLATION_PROMPT[tgt]
).format(row[src])
source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src])
self.pairs.append((source, row[tgt]))
else:# translating in reverse direction
source = random.choice(
TRANSLATION_PROMPT[src]
).format(row[tgt])
else: # translating in reverse direction
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
self.pairs.append((source, row[src]))