Files
Open-Assistant/model/supervised_finetuning/custom_datasets/translation.py
T
2023-01-22 00:56:17 +00:00

146 lines
5.4 KiB
Python

"""
List of translation dataset
GroNLP/divemt
fill in the blanks : https://huggingface.co/datasets/m_lama
"""
import random
from custom_datasets.formatting import format_pair
from datasets import load_dataset
from torch.utils.data import Dataset
# postfix prompt
TRANSLATION_PROMPT = {
"zh": [ # simplified or any chinese which was not mentioned
"Translate to chinese simplified: {}",
"{}, translate to chinese",
"{} give me the chinese translation",
"翻译成中文: {}",
"{} 这句中文翻译怎麽写?",
"我需要这句话的中文翻译: {}",
],
"zh-tw": [ # WMT code
"{}. Translate to chinese traditional",
"{}, translate to chinese",
"{}. get chinese translation",
"中文翻譯: {}",
"幫我翻譯成中文: '{}'",
"{} 這句中文翻譯怎麼寫?",
],
"ja": [
"{}: help me translate to japanese",
"Need japanese translation: {}",
"{}: にほんごやくをよこす",
"{}: にほんごやくをおくれ",
"{}: にほんごやくを じょす",
"give me the japanese translation, {}",
],
"de": [
"{}: translate to german",
"give me the german translation {}",
"I want german translation {}",
"{}, ins Deutsche übersetzen",
"{}, Übersetzen ins Deutsche",
],
"fr": [
"{}. translate to french",
"{} write in french",
"{} french translation",
"{} ,donnez moi la traduction française",
],
"ko": [
"{}. translate to Korean",
"how do we write in korean: {}",
"give me the korean translation: {}",
"{}, 한국어 번역을 해주세요",
],
"ms": [
"{} translate to malay",
"{} how do we write in Malay",
"{} give me the malay translation",
"{} , berikan saya terjemahan dalam bahasa melayu",
"{}, Jemahan di bahasa melayu" "{}, jemahkan ayat ini kepada bahasa melayu",
],
"en": ["{}. translate to english", "{} write in english", "english translation: '{}'"],
"ru": ["помогите мне перевести это на русский : {}", "{} перевести на русский язык", "russian translation: '{}'"],
"tr": ["{}. türkçeye çevi̇ri̇n", "{} write in turkish", "turkish translation: '{}'", "türkçeye çevi̇rmek: {}"],
"it": ["{}. translate to italian", "{} write in italian", "italian translation: '{}'"],
"nl": ["{}. translate to dutch", "{} write in dutch", "dutch translation: '{}'"],
"vi": ["{}. Dịch sang tiếng việt nam", "{} write in vietnamese", "vietnamese translation: '{}'"],
"ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"],
}
class TranslationPair(Dataset):
def __init__(self) -> None:
super().__init__()
self.pairs = []
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
return format_pair(self.pairs[index])
class WMT2019(TranslationPair):
def __init__(self, pair="zh-en", split="train") -> None:
super().__init__()
dataset = load_dataset("wmt19", pair)[split]
self.pairs = []
src, tgt = pair.split("-")
for row in dataset:
row = row["translation"]
if random.random() > 0.5:
source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src])
self.pairs.append((source, row[tgt]))
else: # translating in reverse direction
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
self.pairs.append((source, row[src]))
if len(self.pairs) > 100000:
break
class DiveMT(TranslationPair):
REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"}
def __init__(self, split="train") -> None:
super().__init__()
dataset = load_dataset("GroNLP/divemt", "main")[split]
tgt, src = "tgt_text", "src_text"
for row in dataset:
# ISO 639-2
lang_code_2 = row["subject_id"].split("_")[0]
lang_code = self.REMAP[lang_code_2]
if lang_code not in TRANSLATION_PROMPT:
continue
if random.random() > 0.5:
source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[src])
self.pairs.append((source, row[tgt]))
else: # translating in reverse direction
lang_code = "en"
source = random.choice(TRANSLATION_PROMPT[lang_code]).format(row[tgt])
self.pairs.append((source, row[src]))
class TEDTalk(TranslationPair):
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean
def __init__(self, pair="de-ja", split="train", year="2016") -> None:
super().__init__()
dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split]
src, tgt = pair.split("-")
for row in dataset:
row = row["translation"]
if random.random() > 0.5:
source = random.choice(TRANSLATION_PROMPT[tgt]).format(row[src])
self.pairs.append((source, row[tgt]))
else: # translating in reverse direction
source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt])
self.pairs.append((source, row[src]))