mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 16:14:07 +08:00
🪁 (#129)
* Add Gemma 7B recipe * Use Gemma template * Make it work for dolly lol * Enable cahce * Clean up * DPO to the max * DPO, DPO, DPO * Add openhermes * Add custom configs * Add kwargs * Fix config * Bump deps * Move old recipes * Add doc * Add norte * Renable cache * Nuke * Clean * Apply suggestions from code review Co-authored-by: Alvaro Bartolome <alvaro@argilla.io> * Fix isort * Update README.md * Update config_full.yaml --------- Co-authored-by: Alvaro Bartolome <alvaro@argilla.io> Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,7 @@ However, we know from the [InstructGPT](https://huggingface.co/papers/2203.02155
|
||||
The Alignment Handbook aims to fill that gap by providing the community with a series of robust training recipes that span the whole pipeline.
|
||||
|
||||
## News 🗞️
|
||||
* **March 1, 2024:** We release Zephyr 7B Gemma, which is a new recipe to align Gemma 7B with RLAIF 🔥
|
||||
* **February 1, 2024:** We release a recipe to align open LLMs with Constitutional AI 📜! See the [recipe](https://github.com/huggingface/alignment-handbook/tree/main/recipes/constitutional-ai) and the [blog post](https://huggingface.co/blog/constitutional_ai) for details.
|
||||
* **January 18, 2024:** We release a suite of evaluations of DPO vs KTO vs IPO, see the [recipe](recipes/pref_align_scan/README.md) and the [blog post](https://huggingface.co/blog/pref-tuning) for details.
|
||||
* **November 10, 2023:** We release all the training code to replicate Zephyr-7b-β 🪁! We also release [No Robots](https://huggingface.co/datasets/HuggingFaceH4/no_robots), a brand new dataset of 10,000 instructions and demonstrations written entirely by skilled human annotators.
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
|
||||
# Instructions to Replicate Zephyr 7B Gemma
|
||||
|
||||
Similar to how we trained Zephyr 7B Beta in our [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps:
|
||||
|
||||
1. Apply SFT to fine-tune Gemma 7B on the Deita 10k dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/deita-10k-v0-sft)). The result is an SFT model like [`zephyr-7b-gemma-sft`](https://huggingface.co/HuggingFaceH4/zephyr-7b-gemma-sft-v0.1).
|
||||
2. Align the SFT model to AI feedback via DPO on a curated mix of 7k examples by Argilla ([link](https://huggingface.co/datasets/argilla/dpo-mix-7k)). The result is a DPO model like [`zephyr-7b-gemma`](HuggingFaceH4/zephyr-7b-gemma-v0.1).
|
||||
|
||||
See below for commands to train these models using either DeepSpeed ZeRO-3 or LoRA.
|
||||
|
||||
## Full training examples
|
||||
|
||||
You will require 8 GPUs (80GB of VRAM) to train the full model - alternatively, you can train on 1 GPU by adjusting the micro batch size and gradient accumulation steps to keep the global batch size constant. A recipe involving QLoRA will come later 🤗.
|
||||
|
||||
```shell
|
||||
# Step 1 - SFT
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/zephyr-7b-gemma/sft/config_full.yaml
|
||||
|
||||
# Step 2 - DPO
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/zephyr-7b-gemma/dpo/config_full.yaml
|
||||
```
|
||||
@@ -0,0 +1,42 @@
|
||||
# Model arguments
|
||||
model_name_or_path: HuggingFaceH4/zephyr-7b-gemma-sft-v0.1
|
||||
torch_dtype: bfloat16
|
||||
|
||||
# Data training arguments
|
||||
# For definitions, see: src/h4/training/config.py
|
||||
dataset_mixer:
|
||||
argilla/dpo-mix-7k: 1.0
|
||||
dataset_splits:
|
||||
- train
|
||||
- test
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# DPOTrainer arguments
|
||||
bf16: true
|
||||
beta: 0.05
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: False
|
||||
hub_model_id: zephyr-7b-gemma-dpo
|
||||
learning_rate: 5.0e-7
|
||||
log_level: info
|
||||
logging_steps: 10
|
||||
lr_scheduler_type: cosine
|
||||
max_length: 1024
|
||||
max_prompt_length: 512
|
||||
num_train_epochs: 2
|
||||
optim: adamw_torch
|
||||
output_dir: data/zephyr-7b-gemma-dpo
|
||||
per_device_train_batch_size: 2
|
||||
per_device_eval_batch_size: 4
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- tensorboard
|
||||
- wandb
|
||||
save_strategy: "no"
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
@@ -0,0 +1,48 @@
|
||||
# Model arguments
|
||||
model_name_or_path: google/gemma-7b
|
||||
model_revision: main
|
||||
tokenizer_name_or_path: philschmid/gemma-tokenizer-chatml # Custom tokenizer with <|im_start|> and <|im_end|> tokens
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
|
||||
# Data training arguments
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/deita-10k-v0-sft: 1.0
|
||||
dataset_splits:
|
||||
- train_sft
|
||||
- test_sft
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
dataset_kwargs:
|
||||
add_special_tokens: false # We already wrap <bos> and <eos> in the chat template
|
||||
append_concat_token: false # No need to add <eos> across samples
|
||||
do_eval: true
|
||||
evaluation_strategy: epoch
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: zephyr-7b-gemma-sft
|
||||
hub_strategy: every_save
|
||||
learning_rate: 2.0e-05
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_seq_length: 2048
|
||||
max_steps: -1
|
||||
num_train_epochs: 3
|
||||
output_dir: data/zephyr-7b-gemma-sft
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 4
|
||||
push_to_hub: true
|
||||
remove_unused_columns: true
|
||||
report_to:
|
||||
- tensorboard
|
||||
- wandb
|
||||
save_strategy: "no"
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
+10
-10
@@ -197,16 +197,6 @@ def main():
|
||||
|
||||
logger.info("*** Training complete ***")
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(raw_datasets["test"])
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
@@ -227,6 +217,16 @@ def main():
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(raw_datasets["test"])
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.push_to_hub is True:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
+11
-11
@@ -134,7 +134,6 @@ def main():
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
logger.info("*** Model loaded! ***")
|
||||
|
||||
########################
|
||||
# Initialize the Trainer
|
||||
@@ -150,6 +149,7 @@ def main():
|
||||
tokenizer=tokenizer,
|
||||
packing=True,
|
||||
peft_config=get_peft_config(model_args),
|
||||
dataset_kwargs=training_args.dataset_kwargs,
|
||||
)
|
||||
|
||||
###############
|
||||
@@ -168,16 +168,6 @@ def main():
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(eval_dataset)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
@@ -198,6 +188,16 @@ def main():
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(eval_dataset)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.push_to_hub is True:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
@@ -41,7 +41,7 @@ if stale_egg_info.exists():
|
||||
# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
|
||||
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
|
||||
_deps = [
|
||||
"accelerate==0.23.0",
|
||||
"accelerate==0.27.2",
|
||||
"bitsandbytes==0.41.2.post2",
|
||||
"black==23.1.0",
|
||||
"datasets==2.14.6",
|
||||
@@ -65,7 +65,7 @@ _deps = [
|
||||
"scipy",
|
||||
"tensorboard",
|
||||
"torch==2.1.2",
|
||||
"transformers==4.36.2",
|
||||
"transformers>=4.38.2", # Fixes RoPE computation
|
||||
"trl==0.7.10",
|
||||
"jinja2>=3.0.0",
|
||||
"tqdm>=4.64.1",
|
||||
|
||||
@@ -136,6 +136,14 @@ class ModelArguments:
|
||||
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||
},
|
||||
)
|
||||
tokenizer_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The path to the tokenizer. Useful if you want to use a different tokenizer to the one stored in `model_name_or_path`."
|
||||
)
|
||||
},
|
||||
)
|
||||
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
|
||||
use_flash_attention_2: bool = field(
|
||||
default=False,
|
||||
@@ -220,6 +228,9 @@ class SFTConfig(transformers.TrainingArguments):
|
||||
Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
|
||||
"""
|
||||
|
||||
dataset_kwargs: Optional[Dict[str, Any]] = field(
|
||||
default=None, metadata={"help": "Dataset kwargs for the SFTTrainer"}
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
|
||||
|
||||
@@ -34,7 +34,7 @@ def maybe_insert_system_message(messages, tokenizer):
|
||||
chat_template = tokenizer.default_chat_template
|
||||
|
||||
# confirm the jinja template refers to a system message before inserting
|
||||
if "system" in chat_template:
|
||||
if "system" in chat_template or "<|im_start|>" in chat_template:
|
||||
messages.insert(0, {"role": "system", "content": ""})
|
||||
|
||||
|
||||
|
||||
@@ -65,7 +65,9 @@ def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig |
|
||||
def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
|
||||
"""Get the tokenizer for the model."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
model_args.model_name_or_path
|
||||
if model_args.tokenizer_name_or_path is None
|
||||
else model_args.tokenizer_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
)
|
||||
if tokenizer.pad_token_id is None:
|
||||
|
||||
Reference in New Issue
Block a user