chat template fix

This commit is contained in:
wassname
2025-06-02 07:27:46 +00:00
parent 2819dd46d0
commit 880d4eda1e
5 changed files with 30 additions and 5 deletions
+13 -4
View File
@@ -2,12 +2,21 @@ I'm using this to train some simple base -> SFT models for my work
```sh
uv sync --no-build-isolation-package flash-attn
# took me ~30mins
MAX_JOBS=10 pip install flash-attn --no-build-isolation
# get a H100 V100 from your favorite cloud provider
# create virtualenv
uv sync --verbose
# took me ~30mins due to building flash-attn
# MAX_JOBS=10 pip install flash-attn --no-build-isolation
#uv add --no-build-isolation-package flash-attn
. ./.venv/bin/activate
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml
# note --num_processes=1 related to the GPU's you have
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml --num_processes=1 scripts/run_sft.py recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml
```
+3
View File
@@ -38,6 +38,7 @@ dependencies = [
"setuptools>=80.9.0",
"hatchling>=1.27.0",
"editables>=0.5",
"flash-attn>=2.7.4.post1",
]
@@ -54,3 +55,5 @@ build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/alignment"]
[tool.uv.pip]
no-build-isolation-package = ["flash-attn"]
+1 -1
View File
@@ -6,7 +6,7 @@ attn_implementation: flash_attention_2
# use_flash_attention_2: true
# Data training arguments
tokenizer_name_or_path: Qwen/Qwen3-0.6B # Custom tokenizer with <|im_start|> and <|im_end|> tokens
#tokenizer_name_or_path: Qwen/Qwen3-0.6B # Custom tokenizer with <|im_start|> and <|im_end|> tokens
# chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
dataset_mixer:
+1
View File
@@ -124,6 +124,7 @@ def main():
# For ChatML we need to add special tokens and resize the embedding layer
if "<|im_start|>" in tokenizer.chat_template and "gemma-tokenizer-chatml" not in tokenizer.name_or_path:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
tokenizer.chat_template = None # Not quite sure why I have to do this, but if we end up with chatml for all models, and it's saved in the output tokeniser, that's fine
model, tokenizer = setup_chat_format(model, tokenizer)
model_kwargs = None
Generated
+12
View File
@@ -155,6 +155,7 @@ dependencies = [
{ name = "einops" },
{ name = "evaluate" },
{ name = "flake8" },
{ name = "flash-attn" },
{ name = "hatchling" },
{ name = "hf-doc-builder" },
{ name = "hf-transfer" },
@@ -197,6 +198,7 @@ requires-dist = [
{ name = "einops", specifier = ">=0.6.1" },
{ name = "evaluate", specifier = "==0.4.0" },
{ name = "flake8", specifier = ">=6.0.0" },
{ name = "flash-attn", specifier = ">=2.7.4.post1" },
{ name = "hatchling", specifier = ">=1.27.0" },
{ name = "hf-doc-builder", specifier = ">=0.4.0" },
{ name = "hf-transfer", specifier = ">=0.1.4" },
@@ -661,6 +663,16 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/83/5c/0627be4c9976d56b1217cb5187b7504e7fd7d3503f8bfd312a04077bd4f7/flake8-7.2.0-py2.py3-none-any.whl", hash = "sha256:93b92ba5bdb60754a6da14fa3b93a9361fd00a59632ada61fd7b130436c40343", size = 57786, upload-time = "2025-03-29T20:08:37.902Z" },
]
[[package]]
name = "flash-attn"
version = "2.7.4.post1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "einops" },
{ name = "torch" },
]
sdist = { url = "https://files.pythonhosted.org/packages/11/34/9bf60e736ed7bbe15055ac2dab48ec67d9dbd088d2b4ae318fd77190ab4e/flash_attn-2.7.4.post1.tar.gz", hash = "sha256:f03485c9a49a4d68d0733acdcad80ab0e72afa025a777fdc2966ceccf9d51765", size = 5986610, upload-time = "2025-01-30T06:39:51.93Z" }
[[package]]
name = "frozenlist"
version = "1.6.0"