mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 16:14:07 +08:00
update API to use latest TRL (#182)
* update API * update deepspeed * update black * remove unused import * fix typos * fix typos in readmes * fix grammer * removed as it exists in superclass * fixes in readme * Update README.md Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> * Update src/alignment/configs.py Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> * Update src/alignment/configs.py Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> * Update src/alignment/configs.py Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> * Update src/alignment/configs.py Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> * add back dataset_kwargs * use hub_model_revision in sft and dpo * fix duplicate --------- Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
This commit is contained in:
@@ -49,8 +49,8 @@ If you would like to train chat models on your own datasets, we recommend follow
|
||||
|
||||
The initial release of the handbook will focus on the following techniques:
|
||||
|
||||
* **Continued pretraining:** adapt language models to a new language or domain, or simply improve it by continue pretraning (causal language modeling) on a new dataset.
|
||||
* **Supervised fine-tuning:** teach language models to follow instructions and tips on how to collect and curate your own training dataset.
|
||||
* **Continued pretraining:** adapt language models to a new language or domain, or simply improve it by continued pretraining (causal language modeling) on a new dataset.
|
||||
* **Supervised fine-tuning:** teach language models to follow instructions and tips on how to collect and curate your training dataset.
|
||||
* **Reward modeling:** teach language models to distinguish model responses according to human or AI preferences.
|
||||
* **Rejection sampling:** a simple, but powerful technique to boost the performance of your SFT model.
|
||||
* **Direct preference optimisation (DPO):** a powerful and promising alternative to PPO.
|
||||
|
||||
@@ -21,4 +21,4 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
|
||||
|
||||
## Advanced: generating you own dataset
|
||||
|
||||
To generate the constitutional AI dataset, see https://github.com/huggingface/llm-swarm/tree/main/examples/constitutional-ai for detailed instructions if you want build or customize the dataset.
|
||||
To generate the constitutional AI dataset, see https://github.com/huggingface/llm-swarm/tree/main/examples/constitutional-ai for detailed instructions if you want to build or customize the dataset.
|
||||
|
||||
@@ -17,7 +17,7 @@ bf16: true
|
||||
beta: 0.1
|
||||
do_eval: true
|
||||
do_train: true
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 1000
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
||||
@@ -18,7 +18,7 @@ preprocessing_num_workers: 12
|
||||
bf16: true
|
||||
do_eval: true
|
||||
do_train: true
|
||||
evaluation_strategy: epoch # One of ["no", "steps", "epoch"]
|
||||
eval_strategy: epoch # One of ["no", "steps", "epoch"]
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
||||
@@ -18,7 +18,7 @@ preprocessing_num_workers: 12
|
||||
bf16: true
|
||||
do_eval: true
|
||||
do_train: true
|
||||
evaluation_strategy: epoch # One of ["no", "steps", "epoch"]
|
||||
eval_strategy: epoch # One of ["no", "steps", "epoch"]
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
This directory shows a base example of how to use continued pretraining and further tuning to adapt a language model to new data (e.g. a new language or domain).
|
||||
|
||||
Three steps are needed: continued pretraining (`cpt`), supervised finetuning (`sft`), and direct preference optimisation (`dpo`). In this dummy example we'll continue pretraining gpt2 on Dutch raw data, then sft-tuning it, and finally aligning it with DPO. Note that no extensive hyperparameters were tested in this example and that the output models are bad - it is just to show you how you can use the scripts for LM adaptation. The scripts work on 4x 3090s (24GB VRAM). If you have less powerful hardware you may need to reduce the batch size.
|
||||
Three steps are needed: continued pretraining (`cpt`), supervised finetuning (`sft`), and direct preference optimisation (`dpo`). In this dummy example, we'll continue pretraining gpt2 on Dutch raw data, then sft-tuning it, and finally aligning it with DPO. Note that no extensive hyperparameters were tested in this example and that the output models are bad - it is just to show you how you can use the scripts for LM adaptation. The scripts work on 4x 3090s (24GB VRAM). If you have less powerful hardware you may need to reduce the batch size.
|
||||
|
||||
## Continued pretraining
|
||||
|
||||
@@ -18,7 +18,7 @@ ACCELERATE_LOG_LEVEL=info accelerate launch \
|
||||
|
||||
## Supervised finetuning
|
||||
|
||||
As other recipes, such as the famous zephyr-7b-beta recipe, have shown, we can then teach our model how to hold a conversation by finetuning it on chat-formatted data. As a base model we'll make use of the output of the previous step.
|
||||
As other recipes, such as the famous zephyr-7b-beta recipe, have shown, we can then teach our model how to hold a conversation by finetuning it on chat-formatted data. As a base model, we'll make use of the output of the previous step.
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch \
|
||||
|
||||
@@ -15,7 +15,7 @@ preprocessing_num_workers: 12
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: False
|
||||
evaluation_strategy: "no"
|
||||
eval_strategy: "no"
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
|
||||
@@ -16,7 +16,7 @@ preprocessing_num_workers: 12
|
||||
bf16: true
|
||||
beta: 0.1
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -15,7 +15,7 @@ preprocessing_num_workers: 12
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: true
|
||||
evaluation_strategy: epoch
|
||||
eval_strategy: epoch
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
|
||||
@@ -5,13 +5,14 @@ This directory contains various comparisons for three algorithms: DPO, IPO, and
|
||||
- OpenHermes-2.5 and the OpenOrca datasets
|
||||
|
||||
We release a collection containing the datasets and models used for these experiments, if you require the other trained models, we can release them on request.
|
||||
You can find a longer decription of there results in our [blogpost](https://huggingface.co/blog/pref-tuning)
|
||||
You can find a longer description of these results in our [blogpost](https://huggingface.co/blog/pref-tuning)
|
||||
|
||||
## Comparisons
|
||||
For each algorithm, we aim to tune the beta parameter for a fixed learning rate. We vary beta from 0.1-0.9 in steps of 0.1, we have also found that in certain configurations a tiny value of beta, 0.01, can be effective. So we have included this smaller value in all our comparisons.
|
||||
|
||||
## Usage
|
||||
The experiments can be launched with the following bash script:
|
||||
```
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
# Define an array containing the base configs we wish to fine tune
|
||||
|
||||
@@ -16,7 +16,7 @@ beta: 0.01
|
||||
loss_type: sigmoid
|
||||
do_eval: true
|
||||
do_train: true
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -15,7 +15,7 @@ bf16: true
|
||||
beta: 0.01
|
||||
loss_type: sigmoid
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -16,7 +16,7 @@ preprocessing_num_workers: 12
|
||||
bf16: true
|
||||
beta: 0.05
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
model_name_or_path: bigcode/starcoder2-15b
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
@@ -20,7 +20,7 @@ preprocessing_num_workers: 24
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: true
|
||||
evaluation_strategy: epoch
|
||||
eval_strategy: epoch
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
model_name_or_path: mistral-community/Mixtral-8x22B-v0.1
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
As described in the Zephyr [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps:
|
||||
|
||||
1. Apply SFT to fine-tune Mistral 7B on a filtered version of the UltraChat dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)). The result is an SFT model like [`zephyr-7b-sft-full`](https://huggingface.co/alignment-handbook/zephyr-7b-sft-full) or [`zephyr-7b-sft-qlora`](https://huggingface.co/alignment-handbook/zephyr-7b-sft-qlora).
|
||||
2. Align the SFT model to AI feedback via DPO on a preprocessed version of the UltraFeedback dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)). The result is an DPO model like [`zephyr-7b-dpo-full`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-full) or [`zephyr-7b-dpo-qlora`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-qlora).
|
||||
2. Align the SFT model to AI feedback via DPO on a preprocessed version of the UltraFeedback dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)). The result is a DPO model like [`zephyr-7b-dpo-full`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-full) or [`zephyr-7b-dpo-qlora`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-qlora).
|
||||
|
||||
**Note:** after the release of Zephyr, the team at [Argilla](https://argilla.io) found that the source UltraFeedback dataset had a few thousand incorrect preference labels from GPT-4. Additionally, TRL's `SFTTrainer` had a bug in the learning rate scheduler which terminated training early. Accounting for these changes led us to find a better set of hyperparameters from those described in the technical report. In particular, for DPO training we found that training for 1 epoch with `beta=0.01` was suffucient to achieve comparable performance to `zephyr-7b-beta` (vs. 3 epochs with `beta=0.1`).
|
||||
**Note:** after the release of Zephyr, the team at [Argilla](https://argilla.io) found that the source UltraFeedback dataset had a few thousand incorrect preference labels from GPT-4. Additionally, TRL's `SFTTrainer` had a bug in the learning rate scheduler which terminated training early. Accounting for these changes led us to find a better set of hyperparameters from those described in the technical report. In particular, for DPO training we found that training for 1 epoch with `beta=0.01` was sufficient to achieve comparable performance to `zephyr-7b-beta` (vs. 3 epochs with `beta=0.1`).
|
||||
|
||||
See below for commands to train these models using either DeepSpeed ZeRO-3 or LoRA.
|
||||
|
||||
@@ -34,11 +34,11 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
|
||||
|
||||
P.S. Using Flash Attention also allows you to drastically increase the batch size (x2 in my case)
|
||||
|
||||
Train without flash-attention:
|
||||
Train without flash-attention (i.e. via PyTorch's scaled dot product attention):
|
||||
```````shell
|
||||
# Step 1 - SFT
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true --use_flash_attention_2=false
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true --attn_implementation=sdpa
|
||||
|
||||
# Step 2 - DPO
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml --use_flash_attention_2=false
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml --attn_implementation=sdpa
|
||||
```````
|
||||
@@ -15,7 +15,7 @@ preprocessing_num_workers: 12
|
||||
bf16: true
|
||||
beta: 0.01
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Model arguments
|
||||
model_name_or_path: alignment-handbook/zephyr-7b-sft-qlora
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# LoRA arguments
|
||||
use_peft: true
|
||||
@@ -31,7 +31,7 @@ preprocessing_num_workers: 12
|
||||
bf16: true
|
||||
beta: 0.01
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
||||
@@ -16,7 +16,7 @@ preprocessing_num_workers: 12
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: true
|
||||
evaluation_strategy: epoch
|
||||
eval_strategy: epoch
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# LoRA arguments
|
||||
load_in_4bit: true
|
||||
@@ -31,7 +31,7 @@ preprocessing_num_workers: 12
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: true
|
||||
evaluation_strategy: epoch
|
||||
eval_strategy: epoch
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
|
||||
@@ -15,7 +15,7 @@ preprocessing_num_workers: 12
|
||||
bf16: true
|
||||
beta: 0.05
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -3,7 +3,7 @@ model_name_or_path: google/gemma-7b
|
||||
model_revision: main
|
||||
tokenizer_name_or_path: philschmid/gemma-tokenizer-chatml # Custom tokenizer with <|im_start|> and <|im_end|> tokens
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_mixer:
|
||||
@@ -19,7 +19,7 @@ dataset_kwargs:
|
||||
add_special_tokens: false # We already wrap <bos> and <eos> in the chat template
|
||||
append_concat_token: false # No need to add <eos> across samples
|
||||
do_eval: true
|
||||
evaluation_strategy: epoch
|
||||
eval_strategy: epoch
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
|
||||
+4
-5
@@ -28,7 +28,7 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/fsdp+qlora.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --torch_dtype=bfloat16 --bnb_4bit_quant_storage=bfloat16
|
||||
```
|
||||
|
||||
Here `{task}` refers to the type of training you wish to run. Currently the following tasks are supported:
|
||||
Here `{task}` refers to the type of training you wish to run. Currently, the following tasks are supported:
|
||||
* continued pretraining `cpt` (note that `cpt` is only present in the `gpt-nl` example recipe)
|
||||
* supervised finetuning `sft`
|
||||
* direct preference optimisation `dpo`
|
||||
@@ -54,8 +54,7 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
|
||||
```
|
||||
|
||||
## Logging with Weights and Biases
|
||||
|
||||
By default all training metrics are logged with TensorBoard. If you have a [Weights and Biases](https://wandb.ai/site) account and are logged in, you can view the training metrics by appending `--report_to=wandb`, e.g.
|
||||
By default, all training metrics are logged with TensorBoard. If you have a [Weights and Biases](https://wandb.ai/site) account and are logged in, you can view the training metrics by appending `--report_to=wandb`, e.g.
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml --report_to=wandb
|
||||
@@ -120,7 +119,7 @@ If you format your dataset in the same way, our training scripts should work out
|
||||
We recommend benchmarking chat models on:
|
||||
|
||||
* [MT-Bench](https://huggingface.co/spaces/lmsys/mt-bench): a multi-turn benchmark spanning 80 dialogues and 10 domains.
|
||||
* [AlpacaEval](https://github.com/tatsu-lab/alpaca_eval): a single-turn benchmark which evaluates the helpfulness of chat and instruct models against `text-davinci-003`.
|
||||
* [AlpacaEval](https://github.com/tatsu-lab/alpaca_eval): a single-turn benchmark that evaluates the helpfulness of chat and instruct models against `text-davinci-003`.
|
||||
|
||||
For both benchmarks, we have added support for the [Zephyr chat template](https://huggingface.co/alignment-handbook/zephyr-7b-sft-full/blob/ac6e600eefcce74f5e8bae1035d4f66019e93190/tokenizer_config.json#L30) (which is the default produced by our scripts), so you can evaluate models produced by our scripts as follows:
|
||||
|
||||
@@ -137,6 +136,6 @@ For both benchmarks, we have added support for the [Zephyr chat template](https:
|
||||
* Next, update the [config name](https://github.com/tatsu-lab/alpaca_eval/blob/2daa6e11b194653043ca74f735728dc068e04aae/src/alpaca_eval/models_configs/zephyr-7b-beta/configs.yaml#L1) and [Hub model ID](https://github.com/tatsu-lab/alpaca_eval/blob/2daa6e11b194653043ca74f735728dc068e04aae/src/alpaca_eval/models_configs/zephyr-7b-beta/configs.yaml#L5) to match your model name.
|
||||
* Follow the steps to evaluate your model [here](https://github.com/tatsu-lab/alpaca_eval/tree/main#evaluating-a-model).
|
||||
|
||||
Note that MT-Bench and AlpacaEval rely on LLMs like GPT-4 to judge the quality of the model responses, and thus the ranking exhibit various biases including a preference for models distilled from GPTs. For that reason, we also recommend submitting your best models for human evaluation in:
|
||||
Note that MT-Bench and AlpacaEval rely on LLMs like GPT-4 to judge the quality of the model responses, and thus the ranking exhibits various biases including a preference for models distilled from GPTs. For that reason, we also recommend submitting your best models for human evaluation in:
|
||||
|
||||
* [Chatbot Arena](https://chat.lmsys.org): a live, human evaluation of chat models in head-to-head comparisons.
|
||||
|
||||
+2
-2
@@ -127,7 +127,7 @@ def main():
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
@@ -200,7 +200,7 @@ def main():
|
||||
|
||||
if training_args.push_to_hub is True:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
trainer.push_to_hub(revision=training_args.hub_model_revision, **kwargs)
|
||||
|
||||
logger.info("*** Training complete ***")
|
||||
|
||||
|
||||
+3
-3
@@ -146,7 +146,7 @@ def main():
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
@@ -160,7 +160,7 @@ def main():
|
||||
model_kwargs = dict(
|
||||
revision=model_args.base_model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
@@ -252,7 +252,7 @@ def main():
|
||||
|
||||
if training_args.push_to_hub is True:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
trainer.push_to_hub(revision=training_args.hub_model_revision, **kwargs)
|
||||
|
||||
logger.info("*** Training complete! ***")
|
||||
|
||||
|
||||
+2
-3
@@ -36,8 +36,7 @@ from alignment import (
|
||||
get_quantization_config,
|
||||
get_tokenizer,
|
||||
)
|
||||
from alignment.configs import ORPOConfig
|
||||
from trl import ORPOTrainer, setup_chat_format
|
||||
from trl import ORPOConfig, ORPOTrainer, setup_chat_format
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -107,7 +106,7 @@ def main():
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
|
||||
+2
-2
@@ -113,7 +113,7 @@ def main():
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
@@ -224,7 +224,7 @@ def main():
|
||||
|
||||
if training_args.push_to_hub is True:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
trainer.push_to_hub(revision=training_args.hub_model_revision, **kwargs)
|
||||
|
||||
logger.info("*** Training complete ***")
|
||||
|
||||
|
||||
@@ -43,9 +43,9 @@ if stale_egg_info.exists():
|
||||
_deps = [
|
||||
"accelerate>=0.29.2",
|
||||
"bitsandbytes>=0.43.0",
|
||||
"black==23.1.0",
|
||||
"black>=24.4.2",
|
||||
"datasets>=2.18.0",
|
||||
"deepspeed==0.12.2",
|
||||
"deepspeed>=0.14.4",
|
||||
"einops>=0.6.1",
|
||||
"evaluate==0.4.0",
|
||||
"flake8>=6.0.0",
|
||||
@@ -64,9 +64,9 @@ _deps = [
|
||||
"sentencepiece>=0.1.99",
|
||||
"scipy",
|
||||
"tensorboard",
|
||||
"torch==2.1.2",
|
||||
"torch>=2.1.2",
|
||||
"transformers>=4.39.3",
|
||||
"trl>=0.8.2",
|
||||
"trl>=0.9.6",
|
||||
"jinja2>=3.0.0",
|
||||
"tqdm>=4.64.1",
|
||||
]
|
||||
|
||||
+21
-95
@@ -18,9 +18,10 @@ import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, NewType, Optional, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
|
||||
|
||||
import trl
|
||||
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
@@ -70,7 +71,7 @@ class H4ArgumentParser(HfArgumentParser):
|
||||
inputs[arg] = [str(v) for v in val.split(",")]
|
||||
|
||||
# bool of a non-empty string is True, so we manually check for bools
|
||||
if base_type == bool:
|
||||
if base_type is bool:
|
||||
if val in ["true", "True"]:
|
||||
inputs[arg] = True
|
||||
else:
|
||||
@@ -146,11 +147,11 @@ class ModelArguments:
|
||||
},
|
||||
)
|
||||
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
|
||||
use_flash_attention_2: bool = field(
|
||||
default=False,
|
||||
attn_implementation: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`"
|
||||
"Which attention implementation to use; you can use --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`"
|
||||
)
|
||||
},
|
||||
)
|
||||
@@ -186,7 +187,8 @@ class ModelArguments:
|
||||
)
|
||||
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
|
||||
bnb_4bit_quant_storage: Optional[str] = field(
|
||||
default="uint8", metadata={"help": "storage type to pack the quanitzed 4-bit prarams."}
|
||||
default="uint8",
|
||||
metadata={"help": "storage type to pack the quanitzed 4-bit prarams."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -235,36 +237,12 @@ class DataArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTConfig(transformers.TrainingArguments):
|
||||
class SFTConfig(trl.SFTConfig):
|
||||
"""
|
||||
Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
|
||||
Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments
|
||||
Also used for the continued pretraining task.
|
||||
"""
|
||||
|
||||
dataset_kwargs: Optional[Dict[str, Any]] = field(
|
||||
default=None, metadata={"help": "Dataset kwargs for the SFTTrainer"}
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
|
||||
)
|
||||
logging_first_step: bool = field(
|
||||
default=True,
|
||||
metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
|
||||
)
|
||||
optim: Optional[str] = field(default="adamw_torch")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPOConfig(transformers.TrainingArguments):
|
||||
"""
|
||||
Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
|
||||
"""
|
||||
|
||||
beta: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta factor in DPO loss. Higher beta means less divergence from the initial policy."},
|
||||
)
|
||||
hub_model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
metadata={"help": ("The Hub model branch to push the model to.")},
|
||||
@@ -273,73 +251,21 @@ class DPOConfig(transformers.TrainingArguments):
|
||||
default=True,
|
||||
metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": ("For DPO, the maximum length of the prompt to use for conditioning the model.")},
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
|
||||
)
|
||||
optim: Optional[str] = field(default="rmsprop")
|
||||
remove_unused_columns: bool = field(default=False)
|
||||
loss_type: Optional[str] = field(default="sigmoid", metadata={"help": ("The loss type for DPO.")})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ORPOConfig(transformers.TrainingArguments):
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The maximum length of the sequences in the batch."},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The maximum length of the prompt."},
|
||||
)
|
||||
max_completion_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The maximum length of the completions."},
|
||||
)
|
||||
class DPOConfig(trl.DPOConfig):
|
||||
"""
|
||||
Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments
|
||||
"""
|
||||
|
||||
beta: float = field(
|
||||
default=0.1,
|
||||
metadata={
|
||||
"help": "The beta factor in ORPO loss (lambda/alpha in paper/code) that is the weight of the relative loss ratio in the SFT loss."
|
||||
},
|
||||
hub_model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
metadata={"help": ("The Hub model branch to push the model to.")},
|
||||
)
|
||||
disable_dropout: bool = field(
|
||||
logging_first_step: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to disable dropouts in `model`."},
|
||||
)
|
||||
|
||||
label_pad_token_id: int = field(
|
||||
default=-100,
|
||||
metadata={"help": "The label pad token id."},
|
||||
)
|
||||
padding_value: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The padding value if it is different to the tokenizer's pad_token_id."},
|
||||
)
|
||||
truncation_mode: str = field(
|
||||
default="keep_end",
|
||||
metadata={"help": "The truncation mode to use, either `keep_end` or `keep_start`."},
|
||||
)
|
||||
|
||||
generate_during_eval: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to sample and log generations during evaluation step."},
|
||||
)
|
||||
is_encoder_decoder: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": ("If no model is provided, we need to know if the model_init returns an encoder-decoder.")},
|
||||
)
|
||||
|
||||
model_init_kwargs: Optional[Dict] = field(
|
||||
default=None,
|
||||
metadata={"help": ("Dict of Optional kwargs to pass when instantiating the model from a string")},
|
||||
)
|
||||
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The number of workers to use to tokenize the data.")},
|
||||
metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
|
||||
)
|
||||
optim: Optional[str] = field(default="rmsprop")
|
||||
remove_unused_columns: bool = field(default=False)
|
||||
|
||||
@@ -68,9 +68,11 @@ def get_tokenizer(
|
||||
) -> PreTrainedTokenizer:
|
||||
"""Get the tokenizer for the model."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path
|
||||
if model_args.tokenizer_name_or_path is None
|
||||
else model_args.tokenizer_name_or_path,
|
||||
(
|
||||
model_args.model_name_or_path
|
||||
if model_args.tokenizer_name_or_path is None
|
||||
else model_args.tokenizer_name_or_path
|
||||
),
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
|
||||
Vendored
+1
-1
@@ -14,7 +14,7 @@ preprocessing_num_workers: 12
|
||||
bf16: true
|
||||
beta: 0.1
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
|
||||
Vendored
+2
-2
@@ -2,7 +2,7 @@
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_mixer:
|
||||
@@ -15,7 +15,7 @@ preprocessing_num_workers: 12
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: true
|
||||
evaluation_strategy: epoch
|
||||
eval_strategy: epoch
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_checkpointing: true
|
||||
hub_model_id: zephyr-7b-sft-full
|
||||
|
||||
+1
-1
@@ -124,7 +124,7 @@ class ApplyChatTemplateTest(unittest.TestCase):
|
||||
def test_maybe_insert_system_message(self):
|
||||
# does not accept system prompt
|
||||
mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
||||
# accepts system prompt. use codellama since it has no HF token reqiurement
|
||||
# accepts system prompt. use codellama since it has no HF token requirement
|
||||
llama_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")
|
||||
messages_sys_excl = [{"role": "user", "content": "Tell me a joke."}]
|
||||
messages_sys_incl = [{"role": "system", "content": ""}, {"role": "user", "content": "Tell me a joke."}]
|
||||
|
||||
Reference in New Issue
Block a user