This commit is contained in:
xiamengzhou
2024-07-03 09:44:33 -04:00
parent 4142d762bc
commit 4b96360354
5 changed files with 72 additions and 26 deletions
+23 -22
View File
@@ -227,28 +227,29 @@ def main():
)
model = model_args.model_name_or_path
if is_adapter_model(model, model_args.model_revision) is True:
logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}")
peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
model_kwargs = dict(
revision=model_args.base_model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
**model_kwargs,
)
model = PeftModel.from_pretrained(
base_model,
model_args.model_name_or_path,
revision=model_args.model_revision,
)
model_kwargs = None
# seems to require internet
# if is_adapter_model(model, model_args.model_revision) is True:
# logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}")
# peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
# model_kwargs = dict(
# revision=model_args.base_model_revision,
# trust_remote_code=model_args.trust_remote_code,
# use_flash_attention_2=model_args.use_flash_attention_2,
# torch_dtype=torch_dtype,
# use_cache=False if training_args.gradient_checkpointing else True,
# device_map=get_kbit_device_map() if quantization_config is not None else None,
# quantization_config=quantization_config,
# )
# base_model = AutoModelForCausalLM.from_pretrained(
# peft_config.base_model_name_or_path,
# **model_kwargs,
# )
# model = PeftModel.from_pretrained(
# base_model,
# model_args.model_name_or_path,
# revision=model_args.model_revision,
# )
# model_kwargs = None
ref_model = model
ref_model_kwargs = model_kwargs