From 57508b5c2d2c9fee90baa2906677ac013725f892 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Thu, 4 Jan 2024 05:55:58 +0100 Subject: [PATCH] Make SFT script consistent with DPO script (#86) * Add argument * Make scripts consistent * Fix style --------- Co-authored-by: lewtun --- scripts/run_sft.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scripts/run_sft.py b/scripts/run_sft.py index 748dd71..97cc051 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -85,6 +85,7 @@ def main(): logger.info( f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" ) + column_names = list(raw_datasets["train"].features) ################ # Load tokenizer @@ -94,7 +95,13 @@ def main(): ##################### # Apply chat template ##################### - raw_datasets = raw_datasets.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer, "task": "sft"}) + raw_datasets = raw_datasets.map( + apply_chat_template, + fn_kwargs={"tokenizer": tokenizer, "task": "sft"}, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + desc="Applying chat template", + ) train_dataset = raw_datasets["train"] eval_dataset = raw_datasets["test"]