mirror of
https://github.com/wassname/SimPO.git
synced 2026-06-27 18:57:43 +08:00
update trainer argument
This commit is contained in:
@@ -218,6 +218,7 @@ def main():
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
attn_implementation=training_args.attn_implementation,
|
||||
)
|
||||
|
||||
model = model_args.model_name_or_path
|
||||
@@ -245,6 +246,7 @@ def main():
|
||||
# )
|
||||
# model_kwargs = None
|
||||
|
||||
training_args.model_init_kwargs = model_kwargs
|
||||
#########################
|
||||
# Instantiate SimPO trainer
|
||||
#########################
|
||||
|
||||
@@ -68,3 +68,5 @@ class SimPOConfig(TrainingArguments):
|
||||
model_init_kwargs: Optional[Dict] = None
|
||||
|
||||
dataset_num_proc: Optional[int] = None
|
||||
|
||||
attn_implementation: str = None
|
||||
|
||||
Reference in New Issue
Block a user