Check that default_chat_template is also None (#83)

* Check that `default_chat_template` is also None before overwriting chat template

* add unit test to `get_tokenizer` to ensure default behaviour of chat template is not changed

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Nathan Azrak
2024-01-08 17:54:23 +11:00
committed by GitHub
parent 98fe28fb14
commit c69ae4b8a5
2 changed files with 15 additions and 1 deletions
+1 -1
View File
@@ -73,7 +73,7 @@ def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTr
if data_args.chat_template is not None:
tokenizer.chat_template = data_args.chat_template
elif tokenizer.chat_template is None:
elif tokenizer.chat_template is None and tokenizer.default_chat_template is None:
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
return tokenizer
+14
View File
@@ -15,6 +15,7 @@
import unittest
import torch
from transformers import AutoTokenizer
from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer
from alignment.data import DEFAULT_CHAT_TEMPLATE
@@ -56,6 +57,19 @@ class GetTokenizerTest(unittest.TestCase):
tokenizer = get_tokenizer(self.model_args, DataArguments())
self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE)
def test_default_chat_template_no_overwrite(self):
"""
If no chat template is passed explicitly in the config, then for models with a
`default_chat_template` but no `chat_template` we do not set a `chat_template`,
and that we do not change `default_chat_template`
"""
model_args = ModelArguments(model_name_or_path="codellama/CodeLlama-7b-Instruct-hf")
base_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
processed_tokenizer = get_tokenizer(model_args, DataArguments())
assert getattr(processed_tokenizer, "chat_template") is None
self.assertEqual(base_tokenizer.default_chat_template, processed_tokenizer.default_chat_template)
def test_chatml_chat_template(self):
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template))