diff --git a/model/supervised_finetuning/custom_datasets/summarization.py b/model/supervised_finetuning/custom_datasets/summarization.py new file mode 100644 index 00000000..76147928 --- /dev/null +++ b/model/supervised_finetuning/custom_datasets/summarization.py @@ -0,0 +1,57 @@ +import random +from datasets import load_dataset +from torch.utils.data import Dataset + +SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": ["TL;DR:", "Summarize this", "Give me the summary"]} + + + +summarization_config_mapping = { + "cnn_dailymail": ("3.0.0",), + "samsum": (), + "xsum": (), + "multi_news": (), + "scitldr": ("AIC",), + "billsum": (), + "reddit": (), +} + +summarization_name_mapping = { + "cnn_dailymail": ("article", "highlights"), + "samsum": ("dialogue", "summary"), + "xsum": ("document", "summary"), + "multi_news": ("document", "summary"), + "scitldr": ("source", "target"), + "billsum": ("text", "summary"), + "reddit": ("content", "summary"), +} + + +def index_summary_default(text, summary): + return text.replace('\n\n', '\n'), summary + + +def index_summary_merge(text, summary): + return " ".join(text), " ".join(summary) + + +class SummarizationDataset(Dataset): + def __init__(self, dataset, cache_dir, split): + self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split) + self.text_column, self.summary_column = summarization_name_mapping[dataset] + self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_default + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + data = self.dataset[idx] + text, summary = data[self.text_column], data[self.summary_column] + text, summary = self.preprocess_fn(text, summary) + prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"]) + + return ( + "".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], ' '.join(text.split(' ')[:256]), prompt]), + summary + ) +