diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 2e1e4b30..558ec502 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -43,7 +43,7 @@ def get_one_dataset(conf, dataset_name): if dataset_name == "debate_sum": train, eval = train_val_dataset(train, val_split=0.2) else: - val_name = "validation" if dataset_name not in ["billsum"] else "test" + val_name = "validation" if dataset_name not in ["billsum", "tldr_news"] else "test" eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name) elif "ted_trans" in dataset_name: language_pair = dataset_name.split("_")[-1] diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index 719fa0d6..c96ed576 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -3,7 +3,6 @@ from typing import Optional, Union import numpy as np import torch -from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS from torch.nn import functional as F from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase @@ -23,15 +22,8 @@ class DialogueDataCollator: flatten_messages = [] label_masks = [] - for feature_one in features: - assert len(feature_one) % 2 == 0, "Number of messages must be even" - # TODO: we should push this to dataset __getitem__ - messages = [ - (QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "") - + x - + (QA_SPECIAL_TOKENS["Answer"] if i % 2 == 0 else "") - for i, x in enumerate(feature_one) - ] + for messages in features: + messages = list(messages) # Add a way for the model to terminate generation # When we predict the start of a new expected question, we want to be able to stop generation diff --git a/model/supervised_finetuning/custom_datasets/formatting.py b/model/supervised_finetuning/custom_datasets/formatting.py new file mode 100644 index 00000000..a6c1c0d8 --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/formatting.py @@ -0,0 +1,5 @@ +QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} + + +def format_pair(pair): + return "{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pair[0], QA_SPECIAL_TOKENS["Answer"]), pair[1] diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 4a1d83a3..1c823934 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -2,6 +2,7 @@ import json import os from urllib.request import urlopen +from custom_datasets.formatting import format_pair from torch.utils.data import Dataset @@ -49,8 +50,7 @@ class PromptGeneratedDataset(Dataset): return len(self.pairs) def __getitem__(self, index): - question, answer = self.pairs[index] - return question, answer + return format_pair(self.pairs[index]) class InstructionTuning(Dataset): @@ -101,5 +101,4 @@ class InstructionTuning(Dataset): return len(self.pairs) def __getitem__(self, index): - question, answer = self.pairs[index] - return question, answer + return format_pair(self.pairs[index]) diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index 7d9c7f48..47b1c247 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -7,14 +7,13 @@ import re from urllib.request import urlopen import numpy as np +from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair from datasets import load_dataset from torch.utils.data import Dataset # @agoryuno contributed this re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]") -QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} - def index_squad_v2(example): if len(example["answers"]["text"]): @@ -78,7 +77,7 @@ class QADataset(Dataset): def __getitem__(self, idx): data = self.dataset[idx] - return self.index_fn(data) + return format_pair(self.index_fn(data)) class WebGPT(Dataset): @@ -111,7 +110,7 @@ class WebGPT(Dataset): def __getitem__(self, index): question = self.index2question[index] answer = self.questions[question] - return [question, answer] + return format_pair((question, answer)) class SODA(Dataset): @@ -121,14 +120,14 @@ class SODA(Dataset): def process_soda_convo(self, data): pairs = [] play_as = data["speakers"][1] - prefix = "{}{}. {}{}".format( - QA_SPECIAL_TOKENS["StartPrefix"], - data["narrative"], - "your name {}".format(play_as), - QA_SPECIAL_TOKENS["EndPrefix"], - ) question, answer = "", "" prefix, postfix = "", "" + dialogue_bg = "{}{} {}{}".format( + QA_SPECIAL_TOKENS["StartPrefix"], + data["narrative"], + "your are {}".format(play_as), + QA_SPECIAL_TOKENS["EndPrefix"], + ) previous_chat = [] for idx, convo in enumerate(data["dialogue"]): @@ -138,14 +137,20 @@ class SODA(Dataset): else: answer = convo postfix = data["speakers"][idx] + if len(question) and len(answer) and prefix != postfix and postfix == play_as: history = "".join( - ["{}{}{}".format(p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) for p in previous_chat] + [ + "{}{}{}{}".format(QA_SPECIAL_TOKENS["Question"], p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) + for p in previous_chat + ] ) if len(history): history += "" - pairs.append((prefix + history + question, answer)) + prompt = QA_SPECIAL_TOKENS["Question"] + question + QA_SPECIAL_TOKENS["Answer"] + pairs.append((dialogue_bg + history + prompt, answer)) previous_chat.append((question, answer)) + return pairs def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None: @@ -166,8 +171,8 @@ class SODA(Dataset): return len(self.pairs) def __getitem__(self, index): - question, answer = self.pairs[index] - return question, answer + # special token added during preprocess + return self.pairs[index] class SODADialogue(Dataset): @@ -218,7 +223,7 @@ class SODADialogue(Dataset): return len(self.pairs) def __getitem__(self, index): - return self.pairs[index] + return format_pair(self.pairs[index]) class JokeExplaination(Dataset): @@ -253,8 +258,7 @@ class JokeExplaination(Dataset): return len(self.pairs) def __getitem__(self, index): - question, answer = self.pairs[index] - return question, answer + return format_pair(self.pairs[index]) # https://huggingface.co/datasets/aquamuse diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py index 2a097fe7..85d21a27 100644 --- a/model/supervised_finetuning/custom_datasets/summarization.py +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -3,6 +3,7 @@ """ import random +from custom_datasets.formatting import format_pair from datasets import load_dataset from torch.utils.data import Dataset @@ -54,11 +55,12 @@ def index_summary_merge(text, summary): class SummarizationDataset(Dataset): - def __init__(self, dataset, cache_dir, split): + def __init__(self, dataset, cache_dir, split, max_words=512): self.name = dataset self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) self.text_column, self.summary_column = summarization_name_mapping[dataset] self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default + self.max_words = max_words def __len__(self): return len(self.dataset) @@ -72,4 +74,5 @@ class SummarizationDataset(Dataset): else: prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"]) - return ("".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[:256]), prompt]), summary) + context = "".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[: self.max_words]), prompt]) + return format_pair((context, summary)) diff --git a/model/supervised_finetuning/custom_datasets/toxic_conversation.py b/model/supervised_finetuning/custom_datasets/toxic_conversation.py index 815ac722..640b8d8d 100644 --- a/model/supervised_finetuning/custom_datasets/toxic_conversation.py +++ b/model/supervised_finetuning/custom_datasets/toxic_conversation.py @@ -4,12 +4,13 @@ """ import random +from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair from datasets import load_dataset from torch.utils.data import Dataset class ProsocialDialogueExplaination(Dataset): - name = "prosocial_explain" + name = "explain_prosocial" TEMPLATE = [ # 0 : reply or sentence of interest, 1 : reason of caution ("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"), @@ -36,7 +37,7 @@ class ProsocialDialogueExplaination(Dataset): return len(self.pairs) def __getitem__(self, idx): - return self.pairs[idx] + return format_pair(self.pairs[idx]) class ProsocialDialogue(Dataset): @@ -58,8 +59,9 @@ class ProsocialDialogue(Dataset): dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split] self.pairs = [] for row in dataset: + prompt = QA_SPECIAL_TOKENS["Question"] + row["context"] + QA_SPECIAL_TOKENS["Answer"] for answer in row["rots"]: - self.pairs.append((self.PREFIX + row["context"], answer)) + self.pairs.append((self.PREFIX + prompt, answer)) def __len__(self): return len(self.pairs) diff --git a/model/supervised_finetuning/custom_datasets/translation.py b/model/supervised_finetuning/custom_datasets/translation.py index 694d31ce..f9a71a8e 100644 --- a/model/supervised_finetuning/custom_datasets/translation.py +++ b/model/supervised_finetuning/custom_datasets/translation.py @@ -8,6 +8,7 @@ """ import random +from custom_datasets.formatting import format_pair from datasets import load_dataset from torch.utils.data import Dataset @@ -82,7 +83,7 @@ class TranslationPair(Dataset): return len(self.pairs) def __getitem__(self, index): - return self.pairs[index] + return format_pair(self.pairs[index]) class WMT2019(TranslationPair): @@ -99,6 +100,8 @@ class WMT2019(TranslationPair): else: # translating in reverse direction source = random.choice(TRANSLATION_PROMPT[src]).format(row[tgt]) self.pairs.append((source, row[src])) + if len(self.pairs) > 100000: + break class DiveMT(TranslationPair): diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py index 3b59f289..8d5ad08f 100644 --- a/model/supervised_finetuning/tests/test_datasets.py +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -7,8 +7,8 @@ from custom_datasets.dialogue_collator import DialogueDataCollator def test_all_datasets(): qa_base = QA_DATASETS summarize_base = SUMMARIZATION_DATASETS - others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"] - translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"] + others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning", "explain_prosocial", "prosocial_dialogue"] + translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "ted_trans_de-ja", "ted_trans_nl-en"] config = Namespace(cache_dir=".cache") for dataset_name in translation + others + summarize_base + qa_base: @@ -31,7 +31,6 @@ def test_collate_fn(): qa_base = QA_DATASETS summarize_base = SUMMARIZATION_DATASETS others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"] - trains, evals = [], [] for dataset_name in others + qa_base + summarize_base: print(dataset_name) @@ -41,10 +40,10 @@ def test_collate_fn(): dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128) for batch in dataloader: - # print(batch.keys()) - # print(tokenizer.decode(batch['input_ids'][0])) - # print('-----') - # print(tokenizer.decode(batch['targets'][0][batch['label_masks'][0]])) + print(batch.keys()) + print(tokenizer.decode(batch["input_ids"][0])) + print("-----") + print(tokenizer.decode(batch["targets"][0][batch["label_masks"][0]])) assert batch["targets"].shape[1] <= 512 dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128) for batch in dataloader: diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 7b6e03b6..f7a0ab15 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -25,6 +25,10 @@ def get_tokenizer(conf): tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token, "sep_token": "<|extratoken_100|>"}) elif "codegen" in conf.model_name: tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"}) + elif "pythia" in conf.model_name: + tokenizer.add_special_tokens( + {"pad_token": "<|padding|>", "sep_token": "<|endoftext|>", "eos_token": "<|endoftext|>"} + ) additional_special_tokens = ( []