[feature] Add mix conversation augmentation

This commit is contained in:
theblackcat102
2023-02-01 22:14:11 +00:00
parent 638d8c1572
commit f8eba68544
3 changed files with 95 additions and 12 deletions
@@ -3,6 +3,7 @@
"""
import json
import os
import random
import re
from urllib.request import urlopen
@@ -115,7 +116,7 @@ class QADataset(Dataset):
"reddit_asks": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_asks"},
}
def __init__(self, dataset, cache_dir, split):
def __init__(self, dataset, cache_dir, split, mix_prob=0.2):
self.no_val = False
if dataset in self.DATASET_FORMAT_MAPPING:
context = self.DATASET_FORMAT_MAPPING[dataset]
@@ -137,11 +138,25 @@ class QADataset(Dataset):
self.dataset = load_dataset(context["name"], **context["params"])
else:
raise ValueError("Unknown dataset : " + dataset)
self.length = len(self.dataset)
self.mix_prob = mix_prob
def __len__(self):
return len(self.dataset)
return self.length
def __getitem__(self, idx):
if self.mix_prob > 0 and random.random() < self.mix_prob and idx > 5 and idx < (self.length - 5):
additional = random.randint(0, 10) - 5
while additional == idx:
additional = random.randint(0, 10) - 5
answer_pair = self.index_fn(self.dataset[additional + idx])
history_text = "".join(format_pair(answer_pair))
question, answer = self.index_fn(self.dataset[idx])
question = history_text + question
return format_pair((question, answer))
data = self.dataset[idx]
return format_pair(self.index_fn(data))
@@ -297,8 +312,9 @@ class JokeExplaination(Dataset):
name = "joke"
url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl"
def __init__(self, cache_dir) -> None:
def __init__(self, cache_dir, mix_prob=0.2) -> None:
super().__init__()
self.mix_prob = mix_prob
os.makedirs(cache_dir, exist_ok=True)
joke_explain_filename = os.path.join(cache_dir, "joke_explaination.jsonl")
if not os.path.exists(joke_explain_filename):
@@ -319,9 +335,62 @@ class JokeExplaination(Dataset):
if len(question) > 0 and len(answer) > 0:
self.pairs.append((question, answer))
self.length = len(self.pairs)
def __len__(self):
return len(self.pairs)
return self.length
def __getitem__(self, index):
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
additional = random.randint(0, 10) - 5
while additional == index:
additional = random.randint(0, 10) - 5
history_text = "".join(format_pair(self.pairs[additional + index]))
question, answer = self.pairs[index]
question = history_text + question
return format_pair((question, answer))
return format_pair(self.pairs[index])
class TranslatedQA(Dataset):
name = "oa_translated"
def __init__(self, cache_dir, mix_prob=0.2) -> None:
super().__init__()
self.mix_prob = mix_prob
os.makedirs(cache_dir, exist_ok=True)
path = os.path.join(cache_dir, "oa_translated")
os.makedirs(path, exist_ok=True)
import glob
self.pairs = []
for translated_jsonl in glob.glob(os.path.join(path, "*.jsonl")):
with open(translated_jsonl, "r") as f:
for line in f:
data = json.loads(line)
if "Python " in data["text"]:
continue
# incorrect, TODO: fix later
for convo_round in data["translate"]:
self.pairs.append((convo_round["human"], convo_round["answer"]))
self.length = len(self.pairs)
def __len__(self):
return self.length
def __getitem__(self, index):
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
additional = random.randint(0, 10) - 5
while additional == index:
additional = random.randint(0, 10) - 5
history_text = "".join(format_pair(self.pairs[additional + index]))
question, answer = self.pairs[index]
question = history_text + question
return format_pair((question, answer))
return format_pair(self.pairs[index])
@@ -57,7 +57,7 @@ def index_summary_merge(text, summary):
class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split, max_words=512):
self.name = dataset
if summarization_config_mapping[dataset][0] in ["billsum", "tldr_news"] & split == "validation":
if (dataset in ["billsum", "tldr_news"]) and (split == "validation"):
split = "test"
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.text_column, self.summary_column = summarization_name_mapping[dataset]
@@ -75,20 +75,34 @@ TRANSLATION_PROMPT = {
class TranslationPair(Dataset):
def __init__(self) -> None:
def __init__(self, mix_prob=0.2) -> None:
super().__init__()
self.pairs = []
self.length = -1
self.mix_prob = mix_prob
def __len__(self):
if self.length < 0:
self.length = len(self.pairs)
return len(self.pairs)
def __getitem__(self, index):
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
additional = random.randint(0, 10) - 5
while additional == index:
additional = random.randint(0, 10) - 5
history_text = "".join(format_pair(self.pairs[additional + index]))
question, answer = self.pairs[index]
question = history_text + question
return format_pair((question, answer))
return format_pair(self.pairs[index])
class WMT2019(TranslationPair):
def __init__(self, pair="zh-en", split="train") -> None:
super().__init__()
def __init__(self, pair="zh-en", split="train", mix_prob=0.2) -> None:
super().__init__(mix_prob=mix_prob)
dataset = load_dataset("wmt19", pair)[split]
self.pairs = []
src, tgt = pair.split("-")
@@ -108,8 +122,8 @@ class DiveMT(TranslationPair):
REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"}
def __init__(self, split="train") -> None:
super().__init__()
def __init__(self, split="train", mix_prob=0.2) -> None:
super().__init__(mix_prob=mix_prob)
dataset = load_dataset("GroNLP/divemt", "main")[split]
tgt, src = "tgt_text", "src_text"
for row in dataset:
@@ -131,8 +145,8 @@ class DiveMT(TranslationPair):
class TEDTalk(TranslationPair):
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean
def __init__(self, pair="de-ja", split="train", year="2016") -> None:
super().__init__()
def __init__(self, pair="de-ja", split="train", year="2016", mix_prob=0.2) -> None:
super().__init__(mix_prob=mix_prob)
dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split]
src, tgt = pair.split("-")
for row in dataset: