# coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import pytest from alignment import DataArguments, get_datasets class GetDatasetsTest(unittest.TestCase): """Each of these test datasets has 100 examples""" def test_loading_data_args(self): dataset_mixer = { "HuggingFaceH4/testing_alpaca_small": 0.5, "HuggingFaceH4/testing_self_instruct_small": 0.3, "HuggingFaceH4/testing_codealpaca_small": 0.2, } data_args = DataArguments(dataset_mixer=dataset_mixer) datasets = get_datasets(data_args) self.assertEqual(len(datasets["train"]), 100) self.assertEqual(len(datasets["test"]), 300) def test_loading_data_dict(self): dataset_mixer = { "HuggingFaceH4/testing_alpaca_small": 0.5, "HuggingFaceH4/testing_self_instruct_small": 0.3, "HuggingFaceH4/testing_codealpaca_small": 0.2, } datasets = get_datasets(dataset_mixer) self.assertEqual(len(datasets["train"]), 100) self.assertEqual(len(datasets["test"]), 300) def test_loading_with_unit_fractions(self): dataset_mixer = { "HuggingFaceH4/testing_alpaca_small": 1.0, "HuggingFaceH4/testing_self_instruct_small": 1.0, "HuggingFaceH4/testing_codealpaca_small": 1.0, } datasets = get_datasets(dataset_mixer) self.assertEqual(len(datasets["train"]), 300) self.assertEqual(len(datasets["test"]), 300) def test_loading_with_fractions_greater_than_unity(self): dataset_mixer = { "HuggingFaceH4/testing_alpaca_small": 0.7, "HuggingFaceH4/testing_self_instruct_small": 0.4, } datasets = get_datasets(dataset_mixer) self.assertEqual(len(datasets["train"]), 70 + 40) self.assertEqual(len(datasets["test"]), 200) def test_loading_fails_with_negative_fractions(self): dataset_mixer = { "HuggingFaceH4/testing_alpaca_small": 0.7, "HuggingFaceH4/testing_self_instruct_small": -0.3, } with pytest.raises(ValueError, match=r"Dataset fractions cannot be negative."): get_datasets(dataset_mixer) def test_loading_single_split_with_unit_fractions(self): dataset_mixer = { "HuggingFaceH4/testing_alpaca_small": 1.0, } datasets = get_datasets(dataset_mixer, splits=["test"]) self.assertEqual(len(datasets["test"]), 100) self.assertRaises(KeyError, lambda: datasets["train"])