From 5d9591d82cb2e31272b4335fffbc164bccf68003 Mon Sep 17 00:00:00 2001 From: Sotirios Anagnostidis Date: Thu, 19 Jan 2023 14:25:19 +0100 Subject: [PATCH] added soda dialogue dataset --- .../supervised_finetuning/configs/config.yaml | 1 + .../custom_datasets/__init__.py | 5 +- .../custom_datasets/qa_datasets.py | 62 ++++++++++++++++++- model/supervised_finetuning/requirements.txt | 8 ++- 4 files changed, 70 insertions(+), 6 deletions(-) diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml index 2eaa6686..815c2e75 100644 --- a/model/supervised_finetuning/configs/config.yaml +++ b/model/supervised_finetuning/configs/config.yaml @@ -30,6 +30,7 @@ defaults: - joke - gsm8k - samsum + - soda_dialogue cache_dir: .cache loss_fn: CrossEntropyLoss eval_size: diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py index e293af3d..3bec37e7 100644 --- a/model/supervised_finetuning/custom_datasets/__init__.py +++ b/model/supervised_finetuning/custom_datasets/__init__.py @@ -1,5 +1,5 @@ from custom_datasets.prompt_dialogue import PromptGeneratedDataset -from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT +from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT from custom_datasets.summarization import SummarizationDataset from sklearn.model_selection import train_test_split from torch.utils.data import Subset @@ -36,6 +36,9 @@ def get_one_dataset(conf, dataset_name): elif dataset_name == "soda": dataset = SODA(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.1) + elif dataset_name == "soda_dialogue": + dataset = SODADialogue(conf.cache_dir) + train, eval = train_val_dataset(dataset, val_split=0.1) elif dataset_name == "joke": dataset = JokeExplaination(conf.cache_dir) train, eval = train_val_dataset(dataset, val_split=0.2) diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py index eed9c644..d191c56c 100644 --- a/model/supervised_finetuning/custom_datasets/qa_datasets.py +++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py @@ -106,7 +106,12 @@ class SODA(Dataset): def process_soda_convo(self, data): pairs = [] play_as = data["speakers"][1] - prefix = "{}. {}".format(data["narrative"], "your name {}".format(play_as)) + prefix = "{}{}. {}{}".format( + QA_SPECIAL_TOKENS["StartPrefix"], + data["narrative"], + "your name {}".format(play_as), + QA_SPECIAL_TOKENS["EndPrefix"], + ) question, answer = "", "" prefix, postfix = "", "" previous_chat = [] @@ -119,7 +124,9 @@ class SODA(Dataset): answer = convo postfix = data["speakers"][idx] if len(question) and len(answer) and prefix != postfix and postfix == play_as: - history = "".join(["{}{}".format(*p) for p in previous_chat]) + history = "".join( + ["{}{}{}".format(p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) for p in previous_chat] + ) if len(history): history += "" pairs.append((prefix + history + question, answer)) @@ -148,6 +155,57 @@ class SODA(Dataset): return question, answer +class SODADialogue(Dataset): + url = "https://drive.google.com/uc?id=1TOGQfr419n8wpzJpYLLw4nB3tSKD8zXV" + + def __init__(self, cache_dir, verbose=True): + + path = os.path.join(cache_dir, "soda_dialog.jsonl") + + if not os.path.exists(path): + import gzip + import shutil + + import gdown + + gdown.download(self.url, output=os.path.join(cache_dir, "soda_dialog.jsonl.gz")) + + with gzip.open(os.path.join(cache_dir, "soda_dialog.jsonl.gz"), "rb") as f_in: + with open(path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + + self.pairs = [] + faulty = 0 + with open(path) as fin: + for line in fin: + conversation = json.loads(line) + question_answer_pairs = () + + question_answers = conversation["text"].split("User: ") + for question_answer in question_answers[1:]: # first element is empty + try: + question, answer = question_answer.split("\nAssistant: ") + question_answer_pairs += ( + question, + answer, + ) + except ValueError: + # there might be some extra 'User: ' or 'Assistant: ' tokens in the dataset that cause trouble.. + faulty += 1 + continue + + self.pairs.append(question_answer_pairs) + + if verbose: + print("For SODA dialogue dataset found {} faults within the total {} dialogs".format(faulty, len(self))) + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + return self.pairs[index] + + class JokeExplaination(Dataset): """ """ diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index 0e6eeb51..8f8cc63c 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -3,10 +3,12 @@ bitsandbytes==0.36.0.post2 datasets==2.8.0 deepspeed==0.7.7 evaluate==0.4.0 +gdown mpi4py==3.1.4 nltk==3.8.1 -numpy==1.23.0 -PyYAML==6.0 +numpy>=1.22.4 +py7zr +PyYAML>=6.0 scikit_learn==1.2.0 -torch==1.13.1 +torch>=1.11.0 transformers==4.25.1