update trainer argument

This commit is contained in:
Yu Meng
2024-07-13 16:50:42 -04:00
parent 6c3757f3f2
commit 3e9c4cc3bd
8 changed files with 10 additions and 6 deletions
+2
View File
@@ -218,6 +218,7 @@ def main():
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
attn_implementation=training_args.attn_implementation,
)
model = model_args.model_name_or_path
@@ -245,6 +246,7 @@ def main():
# )
# model_kwargs = None
training_args.model_init_kwargs = model_kwargs
#########################
# Instantiate SimPO trainer
#########################
+2
View File
@@ -68,3 +68,5 @@ class SimPOConfig(TrainingArguments):
model_init_kwargs: Optional[Dict] = None
dataset_num_proc: Optional[int] = None
attn_implementation: str = None