mirror of
https://github.com/wassname/SimPO.git
synced 2026-06-27 18:41:04 +08:00
release
This commit is contained in:
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, set_seed
|
||||
|
||||
from alignment import (
|
||||
DataArguments,
|
||||
DPOConfig,
|
||||
H4ArgumentParser,
|
||||
ModelArguments,
|
||||
maybe_insert_system_message,
|
||||
is_openai_format,
|
||||
get_checkpoint,
|
||||
get_datasets,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
get_tokenizer,
|
||||
is_adapter_model,
|
||||
)
|
||||
from peft import PeftConfig, PeftModel
|
||||
from simpo_trainer import SimPOTrainer
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Literal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MISTRAL_CHAT_TEMPLATE = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'].strip() + '\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
|
||||
|
||||
@dataclass
|
||||
class SimPOConfig(DPOConfig):
|
||||
gamma: Optional[float] = field(
|
||||
default=0.5,
|
||||
metadata={"help": "The target reward margin term in SimPO loss."},
|
||||
)
|
||||
|
||||
def apply_chat_template(
|
||||
example,
|
||||
tokenizer,
|
||||
task: Literal["sft", "generation", "rm", "simpo"],
|
||||
auto_insert_empty_system_msg: bool = True,
|
||||
change_template = None,
|
||||
):
|
||||
if change_template == "mistral":
|
||||
tokenizer.chat_template = MISTRAL_CHAT_TEMPLATE
|
||||
if task in ["sft", "generation"]:
|
||||
messages = example["messages"]
|
||||
# We add an empty system message if there is none
|
||||
if auto_insert_empty_system_msg:
|
||||
maybe_insert_system_message(messages, tokenizer)
|
||||
example["text"] = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True if task == "generation" else False,
|
||||
)
|
||||
elif task == "rm":
|
||||
if all(k in example.keys() for k in ("chosen", "rejected")):
|
||||
chosen_messages = example["chosen"]
|
||||
rejected_messages = example["rejected"]
|
||||
# We add an empty system message if there is none
|
||||
if auto_insert_empty_system_msg:
|
||||
maybe_insert_system_message(chosen_messages, tokenizer)
|
||||
maybe_insert_system_message(rejected_messages, tokenizer)
|
||||
|
||||
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
|
||||
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
|
||||
)
|
||||
elif task == "simpo":
|
||||
if all(k in example.keys() for k in ("chosen", "rejected")):
|
||||
if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]):
|
||||
raise ValueError(
|
||||
f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages"
|
||||
)
|
||||
|
||||
# For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
|
||||
# We therefore need to extract the N-1 turns to form the prompt
|
||||
if "prompt" in example and is_openai_format(example["prompt"]):
|
||||
prompt_messages = example["prompt"]
|
||||
chosen_messages = example["chosen"]
|
||||
rejected_messages = example["rejected"]
|
||||
else:
|
||||
prompt_messages = example["chosen"][:-1]
|
||||
# Now we extract the final turn to define chosen/rejected responses
|
||||
chosen_messages = example["chosen"][-1:]
|
||||
rejected_messages = example["rejected"][-1:]
|
||||
|
||||
# Prepend a system message if the first message is not a system message
|
||||
if auto_insert_empty_system_msg:
|
||||
maybe_insert_system_message(prompt_messages, tokenizer)
|
||||
|
||||
example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
|
||||
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
|
||||
if example["text_chosen"].startswith(tokenizer.bos_token):
|
||||
example["text_chosen"] = example["text_chosen"][len(tokenizer.bos_token):]
|
||||
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
|
||||
if example["text_rejected"].startswith(tokenizer.bos_token):
|
||||
example["text_rejected"] = example["text_rejected"][len(tokenizer.bos_token):]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not format example as dialogue for `{task}` task! Require either the "
|
||||
f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']"
|
||||
)
|
||||
return example
|
||||
|
||||
|
||||
def main():
|
||||
parser = H4ArgumentParser((ModelArguments, DataArguments, SimPOConfig))
|
||||
model_args, data_args, training_args = parser.parse()
|
||||
|
||||
#######
|
||||
# Setup
|
||||
#######
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Data parameters {data_args}")
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Check for last checkpoint
|
||||
last_checkpoint = get_checkpoint(training_args)
|
||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||
|
||||
# Set seed for reproducibility
|
||||
set_seed(training_args.seed)
|
||||
|
||||
###############
|
||||
# Load datasets
|
||||
###############
|
||||
raw_datasets = get_datasets(
|
||||
data_args,
|
||||
splits=data_args.dataset_splits,
|
||||
configs=data_args.dataset_configs,
|
||||
columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
|
||||
seed=training_args.seed,
|
||||
)
|
||||
logger.info(
|
||||
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
|
||||
)
|
||||
column_names = list(raw_datasets["train"].features)
|
||||
|
||||
#####################################
|
||||
# Load tokenizer and process datasets
|
||||
#####################################
|
||||
data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
|
||||
tokenizer = get_tokenizer(model_args, data_args)
|
||||
|
||||
if "mistral" in model_args.model_name_or_path.lower():
|
||||
change_template = "mistral"
|
||||
else:
|
||||
change_template = None
|
||||
#####################
|
||||
# Apply chat template
|
||||
#####################
|
||||
raw_datasets = raw_datasets.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={
|
||||
"tokenizer": tokenizer,
|
||||
"task": "simpo",
|
||||
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
|
||||
"change_template": change_template,
|
||||
},
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Formatting comparisons with prompt template",
|
||||
)
|
||||
|
||||
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
|
||||
for split in ["train", "test"]:
|
||||
raw_datasets[split] = raw_datasets[split].rename_columns(
|
||||
{"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
|
||||
)
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(raw_datasets["train"])), 3):
|
||||
logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}")
|
||||
logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}")
|
||||
logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}")
|
||||
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
model = model_args.model_name_or_path
|
||||
if is_adapter_model(model, model_args.model_revision) is True:
|
||||
logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}")
|
||||
peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.base_model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
peft_config.base_model_name_or_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
)
|
||||
model_kwargs = None
|
||||
|
||||
ref_model = model
|
||||
ref_model_kwargs = model_kwargs
|
||||
|
||||
if model_args.use_peft is True:
|
||||
ref_model = None
|
||||
ref_model_kwargs = None
|
||||
|
||||
#########################
|
||||
# Instantiate SimPO trainer
|
||||
#########################
|
||||
trainer = SimPOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model, # pass in to bypass DPO Trainer check for ref model but is not actually used
|
||||
model_init_kwargs=model_kwargs,
|
||||
args=training_args,
|
||||
beta=training_args.beta,
|
||||
train_dataset=raw_datasets["train"],
|
||||
eval_dataset=raw_datasets["test"],
|
||||
tokenizer=tokenizer,
|
||||
max_length=training_args.max_length,
|
||||
max_prompt_length=training_args.max_prompt_length,
|
||||
peft_config=get_peft_config(model_args),
|
||||
loss_type=training_args.loss_type,
|
||||
)
|
||||
|
||||
###############
|
||||
# Training loop
|
||||
###############
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(raw_datasets["train"])
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
logger.info("*** Training complete ***")
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
logger.info("*** Save model ***")
|
||||
trainer.save_model(training_args.output_dir)
|
||||
logger.info(f"Model saved to {training_args.output_dir}")
|
||||
|
||||
# Save everything else on main process
|
||||
kwargs = {
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": list(data_args.dataset_mixer.keys()),
|
||||
"dataset_tags": list(data_args.dataset_mixer.keys()),
|
||||
"tags": ["alignment-handbook"],
|
||||
}
|
||||
if trainer.accelerator.is_main_process:
|
||||
trainer.create_model_card(**kwargs)
|
||||
# Restore k,v cache for fast inference
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(raw_datasets["test"])
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.push_to_hub is True:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
logger.info("*** Training complete! ***")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,134 @@
|
||||
|
||||
from trl import DPOTrainer
|
||||
import torch
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
class SimPOTrainer(DPOTrainer):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs) # Pass all other arguments using **kwargs
|
||||
training_args = kwargs["args"]
|
||||
self.gamma = training_args.gamma
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def simpo_loss(
|
||||
self,
|
||||
policy_chosen_logps: torch.FloatTensor,
|
||||
policy_rejected_logps: torch.FloatTensor,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Compute the SimPO loss for a batch of policy model log probabilities.
|
||||
|
||||
Args:
|
||||
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
||||
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
||||
The losses tensor contains the SimPO loss for each example in the batch.
|
||||
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
||||
"""
|
||||
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||
gamma_logratios = self.gamma / self.beta
|
||||
pi_logratios = pi_logratios.to(self.accelerator.device)
|
||||
logits = pi_logratios - gamma_logratios
|
||||
|
||||
if self.loss_type == "sigmoid":
|
||||
losses = (
|
||||
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
||||
)
|
||||
elif self.loss_type == "hinge":
|
||||
losses = torch.relu(1 - self.beta * logits)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']"
|
||||
)
|
||||
|
||||
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
|
||||
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
|
||||
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
||||
"""
|
||||
concatenated_batch = self.concatenated_inputs(
|
||||
batch,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
padding_value=self.padding_value,
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
len_chosen = batch["chosen_labels"].shape[0]
|
||||
|
||||
model_kwargs = (
|
||||
{
|
||||
"labels": concatenated_batch["concatenated_labels"],
|
||||
"decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
|
||||
}
|
||||
if self.is_encoder_decoder
|
||||
else {}
|
||||
)
|
||||
|
||||
all_logits = model(
|
||||
concatenated_batch["concatenated_input_ids"],
|
||||
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
).logits
|
||||
|
||||
all_logps = self.get_batch_logps(
|
||||
all_logits,
|
||||
concatenated_batch["concatenated_labels"],
|
||||
average_log_prob=True,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
|
||||
chosen_logps = all_logps[:len_chosen]
|
||||
rejected_logps = all_logps[len_chosen:]
|
||||
|
||||
chosen_logits = all_logits[:len_chosen]
|
||||
rejected_logits = all_logits[len_chosen:]
|
||||
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the SimPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.simpo_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps
|
||||
)
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
|
||||
|
||||
return losses.mean(), metrics
|
||||
Reference in New Issue
Block a user