Files
Alvaro Bartolome 70769f9e9b Add run_orpo.py (#143)
* Add `ORPOConfig`

* Add `task=orpo` and support `(prompt,chosen,rejected)` datasets

* Add missing `model_init_kwargs` and `dataset_num_proc`

* Add `run_orpo.py` (WIP)

* Update `trl` dependency from source

* Add `setup_chat_format` before `apply_chat_template`

* Add `config_full.yaml` for `mistral-7b-orpo`

* Fix comment indentation

* Use `chat_template=chatml` instead

* Add `kaist-ai/mistral-orpo-capybara-7k` recipe

* Rename `DPOTrainer` to `ORPOTrainer` in `config_full.yaml` files

* Run `black --line-length 119 src`

* Add `is_openai_format` to fix `(prompt,chosen,rejected)` formatting

* Run `black --line-length 119 src`

* Fix `isort` in `run_orpo.py`

* Update `mistral-capybara/orpo/config_full.yaml`

* Check if `test` is available split

* Pin `trl` to `alvarobartt/trl` fork (debugging)

* Add `qwen-capybara` recipe

* Update `mistral-capybara` recipe

* Set `add_generation_prompt=True` if `task="orpo"`

* Reduce `logging_steps` to 10

* Unset `add_generation_prompt` when `task=orpo`

* Add filtering based on prompt length

Done similarly to the original implementation, in order to better reproduce their results

* Fix prompt length filtering

* Update `trl` pinned version

* Remove extra outdate config files

* Update `recipes/mistral-capybara/orpo/config_full.yaml`

* Run `make style`

* Activate BEAST MODE

* Pin deps

* Add readme

* Fix dep

---------

Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
2024-04-11 16:02:20 +02:00

1.4 KiB

Instructions to train StarChat2

Similar to how we trained Zephyr 7B Beta in our technical report, training this model proceeds in two steps:

  1. Apply SFT to fine-tune StarCoder2 15B on a blend of chat, code, and math datastets. The result is an SFT model like starchat2-15b-sft-v0.1.
  2. Align the SFT model to AI feedback via DPO on the UltraFeedback and Orca DPO Pairs datasets. The result is a DPO model like starchat2-15b-v0.1.

See below for commands to train these models using DeepSpeed ZeRO-3.

Full training examples

You will require 8 GPUs (80GB of VRAM) to train the full model - alternatively, you can train on 1 GPU by adjusting per_device_train_batch_size and gradient_accumulation_steps to keep the global batch size constant. A recipe involving QLoRA will come later 🤗.

# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/starchat2-15b/sft/config_v0.1.yaml

# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/starchat2-15b/dpo/config_v0.1.yaml