diff --git a/src/alignment/configs.py b/src/alignment/configs.py index 890790e..d785e16 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -17,7 +17,7 @@ import dataclasses import os import sys from dataclasses import dataclass, field -from typing import Any, Dict, List, NewType, Optional, Tuple, Union +from typing import Any, Dict, List, NewType, Optional, Tuple import transformers from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser @@ -87,7 +87,7 @@ class H4ArgumentParser(HfArgumentParser): return outputs - def parse(self) -> Union[DataClassType, Tuple[DataClassType]]: + def parse(self) -> DataClassType | Tuple[DataClassType]: if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): # If we pass only one argument to the script and it's the path to a YAML file, # let's parse it to get our arguments. diff --git a/src/alignment/data.py b/src/alignment/data.py index de25e61..2150095 100644 --- a/src/alignment/data.py +++ b/src/alignment/data.py @@ -1,5 +1,5 @@ import re -from typing import List, Literal, Optional, Union +from typing import List, Literal, Optional from datasets import DatasetDict, concatenate_datasets, load_dataset @@ -67,7 +67,7 @@ def apply_chat_template( def get_datasets( - data_config: Union[DataArguments, dict], + data_config: DataArguments | dict, splits: List[str] = ["train", "test"], shuffle: bool = True, ) -> DatasetDict: diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index a01f1b4..f237745 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Dict import torch from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer @@ -62,7 +62,7 @@ def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTr return tokenizer -def get_peft_config(model_args: ModelArguments) -> Union[PeftConfig, None]: +def get_peft_config(model_args: ModelArguments) -> PeftConfig | None: if model_args.use_peft is False: return None diff --git a/tests/test_data.py b/tests/test_data.py index 63e390a..6a6b63c 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -15,8 +15,9 @@ import unittest import pytest +from datasets import Dataset -from alignment import DataArguments, get_datasets +from alignment import DataArguments, ModelArguments, apply_chat_template, get_datasets, get_tokenizer class GetDatasetsTest(unittest.TestCase): @@ -77,3 +78,71 @@ class GetDatasetsTest(unittest.TestCase): datasets = get_datasets(dataset_mixer, splits=["test"]) self.assertEqual(len(datasets["test"]), 100) self.assertRaises(KeyError, lambda: datasets["train"]) + + +class ApplyChatTemplateTest(unittest.TestCase): + def setUp(self): + model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha") + data_args = DataArguments() + self.tokenizer = get_tokenizer(model_args, data_args) + self.dataset = Dataset.from_dict( + { + "prompt": ["Hello!"], + "messages": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]], + "chosen": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]], + "rejected": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hola!"}]], + } + ) + + def test_sft(self): + dataset = self.dataset.map( + apply_chat_template, + fn_kwargs={"tokenizer": self.tokenizer, "task": "sft"}, + remove_columns=self.dataset.column_names, + ) + self.assertDictEqual( + dataset[0], + {"text": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n"}, + ) + + def test_generation(self): + # Remove last turn from messages + dataset = self.dataset.map(lambda x: {"messages": x["messages"][:-1]}) + dataset = dataset.map( + apply_chat_template, + fn_kwargs={"tokenizer": self.tokenizer, "task": "generation"}, + remove_columns=self.dataset.column_names, + ) + self.assertDictEqual( + dataset[0], + {"text": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\n"}, + ) + + def test_rm(self): + dataset = self.dataset.map( + apply_chat_template, + fn_kwargs={"tokenizer": self.tokenizer, "task": "rm"}, + remove_columns=self.dataset.column_names, + ) + self.assertDictEqual( + dataset[0], + { + "text_chosen": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n", + "text_rejected": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\nHola!\n", + }, + ) + + def test_dpo(self): + dataset = self.dataset.map( + apply_chat_template, + fn_kwargs={"tokenizer": self.tokenizer, "task": "dpo"}, + remove_columns=self.dataset.column_names, + ) + self.assertDictEqual( + dataset[0], + { + "text_prompt": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\n", + "text_chosen": "Bonjour!\n", + "text_rejected": "Hola!\n", + }, + ) diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py new file mode 100644 index 0000000..d20afec --- /dev/null +++ b/tests/test_model_utils.py @@ -0,0 +1,76 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch + +from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer +from alignment.data import DEFAULT_CHAT_TEMPLATE + + +class GetQuantizationConfigTest(unittest.TestCase): + def test_4bit(self): + model_args = ModelArguments(load_in_4bit=True) + quantization_config = get_quantization_config(model_args) + self.assertTrue(quantization_config.load_in_4bit) + self.assertEqual(quantization_config.bnb_4bit_compute_dtype, torch.float16) + self.assertEqual(quantization_config.bnb_4bit_quant_type, "nf4") + self.assertFalse(quantization_config.bnb_4bit_use_double_quant) + + def test_8bit(self): + model_args = ModelArguments(load_in_8bit=True) + quantization_config = get_quantization_config(model_args) + self.assertTrue(quantization_config.load_in_8bit) + + def test_no_quantization(self): + model_args = ModelArguments() + quantization_config = get_quantization_config(model_args) + self.assertIsNone(quantization_config) + + +class GetTokenizerTest(unittest.TestCase): + def setUp(self) -> None: + self.model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha") + + def test_right_truncation_side(self): + tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="right")) + self.assertEqual(tokenizer.truncation_side, "right") + + def test_left_truncation_side(self): + tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="left")) + self.assertEqual(tokenizer.truncation_side, "left") + + def test_default_chat_template(self): + tokenizer = get_tokenizer(self.model_args, DataArguments()) + self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE) + + def test_chatml_chat_template(self): + chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template)) + self.assertEqual(tokenizer.chat_template, chat_template) + + +class GetPeftConfigTest(unittest.TestCase): + def test_peft_config(self): + model_args = ModelArguments(use_peft=True, lora_r=42, lora_alpha=0.66, lora_dropout=0.99) + peft_config = get_peft_config(model_args) + self.assertEqual(peft_config.r, 42) + self.assertEqual(peft_config.lora_alpha, 0.66) + self.assertEqual(peft_config.lora_dropout, 0.99) + + def test_no_peft_config(self): + model_args = ModelArguments(use_peft=False) + peft_config = get_peft_config(model_args) + self.assertIsNone(peft_config)