mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 17:14:25 +08:00
@@ -19,6 +19,7 @@ However, we know from the [InstructGPT](https://huggingface.co/papers/2203.02155
|
||||
The Alignment Handbook aims to fill that gap by providing the community with a series of robust training recipes that span the whole pipeline.
|
||||
|
||||
## News 🗞️
|
||||
* **March 12, 2024:** We release StarChat2 15B, along with the recipe to train capable coding assistants 🌟
|
||||
* **March 1, 2024:** We release Zephyr 7B Gemma, which is a new recipe to align Gemma 7B with RLAIF 🔥
|
||||
* **February 1, 2024:** We release a recipe to align open LLMs with Constitutional AI 📜! See the [recipe](https://github.com/huggingface/alignment-handbook/tree/main/recipes/constitutional-ai) and the [blog post](https://huggingface.co/blog/constitutional_ai) for details.
|
||||
* **January 18, 2024:** We release a suite of evaluations of DPO vs KTO vs IPO, see the [recipe](recipes/pref_align_scan/README.md) and the [blog post](https://huggingface.co/blog/pref-tuning) for details.
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
|
||||
# Instructions to StarChat2
|
||||
|
||||
Similar to how we trained Zephyr 7B Beta in our [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps:
|
||||
|
||||
1. Apply SFT to fine-tune [StarCoder2 15B](https://huggingface.co/bigcode/starcoder2-15b) on a blend of chat, code, and math datastets. The result is an SFT model like [`starchat2-15b-sft-v0.1`](https://huggingface.co/HuggingFaceH4/starchat2-15b-sft-v0.1).
|
||||
2. Align the SFT model to AI feedback via DPO on the UltraFeedback and Orca DPO Pairs datasets. The result is a DPO model like [`starchat2-15b-v0.1`](https://huggingface.co/HuggingFaceH4/starchat2-15b-v0.1).
|
||||
|
||||
See below for commands to train these models using DeepSpeed ZeRO-3.
|
||||
|
||||
## Full training examples
|
||||
|
||||
You will require 8 GPUs (80GB of VRAM) to train the full model - alternatively, you can train on 1 GPU by adjusting `per_device_train_batch_size` and `gradient_accumulation_steps` to keep the global batch size constant. A recipe involving QLoRA will come later 🤗.
|
||||
|
||||
```shell
|
||||
# Step 1 - SFT
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/starchat2-15b/sft/config_v0.1.yaml
|
||||
|
||||
# Step 2 - DPO
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_dpo.py recipes/starchat2-15b/dpo/config_v0.1.yaml
|
||||
```
|
||||
@@ -0,0 +1,43 @@
|
||||
# Model arguments
|
||||
model_name_or_path: HuggingFaceH4/starchat2-15b-sft-v0.1
|
||||
torch_dtype: bfloat16
|
||||
|
||||
# Data training arguments
|
||||
# For definitions, see: src/h4/training/config.py
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
HuggingFaceH4/orca_dpo_pairs: 1.0
|
||||
dataset_splits:
|
||||
- train_prefs
|
||||
- test_prefs
|
||||
preprocessing_num_workers: 12
|
||||
|
||||
# DPOTrainer arguments
|
||||
bf16: true
|
||||
beta: 0.05
|
||||
do_eval: true
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: False
|
||||
hub_model_id: starchat2-15b-dpo-v0.1
|
||||
learning_rate: 5.0e-7
|
||||
log_level: info
|
||||
logging_steps: 10
|
||||
lr_scheduler_type: cosine
|
||||
max_length: 1024
|
||||
max_prompt_length: 512
|
||||
num_train_epochs: 2
|
||||
optim: adamw_torch
|
||||
output_dir: data/starchat2-15b-dpo-v0.1
|
||||
per_device_train_batch_size: 2
|
||||
per_device_eval_batch_size: 4
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- tensorboard
|
||||
- wandb
|
||||
save_strategy: "no"
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
@@ -0,0 +1,49 @@
|
||||
# Model arguments
|
||||
model_name_or_path: bigcode/starcoder2-15b
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
use_flash_attention_2: true
|
||||
|
||||
# Data training arguments
|
||||
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/airoboros-3.2: 1.0
|
||||
HuggingFaceH4/Code-Feedback: 1.0
|
||||
HuggingFaceH4/orca-math-word-problems-200k: 1.0
|
||||
HuggingFaceH4/SystemChat: 1.0
|
||||
HuggingFaceH4/capybara: 1.0
|
||||
dataset_splits:
|
||||
- train_sft
|
||||
- test_sft
|
||||
preprocessing_num_workers: 24
|
||||
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: true
|
||||
evaluation_strategy: epoch
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: starchat2-15b-v0.1
|
||||
hub_strategy: every_save
|
||||
learning_rate: 2.0e-05
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_seq_length: 2048
|
||||
max_steps: -1
|
||||
num_train_epochs: 3
|
||||
output_dir: data/starchat2-15b-v0.1
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 8
|
||||
per_device_train_batch_size: 8
|
||||
push_to_hub: true
|
||||
remove_unused_columns: true
|
||||
report_to:
|
||||
- tensorboard
|
||||
- wandb
|
||||
save_strategy: "no"
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
@@ -39,4 +39,4 @@ report_to:
|
||||
- wandb
|
||||
save_strategy: "no"
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
warmup_ratio: 0.1
|
||||
@@ -28,7 +28,7 @@ hub_model_id: zephyr-7b-gemma-sft
|
||||
hub_strategy: every_save
|
||||
learning_rate: 2.0e-05
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_seq_length: 2048
|
||||
|
||||
@@ -27,6 +27,7 @@ from alignment import (
|
||||
H4ArgumentParser,
|
||||
ModelArguments,
|
||||
apply_chat_template,
|
||||
decontaminate_humaneval,
|
||||
get_checkpoint,
|
||||
get_datasets,
|
||||
get_kbit_device_map,
|
||||
@@ -103,6 +104,23 @@ def main():
|
||||
desc="Formatting comparisons with prompt template",
|
||||
)
|
||||
|
||||
##########################
|
||||
# Decontaminate benchmarks
|
||||
##########################
|
||||
num_raw_train_samples = len(raw_datasets["train"])
|
||||
raw_datasets = raw_datasets.filter(
|
||||
decontaminate_humaneval,
|
||||
fn_kwargs={"text_column": "text_chosen"},
|
||||
batched=True,
|
||||
batch_size=10_000,
|
||||
num_proc=1,
|
||||
desc="Decontaminating HumanEval samples",
|
||||
)
|
||||
num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"])
|
||||
logger.info(
|
||||
f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set."
|
||||
)
|
||||
|
||||
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
|
||||
for split in ["train", "test"]:
|
||||
raw_datasets[split] = raw_datasets[split].rename_columns(
|
||||
|
||||
+43
-24
@@ -24,7 +24,7 @@ import sys
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import set_seed
|
||||
from transformers import AutoModelForCausalLM, set_seed
|
||||
|
||||
from alignment import (
|
||||
DataArguments,
|
||||
@@ -32,6 +32,7 @@ from alignment import (
|
||||
ModelArguments,
|
||||
SFTConfig,
|
||||
apply_chat_template,
|
||||
decontaminate_humaneval,
|
||||
get_checkpoint,
|
||||
get_datasets,
|
||||
get_kbit_device_map,
|
||||
@@ -39,7 +40,7 @@ from alignment import (
|
||||
get_quantization_config,
|
||||
get_tokenizer,
|
||||
)
|
||||
from trl import SFTTrainer
|
||||
from trl import SFTTrainer, setup_chat_format
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -95,27 +96,6 @@ def main():
|
||||
################
|
||||
tokenizer = get_tokenizer(model_args, data_args)
|
||||
|
||||
#####################
|
||||
# Apply chat template
|
||||
#####################
|
||||
raw_datasets = raw_datasets.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={
|
||||
"tokenizer": tokenizer,
|
||||
"task": "sft",
|
||||
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
|
||||
},
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Applying chat template",
|
||||
)
|
||||
train_dataset = raw_datasets["train"]
|
||||
eval_dataset = raw_datasets["test"]
|
||||
|
||||
with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
|
||||
for index in random.sample(range(len(raw_datasets["train"])), 3):
|
||||
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
|
||||
|
||||
#######################
|
||||
# Load pretrained model
|
||||
#######################
|
||||
@@ -135,11 +115,50 @@ def main():
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
model = model_args.model_name_or_path
|
||||
# For ChatML we need to add special tokens and resize the embedding layer
|
||||
if "<|im_start|>" in tokenizer.chat_template:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
model_kwargs = None
|
||||
|
||||
#####################
|
||||
# Apply chat template
|
||||
#####################
|
||||
raw_datasets = raw_datasets.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={
|
||||
"tokenizer": tokenizer,
|
||||
"task": "sft",
|
||||
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
|
||||
},
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Applying chat template",
|
||||
)
|
||||
|
||||
##########################
|
||||
# Decontaminate benchmarks
|
||||
##########################
|
||||
num_raw_train_samples = len(raw_datasets["train"])
|
||||
raw_datasets = raw_datasets.filter(decontaminate_humaneval, batched=True, batch_size=10_000, num_proc=1)
|
||||
num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"])
|
||||
logger.info(
|
||||
f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set."
|
||||
)
|
||||
|
||||
train_dataset = raw_datasets["train"]
|
||||
eval_dataset = raw_datasets["test"]
|
||||
|
||||
with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
|
||||
for index in random.sample(range(len(raw_datasets["train"])), 3):
|
||||
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
|
||||
|
||||
########################
|
||||
# Initialize the Trainer
|
||||
########################
|
||||
trainer = SFTTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
model=model,
|
||||
model_init_kwargs=model_kwargs,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
|
||||
@@ -65,7 +65,7 @@ _deps = [
|
||||
"scipy",
|
||||
"tensorboard",
|
||||
"torch==2.1.2",
|
||||
"transformers>=4.38.2", # Fixes RoPE computation
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@831bc25d8fdb85768402f772cf65cc3d7872b211", # Enable StarCoder2
|
||||
"trl==0.7.10",
|
||||
"jinja2>=3.0.0",
|
||||
"tqdm>=4.64.1",
|
||||
|
||||
@@ -2,6 +2,7 @@ __version__ = "0.3.0.dev0"
|
||||
|
||||
from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig
|
||||
from .data import apply_chat_template, get_datasets
|
||||
from .decontaminate import decontaminate_humaneval
|
||||
from .model_utils import (
|
||||
get_checkpoint,
|
||||
get_kbit_device_map,
|
||||
|
||||
@@ -23,6 +23,8 @@ from .configs import DataArguments
|
||||
|
||||
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
||||
|
||||
COLUMNS_TO_KEEP = ["messages", "chosen", "rejected", "prompt", "completion", "label"]
|
||||
|
||||
|
||||
def maybe_insert_system_message(messages, tokenizer):
|
||||
if messages[0]["role"] == "system":
|
||||
@@ -161,6 +163,8 @@ def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffl
|
||||
# If not, check local dataset
|
||||
dataset = load_from_disk(os.path.join(ds, split))
|
||||
|
||||
# Remove redundant columns to avoid schema conflicts on load
|
||||
dataset = dataset.remove_columns([col for col in dataset.column_names if col not in COLUMNS_TO_KEEP])
|
||||
if "train" in split:
|
||||
raw_train_datasets.append(dataset)
|
||||
elif "test" in split:
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
# HumanEval solutions that are considered simple/generic enough to be kept in the training dataset
|
||||
HUMAN_EVAL_STRINGS_OK = ["return x + y", "return len(string)", "return n**2", "return " ".join(strings)"]
|
||||
|
||||
|
||||
def extract_docstring(prompt: str) -> str:
|
||||
if '"""' in prompt:
|
||||
if prompt.count('"""') == 2:
|
||||
return prompt.split('"""')[1].strip()
|
||||
elif prompt.count('"""') == 4:
|
||||
return prompt.split('"""')[3].strip()
|
||||
else:
|
||||
raise ValueError()
|
||||
elif "'''" in prompt:
|
||||
assert prompt.count("'''") == 2
|
||||
return prompt.split("'''")[1].strip()
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
|
||||
def human_eval_docstrings() -> List[str]:
|
||||
ds = load_dataset("openai_humaneval", split="test")
|
||||
docstrings = [extract_docstring(v["prompt"]) for v in ds]
|
||||
return docstrings
|
||||
|
||||
|
||||
def load_dataset_column(dataset: str, column: str, split: str, name=None) -> List[str]:
|
||||
ds = load_dataset(dataset, split=split, name=name)
|
||||
res = [sample[column].strip() for sample in ds]
|
||||
# Only return non-empty strings
|
||||
return [sample for sample in res if len(sample) > 0]
|
||||
|
||||
|
||||
FILTER_OUT = {
|
||||
"human_eval_docstrings": human_eval_docstrings(),
|
||||
"human_eval_solutions": [
|
||||
s
|
||||
for s in load_dataset_column("openai_humaneval", "canonical_solution", "test")
|
||||
if s not in HUMAN_EVAL_STRINGS_OK
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def decontaminate_humaneval(
|
||||
samples: List[Dict[str, Any]], text_column: str = "text", filter_out: Dict[str, List[str]] = FILTER_OUT
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
filter_out: Dict[str, List[str]] mapping from benchmark name to list of strings that need to be
|
||||
filtered-out.
|
||||
Return a list where each element is True if the corresponding file should be included in the dataset.
|
||||
Otherwise, the element is False.
|
||||
"""
|
||||
output = []
|
||||
|
||||
for content in samples[text_column]:
|
||||
content = normalize_whitespace(content.lower())
|
||||
matched = False
|
||||
for _, substrings in filter_out.items():
|
||||
for substring in substrings:
|
||||
if normalize_whitespace(substring.lower()) in content:
|
||||
matched = True
|
||||
break
|
||||
if matched:
|
||||
break
|
||||
# we keep files that are not matched
|
||||
output.append(not matched)
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,48 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import TestCase
|
||||
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from alignment import apply_chat_template, decontaminate_humaneval
|
||||
|
||||
|
||||
class DecontamintateHumanEvalTest(TestCase):
|
||||
"""Test we decontaminate HumanEval samples correctly"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# Create a dataset with a HumanEval sample wrapped in some fake text
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"messages": [
|
||||
[{"content": "Hello", "role": "user"}],
|
||||
[
|
||||
{
|
||||
"content": 'Hello, I am\nfrom\n\n typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n """ Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n """\n',
|
||||
"role": "assistant",
|
||||
}
|
||||
],
|
||||
]
|
||||
}
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||
self.dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer, "task": "sft"})
|
||||
|
||||
def test_decontamination(self):
|
||||
"""Test we decontaminate HumanEval samples correctly"""
|
||||
decontaminated_dataset = self.dataset.filter(decontaminate_humaneval, batched=True)
|
||||
# Check we recover just the first message
|
||||
self.assertEqual(decontaminated_dataset[0]["text"], self.dataset[0]["text"])
|
||||
@@ -70,8 +70,8 @@ class GetTokenizerTest(unittest.TestCase):
|
||||
`default_chat_template` but no `chat_template` we do not set a `chat_template`,
|
||||
and that we do not change `default_chat_template`
|
||||
"""
|
||||
model_args = ModelArguments(model_name_or_path="codellama/CodeLlama-7b-Instruct-hf")
|
||||
base_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
|
||||
model_args = ModelArguments(model_name_or_path="m-a-p/OpenCodeInterpreter-SC2-7B")
|
||||
base_tokenizer = AutoTokenizer.from_pretrained("m-a-p/OpenCodeInterpreter-SC2-7B")
|
||||
processed_tokenizer = get_tokenizer(model_args, DataArguments())
|
||||
|
||||
assert getattr(processed_tokenizer, "chat_template") is None
|
||||
|
||||
Reference in New Issue
Block a user