diff --git a/model/supervised_finetuning/README.md b/model/supervised_finetuning/README.md index bd202397..822121d8 100644 --- a/model/supervised_finetuning/README.md +++ b/model/supervised_finetuning/README.md @@ -58,6 +58,8 @@ the end to trigger deepspeed python trainer.py --configs defaults your-model-name --deepspeed ``` +## Dataset choices + ## Results Experimental results in wandb diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index bd35f168..2eaa6686 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -6,7 +6,7 @@ defaults: per_device_eval_batch_size: 2 weight_decay: 0.00 warmup_steps: 600 - eval_steps: 100 + eval_steps: 500 save_steps: 500 max_length: 512 num_train_epochs: 3 @@ -18,7 +18,19 @@ defaults: datasets: - webgpt - prompt_dialogue - cache_dir: ~/.cache + - squad_v2 + - adversarial_qa + - trivia_qa_nocontext + - xsum + - cnn_dailymail + - prompt_dialogue + - multi_news + - scitldr + - soda + - joke + - gsm8k + - samsum + cache_dir: .cache loss_fn: CrossEntropyLoss eval_size: log_dir: "base" @@ -48,14 +60,14 @@ gpt-jt: per_device_eval_batch_size: 4 codegen: - learning_rate: 2e-6 + learning_rate: 8e-6 model_name: Salesforce/codegen-2B-multi weight_decay: 0.01 - max_length: 812 - warmup_steps: 600 + max_length: 520 + warmup_steps: 1000 gradient_checkpointing: false - gradient_accumulation_steps: 5 - per_device_train_batch_size: 4 + gradient_accumulation_steps: 9 + per_device_train_batch_size: 2 per_device_eval_batch_size: 4 debug: diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index c0cd424b..e293af3d 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,136 +1,11 @@ -import numpy as np -from datasets import load_dataset +from custom_datasets.prompt_dialogue import PromptGeneratedDataset +from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT +from custom_datasets.summarization import SummarizationDataset from sklearn.model_selection import train_test_split -from torch.utils.data import Dataset, Subset +from torch.utils.data import Subset -from .prompt_dialogue import PromptGeneratedDataset - -QA_SPECIAL_TOKENS = {"Question": "", "Answer": ""} -SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": "TL;DR:"} - -summarization_name_mapping = { - "cnn_dailymail": ("article", "highlights"), - "samsum": ("dialogue", "summary"), - "xsum": ("document", "summary"), - "multi_news": ("document", "summary"), - "scitldr": ("source", "target"), - "billsum": ("text", "summary"), - "reddit": ("content", "summary"), -} -summarization_config_mapping = { - "cnn_dailymail": ("3.0.0",), - "samsum": (), - "xsum": (), - "multi_news": (), - "scitldr": ("AIC",), - "billsum": (), - "reddit": (), -} - -QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"] -SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"] - - -def index_squad_v2(example): - return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] - - -def index_trivia_qa_nocontext(example): - # dummy return one randomly - return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] - - -def index_trivia_qa_context(example): - question = example["question"] - title = example["title"][np.random.randint(len(example["title"]))] - context = example["search_context"][np.random.randint(len(example["search_context"]))] - answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] - - return title + ". " + context + " " + question, answer - - -def index_adversarial_qa(example): - return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] - - -class QADataset(Dataset): - def __init__(self, dataset, cache_dir, split): - if dataset == "squad_v2": - self.index_fn = index_squad_v2 - self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split) - elif dataset == "trivia_qa_nocontext": - self.index_fn = index_trivia_qa_nocontext - self.dataset = load_dataset("trivia_qa", "rc.nocontext") - elif dataset == "trivia_qa_context": - self.index_fn = index_trivia_qa_context - self.dataset = load_dataset("trivia_qa", "rc") - elif dataset == "adversarial_qa": - self.index_fn = index_adversarial_qa - self.dataset = load_dataset("adversarial_qa", "adversarialQA") - else: - raise ValueError("Unknown dataset : " + dataset) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - data = self.dataset[idx] - return self.index_fn(data) - - -def index_summary_default(text, summary): - return text, summary - - -def index_summary_merge(text, summary): - return " ".join(text), " ".join(summary) - - -class SummarizationDataset(Dataset): - def __init__(self, dataset, cache_dir, split): - self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) - self.summary_column, self.text_column = summarization_name_mapping[dataset] - self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_merge - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - data = self.dataset[idx] - text, summary = data[self.text_column], data[self.summary_column] - text, summary = self.preprocess_fn(text, summary) - - return "".join( - SUMMARIZATION_SPECIAL_TOKENS["Text"], text, " ", SUMMARIZATION_SPECIAL_TOKENS["Summary"], summary - ) - - -class WebGPT(Dataset): - def __init__(self) -> None: - super().__init__() - - dataset = load_dataset("openai/webgpt_comparisons") - questions = {} - # using prompt as our index will allows us - # to add additional generated prompt later - self.index2question = {} - for row in dataset["train"]: - question = row["question"]["full_text"] - if question not in self.index2question: - self.index2question[len(self.index2question)] = question - - # only keep the best answer - questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"] - - self.questions = questions - - def __len__(self): - return len(self.index2question) - - def __getitem__(self, index): - question = self.index2question[index] - answer = self.questions[question] - return [question, answer] +QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"] +SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum"] def train_val_dataset(dataset, val_split=0.2): @@ -143,19 +18,26 @@ def train_val_dataset(dataset, val_split=0.2): def get_one_dataset(conf, dataset_name): dataset_name = dataset_name.lower() - if dataset_name in ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]: + if dataset_name in QA_DATASETS: train = QADataset(dataset_name, conf.cache_dir, "train") - eval = QADataset(dataset_name, conf.cache_dir, "validation") + val_name = "validation" if dataset_name not in ["gsm8k"] else "test" + eval = QADataset(dataset_name, conf.cache_dir, val_name) - elif dataset_name in ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"]: + elif dataset_name in SUMMARIZATION_DATASETS: train = SummarizationDataset(dataset_name, conf.cache_dir, "train") - eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation") - + val_name = "validation" if dataset_name not in ["billsum"] else "test" + eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name) elif dataset_name == "webgpt": dataset = WebGPT() train, eval = train_val_dataset(dataset, val_split=0.2) elif dataset_name == "prompt_dialogue": - dataset = PromptGeneratedDataset() + dataset = PromptGeneratedDataset(conf.cache_dir) + train, eval = train_val_dataset(dataset, val_split=0.2) + elif dataset_name == "soda": + dataset = SODA(conf.cache_dir) + train, eval = train_val_dataset(dataset, val_split=0.1) + elif dataset_name == "joke": + dataset = JokeExplaination(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.2) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index 479931f6..2efe160f 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -3,11 +3,10 @@ from typing import Optional, Union import numpy as np import torch +from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS from torch.nn import functional as F from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase -from . import QA_SPECIAL_TOKENS - @dataclass class DialogueDataCollator: @@ -35,7 +34,7 @@ class DialogueDataCollator: # Add a way for the model to terminate generation # When we predict the start of a new expected question, we want to be able to stop generation - messages.append(QA_SPECIAL_TOKENS["Question"]) + messages.append(self.tokenizer.eos_token) flatten_message = self.tokenizer( "".join(messages), diff --git a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py index 17911141..372ea27f 100644 --- a/model/supervised_finetuning/custom_datasets/prompt_dialogue.py +++ b/model/supervised_finetuning/custom_datasets/prompt_dialogue.py @@ -16,10 +16,10 @@ class PromptGeneratedDataset(Dataset): url = "https://github.com/Rallio67/language-model-agents/raw/main/chat_dialogue_v2_c.txt" - def __init__(self) -> None: + def __init__(self, cache_dir) -> None: super().__init__() - os.makedirs("datasets", exist_ok=True) - chat_dialogue = os.path.join("datasets", "chat_dialogue_v2_c.txt") + os.makedirs(cache_dir, exist_ok=True) + chat_dialogue = os.path.join(cache_dir, "chat_dialogue_v2_c.txt") if not os.path.exists(chat_dialogue): with urlopen(self.url) as file: content = file.read().decode() @@ -49,18 +49,3 @@ class PromptGeneratedDataset(Dataset): def __getitem__(self, index): question, answer = self.pairs[index] return question, answer - - -if __name__ == "__main__": - from torch.utils.data import DataLoader - from transformers import AutoTokenizer - - from .dialogue_collator import DialogueDataCollator - - tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-2B-multi") - tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"}) - dataset = PromptGeneratedDataset() - collate_fn = DialogueDataCollator(tokenizer, padding=True, max_length=128) - dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=5) - for batch in dataloader: - print(batch["input_ids"].shape) diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py new file mode 100644 index 00000000..eed9c644 --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -0,0 +1,184 @@ +import json +import os +from urllib.request import urlopen + +import numpy as np +from datasets import load_dataset +from torch.utils.data import Dataset + +QA_SPECIAL_TOKENS = {"Question": "", "Answer": "", "StartPrefix": "", "EndPrefix": ""} + + +def index_squad_v2(example): + if len(example["answers"]["text"]): + answer = example["answers"]["text"][0] + else: + answer = "I do not have answer for that" + return example["context"] + " " + example["question"], answer + + +def index_trivia_qa_nocontext(example): + # dummy return one randomly + return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] + + +def index_trivia_qa_context(example): + question = example["question"] + if len(example["search_results"]["search_context"]): + context = example["search_results"]["search_context"][ + np.random.randint(len(example["search_results"]["search_context"])) + ] + else: + context = "" + answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))] + + return context + " " + question, answer + + +def index_adversarial_qa(example): + return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0] + + +def index_gsm8k(example): + return example["question"], example["answer"] + + +class QADataset(Dataset): + def __init__(self, dataset, cache_dir, split): + if dataset == "squad_v2": + self.index_fn = index_squad_v2 + self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split) + elif dataset == "trivia_qa_nocontext": + self.index_fn = index_trivia_qa_nocontext + self.dataset = load_dataset("trivia_qa", "rc.nocontext", split=split, cache_dir=cache_dir) + elif dataset == "trivia_qa_context": + self.index_fn = index_trivia_qa_context + self.dataset = load_dataset("trivia_qa", "rc", split=split, cache_dir=cache_dir) + elif dataset == "adversarial_qa": + self.index_fn = index_adversarial_qa + self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir) + elif dataset == "gsm8k": + self.index_fn = index_gsm8k + self.dataset = load_dataset("gsm8k", "main", split=split, cache_dir=cache_dir) + elif dataset == "adversarial_qa": + self.index_fn = index_adversarial_qa + self.dataset = load_dataset("adversarial_qa", "adversarialQA", split=split, cache_dir=cache_dir) + else: + raise ValueError("Unknown dataset : " + dataset) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + data = self.dataset[idx] + return self.index_fn(data) + + +class WebGPT(Dataset): + def __init__(self) -> None: + super().__init__() + + dataset = load_dataset("openai/webgpt_comparisons") + questions = {} + # using prompt as our index will allows us + # to add additional generated prompt later + self.index2question = {} + for row in dataset["train"]: + question = row["question"]["full_text"] + if question not in self.index2question: + self.index2question[len(self.index2question)] = question + + # only keep the best answer + questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"] + + self.questions = questions + + def __len__(self): + return len(self.index2question) + + def __getitem__(self, index): + question = self.index2question[index] + answer = self.questions[question] + return [question, answer] + + +class SODA(Dataset): + def process_soda_convo(self, data): + pairs = [] + play_as = data["speakers"][1] + prefix = "{}. {}".format(data["narrative"], "your name {}".format(play_as)) + question, answer = "", "" + prefix, postfix = "", "" + previous_chat = [] + + for idx, convo in enumerate(data["dialogue"]): + if idx % 2 == 0: + question = convo + prefix = data["speakers"][idx] + else: + answer = convo + postfix = data["speakers"][idx] + if len(question) and len(answer) and prefix != postfix and postfix == play_as: + history = "".join(["{}{}".format(*p) for p in previous_chat]) + if len(history): + history += "" + pairs.append((prefix + history + question, answer)) + previous_chat.append((question, answer)) + return pairs + + def __init__(self, cache_dir, max_sample_size=10000, input_max_length=1024) -> None: + super().__init__() + + self.pairs = [] + dataset = load_dataset("allenai/soda", cache_dir=cache_dir)["train"] + for data in dataset: + data_pair = self.process_soda_convo(data) + for (prompt, answer) in data_pair: + if len(prompt) < input_max_length: + self.pairs.append((prompt, answer)) + + if len(self.pairs) > max_sample_size: + break + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + question, answer = self.pairs[index] + return question, answer + + +class JokeExplaination(Dataset): + """ """ + + url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl" + + def __init__(self, cache_dir) -> None: + super().__init__() + os.makedirs(cache_dir, exist_ok=True) + joke_explain_filename = os.path.join(cache_dir, "joke_explaination.jsonl") + if not os.path.exists(joke_explain_filename): + with urlopen(self.url) as file: + content = file.read().decode() + with open(joke_explain_filename, "w") as fout: + fout.write(content) + + question = "" + answer = "" + self.pairs = [] + with open(joke_explain_filename, "r") as f: + for line in f: + data = json.loads(line) + joke = data["joke"] + explanation = data["explaination"] + self.pairs.append((joke, explanation)) + + if len(question) > 0 and len(answer) > 0: + self.pairs.append((question, answer)) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + question, answer = self.pairs[index] + return question, answer diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py new file mode 100644 index 00000000..69e4b51d --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -0,0 +1,62 @@ +import random + +from datasets import load_dataset +from torch.utils.data import Dataset + +SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": ["TL;DR:", "Summarize this", "Give me the summary"]} + +SUMMARY_SPECIAL_PROMPT = { + "multi_news": ["Summarize in bullet points", "Generate summary in list of points"], + "xsum": ["Give me summary in one sentence", "Short TLDR", "Give me a concise summary"], + "samsum": ["TLDR;", "Summarize this dialogue", "Summarize dialogue"], +} + +summarization_config_mapping = { + "cnn_dailymail": ("3.0.0",), + "samsum": (), + "xsum": (), + "multi_news": (), + "scitldr": ("AIC",), + "billsum": (), + "reddit": (), +} + +summarization_name_mapping = { + "cnn_dailymail": ("article", "highlights"), + "samsum": ("dialogue", "summary"), + "xsum": ("document", "summary"), + "multi_news": ("document", "summary"), + "scitldr": ("source", "target"), + "billsum": ("text", "summary"), + "reddit": ("content", "summary"), +} + + +def index_summary_default(text, summary): + return text.replace("\n\n", "\n"), summary + + +def index_summary_merge(text, summary): + return " ".join(text), " ".join(summary) + + +class SummarizationDataset(Dataset): + def __init__(self, dataset, cache_dir, split): + self.name = dataset + self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) + self.text_column, self.summary_column = summarization_name_mapping[dataset] + self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + data = self.dataset[idx] + text, summary = data[self.text_column], data[self.summary_column] + text, summary = self.preprocess_fn(text, summary) + if self.name in SUMMARY_SPECIAL_PROMPT: + prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"]) + else: + prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"]) + + return ("".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], " ".join(text.split(" ")[:256]), prompt]), summary) diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py new file mode 100644 index 00000000..c9363303 --- /dev/null +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -0,0 +1,54 @@ +from argparse import Namespace + +from custom_datasets import QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset +from custom_datasets.dialogue_collator import DialogueDataCollator + + +def test_all_datasets(): + qa_base = QA_DATASETS + summarize_base = SUMMARIZATION_DATASETS + others = ["prompt_dialogue", "webgpt", "soda", "joke"] + + config = Namespace(cache_dir=".cache") + for dataset_name in others + qa_base + summarize_base: + print(dataset_name) + train, eval = get_one_dataset(config, dataset_name) + # sanity check + for idx in range(min(len(train), 1000)): + train[idx] + for idx in range(min(len(eval), 1000)): + eval[idx] + + +def test_collate_fn(): + from torch.utils.data import ConcatDataset, DataLoader + from utils import get_tokenizer + + config = Namespace(cache_dir=".cache", model_name="Salesforce/codegen-2B-multi") + tokenizer = get_tokenizer(config) + collate_fn = DialogueDataCollator(tokenizer, max_length=512) + qa_base = QA_DATASETS + summarize_base = SUMMARIZATION_DATASETS + others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"] + + trains, evals = [], [] + for dataset_name in others + qa_base + summarize_base: + print(dataset_name) + train, eval = get_one_dataset(config, dataset_name) + trains.append(train) + evals.append(eval) + + dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128) + for batch in dataloader: + # print(batch.keys()) + # print(tokenizer.decode(batch['input_ids'][0])) + # print('-----') + # print(tokenizer.decode(batch['targets'][0][batch['label_masks'][0]])) + assert batch["targets"].shape[1] <= 512 + dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128) + for batch in dataloader: + assert batch["targets"].shape[1] <= 512 + + +if __name__ == "__main__": + test_collate_fn() diff --git a/model/supervised_finetuning/tests/test_utils.py b/model/supervised_finetuning/tests/test_utils.py new file mode 100644 index 00000000..ad40e534 --- /dev/null +++ b/model/supervised_finetuning/tests/test_utils.py @@ -0,0 +1,9 @@ +from argparse import Namespace + +from utils import get_tokenizer + + +def test_tokenizer(): + get_tokenizer(Namespace(model_name="Salesforce/codegen-2B-multi", cache_dir=".cache")) + get_tokenizer(Namespace(model_name="facebook/galactica-1.3b", cache_dir=".cache")) + get_tokenizer(Namespace(model_name="", cache_dir=".cache")) diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index f598dde1..7b6e03b6 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -1,13 +1,15 @@ -from functools import partial +# from functools import partial from pathlib import Path import evaluate -import nltk -import numpy as np + +# import nltk +# import numpy as np import transformers import yaml -from custom_datasets import QA_DATASETS, QA_SPECIAL_TOKENS, SUMMARIZATION_DATASETS, get_one_dataset +from custom_datasets import get_one_dataset from custom_datasets.dialogue_collator import DialogueDataCollator +from custom_datasets.qa_datasets import QA_SPECIAL_TOKENS from losses import CrossEntropyLoss, PolyLoss from models import freeze_top_n_layers, get_specific_model from sklearn.model_selection import train_test_split @@ -51,25 +53,25 @@ def preprocess_qa(eval_pred): return (eval_pred.predictions, eval_pred.label_ids) -def postprocess_summarization(preds, labels): - preds = [pred.strip() for pred in preds] - labels = [label.strip() for label in labels] +# def postprocess_summarization(preds, labels): +# preds = [pred.strip() for pred in preds] +# labels = [label.strip() for label in labels] - preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] - labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] +# preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] +# labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] - return preds, labels +# return preds, labels -def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True): - preds, labels = eval_pred - decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) - if ignore_pad_token_for_loss: - labels = np.where(labels != -100, labels, tokenizer.pad_token_id) - decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) +# def preprocess_summarization(eval_pred, tokenizer, ignore_pad_token_for_loss=True): +# preds, labels = eval_pred +# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) +# if ignore_pad_token_for_loss: +# labels = np.where(labels != -100, labels, tokenizer.pad_token_id) +# decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) - decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels) - return decoded_preds, decoded_labels +# decoded_preds, decoded_labels = postprocess_summarization(decoded_preds, decoded_labels) +# return decoded_preds, decoded_labels def get_metrics(conf, tokenizer): @@ -77,16 +79,16 @@ def get_metrics(conf, tokenizer): # metrics in the future for more thorough evaluation metrics, preprocess_fns = [evaluate.load("accuracy")], [default_preprocess] - if any(dataset in QA_DATASETS for dataset in conf.datasets): - raise ValueError("TODO") - metrics.append(evaluate.load("squad_v2")) - preprocess_fns.append(preprocess_qa) - if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets): - raise ValueError("TODO") - metrics.append(evaluate.load("rouge")) - preprocess_fns.append( - partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss) - ) + # if any(dataset in QA_DATASETS for dataset in conf.datasets): + # raise ValueError("TODO") + # metrics.append(evaluate.load("squad_v2")) + # preprocess_fns.append(preprocess_qa) + # if any(dataset in SUMMARIZATION_DATASETS for dataset in conf.datasets): + # raise ValueError("TODO") + # metrics.append(evaluate.load("rouge")) + # preprocess_fns.append( + # partial(preprocess_summarization, tokenizer, ignore_pad_token_for_loss=conf.ignore_pad_token_for_loss) + # ) return metrics, preprocess_fns