mirror of
https://github.com/wassname/SimPO.git
synced 2026-06-27 18:57:43 +08:00
release
This commit is contained in:
@@ -0,0 +1,134 @@
|
||||
|
||||
from trl import DPOTrainer
|
||||
import torch
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
class SimPOTrainer(DPOTrainer):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs) # Pass all other arguments using **kwargs
|
||||
training_args = kwargs["args"]
|
||||
self.gamma = training_args.gamma
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def simpo_loss(
|
||||
self,
|
||||
policy_chosen_logps: torch.FloatTensor,
|
||||
policy_rejected_logps: torch.FloatTensor,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Compute the SimPO loss for a batch of policy model log probabilities.
|
||||
|
||||
Args:
|
||||
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
||||
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
||||
The losses tensor contains the SimPO loss for each example in the batch.
|
||||
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
||||
"""
|
||||
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||
gamma_logratios = self.gamma / self.beta
|
||||
pi_logratios = pi_logratios.to(self.accelerator.device)
|
||||
logits = pi_logratios - gamma_logratios
|
||||
|
||||
if self.loss_type == "sigmoid":
|
||||
losses = (
|
||||
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
||||
)
|
||||
elif self.loss_type == "hinge":
|
||||
losses = torch.relu(1 - self.beta * logits)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']"
|
||||
)
|
||||
|
||||
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
|
||||
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
|
||||
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
||||
"""
|
||||
concatenated_batch = self.concatenated_inputs(
|
||||
batch,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
padding_value=self.padding_value,
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
len_chosen = batch["chosen_labels"].shape[0]
|
||||
|
||||
model_kwargs = (
|
||||
{
|
||||
"labels": concatenated_batch["concatenated_labels"],
|
||||
"decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
|
||||
}
|
||||
if self.is_encoder_decoder
|
||||
else {}
|
||||
)
|
||||
|
||||
all_logits = model(
|
||||
concatenated_batch["concatenated_input_ids"],
|
||||
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
).logits
|
||||
|
||||
all_logps = self.get_batch_logps(
|
||||
all_logits,
|
||||
concatenated_batch["concatenated_labels"],
|
||||
average_log_prob=True,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
|
||||
chosen_logps = all_logps[:len_chosen]
|
||||
rejected_logps = all_logps[len_chosen:]
|
||||
|
||||
chosen_logits = all_logits[:len_chosen]
|
||||
rejected_logits = all_logits[len_chosen:]
|
||||
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the SimPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.simpo_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps
|
||||
)
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
|
||||
|
||||
return losses.mean(), metrics
|
||||
Reference in New Issue
Block a user