[fix] wait it pass?

This commit is contained in:
theblackcat102
2023-01-20 07:26:26 +00:00
parent 22e3ab1a89
commit aca3e9de89
@@ -11,7 +11,7 @@ def test_all_datasets():
translation = ["dive_mt", "wmt2019_zh-en", "wmt2019_ru-en", "wmt2019_de-en", "ted_trans_de-ja", "ted_trans_nl-en"]
config = Namespace(cache_dir=".cache")
for dataset_name in translation:
for dataset_name in translation + others + summarize_base + qa_base:
print(dataset_name)
train, eval = get_one_dataset(config, dataset_name)
# sanity check
@@ -49,7 +49,3 @@ def test_collate_fn():
dataloader = DataLoader(ConcatDataset(evals), collate_fn=collate_fn, batch_size=128)
for batch in dataloader:
assert batch["targets"].shape[1] <= 512
if __name__ == "__main__":
test_all_datasets()