diff --git a/scripts/run_simpo.py b/scripts/run_simpo.py index be69f3f..668a942 100644 --- a/scripts/run_simpo.py +++ b/scripts/run_simpo.py @@ -213,7 +213,6 @@ def main(): model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None,