diff --git a/.gitignore b/.gitignore index 1445d93..1905196 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.venv/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -161,4 +163,4 @@ cython_debug/ # Temp folders data/ -wandb/ \ No newline at end of file +wandb/ diff --git a/justfile b/justfile new file mode 100644 index 0000000..87c0316 --- /dev/null +++ b/justfile @@ -0,0 +1,28 @@ + + + +sft: + ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/llama-3-2-1b-base-sft.yaml + + +install: + # install mambaforge + wget bash https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh + bash Miniforge3-Linux-x86_64.sh -b -p /workspace/miniforge3 + /workspace/miniforge3/bin/conda init zsh + source /root/.zshrc + + # git clone https://github.com/princeton-nlp/SimPO.git + + # make env + conda create -n handbook python=3.10 + conda activate handbook + # install torch + mamba install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia + + # install alignment handbook + git clone https://github.com/huggingface/alignment-handbook.git + cd ./alignment-handbook/ + python -m pip install . + python -m pip install flash-attn --no-build-isolation + cd .. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b5d1666 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,51 @@ +[tool.poetry] +name = "simpo" +version = "0.1.0" +description = "" +authors = ["wassname <1103714+wassname@users.noreply.github.com>"] +readme = "README.md" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[[tool.poetry.source]] +name = "pytorch" +url = "https://download.pytorch.org/whl/cu121" +priority = "explicit" + +[tool.poetry.dependencies] +python = "^3.10" +torch = { version = ">=2.1.2+cu121", source = "pytorch" } +accelerate = ">=0.29.2" +bitsandbytes = ">=0.43.0" +einops = ">=0.6.1" +evaluate = "0.4.0" +datasets = ">=2.18.0" +deepspeed = ">=0.14.4" +hf_transfer = ">=0.1.4" +huggingface-hub = ">=0.19.2,<1.0" +jinja2 = ">=3.0.0" +ninja = ">=1.11.1" +numpy = ">=1.24.2" +packaging = ">=23.0" +peft = ">=0.9.0" +protobuf = "<=3.20.2" +safetensors = ">=0.3.3" +sentencepiece = ">=0.1.99" +scipy = "*" +tensorboard = "*" +tqdm = ">=4.64.1" +transformers = ">=4.39.3" +trl = ">=0.9.6" + +[tool.poetry.group.dev.dependencies] +pytest = "*" +parameterized = ">=0.9.0" +black = ">=24.4.2" +isort = ">=5.12.0" +flake8 = ">=6.0.0" +hf-doc-builder = ">=0.4.0" + +[virtualenvs] +in-project = true diff --git a/training_configs/llama-3-2-1b-base-sft.yaml b/training_configs/llama-3-2-1b-base-sft.yaml new file mode 100644 index 0000000..e65e46b --- /dev/null +++ b/training_configs/llama-3-2-1b-base-sft.yaml @@ -0,0 +1,48 @@ +# Model arguments +model_name_or_path: NousResearch/Llama-3.2-1B +model_revision: main +torch_dtype: bfloat16 +attn_implementation: flash_attention_2 + +# Data training arguments +chat_template: "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" +dataset_mixer: + HuggingFaceH4/ultrachat_200k: 1.0 +dataset_splits: +- train_sft +- test_sft +preprocessing_num_workers: 12 + +# SFT trainer config +bf16: true +do_eval: true +evaluation_strategy: steps +eval_steps: 200 +gradient_accumulation_steps: 4 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: False +hub_model_id: zephyr-7b-sft-full +hub_strategy: every_save +learning_rate: 2.0e-05 +log_level: info +logging_steps: 5 +logging_strategy: steps +lr_scheduler_type: cosine +max_seq_length: 2048 +max_steps: -1 +num_train_epochs: 1 +output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-sft +run_name: llama-3-8b-sft +overwrite_output_dir: true +per_device_eval_batch_size: 8 +per_device_train_batch_size: 8 +push_to_hub: false +remove_unused_columns: true +report_to: +- wandb +save_strategy: "steps" +save_steps: 1000000 +save_total_limit: 1 +seed: 42 +warmup_ratio: 0.1