diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml
index 2eaa6686..815c2e75 100644
--- a/model/supervised_finetuning/configs/config.yaml
+++ b/model/supervised_finetuning/configs/config.yaml
@@ -30,6 +30,7 @@ defaults:
- joke
- gsm8k
- samsum
+ - soda_dialogue
cache_dir: .cache
loss_fn: CrossEntropyLoss
eval_size:
diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py
index e293af3d..3bec37e7 100644
--- a/model/supervised_finetuning/custom_datasets/__init__.py
+++ b/model/supervised_finetuning/custom_datasets/__init__.py
@@ -1,5 +1,5 @@
from custom_datasets.prompt_dialogue import PromptGeneratedDataset
-from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT
+from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT
from custom_datasets.summarization import SummarizationDataset
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
@@ -36,6 +36,9 @@ def get_one_dataset(conf, dataset_name):
elif dataset_name == "soda":
dataset = SODA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.1)
+ elif dataset_name == "soda_dialogue":
+ dataset = SODADialogue(conf.cache_dir)
+ train, eval = train_val_dataset(dataset, val_split=0.1)
elif dataset_name == "joke":
dataset = JokeExplaination(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py
index eed9c644..d191c56c 100644
--- a/model/supervised_finetuning/custom_datasets/qa_datasets.py
+++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py
@@ -106,7 +106,12 @@ class SODA(Dataset):
def process_soda_convo(self, data):
pairs = []
play_as = data["speakers"][1]
- prefix = "{}. {}".format(data["narrative"], "your name {}".format(play_as))
+ prefix = "{}{}. {}{}".format(
+ QA_SPECIAL_TOKENS["StartPrefix"],
+ data["narrative"],
+ "your name {}".format(play_as),
+ QA_SPECIAL_TOKENS["EndPrefix"],
+ )
question, answer = "", ""
prefix, postfix = "", ""
previous_chat = []
@@ -119,7 +124,9 @@ class SODA(Dataset):
answer = convo
postfix = data["speakers"][idx]
if len(question) and len(answer) and prefix != postfix and postfix == play_as:
- history = "".join(["{}{}".format(*p) for p in previous_chat])
+ history = "".join(
+ ["{}{}{}".format(p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) for p in previous_chat]
+ )
if len(history):
history += ""
pairs.append((prefix + history + question, answer))
@@ -148,6 +155,57 @@ class SODA(Dataset):
return question, answer
+class SODADialogue(Dataset):
+ url = "https://drive.google.com/uc?id=1TOGQfr419n8wpzJpYLLw4nB3tSKD8zXV"
+
+ def __init__(self, cache_dir, verbose=True):
+
+ path = os.path.join(cache_dir, "soda_dialog.jsonl")
+
+ if not os.path.exists(path):
+ import gzip
+ import shutil
+
+ import gdown
+
+ gdown.download(self.url, output=os.path.join(cache_dir, "soda_dialog.jsonl.gz"))
+
+ with gzip.open(os.path.join(cache_dir, "soda_dialog.jsonl.gz"), "rb") as f_in:
+ with open(path, "wb") as f_out:
+ shutil.copyfileobj(f_in, f_out)
+
+ self.pairs = []
+ faulty = 0
+ with open(path) as fin:
+ for line in fin:
+ conversation = json.loads(line)
+ question_answer_pairs = ()
+
+ question_answers = conversation["text"].split("User: ")
+ for question_answer in question_answers[1:]: # first element is empty
+ try:
+ question, answer = question_answer.split("\nAssistant: ")
+ question_answer_pairs += (
+ question,
+ answer,
+ )
+ except ValueError:
+ # there might be some extra 'User: ' or 'Assistant: ' tokens in the dataset that cause trouble..
+ faulty += 1
+ continue
+
+ self.pairs.append(question_answer_pairs)
+
+ if verbose:
+ print("For SODA dialogue dataset found {} faults within the total {} dialogs".format(faulty, len(self)))
+
+ def __len__(self):
+ return len(self.pairs)
+
+ def __getitem__(self, index):
+ return self.pairs[index]
+
+
class JokeExplaination(Dataset):
""" """
diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt
index 0e6eeb51..8f8cc63c 100644
--- a/model/supervised_finetuning/requirements.txt
+++ b/model/supervised_finetuning/requirements.txt
@@ -3,10 +3,12 @@ bitsandbytes==0.36.0.post2
datasets==2.8.0
deepspeed==0.7.7
evaluate==0.4.0
+gdown
mpi4py==3.1.4
nltk==3.8.1
-numpy==1.23.0
-PyYAML==6.0
+numpy>=1.22.4
+py7zr
+PyYAML>=6.0
scikit_learn==1.2.0
-torch==1.13.1
+torch>=1.11.0
transformers==4.25.1