added soda dialogue dataset

This commit is contained in:
Sotirios Anagnostidis
2023-01-19 14:25:19 +01:00
parent ef8a00e682
commit 5d9591d82c
4 changed files with 70 additions and 6 deletions
@@ -30,6 +30,7 @@ defaults:
- joke
- gsm8k
- samsum
- soda_dialogue
cache_dir: .cache
loss_fn: CrossEntropyLoss
eval_size:
@@ -1,5 +1,5 @@
from custom_datasets.prompt_dialogue import PromptGeneratedDataset
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT
from custom_datasets.summarization import SummarizationDataset
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
@@ -36,6 +36,9 @@ def get_one_dataset(conf, dataset_name):
elif dataset_name == "soda":
dataset = SODA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.1)
elif dataset_name == "soda_dialogue":
dataset = SODADialogue(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.1)
elif dataset_name == "joke":
dataset = JokeExplaination(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
@@ -106,7 +106,12 @@ class SODA(Dataset):
def process_soda_convo(self, data):
pairs = []
play_as = data["speakers"][1]
prefix = "<prefix>{}. {}</prefix>".format(data["narrative"], "your name {}".format(play_as))
prefix = "{}{}. {}{}".format(
QA_SPECIAL_TOKENS["StartPrefix"],
data["narrative"],
"your name {}".format(play_as),
QA_SPECIAL_TOKENS["EndPrefix"],
)
question, answer = "", ""
prefix, postfix = "", ""
previous_chat = []
@@ -119,7 +124,9 @@ class SODA(Dataset):
answer = convo
postfix = data["speakers"][idx]
if len(question) and len(answer) and prefix != postfix and postfix == play_as:
history = "<sep>".join(["{}<bot>{}".format(*p) for p in previous_chat])
history = "<sep>".join(
["{}{}{}".format(p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) for p in previous_chat]
)
if len(history):
history += "<sep>"
pairs.append((prefix + history + question, answer))
@@ -148,6 +155,57 @@ class SODA(Dataset):
return question, answer
class SODADialogue(Dataset):
url = "https://drive.google.com/uc?id=1TOGQfr419n8wpzJpYLLw4nB3tSKD8zXV"
def __init__(self, cache_dir, verbose=True):
path = os.path.join(cache_dir, "soda_dialog.jsonl")
if not os.path.exists(path):
import gzip
import shutil
import gdown
gdown.download(self.url, output=os.path.join(cache_dir, "soda_dialog.jsonl.gz"))
with gzip.open(os.path.join(cache_dir, "soda_dialog.jsonl.gz"), "rb") as f_in:
with open(path, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
self.pairs = []
faulty = 0
with open(path) as fin:
for line in fin:
conversation = json.loads(line)
question_answer_pairs = ()
question_answers = conversation["text"].split("User: ")
for question_answer in question_answers[1:]: # first element is empty
try:
question, answer = question_answer.split("\nAssistant: ")
question_answer_pairs += (
question,
answer,
)
except ValueError:
# there might be some extra 'User: ' or 'Assistant: ' tokens in the dataset that cause trouble..
faulty += 1
continue
self.pairs.append(question_answer_pairs)
if verbose:
print("For SODA dialogue dataset found {} faults within the total {} dialogs".format(faulty, len(self)))
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
return self.pairs[index]
class JokeExplaination(Dataset):
""" """
+5 -3
View File
@@ -3,10 +3,12 @@ bitsandbytes==0.36.0.post2
datasets==2.8.0
deepspeed==0.7.7
evaluate==0.4.0
gdown
mpi4py==3.1.4
nltk==3.8.1
numpy==1.23.0
PyYAML==6.0
numpy>=1.22.4
py7zr
PyYAML>=6.0
scikit_learn==1.2.0
torch==1.13.1
torch>=1.11.0
transformers==4.25.1