Files
Open-Assistant/model/supervised_finetuning/tests/test_datasets.py
T
2023-01-20 07:26:26 +00:00

52 lines
2.0 KiB
Python

from argparse import Namespace
from custom_datasets import QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator
def test_all_datasets():
qa_base = QA_DATASETS
summarize_base = SUMMARIZATION_DATASETS
others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning"]
translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"]
config = Namespace(cache_dir=".cache")
for dataset_name in translation + others + summarize_base + qa_base:
print(dataset_name)
train, eval = get_one_dataset(config, dataset_name)
# sanity check
for idx in range(min(len(train), 1000)):
train[idx]
for idx in range(min(len(eval), 1000)):
eval[idx]
def test_collate_fn():
from torch.utils.data import ConcatDataset, DataLoader
from utils import get_tokenizer
config = Namespace(cache_dir=".cache", model_name="Salesforce/codegen-2B-multi")
tokenizer = get_tokenizer(config)
collate_fn = DialogueDataCollator(tokenizer, max_length=512)
qa_base = QA_DATASETS
summarize_base = SUMMARIZATION_DATASETS
others = ["prompt_dialogue", "webgpt", "soda", "joke", "gsm8k"]
trains, evals = [], []
for dataset_name in others + qa_base + summarize_base:
print(dataset_name)
train, eval = get_one_dataset(config, dataset_name)
trains.append(train)
evals.append(eval)
dataloader = DataLoader(ConcatDataset(trains), collate_fn=collate_fn, batch_size=128)
for batch in dataloader:
# print(batch.keys())
# print(tokenizer.decode(batch['input_ids'][0]))
# print('-----')
# print(tokenizer.decode(batch['targets'][0][batch['label_masks'][0]]))
assert batch["targets"].shape[1] <= 512
dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
for batch in dataloader:
assert batch["targets"].shape[1] <= 512