diff --git a/model/supervised_finetuning/tests/test_datasets.py b/model/supervised_finetuning/tests/test_datasets.py index 2ac43613..3b59f289 100644 --- a/model/supervised_finetuning/tests/test_datasets.py +++ b/model/supervised_finetuning/tests/test_datasets.py @@ -11,7 +11,7 @@ def test_all_datasets(): translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"] config = Namespace(cache_dir=".cache") - for dataset_name in translation: + for dataset_name in translation + others + summarize_base + qa_base: print(dataset_name) train, eval = get_one_dataset(config, dataset_name) # sanity check @@ -49,7 +49,3 @@ def test_collate_fn(): dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128) for batch in dataloader: assert batch["targets"].shape[1] <= 512 - - -if __name__ == "__main__": - test_all_datasets()