commit 586e5e3d0aeb97c0804daa06b54ee01110e13616 Author: Yu Meng Date: Thu May 23 22:53:13 2024 -0400 release diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1445d93 --- /dev/null +++ b/.gitignore @@ -0,0 +1,164 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# Temp folders +data/ +wandb/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..f121e9e --- /dev/null +++ b/README.md @@ -0,0 +1,88 @@ +# Simple Preference Optimization (SimPO) + +This repository contains the code and released models for our paper [SimPO: Simple Preference Optimization with a Reference-Free Reward](https://arxiv.org/abs/2405.14734). We propose a simpler and more effective preference optimization algorithm than DPO (Direct Preference Optimization) without using a reference model. SimPO outperforms DPO and its latest variants across AlpacaEval 2, MT-Bench, and Arena-Hard benchmarks under various settings. + + + +## 🔗 Quick Links +- [SimPO: Simple Preference Optimization with a Reference-Free Reward](#simple-preference-optimization-simpo) + - [Released Models](#released-models) + - [Install Requirements](#install-requirements) + - [Training scripts](#training-scripts) + - [Evaluation](#evaluation) + - [Bugs or Questions?](#bugs-or-questions) + - [Citation](#citation) + +## Released Models + + +## Install Requirements + +Our codebase is built upon the [alignment-handbook repo](https://github.com/huggingface/alignment-handbook). The following steps will guide you through the installation process. + +First, create a Python virtual environment using e.g. Conda: +```shell +conda create -n handbook python=3.10 && conda activate handbook +``` + +Next, install PyTorch `v2.2.2`. Since this is hardware-dependent, we +direct you to the [PyTorch Installation Page](https://pytorch.org/get-started/locally/). + +You can then install the remaining package dependencies of [alignment-handbook](https://github.com/huggingface/alignment-handbook) as follows: + +```shell +git clone https://github.com/huggingface/alignment-handbook.git +cd ./alignment-handbook/ +python -m pip install . +``` + +You will also need Flash Attention 2 installed, which can be done by running: + +```shell +python -m pip install flash-attn --no-build-isolation +``` + +## Training Scripts + +We provide four training config files for the four training setups reported in our paper: + +* Mistral-Base: +```shell +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/mistral-7b-base-simpo.yaml +``` +* Mistral-Instruct: +```shell +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/mistral-7b-instruct-simpo.yaml +``` +* Llama3-Base: +```shell +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/llama-3-8b-base-simpo.yaml +``` +* Llama3-Instruct: +```shell +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/llama-3-8b-instruct-simpo.yaml +``` + +## Evaluation + +We follow the official implementation for evaluation on AlpacaEval 2, Arena-Hard, and MT-Bench, as follows: + +* AlpacaEval 2: Please refer to the [AlpacaEval repo](https://github.com/tatsu-lab/alpaca_eval) for evaluation. + +* Arena-Hard: Please refer to to the [Arena-Hard-Auto repo](https://github.com/lm-sys/arena-hard-auto) for evaluation. + +* MT-Bench: Please refer to the [FastChat repo](https://github.com/lm-sys/FastChat) for evaluation. + +## Bugs or Questions? +If you have any questions related to the code or the paper, feel free to email Yu (yumeng5@virginia.edu). If you encounter any problems when using the code, or want to report a bug, feel free to open an issue! Please try to specify the problem with details so we can help you better and quicker! + +## Citation +Please cite our paper if you find the repo helpful in your work: + +```bibtex +@article{meng2024simpo, + title={{SimPO}: Simple Preference Optimization with a Reference-Free Reward}, + author={Meng, Yu and Xia, Mengzhou and Chen, Danqi}, + year={2024} +} +``` diff --git a/SimPO.png b/SimPO.png new file mode 100644 index 0000000..6146741 Binary files /dev/null and b/SimPO.png differ diff --git a/accelerate_configs/deepspeed_zero3.yaml b/accelerate_configs/deepspeed_zero3.yaml new file mode 100644 index 0000000..b5a1201 --- /dev/null +++ b/accelerate_configs/deepspeed_zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/accelerate_configs/fsdp.yaml b/accelerate_configs/fsdp.yaml new file mode 100644 index 0000000..2ec9116 --- /dev/null +++ b/accelerate_configs/fsdp.yaml @@ -0,0 +1,26 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: true + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: true +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/accelerate_configs/multi_gpu.yaml b/accelerate_configs/multi_gpu.yaml new file mode 100644 index 0000000..4f05571 --- /dev/null +++ b/accelerate_configs/multi_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/run_simpo.py b/scripts/run_simpo.py new file mode 100644 index 0000000..77374a4 --- /dev/null +++ b/scripts/run_simpo.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +import sys + +import torch +import transformers +from transformers import AutoModelForCausalLM, set_seed + +from alignment import ( + DataArguments, + DPOConfig, + H4ArgumentParser, + ModelArguments, + maybe_insert_system_message, + is_openai_format, + get_checkpoint, + get_datasets, + get_kbit_device_map, + get_peft_config, + get_quantization_config, + get_tokenizer, + is_adapter_model, +) +from peft import PeftConfig, PeftModel +from simpo_trainer import SimPOTrainer +from dataclasses import dataclass, field +from typing import Optional, Literal + +logger = logging.getLogger(__name__) + +MISTRAL_CHAT_TEMPLATE = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'].strip() + '\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + +@dataclass +class SimPOConfig(DPOConfig): + gamma: Optional[float] = field( + default=0.5, + metadata={"help": "The target reward margin term in SimPO loss."}, + ) + +def apply_chat_template( + example, + tokenizer, + task: Literal["sft", "generation", "rm", "simpo"], + auto_insert_empty_system_msg: bool = True, + change_template = None, +): + if change_template == "mistral": + tokenizer.chat_template = MISTRAL_CHAT_TEMPLATE + if task in ["sft", "generation"]: + messages = example["messages"] + # We add an empty system message if there is none + if auto_insert_empty_system_msg: + maybe_insert_system_message(messages, tokenizer) + example["text"] = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True if task == "generation" else False, + ) + elif task == "rm": + if all(k in example.keys() for k in ("chosen", "rejected")): + chosen_messages = example["chosen"] + rejected_messages = example["rejected"] + # We add an empty system message if there is none + if auto_insert_empty_system_msg: + maybe_insert_system_message(chosen_messages, tokenizer) + maybe_insert_system_message(rejected_messages, tokenizer) + + example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) + example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) + else: + raise ValueError( + f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" + ) + elif task == "simpo": + if all(k in example.keys() for k in ("chosen", "rejected")): + if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]): + raise ValueError( + f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages" + ) + + # For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue + # We therefore need to extract the N-1 turns to form the prompt + if "prompt" in example and is_openai_format(example["prompt"]): + prompt_messages = example["prompt"] + chosen_messages = example["chosen"] + rejected_messages = example["rejected"] + else: + prompt_messages = example["chosen"][:-1] + # Now we extract the final turn to define chosen/rejected responses + chosen_messages = example["chosen"][-1:] + rejected_messages = example["rejected"][-1:] + + # Prepend a system message if the first message is not a system message + if auto_insert_empty_system_msg: + maybe_insert_system_message(prompt_messages, tokenizer) + + example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False) + example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) + if example["text_chosen"].startswith(tokenizer.bos_token): + example["text_chosen"] = example["text_chosen"][len(tokenizer.bos_token):] + example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) + if example["text_rejected"].startswith(tokenizer.bos_token): + example["text_rejected"] = example["text_rejected"][len(tokenizer.bos_token):] + else: + raise ValueError( + f"Could not format example as dialogue for `{task}` task! Require either the " + f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}" + ) + else: + raise ValueError( + f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']" + ) + return example + + +def main(): + parser = H4ArgumentParser((ModelArguments, DataArguments, SimPOConfig)) + model_args, data_args, training_args = parser.parse() + + ####### + # Setup + ####### + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.info(f"Model parameters {model_args}") + logger.info(f"Data parameters {data_args}") + logger.info(f"Training/evaluation parameters {training_args}") + + # Check for last checkpoint + last_checkpoint = get_checkpoint(training_args) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") + + # Set seed for reproducibility + set_seed(training_args.seed) + + ############### + # Load datasets + ############### + raw_datasets = get_datasets( + data_args, + splits=data_args.dataset_splits, + configs=data_args.dataset_configs, + columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"], + seed=training_args.seed, + ) + logger.info( + f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" + ) + column_names = list(raw_datasets["train"].features) + + ##################################### + # Load tokenizer and process datasets + ##################################### + data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn + tokenizer = get_tokenizer(model_args, data_args) + + if "mistral" in model_args.model_name_or_path.lower(): + change_template = "mistral" + else: + change_template = None + ##################### + # Apply chat template + ##################### + raw_datasets = raw_datasets.map( + apply_chat_template, + fn_kwargs={ + "tokenizer": tokenizer, + "task": "simpo", + "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg, + "change_template": change_template, + }, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + desc="Formatting comparisons with prompt template", + ) + + # Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected + for split in ["train", "test"]: + raw_datasets[split] = raw_datasets[split].rename_columns( + {"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"} + ) + + # Log a few random samples from the training set: + for index in random.sample(range(len(raw_datasets["train"])), 3): + logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}") + logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}") + logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}") + + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + use_flash_attention_2=model_args.use_flash_attention_2, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + + model = model_args.model_name_or_path + if is_adapter_model(model, model_args.model_revision) is True: + logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}") + peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision) + model_kwargs = dict( + revision=model_args.base_model_revision, + trust_remote_code=model_args.trust_remote_code, + use_flash_attention_2=model_args.use_flash_attention_2, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + base_model = AutoModelForCausalLM.from_pretrained( + peft_config.base_model_name_or_path, + **model_kwargs, + ) + model = PeftModel.from_pretrained( + base_model, + model_args.model_name_or_path, + revision=model_args.model_revision, + ) + model_kwargs = None + + ref_model = model + ref_model_kwargs = model_kwargs + + if model_args.use_peft is True: + ref_model = None + ref_model_kwargs = None + + ######################### + # Instantiate SimPO trainer + ######################### + trainer = SimPOTrainer( + model=model, + ref_model=ref_model, # pass in to bypass DPO Trainer check for ref model but is not actually used + model_init_kwargs=model_kwargs, + args=training_args, + beta=training_args.beta, + train_dataset=raw_datasets["train"], + eval_dataset=raw_datasets["test"], + tokenizer=tokenizer, + max_length=training_args.max_length, + max_prompt_length=training_args.max_prompt_length, + peft_config=get_peft_config(model_args), + loss_type=training_args.loss_type, + ) + + ############### + # Training loop + ############### + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + metrics["train_samples"] = len(raw_datasets["train"]) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + logger.info("*** Training complete ***") + + ################################## + # Save model and create model card + ################################## + logger.info("*** Save model ***") + trainer.save_model(training_args.output_dir) + logger.info(f"Model saved to {training_args.output_dir}") + + # Save everything else on main process + kwargs = { + "finetuned_from": model_args.model_name_or_path, + "dataset": list(data_args.dataset_mixer.keys()), + "dataset_tags": list(data_args.dataset_mixer.keys()), + "tags": ["alignment-handbook"], + } + if trainer.accelerator.is_main_process: + trainer.create_model_card(**kwargs) + # Restore k,v cache for fast inference + trainer.model.config.use_cache = True + trainer.model.config.save_pretrained(training_args.output_dir) + + ########## + # Evaluate + ########## + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate() + metrics["eval_samples"] = len(raw_datasets["test"]) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.push_to_hub is True: + logger.info("Pushing to hub...") + trainer.push_to_hub(**kwargs) + + logger.info("*** Training complete! ***") + + +if __name__ == "__main__": + main() diff --git a/scripts/simpo_trainer.py b/scripts/simpo_trainer.py new file mode 100644 index 0000000..d28c700 --- /dev/null +++ b/scripts/simpo_trainer.py @@ -0,0 +1,134 @@ + +from trl import DPOTrainer +import torch +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +import torch.nn.functional as F +import torch.nn as nn + +class SimPOTrainer(DPOTrainer): + + def __init__(self, **kwargs): + super().__init__(**kwargs) # Pass all other arguments using **kwargs + training_args = kwargs["args"] + self.gamma = training_args.gamma + super().__init__(**kwargs) + + def simpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the SimPO loss for a batch of policy model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the SimPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + gamma_logratios = self.gamma / self.beta + pi_logratios = pi_logratios.to(self.accelerator.device) + logits = pi_logratios - gamma_logratios + + if self.loss_type == "sigmoid": + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + elif self.loss_type == "hinge": + losses = torch.relu(1 - self.beta * logits) + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']" + ) + + chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach() + rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach() + + return losses, chosen_rewards, rejected_rewards + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + ) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "labels": concatenated_batch["concatenated_labels"], + "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), + } + if self.is_encoder_decoder + else {} + ) + + all_logits = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ).logits + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=True, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) + + def get_batch_loss_metrics( + self, + model, + batch: Dict[str, Union[List, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the SimPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(model, batch) + + losses, chosen_rewards, rejected_rewards = self.simpo_loss( + policy_chosen_logps, + policy_rejected_logps + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() + + return losses.mean(), metrics \ No newline at end of file diff --git a/training_configs/llama-3-8b-base-simpo.yaml b/training_configs/llama-3-8b-base-simpo.yaml new file mode 100644 index 0000000..d56ef87 --- /dev/null +++ b/training_configs/llama-3-8b-base-simpo.yaml @@ -0,0 +1,45 @@ +# Model arguments +model_name_or_path: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-sft-full-lr-2e-5 +torch_dtype: null +use_flash_attention_2: true + +# Data training arguments +dataset_mixer: + HuggingFaceH4/ultrafeedback_binarized: 1.0 +dataset_splits: +- train_prefs +- test_prefs +preprocessing_num_workers: 12 + +# SimPOTrainer arguments +bf16: true +beta: 2.0 +gamma: 1.0 +do_eval: true +evaluation_strategy: steps +eval_steps: 400 +gradient_accumulation_steps: 8 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: False +hub_model_id: simpo-exps +learning_rate: 6.0e-7 +log_level: info +logging_steps: 5 +lr_scheduler_type: cosine +max_length: 2048 +max_prompt_length: 1800 +num_train_epochs: 1 +optim: adamw_torch +output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-base-simpo +run_name: llama-3-8b-base-simpo +per_device_train_batch_size: 2 +per_device_eval_batch_size: 4 +push_to_hub: false +save_strategy: "steps" +save_steps: 1000000 +report_to: +- wandb +save_total_limit: 20 +seed: 42 +warmup_ratio: 0.1 diff --git a/training_configs/llama-3-8b-instruct-simpo.yaml b/training_configs/llama-3-8b-instruct-simpo.yaml new file mode 100644 index 0000000..fe51723 --- /dev/null +++ b/training_configs/llama-3-8b-instruct-simpo.yaml @@ -0,0 +1,45 @@ +# Model arguments +model_name_or_path: /scratch/gpfs/DANQIC/ym0081/hf_cache/Meta-Llama-3-8B-Instruct +torch_dtype: null +use_flash_attention_2: true + +# Data training arguments +dataset_mixer: + /scratch/gpfs/DANQIC/ym0081/hf_cache/llama3-ultrafeedback: 1.0 +dataset_splits: +- train +- test +preprocessing_num_workers: 12 + +# SimPOTrainer arguments +bf16: true +beta: 2.5 +gamma: 1.4 +do_eval: true +evaluation_strategy: steps +eval_steps: 400 +gradient_accumulation_steps: 8 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: False +hub_model_id: simpo-exps +learning_rate: 1.0e-6 +log_level: info +logging_steps: 5 +lr_scheduler_type: cosine +max_length: 2048 +max_prompt_length: 1800 +num_train_epochs: 1 +optim: adamw_torch +output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/llama-3-8b-instruct-simpo +run_name: llama-3-8b-instruct-simpo +per_device_train_batch_size: 2 +per_device_eval_batch_size: 4 +push_to_hub: false +save_strategy: "steps" +save_steps: 1000000 +report_to: +- wandb +save_total_limit: 20 +seed: 42 +warmup_ratio: 0.1 diff --git a/training_configs/mistral-7b-base-simpo.yaml b/training_configs/mistral-7b-base-simpo.yaml new file mode 100644 index 0000000..87d329c --- /dev/null +++ b/training_configs/mistral-7b-base-simpo.yaml @@ -0,0 +1,45 @@ +# Model arguments +model_name_or_path: alignment-handbook/zephyr-7b-sft-full +torch_dtype: null +use_flash_attention_2: true + +# Data training arguments +dataset_mixer: + HuggingFaceH4/ultrafeedback_binarized: 1.0 +dataset_splits: +- train_prefs +- test_prefs +preprocessing_num_workers: 12 + +# SimPOTrainer arguments +bf16: true +beta: 2.0 +gamma: 1.6 +do_eval: true +evaluation_strategy: steps +eval_steps: 400 +gradient_accumulation_steps: 8 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: False +hub_model_id: simpo-exps +learning_rate: 3.0e-7 +log_level: info +logging_steps: 5 +lr_scheduler_type: cosine +max_length: 1024 +max_prompt_length: 512 +num_train_epochs: 1 +optim: adamw_torch +output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/mistral-7b-base-simpo +run_name: mistral-7b-base-simpo +per_device_train_batch_size: 2 +per_device_eval_batch_size: 4 +push_to_hub: false +save_strategy: "steps" +save_steps: 1000000 +report_to: +- wandb +save_total_limit: 20 +seed: 42 +warmup_ratio: 0.1 diff --git a/training_configs/mistral-7b-instruct-simpo.yaml b/training_configs/mistral-7b-instruct-simpo.yaml new file mode 100644 index 0000000..08f6842 --- /dev/null +++ b/training_configs/mistral-7b-instruct-simpo.yaml @@ -0,0 +1,45 @@ +# Model arguments +model_name_or_path: /scratch/gpfs/DANQIC/ym0081/hf_cache/Mistral-7B-Instruct-v0.2 +torch_dtype: null +use_flash_attention_2: true + +# Data training arguments +dataset_mixer: + snorkelai/Snorkel-Mistral-PairRM-DPO-Dataset: 1.0 +dataset_splits: +- train +- test +preprocessing_num_workers: 12 + +# SimPOTrainer arguments +bf16: true +beta: 2.5 +gamma: 0.3 +do_eval: true +evaluation_strategy: steps +eval_steps: 400 +gradient_accumulation_steps: 8 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: False +hub_model_id: simpo-exps +learning_rate: 5.0e-7 +log_level: info +logging_steps: 5 +lr_scheduler_type: cosine +max_length: 2048 +max_prompt_length: 1800 +num_train_epochs: 1 +optim: adamw_torch +output_dir: /scratch/gpfs/DANQIC/ym0081/checkpoints_new/mistral-7b-instruct-simpo +run_name: mistral-7b-instruct-simpo +per_device_train_batch_size: 2 +per_device_eval_batch_size: 4 +push_to_hub: false +save_strategy: "steps" +save_steps: 1000000 +report_to: +- wandb +save_total_limit: 20 +seed: 42 +warmup_ratio: 0.1