mirror of
https://github.com/wassname/SimPO.git
synced 2026-06-27 17:46:46 +08:00
release
This commit is contained in:
+164
@@ -0,0 +1,164 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# Temp folders
|
||||
data/
|
||||
wandb/
|
||||
@@ -0,0 +1,88 @@
|
||||
# Simple Preference Optimization (SimPO)
|
||||
|
||||
This repository contains the code and released models for our paper [SimPO: Simple Preference Optimization with a Reference-Free Reward](https://arxiv.org/abs/2405.14734). We propose a simpler and more effective preference optimization algorithm than DPO (Direct Preference Optimization) without using a reference model. SimPO outperforms DPO and its latest variants across AlpacaEval 2, MT-Bench, and Arena-Hard benchmarks under various settings.
|
||||
|
||||
<img src="./SimPO.png" width="1000px"></img>
|
||||
|
||||
## 🔗 Quick Links
|
||||
- [SimPO: Simple Preference Optimization with a Reference-Free Reward](#simple-preference-optimization-simpo)
|
||||
- [Released Models](#released-models)
|
||||
- [Install Requirements](#install-requirements)
|
||||
- [Training scripts](#training-scripts)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Bugs or Questions?](#bugs-or-questions)
|
||||
- [Citation](#citation)
|
||||
|
||||
## Released Models
|
||||
|
||||
|
||||
## Install Requirements
|
||||
|
||||
Our codebase is built upon the [alignment-handbook repo](https://github.com/huggingface/alignment-handbook). The following steps will guide you through the installation process.
|
||||
|
||||
First, create a Python virtual environment using e.g. Conda:
|
||||
```shell
|
||||
conda create -n handbook python=3.10 && conda activate handbook
|
||||
```
|
||||
|
||||
Next, install PyTorch `v2.2.2`. Since this is hardware-dependent, we
|
||||
direct you to the [PyTorch Installation Page](https://pytorch.org/get-started/locally/).
|
||||
|
||||
You can then install the remaining package dependencies of [alignment-handbook](https://github.com/huggingface/alignment-handbook) as follows:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/huggingface/alignment-handbook.git
|
||||
cd ./alignment-handbook/
|
||||
python -m pip install .
|
||||
```
|
||||
|
||||
You will also need Flash Attention 2 installed, which can be done by running:
|
||||
|
||||
```shell
|
||||
python -m pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
## Training Scripts
|
||||
|
||||
We provide four training config files for the four training setups reported in our paper:
|
||||
|
||||
* Mistral-Base:
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/mistral-7b-base-simpo.yaml
|
||||
```
|
||||
* Mistral-Instruct:
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/mistral-7b-instruct-simpo.yaml
|
||||
```
|
||||
* Llama3-Base:
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/llama-3-8b-base-simpo.yaml
|
||||
```
|
||||
* Llama3-Instruct:
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/llama-3-8b-instruct-simpo.yaml
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
We follow the official implementation for evaluation on AlpacaEval 2, Arena-Hard, and MT-Bench, as follows:
|
||||
|
||||
* AlpacaEval 2: Please refer to the [AlpacaEval repo](https://github.com/tatsu-lab/alpaca_eval) for evaluation.
|
||||
|
||||
* Arena-Hard: Please refer to to the [Arena-Hard-Auto repo](https://github.com/lm-sys/arena-hard-auto) for evaluation.
|
||||
|
||||
* MT-Bench: Please refer to the [FastChat repo](https://github.com/lm-sys/FastChat) for evaluation.
|
||||
|
||||
## Bugs or Questions?
|
||||
If you have any questions related to the code or the paper, feel free to email Yu (yumeng5@virginia.edu). If you encounter any problems when using the code, or want to report a bug, feel free to open an issue! Please try to specify the problem with details so we can help you better and quicker!
|
||||
|
||||
## Citation
|
||||
Please cite our paper if you find the repo helpful in your work:
|
||||
|
||||
```bibtex
|
||||
@article{meng2024simpo,
|
||||
title={{SimPO}: Simple Preference Optimization with a Reference-Free Reward},
|
||||
author={Meng, Yu and Xia, Mengzhou and Chen, Danqi},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -0,0 +1,26 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: true
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -0,0 +1,16 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, set_seed
|
||||
|
||||
from alignment import (
|
||||
DataArguments,
|
||||
DPOConfig,
|
||||
H4ArgumentParser,
|
||||
ModelArguments,
|
||||
maybe_insert_system_message,
|
||||
is_openai_format,
|
||||
get_checkpoint,
|
||||
get_datasets,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
get_tokenizer,
|
||||
is_adapter_model,
|
||||
)
|
||||
from peft import PeftConfig, PeftModel
|
||||
from simpo_trainer import SimPOTrainer
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Literal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MISTRAL_CHAT_TEMPLATE = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'].strip() + '\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
|
||||
|
||||
@dataclass
|
||||
class SimPOConfig(DPOConfig):
|
||||
gamma: Optional[float] = field(
|
||||
default=0.5,
|
||||
metadata={"help": "The target reward margin term in SimPO loss."},
|
||||
)
|
||||
|
||||
def apply_chat_template(
|
||||
example,
|
||||
tokenizer,
|
||||
task: Literal["sft", "generation", "rm", "simpo"],
|
||||
auto_insert_empty_system_msg: bool = True,
|
||||
change_template = None,
|
||||
):
|
||||
if change_template == "mistral":
|
||||
tokenizer.chat_template = MISTRAL_CHAT_TEMPLATE
|
||||
if task in ["sft", "generation"]:
|
||||
messages = example["messages"]
|
||||
# We add an empty system message if there is none
|
||||
if auto_insert_empty_system_msg:
|
||||
maybe_insert_system_message(messages, tokenizer)
|
||||
example["text"] = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True if task == "generation" else False,
|
||||
)
|
||||
elif task == "rm":
|
||||
if all(k in example.keys() for k in ("chosen", "rejected")):
|
||||
chosen_messages = example["chosen"]
|
||||
rejected_messages = example["rejected"]
|
||||
# We add an empty system message if there is none
|
||||
if auto_insert_empty_system_msg:
|
||||
maybe_insert_system_message(chosen_messages, tokenizer)
|
||||
maybe_insert_system_message(rejected_messages, tokenizer)
|
||||
|
||||
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
|
||||
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
|
||||
)
|
||||
elif task == "simpo":
|
||||
if all(k in example.keys() for k in ("chosen", "rejected")):
|
||||
if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]):
|
||||
raise ValueError(
|
||||
f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages"
|
||||
)
|
||||
|
||||
# For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
|
||||
# We therefore need to extract the N-1 turns to form the prompt
|
||||
if "prompt" in example and is_openai_format(example["prompt"]):
|
||||
prompt_messages = example["prompt"]
|
||||
chosen_messages = example["chosen"]
|
||||
rejected_messages = example["rejected"]
|
||||
else:
|
||||
prompt_messages = example["chosen"][:-1]
|
||||
# Now we extract the final turn to define chosen/rejected responses
|
||||
chosen_messages = example["chosen"][-1:]
|
||||
rejected_messages = example["rejected"][-1:]
|
||||
|
||||
# Prepend a system message if the first message is not a system message
|
||||
if auto_insert_empty_system_msg:
|
||||
maybe_insert_system_message(prompt_messages, tokenizer)
|
||||
|
||||
example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
|
||||
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
|
||||
if example["text_chosen"].startswith(tokenizer.bos_token):
|
||||
example["text_chosen"] = example["text_chosen"][len(tokenizer.bos_token):]
|
||||
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
|
||||
if example["text_rejected"].startswith(tokenizer.bos_token):
|
||||
example["text_rejected"] = example["text_rejected"][len(tokenizer.bos_token):]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not format example as dialogue for `{task}` task! Require either the "
|
||||
f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']"
|
||||
)
|
||||
return example
|
||||
|
||||
|
||||
def main():
|
||||
parser = H4ArgumentParser((ModelArguments, DataArguments, SimPOConfig))
|
||||
model_args, data_args, training_args = parser.parse()
|
||||
|
||||
#######
|
||||
# Setup
|
||||
#######
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Data parameters {data_args}")
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Check for last checkpoint
|
||||
last_checkpoint = get_checkpoint(training_args)
|
||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||
|
||||
# Set seed for reproducibility
|
||||
set_seed(training_args.seed)
|
||||
|
||||
###############
|
||||
# Load datasets
|
||||
###############
|
||||
raw_datasets = get_datasets(
|
||||
data_args,
|
||||
splits=data_args.dataset_splits,
|
||||
configs=data_args.dataset_configs,
|
||||
columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
|
||||
seed=training_args.seed,
|
||||
)
|
||||
logger.info(
|
||||
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
|
||||
)
|
||||
column_names = list(raw_datasets["train"].features)
|
||||
|
||||
#####################################
|
||||
# Load tokenizer and process datasets
|
||||
#####################################
|
||||
data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
|
||||
tokenizer = get_tokenizer(model_args, data_args)
|
||||
|
||||
if "mistral" in model_args.model_name_or_path.lower():
|
||||
change_template = "mistral"
|
||||
else:
|
||||
change_template = None
|
||||
#####################
|
||||
# Apply chat template
|
||||
#####################
|
||||
raw_datasets = raw_datasets.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={
|
||||
"tokenizer": tokenizer,
|
||||
"task": "simpo",
|
||||
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
|
||||
"change_template": change_template,
|
||||
},
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Formatting comparisons with prompt template",
|
||||
)
|
||||
|
||||
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
|
||||
for split in ["train", "test"]:
|
||||
raw_datasets[split] = raw_datasets[split].rename_columns(
|
||||
{"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
|
||||
)
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(raw_datasets["train"])), 3):
|
||||
logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}")
|
||||
logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}")
|
||||
logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}")
|
||||
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
model = model_args.model_name_or_path
|
||||
if is_adapter_model(model, model_args.model_revision) is True:
|
||||
logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}")
|
||||
peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.base_model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
peft_config.base_model_name_or_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
)
|
||||
model_kwargs = None
|
||||
|
||||
ref_model = model
|
||||
ref_model_kwargs = model_kwargs
|
||||
|
||||
if model_args.use_peft is True:
|
||||
ref_model = None
|
||||
ref_model_kwargs = None
|
||||
|
||||
#########################
|
||||
# Instantiate SimPO trainer
|
||||
#########################
|
||||
trainer = SimPOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model, # pass in to bypass DPO Trainer check for ref model but is not actually used
|
||||
model_init_kwargs=model_kwargs,
|
||||
args=training_args,
|
||||
beta=training_args.beta,
|
||||
train_dataset=raw_datasets["train"],
|
||||
eval_dataset=raw_datasets["test"],
|
||||
tokenizer=tokenizer,
|
||||
max_length=training_args.max_length,
|
||||
max_prompt_length=training_args.max_prompt_length,
|
||||
peft_config=get_peft_config(model_args),
|
||||
loss_type=training_args.loss_type,
|
||||
)
|
||||
|
||||
###############
|
||||
# Training loop
|
||||
###############
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(raw_datasets["train"])
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
logger.info("*** Training complete ***")
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
logger.info("*** Save model ***")
|
||||
trainer.save_model(training_args.output_dir)
|
||||
logger.info(f"Model saved to {training_args.output_dir}")
|
||||
|
||||
# Save everything else on main process
|
||||
kwargs = {
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": list(data_args.dataset_mixer.keys()),
|
||||
"dataset_tags": list(data_args.dataset_mixer.keys()),
|
||||
"tags": ["alignment-handbook"],
|
||||
}
|
||||
if trainer.accelerator.is_main_process:
|
||||
trainer.create_model_card(**kwargs)
|
||||
# Restore k,v cache for fast inference
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(raw_datasets["test"])
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.push_to_hub is True:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
logger.info("*** Training complete! ***")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,134 @@
|
||||
|
||||
from trl import DPOTrainer
|
||||
import torch
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
class SimPOTrainer(DPOTrainer):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs) # Pass all other arguments using **kwargs
|
||||
training_args = kwargs["args"]
|
||||
self.gamma = training_args.gamma
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def simpo_loss(
|
||||
self,
|
||||
policy_chosen_logps: torch.FloatTensor,
|
||||
policy_rejected_logps: torch.FloatTensor,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Compute the SimPO loss for a batch of policy model log probabilities.
|
||||
|
||||
Args:
|
||||
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
||||
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
||||
The losses tensor contains the SimPO loss for each example in the batch.
|
||||
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
||||
"""
|
||||
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||
gamma_logratios = self.gamma / self.beta
|
||||
pi_logratios = pi_logratios.to(self.accelerator.device)
|
||||
logits = pi_logratios - gamma_logratios
|
||||
|
||||
if self.loss_type == "sigmoid":
|
||||
losses = (
|
||||
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
||||
)
|
||||
elif self.loss_type == "hinge":
|
||||
losses = torch.relu(1 - self.beta * logits)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']"
|
||||
)
|
||||
|
||||
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
|
||||
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
|
||||
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
||||
"""
|
||||
concatenated_batch = self.concatenated_inputs(
|
||||
batch,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
padding_value=self.padding_value,
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
len_chosen = batch["chosen_labels"].shape[0]
|
||||
|
||||
model_kwargs = (
|
||||
{
|
||||
"labels": concatenated_batch["concatenated_labels"],
|
||||
"decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
|
||||
}
|
||||
if self.is_encoder_decoder
|
||||
else {}
|
||||
)
|
||||
|
||||
all_logits = model(
|
||||
concatenated_batch["concatenated_input_ids"],
|
||||
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
).logits
|
||||
|
||||
all_logps = self.get_batch_logps(
|
||||
all_logits,
|
||||
concatenated_batch["concatenated_labels"],
|
||||
average_log_prob=True,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
|
||||
chosen_logps = all_logps[:len_chosen]
|
||||
rejected_logps = all_logps[len_chosen:]
|
||||
|
||||
chosen_logits = all_logits[:len_chosen]
|
||||
rejected_logits = all_logits[len_chosen:]
|
||||
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the SimPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.simpo_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps
|
||||
)
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
|
||||
|
||||
return losses.mean(), metrics
|
||||
@@ -0,0 +1,45 @@
|
||||
# Model arguments
|
||||
model_name_or_path: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-sft-full-lr-2e-5
|
||||
torch_dtype: null
|
||||
use_flash_attention_2: true
|
||||
|
||||
# Data training arguments
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
dataset_splits:
|
||||
- train_prefs
|
||||
- test_prefs
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# SimPOTrainer arguments
|
||||
bf16: true
|
||||
beta: 2.0
|
||||
gamma: 1.0
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 400
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: False
|
||||
hub_model_id: simpo-exps
|
||||
learning_rate: 6.0e-7
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
lr_scheduler_type: cosine
|
||||
max_length: 2048
|
||||
max_prompt_length: 1800
|
||||
num_train_epochs: 1
|
||||
optim: adamw_torch
|
||||
output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-base-simpo
|
||||
run_name: llama-3-8b-base-simpo
|
||||
per_device_train_batch_size: 2
|
||||
per_device_eval_batch_size: 4
|
||||
push_to_hub: false
|
||||
save_strategy: "steps"
|
||||
save_steps: 1000000
|
||||
report_to:
|
||||
- wandb
|
||||
save_total_limit: 20
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
@@ -0,0 +1,45 @@
|
||||
# Model arguments
|
||||
model_name_or_path: /scratch/gpfs/DANQIC/ym0081/hf_cache/Meta-Llama-3-8B-Instruct
|
||||
torch_dtype: null
|
||||
use_flash_attention_2: true
|
||||
|
||||
# Data training arguments
|
||||
dataset_mixer:
|
||||
/scratch/gpfs/DANQIC/ym0081/hf_cache/llama3-ultrafeedback: 1.0
|
||||
dataset_splits:
|
||||
- train
|
||||
- test
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# SimPOTrainer arguments
|
||||
bf16: true
|
||||
beta: 2.5
|
||||
gamma: 1.4
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 400
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: False
|
||||
hub_model_id: simpo-exps
|
||||
learning_rate: 1.0e-6
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
lr_scheduler_type: cosine
|
||||
max_length: 2048
|
||||
max_prompt_length: 1800
|
||||
num_train_epochs: 1
|
||||
optim: adamw_torch
|
||||
output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-instruct-simpo
|
||||
run_name: llama-3-8b-instruct-simpo
|
||||
per_device_train_batch_size: 2
|
||||
per_device_eval_batch_size: 4
|
||||
push_to_hub: false
|
||||
save_strategy: "steps"
|
||||
save_steps: 1000000
|
||||
report_to:
|
||||
- wandb
|
||||
save_total_limit: 20
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
@@ -0,0 +1,45 @@
|
||||
# Model arguments
|
||||
model_name_or_path: alignment-handbook/zephyr-7b-sft-full
|
||||
torch_dtype: null
|
||||
use_flash_attention_2: true
|
||||
|
||||
# Data training arguments
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
dataset_splits:
|
||||
- train_prefs
|
||||
- test_prefs
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# SimPOTrainer arguments
|
||||
bf16: true
|
||||
beta: 2.0
|
||||
gamma: 1.6
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 400
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: False
|
||||
hub_model_id: simpo-exps
|
||||
learning_rate: 3.0e-7
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
lr_scheduler_type: cosine
|
||||
max_length: 1024
|
||||
max_prompt_length: 512
|
||||
num_train_epochs: 1
|
||||
optim: adamw_torch
|
||||
output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/mistral-7b-base-simpo
|
||||
run_name: mistral-7b-base-simpo
|
||||
per_device_train_batch_size: 2
|
||||
per_device_eval_batch_size: 4
|
||||
push_to_hub: false
|
||||
save_strategy: "steps"
|
||||
save_steps: 1000000
|
||||
report_to:
|
||||
- wandb
|
||||
save_total_limit: 20
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
@@ -0,0 +1,45 @@
|
||||
# Model arguments
|
||||
model_name_or_path: /scratch/gpfs/DANQIC/ym0081/hf_cache/Mistral-7B-Instruct-v0.2
|
||||
torch_dtype: null
|
||||
use_flash_attention_2: true
|
||||
|
||||
# Data training arguments
|
||||
dataset_mixer:
|
||||
snorkelai/Snorkel-Mistral-PairRM-DPO-Dataset: 1.0
|
||||
dataset_splits:
|
||||
- train
|
||||
- test
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# SimPOTrainer arguments
|
||||
bf16: true
|
||||
beta: 2.5
|
||||
gamma: 0.3
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 400
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: False
|
||||
hub_model_id: simpo-exps
|
||||
learning_rate: 5.0e-7
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
lr_scheduler_type: cosine
|
||||
max_length: 2048
|
||||
max_prompt_length: 1800
|
||||
num_train_epochs: 1
|
||||
optim: adamw_torch
|
||||
output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/mistral-7b-instruct-simpo
|
||||
run_name: mistral-7b-instruct-simpo
|
||||
per_device_train_batch_size: 2
|
||||
per_device_eval_batch_size: 4
|
||||
push_to_hub: false
|
||||
save_strategy: "steps"
|
||||
save_steps: 1000000
|
||||
report_to:
|
||||
- wandb
|
||||
save_total_limit: 20
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
Reference in New Issue
Block a user