change attn implementation arg

This commit is contained in:
Yu Meng
2024-08-05 00:24:47 -04:00
parent f9f7042105
commit 54545e803b
-1
View File
@@ -213,7 +213,6 @@ def main():
model_kwargs = dict( model_kwargs = dict(
revision=model_args.model_revision, revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True, use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None, device_map=get_kbit_device_map() if quantization_config is not None else None,