Add run_orpo.py (#143)

* Add `ORPOConfig`

* Add `task=orpo` and support `(prompt,chosen,rejected)` datasets

* Add missing `model_init_kwargs` and `dataset_num_proc`

* Add `run_orpo.py` (WIP)

* Update `trl` dependency from source

* Add `setup_chat_format` before `apply_chat_template`

* Add `config_full.yaml` for `mistral-7b-orpo`

* Fix comment indentation

* Use `chat_template=chatml` instead

* Add `kaist-ai/mistral-orpo-capybara-7k` recipe

* Rename `DPOTrainer` to `ORPOTrainer` in `config_full.yaml` files

* Run `black --line-length 119 src`

* Add `is_openai_format` to fix `(prompt,chosen,rejected)` formatting

* Run `black --line-length 119 src`

* Fix `isort` in `run_orpo.py`

* Update `mistral-capybara/orpo/config_full.yaml`

* Check if `test` is available split

* Pin `trl` to `alvarobartt/trl` fork (debugging)

* Add `qwen-capybara` recipe

* Update `mistral-capybara` recipe

* Set `add_generation_prompt=True` if `task="orpo"`

* Reduce `logging_steps` to 10

* Unset `add_generation_prompt` when `task=orpo`

* Add filtering based on prompt length

Done similarly to the original implementation, in order to better reproduce their results

* Fix prompt length filtering

* Update `trl` pinned version

* Remove extra outdate config files

* Update `recipes/mistral-capybara/orpo/config_full.yaml`

* Run `make style`

* Activate BEAST MODE

* Pin deps

* Add readme

* Fix dep

---------

Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
This commit is contained in:
Alvaro Bartolome
2024-04-11 16:02:20 +02:00
committed by GitHub
parent a83b1f617f
commit 70769f9e9b
8 changed files with 468 additions and 17 deletions
+26
View File
@@ -0,0 +1,26 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
+1 -1
View File
@@ -1,5 +1,5 @@
# Instructions to StarChat2
# Instructions to train StarChat2
Similar to how we trained Zephyr 7B Beta in our [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps:
+23
View File
@@ -0,0 +1,23 @@
# Instructions to train Zephyr-141B-A35B with ORPO
This model is fine-tuned via a novel alignment algorithm called [Odds Ratio Preference Optimization (ORPO)](https://huggingface.co/papers/2403.07691). ORPO does not require an SFT step to achieve high performance and is thus much more computationally efficient than methods like DPO and PPO. To train Zephyr-141B-A35B, we used the [`argilla/distilabel-capybara-dpo-7k-binarized`](https://huggingface.co/datasets/argilla/distilabel-capybara-dpo-7k-binarized) preference dataset, which consists of synthetic, high-quality, multi-turn preferences that have been scored via LLMs.
See below for commands to train these models using FSDP. **Note:** we found it was not possible to train this large model with DeepSpeed ZeRO-3 due to unresolved NCCL errors which cause GPUs to hang.
## Full training examples
You will require 4 nodes of 8 GPUs (80GB of VRAM) to train the full model - alternatively, you may be able to train on fewer GPUs by adjusting `per_device_train_batch_size` and `gradient_accumulation_steps` and `num_train_epochs` to keep the global batch size constant. A recipe involving QLoRA will come later 🤗.
To run with Slurm, use:
```shell
sbatch --job-name=handbook_sft --nodes=4 recipes/launch.slurm zephyr-141b-A35b orpo full fsdp
```
Under the hood, this calls the following script which can be adapted to other models and datasets:
```shell
ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file recipes/accelerate_configs/fsdp.yaml scripts/run_orpo.py recipes/zephyr-141b-A35b/orpo/config_full.yaml
```
@@ -0,0 +1,39 @@
# Model arguments
model_name_or_path: mistral-community/Mixtral-8x22B-v0.1
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true
# Data training arguments
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
dataset_mixer:
argilla/distilabel-capybara-dpo-7k-binarized: 1.0
dataset_splits:
- train
preprocessing_num_workers: 8
# ORPOTrainer arguments
bf16: true
beta: 0.05
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
hub_model_id: zephyr-orpo-141b-A35b
learning_rate: 5.0e-6
log_level: info
logging_steps: 10
lr_scheduler_type: inverse_sqrt
max_length: 2048
max_prompt_length: 1792
num_train_epochs: 3
optim: adamw_bnb_8bit
output_dir: data/zephyr-orpo-141b-A35b
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- tensorboard
- wandb
save_strategy: "no"
seed: 42
warmup_steps: 100
+271
View File
@@ -0,0 +1,271 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import sys
from typing import Any, Dict
import torch
import transformers
from transformers import AutoModelForCausalLM, set_seed
from alignment import (
DataArguments,
H4ArgumentParser,
ModelArguments,
apply_chat_template,
decontaminate_humaneval,
get_checkpoint,
get_datasets,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
get_tokenizer,
)
from alignment.configs import ORPOConfig
from trl import ORPOTrainer, setup_chat_format
logger = logging.getLogger(__name__)
def main():
parser = H4ArgumentParser((ModelArguments, DataArguments, ORPOConfig))
model_args, data_args, training_args = parser.parse()
#######
# Setup
#######
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.info(f"Model parameters {model_args}")
logger.info(f"Data parameters {data_args}")
logger.info(f"Training/evaluation parameters {training_args}")
# Check for last checkpoint
last_checkpoint = get_checkpoint(training_args)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
# Set seed for reproducibility
set_seed(training_args.seed)
###############
# Load datasets
###############
raw_datasets = get_datasets(
data_args,
splits=data_args.dataset_splits,
configs=data_args.dataset_configs,
columns_to_keep=[
"prompt",
"chosen",
"rejected",
],
)
logger.info(
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
)
column_names = list(raw_datasets["train"].features)
#####################################
# Load tokenizer and process datasets
#####################################
data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
tokenizer = get_tokenizer(model_args, data_args)
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
# For ChatML we need to add special tokens and resize the embedding layer
if "<|im_start|>" in tokenizer.chat_template:
model, tokenizer = setup_chat_format(model, tokenizer)
#####################
# Apply chat template
#####################
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={
"tokenizer": tokenizer,
"task": "orpo",
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Formatting comparisons with prompt template",
)
#############################
# Filter out seq > max_length
#############################
if training_args.max_prompt_length is not None:
unfiltered_train_samples = len(raw_datasets["train"])
if "test" in raw_datasets:
unfiltered_test_samples = len(raw_datasets["test"])
def filter_fn(sample: Dict[str, Any]) -> Dict[str, Any]:
prompt_length = tokenizer(
sample["text_prompt"],
return_tensors="pt",
)[
"input_ids"
].size(dim=-1)
return prompt_length < training_args.max_prompt_length
raw_datasets = raw_datasets.filter(
filter_fn,
desc="Filtering out the samples where len(text_prompt) > max_prompt_length",
)
filtered_train_samples = unfiltered_train_samples - len(raw_datasets["train"])
logger.info(
f"Filtered out {filtered_train_samples} training samples out of the {unfiltered_train_samples} samples."
)
if "test" in raw_datasets:
filtered_test_samples = unfiltered_test_samples - len(raw_datasets["test"])
logger.info(
f"Filtered out {filtered_test_samples} test samples out of the {unfiltered_test_samples} samples."
)
##########################
# Decontaminate benchmarks
##########################
num_raw_train_samples = len(raw_datasets["train"])
raw_datasets = raw_datasets.filter(
decontaminate_humaneval,
fn_kwargs={"text_column": "text_chosen"},
batched=True,
batch_size=10_000,
num_proc=1,
desc="Decontaminating HumanEval samples",
)
num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"])
logger.info(
f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set."
)
# Replace column names with what TRL needs, text_prompt -> prompt, text_chosen -> chosen and text_rejected -> rejected
for split in raw_datasets.keys():
raw_datasets[split] = raw_datasets[split].rename_columns(
{
"text_prompt": "prompt",
"text_chosen": "chosen",
"text_rejected": "rejected",
}
)
# Log a few random samples from the training set:
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}")
logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}")
logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}")
##########################
# Instantiate ORPO trainer
##########################
trainer = ORPOTrainer(
model,
args=training_args,
train_dataset=raw_datasets["train"],
eval_dataset=raw_datasets["test"] if "test" in raw_datasets else None,
tokenizer=tokenizer,
peft_config=get_peft_config(model_args), # type: ignore
)
###############
# Training loop
###############
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(raw_datasets["train"])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
logger.info("*** Training complete ***")
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"dataset": list(data_args.dataset_mixer.keys()),
"dataset_tags": list(data_args.dataset_mixer.keys()),
"tags": ["alignment-handbook"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
##########
# Evaluate
##########
if training_args.do_eval and "test" in raw_datasets:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(raw_datasets["test"])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.push_to_hub is True:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
logger.info("*** Training complete! ***")
if __name__ == "__main__":
main()
+4 -4
View File
@@ -41,10 +41,10 @@ if stale_egg_info.exists():
# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
_deps = [
"accelerate==0.27.2",
"accelerate>=0.29.2",
"bitsandbytes==0.41.2.post2",
"black==23.1.0",
"datasets==2.14.6",
"datasets>=2.18.0",
"deepspeed==0.12.2",
"einops>=0.6.1",
"evaluate==0.4.0",
@@ -65,8 +65,8 @@ _deps = [
"scipy",
"tensorboard",
"torch==2.1.2",
"transformers @ git+https://github.com/huggingface/transformers.git@831bc25d8fdb85768402f772cf65cc3d7872b211", # Enable StarCoder2
"trl==0.7.10",
"transformers>=4.39.3",
"trl>=0.8.2",
"jinja2>=3.0.0",
"tqdm>=4.64.1",
]
+59
View File
@@ -281,3 +281,62 @@ class DPOConfig(transformers.TrainingArguments):
optim: Optional[str] = field(default="rmsprop")
remove_unused_columns: bool = field(default=False)
loss_type: Optional[str] = field(default="sigmoid", metadata={"help": ("The loss type for DPO.")})
@dataclass
class ORPOConfig(transformers.TrainingArguments):
max_length: Optional[int] = field(
default=None,
metadata={"help": "The maximum length of the sequences in the batch."},
)
max_prompt_length: Optional[int] = field(
default=None,
metadata={"help": "The maximum length of the prompt."},
)
max_completion_length: Optional[int] = field(
default=None,
metadata={"help": "The maximum length of the completions."},
)
beta: float = field(
default=0.1,
metadata={
"help": "The beta factor in ORPO loss (lambda/alpha in paper/code) that is the weight of the relative loss ratio in the SFT loss."
},
)
disable_dropout: bool = field(
default=True,
metadata={"help": "Whether or not to disable dropouts in `model`."},
)
label_pad_token_id: int = field(
default=-100,
metadata={"help": "The label pad token id."},
)
padding_value: Optional[int] = field(
default=None,
metadata={"help": "The padding value if it is different to the tokenizer's pad_token_id."},
)
truncation_mode: str = field(
default="keep_end",
metadata={"help": "The truncation mode to use, either `keep_end` or `keep_start`."},
)
generate_during_eval: bool = field(
default=False,
metadata={"help": "Whether to sample and log generations during evaluation step."},
)
is_encoder_decoder: Optional[bool] = field(
default=None,
metadata={"help": ("If no model is provided, we need to know if the model_init returns an encoder-decoder.")},
)
model_init_kwargs: Optional[Dict] = field(
default=None,
metadata={"help": ("Dict of Optional kwargs to pass when instantiating the model from a string")},
)
dataset_num_proc: Optional[int] = field(
default=None,
metadata={"help": ("The number of workers to use to tokenize the data.")},
)
+45 -12
View File
@@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List, Literal, Optional
from typing import Any, List, Literal, Optional
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
@@ -50,7 +51,9 @@ def apply_chat_template(
if auto_insert_empty_system_msg:
maybe_insert_system_message(messages, tokenizer)
example["text"] = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True if task == "generation" else False
messages,
tokenize=False,
add_generation_prompt=True if task == "generation" else False,
)
elif task == "rm":
if all(k in example.keys() for k in ("chosen", "rejected")):
@@ -67,32 +70,58 @@ def apply_chat_template(
raise ValueError(
f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
)
elif task == "dpo":
elif task in ["dpo", "orpo"]:
if all(k in example.keys() for k in ("chosen", "rejected")):
# For DPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]):
raise ValueError(
f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages"
)
# For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
# We therefore need to extract the N-1 turns to form the prompt
prompt_messages = example["chosen"][:-1]
if "prompt" in example and is_openai_format(example["prompt"]):
prompt_messages = example["prompt"]
chosen_messages = example["chosen"]
rejected_messages = example["rejected"]
else:
prompt_messages = example["chosen"][:-1]
# Now we extract the final turn to define chosen/rejected responses
chosen_messages = example["chosen"][-1:]
rejected_messages = example["rejected"][-1:]
# Prepend a system message if the first message is not a system message
if auto_insert_empty_system_msg:
maybe_insert_system_message(prompt_messages, tokenizer)
# Now we extract the final turn to define chosen/rejected responses
chosen_messages = example["chosen"][-1:]
rejected_messages = example["rejected"][-1:]
example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
else:
raise ValueError(
f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
f"Could not format example as dialogue for `{task}` task! Require either the "
f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}"
)
else:
raise ValueError(
f"Task {task} not supported, please ensure that the provided task is one of {['sft', 'generation', 'rm', 'dpo']}"
f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']"
)
return example
def is_openai_format(messages: Any) -> bool:
"""
Check if the input messages are in OpenAI format.
Args:
messages (`Any`):
Messages to check.
Returns:
`bool`: Whether the messages are in OpenAI format.
"""
if isinstance(messages, list) and all(isinstance(message, dict) for message in messages):
return all("role" in message and "content" in message for message in messages)
return False
def get_datasets(
data_config: DataArguments | dict,
splits: Optional[List[str]] = None,
@@ -138,7 +167,11 @@ def get_datasets(
raise ValueError(f"Data config {data_config} not recognized.")
raw_datasets = mix_datasets(
dataset_mixer, splits=splits, configs=configs, columns_to_keep=columns_to_keep, shuffle=shuffle
dataset_mixer,
splits=splits,
configs=configs,
columns_to_keep=columns_to_keep,
shuffle=shuffle,
)
return raw_datasets