[fix] @ekurtulus major logic bug in summarization

This commit is contained in:
theblackcat102
2023-01-14 03:49:19 +00:00
parent 8656d9c9b7
commit 9451aff6cc
@@ -0,0 +1,57 @@
import random
from datasets import load_dataset
from torch.utils.data import Dataset
SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": ["TL;DR:", "Summarize this", "Give me the summary"]}
summarization_config_mapping = {
"cnn_dailymail": ("3.0.0",),
"samsum": (),
"xsum": (),
"multi_news": (),
"scitldr": ("AIC",),
"billsum": (),
"reddit": (),
}
summarization_name_mapping = {
"cnn_dailymail": ("article", "highlights"),
"samsum": ("dialogue", "summary"),
"xsum": ("document", "summary"),
"multi_news": ("document", "summary"),
"scitldr": ("source", "target"),
"billsum": ("text", "summary"),
"reddit": ("content", "summary"),
}
def index_summary_default(text, summary):
return text.replace('\n\n', '\n'), summary
def index_summary_merge(text, summary):
return " ".join(text), " ".join(summary)
class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split):
self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.text_column, self.summary_column = summarization_name_mapping[dataset]
self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_default
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
text, summary = data[self.text_column], data[self.summary_column]
text, summary = self.preprocess_fn(text, summary)
prompt = random.choice(SUMMARIZATION_SPECIAL_TOKENS["Summary"])
return (
"".join([SUMMARIZATION_SPECIAL_TOKENS["Text"], ' '.join(text.split(' ')[:256]), prompt]),
summary
)