diff --git a/scripts/run_sft.py b/scripts/run_sft.py index a38accd..522ce86 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -122,7 +122,7 @@ def main(): model = model_args.model_name_or_path # For ChatML we need to add special tokens and resize the embedding layer - if "<|im_start|>" in tokenizer.chat_template: + if "<|im_start|>" in tokenizer.chat_template and "gemma-tokenizer-chatml" not in tokenizer.name_or_path: model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) model, tokenizer = setup_chat_format(model, tokenizer) model_kwargs = None