diff --git a/README.md b/README.md index bda60c8..83a2022 100644 --- a/README.md +++ b/README.md @@ -49,8 +49,8 @@ If you would like to train chat models on your own datasets, we recommend follow The initial release of the handbook will focus on the following techniques: -* **Continued pretraining:** adapt language models to a new language or domain, or simply improve it by continue pretraning (causal language modeling) on a new dataset. -* **Supervised fine-tuning:** teach language models to follow instructions and tips on how to collect and curate your own training dataset. +* **Continued pretraining:** adapt language models to a new language or domain, or simply improve it by continued pretraining (causal language modeling) on a new dataset. +* **Supervised fine-tuning:** teach language models to follow instructions and tips on how to collect and curate your training dataset. * **Reward modeling:** teach language models to distinguish model responses according to human or AI preferences. * **Rejection sampling:** a simple, but powerful technique to boost the performance of your SFT model. * **Direct preference optimisation (DPO):** a powerful and promising alternative to PPO. diff --git a/recipes/constitutional-ai/README.md b/recipes/constitutional-ai/README.md index 71b073b..08f4520 100644 --- a/recipes/constitutional-ai/README.md +++ b/recipes/constitutional-ai/README.md @@ -21,4 +21,4 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con ## Advanced: generating you own dataset -To generate the constitutional AI dataset, see https://github.com/huggingface/llm-swarm/tree/main/examples/constitutional-ai for detailed instructions if you want build or customize the dataset. +To generate the constitutional AI dataset, see https://github.com/huggingface/llm-swarm/tree/main/examples/constitutional-ai for detailed instructions if you want to build or customize the dataset. diff --git a/recipes/constitutional-ai/dpo/config_anthropic.yaml b/recipes/constitutional-ai/dpo/config_anthropic.yaml index 0ef0801..48f5767 100644 --- a/recipes/constitutional-ai/dpo/config_anthropic.yaml +++ b/recipes/constitutional-ai/dpo/config_anthropic.yaml @@ -17,7 +17,7 @@ bf16: true beta: 0.1 do_eval: true do_train: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 1000 gradient_accumulation_steps: 1 gradient_checkpointing: true diff --git a/recipes/constitutional-ai/sft/config_anthropic.yaml b/recipes/constitutional-ai/sft/config_anthropic.yaml index 6414528..6724de0 100644 --- a/recipes/constitutional-ai/sft/config_anthropic.yaml +++ b/recipes/constitutional-ai/sft/config_anthropic.yaml @@ -2,7 +2,7 @@ model_name_or_path: mistralai/Mistral-7B-v0.1 model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" @@ -18,7 +18,7 @@ preprocessing_num_workers: 12 bf16: true do_eval: true do_train: true -evaluation_strategy: epoch # One of ["no", "steps", "epoch"] +eval_strategy: epoch # One of ["no", "steps", "epoch"] gradient_accumulation_steps: 4 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/constitutional-ai/sft/config_grok.yaml b/recipes/constitutional-ai/sft/config_grok.yaml index 6740ac1..c79031d 100644 --- a/recipes/constitutional-ai/sft/config_grok.yaml +++ b/recipes/constitutional-ai/sft/config_grok.yaml @@ -2,7 +2,7 @@ model_name_or_path: mistralai/Mistral-7B-v0.1 model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" @@ -18,7 +18,7 @@ preprocessing_num_workers: 12 bf16: true do_eval: true do_train: true -evaluation_strategy: epoch # One of ["no", "steps", "epoch"] +eval_strategy: epoch # One of ["no", "steps", "epoch"] gradient_accumulation_steps: 4 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/gpt2-nl/README.md b/recipes/gpt2-nl/README.md index 366ae92..68eccfc 100644 --- a/recipes/gpt2-nl/README.md +++ b/recipes/gpt2-nl/README.md @@ -2,7 +2,7 @@ This directory shows a base example of how to use continued pretraining and further tuning to adapt a language model to new data (e.g. a new language or domain). -Three steps are needed: continued pretraining (`cpt`), supervised finetuning (`sft`), and direct preference optimisation (`dpo`). In this dummy example we'll continue pretraining gpt2 on Dutch raw data, then sft-tuning it, and finally aligning it with DPO. Note that no extensive hyperparameters were tested in this example and that the output models are bad - it is just to show you how you can use the scripts for LM adaptation. The scripts work on 4x 3090s (24GB VRAM). If you have less powerful hardware you may need to reduce the batch size. +Three steps are needed: continued pretraining (`cpt`), supervised finetuning (`sft`), and direct preference optimisation (`dpo`). In this dummy example, we'll continue pretraining gpt2 on Dutch raw data, then sft-tuning it, and finally aligning it with DPO. Note that no extensive hyperparameters were tested in this example and that the output models are bad - it is just to show you how you can use the scripts for LM adaptation. The scripts work on 4x 3090s (24GB VRAM). If you have less powerful hardware you may need to reduce the batch size. ## Continued pretraining @@ -18,7 +18,7 @@ ACCELERATE_LOG_LEVEL=info accelerate launch \ ## Supervised finetuning -As other recipes, such as the famous zephyr-7b-beta recipe, have shown, we can then teach our model how to hold a conversation by finetuning it on chat-formatted data. As a base model we'll make use of the output of the previous step. +As other recipes, such as the famous zephyr-7b-beta recipe, have shown, we can then teach our model how to hold a conversation by finetuning it on chat-formatted data. As a base model, we'll make use of the output of the previous step. ```shell ACCELERATE_LOG_LEVEL=info accelerate launch \ diff --git a/recipes/gpt2-nl/cpt/config_full.yaml b/recipes/gpt2-nl/cpt/config_full.yaml index 69d5437..9c7056c 100644 --- a/recipes/gpt2-nl/cpt/config_full.yaml +++ b/recipes/gpt2-nl/cpt/config_full.yaml @@ -15,7 +15,7 @@ preprocessing_num_workers: 12 # SFT trainer config bf16: true do_eval: False -evaluation_strategy: "no" +eval_strategy: "no" gradient_accumulation_steps: 1 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/gpt2-nl/dpo/config_full.yaml b/recipes/gpt2-nl/dpo/config_full.yaml index a2552f3..976c253 100644 --- a/recipes/gpt2-nl/dpo/config_full.yaml +++ b/recipes/gpt2-nl/dpo/config_full.yaml @@ -16,7 +16,7 @@ preprocessing_num_workers: 12 bf16: true beta: 0.1 do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 8 gradient_checkpointing: true diff --git a/recipes/gpt2-nl/sft/config_full.yaml b/recipes/gpt2-nl/sft/config_full.yaml index fef3d5e..f80d8ef 100644 --- a/recipes/gpt2-nl/sft/config_full.yaml +++ b/recipes/gpt2-nl/sft/config_full.yaml @@ -15,7 +15,7 @@ preprocessing_num_workers: 12 # SFT trainer config bf16: true do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 1 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/pref_align_scan/README.md b/recipes/pref_align_scan/README.md index 767a742..f9c81a5 100644 --- a/recipes/pref_align_scan/README.md +++ b/recipes/pref_align_scan/README.md @@ -5,13 +5,14 @@ This directory contains various comparisons for three algorithms: DPO, IPO, and - OpenHermes-2.5 and the OpenOrca datasets We release a collection containing the datasets and models used for these experiments, if you require the other trained models, we can release them on request. -You can find a longer decription of there results in our [blogpost](https://huggingface.co/blog/pref-tuning) +You can find a longer description of these results in our [blogpost](https://huggingface.co/blog/pref-tuning) + ## Comparisons For each algorithm, we aim to tune the beta parameter for a fixed learning rate. We vary beta from 0.1-0.9 in steps of 0.1, we have also found that in certain configurations a tiny value of beta, 0.01, can be effective. So we have included this smaller value in all our comparisons. ## Usage The experiments can be launched with the following bash script: -``` +```bash #!/bin/bash # Define an array containing the base configs we wish to fine tune diff --git a/recipes/pref_align_scan/dpo/config_openhermes.yaml b/recipes/pref_align_scan/dpo/config_openhermes.yaml index 93d9ef3..43e8a23 100644 --- a/recipes/pref_align_scan/dpo/config_openhermes.yaml +++ b/recipes/pref_align_scan/dpo/config_openhermes.yaml @@ -16,7 +16,7 @@ beta: 0.01 loss_type: sigmoid do_eval: true do_train: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 2 gradient_checkpointing: true diff --git a/recipes/pref_align_scan/dpo/config_zephyr.yaml b/recipes/pref_align_scan/dpo/config_zephyr.yaml index 01899bd..0dd6d37 100644 --- a/recipes/pref_align_scan/dpo/config_zephyr.yaml +++ b/recipes/pref_align_scan/dpo/config_zephyr.yaml @@ -15,7 +15,7 @@ bf16: true beta: 0.01 loss_type: sigmoid do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 2 gradient_checkpointing: true diff --git a/recipes/starchat2-15b/dpo/config_v0.1.yaml b/recipes/starchat2-15b/dpo/config_v0.1.yaml index d53c812..cf0ddb3 100644 --- a/recipes/starchat2-15b/dpo/config_v0.1.yaml +++ b/recipes/starchat2-15b/dpo/config_v0.1.yaml @@ -16,7 +16,7 @@ preprocessing_num_workers: 12 bf16: true beta: 0.05 do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 8 gradient_checkpointing: true diff --git a/recipes/starchat2-15b/sft/config_v0.1.yaml b/recipes/starchat2-15b/sft/config_v0.1.yaml index bd65890..f5892de 100644 --- a/recipes/starchat2-15b/sft/config_v0.1.yaml +++ b/recipes/starchat2-15b/sft/config_v0.1.yaml @@ -2,7 +2,7 @@ model_name_or_path: bigcode/starcoder2-15b model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" @@ -20,7 +20,7 @@ preprocessing_num_workers: 24 # SFT trainer config bf16: true do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 2 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/zephyr-141b-A35b/orpo/config_full.yaml b/recipes/zephyr-141b-A35b/orpo/config_full.yaml index 57ae439..b521013 100644 --- a/recipes/zephyr-141b-A35b/orpo/config_full.yaml +++ b/recipes/zephyr-141b-A35b/orpo/config_full.yaml @@ -2,7 +2,7 @@ model_name_or_path: mistral-community/Mixtral-8x22B-v0.1 model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" diff --git a/recipes/zephyr-7b-beta/README.md b/recipes/zephyr-7b-beta/README.md index d27de43..8c082f1 100644 --- a/recipes/zephyr-7b-beta/README.md +++ b/recipes/zephyr-7b-beta/README.md @@ -4,9 +4,9 @@ As described in the Zephyr [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps: 1. Apply SFT to fine-tune Mistral 7B on a filtered version of the UltraChat dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)). The result is an SFT model like [`zephyr-7b-sft-full`](https://huggingface.co/alignment-handbook/zephyr-7b-sft-full) or [`zephyr-7b-sft-qlora`](https://huggingface.co/alignment-handbook/zephyr-7b-sft-qlora). -2. Align the SFT model to AI feedback via DPO on a preprocessed version of the UltraFeedback dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)). The result is an DPO model like [`zephyr-7b-dpo-full`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-full) or [`zephyr-7b-dpo-qlora`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-qlora). +2. Align the SFT model to AI feedback via DPO on a preprocessed version of the UltraFeedback dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)). The result is a DPO model like [`zephyr-7b-dpo-full`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-full) or [`zephyr-7b-dpo-qlora`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-qlora). -**Note:** after the release of Zephyr, the team at [Argilla](https://argilla.io) found that the source UltraFeedback dataset had a few thousand incorrect preference labels from GPT-4. Additionally, TRL's `SFTTrainer` had a bug in the learning rate scheduler which terminated training early. Accounting for these changes led us to find a better set of hyperparameters from those described in the technical report. In particular, for DPO training we found that training for 1 epoch with `beta=0.01` was suffucient to achieve comparable performance to `zephyr-7b-beta` (vs. 3 epochs with `beta=0.1`). +**Note:** after the release of Zephyr, the team at [Argilla](https://argilla.io) found that the source UltraFeedback dataset had a few thousand incorrect preference labels from GPT-4. Additionally, TRL's `SFTTrainer` had a bug in the learning rate scheduler which terminated training early. Accounting for these changes led us to find a better set of hyperparameters from those described in the technical report. In particular, for DPO training we found that training for 1 epoch with `beta=0.01` was sufficient to achieve comparable performance to `zephyr-7b-beta` (vs. 3 epochs with `beta=0.1`). See below for commands to train these models using either DeepSpeed ZeRO-3 or LoRA. @@ -34,11 +34,11 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con P.S. Using Flash Attention also allows you to drastically increase the batch size (x2 in my case) -Train without flash-attention: +Train without flash-attention (i.e. via PyTorch's scaled dot product attention): ```````shell # Step 1 - SFT -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true --use_flash_attention_2=false +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true --attn_implementation=sdpa # Step 2 - DPO -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml --use_flash_attention_2=false +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml --attn_implementation=sdpa ``````` \ No newline at end of file diff --git a/recipes/zephyr-7b-beta/dpo/config_full.yaml b/recipes/zephyr-7b-beta/dpo/config_full.yaml index 9ea336b..12b47b1 100644 --- a/recipes/zephyr-7b-beta/dpo/config_full.yaml +++ b/recipes/zephyr-7b-beta/dpo/config_full.yaml @@ -15,7 +15,7 @@ preprocessing_num_workers: 12 bf16: true beta: 0.01 do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 2 gradient_checkpointing: true diff --git a/recipes/zephyr-7b-beta/dpo/config_qlora.yaml b/recipes/zephyr-7b-beta/dpo/config_qlora.yaml index 7774255..46fbccd 100644 --- a/recipes/zephyr-7b-beta/dpo/config_qlora.yaml +++ b/recipes/zephyr-7b-beta/dpo/config_qlora.yaml @@ -1,7 +1,7 @@ # Model arguments model_name_or_path: alignment-handbook/zephyr-7b-sft-qlora torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # LoRA arguments use_peft: true @@ -31,7 +31,7 @@ preprocessing_num_workers: 12 bf16: true beta: 0.01 do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 4 gradient_checkpointing: true diff --git a/recipes/zephyr-7b-beta/sft/config_full.yaml b/recipes/zephyr-7b-beta/sft/config_full.yaml index f5eb440..f1e8457 100644 --- a/recipes/zephyr-7b-beta/sft/config_full.yaml +++ b/recipes/zephyr-7b-beta/sft/config_full.yaml @@ -2,7 +2,7 @@ model_name_or_path: mistralai/Mistral-7B-v0.1 model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" @@ -16,7 +16,7 @@ preprocessing_num_workers: 12 # SFT trainer config bf16: true do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 1 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/zephyr-7b-beta/sft/config_qlora.yaml b/recipes/zephyr-7b-beta/sft/config_qlora.yaml index 1984083..4881757 100644 --- a/recipes/zephyr-7b-beta/sft/config_qlora.yaml +++ b/recipes/zephyr-7b-beta/sft/config_qlora.yaml @@ -2,7 +2,7 @@ model_name_or_path: mistralai/Mistral-7B-v0.1 model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # LoRA arguments load_in_4bit: true @@ -31,7 +31,7 @@ preprocessing_num_workers: 12 # SFT trainer config bf16: true do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 2 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/zephyr-7b-gemma/dpo/config_full.yaml b/recipes/zephyr-7b-gemma/dpo/config_full.yaml index d643b94..f17ac68 100644 --- a/recipes/zephyr-7b-gemma/dpo/config_full.yaml +++ b/recipes/zephyr-7b-gemma/dpo/config_full.yaml @@ -15,7 +15,7 @@ preprocessing_num_workers: 12 bf16: true beta: 0.05 do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 8 gradient_checkpointing: true diff --git a/recipes/zephyr-7b-gemma/sft/config_full.yaml b/recipes/zephyr-7b-gemma/sft/config_full.yaml index a28f0e4..03226ab 100644 --- a/recipes/zephyr-7b-gemma/sft/config_full.yaml +++ b/recipes/zephyr-7b-gemma/sft/config_full.yaml @@ -3,7 +3,7 @@ model_name_or_path: google/gemma-7b model_revision: main tokenizer_name_or_path: philschmid/gemma-tokenizer-chatml # Custom tokenizer with <|im_start|> and <|im_end|> tokens torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments dataset_mixer: @@ -19,7 +19,7 @@ dataset_kwargs: add_special_tokens: false # We already wrap and in the chat template append_concat_token: false # No need to add across samples do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 4 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/scripts/README.md b/scripts/README.md index 3860e41..1613d8c 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -28,7 +28,7 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/fsdp+qlora.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --torch_dtype=bfloat16 --bnb_4bit_quant_storage=bfloat16 ``` -Here `{task}` refers to the type of training you wish to run. Currently the following tasks are supported: +Here `{task}` refers to the type of training you wish to run. Currently, the following tasks are supported: * continued pretraining `cpt` (note that `cpt` is only present in the `gpt-nl` example recipe) * supervised finetuning `sft` * direct preference optimisation `dpo` @@ -54,8 +54,7 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con ``` ## Logging with Weights and Biases - -By default all training metrics are logged with TensorBoard. If you have a [Weights and Biases](https://wandb.ai/site) account and are logged in, you can view the training metrics by appending `--report_to=wandb`, e.g. +By default, all training metrics are logged with TensorBoard. If you have a [Weights and Biases](https://wandb.ai/site) account and are logged in, you can view the training metrics by appending `--report_to=wandb`, e.g. ```shell ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml --report_to=wandb @@ -120,7 +119,7 @@ If you format your dataset in the same way, our training scripts should work out We recommend benchmarking chat models on: * [MT-Bench](https://huggingface.co/spaces/lmsys/mt-bench): a multi-turn benchmark spanning 80 dialogues and 10 domains. -* [AlpacaEval](https://github.com/tatsu-lab/alpaca_eval): a single-turn benchmark which evaluates the helpfulness of chat and instruct models against `text-davinci-003`. +* [AlpacaEval](https://github.com/tatsu-lab/alpaca_eval): a single-turn benchmark that evaluates the helpfulness of chat and instruct models against `text-davinci-003`. For both benchmarks, we have added support for the [Zephyr chat template](https://huggingface.co/alignment-handbook/zephyr-7b-sft-full/blob/ac6e600eefcce74f5e8bae1035d4f66019e93190/tokenizer_config.json#L30) (which is the default produced by our scripts), so you can evaluate models produced by our scripts as follows: @@ -137,6 +136,6 @@ For both benchmarks, we have added support for the [Zephyr chat template](https: * Next, update the [config name](https://github.com/tatsu-lab/alpaca_eval/blob/2daa6e11b194653043ca74f735728dc068e04aae/src/alpaca_eval/models_configs/zephyr-7b-beta/configs.yaml#L1) and [Hub model ID](https://github.com/tatsu-lab/alpaca_eval/blob/2daa6e11b194653043ca74f735728dc068e04aae/src/alpaca_eval/models_configs/zephyr-7b-beta/configs.yaml#L5) to match your model name. * Follow the steps to evaluate your model [here](https://github.com/tatsu-lab/alpaca_eval/tree/main#evaluating-a-model). -Note that MT-Bench and AlpacaEval rely on LLMs like GPT-4 to judge the quality of the model responses, and thus the ranking exhibit various biases including a preference for models distilled from GPTs. For that reason, we also recommend submitting your best models for human evaluation in: +Note that MT-Bench and AlpacaEval rely on LLMs like GPT-4 to judge the quality of the model responses, and thus the ranking exhibits various biases including a preference for models distilled from GPTs. For that reason, we also recommend submitting your best models for human evaluation in: * [Chatbot Arena](https://chat.lmsys.org): a live, human evaluation of chat models in head-to-head comparisons. diff --git a/scripts/run_cpt.py b/scripts/run_cpt.py index 4e104c4..5553c15 100644 --- a/scripts/run_cpt.py +++ b/scripts/run_cpt.py @@ -127,7 +127,7 @@ def main(): model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, @@ -200,7 +200,7 @@ def main(): if training_args.push_to_hub is True: logger.info("Pushing to hub...") - trainer.push_to_hub(**kwargs) + trainer.push_to_hub(revision=training_args.hub_model_revision, **kwargs) logger.info("*** Training complete ***") diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py index e500a1a..8c944af 100644 --- a/scripts/run_dpo.py +++ b/scripts/run_dpo.py @@ -146,7 +146,7 @@ def main(): model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, @@ -160,7 +160,7 @@ def main(): model_kwargs = dict( revision=model_args.base_model_revision, trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, @@ -252,7 +252,7 @@ def main(): if training_args.push_to_hub is True: logger.info("Pushing to hub...") - trainer.push_to_hub(**kwargs) + trainer.push_to_hub(revision=training_args.hub_model_revision, **kwargs) logger.info("*** Training complete! ***") diff --git a/scripts/run_orpo.py b/scripts/run_orpo.py index 08b3cc6..ce864d3 100644 --- a/scripts/run_orpo.py +++ b/scripts/run_orpo.py @@ -36,8 +36,7 @@ from alignment import ( get_quantization_config, get_tokenizer, ) -from alignment.configs import ORPOConfig -from trl import ORPOTrainer, setup_chat_format +from trl import ORPOConfig, ORPOTrainer, setup_chat_format logger = logging.getLogger(__name__) @@ -107,7 +106,7 @@ def main(): model_args.model_name_or_path, revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, diff --git a/scripts/run_sft.py b/scripts/run_sft.py index 522ce86..62c8493 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -113,7 +113,7 @@ def main(): model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, @@ -224,7 +224,7 @@ def main(): if training_args.push_to_hub is True: logger.info("Pushing to hub...") - trainer.push_to_hub(**kwargs) + trainer.push_to_hub(revision=training_args.hub_model_revision, **kwargs) logger.info("*** Training complete ***") diff --git a/setup.py b/setup.py index 5ae2312..73214bf 100644 --- a/setup.py +++ b/setup.py @@ -43,9 +43,9 @@ if stale_egg_info.exists(): _deps = [ "accelerate>=0.29.2", "bitsandbytes>=0.43.0", - "black==23.1.0", + "black>=24.4.2", "datasets>=2.18.0", - "deepspeed==0.12.2", + "deepspeed>=0.14.4", "einops>=0.6.1", "evaluate==0.4.0", "flake8>=6.0.0", @@ -64,9 +64,9 @@ _deps = [ "sentencepiece>=0.1.99", "scipy", "tensorboard", - "torch==2.1.2", + "torch>=2.1.2", "transformers>=4.39.3", - "trl>=0.8.2", + "trl>=0.9.6", "jinja2>=3.0.0", "tqdm>=4.64.1", ] diff --git a/src/alignment/configs.py b/src/alignment/configs.py index 208be0e..aff0792 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -18,9 +18,10 @@ import sys from dataclasses import dataclass, field from typing import Any, Dict, List, NewType, Optional, Tuple -import transformers from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser +import trl + MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -70,7 +71,7 @@ class H4ArgumentParser(HfArgumentParser): inputs[arg] = [str(v) for v in val.split(",")] # bool of a non-empty string is True, so we manually check for bools - if base_type == bool: + if base_type is bool: if val in ["true", "True"]: inputs[arg] = True else: @@ -146,11 +147,11 @@ class ModelArguments: }, ) trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) - use_flash_attention_2: bool = field( - default=False, + attn_implementation: Optional[str] = field( + default=None, metadata={ "help": ( - "Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`" + "Which attention implementation to use; you can use --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`" ) }, ) @@ -186,7 +187,8 @@ class ModelArguments: ) use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) bnb_4bit_quant_storage: Optional[str] = field( - default="uint8", metadata={"help": "storage type to pack the quanitzed 4-bit prarams."} + default="uint8", + metadata={"help": "storage type to pack the quanitzed 4-bit prarams."}, ) def __post_init__(self): @@ -235,36 +237,12 @@ class DataArguments: @dataclass -class SFTConfig(transformers.TrainingArguments): +class SFTConfig(trl.SFTConfig): """ - Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments + Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments Also used for the continued pretraining task. """ - dataset_kwargs: Optional[Dict[str, Any]] = field( - default=None, metadata={"help": "Dataset kwargs for the SFTTrainer"} - ) - max_seq_length: Optional[int] = field( - default=None, - metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")}, - ) - logging_first_step: bool = field( - default=True, - metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, - ) - optim: Optional[str] = field(default="adamw_torch") - - -@dataclass -class DPOConfig(transformers.TrainingArguments): - """ - Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments - """ - - beta: Optional[float] = field( - default=0.1, - metadata={"help": "The beta factor in DPO loss. Higher beta means less divergence from the initial policy."}, - ) hub_model_revision: Optional[str] = field( default="main", metadata={"help": ("The Hub model branch to push the model to.")}, @@ -273,73 +251,21 @@ class DPOConfig(transformers.TrainingArguments): default=True, metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, ) - max_prompt_length: Optional[int] = field( - default=None, - metadata={"help": ("For DPO, the maximum length of the prompt to use for conditioning the model.")}, - ) - max_length: Optional[int] = field( - default=None, - metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")}, - ) - optim: Optional[str] = field(default="rmsprop") - remove_unused_columns: bool = field(default=False) - loss_type: Optional[str] = field(default="sigmoid", metadata={"help": ("The loss type for DPO.")}) @dataclass -class ORPOConfig(transformers.TrainingArguments): - max_length: Optional[int] = field( - default=None, - metadata={"help": "The maximum length of the sequences in the batch."}, - ) - max_prompt_length: Optional[int] = field( - default=None, - metadata={"help": "The maximum length of the prompt."}, - ) - max_completion_length: Optional[int] = field( - default=None, - metadata={"help": "The maximum length of the completions."}, - ) +class DPOConfig(trl.DPOConfig): + """ + Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments + """ - beta: float = field( - default=0.1, - metadata={ - "help": "The beta factor in ORPO loss (lambda/alpha in paper/code) that is the weight of the relative loss ratio in the SFT loss." - }, + hub_model_revision: Optional[str] = field( + default="main", + metadata={"help": ("The Hub model branch to push the model to.")}, ) - disable_dropout: bool = field( + logging_first_step: bool = field( default=True, - metadata={"help": "Whether or not to disable dropouts in `model`."}, - ) - - label_pad_token_id: int = field( - default=-100, - metadata={"help": "The label pad token id."}, - ) - padding_value: Optional[int] = field( - default=None, - metadata={"help": "The padding value if it is different to the tokenizer's pad_token_id."}, - ) - truncation_mode: str = field( - default="keep_end", - metadata={"help": "The truncation mode to use, either `keep_end` or `keep_start`."}, - ) - - generate_during_eval: bool = field( - default=False, - metadata={"help": "Whether to sample and log generations during evaluation step."}, - ) - is_encoder_decoder: Optional[bool] = field( - default=None, - metadata={"help": ("If no model is provided, we need to know if the model_init returns an encoder-decoder.")}, - ) - - model_init_kwargs: Optional[Dict] = field( - default=None, - metadata={"help": ("Dict of Optional kwargs to pass when instantiating the model from a string")}, - ) - - dataset_num_proc: Optional[int] = field( - default=None, - metadata={"help": ("The number of workers to use to tokenize the data.")}, + metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, ) + optim: Optional[str] = field(default="rmsprop") + remove_unused_columns: bool = field(default=False) diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index fe1ecad..122da3c 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -68,9 +68,11 @@ def get_tokenizer( ) -> PreTrainedTokenizer: """Get the tokenizer for the model.""" tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path - if model_args.tokenizer_name_or_path is None - else model_args.tokenizer_name_or_path, + ( + model_args.model_name_or_path + if model_args.tokenizer_name_or_path is None + else model_args.tokenizer_name_or_path + ), revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, ) diff --git a/tests/fixtures/config_dpo_full.yaml b/tests/fixtures/config_dpo_full.yaml index 5110f59..9ed1387 100644 --- a/tests/fixtures/config_dpo_full.yaml +++ b/tests/fixtures/config_dpo_full.yaml @@ -14,7 +14,7 @@ preprocessing_num_workers: 12 bf16: true beta: 0.1 do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 1 gradient_checkpointing: true diff --git a/tests/fixtures/config_sft_full.yaml b/tests/fixtures/config_sft_full.yaml index adf13da..297dc06 100644 --- a/tests/fixtures/config_sft_full.yaml +++ b/tests/fixtures/config_sft_full.yaml @@ -2,7 +2,7 @@ model_name_or_path: mistralai/Mistral-7B-v0.1 model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments dataset_mixer: @@ -15,7 +15,7 @@ preprocessing_num_workers: 12 # SFT trainer config bf16: true do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 2 gradient_checkpointing: true hub_model_id: zephyr-7b-sft-full diff --git a/tests/test_data.py b/tests/test_data.py index f2d73ee..28483c3 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -124,7 +124,7 @@ class ApplyChatTemplateTest(unittest.TestCase): def test_maybe_insert_system_message(self): # does not accept system prompt mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") - # accepts system prompt. use codellama since it has no HF token reqiurement + # accepts system prompt. use codellama since it has no HF token requirement llama_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf") messages_sys_excl = [{"role": "user", "content": "Tell me a joke."}] messages_sys_incl = [{"role": "system", "content": ""}, {"role": "user", "content": "Tell me a joke."}]