Add model utils tests

This commit is contained in:
Lewis Tunstall
2023-11-10 09:42:15 +00:00
parent 0af8011993
commit 2ed5a45d25
5 changed files with 152 additions and 7 deletions
+2 -2
View File
@@ -17,7 +17,7 @@ import dataclasses
import os
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, NewType, Optional, Tuple, Union
from typing import Any, Dict, List, NewType, Optional, Tuple
import transformers
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
@@ -87,7 +87,7 @@ class H4ArgumentParser(HfArgumentParser):
return outputs
def parse(self) -> Union[DataClassType, Tuple[DataClassType]]:
def parse(self) -> DataClassType | Tuple[DataClassType]:
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
# If we pass only one argument to the script and it's the path to a YAML file,
# let's parse it to get our arguments.
+2 -2
View File
@@ -1,5 +1,5 @@
import re
from typing import List, Literal, Optional, Union
from typing import List, Literal, Optional
from datasets import DatasetDict, concatenate_datasets, load_dataset
@@ -67,7 +67,7 @@ def apply_chat_template(
def get_datasets(
data_config: Union[DataArguments, dict],
data_config: DataArguments | dict,
splits: List[str] = ["train", "test"],
shuffle: bool = True,
) -> DatasetDict:
+2 -2
View File
@@ -1,4 +1,4 @@
from typing import Dict, Union
from typing import Dict
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
@@ -62,7 +62,7 @@ def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTr
return tokenizer
def get_peft_config(model_args: ModelArguments) -> Union[PeftConfig, None]:
def get_peft_config(model_args: ModelArguments) -> PeftConfig | None:
if model_args.use_peft is False:
return None
+70 -1
View File
@@ -15,8 +15,9 @@
import unittest
import pytest
from datasets import Dataset
from alignment import DataArguments, get_datasets
from alignment import DataArguments, ModelArguments, apply_chat_template, get_datasets, get_tokenizer
class GetDatasetsTest(unittest.TestCase):
@@ -77,3 +78,71 @@ class GetDatasetsTest(unittest.TestCase):
datasets = get_datasets(dataset_mixer, splits=["test"])
self.assertEqual(len(datasets["test"]), 100)
self.assertRaises(KeyError, lambda: datasets["train"])
class ApplyChatTemplateTest(unittest.TestCase):
def setUp(self):
model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha")
data_args = DataArguments()
self.tokenizer = get_tokenizer(model_args, data_args)
self.dataset = Dataset.from_dict(
{
"prompt": ["Hello!"],
"messages": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]],
"chosen": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]],
"rejected": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hola!"}]],
}
)
def test_sft(self):
dataset = self.dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": self.tokenizer, "task": "sft"},
remove_columns=self.dataset.column_names,
)
self.assertDictEqual(
dataset[0],
{"text": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n"},
)
def test_generation(self):
# Remove last turn from messages
dataset = self.dataset.map(lambda x: {"messages": x["messages"][:-1]})
dataset = dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": self.tokenizer, "task": "generation"},
remove_columns=self.dataset.column_names,
)
self.assertDictEqual(
dataset[0],
{"text": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\n"},
)
def test_rm(self):
dataset = self.dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": self.tokenizer, "task": "rm"},
remove_columns=self.dataset.column_names,
)
self.assertDictEqual(
dataset[0],
{
"text_chosen": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n",
"text_rejected": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\nHola!</s>\n",
},
)
def test_dpo(self):
dataset = self.dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": self.tokenizer, "task": "dpo"},
remove_columns=self.dataset.column_names,
)
self.assertDictEqual(
dataset[0],
{
"text_prompt": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\n",
"text_chosen": "Bonjour!</s>\n",
"text_rejected": "Hola!</s>\n",
},
)
+76
View File
@@ -0,0 +1,76 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer
from alignment.data import DEFAULT_CHAT_TEMPLATE
class GetQuantizationConfigTest(unittest.TestCase):
def test_4bit(self):
model_args = ModelArguments(load_in_4bit=True)
quantization_config = get_quantization_config(model_args)
self.assertTrue(quantization_config.load_in_4bit)
self.assertEqual(quantization_config.bnb_4bit_compute_dtype, torch.float16)
self.assertEqual(quantization_config.bnb_4bit_quant_type, "nf4")
self.assertFalse(quantization_config.bnb_4bit_use_double_quant)
def test_8bit(self):
model_args = ModelArguments(load_in_8bit=True)
quantization_config = get_quantization_config(model_args)
self.assertTrue(quantization_config.load_in_8bit)
def test_no_quantization(self):
model_args = ModelArguments()
quantization_config = get_quantization_config(model_args)
self.assertIsNone(quantization_config)
class GetTokenizerTest(unittest.TestCase):
def setUp(self) -> None:
self.model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha")
def test_right_truncation_side(self):
tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="right"))
self.assertEqual(tokenizer.truncation_side, "right")
def test_left_truncation_side(self):
tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="left"))
self.assertEqual(tokenizer.truncation_side, "left")
def test_default_chat_template(self):
tokenizer = get_tokenizer(self.model_args, DataArguments())
self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE)
def test_chatml_chat_template(self):
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template))
self.assertEqual(tokenizer.chat_template, chat_template)
class GetPeftConfigTest(unittest.TestCase):
def test_peft_config(self):
model_args = ModelArguments(use_peft=True, lora_r=42, lora_alpha=0.66, lora_dropout=0.99)
peft_config = get_peft_config(model_args)
self.assertEqual(peft_config.r, 42)
self.assertEqual(peft_config.lora_alpha, 0.66)
self.assertEqual(peft_config.lora_dropout, 0.99)
def test_no_peft_config(self):
model_args = ModelArguments(use_peft=False)
peft_config = get_peft_config(model_args)
self.assertIsNone(peft_config)