From 9451aff6ccce98f64e706d511664f1d29f9fc7f2 Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Sat, 14 Jan 2023 03:49:19 +0000 Subject: [PATCH 1/5] [fix] @ekurtulus major logic bug in summarization --- .../custom_datasets/summarization.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 model/supervised_finetuning/custom_datasets/summarization.py diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py new file mode 100644 index 00000000..76147928 --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -0,0 +1,57 @@ +import random +from datasets import load_dataset +from torch.utils.data import Dataset + +SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": ["TL;DR:", "Summarize this", "Give me the summary"]} + + + +summarization_config_mapping = { + "cnn_dailymail": ("3.0.0",), + "samsum": (), + "xsum": (), + "multi_news": (), + "scitldr": ("AIC",), + "billsum": (), + "reddit": (), +} + +summarization_name_mapping = { + "cnn_dailymail": ("article", "highlights"), + "samsum": ("dialogue", "summary"), + "xsum": ("document", "summary"), + "multi_news": ("document", "summary"), + "scitldr": ("source", "target"), + "billsum": ("text", "summary"), + "reddit": ("content", "summary"), +} + + +def index_summary_default(text, summary): + return text.replace('\n\n', '\n'), summary + + +def index_summary_merge(text, summary): + return " ".join(text), " ".join(summary) + + +class SummarizationDataset(Dataset): + def __init__(self, dataset, cache_dir, split): + self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) + self.text_column, self.summary_column = summarization_name_mapping[dataset] + self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_default + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + data = self.dataset[idx] + text, summary = data[self.text_column], data[self.summary_column] + text, summary = self.preprocess_fn(text, summary) + prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"]) + + return ( + "".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], ' '.join(text.split(' ')[:256]), prompt]), + summary + ) + From 3966024871697bd47de52f9f4ec846216722f669 Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Sat, 14 Jan 2023 05:49:22 +0000 Subject: [PATCH 2/5] [fix] Fix summarizer bug and QA typo issue --- model/supervised_finetuning/README.md | 2 + .../supervised_finetuning/configs/config.yaml | 14 +- .../custom_datasets/__init__.py | 152 ++------------- .../custom_datasets/dialogue_collator.py | 5 +- .../custom_datasets/prompt_dialogue.py | 21 +-- .../custom_datasets/qa_datasets.py | 174 ++++++++++++++++++ .../custom_datasets/summarization.py | 22 ++- .../tests/test_datasets.py | 41 +++++ .../supervised_finetuning/tests/test_utils.py | 9 + 9 files changed, 267 insertions(+), 173 deletions(-) create mode 100644 model/supervised_finetuning/custom_datasets/qa_datasets.py create mode 100644 model/supervised_finetuning/tests/test_datasets.py create mode 100644 model/supervised_finetuning/tests/test_utils.py diff --git a/model/supervised_finetuning/README.md b/model/supervised_finetuning/README.md index bd202397..822121d8 100644 --- a/model/supervised_finetuning/README.md +++ b/model/supervised_finetuning/README.md @@ -58,6 +58,8 @@ the end to trigger deepspeed python trainer.py --configs defaults your-model-name --deepspeed ``` +## Dataset choices + ## Results Experimental results in wandb diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index bd35f168..b0684bda 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -6,7 +6,7 @@ defaults: per_device_eval_batch_size: 2 weight_decay: 0.00 warmup_steps: 600 - eval_steps: 100 + eval_steps: 500 save_steps: 500 max_length: 512 num_train_epochs: 3 @@ -18,7 +18,7 @@ defaults: datasets: - webgpt - prompt_dialogue - cache_dir: ~/.cache + cache_dir: .cache loss_fn: CrossEntropyLoss eval_size: log_dir: "base" @@ -48,14 +48,14 @@ gpt-jt: per_device_eval_batch_size: 4 codegen: - learning_rate: 2e-6 + learning_rate: 8e-6 model_name: Salesforce/codegen-2B-multi weight_decay: 0.01 - max_length: 812 - warmup_steps: 600 + max_length: 512 + warmup_steps: 1000 gradient_checkpointing: false - gradient_accumulation_steps: 5 - per_device_train_batch_size: 4 + gradient_accumulation_steps: 10 + per_device_train_batch_size: 2 per_device_eval_batch_size: 4 debug: diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index c0cd424b..6b7906f7 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,136 +1,11 @@ -import numpy as np -from datasets import load_dataset +from custom_datasets.prompt_dialogue import PromptGeneratedDataset +from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT +from custom_datasets.summarization import SummarizationDataset from sklearn.model_selection import train_test_split -from torch.utils.data import Dataset, Subset +from torch.utils.data import Subset -from .prompt_dialogue import PromptGeneratedDataset - -QA_SPECIAL_TOKENS = {"Question": "", "Answer": ""} -SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": "TL;DR:"} - -summarization_name_mapping = { - "cnn_dailymail": ("article", "highlights"), - "samsum": ("dialogue", "summary"), - "xsum": ("document", "summary"), - "multi_news": ("document", "summary"), - "scitldr": ("source", "target"), - "billsum": ("text", "summary"), - "reddit": ("content", "summary"), -} -summarization_config_mapping = { - "cnn_dailymail": ("3.0.0",), - "samsum": (), - "xsum": (), - "multi_news": (), - "scitldr": ("AIC",), - "billsum": (), - "reddit": (), -} - -QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"] -SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"] - - -def index_squad_v2(example): - return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] - - -def index_trivia_qa_nocontext(example): - # dummy return one randomly - return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] - - -def index_trivia_qa_context(example): - question = example["question"] - title = example["title"][np.random.randint(len(example["title"]))] - context = example["search_context"][np.random.randint(len(example["search_context"]))] - answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] - - return title + ". " + context + " " + question, answer - - -def index_adversarial_qa(example): - return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] - - -class QADataset(Dataset): - def __init__(self, dataset, cache_dir, split): - if dataset == "squad_v2": - self.index_fn = index_squad_v2 - self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split) - elif dataset == "trivia_qa_nocontext": - self.index_fn = index_trivia_qa_nocontext - self.dataset = load_dataset("trivia_qa", "rc.nocontext") - elif dataset == "trivia_qa_context": - self.index_fn = index_trivia_qa_context - self.dataset = load_dataset("trivia_qa", "rc") - elif dataset == "adversarial_qa": - self.index_fn = index_adversarial_qa - self.dataset = load_dataset("adversarial_qa", "adversarialQA") - else: - raise ValueError("Unknown dataset : " + dataset) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - data = self.dataset[idx] - return self.index_fn(data) - - -def index_summary_default(text, summary): - return text, summary - - -def index_summary_merge(text, summary): - return " ".join(text), " ".join(summary) - - -class SummarizationDataset(Dataset): - def __init__(self, dataset, cache_dir, split): - self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) - self.summary_column, self.text_column = summarization_name_mapping[dataset] - self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_merge - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - data = self.dataset[idx] - text, summary = data[self.text_column], data[self.summary_column] - text, summary = self.preprocess_fn(text, summary) - - return "".join( - SUMMARIZATION_SPECIAL_TOKENS["Text"], text, " ", SUMMARIZATION_SPECIAL_TOKENS["Summary"], summary - ) - - -class WebGPT(Dataset): - def __init__(self) -> None: - super().__init__() - - dataset = load_dataset("openai/webgpt_comparisons") - questions = {} - # using prompt as our index will allows us - # to add additional generated prompt later - self.index2question = {} - for row in dataset["train"]: - question = row["question"]["full_text"] - if question not in self.index2question: - self.index2question[len(self.index2question)] = question - - # only keep the best answer - questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"] - - self.questions = questions - - def __len__(self): - return len(self.index2question) - - def __getitem__(self, index): - question = self.index2question[index] - answer = self.questions[question] - return [question, answer] +QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext"] +SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"] def train_val_dataset(dataset, val_split=0.2): @@ -143,19 +18,24 @@ def train_val_dataset(dataset, val_split=0.2): def get_one_dataset(conf, dataset_name): dataset_name = dataset_name.lower() - if dataset_name in ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]: + if dataset_name in QA_DATASETS: train = QADataset(dataset_name, conf.cache_dir, "train") eval = QADataset(dataset_name, conf.cache_dir, "validation") - elif dataset_name in ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"]: + elif dataset_name in SUMMARIZATION_DATASETS: train = SummarizationDataset(dataset_name, conf.cache_dir, "train") - eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation") - + eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation" if dataset_name != "billsum" else "test") elif dataset_name == "webgpt": dataset = WebGPT() train, eval = train_val_dataset(dataset, val_split=0.2) elif dataset_name == "prompt_dialogue": - dataset = PromptGeneratedDataset() + dataset = PromptGeneratedDataset(conf.cache_dir) + train, eval = train_val_dataset(dataset, val_split=0.2) + elif dataset_name == "soda": + dataset = SODA(conf.cache_dir) + train, eval = train_val_dataset(dataset, val_split=0.1) + elif dataset_name == "joke": + dataset = JokeExplaination(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.2) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index 479931f6..2efe160f 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -3,11 +3,10 @@ from typing import Optional, Union import numpy as np import torch +from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS from torch.nn import functional as F from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase -from . import QA_SPECIAL_TOKENS - @dataclass class DialogueDataCollator: @@ -35,7 +34,7 @@ class DialogueDataCollator: # Add a way for the model to terminate generation # When we predict the start of a new expected question, we want to be able to stop generation - messages.append(QA_SPECIAL_TOKENS["Question"]) + messages.append(self.tokenizer.eos_token) flatten_message = self.tokenizer( "".join(messages), diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 17911141..372ea27f 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -16,10 +16,10 @@ class PromptGeneratedDataset(Dataset): url = "https://github.com/Rallio67/language-model-agents/raw/main/chat_dialogue_v2_c.txt" - def __init__(self) -> None: + def __init__(self, cache_dir) -> None: super().__init__() - os.makedirs("datasets", exist_ok=True) - chat_dialogue = os.path.join("datasets", "chat_dialogue_v2_c.txt") + os.makedirs(cache_dir, exist_ok=True) + chat_dialogue = os.path.join(cache_dir, "chat_dialogue_v2_c.txt") if not os.path.exists(chat_dialogue): with urlopen(self.url) as file: content = file.read().decode() @@ -49,18 +49,3 @@ class PromptGeneratedDataset(Dataset): def __getitem__(self, index): question, answer = self.pairs[index] return question, answer - - -if __name__ == "__main__": - from torch.utils.data import DataLoader - from transformers import AutoTokenizer - - from .dialogue_collator import DialogueDataCollator - - tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-2B-multi") - tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"}) - dataset = PromptGeneratedDataset() - collate_fn = DialogueDataCollator(tokenizer, padding=True, max_length=128) - dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=5) - for batch in dataloader: - print(batch["input_ids"].shape) diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py new file mode 100644 index 00000000..d0fcffda --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -0,0 +1,174 @@ +import json +import os +from urllib.request import urlopen + +import numpy as np +from datasets import load_dataset +from torch.utils.data import Dataset + +QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} + + +def index_squad_v2(example): + if len(example["answers"]["text"]): + answer = example["answers"]["text"][0] + else: + answer = "I do not have answer for that" + return example["context"] + " " + example["question"], answer + + +def index_trivia_qa_nocontext(example): + # dummy return one randomly + return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] + + +def index_trivia_qa_context(example): + question = example["question"] + if len(example["search_results"]["search_context"]): + context = example["search_results"]["search_context"][ + np.random.randint(len(example["search_results"]["search_context"])) + ] + else: + context = "" + answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] + + return context + " " + question, answer + + +def index_adversarial_qa(example): + return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] + + +class QADataset(Dataset): + def __init__(self, dataset, cache_dir, split): + if dataset == "squad_v2": + self.index_fn = index_squad_v2 + self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split) + elif dataset == "trivia_qa_nocontext": + self.index_fn = index_trivia_qa_nocontext + self.dataset = load_dataset("trivia_qa", "rc.nocontext", split=split) + elif dataset == "trivia_qa_context": + self.index_fn = index_trivia_qa_context + self.dataset = load_dataset("trivia_qa", "rc", split=split) + elif dataset == "adversarial_qa": + self.index_fn = index_adversarial_qa + self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split) + else: + raise ValueError("Unknown dataset : " + dataset) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + data = self.dataset[idx] + return self.index_fn(data) + + +class WebGPT(Dataset): + def __init__(self) -> None: + super().__init__() + + dataset = load_dataset("openai/webgpt_comparisons") + questions = {} + # using prompt as our index will allows us + # to add additional generated prompt later + self.index2question = {} + for row in dataset["train"]: + question = row["question"]["full_text"] + if question not in self.index2question: + self.index2question[len(self.index2question)] = question + + # only keep the best answer + questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"] + + self.questions = questions + + def __len__(self): + return len(self.index2question) + + def __getitem__(self, index): + question = self.index2question[index] + answer = self.questions[question] + return [question, answer] + + +class SODA(Dataset): + def process_soda_convo(self, data): + pairs = [] + play_as = data["speakers"][1] + prefix = "{}. {}".format(data["narrative"], "your name {}".format(play_as)) + question, answer = "", "" + prefix, postfix = "", "" + previous_chat = [] + + for idx, convo in enumerate(data["dialogue"]): + if idx % 2 == 0: + question = convo + prefix = data["speakers"][idx] + else: + answer = convo + postfix = data["speakers"][idx] + if len(question) and len(answer) and prefix != postfix and postfix == play_as: + history = "".join(["{}{}".format(*p) for p in previous_chat]) + if len(history): + history += "" + pairs.append((prefix + history + question, answer)) + previous_chat.append((question, answer)) + return pairs + + def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None: + super().__init__() + + self.pairs = [] + dataset = load_dataset("allenai/soda", cache_dir=cache_dir)["train"] + for data in dataset: + data_pair = self.process_soda_convo(data) + for (prompt, answer) in data_pair: + if len(prompt) < input_max_length: + self.pairs.append((prompt, answer)) + + if len(self.pairs) > max_sample_size: + break + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + question, answer = self.pairs[index] + return question, answer + + +class JokeExplaination(Dataset): + """ """ + + url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl" + + def __init__(self, cache_dir) -> None: + super().__init__() + os.makedirs(cache_dir, exist_ok=True) + joke_explain_filename = os.path.join(cache_dir, "joke_explaination.jsonl") + if not os.path.exists(joke_explain_filename): + with urlopen(self.url) as file: + content = file.read().decode() + with open(joke_explain_filename, "w") as fout: + fout.write(content) + + question = "" + answer = "" + self.pairs = [] + with open(joke_explain_filename, "r") as f: + for line in f: + data = json.loads(line) + joke = data["joke"] + explanation = data["explaination"] + self.pairs.append((joke, explanation)) + + if len(question) > 0 and len(answer) > 0: + self.pairs.append((question, answer)) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + question, answer = self.pairs[index] + return question, answer diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py index 76147928..41fa6dc0 100644 --- a/model/supervised_finetuning/custom_datasets/summarization.py +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -1,10 +1,14 @@ import random + from datasets import load_dataset from torch.utils.data import Dataset SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": ["TL;DR:", "Summarize this", "Give me the summary"]} - +SUMMARY_SPECIAL_PROMPT = { + "multi_news": ["Summarize in bullet points", "Generate summary in list of points"], + "xsum": ["Give me summary in one sentence", "Short TLDR", "Give me a concise summary"], +} summarization_config_mapping = { "cnn_dailymail": ("3.0.0",), @@ -28,7 +32,7 @@ summarization_name_mapping = { def index_summary_default(text, summary): - return text.replace('\n\n', '\n'), summary + return text.replace("\n\n", "\n"), summary def index_summary_merge(text, summary): @@ -37,9 +41,10 @@ def index_summary_merge(text, summary): class SummarizationDataset(Dataset): def __init__(self, dataset, cache_dir, split): + self.name = dataset self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) self.text_column, self.summary_column = summarization_name_mapping[dataset] - self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_default + self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default def __len__(self): return len(self.dataset) @@ -48,10 +53,9 @@ class SummarizationDataset(Dataset): data = self.dataset[idx] text, summary = data[self.text_column], data[self.summary_column] text, summary = self.preprocess_fn(text, summary) - prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"]) - - return ( - "".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], ' '.join(text.split(' ')[:256]), prompt]), - summary - ) + if self.name in SUMMARY_SPECIAL_PROMPT: + prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"]) + else: + prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"]) + return ("".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[:256]), prompt]), summary) diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py new file mode 100644 index 00000000..721fdfd3 --- /dev/null +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -0,0 +1,41 @@ +from argparse import Namespace + +from custom_datasets import get_one_dataset +from custom_datasets.dialogue_collator import DialogueDataCollator + + +def test_all_datasets(): + qa_base = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext"] + summarize_base = ["scitldr", "xsum", "cnn_dailymail", "samsum", "multi_news", "billsum"] + others = ["prompt_dialogue", "webgpt", "soda", "joke"] + + config = Namespace(cache_dir=".cache") + for dataset_name in others + qa_base + summarize_base: + print(dataset_name) + train, eval = get_one_dataset(config, dataset_name) + # sanity check + for idx in range(min(len(train), 1000)): + train[idx] + for idx in range(min(len(eval), 1000)): + eval[idx] + + +def test_collate_fn(): + from torch.utils.data import DataLoader + from utils import get_tokenizer + + config = Namespace(cache_dir=".cache", model_name="Salesforce/codegen-2B-multi") + tokenizer = get_tokenizer(config) + collate_fn = DialogueDataCollator(tokenizer, max_length=512) + train, eval = get_one_dataset(config, "multi_news") + dataloader = DataLoader(train, collate_fn=collate_fn, batch_size=128) + for batch in dataloader: + # print(batch.keys()) + # print(tokenizer.decode(batch['input_ids'][0])) + # print('-----') + # print(tokenizer.decode(batch['targets'][0][batch['label_masks'][0]])) + assert batch["targets"].shape[1] <= 512 + + +if __name__ == "__main__": + test_all_datasets() diff --git a/model/supervised_finetuning/tests/test_utils.py b/model/supervised_finetuning/tests/test_utils.py new file mode 100644 index 00000000..ad40e534 --- /dev/null +++ b/model/supervised_finetuning/tests/test_utils.py @@ -0,0 +1,9 @@ +from argparse import Namespace + +from utils import get_tokenizer + + +def test_tokenizer(): + get_tokenizer(Namespace(model_name="Salesforce/codegen-2B-multi", cache_dir=".cache")) + get_tokenizer(Namespace(model_name="facebook/galactica-1.3b", cache_dir=".cache")) + get_tokenizer(Namespace(model_name="", cache_dir=".cache")) From 154611109465e3ca1c22efd08ce8bb15e52f519d Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Sat, 14 Jan 2023 06:24:47 +0000 Subject: [PATCH 3/5] [feature] added GSM8k and code refactoring --- .../supervised_finetuning/configs/config.yaml | 12 +++++++++ .../custom_datasets/__init__.py | 10 ++++--- .../custom_datasets/qa_datasets.py | 16 ++++++++--- .../tests/test_datasets.py | 27 ++++++++++++++----- model/supervised_finetuning/utils.py | 3 ++- 5 files changed, 53 insertions(+), 15 deletions(-) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index b0684bda..0440201a 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -18,6 +18,18 @@ defaults: datasets: - webgpt - prompt_dialogue + - squad_v2 + - adversarial_qa + - trivia_qa_nocontext + - xsum + - cnn_dailymail + - prompt_dialogue + - multi_news + - scitldr + - soda + - joke + - + - joke cache_dir: .cache loss_fn: CrossEntropyLoss eval_size: diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index 6b7906f7..e293af3d 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -4,8 +4,8 @@ from custom_datasets.summarization import SummarizationDataset from sklearn.model_selection import train_test_split from torch.utils.data import Subset -QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext"] -SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"] +QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"] +SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum"] def train_val_dataset(dataset, val_split=0.2): @@ -20,11 +20,13 @@ def get_one_dataset(conf, dataset_name): if dataset_name in QA_DATASETS: train = QADataset(dataset_name, conf.cache_dir, "train") - eval = QADataset(dataset_name, conf.cache_dir, "validation") + val_name = "validation" if dataset_name not in ["gsm8k"] else "test" + eval = QADataset(dataset_name, conf.cache_dir, val_name) elif dataset_name in SUMMARIZATION_DATASETS: train = SummarizationDataset(dataset_name, conf.cache_dir, "train") - eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation" if dataset_name != "billsum" else "test") + val_name = "validation" if dataset_name not in ["billsum"] else "test" + eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name) elif dataset_name == "webgpt": dataset = WebGPT() train, eval = train_val_dataset(dataset, val_split=0.2) diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index d0fcffda..eed9c644 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -39,6 +39,10 @@ def index_adversarial_qa(example): return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] +def index_gsm8k(example): + return example["question"], example["answer"] + + class QADataset(Dataset): def __init__(self, dataset, cache_dir, split): if dataset == "squad_v2": @@ -46,13 +50,19 @@ class QADataset(Dataset): self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split) elif dataset == "trivia_qa_nocontext": self.index_fn = index_trivia_qa_nocontext - self.dataset = load_dataset("trivia_qa", "rc.nocontext", split=split) + self.dataset = load_dataset("trivia_qa", "rc.nocontext", split=split, cache_dir=cache_dir) elif dataset == "trivia_qa_context": self.index_fn = index_trivia_qa_context - self.dataset = load_dataset("trivia_qa", "rc", split=split) + self.dataset = load_dataset("trivia_qa", "rc", split=split, cache_dir=cache_dir) elif dataset == "adversarial_qa": self.index_fn = index_adversarial_qa - self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split) + self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir) + elif dataset == "gsm8k": + self.index_fn = index_gsm8k + self.dataset = load_dataset("gsm8k", "main", split=split, cache_dir=cache_dir) + elif dataset == "adversarial_qa": + self.index_fn = index_adversarial_qa + self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir) else: raise ValueError("Unknown dataset : " + dataset) diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py index 721fdfd3..c9363303 100644 --- a/model/supervised_finetuning/tests/test_datasets.py +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -1,12 +1,12 @@ from argparse import Namespace -from custom_datasets import get_one_dataset +from custom_datasets import QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset from custom_datasets.dialogue_collator import DialogueDataCollator def test_all_datasets(): - qa_base = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext"] - summarize_base = ["scitldr", "xsum", "cnn_dailymail", "samsum", "multi_news", "billsum"] + qa_base = QA_DATASETS + summarize_base = SUMMARIZATION_DATASETS others = ["prompt_dialogue", "webgpt", "soda", "joke"] config = Namespace(cache_dir=".cache") @@ -21,21 +21,34 @@ def test_all_datasets(): def test_collate_fn(): - from torch.utils.data import DataLoader + from torch.utils.data import ConcatDataset, DataLoader from utils import get_tokenizer config = Namespace(cache_dir=".cache", model_name="Salesforce/codegen-2B-multi") tokenizer = get_tokenizer(config) collate_fn = DialogueDataCollator(tokenizer, max_length=512) - train, eval = get_one_dataset(config, "multi_news") - dataloader = DataLoader(train, collate_fn=collate_fn, batch_size=128) + qa_base = QA_DATASETS + summarize_base = SUMMARIZATION_DATASETS + others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"] + + trains, evals = [], [] + for dataset_name in others + qa_base + summarize_base: + print(dataset_name) + train, eval = get_one_dataset(config, dataset_name) + trains.append(train) + evals.append(eval) + + dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128) for batch in dataloader: # print(batch.keys()) # print(tokenizer.decode(batch['input_ids'][0])) # print('-----') # print(tokenizer.decode(batch['targets'][0][batch['label_masks'][0]])) assert batch["targets"].shape[1] <= 512 + dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128) + for batch in dataloader: + assert batch["targets"].shape[1] <= 512 if __name__ == "__main__": - test_all_datasets() + test_collate_fn() diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index f598dde1..85fb86db 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -6,8 +6,9 @@ import nltk import numpy as np import transformers import yaml -from custom_datasets import QA_DATASETS, QA_SPECIAL_TOKENS, SUMMARIZATION_DATASETS, get_one_dataset +from custom_datasets import QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset from custom_datasets.dialogue_collator import DialogueDataCollator +from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model from sklearn.model_selection import train_test_split From 6f6c590e5798b6aa0c37df5be1655bfd19b3eeca Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Sat, 14 Jan 2023 06:47:21 +0000 Subject: [PATCH 4/5] [fix] Disable task specific evaluation --- .../custom_datasets/summarization.py | 1 + model/supervised_finetuning/utils.py | 57 ++++++++++--------- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py index 41fa6dc0..69e4b51d 100644 --- a/model/supervised_finetuning/custom_datasets/summarization.py +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -8,6 +8,7 @@ SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": ["TL;DR:", "Summarize thi SUMMARY_SPECIAL_PROMPT = { "multi_news": ["Summarize in bullet points", "Generate summary in list of points"], "xsum": ["Give me summary in one sentence", "Short TLDR", "Give me a concise summary"], + "samsum": ["TLDR;", "Summarize this dialogue", "Summarize dialogue"], } summarization_config_mapping = { diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 85fb86db..7b6e03b6 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -1,12 +1,13 @@ -from functools import partial +# from functools import partial from pathlib import Path import evaluate -import nltk -import numpy as np + +# import nltk +# import numpy as np import transformers import yaml -from custom_datasets import QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset +from custom_datasets import get_one_dataset from custom_datasets.dialogue_collator import DialogueDataCollator from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS from losses import CrossEntropyLoss, PolyLoss @@ -52,25 +53,25 @@ def preprocess_qa(eval_pred): return (eval_pred.predictions, eval_pred.label_ids) -def postprocess_summarization(preds, labels): - preds = [pred.strip() for pred in preds] - labels = [label.strip() for label in labels] +# def postprocess_summarization(preds, labels): +# preds = [pred.strip() for pred in preds] +# labels = [label.strip() for label in labels] - preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] - labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] +# preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] +# labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] - return preds, labels +# return preds, labels -def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True): - preds, labels = eval_pred - decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) - if ignore_pad_token_for_loss: - labels = np.where(labels != -100, labels, tokenizer.pad_token_id) - decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) +# def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True): +# preds, labels = eval_pred +# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) +# if ignore_pad_token_for_loss: +# labels = np.where(labels != -100, labels, tokenizer.pad_token_id) +# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) - decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels) - return decoded_preds, decoded_labels +# decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels) +# return decoded_preds, decoded_labels def get_metrics(conf, tokenizer): @@ -78,16 +79,16 @@ def get_metrics(conf, tokenizer): # metrics in the future for more thorough evaluation metrics, preprocess_fns = [evaluate.load("accuracy")], [default_preprocess] - if any(dataset in QA_DATASETS for dataset in conf.datasets): - raise ValueError("TODO") - metrics.append(evaluate.load("squad_v2")) - preprocess_fns.append(preprocess_qa) - if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets): - raise ValueError("TODO") - metrics.append(evaluate.load("rouge")) - preprocess_fns.append( - partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss) - ) + # if any(dataset in QA_DATASETS for dataset in conf.datasets): + # raise ValueError("TODO") + # metrics.append(evaluate.load("squad_v2")) + # preprocess_fns.append(preprocess_qa) + # if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets): + # raise ValueError("TODO") + # metrics.append(evaluate.load("rouge")) + # preprocess_fns.append( + # partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss) + # ) return metrics, preprocess_fns From 670be60ca80e084c116e13847085ec604068c280 Mon Sep 17 00:00:00 2001 From: theblackcat102 Date: Sat, 14 Jan 2023 12:17:58 +0000 Subject: [PATCH 5/5] [fix] Fix config typo --- model/supervised_finetuning/configs/config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 0440201a..2eaa6686 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -28,8 +28,8 @@ defaults: - scitldr - soda - joke - - - - joke + - gsm8k + - samsum cache_dir: .cache loss_fn: CrossEntropyLoss eval_size: @@ -63,10 +63,10 @@ codegen: learning_rate: 8e-6 model_name: Salesforce/codegen-2B-multi weight_decay: 0.01 - max_length: 512 + max_length: 520 warmup_steps: 1000 gradient_checkpointing: false - gradient_accumulation_steps: 10 + gradient_accumulation_steps: 9 per_device_train_batch_size: 2 per_device_eval_batch_size: 4