Add check that parameters are not intended to be offloaded (#51)

* Add check that parameters are not intended to be offloaded

* Only push model to device if quantization config is set.
This commit is contained in:
Nathan Azrak
2023-12-04 19:10:41 +11:00
committed by GitHub
parent 15279e7157
commit 3f368a0748
2 changed files with 7 additions and 4 deletions
+4 -2
View File
@@ -105,14 +105,16 @@ def main():
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map(),
quantization_config=get_quantization_config(model_args),
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = model_args.model_name_or_path
+3 -2
View File
@@ -109,6 +109,7 @@ def main():
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
@@ -116,8 +117,8 @@ def main():
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map(),
quantization_config=get_quantization_config(model_args),
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
logger.info("*** Model loaded! ***")