From c69ae4b8a5436885b72475914b92dfb341876bf0 Mon Sep 17 00:00:00 2001 From: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> Date: Mon, 8 Jan 2024 17:54:23 +1100 Subject: [PATCH] Check that `default_chat_template` is also None (#83) * Check that `default_chat_template` is also None before overwriting chat template * add unit test to `get_tokenizer` to ensure default behaviour of chat template is not changed --------- Co-authored-by: lewtun --- src/alignment/model_utils.py | 2 +- tests/test_model_utils.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index ab2025f..abb930c 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -73,7 +73,7 @@ def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTr if data_args.chat_template is not None: tokenizer.chat_template = data_args.chat_template - elif tokenizer.chat_template is None: + elif tokenizer.chat_template is None and tokenizer.default_chat_template is None: tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE return tokenizer diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py index d20afec..688b3bd 100644 --- a/tests/test_model_utils.py +++ b/tests/test_model_utils.py @@ -15,6 +15,7 @@ import unittest import torch +from transformers import AutoTokenizer from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer from alignment.data import DEFAULT_CHAT_TEMPLATE @@ -56,6 +57,19 @@ class GetTokenizerTest(unittest.TestCase): tokenizer = get_tokenizer(self.model_args, DataArguments()) self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE) + def test_default_chat_template_no_overwrite(self): + """ + If no chat template is passed explicitly in the config, then for models with a + `default_chat_template` but no `chat_template` we do not set a `chat_template`, + and that we do not change `default_chat_template` + """ + model_args = ModelArguments(model_name_or_path="codellama/CodeLlama-7b-Instruct-hf") + base_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf") + processed_tokenizer = get_tokenizer(model_args, DataArguments()) + + assert getattr(processed_tokenizer, "chat_template") is None + self.assertEqual(base_tokenizer.default_chat_template, processed_tokenizer.default_chat_template) + def test_chatml_chat_template(self): chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template))