mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
added soda dialogue dataset
This commit is contained in:
@@ -30,6 +30,7 @@ defaults:
|
||||
- joke
|
||||
- gsm8k
|
||||
- samsum
|
||||
- soda_dialogue
|
||||
cache_dir: .cache
|
||||
loss_fn: CrossEntropyLoss
|
||||
eval_size:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from custom_datasets.prompt_dialogue import PromptGeneratedDataset
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT
|
||||
from custom_datasets.summarization import SummarizationDataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset
|
||||
@@ -36,6 +36,9 @@ def get_one_dataset(conf, dataset_name):
|
||||
elif dataset_name == "soda":
|
||||
dataset = SODA(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.1)
|
||||
elif dataset_name == "soda_dialogue":
|
||||
dataset = SODADialogue(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.1)
|
||||
elif dataset_name == "joke":
|
||||
dataset = JokeExplaination(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
|
||||
@@ -106,7 +106,12 @@ class SODA(Dataset):
|
||||
def process_soda_convo(self, data):
|
||||
pairs = []
|
||||
play_as = data["speakers"][1]
|
||||
prefix = "<prefix>{}. {}</prefix>".format(data["narrative"], "your name {}".format(play_as))
|
||||
prefix = "{}{}. {}{}".format(
|
||||
QA_SPECIAL_TOKENS["StartPrefix"],
|
||||
data["narrative"],
|
||||
"your name {}".format(play_as),
|
||||
QA_SPECIAL_TOKENS["EndPrefix"],
|
||||
)
|
||||
question, answer = "", ""
|
||||
prefix, postfix = "", ""
|
||||
previous_chat = []
|
||||
@@ -119,7 +124,9 @@ class SODA(Dataset):
|
||||
answer = convo
|
||||
postfix = data["speakers"][idx]
|
||||
if len(question) and len(answer) and prefix != postfix and postfix == play_as:
|
||||
history = "<sep>".join(["{}<bot>{}".format(*p) for p in previous_chat])
|
||||
history = "<sep>".join(
|
||||
["{}{}{}".format(p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) for p in previous_chat]
|
||||
)
|
||||
if len(history):
|
||||
history += "<sep>"
|
||||
pairs.append((prefix + history + question, answer))
|
||||
@@ -148,6 +155,57 @@ class SODA(Dataset):
|
||||
return question, answer
|
||||
|
||||
|
||||
class SODADialogue(Dataset):
|
||||
url = "https://drive.google.com/uc?id=1TOGQfr419n8wpzJpYLLw4nB3tSKD8zXV"
|
||||
|
||||
def __init__(self, cache_dir, verbose=True):
|
||||
|
||||
path = os.path.join(cache_dir, "soda_dialog.jsonl")
|
||||
|
||||
if not os.path.exists(path):
|
||||
import gzip
|
||||
import shutil
|
||||
|
||||
import gdown
|
||||
|
||||
gdown.download(self.url, output=os.path.join(cache_dir, "soda_dialog.jsonl.gz"))
|
||||
|
||||
with gzip.open(os.path.join(cache_dir, "soda_dialog.jsonl.gz"), "rb") as f_in:
|
||||
with open(path, "wb") as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
|
||||
self.pairs = []
|
||||
faulty = 0
|
||||
with open(path) as fin:
|
||||
for line in fin:
|
||||
conversation = json.loads(line)
|
||||
question_answer_pairs = ()
|
||||
|
||||
question_answers = conversation["text"].split("User: ")
|
||||
for question_answer in question_answers[1:]: # first element is empty
|
||||
try:
|
||||
question, answer = question_answer.split("\nAssistant: ")
|
||||
question_answer_pairs += (
|
||||
question,
|
||||
answer,
|
||||
)
|
||||
except ValueError:
|
||||
# there might be some extra 'User: ' or 'Assistant: ' tokens in the dataset that cause trouble..
|
||||
faulty += 1
|
||||
continue
|
||||
|
||||
self.pairs.append(question_answer_pairs)
|
||||
|
||||
if verbose:
|
||||
print("For SODA dialogue dataset found {} faults within the total {} dialogs".format(faulty, len(self)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.pairs[index]
|
||||
|
||||
|
||||
class JokeExplaination(Dataset):
|
||||
""" """
|
||||
|
||||
|
||||
@@ -3,10 +3,12 @@ bitsandbytes==0.36.0.post2
|
||||
datasets==2.8.0
|
||||
deepspeed==0.7.7
|
||||
evaluate==0.4.0
|
||||
gdown
|
||||
mpi4py==3.1.4
|
||||
nltk==3.8.1
|
||||
numpy==1.23.0
|
||||
PyYAML==6.0
|
||||
numpy>=1.22.4
|
||||
py7zr
|
||||
PyYAML>=6.0
|
||||
scikit_learn==1.2.0
|
||||
torch==1.13.1
|
||||
torch>=1.11.0
|
||||
transformers==4.25.1
|
||||
|
||||
Reference in New Issue
Block a user