mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 18:41:19 +08:00
95dc47218c
* update API * update deepspeed * update black * remove unused import * fix typos * fix typos in readmes * fix grammer * removed as it exists in superclass * fixes in readme * Update README.md Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> * Update src/alignment/configs.py Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> * Update src/alignment/configs.py Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> * Update src/alignment/configs.py Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> * Update src/alignment/configs.py Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> * add back dataset_kwargs * use hub_model_revision in sft and dpo * fix duplicate --------- Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Instructions to train StarChat2
Similar to how we trained Zephyr 7B Beta in our technical report, training this model proceeds in two steps:
- Apply SFT to fine-tune StarCoder2 15B on a blend of chat, code, and math datastets. The result is an SFT model like
starchat2-15b-sft-v0.1. - Align the SFT model to AI feedback via DPO on the UltraFeedback and Orca DPO Pairs datasets. The result is a DPO model like
starchat2-15b-v0.1.
See below for commands to train these models using DeepSpeed ZeRO-3.
Full training examples
You will require 8 GPUs (80GB of VRAM) to train the full model - alternatively, you can train on 1 GPU by adjusting per_device_train_batch_size and gradient_accumulation_steps to keep the global batch size constant. A recipe involving QLoRA will come later 🤗.
# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/starchat2-15b/sft/config_v0.1.yaml
# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/starchat2-15b/dpo/config_v0.1.yaml