mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
[feature] Add mix conversation augmentation
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from urllib.request import urlopen
|
||||
|
||||
@@ -115,7 +116,7 @@ class QADataset(Dataset):
|
||||
"reddit_asks": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_asks"},
|
||||
}
|
||||
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
def __init__(self, dataset, cache_dir, split, mix_prob=0.2):
|
||||
self.no_val = False
|
||||
if dataset in self.DATASET_FORMAT_MAPPING:
|
||||
context = self.DATASET_FORMAT_MAPPING[dataset]
|
||||
@@ -137,11 +138,25 @@ class QADataset(Dataset):
|
||||
self.dataset = load_dataset(context["name"], **context["params"])
|
||||
else:
|
||||
raise ValueError("Unknown dataset : " + dataset)
|
||||
self.length = len(self.dataset)
|
||||
self.mix_prob = mix_prob
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.mix_prob > 0 and random.random() < self.mix_prob and idx > 5 and idx < (self.length - 5):
|
||||
|
||||
additional = random.randint(0, 10) - 5
|
||||
while additional == idx:
|
||||
additional = random.randint(0, 10) - 5
|
||||
|
||||
answer_pair = self.index_fn(self.dataset[additional + idx])
|
||||
history_text = "".join(format_pair(answer_pair))
|
||||
question, answer = self.index_fn(self.dataset[idx])
|
||||
question = history_text + question
|
||||
return format_pair((question, answer))
|
||||
|
||||
data = self.dataset[idx]
|
||||
return format_pair(self.index_fn(data))
|
||||
|
||||
@@ -297,8 +312,9 @@ class JokeExplaination(Dataset):
|
||||
name = "joke"
|
||||
url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl"
|
||||
|
||||
def __init__(self, cache_dir) -> None:
|
||||
def __init__(self, cache_dir, mix_prob=0.2) -> None:
|
||||
super().__init__()
|
||||
self.mix_prob = mix_prob
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
joke_explain_filename = os.path.join(cache_dir, "joke_explaination.jsonl")
|
||||
if not os.path.exists(joke_explain_filename):
|
||||
@@ -319,9 +335,62 @@ class JokeExplaination(Dataset):
|
||||
|
||||
if len(question) > 0 and len(answer) > 0:
|
||||
self.pairs.append((question, answer))
|
||||
self.length = len(self.pairs)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, index):
|
||||
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
|
||||
additional = random.randint(0, 10) - 5
|
||||
while additional == index:
|
||||
additional = random.randint(0, 10) - 5
|
||||
|
||||
history_text = "".join(format_pair(self.pairs[additional + index]))
|
||||
question, answer = self.pairs[index]
|
||||
question = history_text + question
|
||||
return format_pair((question, answer))
|
||||
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
class TranslatedQA(Dataset):
|
||||
|
||||
name = "oa_translated"
|
||||
|
||||
def __init__(self, cache_dir, mix_prob=0.2) -> None:
|
||||
super().__init__()
|
||||
self.mix_prob = mix_prob
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
path = os.path.join(cache_dir, "oa_translated")
|
||||
os.makedirs(path, exist_ok=True)
|
||||
import glob
|
||||
|
||||
self.pairs = []
|
||||
for translated_jsonl in glob.glob(os.path.join(path, "*.jsonl")):
|
||||
with open(translated_jsonl, "r") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
if "Python " in data["text"]:
|
||||
continue
|
||||
# incorrect, TODO: fix later
|
||||
for convo_round in data["translate"]:
|
||||
self.pairs.append((convo_round["human"], convo_round["answer"]))
|
||||
|
||||
self.length = len(self.pairs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, index):
|
||||
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
|
||||
additional = random.randint(0, 10) - 5
|
||||
while additional == index:
|
||||
additional = random.randint(0, 10) - 5
|
||||
|
||||
history_text = "".join(format_pair(self.pairs[additional + index]))
|
||||
question, answer = self.pairs[index]
|
||||
question = history_text + question
|
||||
return format_pair((question, answer))
|
||||
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
@@ -57,7 +57,7 @@ def index_summary_merge(text, summary):
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split, max_words=512):
|
||||
self.name = dataset
|
||||
if summarization_config_mapping[dataset][0] in ["billsum", "tldr_news"] & split == "validation":
|
||||
if (dataset in ["billsum", "tldr_news"]) and (split == "validation"):
|
||||
split = "test"
|
||||
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.text_column, self.summary_column = summarization_name_mapping[dataset]
|
||||
|
||||
@@ -75,20 +75,34 @@ TRANSLATION_PROMPT = {
|
||||
|
||||
|
||||
class TranslationPair(Dataset):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, mix_prob=0.2) -> None:
|
||||
super().__init__()
|
||||
self.pairs = []
|
||||
self.length = -1
|
||||
self.mix_prob = mix_prob
|
||||
|
||||
def __len__(self):
|
||||
if self.length < 0:
|
||||
self.length = len(self.pairs)
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
|
||||
additional = random.randint(0, 10) - 5
|
||||
while additional == index:
|
||||
additional = random.randint(0, 10) - 5
|
||||
|
||||
history_text = "".join(format_pair(self.pairs[additional + index]))
|
||||
question, answer = self.pairs[index]
|
||||
question = history_text + question
|
||||
return format_pair((question, answer))
|
||||
|
||||
return format_pair(self.pairs[index])
|
||||
|
||||
|
||||
class WMT2019(TranslationPair):
|
||||
def __init__(self, pair="zh-en", split="train") -> None:
|
||||
super().__init__()
|
||||
def __init__(self, pair="zh-en", split="train", mix_prob=0.2) -> None:
|
||||
super().__init__(mix_prob=mix_prob)
|
||||
dataset = load_dataset("wmt19", pair)[split]
|
||||
self.pairs = []
|
||||
src, tgt = pair.split("-")
|
||||
@@ -108,8 +122,8 @@ class DiveMT(TranslationPair):
|
||||
|
||||
REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"}
|
||||
|
||||
def __init__(self, split="train") -> None:
|
||||
super().__init__()
|
||||
def __init__(self, split="train", mix_prob=0.2) -> None:
|
||||
super().__init__(mix_prob=mix_prob)
|
||||
dataset = load_dataset("GroNLP/divemt", "main")[split]
|
||||
tgt, src = "tgt_text", "src_text"
|
||||
for row in dataset:
|
||||
@@ -131,8 +145,8 @@ class DiveMT(TranslationPair):
|
||||
class TEDTalk(TranslationPair):
|
||||
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean
|
||||
|
||||
def __init__(self, pair="de-ja", split="train", year="2016") -> None:
|
||||
super().__init__()
|
||||
def __init__(self, pair="de-ja", split="train", year="2016", mix_prob=0.2) -> None:
|
||||
super().__init__(mix_prob=mix_prob)
|
||||
dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split]
|
||||
src, tgt = pair.split("-")
|
||||
for row in dataset:
|
||||
|
||||
Reference in New Issue
Block a user