mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 17:14:25 +08:00
chat template fix
This commit is contained in:
@@ -2,12 +2,21 @@ I'm using this to train some simple base -> SFT models for my work
|
||||
|
||||
|
||||
```sh
|
||||
uv sync --no-build-isolation-package flash-attn
|
||||
# took me ~30mins
|
||||
MAX_JOBS=10 pip install flash-attn --no-build-isolation
|
||||
# get a H100 V100 from your favorite cloud provider
|
||||
|
||||
# create virtualenv
|
||||
uv sync --verbose
|
||||
# took me ~30mins due to building flash-attn
|
||||
|
||||
|
||||
# MAX_JOBS=10 pip install flash-attn --no-build-isolation
|
||||
#uv add --no-build-isolation-package flash-attn
|
||||
. ./.venv/bin/activate
|
||||
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml
|
||||
# note --num_processes=1 related to the GPU's you have
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml --num_processes=1 scripts/run_sft.py recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml
|
||||
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ dependencies = [
|
||||
"setuptools>=80.9.0",
|
||||
"hatchling>=1.27.0",
|
||||
"editables>=0.5",
|
||||
"flash-attn>=2.7.4.post1",
|
||||
]
|
||||
|
||||
|
||||
@@ -54,3 +55,5 @@ build-backend = "hatchling.build"
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/alignment"]
|
||||
|
||||
[tool.uv.pip]
|
||||
no-build-isolation-package = ["flash-attn"]
|
||||
|
||||
@@ -6,7 +6,7 @@ attn_implementation: flash_attention_2
|
||||
# use_flash_attention_2: true
|
||||
|
||||
# Data training arguments
|
||||
tokenizer_name_or_path: Qwen/Qwen3-0.6B # Custom tokenizer with <|im_start|> and <|im_end|> tokens
|
||||
#tokenizer_name_or_path: Qwen/Qwen3-0.6B # Custom tokenizer with <|im_start|> and <|im_end|> tokens
|
||||
# chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
||||
|
||||
dataset_mixer:
|
||||
|
||||
@@ -124,6 +124,7 @@ def main():
|
||||
# For ChatML we need to add special tokens and resize the embedding layer
|
||||
if "<|im_start|>" in tokenizer.chat_template and "gemma-tokenizer-chatml" not in tokenizer.name_or_path:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
|
||||
tokenizer.chat_template = None # Not quite sure why I have to do this, but if we end up with chatml for all models, and it's saved in the output tokeniser, that's fine
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
model_kwargs = None
|
||||
|
||||
|
||||
@@ -155,6 +155,7 @@ dependencies = [
|
||||
{ name = "einops" },
|
||||
{ name = "evaluate" },
|
||||
{ name = "flake8" },
|
||||
{ name = "flash-attn" },
|
||||
{ name = "hatchling" },
|
||||
{ name = "hf-doc-builder" },
|
||||
{ name = "hf-transfer" },
|
||||
@@ -197,6 +198,7 @@ requires-dist = [
|
||||
{ name = "einops", specifier = ">=0.6.1" },
|
||||
{ name = "evaluate", specifier = "==0.4.0" },
|
||||
{ name = "flake8", specifier = ">=6.0.0" },
|
||||
{ name = "flash-attn", specifier = ">=2.7.4.post1" },
|
||||
{ name = "hatchling", specifier = ">=1.27.0" },
|
||||
{ name = "hf-doc-builder", specifier = ">=0.4.0" },
|
||||
{ name = "hf-transfer", specifier = ">=0.1.4" },
|
||||
@@ -661,6 +663,16 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/83/5c/0627be4c9976d56b1217cb5187b7504e7fd7d3503f8bfd312a04077bd4f7/flake8-7.2.0-py2.py3-none-any.whl", hash = "sha256:93b92ba5bdb60754a6da14fa3b93a9361fd00a59632ada61fd7b130436c40343", size = 57786, upload-time = "2025-03-29T20:08:37.902Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flash-attn"
|
||||
version = "2.7.4.post1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "einops" },
|
||||
{ name = "torch" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/11/34/9bf60e736ed7bbe15055ac2dab48ec67d9dbd088d2b4ae318fd77190ab4e/flash_attn-2.7.4.post1.tar.gz", hash = "sha256:f03485c9a49a4d68d0733acdcad80ab0e72afa025a777fdc2966ceccf9d51765", size = 5986610, upload-time = "2025-01-30T06:39:51.93Z" }
|
||||
|
||||
[[package]]
|
||||
name = "frozenlist"
|
||||
version = "1.6.0"
|
||||
|
||||
Reference in New Issue
Block a user