mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
[feature] added translation, rallio instruct tuning dataset, prosocial for safety, new summary dataset
This commit is contained in:
@@ -2,7 +2,7 @@ model_name: microsoft/deberta-v3-base
|
||||
learning_rate: 1e-5
|
||||
scheduler: cosine
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 32
|
||||
gradient_accumulation_steps: 16
|
||||
per_device_train_batch_size: 2
|
||||
warmup_steps: 600
|
||||
eval_steps: 200
|
||||
|
||||
@@ -29,6 +29,14 @@ defaults:
|
||||
- soda
|
||||
- joke
|
||||
- gsm8k
|
||||
- dive_mt
|
||||
- wmt2019_zh-en
|
||||
- wmt2019_ru-en
|
||||
- wmt2019_de-en
|
||||
- ted_trans_nl-en
|
||||
- ted_trans_de-ja
|
||||
- instruct_tuning
|
||||
- wmt2019_de-en
|
||||
- samsum
|
||||
cache_dir: .cache
|
||||
loss_fn: CrossEntropyLoss
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
# Dataset collections overview:
|
||||
|
||||
currently dataset can be divided into 3 classes
|
||||
|
||||
- language knowledge
|
||||
|
||||
- summarization
|
||||
|
||||
- translation
|
||||
|
||||
- dialogue : don't let user know you are a robot
|
||||
|
||||
- STEM : knowledge about the world
|
||||
|
||||
- coding
|
||||
|
||||
- world knowledge <= ideally we want to handle this via prefix context
|
||||
|
||||
Issues and TODO:
|
||||
|
||||
* as dataset are growing, how can we update this section less
|
||||
|
||||
* ideally we can update the config yaml and new dataset will be download from hub
|
||||
|
||||
* one possible idea is we upload the trasform format of these dataset to the OA hub
|
||||
|
||||
@@ -1,11 +1,26 @@
|
||||
from custom_datasets.prompt_dialogue import PromptGeneratedDataset
|
||||
"""
|
||||
High level functions for model training
|
||||
"""
|
||||
from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT
|
||||
from custom_datasets.summarization import SummarizationDataset
|
||||
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
|
||||
from custom_datasets.translation import WMT2019, DiveMT, TEDTalk
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset
|
||||
|
||||
QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_nocontext", "gsm8k"]
|
||||
SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum"]
|
||||
SUMMARIZATION_DATASETS = [
|
||||
"xsum",
|
||||
"cnn_dailymail",
|
||||
"samsum",
|
||||
"multi_news",
|
||||
"scitldr",
|
||||
"billsum",
|
||||
"debate_sum",
|
||||
"tldr_news",
|
||||
]
|
||||
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"]
|
||||
|
||||
|
||||
def train_val_dataset(dataset, val_split=0.2):
|
||||
@@ -25,20 +40,43 @@ def get_one_dataset(conf, dataset_name):
|
||||
|
||||
elif dataset_name in SUMMARIZATION_DATASETS:
|
||||
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
|
||||
val_name = "validation" if dataset_name not in ["billsum"] else "test"
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
|
||||
if dataset_name == "debate_sum":
|
||||
train, eval = train_val_dataset(train, val_split=0.2)
|
||||
else:
|
||||
val_name = "validation" if dataset_name not in ["billsum"] else "test"
|
||||
eval = SummarizationDataset(dataset_name, conf.cache_dir, val_name)
|
||||
elif "ted_trans" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
dataset = TEDTalk(pair=language_pair, split="train")
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif "wmt2019" in dataset_name:
|
||||
language_pair = dataset_name.split("_")[-1]
|
||||
train = WMT2019(pair=language_pair, split="train")
|
||||
eval = WMT2019(pair=language_pair, split="validation")
|
||||
elif dataset_name == "dive_mt":
|
||||
dataset = DiveMT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "webgpt":
|
||||
dataset = WebGPT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "prompt_dialogue":
|
||||
dataset = PromptGeneratedDataset(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "prosocial_dialogue":
|
||||
train = ProsocialDialogue(cache_dir=conf.cache_dir, split="train")
|
||||
eval = ProsocialDialogue(cache_dir=conf.cache_dir, split="validation")
|
||||
elif dataset_name == "explain_prosocial":
|
||||
train = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="train")
|
||||
eval = ProsocialDialogueExplaination(cache_dir=conf.cache_dir, split="validation")
|
||||
elif dataset_name == "soda":
|
||||
dataset = SODA(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.1)
|
||||
elif dataset_name == "joke":
|
||||
dataset = JokeExplaination(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "instruct_tuning":
|
||||
dataset = InstructionTuning(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {dataset_name}")
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ class DialogueDataCollator:
|
||||
|
||||
for feature_one in features:
|
||||
assert len(feature_one) % 2 == 0, "Number of messages must be even"
|
||||
# TODO: we should push this to dataset __getitem__
|
||||
messages = [
|
||||
(QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "")
|
||||
+ x
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
from urllib.request import urlopen
|
||||
|
||||
@@ -14,6 +15,7 @@ class PromptGeneratedDataset(Dataset):
|
||||
we are ignoring results with multiple lines for now
|
||||
"""
|
||||
|
||||
name = "prompt_dialogue"
|
||||
url = "https://github.com/Rallio67/language-model-agents/raw/main/chat_dialogue_v2_c.txt"
|
||||
|
||||
def __init__(self, cache_dir) -> None:
|
||||
@@ -49,3 +51,55 @@ class PromptGeneratedDataset(Dataset):
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
|
||||
|
||||
class InstructionTuning(Dataset):
|
||||
"""
|
||||
We have seen some promising capabilities from instruction tuning
|
||||
with the following mix of datasets that are derived from datasets
|
||||
available online.
|
||||
The files for this data are in json format as a list of tuples
|
||||
where each tuple is (source,instruction_response_pair)
|
||||
|
||||
- instruction_tuning_dataset_alpha_part1.json
|
||||
- instruction_tuning_dataset_alpha_part2.json
|
||||
|
||||
Not to be confused with unatural instruction
|
||||
"""
|
||||
|
||||
name = "instruction_dataset"
|
||||
url_part_2 = (
|
||||
"https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part2.json"
|
||||
)
|
||||
url_part_1 = (
|
||||
"https://github.com/Rallio67/language-model-agents/raw/main/instruction_tuning_dataset_alpha_part1.json"
|
||||
)
|
||||
|
||||
def __init__(self, cache_dir) -> None:
|
||||
super().__init__()
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
self.pairs = []
|
||||
for file_link in [self.url_part_1, self.url_part_2]:
|
||||
basename = file_link.split("/")[-1]
|
||||
instruction_tune_file = os.path.join(cache_dir, basename)
|
||||
if not os.path.exists(instruction_tune_file):
|
||||
with urlopen(file_link) as file:
|
||||
content = file.read().decode()
|
||||
with open(instruction_tune_file, "w", encoding="utf-8") as fout:
|
||||
fout.write(content)
|
||||
|
||||
with open(instruction_tune_file, "r", encoding="utf-8") as f:
|
||||
datasets = json.load(f)
|
||||
for row in datasets:
|
||||
_, response_pair = row
|
||||
question, answer = response_pair.split("\n\n", maxsplit=1)
|
||||
answer = answer.replace("<|endoftext|>", "").strip()
|
||||
self.pairs.append((question, answer))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
"""
|
||||
Open / close book QA datasets
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from urllib.request import urlopen
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# @agoryuno contributed this
|
||||
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
|
||||
|
||||
QA_SPECIAL_TOKENS = {"Question": "<human>", "Answer": "<bot>", "StartPrefix": "<prefix>", "EndPrefix": "</prefix>"}
|
||||
|
||||
|
||||
@@ -75,6 +82,9 @@ class QADataset(Dataset):
|
||||
|
||||
|
||||
class WebGPT(Dataset):
|
||||
|
||||
name = "webgpt"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -89,7 +99,9 @@ class WebGPT(Dataset):
|
||||
self.index2question[len(self.index2question)] = question
|
||||
|
||||
# only keep the best answer
|
||||
questions[question] = row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"]
|
||||
questions[question] = re_reference_remove.sub(
|
||||
"", row["answer_0" if row["score_0"] > row["score_1"] else "answer_1"]
|
||||
)
|
||||
|
||||
self.questions = questions
|
||||
|
||||
@@ -103,6 +115,9 @@ class WebGPT(Dataset):
|
||||
|
||||
|
||||
class SODA(Dataset):
|
||||
|
||||
name = "soda"
|
||||
|
||||
def process_soda_convo(self, data):
|
||||
pairs = []
|
||||
play_as = data["speakers"][1]
|
||||
@@ -149,8 +164,8 @@ class SODA(Dataset):
|
||||
|
||||
|
||||
class JokeExplaination(Dataset):
|
||||
""" """
|
||||
|
||||
name = "joke"
|
||||
url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl"
|
||||
|
||||
def __init__(self, cache_dir) -> None:
|
||||
@@ -182,3 +197,6 @@ class JokeExplaination(Dataset):
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
|
||||
|
||||
# https://huggingface.co/datasets/aquamuse
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
"""
|
||||
Summarize different spectrum of documents
|
||||
"""
|
||||
import random
|
||||
|
||||
from datasets import load_dataset
|
||||
@@ -12,13 +15,21 @@ SUMMARY_SPECIAL_PROMPT = {
|
||||
}
|
||||
|
||||
summarization_config_mapping = {
|
||||
"cnn_dailymail": ("3.0.0",),
|
||||
"samsum": (),
|
||||
"xsum": (),
|
||||
"multi_news": (),
|
||||
"scitldr": ("AIC",),
|
||||
"billsum": (),
|
||||
"reddit": (),
|
||||
"cnn_dailymail": (
|
||||
"cnn_dailymail",
|
||||
"3.0.0",
|
||||
),
|
||||
"samsum": ("samsum",),
|
||||
"xsum": ("xsum",),
|
||||
"multi_news": ("multi_news",),
|
||||
"scitldr": (
|
||||
"scitldr",
|
||||
"AIC",
|
||||
),
|
||||
"billsum": ("billsum",),
|
||||
"reddit": ("reddit",),
|
||||
"tldr_news": ("JulesBelveze/tldr_news",), # need to fix : JulesBelveze/tldr_news
|
||||
"debate_sum": ("Hellisotherpeople/DebateSum",), # Hellisotherpeople/DebateSum
|
||||
}
|
||||
|
||||
summarization_name_mapping = {
|
||||
@@ -29,6 +40,8 @@ summarization_name_mapping = {
|
||||
"scitldr": ("source", "target"),
|
||||
"billsum": ("text", "summary"),
|
||||
"reddit": ("content", "summary"),
|
||||
"tldr_news": ("content", "headline"),
|
||||
"debate_sum": ("Full-Document", "Extract"),
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +56,7 @@ def index_summary_merge(text, summary):
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, dataset, cache_dir, split):
|
||||
self.name = dataset
|
||||
self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
|
||||
self.text_column, self.summary_column = summarization_name_mapping[dataset]
|
||||
self.preprocess_fn = index_summary_merge if dataset == "scitldr" else index_summary_default
|
||||
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
'''
|
||||
SFT dataset to reject toxic questions
|
||||
|
||||
'''
|
||||
import random
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class ProsocialDialogueExplaination(Dataset):
|
||||
name = "prosocial_explain"
|
||||
TEMPLATE = [
|
||||
# 0 : reply or sentence of interest, 1 : reason of caution
|
||||
("'{0}' Why is this sentence {1}", "This sentence is {1} because {0}"),
|
||||
("Explain to me why this sentence is {1}: {0}", "This sentence is {1} because {0}"),
|
||||
("'{0}' Why is this sentence {1}", "This sentence may need {1} because {0}"),
|
||||
("'{0}' Explain to me why we need to {1} of this sentence", "We need to be {1} of this sentence because {0}"),
|
||||
]
|
||||
def __init__(self, split='train', cache_dir='.cache') -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
|
||||
self.pairs = []
|
||||
for row in dataset:
|
||||
for safety_annotation, safe_answer in zip(row['safety_annotations'], row['safety_annotation_reasons']):
|
||||
(prompt_template, answer_template) = random.choice(self.TEMPLATE)
|
||||
self.pairs.append((
|
||||
prompt_template.format(row['context'],safety_annotation),
|
||||
answer_template.format( safe_answer, safety_annotation)
|
||||
))
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.pairs[idx]
|
||||
|
||||
class ProsocialDialogue(Dataset):
|
||||
name = "prosocial_dialogue"
|
||||
'''
|
||||
ProsocialDialog, we set up a human-AI collaborative data creation framework,
|
||||
where GPT-3 generates the potentially unsafe utterances, and crowdworkers
|
||||
provide prosocial responses to them. This approach allows us to circumvent
|
||||
two substantial challenges:
|
||||
(1) there are no available large-scale corpora of multiturn prosocial conversations
|
||||
between humans
|
||||
(2) asking humans to write unethical, toxic, or problematic utterances could result
|
||||
in psychological harms (Roberts, 2017; Steiger et al., 2021).
|
||||
'''
|
||||
PREFIX = "<prefix>You are now a prosocial chatbot, be caution and casual when reply</prefix>"
|
||||
|
||||
|
||||
def __init__(self, split='train', cache_dir='.cache') -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset("allenai/prosocial-dialog", cache_dir=cache_dir)[split]
|
||||
self.pairs = []
|
||||
for row in dataset:
|
||||
for answer in row['rots']:
|
||||
self.pairs.append((
|
||||
self.PREFIX+row['context'],
|
||||
answer
|
||||
))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.pairs[idx]
|
||||
@@ -0,0 +1,157 @@
|
||||
'''
|
||||
List of translation dataset
|
||||
|
||||
GroNLP/divemt
|
||||
|
||||
fill in the blanks : https://huggingface.co/datasets/m_lama
|
||||
|
||||
'''
|
||||
import random
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# postfix prompt
|
||||
TRANSLATION_PROMPT = {
|
||||
"zh": [ # simplified or any chinese which was not mentioned
|
||||
"Translate to chinese simplified: {}",
|
||||
"{}, translate to chinese",
|
||||
"{} give me the chinese translation",
|
||||
"翻译成中文: {}",
|
||||
"{} 这句中文翻译怎麽写?",
|
||||
"我需要这句话的中文翻译: {}",
|
||||
],
|
||||
"zh-tw": [ # WMT code
|
||||
"{}. Translate to chinese traditional",
|
||||
"{}, translate to chinese",
|
||||
"{}. get chinese translation",
|
||||
"中文翻譯: {}",
|
||||
"幫我翻譯成中文: '{}'",
|
||||
"{} 這句中文翻譯怎麼寫?",
|
||||
],
|
||||
"ja": [
|
||||
"{}: help me translate to japanese",
|
||||
"Need japanese translation: {}",
|
||||
"{}: にほんごやくをよこす",
|
||||
"{}: にほんごやくをおくれ",
|
||||
"{}: にほんごやくを じょす",
|
||||
"give me the japanese translation, {}",
|
||||
],
|
||||
"de": [
|
||||
"{}: translate to german",
|
||||
"give me the german translation {}",
|
||||
"I want german translation {}",
|
||||
"{}, ins Deutsche übersetzen",
|
||||
"{}, Übersetzen ins Deutsche",
|
||||
],
|
||||
"fr": [
|
||||
"{}. translate to french",
|
||||
"{} write in french",
|
||||
"{} french translation",
|
||||
"{} ,donnez moi la traduction française"],
|
||||
"ko": [
|
||||
"{}. translate to Korean",
|
||||
"how do we write in korean: {}",
|
||||
"give me the korean translation: {}",
|
||||
"{}, 한국어 번역을 해주세요",
|
||||
],
|
||||
"ms": [
|
||||
"{} translate to malay",
|
||||
"{} how do we write in Malay",
|
||||
"{} give me the malay translation",
|
||||
"{} , berikan saya terjemahan dalam bahasa melayu",
|
||||
"{}, Jemahan di bahasa melayu"
|
||||
"{}, jemahkan ayat ini kepada bahasa melayu"
|
||||
],
|
||||
"en": ["{}. translate to english", "{} write in english", "english translation: '{}'"],
|
||||
"tr": ["{}. translate to turkish", "{} write in turkish", "turkish translation: '{}'"],
|
||||
"it": ["{}. translate to italian", "{} write in italian", "italian translation: '{}'"],
|
||||
"nl": ["{}. translate to dutch", "{} write in dutch", "dutch translation: '{}'"],
|
||||
"vi": ["{}. translate to vietnamese", "{} write in vietnamese", "vietnamese translation: '{}'"],
|
||||
"ar": ["{}. translate to arabic", "{} write in arabic", "arabic translation: '{}'"],
|
||||
}
|
||||
class TranslationPair(Dataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.pairs = []
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.pairs[index]
|
||||
|
||||
|
||||
class WMT2019(TranslationPair):
|
||||
|
||||
def __init__(self, pair='zh-en', split='train') -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset('wmt19', pair)[split]
|
||||
self.pairs = []
|
||||
src, tgt = pair.split('-')
|
||||
for row in dataset:
|
||||
row = row['translation']
|
||||
if random.random() > 0.5:
|
||||
source = random.choice(
|
||||
TRANSLATION_PROMPT[tgt]
|
||||
).format(row[src])
|
||||
self.pairs.append((source, row[tgt]))
|
||||
else:# translating in reverse direction
|
||||
source = random.choice(
|
||||
TRANSLATION_PROMPT[src]
|
||||
).format(row[tgt])
|
||||
self.pairs.append((source, row[src]))
|
||||
|
||||
class DiveMT(TranslationPair):
|
||||
|
||||
REMAP = {
|
||||
'tur': 'tr',
|
||||
'ita': 'it',
|
||||
'ukr': 'uk',
|
||||
'nld': 'nl',
|
||||
'vie': 'vi',
|
||||
'ara': 'ar'
|
||||
}
|
||||
|
||||
def __init__(self, split='train') -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset('GroNLP/divemt', 'main')[split]
|
||||
tgt, src = 'tgt_text', 'src_text'
|
||||
for row in dataset:
|
||||
# ISO 639-2
|
||||
lang_code_2 = row['subject_id'].split('_')[0]
|
||||
lang_code = self.REMAP[lang_code_2]
|
||||
if lang_code not in TRANSLATION_PROMPT:
|
||||
continue
|
||||
|
||||
if random.random() > 0.5:
|
||||
source = random.choice(
|
||||
TRANSLATION_PROMPT[lang_code]
|
||||
).format(row[src])
|
||||
self.pairs.append((source, row[tgt]))
|
||||
else:# translating in reverse direction
|
||||
lang_code = 'en'
|
||||
source = random.choice(
|
||||
TRANSLATION_PROMPT[lang_code]
|
||||
).format(row[tgt])
|
||||
self.pairs.append((source, row[src]))
|
||||
|
||||
|
||||
class TEDTalk(TranslationPair):
|
||||
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean
|
||||
|
||||
def __init__(self, pair='de-ja', split='train', year='2016') -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset('ted_talks_iwslt', language_pair=pair.split('-'), year=year)[split]
|
||||
src, tgt = pair.split('-')
|
||||
for row in dataset:
|
||||
row = row['translation']
|
||||
if random.random() > 0.5:
|
||||
source = random.choice(
|
||||
TRANSLATION_PROMPT[tgt]
|
||||
).format(row[src])
|
||||
self.pairs.append((source, row[tgt]))
|
||||
else:# translating in reverse direction
|
||||
source = random.choice(
|
||||
TRANSLATION_PROMPT[src]
|
||||
).format(row[tgt])
|
||||
self.pairs.append((source, row[src]))
|
||||
@@ -7,10 +7,11 @@ from custom_datasets.dialogue_collator import DialogueDataCollator
|
||||
def test_all_datasets():
|
||||
qa_base = QA_DATASETS
|
||||
summarize_base = SUMMARIZATION_DATASETS
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke"]
|
||||
others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"]
|
||||
translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"]
|
||||
|
||||
config = Namespace(cache_dir=".cache")
|
||||
for dataset_name in others + qa_base + summarize_base:
|
||||
for dataset_name in translation:
|
||||
print(dataset_name)
|
||||
train, eval = get_one_dataset(config, dataset_name)
|
||||
# sanity check
|
||||
@@ -51,4 +52,4 @@ def test_collate_fn():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_collate_fn()
|
||||
test_all_datasets()
|
||||
|
||||
Reference in New Issue
Block a user