diff --git a/scripts/run_simpo.py b/scripts/run_simpo.py index 77374a4..d5fabd2 100644 --- a/scripts/run_simpo.py +++ b/scripts/run_simpo.py @@ -26,8 +26,6 @@ from alignment import ( DPOConfig, H4ArgumentParser, ModelArguments, - maybe_insert_system_message, - is_openai_format, get_checkpoint, get_datasets, get_kbit_device_map, @@ -36,6 +34,7 @@ from alignment import ( get_tokenizer, is_adapter_model, ) +from alignment.data import maybe_insert_system_message, is_openai_format from peft import PeftConfig, PeftModel from simpo_trainer import SimPOTrainer from dataclasses import dataclass, field @@ -167,7 +166,7 @@ def main(): splits=data_args.dataset_splits, configs=data_args.dataset_configs, columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"], - seed=training_args.seed, + # seed=training_args.seed, ) logger.info( f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"