Added gemma models

This commit is contained in:
xiamengzhou
2024-07-17 11:17:36 -04:00
committed by GitHub
parent 17f3559c88
commit aa5c3062fc
+12 -7
View File
@@ -29,7 +29,7 @@ We provide an [environment file](https://github.com/princeton-nlp/SimPO/blob/mai
### Hyperparameter tuning
Hyperparameter tuning is crucial for SimPO (and other preference optimization algorithms in general). The three main hyperparameters of SimPO to focus on are `learning_rate`, `beta`, and `gamma`.
- `learning_rate`: It is the most critical hyperparameter for preference optimization. A large learning rate (e.g., 1e-5) can significantly degrade performance, causing the model to produce incoherent sentences or completely repetitive responses. We recommend grid searching over 3e-7, 5e-7, and 1e-6, if resources allow.
- `learning_rate`: It is the most critical hyperparameter for preference optimization. A large learning rate (e.g., 1e-5) can significantly degrade performance, causing the model to produce incoherent sentences or completely repetitive responses. We recommend grid searching over 3e-7, 5e-7, 8e-7, and 1e-6, if resources allow. **We find that a smaller learning rate (e.g., 5e-7) is more suitable for reasoning intensive domains like math for both DPO and SimPO.**
- `beta`: Beta controls the reward scaling between winning and losing responses. SimPO requires a much larger `beta` than DPO. In our preprint, we used a beta of `2.0` or `2.5`, but in many cases, an even larger beta (e.g., `10`) could yield better results.
- `gamma`: Gamma controls the target reward margin. We suggest tuning the ratio of gamma to beta (i.e., `gamma / beta`). We recommend using `0.5` as a starting point for `gamma_beta_ratio` and grid searching between `0` and `1`. A well-tuned `gamma_beta_ratio` can provide a modest improvement, but it is not as critical as other hyperparameters.
@@ -41,6 +41,7 @@ We used the following hyperparameters for training the released models (note tha
| Llama3-Base | 2.0 | 0.5 | 6e-7 |
| Llama3-Instruct | 2.5 | 0.55 | 1e-6 |
| Llama3-Instruct v0.2 | 10 | 0.3 | 1e-6 |
| Gemma | 10 | 0.5 | 8e-7 |
For DPO, the best hyperparameters for each setting are as follows.
| Setting | β | Learning Rate |
@@ -50,6 +51,7 @@ For DPO, the best hyperparameters for each setting are as follows.
| Llama3-Base | 0.01 | 5e-7 |
| Llama3-Instruct | 0.01 | 7e-7 |
| Llama3-Instruct v0.2 | 0.01 | 3e-7 |
| Gemma | 0.01 | 5e-7 |
### Training and evaluation consistency in BOS
@@ -66,13 +68,16 @@ The [CPO_SIMPO](https://github.com/fe1ixxu/CPO_SIMPO/tree/main) repository did p
## Released Models
<!--- ### Gemma
We release the following two models that are built on top of the strong [google/gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it) model.
### Gemma
We release the following two models that are built on top of the strong [google/gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it) model by training DPO and SimPO on the on-policy dataset [princeton-nlp/gemma2-ultrafeedback-armorm](https://huggingface.co/datasets/princeton-nlp/gemma2-ultrafeedback-armorm). For GSM and MMLU, we use the [EvalZero](https://github.com/yuchenlin/ZeroEval) reporistory which aims to evaluate instruction-tuned LLMs (i.e., chat models instead of base models) for their zero-shot performance on reasoning and knowledge heavy tasks. More results on [WildBench](https://huggingface.co/spaces/allenai/WildBench) are coming soon.
| models | AE2 LC | AE2 WR | AH |
|-----------------------------------------------------------------------------------------------------------|:------:|:------:|:----:|
| [princeton-nlp/gemma-2-9b-it-DPO](https://huggingface.co/princeton-nlp/gemma-2-9b-it-DPO) | 37.9 | 31.6 | 28.8 |
| [princeton-nlp/gemma-2-9b-it-SimPO](https://huggingface.co/princeton-nlp/gemma-2-9b-it-SimPO) | 33.9 | 32.5 | 29.3 | -->
| models | AE2 LC | AE2 WR | AE2 Length | AH | AH Length | GSM | GSM Length | MMLU | MMLU Length |
|-----------------------------------|:------:|:------:|:----------:|:----:|:---------:|:----:|:----------:|:----:|:-----------:|
| [google/gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it) | 51.1 | 38.1 | 1571 | 40.8 | 545 | 87.4 | 395 | 72.7 | 515 |
| [princeton-nlp/gemma-2-9b-it-DPO](https://huggingface.co/princeton-nlp/gemma-2-9b-it-DPO) | 69.6 | 67.2 | 2016 | 58.9 | 717 | 88.5 | 392 | 72.2 | 624 |
| [princeton-nlp/gemma-2-9b-it-SimPO](https://huggingface.co/princeton-nlp/gemma-2-9b-it-SimPO) | 73.2 | 66.7 | 1833 | 59.1 | 693 | 88.0 | 341 | 72.2 | 441 |
Compared to the llama3 models, we found that the gemma models exhibit significantly less catastrophic forgetting on math tasks (e.g., GSM) and MMLU, despite the ultrafeedback dataset having limited math-related data. This demonstrates that the [google/gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it) model is more suitable for continued preference optimization.
### v0.2