* Add Gemma 7B recipe

* Use Gemma template

* Make it work for dolly lol

* Enable cahce

* Clean up

* DPO to the max

* DPO, DPO, DPO

* Add openhermes

* Add custom configs

* Add kwargs

* Fix config

* Bump deps

* Move old recipes

* Add doc

* Add norte

* Renable cache

* Nuke

* Clean

* Apply suggestions from code review

Co-authored-by: Alvaro Bartolome <alvaro@argilla.io>

* Fix isort

* Update README.md

* Update config_full.yaml

---------

Co-authored-by: Alvaro Bartolome <alvaro@argilla.io>
Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
This commit is contained in:
lewtun
2024-03-01 17:29:42 +01:00
committed by GitHub
parent d17fd7cd3b
commit ff618a4d13
10 changed files with 150 additions and 25 deletions
+1
View File
@@ -19,6 +19,7 @@ However, we know from the [InstructGPT](https://huggingface.co/papers/2203.02155
The Alignment Handbook aims to fill that gap by providing the community with a series of robust training recipes that span the whole pipeline.
## News 🗞️
* **March 1, 2024:** We release Zephyr 7B Gemma, which is a new recipe to align Gemma 7B with RLAIF 🔥
* **February 1, 2024:** We release a recipe to align open LLMs with Constitutional AI 📜! See the [recipe](https://github.com/huggingface/alignment-handbook/tree/main/recipes/constitutional-ai) and the [blog post](https://huggingface.co/blog/constitutional_ai) for details.
* **January 18, 2024:** We release a suite of evaluations of DPO vs KTO vs IPO, see the [recipe](recipes/pref_align_scan/README.md) and the [blog post](https://huggingface.co/blog/pref-tuning) for details.
* **November 10, 2023:** We release all the training code to replicate Zephyr-7b-β 🪁! We also release [No Robots](https://huggingface.co/datasets/HuggingFaceH4/no_robots), a brand new dataset of 10,000 instructions and demonstrations written entirely by skilled human annotators.
+21
View File
@@ -0,0 +1,21 @@
# Instructions to Replicate Zephyr 7B Gemma
Similar to how we trained Zephyr 7B Beta in our [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps:
1. Apply SFT to fine-tune Gemma 7B on the Deita 10k dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/deita-10k-v0-sft)). The result is an SFT model like [`zephyr-7b-gemma-sft`](https://huggingface.co/HuggingFaceH4/zephyr-7b-gemma-sft-v0.1).
2. Align the SFT model to AI feedback via DPO on a curated mix of 7k examples by Argilla ([link](https://huggingface.co/datasets/argilla/dpo-mix-7k)). The result is a DPO model like [`zephyr-7b-gemma`](HuggingFaceH4/zephyr-7b-gemma-v0.1).
See below for commands to train these models using either DeepSpeed ZeRO-3 or LoRA.
## Full training examples
You will require 8 GPUs (80GB of VRAM) to train the full model - alternatively, you can train on 1 GPU by adjusting the micro batch size and gradient accumulation steps to keep the global batch size constant. A recipe involving QLoRA will come later 🤗.
```shell
# Step 1 - SFT
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/zephyr-7b-gemma/sft/config_full.yaml
# Step 2 - DPO
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-gemma/dpo/config_full.yaml
```
@@ -0,0 +1,42 @@
# Model arguments
model_name_or_path: HuggingFaceH4/zephyr-7b-gemma-sft-v0.1
torch_dtype: bfloat16
# Data training arguments
# For definitions, see: src/h4/training/config.py
dataset_mixer:
argilla/dpo-mix-7k: 1.0
dataset_splits:
- train
- test
preprocessing_num_workers: 12
# DPOTrainer arguments
bf16: true
beta: 0.05
do_eval: true
evaluation_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: False
hub_model_id: zephyr-7b-gemma-dpo
learning_rate: 5.0e-7
log_level: info
logging_steps: 10
lr_scheduler_type: cosine
max_length: 1024
max_prompt_length: 512
num_train_epochs: 2
optim: adamw_torch
output_dir: data/zephyr-7b-gemma-dpo
per_device_train_batch_size: 2
per_device_eval_batch_size: 4
push_to_hub: true
report_to:
- tensorboard
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
@@ -0,0 +1,48 @@
# Model arguments
model_name_or_path: google/gemma-7b
model_revision: main
tokenizer_name_or_path: philschmid/gemma-tokenizer-chatml # Custom tokenizer with <|im_start|> and <|im_end|> tokens
torch_dtype: bfloat16
use_flash_attention_2: true
# Data training arguments
dataset_mixer:
HuggingFaceH4/deita-10k-v0-sft: 1.0
dataset_splits:
- train_sft
- test_sft
preprocessing_num_workers: 12
# SFT trainer config
bf16: true
dataset_kwargs:
add_special_tokens: false # We already wrap <bos> and <eos> in the chat template
append_concat_token: false # No need to add <eos> across samples
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: zephyr-7b-gemma-sft
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
max_steps: -1
num_train_epochs: 3
output_dir: data/zephyr-7b-gemma-sft
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 4
push_to_hub: true
remove_unused_columns: true
report_to:
- tensorboard
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
+10 -10
View File
@@ -197,16 +197,6 @@ def main():
logger.info("*** Training complete ***")
##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(raw_datasets["test"])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
##################################
# Save model and create model card
##################################
@@ -227,6 +217,16 @@ def main():
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(raw_datasets["test"])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.push_to_hub is True:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
+11 -11
View File
@@ -134,7 +134,6 @@ def main():
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
logger.info("*** Model loaded! ***")
########################
# Initialize the Trainer
@@ -150,6 +149,7 @@ def main():
tokenizer=tokenizer,
packing=True,
peft_config=get_peft_config(model_args),
dataset_kwargs=training_args.dataset_kwargs,
)
###############
@@ -168,16 +168,6 @@ def main():
trainer.save_metrics("train", metrics)
trainer.save_state()
##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(eval_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
##################################
# Save model and create model card
##################################
@@ -198,6 +188,16 @@ def main():
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(eval_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.push_to_hub is True:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
+2 -2
View File
@@ -41,7 +41,7 @@ if stale_egg_info.exists():
# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
_deps = [
"accelerate==0.23.0",
"accelerate==0.27.2",
"bitsandbytes==0.41.2.post2",
"black==23.1.0",
"datasets==2.14.6",
@@ -65,7 +65,7 @@ _deps = [
"scipy",
"tensorboard",
"torch==2.1.2",
"transformers==4.36.2",
"transformers>=4.38.2", # Fixes RoPE computation
"trl==0.7.10",
"jinja2>=3.0.0",
"tqdm>=4.64.1",
+11
View File
@@ -136,6 +136,14 @@ class ModelArguments:
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The path to the tokenizer. Useful if you want to use a different tokenizer to the one stored in `model_name_or_path`."
)
},
)
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
use_flash_attention_2: bool = field(
default=False,
@@ -220,6 +228,9 @@ class SFTConfig(transformers.TrainingArguments):
Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
"""
dataset_kwargs: Optional[Dict[str, Any]] = field(
default=None, metadata={"help": "Dataset kwargs for the SFTTrainer"}
)
max_seq_length: Optional[int] = field(
default=None,
metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
+1 -1
View File
@@ -34,7 +34,7 @@ def maybe_insert_system_message(messages, tokenizer):
chat_template = tokenizer.default_chat_template
# confirm the jinja template refers to a system message before inserting
if "system" in chat_template:
if "system" in chat_template or "<|im_start|>" in chat_template:
messages.insert(0, {"role": "system", "content": ""})
+3 -1
View File
@@ -65,7 +65,9 @@ def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig |
def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
"""Get the tokenizer for the model."""
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
model_args.model_name_or_path
if model_args.tokenizer_name_or_path is None
else model_args.tokenizer_name_or_path,
revision=model_args.model_revision,
)
if tokenizer.pad_token_id is None: