mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 16:14:07 +08:00
Add check that parameters are not intended to be offloaded (#51)
* Add check that parameters are not intended to be offloaded * Only push model to device if quantization config is set.
This commit is contained in:
+4
-2
@@ -105,14 +105,16 @@ def main():
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map(),
|
||||
quantization_config=get_quantization_config(model_args),
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
model = model_args.model_name_or_path
|
||||
|
||||
+3
-2
@@ -109,6 +109,7 @@ def main():
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
@@ -116,8 +117,8 @@ def main():
|
||||
use_flash_attention_2=model_args.use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map(),
|
||||
quantization_config=get_quantization_config(model_args),
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
logger.info("*** Model loaded! ***")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user