* Add StarChat2

* Add DPO

* Fix unit test

* Typos

* Typo
This commit is contained in:
lewtun
2024-03-12 17:22:21 +01:00
committed by GitHub
parent ff618a4d13
commit a9b8a50a27
14 changed files with 324 additions and 29 deletions
+18
View File
@@ -27,6 +27,7 @@ from alignment import (
H4ArgumentParser,
ModelArguments,
apply_chat_template,
decontaminate_humaneval,
get_checkpoint,
get_datasets,
get_kbit_device_map,
@@ -103,6 +104,23 @@ def main():
desc="Formatting comparisons with prompt template",
)
##########################
# Decontaminate benchmarks
##########################
num_raw_train_samples = len(raw_datasets["train"])
raw_datasets = raw_datasets.filter(
decontaminate_humaneval,
fn_kwargs={"text_column": "text_chosen"},
batched=True,
batch_size=10_000,
num_proc=1,
desc="Decontaminating HumanEval samples",
)
num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"])
logger.info(
f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set."
)
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
for split in ["train", "test"]:
raw_datasets[split] = raw_datasets[split].rename_columns(
+43 -24
View File
@@ -24,7 +24,7 @@ import sys
import datasets
import torch
import transformers
from transformers import set_seed
from transformers import AutoModelForCausalLM, set_seed
from alignment import (
DataArguments,
@@ -32,6 +32,7 @@ from alignment import (
ModelArguments,
SFTConfig,
apply_chat_template,
decontaminate_humaneval,
get_checkpoint,
get_datasets,
get_kbit_device_map,
@@ -39,7 +40,7 @@ from alignment import (
get_quantization_config,
get_tokenizer,
)
from trl import SFTTrainer
from trl import SFTTrainer, setup_chat_format
logger = logging.getLogger(__name__)
@@ -95,27 +96,6 @@ def main():
################
tokenizer = get_tokenizer(model_args, data_args)
#####################
# Apply chat template
#####################
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={
"tokenizer": tokenizer,
"task": "sft",
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Applying chat template",
)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
#######################
# Load pretrained model
#######################
@@ -135,11 +115,50 @@ def main():
quantization_config=quantization_config,
)
model = model_args.model_name_or_path
# For ChatML we need to add special tokens and resize the embedding layer
if "<|im_start|>" in tokenizer.chat_template:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
model, tokenizer = setup_chat_format(model, tokenizer)
model_kwargs = None
#####################
# Apply chat template
#####################
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={
"tokenizer": tokenizer,
"task": "sft",
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Applying chat template",
)
##########################
# Decontaminate benchmarks
##########################
num_raw_train_samples = len(raw_datasets["train"])
raw_datasets = raw_datasets.filter(decontaminate_humaneval, batched=True, batch_size=10_000, num_proc=1)
num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"])
logger.info(
f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set."
)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
for index in random.sample(range(len(raw_datasets["train"])), 3):
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
########################
# Initialize the Trainer
########################
trainer = SFTTrainer(
model=model_args.model_name_or_path,
model=model,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,