mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 16:14:07 +08:00
80 lines
3.1 KiB
Python
80 lines
3.1 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import unittest
|
|
|
|
import pytest
|
|
|
|
from alignment import DataArguments, get_datasets
|
|
|
|
|
|
class GetDatasetsTest(unittest.TestCase):
|
|
"""Each of these test datasets has 100 examples"""
|
|
|
|
def test_loading_data_args(self):
|
|
dataset_mixer = {
|
|
"HuggingFaceH4/testing_alpaca_small": 0.5,
|
|
"HuggingFaceH4/testing_self_instruct_small": 0.3,
|
|
"HuggingFaceH4/testing_codealpaca_small": 0.2,
|
|
}
|
|
data_args = DataArguments(dataset_mixer=dataset_mixer)
|
|
datasets = get_datasets(data_args)
|
|
self.assertEqual(len(datasets["train"]), 100)
|
|
self.assertEqual(len(datasets["test"]), 300)
|
|
|
|
def test_loading_data_dict(self):
|
|
dataset_mixer = {
|
|
"HuggingFaceH4/testing_alpaca_small": 0.5,
|
|
"HuggingFaceH4/testing_self_instruct_small": 0.3,
|
|
"HuggingFaceH4/testing_codealpaca_small": 0.2,
|
|
}
|
|
datasets = get_datasets(dataset_mixer)
|
|
self.assertEqual(len(datasets["train"]), 100)
|
|
self.assertEqual(len(datasets["test"]), 300)
|
|
|
|
def test_loading_with_unit_fractions(self):
|
|
dataset_mixer = {
|
|
"HuggingFaceH4/testing_alpaca_small": 1.0,
|
|
"HuggingFaceH4/testing_self_instruct_small": 1.0,
|
|
"HuggingFaceH4/testing_codealpaca_small": 1.0,
|
|
}
|
|
datasets = get_datasets(dataset_mixer)
|
|
self.assertEqual(len(datasets["train"]), 300)
|
|
self.assertEqual(len(datasets["test"]), 300)
|
|
|
|
def test_loading_with_fractions_greater_than_unity(self):
|
|
dataset_mixer = {
|
|
"HuggingFaceH4/testing_alpaca_small": 0.7,
|
|
"HuggingFaceH4/testing_self_instruct_small": 0.4,
|
|
}
|
|
datasets = get_datasets(dataset_mixer)
|
|
self.assertEqual(len(datasets["train"]), 70 + 40)
|
|
self.assertEqual(len(datasets["test"]), 200)
|
|
|
|
def test_loading_fails_with_negative_fractions(self):
|
|
dataset_mixer = {
|
|
"HuggingFaceH4/testing_alpaca_small": 0.7,
|
|
"HuggingFaceH4/testing_self_instruct_small": -0.3,
|
|
}
|
|
with pytest.raises(ValueError, match=r"Dataset fractions cannot be negative."):
|
|
get_datasets(dataset_mixer)
|
|
|
|
def test_loading_single_split_with_unit_fractions(self):
|
|
dataset_mixer = {
|
|
"HuggingFaceH4/testing_alpaca_small": 1.0,
|
|
}
|
|
datasets = get_datasets(dataset_mixer, splits=["test"])
|
|
self.assertEqual(len(datasets["test"]), 100)
|
|
self.assertRaises(KeyError, lambda: datasets["train"])
|