From f0ffa0d7a6ab666b1f80f3f7dbb3c6364ac31967 Mon Sep 17 00:00:00 2001 From: lewtun Date: Wed, 10 Jan 2024 17:42:24 +1100 Subject: [PATCH] Update Zephyr configs to account for UltraFeedback & TRL fixes (#88) * Add files * Add checkpointing * Add checkpointing to SFT * Add loss type * Fix setup| * Clean SFT * Add lora config * Rename config * Remove max eval samples * Add kwargs tp push to hub * Add DPO configs * Fix dpo configs * Extend chat template test to multi-turn * Add warmup * Refactor * Fix LoRA -> QLoRA * Fix configs * Specify chat template * Add sample logging * Fix push to hub hanging * Add reentrant * Fix quality * Add transformer logging * Tweak grad acc * Add null type * Add doc --- README.md | 10 +-- recipes/launch.slurm | 13 +-- recipes/zephyr-7b-beta/README.md | 13 +-- recipes/zephyr-7b-beta/dpo/config_full.yaml | 20 +++-- recipes/zephyr-7b-beta/dpo/config_lora.yaml | 51 ------------ recipes/zephyr-7b-beta/dpo/config_qlora.yaml | 56 +++++++++++++ recipes/zephyr-7b-beta/sft/config_full.yaml | 16 ++-- .../{config_lora.yaml => config_qlora.yaml} | 28 ++++--- scripts/README.md | 16 ++-- scripts/run_dpo.py | 82 +++++++++++-------- scripts/run_sft.py | 38 +++++---- setup.py | 9 +- src/alignment/__init__.py | 9 +- src/alignment/configs.py | 1 + src/alignment/data.py | 32 +++----- src/alignment/model_utils.py | 11 ++- tests/test_data.py | 48 ++++++++--- 17 files changed, 266 insertions(+), 187 deletions(-) delete mode 100644 recipes/zephyr-7b-beta/dpo/config_lora.yaml create mode 100644 recipes/zephyr-7b-beta/dpo/config_qlora.yaml rename recipes/zephyr-7b-beta/sft/{config_lora.yaml => config_qlora.yaml} (50%) diff --git a/README.md b/README.md index 5a5d4e0..ee0c6a5 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ To run the code in this project, first, create a Python virtual environment usin conda create -n handbook python=3.10 && conda activate handbook ``` -Next, install PyTorch `v2.1.0` - the precise version is important for reproducibility! Since this is hardware-dependent, we +Next, install PyTorch `v2.1.2` - the precise version is important for reproducibility! Since this is hardware-dependent, we direct you to the [PyTorch Installation Page](https://pytorch.org/get-started/locally/). You can then install the remaining package dependencies as follows: @@ -71,13 +71,13 @@ python -m pip install . You will also need Flash Attention 2 installed, which can be done by running: -> **Note** -> If your machine has less than 96GB of RAM and many CPU cores, reduce the MAX_JOBS., e.g. `MAX_JOBS=4 pip install flash-attn --no-build-isolation` - ```shell -python -m pip install flash-attn --no-build-isolation +python -m pip install flash-attn==2.3.6 --no-build-isolation ``` +> **Note** +> If your machine has less than 96GB of RAM and many CPU cores, reduce the `MAX_JOBS` arguments, e.g. `MAX_JOBS=4 pip install flash-attn==2.3.6 --no-build-isolation` + Next, log into your Hugging Face account as follows: ```shell diff --git a/recipes/launch.slurm b/recipes/launch.slurm index a5f4359..d90fdae 100644 --- a/recipes/launch.slurm +++ b/recipes/launch.slurm @@ -2,7 +2,7 @@ #SBATCH --ntasks-per-node=1 #SBATCH --exclusive #SBATCH --gres=gpu:8 -#SBATCH --partition=production-cluster # Adjust this for your cluster +#SBATCH --partition=hopper-prod # Adjust this for your cluster #SBATCH --output=/fsx/h4/logs/%x-%j.out # Adjust this for your cluster #SBATCH --err=/fsx/h4/logs/%x-%j.err # Adjust this for your cluster @@ -47,7 +47,7 @@ export CMD=" \ scripts/run_$TASK.py $CONFIG_FILE $OPTIONAL_ARGS " -export LAUNCHER="ACCELERATE_LOG_LEVEL=info accelerate launch \ +export LAUNCHER="HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \ --config_file recipes/accelerate_configs/$ACCELERATOR.yaml \ --gradient_accumulation_steps $GRAD_ACC_STEPS \ --num_machines $NUM_NODES \ @@ -71,14 +71,7 @@ export NCCL_ASYNC_ERROR_HANDLING=1 # Specific configuration optimized for the Hugging Face Compute Cluster # Be ye warned this may not work on other clusters! -export NCCL_PROTO=simple -export RDMAV_FORK_SAFE=1 -export FI_EFA_FORK_SAFE=1 -export FI_EFA_USE_DEVICE_RDMA=1 -export FI_PROVIDER=efa -export FI_LOG_LEVEL=1 -export NCCL_IB_DISABLE=1 -export NCCL_SOCKET_IFNAME=ens +module load cuda/12.1 # srun error handling: # --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks diff --git a/recipes/zephyr-7b-beta/README.md b/recipes/zephyr-7b-beta/README.md index 836bcc3..1134e71 100644 --- a/recipes/zephyr-7b-beta/README.md +++ b/recipes/zephyr-7b-beta/README.md @@ -3,12 +3,15 @@ As described in the Zephyr [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps: -1. Apply SFT to fine-tune Mistral 7B on a filtered version of the UltraChat dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)). The result is an SFT model like [`zephyr-7b-sft-full`](https://huggingface.co/alignment-handbook/zephyr-7b-sft-full) or [`zephyr-7b-sft-lora`](https://huggingface.co/alignment-handbook/zephyr-7b-sft-lora). -2. Align the SFT model to AI feedback via DPO on a preprocessed version of the UltraFeedback dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)). The result is an DPO model like [`zephyr-7b-dpo-full`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-full) or [`zephyr-7b-dpo-lora`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-lora). +1. Apply SFT to fine-tune Mistral 7B on a filtered version of the UltraChat dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)). The result is an SFT model like [`zephyr-7b-sft-full`](https://huggingface.co/alignment-handbook/zephyr-7b-sft-full) or [`zephyr-7b-sft-qlora`](https://huggingface.co/alignment-handbook/zephyr-7b-sft-qlora). +2. Align the SFT model to AI feedback via DPO on a preprocessed version of the UltraFeedback dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)). The result is an DPO model like [`zephyr-7b-dpo-full`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-full) or [`zephyr-7b-dpo-qlora`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-qlora). + +**Note:** after the release of Zephyr, the team at [Argilla](https://argilla.io) found that the source UltraFeedback dataset had a few thousand incorrect preference labels from GPT-4. Additionally, TRL's `SFTTrainer` had a bug in the learning rate scheduler which terminated training early. Accounting for these changes led us to find a better set of hyperparameters from those described in the technical report. In particular, for DPO training we found that training for 1 epoch with `beta=0.01` was suffucient to achieve comparable performance to `zephyr-7b-beta` (vs. 3 epochs with `beta=0.1`). See below for commands to train these models using either DeepSpeed ZeRO-3 or LoRA. ## Full training examples + You will require 8 GPUs (80GB of VRAM) to train the full model. ```shell # Step 1 - SFT @@ -18,12 +21,12 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_full.yaml ``` -## LoRA training examples +## QLoRA training examples ```shell # Step 1 - SFT -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_lora.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true # Step 2 - DPO -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_lora.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml ``` \ No newline at end of file diff --git a/recipes/zephyr-7b-beta/dpo/config_full.yaml b/recipes/zephyr-7b-beta/dpo/config_full.yaml index 5110f59..9ea336b 100644 --- a/recipes/zephyr-7b-beta/dpo/config_full.yaml +++ b/recipes/zephyr-7b-beta/dpo/config_full.yaml @@ -1,5 +1,6 @@ # Model arguments model_name_or_path: alignment-handbook/zephyr-7b-sft-full +torch_dtype: null # Data training arguments # For definitions, see: src/h4/training/config.py @@ -12,26 +13,29 @@ preprocessing_num_workers: 12 # DPOTrainer arguments bf16: true -beta: 0.1 +beta: 0.01 do_eval: true evaluation_strategy: steps eval_steps: 100 -gradient_accumulation_steps: 1 +gradient_accumulation_steps: 2 gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: False hub_model_id: zephyr-7b-dpo-full learning_rate: 5.0e-7 log_level: info logging_steps: 10 -lr_scheduler_type: linear +lr_scheduler_type: cosine max_length: 1024 max_prompt_length: 512 -num_train_epochs: 3 -optim: rmsprop +num_train_epochs: 1 +optim: adamw_torch output_dir: data/zephyr-7b-dpo-full per_device_train_batch_size: 8 -per_device_eval_batch_size: 4 +per_device_eval_batch_size: 8 push_to_hub: true -save_strategy: "no" -save_total_limit: null +save_strategy: "steps" +save_steps: 100 +save_total_limit: 1 seed: 42 warmup_ratio: 0.1 \ No newline at end of file diff --git a/recipes/zephyr-7b-beta/dpo/config_lora.yaml b/recipes/zephyr-7b-beta/dpo/config_lora.yaml deleted file mode 100644 index afeb8b4..0000000 --- a/recipes/zephyr-7b-beta/dpo/config_lora.yaml +++ /dev/null @@ -1,51 +0,0 @@ -# Model arguments -model_name_or_path: alignment-handbook/zephyr-7b-sft-lora -torch_dtype: auto - -# LoRA arguments -use_peft: true -lora_r: 64 -lora_alpha: 16 -lora_dropout: 0.1 -lora_target_modules: -- q_proj -- k_proj -- v_proj -- o_proj - -# Data training arguments - -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 1.0 -dataset_splits: -- train_prefs -- test_prefs -preprocessing_num_workers: 12 - -# DPOTrainer arguments -bf16: true -beta: 0.1 -do_eval: true -evaluation_strategy: epoch -eval_steps: 100 -gradient_accumulation_steps: 32 -gradient_checkpointing: true -gradient_checkpointing_kwargs: - use_reentrant: false -hub_model_id: zephyr-7b-dpo-lora -learning_rate: 5.0e-7 -log_level: info -logging_steps: 10 -lr_scheduler_type: linear -max_length: 1024 -max_prompt_length: 512 -num_train_epochs: 3 -optim: rmsprop -output_dir: data/zephyr-7b-dpo-lora # It is handy to append `hub_model_revision` to keep track of your local experiments -per_device_train_batch_size: 2 -per_device_eval_batch_size: 4 -push_to_hub: true -save_strategy: "no" -save_total_limit: null -seed: 42 -warmup_ratio: 0.1 \ No newline at end of file diff --git a/recipes/zephyr-7b-beta/dpo/config_qlora.yaml b/recipes/zephyr-7b-beta/dpo/config_qlora.yaml new file mode 100644 index 0000000..65adcf3 --- /dev/null +++ b/recipes/zephyr-7b-beta/dpo/config_qlora.yaml @@ -0,0 +1,56 @@ +# Model arguments +model_name_or_path: alignment-handbook/zephyr-7b-sft-qlora +torch_dtype: float16 + +# LoRA arguments +use_peft: true +load_in_4bit: true +lora_r: 16 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: +- q_proj +- k_proj +- v_proj +- o_proj +- gate_proj +- up_proj +- down_proj + +# Data training arguments + +dataset_mixer: + HuggingFaceH4/ultrafeedback_binarized: 1.0 +dataset_splits: +- train_prefs +- test_prefs +preprocessing_num_workers: 12 + +# DPOTrainer arguments +bf16: true +beta: 0.01 +do_eval: true +evaluation_strategy: steps +eval_steps: 100 +gradient_accumulation_steps: 2 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +hub_model_id: zephyr-7b-dpo-qlora +learning_rate: 5.0e-6 +log_level: info +logging_steps: 10 +lr_scheduler_type: cosine +max_length: 1024 +max_prompt_length: 512 +num_train_epochs: 1 +optim: paged_adamw_32bit +output_dir: data/zephyr-7b-dpo-qlora # It is handy to append `hub_model_revision` to keep track of your local experiments +per_device_train_batch_size: 4 +per_device_eval_batch_size: 8 +push_to_hub: true +save_strategy: "steps" +save_steps: 100 +save_total_limit: 1 +seed: 42 +warmup_ratio: 0.1 \ No newline at end of file diff --git a/recipes/zephyr-7b-beta/sft/config_full.yaml b/recipes/zephyr-7b-beta/sft/config_full.yaml index 4d8d2d1..f5eb440 100644 --- a/recipes/zephyr-7b-beta/sft/config_full.yaml +++ b/recipes/zephyr-7b-beta/sft/config_full.yaml @@ -5,6 +5,7 @@ torch_dtype: bfloat16 use_flash_attention_2: true # Data training arguments +chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" dataset_mixer: HuggingFaceH4/ultrachat_200k: 1.0 dataset_splits: @@ -16,8 +17,10 @@ preprocessing_num_workers: 12 bf16: true do_eval: true evaluation_strategy: epoch -gradient_accumulation_steps: 2 +gradient_accumulation_steps: 1 gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: False hub_model_id: zephyr-7b-sft-full hub_strategy: every_save learning_rate: 2.0e-05 @@ -30,13 +33,14 @@ max_steps: -1 num_train_epochs: 1 output_dir: data/zephyr-7b-sft-full overwrite_output_dir: true -per_device_eval_batch_size: 16 -per_device_train_batch_size: 32 +per_device_eval_batch_size: 8 +per_device_train_batch_size: 16 push_to_hub: true remove_unused_columns: true report_to: - tensorboard -save_strategy: "no" -save_total_limit: null +save_strategy: "steps" +save_steps: 100 +save_total_limit: 1 seed: 42 -tf32: true \ No newline at end of file +warmup_ratio: 0.1 \ No newline at end of file diff --git a/recipes/zephyr-7b-beta/sft/config_lora.yaml b/recipes/zephyr-7b-beta/sft/config_qlora.yaml similarity index 50% rename from recipes/zephyr-7b-beta/sft/config_lora.yaml rename to recipes/zephyr-7b-beta/sft/config_qlora.yaml index f45c5fc..3b09218 100644 --- a/recipes/zephyr-7b-beta/sft/config_lora.yaml +++ b/recipes/zephyr-7b-beta/sft/config_qlora.yaml @@ -1,20 +1,25 @@ # Model arguments model_name_or_path: mistralai/Mistral-7B-v0.1 -torch_dtype: auto -use_flash_attention_2: true +model_revision: main +torch_dtype: float16 # LoRA arguments +load_in_4bit: true use_peft: true -lora_r: 64 +lora_r: 16 lora_alpha: 16 -lora_dropout: 0.1 +lora_dropout: 0.05 lora_target_modules: - q_proj - k_proj - v_proj - o_proj +- gate_proj +- up_proj +- down_proj # Data training arguments +chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" dataset_mixer: HuggingFaceH4/ultrachat_200k: 1.0 dataset_splits: @@ -26,13 +31,13 @@ preprocessing_num_workers: 12 bf16: true do_eval: true evaluation_strategy: epoch -gradient_accumulation_steps: 128 +gradient_accumulation_steps: 2 gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false -hub_model_id: zephyr-7b-sft-lora +hub_model_id: zephyr-7b-sft-qlora hub_strategy: every_save -learning_rate: 2.0e-05 +learning_rate: 2.0e-04 log_level: info logging_steps: 5 logging_strategy: steps @@ -40,14 +45,15 @@ lr_scheduler_type: cosine max_seq_length: 2048 max_steps: -1 num_train_epochs: 1 -output_dir: data/zephyr-7b-sft-lora +output_dir: data/zephyr-7b-sft-qlora overwrite_output_dir: true per_device_eval_batch_size: 8 per_device_train_batch_size: 4 push_to_hub: true report_to: - tensorboard -save_strategy: "no" -save_total_limit: null +save_strategy: "steps" +save_steps: 100 +save_total_limit: 1 seed: 42 -warmup_ratio: 0.1 +warmup_ratio: 0.1 \ No newline at end of file diff --git a/scripts/README.md b/scripts/README.md index 07e2976..5e28cee 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -9,20 +9,20 @@ In the handbook, we provide three main ways to align LLMs for chat: - LoRA or QLoRA fine-tuning on a single consumer 24GB GPU (tested on an RTX 4090). - LoRA fine-tuning on a multi-GPU machine with DeepSpeed ZeRO-3 (tested on a 2 x A100s (80GB)). -In practice, we find comparable performance for both full and LoRA fine-tuning, with the latter having the advantage of producing small adapter weights that are fast to upload and download from the Hugging Face Hub. Here are the general commands to fine-tune your models: +In practice, we find comparable performance for both full and QLoRA fine-tuning, with the latter having the advantage of producing small adapter weights that are fast to upload and download from the Hugging Face Hub. Here are the general commands to fine-tune your models: ```shell # Full training with ZeRO-3 on 8 GPUs ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml -# LoRA training on a single GPU -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_{task}.py recipes/{model_name}/{task}/config_lora.yaml - # QLoRA 4-bit training on a single GPU -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_{task}.py recipes/{model_name}/{task}/config_lora.yaml --load_in_4bit=true +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml + +# LoRA training on a single GPU +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --load_in_4bit=false # LoRA training with ZeRO-3 on two or more GPUs -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_lora.yaml +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --load_in_4bit=false ``` Here `{task}` refers to the type of training you wish to run (SFT, DPO, etc), while `{model_name}` refers to the choice of a recipe in the `recipes` directory. For example, to replicate Zephyr-7B-β you can run: @@ -44,6 +44,8 @@ By default, these scripts will push each model to your Hugging Face Hub username ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml --per_device_train_batch_size=42 --num_train_epochs=5 ``` +## Logging with Weights and Biases + By default all training metrics are logged with TensorBoard. If you have a [Weights and Biases](https://wandb.ai/site) account and are logged in, you can view the training metrics by appending `--report_to=wandb`, e.g. ```shell @@ -58,7 +60,7 @@ If you have access to a Slurm cluster, we provide a `recipes/launch.slurm` scrip sbatch --job-name=handbook_{task} --nodes=1 recipes/launch.slurm {model_name} {task} {precision} {accelerator} ``` -Here `{model_name}` and `{task}` are defined as above, while `{precision}` refers to the type of training (`full` vs `lora`) and `{accelerator}` refers to the choice of 🤗 Accelerate config in `recipes/accelerate_configs`. If you wish to override the default config parameters, you can provide them by appending a space-separated string like `'--arg1=value1 --arg2=value2'. Here's a concrete example to run SFT on 1 node of 8 GPUs: +Here `{model_name}` and `{task}` are defined as above, while `{precision}` refers to the type of training (`full` vs `qlora`) and `{accelerator}` refers to the choice of 🤗 Accelerate config in `recipes/accelerate_configs`. If you wish to override the default config parameters, you can provide them by appending a space-separated string like `'--arg1=value1 --arg2=value2'. Here's a concrete example to run SFT on 1 node of 8 GPUs: ```shell # Launch on Slurm and override default hyperparameters diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py index fbd084e..58b97ed 100644 --- a/scripts/run_dpo.py +++ b/scripts/run_dpo.py @@ -14,19 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random import sys import torch import transformers from transformers import AutoModelForCausalLM, set_seed -from accelerate import Accelerator from alignment import ( DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, apply_chat_template, + get_checkpoint, get_datasets, get_kbit_device_map, get_peft_config, @@ -64,12 +65,14 @@ def main(): logger.info(f"Data parameters {data_args}") logger.info(f"Training/evaluation parameters {training_args}") + # Check for last checkpoint + last_checkpoint = get_checkpoint(training_args) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") + # Set seed for reproducibility set_seed(training_args.seed) - # Increase distributed timeout to 3h to enable push to Hub to complete - accelerator = Accelerator() - ############### # Load datasets ############### @@ -102,6 +105,12 @@ def main(): {"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"} ) + # Log a few random samples from the training set: + for index in random.sample(range(len(raw_datasets["train"])), 3): + logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}") + logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}") + logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}") + torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) @@ -118,10 +127,10 @@ def main(): ) model = model_args.model_name_or_path - if is_adapter_model(model, model_args.model_revision): - # load the model, merge the adapter weights and unload the adapter - # Note: to run QLora, you will need to merge the based model separately as the merged model in 16bit - logger.info(f"Merging peft adapters for {model_args.model_name_or_path=}") + if is_adapter_model(model, model_args.model_revision) is True: + # Load the base model, merge the adapter weights and unload the adapter + # Note: to run QLoRA, you will need to merge the base model separately as the merged model in 16bit + logger.info(f"Merging PEFT adapters for {model_args.model_name_or_path=}") peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision) @@ -153,7 +162,7 @@ def main(): ######################### # Instantiate DPO trainer ######################### - dpo_trainer = DPOTrainer( + trainer = DPOTrainer( model, ref_model, model_init_kwargs=model_kwargs, @@ -166,17 +175,23 @@ def main(): max_length=training_args.max_length, max_prompt_length=training_args.max_prompt_length, peft_config=get_peft_config(model_args), + loss_type=training_args.loss_type, ) ############### # Training loop ############### - train_result = dpo_trainer.train() + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) metrics = train_result.metrics metrics["train_samples"] = len(raw_datasets["train"]) - dpo_trainer.log_metrics("train", metrics) - dpo_trainer.save_metrics("train", metrics) - dpo_trainer.save_state() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() logger.info("*** Training complete ***") @@ -185,35 +200,36 @@ def main(): ########## if training_args.do_eval: logger.info("*** Evaluate ***") - metrics = dpo_trainer.evaluate() + metrics = trainer.evaluate() metrics["eval_samples"] = len(raw_datasets["test"]) - dpo_trainer.log_metrics("eval", metrics) - dpo_trainer.save_metrics("eval", metrics) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) ################################## # Save model and create model card ################################## - dpo_trainer.save_model(training_args.output_dir) + logger.info("*** Save model ***") + trainer.save_model(training_args.output_dir) + logger.info(f"Model saved to {training_args.output_dir}") + # Save everything else on main process - if accelerator.is_main_process: - kwargs = { - "finetuned_from": model_args.model_name_or_path, - "dataset": list(data_args.dataset_mixer.keys()), - "dataset_tags": list(data_args.dataset_mixer.keys()), - "tags": ["alignment-handbook"], - } - dpo_trainer.create_model_card(**kwargs) + kwargs = { + "finetuned_from": model_args.model_name_or_path, + "dataset": list(data_args.dataset_mixer.keys()), + "dataset_tags": list(data_args.dataset_mixer.keys()), + "tags": ["alignment-handbook"], + } + if trainer.accelerator.is_main_process: + trainer.create_model_card(**kwargs) # Restore k,v cache for fast inference - dpo_trainer.model.config.use_cache = True - dpo_trainer.model.config.save_pretrained(training_args.output_dir) - if training_args.push_to_hub is True: - dpo_trainer.push_to_hub() + trainer.model.config.use_cache = True + trainer.model.config.save_pretrained(training_args.output_dir) - # Ensure we don't timeout on model save / push to Hub - logger.info("*** Waiting for all processes to finish ***") - accelerator.wait_for_everyone() + if training_args.push_to_hub is True: + logger.info("Pushing to hub...") + trainer.push_to_hub(**kwargs) - logger.info("*** Run complete! ***") + logger.info("*** Training complete! ***") if __name__ == "__main__": diff --git a/scripts/run_sft.py b/scripts/run_sft.py index e0d892f..eafb259 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -26,13 +26,13 @@ import torch import transformers from transformers import set_seed -from accelerate import Accelerator from alignment import ( DataArguments, H4ArgumentParser, ModelArguments, SFTConfig, apply_chat_template, + get_checkpoint, get_datasets, get_kbit_device_map, get_peft_config, @@ -52,8 +52,6 @@ def main(): # Set seed for reproducibility set_seed(training_args.seed) - accelerator = Accelerator() - ############### # Setup logging ############### @@ -78,6 +76,11 @@ def main(): logger.info(f"Data parameters {data_args}") logger.info(f"Training/evaluation parameters {training_args}") + # Check for last checkpoint + last_checkpoint = get_checkpoint(training_args) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") + ############### # Load datasets ############### @@ -149,7 +152,12 @@ def main(): # Training loop ############### logger.info("*** Train ***") - train_result = trainer.train() + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) metrics = train_result.metrics metrics["train_samples"] = len(train_dataset) trainer.log_metrics("train", metrics) @@ -174,23 +182,23 @@ def main(): logger.info(f"Model saved to {training_args.output_dir}") # Save everything else on main process - if accelerator.is_main_process: - kwargs = { - "finetuned_from": model_args.model_name_or_path, - "dataset": list(data_args.dataset_mixer.keys()), - "dataset_tags": list(data_args.dataset_mixer.keys()), - "tags": ["alignment-handbook"], - } + kwargs = { + "finetuned_from": model_args.model_name_or_path, + "dataset": list(data_args.dataset_mixer.keys()), + "dataset_tags": list(data_args.dataset_mixer.keys()), + "tags": ["alignment-handbook"], + } + if trainer.accelerator.is_main_process: trainer.create_model_card(**kwargs) # Restore k,v cache for fast inference trainer.model.config.use_cache = True trainer.model.config.save_pretrained(training_args.output_dir) - if training_args.push_to_hub is True: - logger.info("Pushing to hub...") - trainer.push_to_hub() + if training_args.push_to_hub is True: + logger.info("Pushing to hub...") + trainer.push_to_hub(**kwargs) - accelerator.wait_for_everyone() + logger.info("*** Training complete ***") if __name__ == "__main__": diff --git a/setup.py b/setup.py index efdac57..51437e2 100644 --- a/setup.py +++ b/setup.py @@ -50,21 +50,22 @@ _deps = [ "evaluate==0.4.0", "flake8>=6.0.0", "hf-doc-builder>=0.4.0", + "hf_transfer>=0.1.4", "huggingface-hub>=0.14.1,<1.0", "isort>=5.12.0", "ninja>=1.11.1", "numpy>=1.24.2", "packaging>=23.0", "parameterized>=0.9.0", - "peft==0.6.1", + "peft==0.7.1", "protobuf<=3.20.2", # Needed to avoid conflicts with `transformers` "pytest", "safetensors>=0.3.3", "scipy", "tensorboard", - "torch==2.1.0", - "transformers==4.35.0", - "trl==0.7.4", + "torch==2.1.2", + "transformers==4.36.2", + "trl==0.7.7", "jinja2>=3.0.0", "tqdm>=4.64.1", ] diff --git a/src/alignment/__init__.py b/src/alignment/__init__.py index d20d6d7..81317f8 100644 --- a/src/alignment/__init__.py +++ b/src/alignment/__init__.py @@ -2,4 +2,11 @@ __version__ = "0.3.0.dev0" from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig from .data import apply_chat_template, get_datasets -from .model_utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer, is_adapter_model +from .model_utils import ( + get_checkpoint, + get_kbit_device_map, + get_peft_config, + get_quantization_config, + get_tokenizer, + is_adapter_model, +) diff --git a/src/alignment/configs.py b/src/alignment/configs.py index 9ca7c8e..2a71ea4 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -251,3 +251,4 @@ class DPOConfig(transformers.TrainingArguments): ) optim: Optional[str] = field(default="rmsprop") remove_unused_columns: bool = field(default=False) + loss_type: Optional[str] = field(default="sigmoid", metadata={"help": ("The loss type for DPO.")}) diff --git a/src/alignment/data.py b/src/alignment/data.py index 532bddc..1617d3c 100644 --- a/src/alignment/data.py +++ b/src/alignment/data.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import re from typing import List, Literal, Optional from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk @@ -26,12 +25,10 @@ DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == def apply_chat_template( - example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n" + example, + tokenizer, + task: Literal["sft", "generation", "rm", "dpo"], ): - def _strip_prefix(s, pattern): - # Use re.escape to escape any special characters in the pattern - return re.sub(f"^{re.escape(pattern)}", "", s) - if task in ["sft", "generation"]: messages = example["messages"] # We add an empty system message if there is none @@ -57,23 +54,18 @@ def apply_chat_template( ) elif task == "dpo": if all(k in example.keys() for k in ("chosen", "rejected")): - # Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token - prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]] - # Insert system message + # For DPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue + # We therefore need to extract the N-1 turns to form the prompt + prompt_messages = example["chosen"][:-1] + # Prepend a system message if the first message is not a system message if example["chosen"][0]["role"] != "system": prompt_messages.insert(0, {"role": "system", "content": ""}) - else: - prompt_messages.insert(0, example["chosen"][0]) - # TODO: handle case where chosen/rejected also have system messages - chosen_messages = example["chosen"][1:] - rejected_messages = example["rejected"][1:] + # Now we extract the final turn to define chosen/rejected responses + chosen_messages = example["chosen"][-1:] + rejected_messages = example["rejected"][-1:] example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) - example["text_prompt"] = tokenizer.apply_chat_template( - prompt_messages, tokenize=False, add_generation_prompt=True - ) - example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix) - example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix) + example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False) else: raise ValueError( f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" @@ -112,7 +104,7 @@ def get_datasets( # - 'dataset2': 0.3 # - 'dataset3': 0.2 dataset_mixer = data_config.dataset_mixer - elif type(data_config) is dict: + elif isinstance(data_config, dict): # Structure of the input is: # dataset_mixer = { # "dataset1": 0.5, diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index abb930c..52bf00f 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -13,17 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from pathlib import Path from typing import Dict import torch from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer +from transformers.trainer_utils import get_last_checkpoint from accelerate import Accelerator from huggingface_hub import list_repo_files from huggingface_hub.utils._validators import HFValidationError from peft import LoraConfig, PeftConfig -from .configs import DataArguments, ModelArguments +from .configs import DataArguments, DPOConfig, ModelArguments, SFTConfig from .data import DEFAULT_CHAT_TEMPLATE @@ -104,3 +106,10 @@ def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool: # If not, check local repo repo_files = os.listdir(model_name_or_path) return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files + + +def get_checkpoint(training_args: SFTConfig | DPOConfig) -> Path | None: + last_checkpoint = None + if os.path.isdir(training_args.output_dir): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + return last_checkpoint diff --git a/tests/test_data.py b/tests/test_data.py index 6a6b63c..00d10b5 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -88,9 +88,33 @@ class ApplyChatTemplateTest(unittest.TestCase): self.dataset = Dataset.from_dict( { "prompt": ["Hello!"], - "messages": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]], - "chosen": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]], - "rejected": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hola!"}]], + "messages": [ + [ + {"role": "system", "content": "You are a happy chatbot"}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Bonjour!"}, + {"role": "user", "content": "How are you?"}, + {"role": "assistant", "content": "I am doing well, thanks!"}, + ] + ], + "chosen": [ + [ + {"role": "system", "content": "You are a happy chatbot"}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Bonjour!"}, + {"role": "user", "content": "How are you?"}, + {"role": "assistant", "content": "I am doing well, thanks!"}, + ] + ], + "rejected": [ + [ + {"role": "system", "content": "You are a happy chatbot"}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Bonjour!"}, + {"role": "user", "content": "How are you?"}, + {"role": "assistant", "content": "Not so good tbh"}, + ] + ], } ) @@ -102,7 +126,9 @@ class ApplyChatTemplateTest(unittest.TestCase): ) self.assertDictEqual( dataset[0], - {"text": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n"}, + { + "text": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n<|assistant|>\nI am doing well, thanks!\n" + }, ) def test_generation(self): @@ -115,7 +141,9 @@ class ApplyChatTemplateTest(unittest.TestCase): ) self.assertDictEqual( dataset[0], - {"text": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\n"}, + { + "text": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n<|assistant|>\n" + }, ) def test_rm(self): @@ -127,8 +155,8 @@ class ApplyChatTemplateTest(unittest.TestCase): self.assertDictEqual( dataset[0], { - "text_chosen": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n", - "text_rejected": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\nHola!\n", + "text_chosen": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n<|assistant|>\nI am doing well, thanks!\n", + "text_rejected": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n<|assistant|>\nNot so good tbh\n", }, ) @@ -141,8 +169,8 @@ class ApplyChatTemplateTest(unittest.TestCase): self.assertDictEqual( dataset[0], { - "text_prompt": "<|system|>\n\n<|user|>\nHello!\n<|assistant|>\n", - "text_chosen": "Bonjour!\n", - "text_rejected": "Hola!\n", + "text_prompt": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n", + "text_chosen": "<|assistant|>\nI am doing well, thanks!\n", + "text_rejected": "<|assistant|>\nNot so good tbh\n", }, )