diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index 2c5c7ee2..2acf9106 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -3,6 +3,7 @@ """ import json import os +import random import re from urllib.request import urlopen @@ -115,7 +116,7 @@ class QADataset(Dataset): "reddit_asks": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_asks"}, } - def __init__(self, dataset, cache_dir, split): + def __init__(self, dataset, cache_dir, split, mix_prob=0.2): self.no_val = False if dataset in self.DATASET_FORMAT_MAPPING: context = self.DATASET_FORMAT_MAPPING[dataset] @@ -137,11 +138,25 @@ class QADataset(Dataset): self.dataset = load_dataset(context["name"], **context["params"]) else: raise ValueError("Unknown dataset : " + dataset) + self.length = len(self.dataset) + self.mix_prob = mix_prob def __len__(self): - return len(self.dataset) + return self.length def __getitem__(self, idx): + if self.mix_prob > 0 and random.random() < self.mix_prob and idx > 5 and idx < (self.length - 5): + + additional = random.randint(0, 10) - 5 + while additional == idx: + additional = random.randint(0, 10) - 5 + + answer_pair = self.index_fn(self.dataset[additional + idx]) + history_text = "".join(format_pair(answer_pair)) + question, answer = self.index_fn(self.dataset[idx]) + question = history_text + question + return format_pair((question, answer)) + data = self.dataset[idx] return format_pair(self.index_fn(data)) @@ -297,8 +312,9 @@ class JokeExplaination(Dataset): name = "joke" url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl" - def __init__(self, cache_dir) -> None: + def __init__(self, cache_dir, mix_prob=0.2) -> None: super().__init__() + self.mix_prob = mix_prob os.makedirs(cache_dir, exist_ok=True) joke_explain_filename = os.path.join(cache_dir, "joke_explaination.jsonl") if not os.path.exists(joke_explain_filename): @@ -319,9 +335,62 @@ class JokeExplaination(Dataset): if len(question) > 0 and len(answer) > 0: self.pairs.append((question, answer)) + self.length = len(self.pairs) def __len__(self): - return len(self.pairs) + return self.length def __getitem__(self, index): + if random.random() < self.mix_prob and index > 5 and index < (self.length - 5): + additional = random.randint(0, 10) - 5 + while additional == index: + additional = random.randint(0, 10) - 5 + + history_text = "".join(format_pair(self.pairs[additional + index])) + question, answer = self.pairs[index] + question = history_text + question + return format_pair((question, answer)) + + return format_pair(self.pairs[index]) + + +class TranslatedQA(Dataset): + + name = "oa_translated" + + def __init__(self, cache_dir, mix_prob=0.2) -> None: + super().__init__() + self.mix_prob = mix_prob + os.makedirs(cache_dir, exist_ok=True) + path = os.path.join(cache_dir, "oa_translated") + os.makedirs(path, exist_ok=True) + import glob + + self.pairs = [] + for translated_jsonl in glob.glob(os.path.join(path, "*.jsonl")): + with open(translated_jsonl, "r") as f: + for line in f: + data = json.loads(line) + if "Python " in data["text"]: + continue + # incorrect, TODO: fix later + for convo_round in data["translate"]: + self.pairs.append((convo_round["human"], convo_round["answer"])) + + self.length = len(self.pairs) + + def __len__(self): + return self.length + + def __getitem__(self, index): + if random.random() < self.mix_prob and index > 5 and index < (self.length - 5): + additional = random.randint(0, 10) - 5 + while additional == index: + additional = random.randint(0, 10) - 5 + + history_text = "".join(format_pair(self.pairs[additional + index])) + question, answer = self.pairs[index] + question = history_text + question + return format_pair((question, answer)) + return format_pair(self.pairs[index]) diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py index 834fa16c..aa644691 100644 --- a/model/supervised_finetuning/custom_datasets/summarization.py +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -57,7 +57,7 @@ def index_summary_merge(text, summary): class SummarizationDataset(Dataset): def __init__(self, dataset, cache_dir, split, max_words=512): self.name = dataset - if summarization_config_mapping[dataset][0] in ["billsum", "tldr_news"] & split == "validation": + if (dataset in ["billsum", "tldr_news"]) and (split == "validation"): split = "test" self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) self.text_column, self.summary_column = summarization_name_mapping[dataset] diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py index f9a71a8e..18cb9a09 100644 --- a/model/supervised_finetuning/custom_datasets/translation.py +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -75,20 +75,34 @@ TRANSLATION_PROMPT = { class TranslationPair(Dataset): - def __init__(self) -> None: + def __init__(self, mix_prob=0.2) -> None: super().__init__() self.pairs = [] + self.length = -1 + self.mix_prob = mix_prob def __len__(self): + if self.length < 0: + self.length = len(self.pairs) return len(self.pairs) def __getitem__(self, index): + if random.random() < self.mix_prob and index > 5 and index < (self.length - 5): + additional = random.randint(0, 10) - 5 + while additional == index: + additional = random.randint(0, 10) - 5 + + history_text = "".join(format_pair(self.pairs[additional + index])) + question, answer = self.pairs[index] + question = history_text + question + return format_pair((question, answer)) + return format_pair(self.pairs[index]) class WMT2019(TranslationPair): - def __init__(self, pair="zh-en", split="train") -> None: - super().__init__() + def __init__(self, pair="zh-en", split="train", mix_prob=0.2) -> None: + super().__init__(mix_prob=mix_prob) dataset = load_dataset("wmt19", pair)[split] self.pairs = [] src, tgt = pair.split("-") @@ -108,8 +122,8 @@ class DiveMT(TranslationPair): REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"} - def __init__(self, split="train") -> None: - super().__init__() + def __init__(self, split="train", mix_prob=0.2) -> None: + super().__init__(mix_prob=mix_prob) dataset = load_dataset("GroNLP/divemt", "main")[split] tgt, src = "tgt_text", "src_text" for row in dataset: @@ -131,8 +145,8 @@ class DiveMT(TranslationPair): class TEDTalk(TranslationPair): # NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean - def __init__(self, pair="de-ja", split="train", year="2016") -> None: - super().__init__() + def __init__(self, pair="de-ja", split="train", year="2016", mix_prob=0.2) -> None: + super().__init__(mix_prob=mix_prob) dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split] src, tgt = pair.split("-") for row in dataset: