diff --git a/README.md b/README.md index 3c8c457..6226ed0 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,21 @@ I'm using this to train some simple base -> SFT models for my work ```sh -uv sync --no-build-isolation-package flash-attn -# took me ~30mins -MAX_JOBS=10 pip install flash-attn --no-build-isolation +# get a H100 V100 from your favorite cloud provider + +# create virtualenv +uv sync --verbose +# took me ~30mins due to building flash-attn + + +# MAX_JOBS=10 pip install flash-attn --no-build-isolation +#uv add --no-build-isolation-package flash-attn . ./.venv/bin/activate -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml +# note --num_processes=1 related to the GPU's you have +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml --num_processes=1 scripts/run_sft.py recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml + +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml ``` diff --git a/pyproject.toml b/pyproject.toml index 69e10a9..17d23e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "setuptools>=80.9.0", "hatchling>=1.27.0", "editables>=0.5", + "flash-attn>=2.7.4.post1", ] @@ -54,3 +55,5 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["src/alignment"] +[tool.uv.pip] +no-build-isolation-package = ["flash-attn"] diff --git a/recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml b/recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml index 7ae8197..7b2c6dc 100644 --- a/recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml +++ b/recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml @@ -6,7 +6,7 @@ attn_implementation: flash_attention_2 # use_flash_attention_2: true # Data training arguments -tokenizer_name_or_path: Qwen/Qwen3-0.6B # Custom tokenizer with <|im_start|> and <|im_end|> tokens +#tokenizer_name_or_path: Qwen/Qwen3-0.6B # Custom tokenizer with <|im_start|> and <|im_end|> tokens # chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" dataset_mixer: diff --git a/scripts/run_sft.py b/scripts/run_sft.py index 60a2dfd..318565a 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -124,6 +124,7 @@ def main(): # For ChatML we need to add special tokens and resize the embedding layer if "<|im_start|>" in tokenizer.chat_template and "gemma-tokenizer-chatml" not in tokenizer.name_or_path: model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) + tokenizer.chat_template = None # Not quite sure why I have to do this, but if we end up with chatml for all models, and it's saved in the output tokeniser, that's fine model, tokenizer = setup_chat_format(model, tokenizer) model_kwargs = None diff --git a/uv.lock b/uv.lock index 0268816..13f8b25 100644 --- a/uv.lock +++ b/uv.lock @@ -155,6 +155,7 @@ dependencies = [ { name = "einops" }, { name = "evaluate" }, { name = "flake8" }, + { name = "flash-attn" }, { name = "hatchling" }, { name = "hf-doc-builder" }, { name = "hf-transfer" }, @@ -197,6 +198,7 @@ requires-dist = [ { name = "einops", specifier = ">=0.6.1" }, { name = "evaluate", specifier = "==0.4.0" }, { name = "flake8", specifier = ">=6.0.0" }, + { name = "flash-attn", specifier = ">=2.7.4.post1" }, { name = "hatchling", specifier = ">=1.27.0" }, { name = "hf-doc-builder", specifier = ">=0.4.0" }, { name = "hf-transfer", specifier = ">=0.1.4" }, @@ -661,6 +663,16 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/83/5c/0627be4c9976d56b1217cb5187b7504e7fd7d3503f8bfd312a04077bd4f7/flake8-7.2.0-py2.py3-none-any.whl", hash = "sha256:93b92ba5bdb60754a6da14fa3b93a9361fd00a59632ada61fd7b130436c40343", size = 57786, upload-time = "2025-03-29T20:08:37.902Z" }, ] +[[package]] +name = "flash-attn" +version = "2.7.4.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "einops" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/11/34/9bf60e736ed7bbe15055ac2dab48ec67d9dbd088d2b4ae318fd77190ab4e/flash_attn-2.7.4.post1.tar.gz", hash = "sha256:f03485c9a49a4d68d0733acdcad80ab0e72afa025a777fdc2966ceccf9d51765", size = 5986610, upload-time = "2025-01-30T06:39:51.93Z" } + [[package]] name = "frozenlist" version = "1.6.0"