From d17fd7cd3b71c6a7bf7af34d8dc73135bb7ea8e9 Mon Sep 17 00:00:00 2001 From: Bram Vanroy <2779410+BramVanroy@users.noreply.github.com> Date: Wed, 28 Feb 2024 20:05:44 +0100 Subject: [PATCH] Add `auto_insert_empty_system_msg` config flag (#123) * Make system messages optional Also use the `maybe_insert_system_message` in dpo setting * add `auto_insert_empty_system_msg` flag * add `auto_insert_empty_system_msg` * add auto_insert_empty_system_msg * Update src/alignment/configs.py Co-authored-by: lewtun * make style --------- Co-authored-by: lewtun --- scripts/run_dpo.py | 6 +++++- scripts/run_sft.py | 6 +++++- src/alignment/configs.py | 8 ++++++++ src/alignment/data.py | 14 +++++++++----- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py index 3a41c37..c20b184 100644 --- a/scripts/run_dpo.py +++ b/scripts/run_dpo.py @@ -93,7 +93,11 @@ def main(): ##################### raw_datasets = raw_datasets.map( apply_chat_template, - fn_kwargs={"tokenizer": tokenizer, "task": "dpo"}, + fn_kwargs={ + "tokenizer": tokenizer, + "task": "dpo", + "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg, + }, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, desc="Formatting comparisons with prompt template", diff --git a/scripts/run_sft.py b/scripts/run_sft.py index eafb259..e6fa4e0 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -100,7 +100,11 @@ def main(): ##################### raw_datasets = raw_datasets.map( apply_chat_template, - fn_kwargs={"tokenizer": tokenizer, "task": "sft"}, + fn_kwargs={ + "tokenizer": tokenizer, + "task": "sft", + "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg, + }, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, desc="Applying chat template", diff --git a/src/alignment/configs.py b/src/alignment/configs.py index 2a71ea4..ba96178 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -204,6 +204,14 @@ class DataArguments: truncation_side: Optional[str] = field( default=None, metadata={"help": "Truncation side to use for the tokenizer."} ) + auto_insert_empty_system_msg: bool = field( + default=True, + metadata={ + "help": ( + "Whether to automatically insert an empty system message as the first message if `system` is mentioned in the chat template." + ) + }, + ) @dataclass diff --git a/src/alignment/data.py b/src/alignment/data.py index 267fd2e..16b095f 100644 --- a/src/alignment/data.py +++ b/src/alignment/data.py @@ -42,11 +42,13 @@ def apply_chat_template( example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"], + auto_insert_empty_system_msg: bool = True, ): if task in ["sft", "generation"]: messages = example["messages"] # We add an empty system message if there is none - maybe_insert_system_message(messages, tokenizer) + if auto_insert_empty_system_msg: + maybe_insert_system_message(messages, tokenizer) example["text"] = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True if task == "generation" else False ) @@ -55,8 +57,9 @@ def apply_chat_template( chosen_messages = example["chosen"] rejected_messages = example["rejected"] # We add an empty system message if there is none - maybe_insert_system_message(chosen_messages, tokenizer) - maybe_insert_system_message(rejected_messages, tokenizer) + if auto_insert_empty_system_msg: + maybe_insert_system_message(chosen_messages, tokenizer) + maybe_insert_system_message(rejected_messages, tokenizer) example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) @@ -70,8 +73,9 @@ def apply_chat_template( # We therefore need to extract the N-1 turns to form the prompt prompt_messages = example["chosen"][:-1] # Prepend a system message if the first message is not a system message - if example["chosen"][0]["role"] != "system": - prompt_messages.insert(0, {"role": "system", "content": ""}) + if auto_insert_empty_system_msg: + maybe_insert_system_message(prompt_messages, tokenizer) + # Now we extract the final turn to define chosen/rejected responses chosen_messages = example["chosen"][-1:] rejected_messages = example["rejected"][-1:]