mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 17:47:01 +08:00
70769f9e9b
* Add `ORPOConfig` * Add `task=orpo` and support `(prompt,chosen,rejected)` datasets * Add missing `model_init_kwargs` and `dataset_num_proc` * Add `run_orpo.py` (WIP) * Update `trl` dependency from source * Add `setup_chat_format` before `apply_chat_template` * Add `config_full.yaml` for `mistral-7b-orpo` * Fix comment indentation * Use `chat_template=chatml` instead * Add `kaist-ai/mistral-orpo-capybara-7k` recipe * Rename `DPOTrainer` to `ORPOTrainer` in `config_full.yaml` files * Run `black --line-length 119 src` * Add `is_openai_format` to fix `(prompt,chosen,rejected)` formatting * Run `black --line-length 119 src` * Fix `isort` in `run_orpo.py` * Update `mistral-capybara/orpo/config_full.yaml` * Check if `test` is available split * Pin `trl` to `alvarobartt/trl` fork (debugging) * Add `qwen-capybara` recipe * Update `mistral-capybara` recipe * Set `add_generation_prompt=True` if `task="orpo"` * Reduce `logging_steps` to 10 * Unset `add_generation_prompt` when `task=orpo` * Add filtering based on prompt length Done similarly to the original implementation, in order to better reproduce their results * Fix prompt length filtering * Update `trl` pinned version * Remove extra outdate config files * Update `recipes/mistral-capybara/orpo/config_full.yaml` * Run `make style` * Activate BEAST MODE * Pin deps * Add readme * Fix dep --------- Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
1.4 KiB
1.4 KiB
Instructions to train StarChat2
Similar to how we trained Zephyr 7B Beta in our technical report, training this model proceeds in two steps:
- Apply SFT to fine-tune StarCoder2 15B on a blend of chat, code, and math datastets. The result is an SFT model like
starchat2-15b-sft-v0.1. - Align the SFT model to AI feedback via DPO on the UltraFeedback and Orca DPO Pairs datasets. The result is a DPO model like
starchat2-15b-v0.1.
See below for commands to train these models using DeepSpeed ZeRO-3.
Full training examples
You will require 8 GPUs (80GB of VRAM) to train the full model - alternatively, you can train on 1 GPU by adjusting per_device_train_batch_size and gradient_accumulation_steps to keep the global batch size constant. A recipe involving QLoRA will come later 🤗.
# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/starchat2-15b/sft/config_v0.1.yaml
# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/starchat2-15b/dpo/config_v0.1.yaml