Files
Kashif Rasul 95dc47218c update API to use latest TRL (#182)
* update API

* update deepspeed

* update black

* remove unused import

* fix typos

* fix typos in readmes

* fix grammer

* removed as it exists in superclass

* fixes in readme

* Update README.md

Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>

* Update src/alignment/configs.py

Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>

* Update src/alignment/configs.py

Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>

* Update src/alignment/configs.py

Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>

* Update src/alignment/configs.py

Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>

* add back dataset_kwargs

* use hub_model_revision in sft and dpo

* fix duplicate

---------

Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
2024-07-30 09:16:25 +02:00
..
2024-07-30 09:16:25 +02:00
2024-07-30 09:16:25 +02:00
2024-07-30 09:16:25 +02:00

Instructions to Replicate Zephyr-7b-β

As described in the Zephyr technical report, training this model proceeds in two steps:

  1. Apply SFT to fine-tune Mistral 7B on a filtered version of the UltraChat dataset (link). The result is an SFT model like zephyr-7b-sft-full or zephyr-7b-sft-qlora.
  2. Align the SFT model to AI feedback via DPO on a preprocessed version of the UltraFeedback dataset (link). The result is a DPO model like zephyr-7b-dpo-full or zephyr-7b-dpo-qlora.

Note: after the release of Zephyr, the team at Argilla found that the source UltraFeedback dataset had a few thousand incorrect preference labels from GPT-4. Additionally, TRL's SFTTrainer had a bug in the learning rate scheduler which terminated training early. Accounting for these changes led us to find a better set of hyperparameters from those described in the technical report. In particular, for DPO training we found that training for 1 epoch with beta=0.01 was sufficient to achieve comparable performance to zephyr-7b-beta (vs. 3 epochs with beta=0.1).

See below for commands to train these models using either DeepSpeed ZeRO-3 or LoRA.

Full training examples

You will require 8 GPUs (80GB of VRAM) to train the full model.

# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_full.yaml

# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_full.yaml

QLoRA training examples

Train faster with flash-attention 2 (GPU supporting FA2: A100, H100, etc)

# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true

# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml

P.S. Using Flash Attention also allows you to drastically increase the batch size (x2 in my case)

Train without flash-attention (i.e. via PyTorch's scaled dot product attention):

# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true --attn_implementation=sdpa

# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml --attn_implementation=sdpa