mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-04 17:20:19 +08:00
[fix] linter fix
This commit is contained in:
@@ -60,7 +60,10 @@ python trainer.py --configs defaults your-model-name --deepspeed
|
||||
|
||||
## Dataset choices
|
||||
|
||||
To specify which translation pair for [WMT](https://huggingface.co/datasets/wmt19) and [TED Talk](https://huggingface.co/datasets/ted_talks_iwslt) translation simply add the supported language pair at the postfix
|
||||
To specify which translation pair for
|
||||
[WMT](https://huggingface.co/datasets/wmt19) and
|
||||
[TED Talk](https://huggingface.co/datasets/ted_talks_iwslt) translation simply
|
||||
add the supported language pair at the postfix
|
||||
|
||||
```
|
||||
datasets:
|
||||
|
||||
@@ -4,23 +4,24 @@ currently dataset can be divided into 3 classes
|
||||
|
||||
- language knowledge
|
||||
|
||||
- summarization
|
||||
- summarization
|
||||
|
||||
- translation
|
||||
- translation
|
||||
|
||||
- dialogue : don't let user know you are a robot
|
||||
|
||||
- STEM : knowledge about the world
|
||||
|
||||
- coding
|
||||
|
||||
- world knowledge <= ideally we want to handle this via prefix context
|
||||
- coding
|
||||
|
||||
- world knowledge <= ideally we want to handle this via prefix context
|
||||
|
||||
Issues and TODO:
|
||||
|
||||
* as dataset are growing, how can we update this section less
|
||||
- as dataset are growing, how can we update this section less
|
||||
|
||||
* ideally we can update the config yaml and new dataset will be download from hub
|
||||
|
||||
* one possible idea is we upload the trasform format of these dataset to the OA hub
|
||||
- ideally we can update the config yaml and new dataset will be download from
|
||||
hub
|
||||
|
||||
- one possible idea is we upload the trasform format of these dataset to the
|
||||
OA hub
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
High level functions for model training
|
||||
"""
|
||||
from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT, SODADialogue
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT
|
||||
from custom_datasets.summarization import SummarizationDataset
|
||||
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
|
||||
from custom_datasets.translation import WMT2019, DiveMT, TEDTalk
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
'''
|
||||
"""
|
||||
SFT dataset to reject toxic questions
|
||||
|
||||
'''
|
||||
"""
|
||||
import random
|
||||
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class ProsocialDialogueExplaination(Dataset):
|
||||
name = "prosocial_explain"
|
||||
TEMPLATE = [
|
||||
@@ -15,26 +17,31 @@ class ProsocialDialogueExplaination(Dataset):
|
||||
("'{0}' Why is this sentence {1}", "This sentence may need {1} because {0}"),
|
||||
("'{0}' Explain to me why we need to {1} of this sentence", "We need to be {1} of this sentence because {0}"),
|
||||
]
|
||||
def __init__(self, split='train', cache_dir='.cache') -> None:
|
||||
|
||||
def __init__(self, split="train", cache_dir=".cache") -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
|
||||
self.pairs = []
|
||||
for row in dataset:
|
||||
for safety_annotation, safe_answer in zip(row['safety_annotations'], row['safety_annotation_reasons']):
|
||||
for safety_annotation, safe_answer in zip(row["safety_annotations"], row["safety_annotation_reasons"]):
|
||||
(prompt_template, answer_template) = random.choice(self.TEMPLATE)
|
||||
self.pairs.append((
|
||||
prompt_template.format(row['context'],safety_annotation),
|
||||
answer_template.format( safe_answer, safety_annotation)
|
||||
))
|
||||
self.pairs.append(
|
||||
(
|
||||
prompt_template.format(row["context"], safety_annotation),
|
||||
answer_template.format(safe_answer, safety_annotation),
|
||||
)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.pairs[idx]
|
||||
|
||||
|
||||
class ProsocialDialogue(Dataset):
|
||||
name = "prosocial_dialogue"
|
||||
'''
|
||||
"""
|
||||
ProsocialDialog, we set up a human-AI collaborative data creation framework,
|
||||
where GPT-3 generates the potentially unsafe utterances, and crowdworkers
|
||||
provide prosocial responses to them. This approach allows us to circumvent
|
||||
@@ -43,20 +50,16 @@ class ProsocialDialogue(Dataset):
|
||||
between humans
|
||||
(2) asking humans to write unethical, toxic, or problematic utterances could result
|
||||
in psychological harms (Roberts, 2017; Steiger et al., 2021).
|
||||
'''
|
||||
"""
|
||||
PREFIX = "<prefix>You are now a prosocial chatbot, be caution and casual when reply</prefix>"
|
||||
|
||||
|
||||
def __init__(self, split='train', cache_dir='.cache') -> None:
|
||||
def __init__(self, split="train", cache_dir=".cache") -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
|
||||
self.pairs = []
|
||||
for row in dataset:
|
||||
for answer in row['rots']:
|
||||
self.pairs.append((
|
||||
self.PREFIX+row['context'],
|
||||
answer
|
||||
))
|
||||
for answer in row["rots"]:
|
||||
self.pairs.append((self.PREFIX + row["context"], answer))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
'''
|
||||
"""
|
||||
List of translation dataset
|
||||
|
||||
GroNLP/divemt
|
||||
|
||||
fill in the blanks : https://huggingface.co/datasets/m_lama
|
||||
|
||||
'''
|
||||
"""
|
||||
import random
|
||||
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# postfix prompt
|
||||
TRANSLATION_PROMPT = {
|
||||
"zh": [ # simplified or any chinese which was not mentioned
|
||||
"zh": [ # simplified or any chinese which was not mentioned
|
||||
"Translate to chinese simplified: {}",
|
||||
"{}, translate to chinese",
|
||||
"{} give me the chinese translation",
|
||||
@@ -20,7 +21,7 @@ TRANSLATION_PROMPT = {
|
||||
"{} 这句中文翻译怎麽写?",
|
||||
"我需要这句话的中文翻译: {}",
|
||||
],
|
||||
"zh-tw": [ # WMT code
|
||||
"zh-tw": [ # WMT code
|
||||
"{}. Translate to chinese traditional",
|
||||
"{}, translate to chinese",
|
||||
"{}. get chinese translation",
|
||||
@@ -47,7 +48,8 @@ TRANSLATION_PROMPT = {
|
||||
"{}. translate to french",
|
||||
"{} write in french",
|
||||
"{} french translation",
|
||||
"{} ,donnez moi la traduction française"],
|
||||
"{} ,donnez moi la traduction française",
|
||||
],
|
||||
"ko": [
|
||||
"{}. translate to Korean",
|
||||
"how do we write in korean: {}",
|
||||
@@ -59,8 +61,7 @@ TRANSLATION_PROMPT = {
|
||||
"{} how do we write in Malay",
|
||||
"{} give me the malay translation",
|
||||
"{} , berikan saya terjemahan dalam bahasa melayu",
|
||||
"{}, Jemahan di bahasa melayu"
|
||||
"{}, jemahkan ayat ini kepada bahasa melayu"
|
||||
"{}, Jemahan di bahasa melayu" "{}, jemahkan ayat ini kepada bahasa melayu",
|
||||
],
|
||||
"en": ["{}. translate to english", "{} write in english", "english translation: '{}'"],
|
||||
"ru": ["помогите мне перевести это на русский : {}", "{} перевести на русский язык", "russian translation: '{}'"],
|
||||
@@ -70,6 +71,8 @@ TRANSLATION_PROMPT = {
|
||||
"vi": ["{}. Dịch sang tiếng việt nam", "{} write in vietnamese", "vietnamese translation: '{}'"],
|
||||
"ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"],
|
||||
}
|
||||
|
||||
|
||||
class TranslationPair(Dataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -80,79 +83,60 @@ class TranslationPair(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.pairs[index]
|
||||
|
||||
|
||||
|
||||
class WMT2019(TranslationPair):
|
||||
|
||||
def __init__(self, pair='zh-en', split='train') -> None:
|
||||
def __init__(self, pair="zh-en", split="train") -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset('wmt19', pair)[split]
|
||||
dataset = load_dataset("wmt19", pair)[split]
|
||||
self.pairs = []
|
||||
src, tgt = pair.split('-')
|
||||
src, tgt = pair.split("-")
|
||||
for row in dataset:
|
||||
row = row['translation']
|
||||
row = row["translation"]
|
||||
if random.random() > 0.5:
|
||||
source = random.choice(
|
||||
TRANSLATION_PROMPT[tgt]
|
||||
).format(row[src])
|
||||
source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src])
|
||||
self.pairs.append((source, row[tgt]))
|
||||
else:# translating in reverse direction
|
||||
source = random.choice(
|
||||
TRANSLATION_PROMPT[src]
|
||||
).format(row[tgt])
|
||||
else: # translating in reverse direction
|
||||
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
|
||||
self.pairs.append((source, row[src]))
|
||||
|
||||
|
||||
class DiveMT(TranslationPair):
|
||||
|
||||
REMAP = {
|
||||
'tur': 'tr',
|
||||
'ita': 'it',
|
||||
'ukr': 'uk',
|
||||
'nld': 'nl',
|
||||
'vie': 'vi',
|
||||
'ara': 'ar'
|
||||
}
|
||||
REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"}
|
||||
|
||||
def __init__(self, split='train') -> None:
|
||||
def __init__(self, split="train") -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset('GroNLP/divemt', 'main')[split]
|
||||
tgt, src = 'tgt_text', 'src_text'
|
||||
dataset = load_dataset("GroNLP/divemt", "main")[split]
|
||||
tgt, src = "tgt_text", "src_text"
|
||||
for row in dataset:
|
||||
# ISO 639-2
|
||||
lang_code_2 = row['subject_id'].split('_')[0]
|
||||
lang_code_2 = row["subject_id"].split("_")[0]
|
||||
lang_code = self.REMAP[lang_code_2]
|
||||
if lang_code not in TRANSLATION_PROMPT:
|
||||
continue
|
||||
|
||||
if random.random() > 0.5:
|
||||
source = random.choice(
|
||||
TRANSLATION_PROMPT[lang_code]
|
||||
).format(row[src])
|
||||
source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[src])
|
||||
self.pairs.append((source, row[tgt]))
|
||||
else:# translating in reverse direction
|
||||
lang_code = 'en'
|
||||
source = random.choice(
|
||||
TRANSLATION_PROMPT[lang_code]
|
||||
).format(row[tgt])
|
||||
else: # translating in reverse direction
|
||||
lang_code = "en"
|
||||
source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[tgt])
|
||||
self.pairs.append((source, row[src]))
|
||||
|
||||
|
||||
class TEDTalk(TranslationPair):
|
||||
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean
|
||||
|
||||
def __init__(self, pair='de-ja', split='train', year='2016') -> None:
|
||||
def __init__(self, pair="de-ja", split="train", year="2016") -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset('ted_talks_iwslt', language_pair=pair.split('-'), year=year)[split]
|
||||
src, tgt = pair.split('-')
|
||||
dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split]
|
||||
src, tgt = pair.split("-")
|
||||
for row in dataset:
|
||||
row = row['translation']
|
||||
row = row["translation"]
|
||||
if random.random() > 0.5:
|
||||
source = random.choice(
|
||||
TRANSLATION_PROMPT[tgt]
|
||||
).format(row[src])
|
||||
source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src])
|
||||
self.pairs.append((source, row[tgt]))
|
||||
else:# translating in reverse direction
|
||||
source = random.choice(
|
||||
TRANSLATION_PROMPT[src]
|
||||
).format(row[tgt])
|
||||
else: # translating in reverse direction
|
||||
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
|
||||
self.pairs.append((source, row[src]))
|
||||
|
||||
Reference in New Issue
Block a user