This commit is contained in:
Yu Meng
2024-05-23 22:53:13 -04:00
commit 586e5e3d0a
12 changed files with 964 additions and 0 deletions
+164
View File
@@ -0,0 +1,164 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# Temp folders
data/
wandb/
+88
View File
@@ -0,0 +1,88 @@
# Simple Preference Optimization (SimPO)
This repository contains the code and released models for our paper [SimPO: Simple Preference Optimization with a Reference-Free Reward](https://arxiv.org/abs/2405.14734). We propose a simpler and more effective preference optimization algorithm than DPO (Direct Preference Optimization) without using a reference model. SimPO outperforms DPO and its latest variants across AlpacaEval 2, MT-Bench, and Arena-Hard benchmarks under various settings.
<img src="./SimPO.png" width="1000px"></img>
## 🔗 Quick Links
- [SimPO: Simple Preference Optimization with a Reference-Free Reward](#simple-preference-optimization-simpo)
- [Released Models](#released-models)
- [Install Requirements](#install-requirements)
- [Training scripts](#training-scripts)
- [Evaluation](#evaluation)
- [Bugs or Questions?](#bugs-or-questions)
- [Citation](#citation)
## Released Models
## Install Requirements
Our codebase is built upon the [alignment-handbook repo](https://github.com/huggingface/alignment-handbook). The following steps will guide you through the installation process.
First, create a Python virtual environment using e.g. Conda:
```shell
conda create -n handbook python=3.10 && conda activate handbook
```
Next, install PyTorch `v2.2.2`. Since this is hardware-dependent, we
direct you to the [PyTorch Installation Page](https://pytorch.org/get-started/locally/).
You can then install the remaining package dependencies of [alignment-handbook](https://github.com/huggingface/alignment-handbook) as follows:
```shell
git clone https://github.com/huggingface/alignment-handbook.git
cd ./alignment-handbook/
python -m pip install .
```
You will also need Flash Attention 2 installed, which can be done by running:
```shell
python -m pip install flash-attn --no-build-isolation
```
## Training Scripts
We provide four training config files for the four training setups reported in our paper:
* Mistral-Base:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/mistral-7b-base-simpo.yaml
```
* Mistral-Instruct:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/mistral-7b-instruct-simpo.yaml
```
* Llama3-Base:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/llama-3-8b-base-simpo.yaml
```
* Llama3-Instruct:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/llama-3-8b-instruct-simpo.yaml
```
## Evaluation
We follow the official implementation for evaluation on AlpacaEval 2, Arena-Hard, and MT-Bench, as follows:
* AlpacaEval 2: Please refer to the [AlpacaEval repo](https://github.com/tatsu-lab/alpaca_eval) for evaluation.
* Arena-Hard: Please refer to to the [Arena-Hard-Auto repo](https://github.com/lm-sys/arena-hard-auto) for evaluation.
* MT-Bench: Please refer to the [FastChat repo](https://github.com/lm-sys/FastChat) for evaluation.
## Bugs or Questions?
If you have any questions related to the code or the paper, feel free to email Yu (yumeng5@virginia.edu). If you encounter any problems when using the code, or want to report a bug, feel free to open an issue! Please try to specify the problem with details so we can help you better and quicker!
## Citation
Please cite our paper if you find the repo helpful in your work:
```bibtex
@article{meng2024simpo,
title={{SimPO}: Simple Preference Optimization with a Reference-Free Reward},
author={Meng, Yu and Xia, Mengzhou and Chen, Danqi},
year={2024}
}
```
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

+22
View File
@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
+26
View File
@@ -0,0 +1,26 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
+16
View File
@@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
+334
View File
@@ -0,0 +1,334 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import sys
import torch
import transformers
from transformers import AutoModelForCausalLM, set_seed
from alignment import (
DataArguments,
DPOConfig,
H4ArgumentParser,
ModelArguments,
maybe_insert_system_message,
is_openai_format,
get_checkpoint,
get_datasets,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
get_tokenizer,
is_adapter_model,
)
from peft import PeftConfig, PeftModel
from simpo_trainer import SimPOTrainer
from dataclasses import dataclass, field
from typing import Optional, Literal
logger = logging.getLogger(__name__)
MISTRAL_CHAT_TEMPLATE = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'].strip() + '\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
@dataclass
class SimPOConfig(DPOConfig):
gamma: Optional[float] = field(
default=0.5,
metadata={"help": "The target reward margin term in SimPO loss."},
)
def apply_chat_template(
example,
tokenizer,
task: Literal["sft", "generation", "rm", "simpo"],
auto_insert_empty_system_msg: bool = True,
change_template = None,
):
if change_template == "mistral":
tokenizer.chat_template = MISTRAL_CHAT_TEMPLATE
if task in ["sft", "generation"]:
messages = example["messages"]
# We add an empty system message if there is none
if auto_insert_empty_system_msg:
maybe_insert_system_message(messages, tokenizer)
example["text"] = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True if task == "generation" else False,
)
elif task == "rm":
if all(k in example.keys() for k in ("chosen", "rejected")):
chosen_messages = example["chosen"]
rejected_messages = example["rejected"]
# We add an empty system message if there is none
if auto_insert_empty_system_msg:
maybe_insert_system_message(chosen_messages, tokenizer)
maybe_insert_system_message(rejected_messages, tokenizer)
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
else:
raise ValueError(
f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
)
elif task == "simpo":
if all(k in example.keys() for k in ("chosen", "rejected")):
if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]):
raise ValueError(
f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages"
)
# For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
# We therefore need to extract the N-1 turns to form the prompt
if "prompt" in example and is_openai_format(example["prompt"]):
prompt_messages = example["prompt"]
chosen_messages = example["chosen"]
rejected_messages = example["rejected"]
else:
prompt_messages = example["chosen"][:-1]
# Now we extract the final turn to define chosen/rejected responses
chosen_messages = example["chosen"][-1:]
rejected_messages = example["rejected"][-1:]
# Prepend a system message if the first message is not a system message
if auto_insert_empty_system_msg:
maybe_insert_system_message(prompt_messages, tokenizer)
example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
if example["text_chosen"].startswith(tokenizer.bos_token):
example["text_chosen"] = example["text_chosen"][len(tokenizer.bos_token):]
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
if example["text_rejected"].startswith(tokenizer.bos_token):
example["text_rejected"] = example["text_rejected"][len(tokenizer.bos_token):]
else:
raise ValueError(
f"Could not format example as dialogue for `{task}` task! Require either the "
f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}"
)
else:
raise ValueError(
f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']"
)
return example
def main():
parser = H4ArgumentParser((ModelArguments, DataArguments, SimPOConfig))
model_args, data_args, training_args = parser.parse()
#######
# Setup
#######
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.info(f"Model parameters {model_args}")
logger.info(f"Data parameters {data_args}")
logger.info(f"Training/evaluation parameters {training_args}")
# Check for last checkpoint
last_checkpoint = get_checkpoint(training_args)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
# Set seed for reproducibility
set_seed(training_args.seed)
###############
# Load datasets
###############
raw_datasets = get_datasets(
data_args,
splits=data_args.dataset_splits,
configs=data_args.dataset_configs,
columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
seed=training_args.seed,
)
logger.info(
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
)
column_names = list(raw_datasets["train"].features)
#####################################
# Load tokenizer and process datasets
#####################################
data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
tokenizer = get_tokenizer(model_args, data_args)
if "mistral" in model_args.model_name_or_path.lower():
change_template = "mistral"
else:
change_template = None
#####################
# Apply chat template
#####################
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={
"tokenizer": tokenizer,
"task": "simpo",
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
"change_template": change_template,
},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Formatting comparisons with prompt template",
)
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
for split in ["train", "test"]:
raw_datasets[split] = raw_datasets[split].rename_columns(
{"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
)
# Log a few random samples from the training set:
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}")
logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}")
logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = model_args.model_name_or_path
if is_adapter_model(model, model_args.model_revision) is True:
logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}")
peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
model_kwargs = dict(
revision=model_args.base_model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
**model_kwargs,
)
model = PeftModel.from_pretrained(
base_model,
model_args.model_name_or_path,
revision=model_args.model_revision,
)
model_kwargs = None
ref_model = model
ref_model_kwargs = model_kwargs
if model_args.use_peft is True:
ref_model = None
ref_model_kwargs = None
#########################
# Instantiate SimPO trainer
#########################
trainer = SimPOTrainer(
model=model,
ref_model=ref_model, # pass in to bypass DPO Trainer check for ref model but is not actually used
model_init_kwargs=model_kwargs,
args=training_args,
beta=training_args.beta,
train_dataset=raw_datasets["train"],
eval_dataset=raw_datasets["test"],
tokenizer=tokenizer,
max_length=training_args.max_length,
max_prompt_length=training_args.max_prompt_length,
peft_config=get_peft_config(model_args),
loss_type=training_args.loss_type,
)
###############
# Training loop
###############
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(raw_datasets["train"])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
logger.info("*** Training complete ***")
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"dataset": list(data_args.dataset_mixer.keys()),
"dataset_tags": list(data_args.dataset_mixer.keys()),
"tags": ["alignment-handbook"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(raw_datasets["test"])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.push_to_hub is True:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
logger.info("*** Training complete! ***")
if __name__ == "__main__":
main()
+134
View File
@@ -0,0 +1,134 @@
from trl import DPOTrainer
import torch
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch.nn.functional as F
import torch.nn as nn
class SimPOTrainer(DPOTrainer):
def __init__(self, **kwargs):
super().__init__(**kwargs) # Pass all other arguments using **kwargs
training_args = kwargs["args"]
self.gamma = training_args.gamma
super().__init__(**kwargs)
def simpo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the SimPO loss for a batch of policy model log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the SimPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
gamma_logratios = self.gamma / self.beta
pi_logratios = pi_logratios.to(self.accelerator.device)
logits = pi_logratios - gamma_logratios
if self.loss_type == "sigmoid":
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "hinge":
losses = torch.relu(1 - self.beta * logits)
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']"
)
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
return losses, chosen_rewards, rejected_rewards
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
concatenated_batch = self.concatenated_inputs(
batch,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
padding_value=self.padding_value,
device=self.accelerator.device,
)
len_chosen = batch["chosen_labels"].shape[0]
model_kwargs = (
{
"labels": concatenated_batch["concatenated_labels"],
"decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
}
if self.is_encoder_decoder
else {}
)
all_logits = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
).logits
all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the SimPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
losses, chosen_rewards, rejected_rewards = self.simpo_loss(
policy_chosen_logps,
policy_rejected_logps
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
return losses.mean(), metrics
@@ -0,0 +1,45 @@
# Model arguments
model_name_or_path: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-sft-full-lr-2e-5
torch_dtype: null
use_flash_attention_2: true
# Data training arguments
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 1.0
dataset_splits:
- train_prefs
- test_prefs
preprocessing_num_workers: 12
# SimPOTrainer arguments
bf16: true
beta: 2.0
gamma: 1.0
do_eval: true
evaluation_strategy: steps
eval_steps: 400
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: False
hub_model_id: simpo-exps
learning_rate: 6.0e-7
log_level: info
logging_steps: 5
lr_scheduler_type: cosine
max_length: 2048
max_prompt_length: 1800
num_train_epochs: 1
optim: adamw_torch
output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-base-simpo
run_name: llama-3-8b-base-simpo
per_device_train_batch_size: 2
per_device_eval_batch_size: 4
push_to_hub: false
save_strategy: "steps"
save_steps: 1000000
report_to:
- wandb
save_total_limit: 20
seed: 42
warmup_ratio: 0.1
@@ -0,0 +1,45 @@
# Model arguments
model_name_or_path: /scratch/gpfs/DANQIC/ym0081/hf_cache/Meta-Llama-3-8B-Instruct
torch_dtype: null
use_flash_attention_2: true
# Data training arguments
dataset_mixer:
/scratch/gpfs/DANQIC/ym0081/hf_cache/llama3-ultrafeedback: 1.0
dataset_splits:
- train
- test
preprocessing_num_workers: 12
# SimPOTrainer arguments
bf16: true
beta: 2.5
gamma: 1.4
do_eval: true
evaluation_strategy: steps
eval_steps: 400
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: False
hub_model_id: simpo-exps
learning_rate: 1.0e-6
log_level: info
logging_steps: 5
lr_scheduler_type: cosine
max_length: 2048
max_prompt_length: 1800
num_train_epochs: 1
optim: adamw_torch
output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-instruct-simpo
run_name: llama-3-8b-instruct-simpo
per_device_train_batch_size: 2
per_device_eval_batch_size: 4
push_to_hub: false
save_strategy: "steps"
save_steps: 1000000
report_to:
- wandb
save_total_limit: 20
seed: 42
warmup_ratio: 0.1
@@ -0,0 +1,45 @@
# Model arguments
model_name_or_path: alignment-handbook/zephyr-7b-sft-full
torch_dtype: null
use_flash_attention_2: true
# Data training arguments
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 1.0
dataset_splits:
- train_prefs
- test_prefs
preprocessing_num_workers: 12
# SimPOTrainer arguments
bf16: true
beta: 2.0
gamma: 1.6
do_eval: true
evaluation_strategy: steps
eval_steps: 400
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: False
hub_model_id: simpo-exps
learning_rate: 3.0e-7
log_level: info
logging_steps: 5
lr_scheduler_type: cosine
max_length: 1024
max_prompt_length: 512
num_train_epochs: 1
optim: adamw_torch
output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/mistral-7b-base-simpo
run_name: mistral-7b-base-simpo
per_device_train_batch_size: 2
per_device_eval_batch_size: 4
push_to_hub: false
save_strategy: "steps"
save_steps: 1000000
report_to:
- wandb
save_total_limit: 20
seed: 42
warmup_ratio: 0.1
@@ -0,0 +1,45 @@
# Model arguments
model_name_or_path: /scratch/gpfs/DANQIC/ym0081/hf_cache/Mistral-7B-Instruct-v0.2
torch_dtype: null
use_flash_attention_2: true
# Data training arguments
dataset_mixer:
snorkelai/Snorkel-Mistral-PairRM-DPO-Dataset: 1.0
dataset_splits:
- train
- test
preprocessing_num_workers: 12
# SimPOTrainer arguments
bf16: true
beta: 2.5
gamma: 0.3
do_eval: true
evaluation_strategy: steps
eval_steps: 400
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: False
hub_model_id: simpo-exps
learning_rate: 5.0e-7
log_level: info
logging_steps: 5
lr_scheduler_type: cosine
max_length: 2048
max_prompt_length: 1800
num_train_epochs: 1
optim: adamw_torch
output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/mistral-7b-instruct-simpo
run_name: mistral-7b-instruct-simpo
per_device_train_batch_size: 2
per_device_eval_batch_size: 4
push_to_hub: false
save_strategy: "steps"
save_steps: 1000000
report_to:
- wandb
save_total_limit: 20
seed: 42
warmup_ratio: 0.1