Merge pull request #11 from huggingface/zephyr-recipe

Code release
This commit is contained in:
lewtun
2023-11-10 15:54:13 +01:00
committed by GitHub
30 changed files with 1861 additions and 23 deletions
+31
View File
@@ -0,0 +1,31 @@
name: Tests
on:
push:
branches:
- main
- v*-release
pull_request:
branches:
- main
jobs:
unit-tests:
name: Run unit tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Setup Python environment
uses: actions/setup-python@v2
with:
python-version: 3.10.10
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[dev, torch]"
- name: Run unit tests
run: HF_TOKEN=$HF_TOKEN pytest -sv tests/
+4
View File
@@ -158,3 +158,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# Temp folders
data/
wandb/
+10 -10
View File
@@ -3,31 +3,31 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src
check_dirs := src tests
check_dirs := src tests scripts
style:
python -m black --line-length 119 --target-version py310 $(check_dirs) setup.py
python -m isort $(check_dirs) setup.py
black --line-length 119 --target-version py310 $(check_dirs) setup.py
isort $(check_dirs) setup.py
quality:
python -m black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
python -m isort --check-only $(check_dirs) setup.py
python -m flake8 --max-line-length 119 $(check_dirs) setup.py
black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
isort --check-only $(check_dirs) setup.py
flake8 --max-line-length 119 $(check_dirs) setup.py
# Release stuff
pre-release:
python src/alignment/utils/release.py
python src/alignment/release.py
pre-patch:
python src/alignment/utils/release.py --patch
python src/alignment/release.py --patch
post-release:
python src/alignment/utils/release.py --post_release
python src/alignment/release.py --post_release
post-patch:
python src/alignment/utils/release.py --post_release --patch
python src/alignment/release.py --post_release --patch
wheels:
python setup.py bdist_wheel && python setup.py sdist
+31 -3
View File
@@ -10,6 +10,10 @@ However, we know from the [InstructGPT](https://huggingface.co/papers/2203.02155
The Alignment Handbook aims to fill that gap by providing the community with a series of robust training recipes that span the whole pipeline.
## News 🗞️
* November 10, 2023: We release all the training code to replicate Zephyr-7b-β 🪁!
## Links 🔗
* [Zephyr 7B models, datasets, and demos](https://huggingface.co/collections/HuggingFaceH4/zephyr-7b-6538c6d6d5ddd1cbb1744a66)
@@ -32,13 +36,20 @@ To run the code in this project, first create a Python virtual environment using
conda create -n handbook python=3.10 && conda activate handbook
```
Next, install PyTorch v2.1.0. Since this hardware-dependent, we
Next, install PyTorch `v2.1.0` - the precise version is important for reproducibility! Since this is hardware-dependent, we
direct you to the [PyTorch Installation Page](https://pytorch.org/get-started/locally/).
Once PyTorch is installed, you can install the remaining package dependencies as follows:
You can then install the remaining package dependencies as follows:
```shell
pip install .
python -m pip install .
```
You will also need Flash Attention 2 installed, which can be done by running:
_Note: If your machine has less than 96GB of RAM and many CPU cores, reduce the MAX_JOBS., e.g. `MAX_JOBS=4 pip install flash-attn --no-build-isolation` _
```shell
python -m pip install flash-attn --no-build-isolation
```
Next, log into your Hugging Face account as follows:
@@ -53,6 +64,23 @@ Finally, install Git LFS so that you can push models to the Hugging Face Hub:
sudo apt-get install git-lfs
```
You can now checkout the `scripts` and `recipes` directories for instructions on how to train some models 🪁!
## Project structure
```
├── LICENSE
├── Makefile <- Makefile with commands like `make style`
├── README.md <- The top-level README for developers using this project
├── chapters <- Educational content to render on hf.co/learn
├── recipes <- Recipe configs, accelerate configs, slurm scripts
├── scripts <- Scripts to train and evaluate chat models
├── setup.cfg <- Installation config (mostly used for configuring code quality & tests)
├── setup.py <- Makes project pip installable (pip install -e .) so `alignment` can be imported
├── src <- Source code for use in this project
└── tests <- Unit tests
```
## Citation
If you find the content of this repo useful in your work, please cite it as follows:
@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
+16
View File
@@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
View File
+93
View File
@@ -0,0 +1,93 @@
#!/bin/bash
#SBATCH --ntasks-per-node=1
#SBATCH --exclusive
#SBATCH --gres=gpu:8
#SBATCH --partition=production-cluster # Adjust this for your cluster
#SBATCH --output=/fsx/h4/logs/%x-%j.out # Adjust this for your cluster
#SBATCH --err=/fsx/h4/logs/%x-%j.err # Adjust this for your cluster
set -x -e
source ~/.bashrc
conda activate handbook
echo "START TIME: $(date)"
MODEL=$1
TASK=$2
PRECISION=$3
ACCELERATOR=$4
OPTIONAL_ARGS=$5
# Training setup
NUM_NODES=$SLURM_NNODES
GPUS_PER_NODE=8
WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))
# Due to conflicts between Accelerate's DeepSpeed configs and Transformers' TrainingArguments, we need to parse the gradient accumulation steps from the config file to ensure they match
CONFIG_FILE=recipes/$MODEL/$TASK/config_$PRECISION.yaml
GRAD_ACC_STEPS=$(grep 'gradient_accumulation_steps' $CONFIG_FILE | awk '{print $2}')
# Split the string into individual arguments
IFS=' ' read -ra ARGS <<< "$OPTIONAL_ARGS"
# Loop through the arguments and find the one with "--gradient_accumulation_steps"
for arg in "${ARGS[@]}"; do
if [[ "$arg" == "--gradient_accumulation_steps="* ]]; then
# Extract the value after the equals sign
GRAD_ACC_STEPS="${arg#*=}"
break # Exit the loop once we find the desired argument
fi
done
echo "Gradient accumulation steps: $GRAD_ACC_STEPS"
# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000
export CMD=" \
scripts/run_$TASK.py $CONFIG_FILE $OPTIONAL_ARGS
"
export LAUNCHER="ACCELERATE_LOG_LEVEL=info accelerate launch \
--config_file recipes/accelerate_configs/$ACCELERATOR.yaml \
--gradient_accumulation_steps $GRAD_ACC_STEPS \
--num_machines $NUM_NODES \
--num_processes $WORLD_SIZE \
--main_process_ip $MASTER_ADDR \
--main_process_port $MASTER_PORT \
--machine_rank \$SLURM_PROCID \
--rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
--max_restarts 1 \
--role \$(hostname -s): \
--tee 3 \
"
# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1
# Specific configuration optimized for the Hugging Face Compute Cluster
# Be ye warned this may not work on other clusters!
export NCCL_PROTO=simple
export RDMAV_FORK_SAFE=1
export FI_EFA_FORK_SAFE=1
export FI_EFA_USE_DEVICE_RDMA=1
export FI_PROVIDER=efa
export FI_LOG_LEVEL=1
export NCCL_IB_DISABLE=1
export NCCL_SOCKET_IFNAME=ens
# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
"
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1
echo "END TIME: $(date)"
View File
View File
View File
+29
View File
@@ -0,0 +1,29 @@
# Instructions to Replicate Zephyr 7B
As described in the Zephyr [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps:
1. Apply SFT to fine-tune Mistral 7B on a filtered version of the UltraChat dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)).
2. Align the SFT model to AI feedback via DPO on a preprocessed version of the UltraFeedback dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)).
See below for commands to train these models using either DeepSpeed ZeRO-3 or LoRA.
## Full training examples
```shell
# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_full.yaml
# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-beta/beta/beta/beta/dpo/config_full.yaml
```
## LoRA training examples
```shell
# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/beta/sft/config_lora.yaml
# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_lora.yaml
```
@@ -0,0 +1,37 @@
# Model arguments
model_name_or_path: alignment-handbook/zephyr-7b-sft-full
# Data training arguments
# For definitions, see: src/h4/training/config.py
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 1.0
dataset_splits:
- train_prefs
- test_prefs
preprocessing_num_workers: 12
# DPOTrainer arguments
bf16: true
beta: 0.1
do_eval: true
evaluation_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 1
gradient_checkpointing: true
hub_model_id: zephyr-7b-dpo-full
learning_rate: 5.0e-7
log_level: info
logging_steps: 10
lr_scheduler_type: linear
max_length: 1024
max_prompt_length: 512
num_train_epochs: 3
optim: rmsprop
output_dir: data/zephyr-7b-dpo-full
per_device_train_batch_size: 8
per_device_eval_batch_size: 4
push_to_hub: true
save_strategy: "no"
save_total_limit: null
seed: 42
warmup_ratio: 0.1
@@ -0,0 +1,51 @@
# Model arguments
model_name_or_path: alignment-handbook/zephyr-7b-sft-lora
torch_dtype: auto
# LoRA arguments
use_peft: true
lora_r: 64
lora_alpha: 16
lora_dropout: 0.1
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Data training arguments
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 1.0
dataset_splits:
- train_prefs
- test_prefs
preprocessing_num_workers: 12
# DPOTrainer arguments
bf16: true
beta: 0.1
do_eval: true
evaluation_strategy: epoch
eval_steps: 100
gradient_accumulation_steps: 32
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: zephyr-7b-dpo-lora
learning_rate: 5.0e-7
log_level: info
logging_steps: 10
lr_scheduler_type: linear
max_length: 1024
max_prompt_length: 512
num_train_epochs: 3
optim: rmsprop
output_dir: data/zephyr-7b-dpo-lora # It is handy to append `hub_model_revision` to keep track of your local experiments
per_device_train_batch_size: 2
per_device_eval_batch_size: 4
push_to_hub: true
save_strategy: "no"
save_total_limit: null
seed: 42
warmup_ratio: 0.1
@@ -0,0 +1,42 @@
# Model arguments
model_name_or_path: mistralai/Mistral-7B-v0.1
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true
# Data training arguments
dataset_mixer:
HuggingFaceH4/ultrachat_200k: 1.0
dataset_splits:
- train_sft
- test_sft
preprocessing_num_workers: 12
# SFT trainer config
bf16: true
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 2
gradient_checkpointing: true
hub_model_id: zephyr-7b-sft-full
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
max_steps: -1
num_train_epochs: 1
output_dir: data/zephyr-7b-sft-full
overwrite_output_dir: true
per_device_eval_batch_size: 16
per_device_train_batch_size: 32
push_to_hub: true
remove_unused_columns: true
report_to:
- tensorboard
save_strategy: "no"
save_total_limit: null
seed: 42
tf32: true
@@ -0,0 +1,52 @@
# Model arguments
model_name_or_path: mistralai/Mistral-7B-v0.1
torch_dtype: auto
use_flash_attention_2: true
# LoRA arguments
use_peft: true
lora_r: 64
lora_alpha: 16
lora_dropout: 0.1
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Data training arguments
dataset_mixer:
HuggingFaceH4/ultrachat_200k: 1.0
dataset_splits:
- train_sft
- test_sft
preprocessing_num_workers: 12
# SFT trainer config
bf16: true
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 128
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: zephyr-7b-sft-lora
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
max_steps: -1
num_train_epochs: 1
output_dir: data/zephyr-7b-sft-lora
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- tensorboard
save_strategy: "no"
save_total_limit: null
seed: 42
+105
View File
@@ -0,0 +1,105 @@
# Scripts to Train and Evaluate Chat Models
## Fine-tuning
In the handbook, we provide three main ways to align LLMs for chat:
- Full fine-tuning on a multi-GPU machine with DeepSpeed ZeRO-3 (tested on an 8 x A100 (80GB) node).
- LoRA or QLoRA fine-tuning on a single consumer 24GB GPU (tested on a RTX 4090).
- LoRA fine-tuning on a multi-GPU machine with DeepSpeed ZeRO-3 (tested on a 2 x A100s (80GB)).
In practice, we find comparable performance for both full and LoRA fine-tuning, with the latter having the advantage of producing small adapter weights that are fast to upload and download from the Hugging Face Hub. Here's the two general commands to fine-tune your models:
```shell
# Full training with ZeRO-3 on 8 GPUs
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml
# LoRA training on a single GPU
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_{task}.py recipes/{model_name}/{task}/config_lora.yaml
# QLoRA 4-bit training on a single GPU
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_{task}.py recipes/{model_name}/{task}/config_lora.yaml --load_in_4bit=true
# LoRA training with ZeRO-3 on two or more GPUs
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_lora.yaml
```
Here `{task}` refers to type of training you wish to run (SFT, DPO, etc), while `{model_name}` refers to the choice of recipe in the `recipes` directory. For example, to replicate Zephyr-7B-β you can run:
```shell
# Step 1 - train SFT policy
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_full.yaml
# Step 2 - align with DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_full.yaml
```
** 💡 Tip:** If you scale up/down the number of GPUs, we recommend also scaling up the per-device batch size or number of gradient accumulation steps to keep the global batch size constant (and thus replicate our results).
By default, these scripts will push each model to your Hugging Face Hub username, i.e. `{username}/{model_name}-{task}`. You can override the parameters in each YAML config by appending them to the command as follows:
```shell
# Change batch size, number of epochs etc
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml --per_device_train_batch_size=42 --num_train_epochs=5
```
By default all training metrics are logged with TensorBoard. If you have a [Weights and Biases](https://wandb.ai/site) account and are logged in, you can view the training metrics by appending `--report_to=wandb`, e.g.
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml --report_to=wandb
```
## Launching jobs on a Slurm cluster
If you have access to a Slurm cluster, we provide a `recipes/launch.slurm` script that will automatically queue training jobs for you. Here's how you can use it:
```shell
sbatch --job-name=handbook_{task} --nodes=1 recipes/launch.slurm {model_name} {task} {precision} {accelerator}
```
Here `{model_name}` and `{task}` are defined as above, while `{precision}` refers to the type of training (full vs LoRA) and `{accelerator}` refers to the choice of 🤗 Accelerate config in `recipes/accelerate_configs`. If you wish to override the default config parameters, you can provide them by appending a space-separated string like `'--arg1=value1 --arg2=value2'. Here's a concrete example to run SFT on 1 node of 8 GPUs:
```shell
# Launch on Slurm and override default hyperparameters
sbatch --job-name=handbook_sft --nodes=1 recipes/launch.slurm zephyr-7b-beta sft full deepspeed_zero3 '--per_device_train_batch_size=42 --num_train_epochs=5'
```
You can scale the number of nodes by increasing the `--nodes` flag.
**⚠️ Note:** the configuration in `recipes/launch.slurm` is optimised for the Hugging Face Compute Cluster and may require tweaking to be adapted to your own compute nodes.
## Fine-tuning on your datasets
Under the hood, each training script uses the `get_datasets()` function which allows one to easily combing multiple datasets with varying proportions. For instance, this is how one can specify multiple datasets and which splits to combine in one of the YAML configs:
```yaml
datasets_mixer:
dataset_1: 0.5 # Use 50% of the training examples
dataset_2: 0.66 # Use 66% of the training examples
dataset_3: 0.10 # Use 10% of the training examples
dataset_splits:
- train_xxx # The training splits to mix
- test_xxx # The test splits to mix
```
If you want to fine-tune on your own datasets, the main thing to keep in mind is how the chat templates are applied to the dataset blend. Since each task (SFT, DPO, etc), requires a different format, we assume the datasets have the following columns:
**SFT**
* `messages`: A list of `dicts` in the form `{"role": "{role}", "content": {content}}`.
* See [ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) for an example.
**DPO**
* `chosen`: A list of `dicts` in the form `{"role": "{role}", "content": {content}}` corresponding to the preferred dialogue.
* `rejected`: A list of `dicts` in the form `{"role": "{role}", "content": {content}}` corresponding to the dispreferred dialogue.
* See [ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) for an example.
We also find it useful to include dedicated splits per task in our datasets, so e.g. we have:
* `{train,test}_sft`: Splits for SFT training.
* `{train,test}_gen`: Splits for generation ranking like rejection sampling or PPO.
* `{train,test}_prefs`: Splits for preference modelling, like reward modelling or DPO.
If you format your dataset in the same way, our training scripts should work out of the box!
+224
View File
@@ -0,0 +1,224 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
import torch
import transformers
from transformers import AutoModelForCausalLM, set_seed
from accelerate import Accelerator
from alignment import (
DataArguments,
DPOConfig,
H4ArgumentParser,
ModelArguments,
apply_chat_template,
get_datasets,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
get_tokenizer,
is_adapter_model,
)
from peft import PeftConfig, PeftModel
from trl import DPOTrainer
logger = logging.getLogger(__name__)
def main():
parser = H4ArgumentParser((ModelArguments, DataArguments, DPOConfig))
model_args, data_args, training_args = parser.parse()
#######
# Setup
#######
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.info(f"Model parameters {model_args}")
logger.info(f"Data parameters {data_args}")
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed for reproducibility
set_seed(training_args.seed)
# Increase distributed timeout to 3h to enable push to Hub to complete
accelerator = Accelerator()
###############
# Load datasets
###############
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits)
logger.info(
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
)
column_names = list(raw_datasets["train"].features)
#####################################
# Load tokenizer and process datasets
#####################################
data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
tokenizer = get_tokenizer(model_args, data_args)
#####################
# Apply chat template
#####################
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={"tokenizer": tokenizer, "task": "dpo"},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Formatting comparisons with prompt template",
)
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
for split in ["train", "test"]:
raw_datasets[split] = raw_datasets[split].rename_columns(
{"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
)
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map(),
quantization_config=get_quantization_config(model_args),
)
model = model_args.model_name_or_path
if is_adapter_model(model, model_args.model_revision):
# load the model, merge the adapter weights and unload the adapter
# Note: to run QLora, you will need to merge the based model separately as the merged model in 16bit
logger.info(f"Merging peft adapters for {model_args.model_name_or_path=}")
peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
model_kwargs = dict(
revision=model_args.base_model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
**model_kwargs,
)
model = PeftModel.from_pretrained(
base_model, model_args.model_name_or_path, revision=model_args.model_revision
)
model.eval()
model = model.merge_and_unload()
model_kwargs = None
ref_model = model
ref_model_kwargs = model_kwargs
if model_args.use_peft is True:
ref_model = None
ref_model_kwargs = None
#########################
# Instantiate DPO trainer
#########################
dpo_trainer = DPOTrainer(
model,
ref_model,
model_init_kwargs=model_kwargs,
ref_model_init_kwargs=ref_model_kwargs,
args=training_args,
beta=training_args.beta,
train_dataset=raw_datasets["train"],
eval_dataset=raw_datasets["test"],
tokenizer=tokenizer,
max_length=training_args.max_length,
max_prompt_length=training_args.max_prompt_length,
peft_config=get_peft_config(model_args),
)
###############
# Training loop
###############
train_result = dpo_trainer.train()
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(raw_datasets["train"])
)
metrics["train_samples"] = min(max_train_samples, len(raw_datasets["train"]))
dpo_trainer.log_metrics("train", metrics)
dpo_trainer.save_metrics("train", metrics)
dpo_trainer.save_state()
logger.info("*** Training complete ***")
##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = dpo_trainer.evaluate()
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(raw_datasets["test"])
)
metrics["eval_samples"] = min(max_eval_samples, len(raw_datasets["test"]))
dpo_trainer.log_metrics("eval", metrics)
dpo_trainer.save_metrics("eval", metrics)
##################################
# Save model and create model card
##################################
dpo_trainer.save_model(training_args.output_dir)
# Save everything else on main process
if accelerator.is_main_process:
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"dataset": list(data_args.dataset_mixer.keys()),
"dataset_tags": list(data_args.dataset_mixer.keys()),
"tags": ["alignment-handbook"],
}
dpo_trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
dpo_trainer.model.config.use_cache = True
dpo_trainer.model.config.save_pretrained(training_args.output_dir)
if training_args.push_to_hub is True:
dpo_trainer.push_to_hub()
# Ensure we don't timeout on model save / push to Hub
logger.info("*** Waiting for all processes to finish ***")
accelerator.wait_for_everyone()
logger.info("*** Run complete! ***")
if __name__ == "__main__":
main()
+191
View File
@@ -0,0 +1,191 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Supervised fine-tuning script for decoder language models.
"""
import logging
import random
import sys
import datasets
import torch
import transformers
from transformers import set_seed
from accelerate import Accelerator
from alignment import (
DataArguments,
H4ArgumentParser,
ModelArguments,
SFTConfig,
apply_chat_template,
get_datasets,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
get_tokenizer,
)
from trl import SFTTrainer
logger = logging.getLogger(__name__)
def main():
parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))
model_args, data_args, training_args = parser.parse()
# Set seed for reproducibility
set_seed(training_args.seed)
accelerator = Accelerator()
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Data parameters {data_args}")
logger.info(f"Training/evaluation parameters {training_args}")
###############
# Load datasets
###############
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits)
logger.info(
f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
)
################
# Load tokenizer
################
tokenizer = get_tokenizer(model_args, data_args)
#####################
# Apply chat template
#####################
raw_datasets = raw_datasets.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer, "task": "sft"})
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
#######################
# Load pretrained model
#######################
logger.info("*** Load pretrained model ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map(),
quantization_config=get_quantization_config(model_args),
)
logger.info("*** Model loaded! ***")
########################
# Initialize the Trainer
########################
trainer = SFTTrainer(
model=model_args.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text",
max_seq_length=training_args.max_seq_length,
tokenizer=tokenizer,
packing=True,
peft_config=get_peft_config(model_args),
)
###############
# Training loop
###############
logger.info("*** Train ***")
train_result = trainer.train()
metrics = train_result.metrics
max_train_samples = data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
if accelerator.is_main_process:
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"dataset": list(data_args.dataset_mixer.keys()),
"dataset_tags": list(data_args.dataset_mixer.keys()),
"tags": ["alignment-handbook"],
}
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
if training_args.push_to_hub is True:
logger.info("Pushing to hub...")
trainer.push_to_hub()
accelerator.wait_for_everyone()
if __name__ == "__main__":
main()
+14 -10
View File
@@ -42,28 +42,30 @@ if stale_egg_info.exists():
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
_deps = [
"accelerate==0.23.0",
"bitsandbytes==0.41.1",
"bitsandbytes==0.41.2.post2",
"black==23.1.0",
"datasets==2.12.0",
"deepspeed==0.9.5",
"einops==0.6.1",
"datasets==2.14.6",
"deepspeed==0.12.2",
"einops>=0.6.1",
"evaluate==0.4.0",
"flake8>=6.0.0",
"hf-doc-builder>=0.4.0",
"huggingface-hub>=0.14.1,<1.0",
"isort>=5.12.0",
"ninja==1.11.1",
"ninja>=1.11.1",
"numpy>=1.24.2",
"packaging>=23.0",
"parameterized>=0.9.0",
"peft==0.5.0",
"peft==0.6.1",
"protobuf<=3.20.2", # Needed to avoid conflicts with `transformers`
"pytest",
"safetensors==0.3.3",
"safetensors>=0.3.3",
"scipy",
"tensorboard",
"torch==2.0.1",
"transformers @ git+https://github.com/huggingface/transformers.git@b3961f7291307ee877ef1a4d057949597d805220",
"trl @ git+https://github.com/huggingface/trl.git@1e56ff0f166888973d69cd9d56be60a9f8edfedb", # TODO bump to next release, added for NEFTune
"torch==2.1.0",
"transformers==4.35.0",
"trl==0.7.4",
"jinja2>=3.0.0",
"tqdm>=4.64.1",
]
@@ -96,12 +98,14 @@ install_requires = [
deps["datasets"],
deps["deepspeed"],
deps["huggingface-hub"],
deps["jinja2"],
deps["ninja"],
deps["numpy"],
deps["packaging"], # utilities from PyPA to e.g., compare versions
deps["peft"],
deps["protobuf"],
deps["safetensors"],
deps["scipy"],
deps["tensorboard"],
deps["tqdm"], # progress bars in model download and training scripts
deps["transformers"],
+4
View File
@@ -1 +1,5 @@
__version__ = "0.2.0.dev0"
from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig
from .data import apply_chat_template, get_datasets
from .model_utils import get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer, is_adapter_model
+272
View File
@@ -0,0 +1,272 @@
# coding=utf-8
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import os
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, NewType, Optional, Tuple
import transformers
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
DataClassType = NewType("DataClassType", Any)
class H4ArgumentParser(HfArgumentParser):
def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]:
"""
Parse a YAML file and overwrite the default/loaded values with the values provided to the command line.
Args:
yaml_arg (`str`):
The path to the config file used
other_args (`List[str]`, *optional`):
A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2'].
Returns:
[`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line
"""
arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))
outputs = []
# strip other args list into dict of key-value pairs
other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args}
used_args = {}
# overwrite the default/loaded value with the value provided to the command line
# adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327
for data_yaml, data_class in zip(arg_list, self.dataclass_types):
keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
for arg, val in other_args.items():
# add only if in keys
if arg in keys:
base_type = data_yaml.__dataclass_fields__[arg].type
inputs[arg] = val
# cast type for ints, floats (default to strings)
if base_type in [int, float]:
inputs[arg] = base_type(val)
if base_type == List[str]:
inputs[arg] = [str(v) for v in val.split(",")]
# bool of a non-empty string is True, so we manually check for bools
if base_type == bool:
if val in ["true", "True"]:
inputs[arg] = True
else:
inputs[arg] = False
# add to used-args so we can check if double add
if arg not in used_args:
used_args[arg] = val
else:
raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior")
obj = data_class(**inputs)
outputs.append(obj)
return outputs
def parse(self) -> DataClassType | Tuple[DataClassType]:
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
# If we pass only one argument to the script and it's the path to a YAML file,
# let's parse it to get our arguments.
output = self.parse_yaml_file(os.path.abspath(sys.argv[1]))
# parse command line args and yaml file
elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"):
output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:])
# parse command line args only
else:
output = self.parse_args_into_dataclasses()
if len(output) == 1:
output = output[0]
return output
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
base_model_revision: Optional[str] = field(
default=None,
metadata={"help": ("The base model checkpoint for weights initialization with PEFT adatpers.")},
)
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"})
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
use_flash_attention_2: bool = field(
default=False,
metadata={
"help": (
"Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`"
)
},
)
use_peft: bool = field(
default=False,
metadata={"help": ("Whether to use PEFT or not for training.")},
)
lora_r: Optional[int] = field(
default=16,
metadata={"help": ("LoRA R value.")},
)
lora_alpha: Optional[int] = field(
default=32,
metadata={"help": ("LoRA alpha.")},
)
lora_dropout: Optional[float] = field(
default=0.05,
metadata={"help": ("LoRA dropout.")},
)
lora_target_modules: Optional[List[str]] = field(
default=None,
metadata={"help": ("LoRA target modules.")},
)
lora_modules_to_save: Optional[List[str]] = field(
default=None,
metadata={"help": ("Model layers to unfreeze & train")},
)
load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"})
load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"})
bnb_4bit_quant_type: Optional[str] = field(
default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
)
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
@dataclass
class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
dataset_mixer: Optional[Dict[str, float]] = field(
default=None,
metadata={"help": ("Datasets and their proportions to be used for training ift/rl.")},
)
dataset_splits: Optional[List[str]] = field(
default_factory=lambda: ["train", "test"],
metadata={"help": ("List of train test splits to use in the dataset")},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
truncation_side: Optional[str] = field(
default=None, metadata={"help": "Truncation side to use for the tokenizer."}
)
@dataclass
class SFTConfig(transformers.TrainingArguments):
"""
Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
"""
max_seq_length: Optional[int] = field(
default=None,
metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
)
logging_first_step: bool = field(
default=True,
metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
)
optim: Optional[str] = field(default="adamw_torch")
@dataclass
class DPOConfig(transformers.TrainingArguments):
"""
Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
"""
beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta factor in DPO loss. Higher beta means less divergence from the initial policy."},
)
hub_model_revision: Optional[str] = field(
default="main",
metadata={"help": ("The Hub model branch to push the model to.")},
)
logging_first_step: bool = field(
default=True,
metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
)
max_prompt_length: Optional[int] = field(
default=None,
metadata={"help": ("For DPO, the maximum length of the prompt to use for conditioning the model.")},
)
max_length: Optional[int] = field(
default=None,
metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
)
optim: Optional[str] = field(default="rmsprop")
remove_unused_columns: bool = field(default=False)
+187
View File
@@ -0,0 +1,187 @@
# coding=utf-8
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import List, Literal, Optional
from datasets import DatasetDict, concatenate_datasets, load_dataset
from .configs import DataArguments
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
def apply_chat_template(
example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n"
):
def _strip_prefix(s, pattern):
# Use re.escape to escape any special characters in the pattern
return re.sub(f"^{re.escape(pattern)}", "", s)
if task in ["sft", "generation"]:
messages = example["messages"]
# We add an empty system message if there is none
if messages[0]["role"] != "system":
messages.insert(0, {"role": "system", "content": ""})
example["text"] = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True if task == "generation" else False
)
elif task == "rm":
if all(k in example.keys() for k in ("chosen", "rejected")):
chosen_messages = example["chosen"]
rejected_messages = example["rejected"]
# We add an empty system message if there is none
if chosen_messages[0]["role"] != "system":
chosen_messages.insert(0, {"role": "system", "content": ""})
if rejected_messages[0]["role"] != "system":
rejected_messages.insert(0, {"role": "system", "content": ""})
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
else:
raise ValueError(
f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
)
elif task == "dpo":
if all(k in example.keys() for k in ("chosen", "rejected")):
# Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token
prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]]
# Insert system message
if example["chosen"][0]["role"] != "system":
prompt_messages.insert(0, {"role": "system", "content": ""})
else:
prompt_messages.insert(0, example["chosen"][0])
# TODO: handle case where chosen/rejected also have system messages
chosen_messages = example["chosen"][1:]
rejected_messages = example["rejected"][1:]
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
example["text_prompt"] = tokenizer.apply_chat_template(
prompt_messages, tokenize=False, add_generation_prompt=True
)
example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix)
example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix)
else:
raise ValueError(
f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
)
return example
def get_datasets(
data_config: DataArguments | dict,
splits: List[str] = ["train", "test"],
shuffle: bool = True,
) -> DatasetDict:
"""
Loads one or more datasets with varying training set proportions.
Args:
data_config (`DataArguments` or `dict`):
Dataset configuration and split proportions.
splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
shuffle (`bool`, *optional*, defaults to `True`):
Whether to shuffle the training data.
Returns
[`DatasetDict`]: The dataset dictionary containing the loaded datasets.
"""
if type(data_config) is DataArguments:
# Structure of the config to read the datasets and their mix
# datasets_mixer:
# - 'dataset1': 0.5
# - 'dataset2': 0.3
# - 'dataset3': 0.2
dataset_mixer = data_config.dataset_mixer
elif type(data_config) is dict:
# Structure of the input is:
# dataset_mixer = {
# "dataset1": 0.5,
# "dataset1": 0.3,
# "dataset1": 0.2,
# }
dataset_mixer = data_config
else:
raise ValueError(f"Data config {data_config} not recognized.")
raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle)
return raw_datasets
def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict:
"""
Loads and mixes datasets according to proportions specified in `dataset_mixer`.
Args:
dataset_mixer (`dict`):
Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
splits (Optional[List[str]], *optional*, defaults to `None`):
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
shuffle (`bool`, *optional*, defaults to `True`):
Whether to shuffle the training data.
"""
raw_datasets = DatasetDict()
raw_train_datasets = []
raw_val_datasets = []
fracs = []
for ds, frac in dataset_mixer.items():
fracs.append(frac)
for split in splits:
if "train" in split:
raw_train_datasets.append(
load_dataset(
ds,
split=split,
)
)
elif "test" in split:
raw_val_datasets.append(
load_dataset(
ds,
split=split,
)
)
else:
raise ValueError(f"Split type {split} not recognized as one of test or train.")
if any(frac < 0 for frac in fracs):
raise ValueError("Dataset fractions cannot be negative.")
if len(raw_train_datasets) > 0:
train_subsets = []
for dataset, frac in zip(raw_train_datasets, fracs):
train_subset = dataset.select(range(int(frac * len(dataset))))
train_subsets.append(train_subset)
if shuffle:
raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
else:
raw_datasets["train"] = concatenate_datasets(train_subsets)
# No subsampling for test datasets to enable fair comparison across models
if len(raw_val_datasets) > 0:
if shuffle:
raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42)
else:
raw_datasets["test"] = concatenate_datasets(raw_val_datasets)
if len(raw_datasets) == 0:
raise ValueError(
f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted."
)
return raw_datasets
+101
View File
@@ -0,0 +1,101 @@
# coding=utf-8
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
from accelerate import Accelerator
from huggingface_hub import list_repo_files
from peft import LoraConfig, PeftConfig
from .configs import DataArguments, ModelArguments
from .data import DEFAULT_CHAT_TEMPLATE
def get_current_device() -> int:
"""Get the current device. For GPU we return the local process index to enable multiple GPU training."""
return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
def get_kbit_device_map() -> Dict[str, int] | None:
"""Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
return {"": get_current_device()} if torch.cuda.is_available() else None
def get_quantization_config(model_args) -> BitsAndBytesConfig | None:
if model_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16, # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
quantization_config = None
return quantization_config
def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
"""Get the tokenizer for the model."""
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
if data_args.truncation_side is not None:
tokenizer.truncation_side = data_args.truncation_side
# Set reasonable default for models without max length
if tokenizer.model_max_length > 100_000:
tokenizer.model_max_length = 2048
if data_args.chat_template is not None:
tokenizer.chat_template = data_args.chat_template
elif tokenizer.chat_template is None:
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
return tokenizer
def get_peft_config(model_args: ModelArguments) -> PeftConfig | None:
if model_args.use_peft is False:
return None
peft_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=model_args.lora_target_modules,
modules_to_save=model_args.lora_modules_to_save,
)
return peft_config
def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool:
repo_files = list_repo_files(model_name_or_path, revision=revision)
return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files
+37
View File
@@ -0,0 +1,37 @@
# Model arguments
model_name_or_path: alignment-handbook/zephyr-7b-sft-full
# Data training arguments
# For definitions, see: src/h4/training/config.py
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 1.0
dataset_splits:
- train_prefs
- test_prefs
preprocessing_num_workers: 12
# DPOTrainer arguments
bf16: true
beta: 0.1
do_eval: true
evaluation_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 1
gradient_checkpointing: true
hub_model_id: zephyr-7b-dpo-full
learning_rate: 5.0e-7
log_level: info
logging_steps: 10
lr_scheduler_type: linear
max_length: 1024
max_prompt_length: 512
num_train_epochs: 3
optim: rmsprop
output_dir: data/zephyr-7b-dpo-full
per_device_train_batch_size: 8
per_device_eval_batch_size: 4
push_to_hub: true
save_strategy: "no"
save_total_limit: null
seed: 42
warmup_ratio: 0.1
+41
View File
@@ -0,0 +1,41 @@
# Model arguments
model_name_or_path: mistralai/Mistral-7B-v0.1
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true
# Data training arguments
dataset_mixer:
HuggingFaceH4/ultrachat_200k: 1.0
dataset_splits:
- train_sft
- test_sft
preprocessing_num_workers: 12
# SFT trainer config
bf16: true
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 2
gradient_checkpointing: true
hub_model_id: zephyr-7b-sft-full
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
max_steps: -1
num_train_epochs: 1
output_dir: data/zephyr-7b-sft-full
overwrite_output_dir: true
per_device_eval_batch_size: 16
per_device_train_batch_size: 32
push_to_hub: true
remove_unused_columns: true
report_to:
- tensorboard
save_strategy: "no"
save_total_limit: null
seed: 42
+43
View File
@@ -0,0 +1,43 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from alignment import DataArguments, H4ArgumentParser, ModelArguments, SFTConfig
class H4ArgumentParserTest(unittest.TestCase):
def setUp(self):
self.parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))
self.yaml_file_path = "tests/fixtures/config_sft_full.yaml"
def test_load_yaml(self):
model_args, data_args, training_args = self.parser.parse_yaml_file(os.path.abspath(self.yaml_file_path))
self.assertEqual(model_args.model_name_or_path, "mistralai/Mistral-7B-v0.1")
def test_load_yaml_and_args(self):
command_line_args = [
"--model_name_or_path=test",
"--use_peft=true",
"--lora_r=16",
"--lora_dropout=0.5",
]
model_args, data_args, training_args = self.parser.parse_yaml_and_args(
os.path.abspath(self.yaml_file_path), command_line_args
)
self.assertEqual(model_args.model_name_or_path, "test")
self.assertEqual(model_args.use_peft, True)
self.assertEqual(model_args.lora_r, 16)
self.assertEqual(model_args.lora_dropout, 0.5)
+148
View File
@@ -0,0 +1,148 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
from datasets import Dataset
from alignment import DataArguments, ModelArguments, apply_chat_template, get_datasets, get_tokenizer
class GetDatasetsTest(unittest.TestCase):
"""Each of these test datasets has 100 examples"""
def test_loading_data_args(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 0.5,
"HuggingFaceH4/testing_self_instruct_small": 0.3,
"HuggingFaceH4/testing_codealpaca_small": 0.2,
}
data_args = DataArguments(dataset_mixer=dataset_mixer)
datasets = get_datasets(data_args)
self.assertEqual(len(datasets["train"]), 100)
self.assertEqual(len(datasets["test"]), 300)
def test_loading_data_dict(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 0.5,
"HuggingFaceH4/testing_self_instruct_small": 0.3,
"HuggingFaceH4/testing_codealpaca_small": 0.2,
}
datasets = get_datasets(dataset_mixer)
self.assertEqual(len(datasets["train"]), 100)
self.assertEqual(len(datasets["test"]), 300)
def test_loading_with_unit_fractions(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 1.0,
"HuggingFaceH4/testing_self_instruct_small": 1.0,
"HuggingFaceH4/testing_codealpaca_small": 1.0,
}
datasets = get_datasets(dataset_mixer)
self.assertEqual(len(datasets["train"]), 300)
self.assertEqual(len(datasets["test"]), 300)
def test_loading_with_fractions_greater_than_unity(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 0.7,
"HuggingFaceH4/testing_self_instruct_small": 0.4,
}
datasets = get_datasets(dataset_mixer)
self.assertEqual(len(datasets["train"]), 70 + 40)
self.assertEqual(len(datasets["test"]), 200)
def test_loading_fails_with_negative_fractions(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 0.7,
"HuggingFaceH4/testing_self_instruct_small": -0.3,
}
with pytest.raises(ValueError, match=r"Dataset fractions cannot be negative."):
get_datasets(dataset_mixer)
def test_loading_single_split_with_unit_fractions(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 1.0,
}
datasets = get_datasets(dataset_mixer, splits=["test"])
self.assertEqual(len(datasets["test"]), 100)
self.assertRaises(KeyError, lambda: datasets["train"])
class ApplyChatTemplateTest(unittest.TestCase):
def setUp(self):
model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha")
data_args = DataArguments()
self.tokenizer = get_tokenizer(model_args, data_args)
self.dataset = Dataset.from_dict(
{
"prompt": ["Hello!"],
"messages": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]],
"chosen": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Bonjour!"}]],
"rejected": [[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hola!"}]],
}
)
def test_sft(self):
dataset = self.dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": self.tokenizer, "task": "sft"},
remove_columns=self.dataset.column_names,
)
self.assertDictEqual(
dataset[0],
{"text": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n"},
)
def test_generation(self):
# Remove last turn from messages
dataset = self.dataset.map(lambda x: {"messages": x["messages"][:-1]})
dataset = dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": self.tokenizer, "task": "generation"},
remove_columns=self.dataset.column_names,
)
self.assertDictEqual(
dataset[0],
{"text": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\n"},
)
def test_rm(self):
dataset = self.dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": self.tokenizer, "task": "rm"},
remove_columns=self.dataset.column_names,
)
self.assertDictEqual(
dataset[0],
{
"text_chosen": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n",
"text_rejected": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\nHola!</s>\n",
},
)
def test_dpo(self):
dataset = self.dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": self.tokenizer, "task": "dpo"},
remove_columns=self.dataset.column_names,
)
self.assertDictEqual(
dataset[0],
{
"text_prompt": "<|system|>\n</s>\n<|user|>\nHello!</s>\n<|assistant|>\n",
"text_chosen": "Bonjour!</s>\n",
"text_rejected": "Hola!</s>\n",
},
)
+76
View File
@@ -0,0 +1,76 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from alignment import DataArguments, ModelArguments, get_peft_config, get_quantization_config, get_tokenizer
from alignment.data import DEFAULT_CHAT_TEMPLATE
class GetQuantizationConfigTest(unittest.TestCase):
def test_4bit(self):
model_args = ModelArguments(load_in_4bit=True)
quantization_config = get_quantization_config(model_args)
self.assertTrue(quantization_config.load_in_4bit)
self.assertEqual(quantization_config.bnb_4bit_compute_dtype, torch.float16)
self.assertEqual(quantization_config.bnb_4bit_quant_type, "nf4")
self.assertFalse(quantization_config.bnb_4bit_use_double_quant)
def test_8bit(self):
model_args = ModelArguments(load_in_8bit=True)
quantization_config = get_quantization_config(model_args)
self.assertTrue(quantization_config.load_in_8bit)
def test_no_quantization(self):
model_args = ModelArguments()
quantization_config = get_quantization_config(model_args)
self.assertIsNone(quantization_config)
class GetTokenizerTest(unittest.TestCase):
def setUp(self) -> None:
self.model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha")
def test_right_truncation_side(self):
tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="right"))
self.assertEqual(tokenizer.truncation_side, "right")
def test_left_truncation_side(self):
tokenizer = get_tokenizer(self.model_args, DataArguments(truncation_side="left"))
self.assertEqual(tokenizer.truncation_side, "left")
def test_default_chat_template(self):
tokenizer = get_tokenizer(self.model_args, DataArguments())
self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE)
def test_chatml_chat_template(self):
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template))
self.assertEqual(tokenizer.chat_template, chat_template)
class GetPeftConfigTest(unittest.TestCase):
def test_peft_config(self):
model_args = ModelArguments(use_peft=True, lora_r=42, lora_alpha=0.66, lora_dropout=0.99)
peft_config = get_peft_config(model_args)
self.assertEqual(peft_config.r, 42)
self.assertEqual(peft_config.lora_alpha, 0.66)
self.assertEqual(peft_config.lora_dropout, 0.99)
def test_no_peft_config(self):
model_args = ModelArguments(use_peft=False)
peft_config = get_peft_config(model_args)
self.assertIsNone(peft_config)