diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index ab2025f..abb930c 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -73,7 +73,7 @@ def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTr if data_args.chat_template is not None: tokenizer.chat_template = data_args.chat_template - elif tokenizer.chat_template is None: + elif tokenizer.chat_template is None and tokenizer.default_chat_template is None: tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE return tokenizer diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py index d20afec..688b3bd 100644 --- a/tests/test_model_utils.py +++ b/tests/test_model_utils.py @@ -15,6 +15,7 @@ import unittest import torch +from transformers import AutoTokenizer from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer from alignment.data import DEFAULT_CHAT_TEMPLATE @@ -56,6 +57,19 @@ class GetTokenizerTest(unittest.TestCase): tokenizer = get_tokenizer(self.model_args, DataArguments()) self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE) + def test_default_chat_template_no_overwrite(self): + """ + If no chat template is passed explicitly in the config, then for models with a + `default_chat_template` but no `chat_template` we do not set a `chat_template`, + and that we do not change `default_chat_template` + """ + model_args = ModelArguments(model_name_or_path="codellama/CodeLlama-7b-Instruct-hf") + base_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf") + processed_tokenizer = get_tokenizer(model_args, DataArguments()) + + assert getattr(processed_tokenizer, "chat_template") is None + self.assertEqual(base_tokenizer.default_chat_template, processed_tokenizer.default_chat_template) + def test_chatml_chat_template(self): chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template))