diff --git a/README.md b/README.md index 25662c4..a288a3d 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ This repository contains the code and released models for our paper [SimPO: Simp ## 🔗 Quick Links - [SimPO: Simple Preference Optimization with a Reference-Free Reward](#simple-preference-optimization-simpo) + - [Tips for Running SimPO](#tips-for-running-simpo) - [Released Models](#released-models) - [Install Requirements](#install-requirements) - [Training scripts](#training-scripts) @@ -13,9 +14,34 @@ This repository contains the code and released models for our paper [SimPO: Simp - [Bugs or Questions?](#bugs-or-questions) - [Citation](#citation) -## Released Models -Below is the complete list of models evaluated in our preprint. The following Llama3 models utilize the initial Llama3 tokenizer (before this [PR](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/commit/339ce92d052f002cdbac4a4bd551d1c61dd8345e)). We found that using the updated llama3 tokenizer with vLLM sometimes introduces two BOS tokens, potentially affecting evaluation results, particularly for Arena-Hard. Therefore, please ensure that **only one BOS token** is included in the prompt after applying the Llama3 chat template during any evaluation. +## Tips for Running SimPO +Given the various inquiries about SimPO, we provide a list of tips to help you reproduce our paper results and achieve better outcomes for running SimPO on your own tasks. +### Hyperparameter tuning +Hyperparameter tuning is crucial for SimPO. The three main hyperparameters to focus on are learning_rate, beta, and gamma. +- `learning_rate`: learning_rate: The learning rate is the most critical hyperparameter for preference optimization. A large learning rate (e.g., 1e-5) can significantly degrade performance, causing the model to produce incoherent sentences or completely repetitive responses. We recommend grid searching over 3e-7, 5e-7, and 1e-6, if resources allow. +- `beta: Beta controls the reward scaling between winning and losing responses. In our preprint, we used a small beta (e.g., 2.0 or 2.5), but researchers from Meta suggest that a larger beta (e.g., 10) could yield better results. +- `gamma: Gamma controls the target reward margin. We suggest tuning gamma in tandem with beta, where gamma = c * beta. We recommend grid searching over 0.25, 0.3, and 0.4. A well-tuned gamma can provide a modest improvement, but it is not as critical as other hyperparameters. +We used the following hyperparameters for training the released models. +| Setting | β | γ | Learning rate | +|-------------------|-----|-----|----------------| +| Mistral-Base | 2.0 | 1.6 | 3e-7 | +| Mistral-Instruct | 2.5 | 0.3 | 5e-7 | +| Llama3-Base | 2.0 | 1.0 | 6e-7 | +| Llama3-Instruct | 2.5 | 1.4 | 1e-6 | + +### Training and evaluation consistency in BOS +Our released Llama3 models use the initial version of the Llama3 tokenizer (prior to this [PR](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/commit/339ce92d052f002cdbac4a4bd551d1c61dd8345e)). We have found that the updated Llama3 tokenizer with vLLM occasionally introduces two BOS tokens, which can affect evaluation results. Therefore, please ensure that only one BOS token is included in the prompt after applying the Llama3 chat template during any evaluation. + +*Notably, if you are training Llama3 and evaluating the trained models on AlpacaEval 2 and Arena-Hard using the templates provided in this repo, please make sure to use the pre-update Llama3 tokenizer (i.e., the one before the PR).* + + +### Adding an extra sft loss +We have observed that, in some cases, adding an additional SFT loss can help improve results. These findings have been initially validated in the [CPO_SIMPO](https://github.com/fe1ixxu/CPO_SIMPO/tree/main) repository. We are currently working on integrating this improvement into our main repository. + + +## Released Models +Below is the complete list of models evaluated in our preprint. | models | | AE2 LC | AE2 WR | AH | |------------------------------|-----------------------------------------------------------------------------------------------------------|:------:|:------:|:----:| | Mistral Base 7B SFT | [alignment-handbook/zephyr-7b-sft-full](https://huggingface.co/alignment-handbook/zephyr-7b-sft-full) | 8.4 | 6.2 | 1.3 | diff --git a/accelerate_configs/deepspeed_zero3.yaml b/accelerate_configs/deepspeed_zero3.yaml index b5a1201..ad0d4e6 100644 --- a/accelerate_configs/deepspeed_zero3.yaml +++ b/accelerate_configs/deepspeed_zero3.yaml @@ -13,7 +13,7 @@ machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 -num_processes: 8 +num_processes: 4 rdzv_backend: static same_network: true tpu_env: [] diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..0805f1a --- /dev/null +++ b/run.sh @@ -0,0 +1,19 @@ +#!/bin/bash +#SBATCH --ntasks-per-node=1 +#SBATCH --mem=512G +#SBATCH --gres=gpu:4 +#SBATCH --time=10:00:00 +#SBATCH --partition=pli-c +#SBATCH --output=/scratch/gpfs/mengzhou/space17/out/slurm/%x-%j.out +#SBATCH --err=/scratch/gpfs/mengzhou/space17/out/slurm/%x-%j.err + +conda activate handbook + +cd $n/space17/SimPO + +seed=${1:-1} +output_dir=$n/space17/out/simpo_seed${seed} +mkdir -p $output_dir + +# 4 gpus +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/llama-3-8b-instruct-simpo.yaml --seed=$seed --output_dir=$output_dir \ No newline at end of file diff --git a/scripts/run_simpo.py b/scripts/run_simpo.py index d5fabd2..e546853 100644 --- a/scripts/run_simpo.py +++ b/scripts/run_simpo.py @@ -227,28 +227,29 @@ def main(): ) model = model_args.model_name_or_path - if is_adapter_model(model, model_args.model_revision) is True: - logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}") - peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision) - model_kwargs = dict( - revision=model_args.base_model_revision, - trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, - torch_dtype=torch_dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - base_model = AutoModelForCausalLM.from_pretrained( - peft_config.base_model_name_or_path, - **model_kwargs, - ) - model = PeftModel.from_pretrained( - base_model, - model_args.model_name_or_path, - revision=model_args.model_revision, - ) - model_kwargs = None + # seems to require internet + # if is_adapter_model(model, model_args.model_revision) is True: + # logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}") + # peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision) + # model_kwargs = dict( + # revision=model_args.base_model_revision, + # trust_remote_code=model_args.trust_remote_code, + # use_flash_attention_2=model_args.use_flash_attention_2, + # torch_dtype=torch_dtype, + # use_cache=False if training_args.gradient_checkpointing else True, + # device_map=get_kbit_device_map() if quantization_config is not None else None, + # quantization_config=quantization_config, + # ) + # base_model = AutoModelForCausalLM.from_pretrained( + # peft_config.base_model_name_or_path, + # **model_kwargs, + # ) + # model = PeftModel.from_pretrained( + # base_model, + # model_args.model_name_or_path, + # revision=model_args.model_revision, + # ) + # model_kwargs = None ref_model = model ref_model_kwargs = model_kwargs diff --git a/training_configs/llama-3-8b-instruct-simpo.yaml b/training_configs/llama-3-8b-instruct-simpo.yaml index d0c91ad..17b643d 100644 --- a/training_configs/llama-3-8b-instruct-simpo.yaml +++ b/training_configs/llama-3-8b-instruct-simpo.yaml @@ -18,7 +18,7 @@ gamma: 1.4 do_eval: true evaluation_strategy: steps eval_steps: 400 -gradient_accumulation_steps: 8 +gradient_accumulation_steps: 16 gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: False