mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 18:22:17 +08:00
Add model utils tests
This commit is contained in:
@@ -17,7 +17,7 @@ import dataclasses
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, NewType, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, NewType, Optional, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
|
||||
@@ -87,7 +87,7 @@ class H4ArgumentParser(HfArgumentParser):
|
||||
|
||||
return outputs
|
||||
|
||||
def parse(self) -> Union[DataClassType, Tuple[DataClassType]]:
|
||||
def parse(self) -> DataClassType | Tuple[DataClassType]:
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
# If we pass only one argument to the script and it's the path to a YAML file,
|
||||
# let's parse it to get our arguments.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
||||
|
||||
@@ -67,7 +67,7 @@ def apply_chat_template(
|
||||
|
||||
|
||||
def get_datasets(
|
||||
data_config: Union[DataArguments, dict],
|
||||
data_config: DataArguments | dict,
|
||||
splits: List[str] = ["train", "test"],
|
||||
shuffle: bool = True,
|
||||
) -> DatasetDict:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, Union
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
|
||||
@@ -62,7 +62,7 @@ def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTr
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_peft_config(model_args: ModelArguments) -> Union[PeftConfig, None]:
|
||||
def get_peft_config(model_args: ModelArguments) -> PeftConfig | None:
|
||||
if model_args.use_peft is False:
|
||||
return None
|
||||
|
||||
|
||||
+70
-1
@@ -15,8 +15,9 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
from alignment import DataArguments, get_datasets
|
||||
from alignment import DataArguments, ModelArguments, apply_chat_template, get_datasets, get_tokenizer
|
||||
|
||||
|
||||
class GetDatasetsTest(unittest.TestCase):
|
||||
@@ -77,3 +78,71 @@ class GetDatasetsTest(unittest.TestCase):
|
||||
datasets = get_datasets(dataset_mixer, splits=["test"])
|
||||
self.assertEqual(len(datasets["test"]), 100)
|
||||
self.assertRaises(KeyError, lambda: datasets["train"])
|
||||
|
||||
|
||||
class ApplyChatTemplateTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha")
|
||||
data_args = DataArguments()
|
||||
self.tokenizer = get_tokenizer(model_args, data_args)
|
||||
self.dataset = Dataset.from_dict(
|
||||
{
|
||||
"prompt": ["Hello!"],
|
||||
"messages": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]],
|
||||
"chosen": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]],
|
||||
"rejected": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hola!"}]],
|
||||
}
|
||||
)
|
||||
|
||||
def test_sft(self):
|
||||
dataset = self.dataset.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={"tokenizer": self.tokenizer, "task": "sft"},
|
||||
remove_columns=self.dataset.column_names,
|
||||
)
|
||||
self.assertDictEqual(
|
||||
dataset[0],
|
||||
{"text": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n"},
|
||||
)
|
||||
|
||||
def test_generation(self):
|
||||
# Remove last turn from messages
|
||||
dataset = self.dataset.map(lambda x: {"messages": x["messages"][:-1]})
|
||||
dataset = dataset.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={"tokenizer": self.tokenizer, "task": "generation"},
|
||||
remove_columns=self.dataset.column_names,
|
||||
)
|
||||
self.assertDictEqual(
|
||||
dataset[0],
|
||||
{"text": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\n"},
|
||||
)
|
||||
|
||||
def test_rm(self):
|
||||
dataset = self.dataset.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={"tokenizer": self.tokenizer, "task": "rm"},
|
||||
remove_columns=self.dataset.column_names,
|
||||
)
|
||||
self.assertDictEqual(
|
||||
dataset[0],
|
||||
{
|
||||
"text_chosen": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n",
|
||||
"text_rejected": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\nHola!</s>\n",
|
||||
},
|
||||
)
|
||||
|
||||
def test_dpo(self):
|
||||
dataset = self.dataset.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={"tokenizer": self.tokenizer, "task": "dpo"},
|
||||
remove_columns=self.dataset.column_names,
|
||||
)
|
||||
self.assertDictEqual(
|
||||
dataset[0],
|
||||
{
|
||||
"text_prompt": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\n",
|
||||
"text_chosen": "Bonjour!</s>\n",
|
||||
"text_rejected": "Hola!</s>\n",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer
|
||||
from alignment.data import DEFAULT_CHAT_TEMPLATE
|
||||
|
||||
|
||||
class GetQuantizationConfigTest(unittest.TestCase):
|
||||
def test_4bit(self):
|
||||
model_args = ModelArguments(load_in_4bit=True)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
self.assertTrue(quantization_config.load_in_4bit)
|
||||
self.assertEqual(quantization_config.bnb_4bit_compute_dtype, torch.float16)
|
||||
self.assertEqual(quantization_config.bnb_4bit_quant_type, "nf4")
|
||||
self.assertFalse(quantization_config.bnb_4bit_use_double_quant)
|
||||
|
||||
def test_8bit(self):
|
||||
model_args = ModelArguments(load_in_8bit=True)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
self.assertTrue(quantization_config.load_in_8bit)
|
||||
|
||||
def test_no_quantization(self):
|
||||
model_args = ModelArguments()
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
self.assertIsNone(quantization_config)
|
||||
|
||||
|
||||
class GetTokenizerTest(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha")
|
||||
|
||||
def test_right_truncation_side(self):
|
||||
tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="right"))
|
||||
self.assertEqual(tokenizer.truncation_side, "right")
|
||||
|
||||
def test_left_truncation_side(self):
|
||||
tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="left"))
|
||||
self.assertEqual(tokenizer.truncation_side, "left")
|
||||
|
||||
def test_default_chat_template(self):
|
||||
tokenizer = get_tokenizer(self.model_args, DataArguments())
|
||||
self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE)
|
||||
|
||||
def test_chatml_chat_template(self):
|
||||
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template))
|
||||
self.assertEqual(tokenizer.chat_template, chat_template)
|
||||
|
||||
|
||||
class GetPeftConfigTest(unittest.TestCase):
|
||||
def test_peft_config(self):
|
||||
model_args = ModelArguments(use_peft=True, lora_r=42, lora_alpha=0.66, lora_dropout=0.99)
|
||||
peft_config = get_peft_config(model_args)
|
||||
self.assertEqual(peft_config.r, 42)
|
||||
self.assertEqual(peft_config.lora_alpha, 0.66)
|
||||
self.assertEqual(peft_config.lora_dropout, 0.99)
|
||||
|
||||
def test_no_peft_config(self):
|
||||
model_args = ModelArguments(use_peft=False)
|
||||
peft_config = get_peft_config(model_args)
|
||||
self.assertIsNone(peft_config)
|
||||
Reference in New Issue
Block a user