mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 17:14:25 +08:00
Add auto_insert_empty_system_msg config flag (#123)
* Make system messages optional Also use the `maybe_insert_system_message` in dpo setting * add `auto_insert_empty_system_msg` flag * add `auto_insert_empty_system_msg` * add auto_insert_empty_system_msg * Update src/alignment/configs.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * make style --------- Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
+5
-1
@@ -93,7 +93,11 @@ def main():
|
||||
#####################
|
||||
raw_datasets = raw_datasets.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={"tokenizer": tokenizer, "task": "dpo"},
|
||||
fn_kwargs={
|
||||
"tokenizer": tokenizer,
|
||||
"task": "dpo",
|
||||
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
|
||||
},
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Formatting comparisons with prompt template",
|
||||
|
||||
+5
-1
@@ -100,7 +100,11 @@ def main():
|
||||
#####################
|
||||
raw_datasets = raw_datasets.map(
|
||||
apply_chat_template,
|
||||
fn_kwargs={"tokenizer": tokenizer, "task": "sft"},
|
||||
fn_kwargs={
|
||||
"tokenizer": tokenizer,
|
||||
"task": "sft",
|
||||
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
|
||||
},
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Applying chat template",
|
||||
|
||||
@@ -204,6 +204,14 @@ class DataArguments:
|
||||
truncation_side: Optional[str] = field(
|
||||
default=None, metadata={"help": "Truncation side to use for the tokenizer."}
|
||||
)
|
||||
auto_insert_empty_system_msg: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether to automatically insert an empty system message as the first message if `system` is mentioned in the chat template."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -42,11 +42,13 @@ def apply_chat_template(
|
||||
example,
|
||||
tokenizer,
|
||||
task: Literal["sft", "generation", "rm", "dpo"],
|
||||
auto_insert_empty_system_msg: bool = True,
|
||||
):
|
||||
if task in ["sft", "generation"]:
|
||||
messages = example["messages"]
|
||||
# We add an empty system message if there is none
|
||||
maybe_insert_system_message(messages, tokenizer)
|
||||
if auto_insert_empty_system_msg:
|
||||
maybe_insert_system_message(messages, tokenizer)
|
||||
example["text"] = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True if task == "generation" else False
|
||||
)
|
||||
@@ -55,8 +57,9 @@ def apply_chat_template(
|
||||
chosen_messages = example["chosen"]
|
||||
rejected_messages = example["rejected"]
|
||||
# We add an empty system message if there is none
|
||||
maybe_insert_system_message(chosen_messages, tokenizer)
|
||||
maybe_insert_system_message(rejected_messages, tokenizer)
|
||||
if auto_insert_empty_system_msg:
|
||||
maybe_insert_system_message(chosen_messages, tokenizer)
|
||||
maybe_insert_system_message(rejected_messages, tokenizer)
|
||||
|
||||
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
|
||||
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
|
||||
@@ -70,8 +73,9 @@ def apply_chat_template(
|
||||
# We therefore need to extract the N-1 turns to form the prompt
|
||||
prompt_messages = example["chosen"][:-1]
|
||||
# Prepend a system message if the first message is not a system message
|
||||
if example["chosen"][0]["role"] != "system":
|
||||
prompt_messages.insert(0, {"role": "system", "content": ""})
|
||||
if auto_insert_empty_system_msg:
|
||||
maybe_insert_system_message(prompt_messages, tokenizer)
|
||||
|
||||
# Now we extract the final turn to define chosen/rejected responses
|
||||
chosen_messages = example["chosen"][-1:]
|
||||
rejected_messages = example["rejected"][-1:]
|
||||
|
||||
Reference in New Issue
Block a user