mirror of
https://github.com/wassname/SimPO.git
synced 2026-06-27 18:41:04 +08:00
134 lines
5.5 KiB
Python
134 lines
5.5 KiB
Python
|
|
from trl import DPOTrainer
|
|
import torch
|
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
|
|
class SimPOTrainer(DPOTrainer):
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs) # Pass all other arguments using **kwargs
|
|
training_args = kwargs["args"]
|
|
self.gamma = training_args.gamma
|
|
super().__init__(**kwargs)
|
|
|
|
def simpo_loss(
|
|
self,
|
|
policy_chosen_logps: torch.FloatTensor,
|
|
policy_rejected_logps: torch.FloatTensor,
|
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
|
"""Compute the SimPO loss for a batch of policy model log probabilities.
|
|
|
|
Args:
|
|
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
|
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
|
|
|
Returns:
|
|
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
|
The losses tensor contains the SimPO loss for each example in the batch.
|
|
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
|
"""
|
|
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
|
gamma_logratios = self.gamma / self.beta
|
|
pi_logratios = pi_logratios.to(self.accelerator.device)
|
|
logits = pi_logratios - gamma_logratios
|
|
|
|
if self.loss_type == "sigmoid":
|
|
losses = (
|
|
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
|
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
|
)
|
|
elif self.loss_type == "hinge":
|
|
losses = torch.relu(1 - self.beta * logits)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']"
|
|
)
|
|
|
|
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
|
|
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
|
|
|
|
return losses, chosen_rewards, rejected_rewards
|
|
|
|
def concatenated_forward(
|
|
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
|
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
|
|
|
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
|
"""
|
|
concatenated_batch = self.concatenated_inputs(
|
|
batch,
|
|
is_encoder_decoder=self.is_encoder_decoder,
|
|
label_pad_token_id=self.label_pad_token_id,
|
|
padding_value=self.padding_value,
|
|
device=self.accelerator.device,
|
|
)
|
|
len_chosen = batch["chosen_labels"].shape[0]
|
|
|
|
model_kwargs = (
|
|
{
|
|
"labels": concatenated_batch["concatenated_labels"],
|
|
"decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
|
|
}
|
|
if self.is_encoder_decoder
|
|
else {}
|
|
)
|
|
|
|
all_logits = model(
|
|
concatenated_batch["concatenated_input_ids"],
|
|
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
|
use_cache=False,
|
|
**model_kwargs,
|
|
).logits
|
|
|
|
all_logps = self.get_batch_logps(
|
|
all_logits,
|
|
concatenated_batch["concatenated_labels"],
|
|
average_log_prob=True,
|
|
is_encoder_decoder=self.is_encoder_decoder,
|
|
label_pad_token_id=self.label_pad_token_id,
|
|
)
|
|
|
|
chosen_logps = all_logps[:len_chosen]
|
|
rejected_logps = all_logps[len_chosen:]
|
|
|
|
chosen_logits = all_logits[:len_chosen]
|
|
rejected_logits = all_logits[len_chosen:]
|
|
|
|
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
|
|
|
def get_batch_loss_metrics(
|
|
self,
|
|
model,
|
|
batch: Dict[str, Union[List, torch.LongTensor]],
|
|
train_eval: Literal["train", "eval"] = "train",
|
|
):
|
|
"""Compute the SimPO loss and other metrics for the given batch of inputs for train or test."""
|
|
metrics = {}
|
|
|
|
(
|
|
policy_chosen_logps,
|
|
policy_rejected_logps,
|
|
policy_chosen_logits,
|
|
policy_rejected_logits,
|
|
) = self.concatenated_forward(model, batch)
|
|
|
|
losses, chosen_rewards, rejected_rewards = self.simpo_loss(
|
|
policy_chosen_logps,
|
|
policy_rejected_logps
|
|
)
|
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
|
|
|
prefix = "eval_" if train_eval == "eval" else ""
|
|
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
|
|
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
|
|
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
|
|
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
|
|
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
|
|
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
|
|
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
|
|
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
|
|
|
|
return losses.mean(), metrics |