mirror of
https://github.com/wassname/SimPO.git
synced 2026-06-27 18:57:43 +08:00
update
This commit is contained in:
+23
-22
@@ -227,28 +227,29 @@ def main():
|
||||
)
|
||||
|
||||
model = model_args.model_name_or_path
|
||||
if is_adapter_model(model, model_args.model_revision) is True:
|
||||
logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}")
|
||||
peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.base_model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
peft_config.base_model_name_or_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
)
|
||||
model_kwargs = None
|
||||
# seems to require internet
|
||||
# if is_adapter_model(model, model_args.model_revision) is True:
|
||||
# logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}")
|
||||
# peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
|
||||
# model_kwargs = dict(
|
||||
# revision=model_args.base_model_revision,
|
||||
# trust_remote_code=model_args.trust_remote_code,
|
||||
# use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
# torch_dtype=torch_dtype,
|
||||
# use_cache=False if training_args.gradient_checkpointing else True,
|
||||
# device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
# quantization_config=quantization_config,
|
||||
# )
|
||||
# base_model = AutoModelForCausalLM.from_pretrained(
|
||||
# peft_config.base_model_name_or_path,
|
||||
# **model_kwargs,
|
||||
# )
|
||||
# model = PeftModel.from_pretrained(
|
||||
# base_model,
|
||||
# model_args.model_name_or_path,
|
||||
# revision=model_args.model_revision,
|
||||
# )
|
||||
# model_kwargs = None
|
||||
|
||||
ref_model = model
|
||||
ref_model_kwargs = model_kwargs
|
||||
|
||||
Reference in New Issue
Block a user