Add auto_insert_empty_system_msg config flag (#123)

* Make system messages optional

Also use the `maybe_insert_system_message` in dpo setting

* add `auto_insert_empty_system_msg` flag

* add `auto_insert_empty_system_msg`

* add auto_insert_empty_system_msg

* Update src/alignment/configs.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* make style

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Bram Vanroy
2024-02-28 20:05:44 +01:00
committed by GitHub
parent 87cc800498
commit d17fd7cd3b
4 changed files with 27 additions and 7 deletions
+5 -1
View File
@@ -93,7 +93,11 @@ def main():
#####################
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={"tokenizer": tokenizer, "task": "dpo"},
fn_kwargs={
"tokenizer": tokenizer,
"task": "dpo",
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Formatting comparisons with prompt template",
+5 -1
View File
@@ -100,7 +100,11 @@ def main():
#####################
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={"tokenizer": tokenizer, "task": "sft"},
fn_kwargs={
"tokenizer": tokenizer,
"task": "sft",
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Applying chat template",
+8
View File
@@ -204,6 +204,14 @@ class DataArguments:
truncation_side: Optional[str] = field(
default=None, metadata={"help": "Truncation side to use for the tokenizer."}
)
auto_insert_empty_system_msg: bool = field(
default=True,
metadata={
"help": (
"Whether to automatically insert an empty system message as the first message if `system` is mentioned in the chat template."
)
},
)
@dataclass
+9 -5
View File
@@ -42,11 +42,13 @@ def apply_chat_template(
example,
tokenizer,
task: Literal["sft", "generation", "rm", "dpo"],
auto_insert_empty_system_msg: bool = True,
):
if task in ["sft", "generation"]:
messages = example["messages"]
# We add an empty system message if there is none
maybe_insert_system_message(messages, tokenizer)
if auto_insert_empty_system_msg:
maybe_insert_system_message(messages, tokenizer)
example["text"] = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True if task == "generation" else False
)
@@ -55,8 +57,9 @@ def apply_chat_template(
chosen_messages = example["chosen"]
rejected_messages = example["rejected"]
# We add an empty system message if there is none
maybe_insert_system_message(chosen_messages, tokenizer)
maybe_insert_system_message(rejected_messages, tokenizer)
if auto_insert_empty_system_msg:
maybe_insert_system_message(chosen_messages, tokenizer)
maybe_insert_system_message(rejected_messages, tokenizer)
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
@@ -70,8 +73,9 @@ def apply_chat_template(
# We therefore need to extract the N-1 turns to form the prompt
prompt_messages = example["chosen"][:-1]
# Prepend a system message if the first message is not a system message
if example["chosen"][0]["role"] != "system":
prompt_messages.insert(0, {"role": "system", "content": ""})
if auto_insert_empty_system_msg:
maybe_insert_system_message(prompt_messages, tokenizer)
# Now we extract the final turn to define chosen/rejected responses
chosen_messages = example["chosen"][-1:]
rejected_messages = example["rejected"][-1:]