mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 18:41:19 +08:00
@@ -27,6 +27,7 @@ from alignment import (
|
||||
H4ArgumentParser,
|
||||
ModelArguments,
|
||||
apply_chat_template,
|
||||
decontaminate_humaneval,
|
||||
get_checkpoint,
|
||||
get_datasets,
|
||||
get_kbit_device_map,
|
||||
@@ -103,6 +104,23 @@ def main():
|
||||
desc="Formatting comparisons with prompt template",
|
||||
)
|
||||
|
||||
##########################
|
||||
# Decontaminate benchmarks
|
||||
##########################
|
||||
num_raw_train_samples = len(raw_datasets["train"])
|
||||
raw_datasets = raw_datasets.filter(
|
||||
decontaminate_humaneval,
|
||||
fn_kwargs={"text_column": "text_chosen"},
|
||||
batched=True,
|
||||
batch_size=10_000,
|
||||
num_proc=1,
|
||||
desc="Decontaminating HumanEval samples",
|
||||
)
|
||||
num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"])
|
||||
logger.info(
|
||||
f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set."
|
||||
)
|
||||
|
||||
# Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
|
||||
for split in ["train", "test"]:
|
||||
raw_datasets[split] = raw_datasets[split].rename_columns(
|
||||
|
||||
+43
-24
@@ -24,7 +24,7 @@ import sys
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import set_seed
|
||||
from transformers import AutoModelForCausalLM, set_seed
|
||||
|
||||
from alignment import (
|
||||
DataArguments,
|
||||
@@ -32,6 +32,7 @@ from alignment import (
|
||||
ModelArguments,
|
||||
SFTConfig,
|
||||
apply_chat_template,
|
||||
decontaminate_humaneval,
|
||||
get_checkpoint,
|
||||
get_datasets,
|
||||
get_kbit_device_map,
|
||||
@@ -39,7 +40,7 @@ from alignment import (
|
||||
get_quantization_config,
|
||||
get_tokenizer,
|
||||
)
|
||||
from trl import SFTTrainer
|
||||
from trl import SFTTrainer, setup_chat_format
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -95,27 +96,6 @@ def main():
|
||||
################
|
||||
tokenizer = get_tokenizer(model_args, data_args)
|
||||
|
||||
#####################
|
||||
# Apply chat template
|
||||
#####################
|
||||
raw_datasets = raw_datasets.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={
|
||||
"tokenizer": tokenizer,
|
||||
"task": "sft",
|
||||
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
|
||||
},
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Applying chat template",
|
||||
)
|
||||
train_dataset = raw_datasets["train"]
|
||||
eval_dataset = raw_datasets["test"]
|
||||
|
||||
with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
|
||||
for index in random.sample(range(len(raw_datasets["train"])), 3):
|
||||
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
|
||||
|
||||
#######################
|
||||
# Load pretrained model
|
||||
#######################
|
||||
@@ -135,11 +115,50 @@ def main():
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
model = model_args.model_name_or_path
|
||||
# For ChatML we need to add special tokens and resize the embedding layer
|
||||
if "<|im_start|>" in tokenizer.chat_template:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
model_kwargs = None
|
||||
|
||||
#####################
|
||||
# Apply chat template
|
||||
#####################
|
||||
raw_datasets = raw_datasets.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={
|
||||
"tokenizer": tokenizer,
|
||||
"task": "sft",
|
||||
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
|
||||
},
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Applying chat template",
|
||||
)
|
||||
|
||||
##########################
|
||||
# Decontaminate benchmarks
|
||||
##########################
|
||||
num_raw_train_samples = len(raw_datasets["train"])
|
||||
raw_datasets = raw_datasets.filter(decontaminate_humaneval, batched=True, batch_size=10_000, num_proc=1)
|
||||
num_filtered_train_samples = num_raw_train_samples - len(raw_datasets["train"])
|
||||
logger.info(
|
||||
f"Decontaminated {num_filtered_train_samples} ({num_filtered_train_samples/num_raw_train_samples * 100:.2f}%) samples from the training set."
|
||||
)
|
||||
|
||||
train_dataset = raw_datasets["train"]
|
||||
eval_dataset = raw_datasets["test"]
|
||||
|
||||
with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
|
||||
for index in random.sample(range(len(raw_datasets["train"])), 3):
|
||||
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
|
||||
|
||||
########################
|
||||
# Initialize the Trainer
|
||||
########################
|
||||
trainer = SFTTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
model=model,
|
||||
model_init_kwargs=model_kwargs,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
|
||||
Reference in New Issue
Block a user