mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 17:29:09 +08:00
Add fsdp+qlora support (#160)
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: true
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 2
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -7,6 +7,7 @@ In the handbook, we provide three main ways to align LLMs for chat:
|
||||
- Full fine-tuning on a multi-GPU machine with DeepSpeed ZeRO-3 (tested on an 8 x A100 (80GB) node).
|
||||
- LoRA or QLoRA fine-tuning on a single consumer 24GB GPU (tested on an RTX 4090).
|
||||
- LoRA fine-tuning on a multi-GPU machine with DeepSpeed ZeRO-3 (tested on a 2 x A100s (80GB)).
|
||||
- QLoRA fine-tuning on multi-GPU machine with FSDP (tested on a 2 x A6000s (48GB)).
|
||||
|
||||
In practice, we find comparable performance for both full and QLoRA fine-tuning, with the latter having the advantage of producing small adapter weights that are fast to upload and download from the Hugging Face Hub. Here are the general commands to fine-tune your models:
|
||||
|
||||
@@ -22,6 +23,9 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
|
||||
|
||||
# LoRA training with ZeRO-3 on two or more GPUs
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --load_in_4bit=false
|
||||
|
||||
# QLoRA training with FSDP on two or more GPUs
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/fsdp+qlora.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --torch_dtype=bfloat16 --bnb_4bit_quant_storage=bfloat16
|
||||
```
|
||||
|
||||
Here `{task}` refers to the type of training you wish to run. Currently the following tasks are supported:
|
||||
|
||||
@@ -42,7 +42,7 @@ if stale_egg_info.exists():
|
||||
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
|
||||
_deps = [
|
||||
"accelerate>=0.29.2",
|
||||
"bitsandbytes==0.41.2.post2",
|
||||
"bitsandbytes>=0.43.0",
|
||||
"black==23.1.0",
|
||||
"datasets>=2.18.0",
|
||||
"deepspeed==0.12.2",
|
||||
@@ -57,7 +57,7 @@ _deps = [
|
||||
"numpy>=1.24.2",
|
||||
"packaging>=23.0",
|
||||
"parameterized>=0.9.0",
|
||||
"peft==0.7.1",
|
||||
"peft>=0.9.0",
|
||||
"protobuf<=3.20.2", # Needed to avoid conflicts with `transformers`
|
||||
"pytest",
|
||||
"safetensors>=0.3.3",
|
||||
|
||||
@@ -185,6 +185,9 @@ class ModelArguments:
|
||||
default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
|
||||
)
|
||||
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
|
||||
bnb_4bit_quant_storage: Optional[str] = field(
|
||||
default="uint8", metadata={"help": "storage type to pack the quanitzed 4-bit prarams."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.load_in_8bit and self.load_in_4bit:
|
||||
|
||||
@@ -51,6 +51,7 @@ def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig |
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
|
||||
bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage,
|
||||
)
|
||||
elif model_args.load_in_8bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
|
||||
Reference in New Issue
Block a user