mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 18:24:15 +08:00
fdb4c77d6c
- Fetch canonical reference impls for offline review:
* peft_{lora,hra,delora,ia3}_layer.py + peft_lora_{dora,variants}.py
* orig_pissa_init.py (MuLabPKU/PiSSA)
* orig_hra_layer.py (DaShenZi721/HRA)
* orig_delora.py (ExplainableML/DeLoRA author fork)
- Add reference-impl URLs to all 6 variant docstrings
- Document HRA gate=0 dead-grad issue and DoRA detach-omission in their docstrings
- Re-run external review (codex) with refs available -> docs/audit/variants_review_v2.md
Major NEW findings vs paper-only review:
* DeLoRA: scalar W.norm() should be per-input-channel norm(dim=0)
* HRA: PEFT uses symmetric repeated-column init (no dead grad), not zero gate
* IA3: FFN targets need input-side gating, not output, our up_proj advice wrong
* All LoRA-family: cfg.dropout silently ignored (no-op)
* DeLoRA: wnorm should be persistent buffer, not Parameter
HRA and DeLoRA upgraded to BUGGY (from Partial)
924 lines
39 KiB
Python
924 lines
39 KiB
Python
# Copyright 2023-present the HuggingFace Inc. team.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
from __future__ import annotations
|
||
|
||
import collections
|
||
import warnings
|
||
from typing import Any, Optional
|
||
|
||
import torch
|
||
from accelerate.utils.imports import is_xpu_available
|
||
from torch import nn
|
||
|
||
from peft.tuners.lora.config import BdLoraConfig
|
||
from peft.utils.other import transpose
|
||
|
||
from .arrow import ArrowLoraLinearLayer
|
||
from .config import LoraConfig, PeftConfig
|
||
from .dora import DoraConv1dLayer, DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer
|
||
from .layer import Conv1d, Conv2d, Conv3d, Embedding, Linear, LoraVariant, _ConvNd
|
||
|
||
|
||
class ArrowLinearVariant(LoraVariant):
|
||
@staticmethod
|
||
def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs):
|
||
"""
|
||
Initialise the ArrowLoraLinearLayer() inside lora_arrow. lora_arrow is nn.ModuleDict(), serving as a container
|
||
for ArrowLoraLinearLayer(). A layer of the base model with LoRA adapter loaded on it will be like:
|
||
----------------------------------------------------
|
||
(qkv_proj): lora.Linear4bit or lora.Linear(
|
||
(base_layer): Linear4bit or Linear (lora_dropout): ModuleDict( ... ) (lora_A): ModuleDict( ... )
|
||
(lora_B): ModuleDict( ... ) (lora_embedding_A): ParameterDict( ... ) (lora_embedding_B): ParameterDict(
|
||
... ) (lora_magnitude_vector): ModuleDict( ... ) (lora_arrow): ModuleDict(
|
||
(arrow_router): ArrowLoraLinearLayer() )
|
||
)
|
||
----------------------------------------------------
|
||
|
||
Args:
|
||
module (Linear): LoRA Layer of the model, containing base_layer, lora_A, lora_B, etc.
|
||
adapter_name (str): name of the adapter that will be put in lora_arrow.
|
||
The adapter_name is "arrow_router" by default, set in create_arrow_model() in ./arrow.py
|
||
"""
|
||
# Checking for arrow necessary config
|
||
arrow_config = config.arrow_config
|
||
if arrow_config is None:
|
||
raise ValueError("ArrowLinearVariant.init() did not receive an arrow_config")
|
||
|
||
# 1-a) build the ArrowLoRALayer
|
||
arrow_layer = ArrowLoraLinearLayer(
|
||
in_features=module.in_features,
|
||
arrow_config=arrow_config,
|
||
).to(module.weight.device)
|
||
|
||
# 1-b) register a container if it doesn’t exist yet
|
||
if not hasattr(module, "lora_arrow"):
|
||
module.lora_arrow = nn.ModuleDict()
|
||
|
||
module.lora_arrow[adapter_name] = arrow_layer
|
||
|
||
@staticmethod
|
||
def forward(
|
||
module: Linear,
|
||
*,
|
||
active_adapter: str,
|
||
x: torch.Tensor,
|
||
result: torch.Tensor,
|
||
**kwargs,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Parameters mirror those in PEFT’s `LoraVariant.forward`. Called every time the host Linear does a fwd pass.
|
||
|
||
build_prototypes() and gen_know_sub() should run only once before routing. Both are implemented in
|
||
ArrowLoraLinearLayer (see ./arrow.py). They are lazily invoked in the forward pass below. Attributes of
|
||
ArrowLoraLinearLayer() class ensure they execute only a single time.
|
||
|
||
Args:
|
||
module (Linear): LoRA Layer of the model
|
||
active_adapter (str): name of the arrow route, which should be active to perform arrow.
|
||
x (torch.Tensor): input to the layer
|
||
result (torch.Tensor): output of the base layer.
|
||
|
||
Return value:
|
||
output of the base model + delta weight computed by arrow layer.
|
||
"""
|
||
arrow = module.lora_arrow[active_adapter] # ArrowLoraLinearLayer
|
||
# Apply GenKnowSub the 1st time if applcable. By calling arrow/on_adapter_change(),
|
||
# gen_know_sub() is redone for newly added adapters after arrow.create_arrow_model().
|
||
arrow.gen_know_sub(module.lora_A, module.lora_B)
|
||
# lazily build prototypes the 1st time after GenKnowSub. By calling arrow/on_adapter_change(),
|
||
# build_prototypes() is redone for newly added adapters after arrow.create_arrow_model().
|
||
arrow.build_prototypes(module.lora_A, module.lora_B)
|
||
|
||
# A forward path of ArrowLoraLinearLayer is called so routing performs.
|
||
# Accept and ignore extra variant kwargs (e.g., 'alora_offsets') for compatibility
|
||
delta = arrow(
|
||
x,
|
||
lora_A=module.lora_A,
|
||
lora_B=module.lora_B,
|
||
dropout=module.lora_dropout[active_adapter],
|
||
scaling=module.scaling,
|
||
)
|
||
return result + delta
|
||
|
||
"""
|
||
Since Arrow is a Mixture-of-Experts (MoE) approach, merging adapters is not meaningful or even possible: for each
|
||
token, the top-k LoRA experts are dynamically selected and routed. Because of this per-token routing, there is no
|
||
single set of weights that can represent a merged adapter.
|
||
"""
|
||
|
||
@staticmethod
|
||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
raise RuntimeError("Cannot merge an active Arrow router adapter. Remove it first.")
|
||
|
||
@staticmethod
|
||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||
raise RuntimeError("Cannot merge an active Arrow router adapter. Remove it first.")
|
||
|
||
@staticmethod
|
||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
raise RuntimeError("Cannot unmerge an active Arrow router adapter. Remove it first.")
|
||
|
||
|
||
class DoraLinearVariant(LoraVariant):
|
||
@staticmethod
|
||
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
|
||
if not module.lora_magnitude_vector:
|
||
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
|
||
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_magnitude_vector",)
|
||
|
||
dora_layer = DoraLinearLayer(fan_in_fan_out=getattr(module, "fan_in_fan_out", False))
|
||
lora_A = module.lora_A[adapter_name].weight
|
||
lora_B = module.lora_B[adapter_name].weight
|
||
place_on_cpu = module.ephemeral_gpu_offload and (lora_A.device.type == "cpu" or lora_B.device.type == "cpu")
|
||
if module.ephemeral_gpu_offload:
|
||
if lora_A.device.type in ["cuda", "xpu"]:
|
||
lora_B = lora_B.to(lora_A.device)
|
||
else:
|
||
if lora_B.device.type not in ["cuda", "xpu"]:
|
||
if is_xpu_available():
|
||
lora_B = lora_B.to("xpu")
|
||
else:
|
||
lora_B = lora_B.to("cuda")
|
||
lora_A = lora_A.to(lora_B.device)
|
||
scaling = module.scaling[adapter_name]
|
||
dora_layer.update_layer(
|
||
base_layer=module.get_base_layer(),
|
||
lora_A=lora_A,
|
||
lora_B=lora_B,
|
||
scaling=scaling,
|
||
place_on_cpu=place_on_cpu,
|
||
)
|
||
module.lora_magnitude_vector[adapter_name] = dora_layer
|
||
|
||
@staticmethod
|
||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
orig_dtype = orig_weight.dtype
|
||
delta_weight = module.get_delta_weight(active_adapter)
|
||
|
||
# since delta_weight already includes scaling, set it to 1 here
|
||
weight_norm = (
|
||
module.lora_magnitude_vector[active_adapter]
|
||
.get_weight_norm(orig_weight, transpose(delta_weight, module.fan_in_fan_out), scaling=1)
|
||
.detach()
|
||
)
|
||
# We need to cache weight_norm because it has to be based on the original weights. We
|
||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||
# different value
|
||
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||
dora_factor = transpose(dora_factor.view(-1, 1), module.fan_in_fan_out)
|
||
new_weight = dora_factor * (orig_weight + delta_weight)
|
||
new_weight = new_weight.to(orig_dtype)
|
||
return new_weight
|
||
|
||
@staticmethod
|
||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||
orig_dtype = orig_weight.dtype
|
||
delta_weight = module.get_delta_weight(active_adapter)
|
||
weight_norm = (
|
||
module.lora_magnitude_vector[active_adapter]
|
||
.get_weight_norm(orig_weight, transpose(delta_weight, module.fan_in_fan_out), scaling=1)
|
||
.detach()
|
||
)
|
||
# We need to cache weight_norm because it has to be based on the original weights. We
|
||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||
# different value
|
||
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||
dora_factor = transpose(dora_factor.view(-1, 1), module.fan_in_fan_out)
|
||
new_weight = dora_factor * (orig_weight.data + delta_weight)
|
||
new_weight = new_weight.to(orig_dtype)
|
||
orig_weight.data = new_weight
|
||
|
||
@staticmethod
|
||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
orig_dtype = orig_weight.dtype
|
||
delta_weight = module.get_delta_weight(active_adapter)
|
||
weight_norm = module._cache_pop(f"{active_adapter}-weight_norm")
|
||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||
new_weight = orig_weight.data / dora_factor.view(-1, 1) - delta_weight
|
||
new_weight = new_weight.to(orig_dtype)
|
||
return new_weight
|
||
|
||
@staticmethod
|
||
def forward(
|
||
module: Linear,
|
||
active_adapter: str,
|
||
x: torch.Tensor,
|
||
result: torch.Tensor,
|
||
**kwargs,
|
||
) -> torch.Tensor:
|
||
lora_A = module.lora_A[active_adapter]
|
||
lora_B = module.lora_B[active_adapter]
|
||
dropout = module.lora_dropout[active_adapter]
|
||
scaling = module.scaling[active_adapter]
|
||
|
||
if isinstance(dropout, nn.Identity) or not module.training:
|
||
base_result = result
|
||
else:
|
||
x = dropout(x)
|
||
base_result = None
|
||
|
||
result = result + module.lora_magnitude_vector[active_adapter](
|
||
x,
|
||
lora_A=lora_A,
|
||
lora_B=lora_B,
|
||
scaling=scaling,
|
||
base_layer=module.get_base_layer(),
|
||
base_result=base_result,
|
||
adapter_name=active_adapter,
|
||
)
|
||
return result
|
||
|
||
|
||
class DoraEmbeddingVariant(DoraLinearVariant):
|
||
@staticmethod
|
||
def init(module: Embedding, adapter_name: str, **kwargs: Any) -> None:
|
||
if module.lora_magnitude_vector is None:
|
||
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
|
||
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_magnitude_vector",)
|
||
|
||
dora_layer = DoraEmbeddingLayer(fan_in_fan_out=True)
|
||
lora_embedding_A = module.lora_embedding_A[adapter_name]
|
||
lora_embedding_B = module.lora_embedding_B[adapter_name]
|
||
scaling = module.scaling[adapter_name]
|
||
dora_layer.update_layer(
|
||
base_layer=module.get_base_layer(), lora_A=lora_embedding_A, lora_B=lora_embedding_B, scaling=scaling
|
||
)
|
||
module.lora_magnitude_vector[adapter_name] = dora_layer
|
||
|
||
@staticmethod
|
||
def merge_safe(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
orig_dtype = orig_weight.dtype
|
||
delta_weight = module.get_delta_weight(active_adapter)
|
||
|
||
# since delta_weight already includes scaling, set it to 1 here
|
||
weight_norm = (
|
||
module.lora_magnitude_vector[active_adapter]
|
||
.get_weight_norm(orig_weight, delta_weight.T, scaling=1)
|
||
.detach()
|
||
)
|
||
# We need to cache weight_norm because it has to be based on the original weights. We
|
||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||
# different value
|
||
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||
dora_factor = dora_factor.view(1, -1)
|
||
new_weight = dora_factor * (orig_weight + delta_weight)
|
||
new_weight = new_weight.to(orig_dtype)
|
||
return new_weight
|
||
|
||
@staticmethod
|
||
def merge_unsafe(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||
orig_dtype = orig_weight.dtype
|
||
delta_weight = module.get_delta_weight(active_adapter)
|
||
weight_norm = (
|
||
module.lora_magnitude_vector[active_adapter]
|
||
.get_weight_norm(orig_weight, delta_weight.T, scaling=1)
|
||
.detach()
|
||
)
|
||
# We need to cache weight_norm because it has to be based on the original weights. We
|
||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||
# different value
|
||
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||
dora_factor = dora_factor.view(1, -1)
|
||
new_weight = dora_factor * (orig_weight.data + delta_weight)
|
||
new_weight = new_weight.to(orig_dtype)
|
||
orig_weight.data = new_weight
|
||
|
||
@staticmethod
|
||
def unmerge(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
orig_dtype = orig_weight.dtype
|
||
delta_weight = module.get_delta_weight(active_adapter)
|
||
weight_norm = module._cache_pop(f"{active_adapter}-weight_norm")
|
||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||
new_weight = orig_weight.data / dora_factor.view(1, -1) - delta_weight
|
||
new_weight = new_weight.to(orig_dtype)
|
||
return new_weight
|
||
|
||
@staticmethod
|
||
def forward(
|
||
module: Embedding,
|
||
active_adapter: str,
|
||
x: torch.Tensor,
|
||
result: torch.Tensor,
|
||
**kwargs,
|
||
) -> torch.Tensor:
|
||
embedding_A = module.lora_embedding_A[active_adapter].T
|
||
embedding_B = module.lora_embedding_B[active_adapter].T
|
||
scaling = module.scaling[active_adapter]
|
||
|
||
mag_norm_scale, dora_result = module.lora_magnitude_vector[active_adapter](
|
||
x,
|
||
lora_A=embedding_A,
|
||
lora_B=embedding_B,
|
||
scaling=scaling,
|
||
base_layer=module.get_base_layer(),
|
||
embed_fn=module._embed,
|
||
adapter_name=active_adapter,
|
||
)
|
||
|
||
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
|
||
# Since base_layer(x) already includes this scaling, we need to apply it to DoRA contributions too.
|
||
# Note: embed_scale is applied AFTER weight norm calculation to preserve DoRA's weight geometry semantics.
|
||
embed_scale = module._get_embed_scale()
|
||
if embed_scale is not None:
|
||
dora_result = dora_result * embed_scale.to(dora_result.dtype)
|
||
|
||
result = mag_norm_scale * result + dora_result
|
||
return result
|
||
|
||
|
||
class _DoraConvNdVariant(LoraVariant):
|
||
@staticmethod
|
||
def init_convd_variant(module: _ConvNd, adapter_name: str, dora_layer: nn.Module) -> None:
|
||
if module.lora_magnitude_vector is None:
|
||
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
|
||
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_magnitude_vector",)
|
||
|
||
lora_A = module.lora_A[adapter_name].weight
|
||
lora_B = module.lora_B[adapter_name].weight
|
||
scaling = module.scaling[adapter_name]
|
||
dora_layer.update_layer(base_layer=module.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling)
|
||
module.lora_magnitude_vector[adapter_name] = dora_layer
|
||
|
||
@staticmethod
|
||
def merge_safe(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
orig_dtype = orig_weight.dtype
|
||
delta_weight = module.get_delta_weight(active_adapter)
|
||
|
||
# since delta_weight already includes scaling, set it to 1 here
|
||
weight_norm = (
|
||
module.lora_magnitude_vector[active_adapter].get_weight_norm(orig_weight, delta_weight, scaling=1).detach()
|
||
)
|
||
# We need to cache weight_norm because it has to be based on the original weights. We
|
||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||
# different value
|
||
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||
new_weight = dora_factor.view(*module._get_dora_factor_view()) * (orig_weight + delta_weight)
|
||
new_weight = new_weight.to(orig_dtype)
|
||
return new_weight
|
||
|
||
@staticmethod
|
||
def merge_unsafe(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||
orig_dtype = orig_weight.dtype
|
||
delta_weight = module.get_delta_weight(active_adapter)
|
||
# since delta_weight already includes scaling, set it to 1 here
|
||
weight_norm = (
|
||
module.lora_magnitude_vector[active_adapter].get_weight_norm(orig_weight, delta_weight, scaling=1).detach()
|
||
)
|
||
# We need to cache weight_norm because it has to be based on the original weights. We
|
||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||
# different value
|
||
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||
new_weight = dora_factor.view(*module._get_dora_factor_view()) * (orig_weight.data + delta_weight)
|
||
new_weight = new_weight.to(orig_dtype)
|
||
orig_weight.data = new_weight
|
||
|
||
@staticmethod
|
||
def unmerge(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
orig_dtype = orig_weight.dtype
|
||
delta_weight = module.get_delta_weight(active_adapter)
|
||
weight_norm = module._cache_pop(f"{active_adapter}-weight_norm")
|
||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||
new_weight = orig_weight.data / dora_factor.view(*module._get_dora_factor_view()) - delta_weight
|
||
new_weight = new_weight.to(orig_dtype)
|
||
return new_weight
|
||
|
||
@staticmethod
|
||
def forward(
|
||
module: _ConvNd,
|
||
active_adapter: str,
|
||
x: torch.Tensor,
|
||
result: torch.Tensor,
|
||
**kwargs,
|
||
) -> torch.Tensor:
|
||
lora_A = module.lora_A[active_adapter]
|
||
lora_B = module.lora_B[active_adapter]
|
||
dropout = module.lora_dropout[active_adapter]
|
||
scaling = module.scaling[active_adapter]
|
||
|
||
if isinstance(dropout, nn.Identity) or not module.training:
|
||
base_result = result
|
||
else:
|
||
x = dropout(x)
|
||
base_result = None
|
||
|
||
result = result + module.lora_magnitude_vector[active_adapter](
|
||
x,
|
||
lora_A=lora_A,
|
||
lora_B=lora_B,
|
||
scaling=scaling,
|
||
base_layer=module.get_base_layer(),
|
||
base_result=base_result,
|
||
adapter_name=active_adapter,
|
||
)
|
||
return result
|
||
|
||
|
||
class DoraConv1dVariant(_DoraConvNdVariant):
|
||
@staticmethod
|
||
def init(module: Conv1d, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
|
||
dora_layer = DoraConv1dLayer(fan_in_fan_out=False)
|
||
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)
|
||
|
||
|
||
class DoraConv2dVariant(_DoraConvNdVariant):
|
||
@staticmethod
|
||
def init(module: Conv2d, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
|
||
dora_layer = DoraConv2dLayer(fan_in_fan_out=False)
|
||
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)
|
||
|
||
|
||
class DoraConv3dVariant(_DoraConvNdVariant):
|
||
@staticmethod
|
||
def init(module: Conv3d, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
|
||
dora_layer = DoraConv3dLayer(fan_in_fan_out=False)
|
||
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)
|
||
|
||
|
||
class QALoraLinearVariant(LoraVariant):
|
||
@staticmethod
|
||
def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
|
||
"""
|
||
Initializes QALoRA specific parameters for a given adapter.
|
||
|
||
Args:
|
||
module (Linear): The linear module to be adapted.
|
||
adapter_name (str): The name of the adapter.
|
||
config (LoraConfig): The config of the LoRA adapter.
|
||
**kwargs: Additional keyword arguments.
|
||
"""
|
||
qalora_group_size = config.qalora_group_size
|
||
if module.in_features is not None and module.in_features % qalora_group_size != 0:
|
||
raise ValueError(
|
||
f"`use_qalora=True` requires `module.in_features` ({module.in_features}) to be"
|
||
f"divisible by 'qalora_group_size' ({qalora_group_size})"
|
||
)
|
||
|
||
if "qalora_group_size" not in module.other_param_names:
|
||
module.other_param_names = module.other_param_names + ("qalora_group_size",)
|
||
|
||
if not hasattr(module, "qalora_group_size"):
|
||
module.qalora_group_size = {}
|
||
module.qalora_group_size[adapter_name] = qalora_group_size
|
||
|
||
old_lora_A_layer = module.lora_A[adapter_name]
|
||
r = old_lora_A_layer.out_features
|
||
device = old_lora_A_layer.weight.device
|
||
dtype = old_lora_A_layer.weight.dtype
|
||
|
||
new_lora_A_layer = nn.Linear(
|
||
old_lora_A_layer.in_features // module.qalora_group_size[adapter_name],
|
||
r,
|
||
bias=False,
|
||
device=device,
|
||
dtype=dtype,
|
||
)
|
||
module.lora_A[adapter_name] = new_lora_A_layer
|
||
|
||
@staticmethod
|
||
def get_delta_weight(module: Linear, active_adapter: str) -> torch.Tensor:
|
||
raise NotImplementedError("QALoRA for GPTQ layers does not support 'get_delta_weight'.")
|
||
|
||
@staticmethod
|
||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
raise NotImplementedError("QALoRA for GPTQ layers does not support 'safe_merge'.")
|
||
|
||
@staticmethod
|
||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||
raise NotImplementedError("QALoRA for GPTQ layers does not support 'merge_unsafe'.")
|
||
|
||
@staticmethod
|
||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
raise NotImplementedError("QALoRA for GPTQ layers does not support 'unmerge'.")
|
||
|
||
@staticmethod
|
||
def forward(
|
||
module: Linear,
|
||
active_adapter: str,
|
||
x: torch.Tensor,
|
||
result: torch.Tensor,
|
||
**kwargs,
|
||
) -> torch.Tensor:
|
||
lora_A_weight = module.lora_A[active_adapter].weight
|
||
lora_B_weight = module.lora_B[active_adapter].weight
|
||
dropout = module.lora_dropout[active_adapter]
|
||
scaling = module.scaling[active_adapter]
|
||
group_size = module.qalora_group_size[active_adapter]
|
||
|
||
x_dropped = dropout(x) if module.training and not isinstance(dropout, nn.Identity) else x
|
||
orig_shape = x_dropped.shape
|
||
|
||
# Reshape to 2D
|
||
if len(orig_shape) > 2:
|
||
x_flat = x_dropped.view(-1, module.in_features)
|
||
else:
|
||
x_flat = x_dropped
|
||
|
||
batch_size, in_features = x_flat.shape
|
||
pooled_features = in_features // group_size
|
||
|
||
x_pooled = x_flat.view(batch_size, pooled_features, group_size).mean(dim=2)
|
||
|
||
x_pooled_scaled = x_pooled * pooled_features
|
||
|
||
# LoRA computation
|
||
delta = x_pooled_scaled @ lora_A_weight.t() @ lora_B_weight.t() * scaling
|
||
|
||
# Reshape back
|
||
if len(orig_shape) > 2:
|
||
delta = delta.view(orig_shape[:-1] + (delta.size(-1),))
|
||
|
||
return result + delta
|
||
|
||
|
||
class ALoraLinearVariant(LoraVariant):
|
||
@staticmethod
|
||
def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
|
||
pass
|
||
|
||
@staticmethod
|
||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
raise NotImplementedError("aLoRA does not support safe merging.")
|
||
|
||
@staticmethod
|
||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||
raise NotImplementedError("aLoRA does not support merging.")
|
||
|
||
@staticmethod
|
||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
raise NotImplementedError("aLoRA does not support unmerging.")
|
||
|
||
@staticmethod
|
||
def forward(
|
||
module: Linear,
|
||
active_adapter: str,
|
||
x: torch.Tensor,
|
||
result: torch.Tensor,
|
||
**kwargs,
|
||
) -> torch.Tensor:
|
||
alora_offsets = kwargs.get("alora_offsets", None)
|
||
lora_A = module.lora_A[active_adapter]
|
||
lora_B = module.lora_B[active_adapter]
|
||
dropout = module.lora_dropout[active_adapter]
|
||
scaling = module.scaling[active_adapter]
|
||
x = x.to(lora_A.weight.dtype)
|
||
result_shape = result.shape
|
||
B = result_shape[0] # batch
|
||
if len(result_shape) == 3:
|
||
T = result_shape[1] # tokens
|
||
else:
|
||
T = 1
|
||
D = result_shape[-1] # dimensions
|
||
Dx = x.shape[-1]
|
||
device = result.device
|
||
if alora_offsets is None: # use base model only, but ensure 0 gradient
|
||
mask = torch.zeros((B, T), dtype=torch.bool)
|
||
else:
|
||
# If alora_offsets[i] is None, this means that the invocation sequence was not found in the
|
||
# input. As a result, the weights should not be activated anywhere (equivalent to base model).
|
||
# Convert None -> 0 and clip to T
|
||
offsets = torch.tensor(
|
||
[0 if o is None else min(int(o), T) for o in alora_offsets],
|
||
device=device,
|
||
dtype=torch.long,
|
||
)
|
||
# Mask True on the last `offsets[i]` positions for each row i
|
||
pos = torch.arange(T, device=device).unsqueeze(0) # [1, T]
|
||
mask = pos >= (T - offsets).unsqueeze(1)
|
||
|
||
# Flatten for vectorization
|
||
x_flat = x.view(-1, Dx)
|
||
res_flat = result.view(-1, D)
|
||
mask_flat = mask.view(-1)
|
||
|
||
# Compute adapter on the selected tokens only
|
||
res_flat[mask_flat] += lora_B(lora_A(dropout(x_flat[mask_flat]))) * scaling
|
||
return result
|
||
|
||
|
||
def calculate_alora_offsets(
|
||
peft_config: PeftConfig, active_adapter: str, input_ids: torch.Tensor, adapter_names: Optional[list[str]] = None
|
||
) -> list[int]:
|
||
"""
|
||
This is a helper function for Activated LoRA (aLoRA) that searches each input token sequence for the last
|
||
occurrence of the appropriate "alora_invocation_tokens" invocation sequence. The calculated alora_offset is the
|
||
location of the *start* of the invocation tokens, counting backward from the end (will therefore always be >=
|
||
len(alora_invocation_tokens). If adapter_names is passed, then each input uses the appropriate invocation sequence
|
||
for the specified adapter for that row. Logic is provided to handle mixed collections of adapters for which not all
|
||
are aLoRAs (e.g. some base model, some LoRA).
|
||
"""
|
||
if input_ids is None:
|
||
return []
|
||
|
||
batch_size = input_ids.shape[0]
|
||
alora_offsets = [None] * batch_size
|
||
|
||
cached_invocation_tensors = {}
|
||
adapters_to_process_indices = collections.defaultdict(list)
|
||
|
||
for i in range(batch_size):
|
||
current_adapter_name = adapter_names[i] if adapter_names and i < len(adapter_names) else active_adapter
|
||
|
||
if current_adapter_name == "__base__":
|
||
alora_offsets[i] = None
|
||
continue
|
||
|
||
if current_adapter_name not in peft_config:
|
||
warnings.warn(f"Adapter '{current_adapter_name}' not found in peft_config. Using base model for row {i}.")
|
||
alora_offsets[i] = None
|
||
continue
|
||
|
||
current_peft_config = peft_config[current_adapter_name]
|
||
|
||
invocation_tokens = getattr(current_peft_config, "alora_invocation_tokens", None)
|
||
if invocation_tokens is None:
|
||
alora_offsets[i] = None # Not an aLoRA adapter or wrong type
|
||
continue
|
||
|
||
if current_adapter_name not in cached_invocation_tensors:
|
||
cached_invocation_tensors[current_adapter_name] = torch.tensor(
|
||
invocation_tokens, dtype=torch.long, device=input_ids.device
|
||
)
|
||
|
||
adapters_to_process_indices[current_adapter_name].append(i)
|
||
|
||
for adapter_name_to_process, indices in adapters_to_process_indices.items():
|
||
current_invocation_ids_tensor = cached_invocation_tensors[adapter_name_to_process]
|
||
invocation_len = len(current_invocation_ids_tensor)
|
||
|
||
for i in indices:
|
||
sequence = input_ids[i]
|
||
seq_len = len(sequence)
|
||
best_match_start_idx = -1
|
||
|
||
possible_starts = (sequence == current_invocation_ids_tensor[0]).nonzero(as_tuple=True)[0]
|
||
|
||
for start_idx_tensor in possible_starts:
|
||
idx = start_idx_tensor.item()
|
||
if idx + invocation_len <= seq_len:
|
||
if torch.equal(sequence[idx : idx + invocation_len], current_invocation_ids_tensor):
|
||
if idx > best_match_start_idx:
|
||
best_match_start_idx = idx
|
||
|
||
if best_match_start_idx != -1:
|
||
offset_val = seq_len - best_match_start_idx
|
||
alora_offsets[i] = offset_val if offset_val > 0 else None
|
||
else: # Invocation sequence not found in input
|
||
alora_offsets[i] = None
|
||
return alora_offsets
|
||
|
||
|
||
def is_alora_relevant_in_batch(model: nn.Module, adapter_names: Optional[list[str]] = None):
|
||
"""
|
||
Helper function to determine if the current batch has any aLoRA adapters.
|
||
"""
|
||
is_alora_relevant = False
|
||
if getattr(model.active_peft_config, "alora_invocation_tokens", None):
|
||
is_alora_relevant = True
|
||
elif adapter_names:
|
||
for name in adapter_names:
|
||
if name == "__base__":
|
||
continue
|
||
config_ = model.peft_config.get(name)
|
||
if config_ and getattr(config_, "alora_invocation_tokens", None):
|
||
is_alora_relevant = True
|
||
break
|
||
|
||
return is_alora_relevant
|
||
|
||
|
||
def get_alora_offsets_for_forward(
|
||
model: nn.Module, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, **kwargs
|
||
):
|
||
"""
|
||
Wrapper around calculate_alora_offsets, for the .forward of the model. It only calculates alora_offsets if the
|
||
batch contains aLoRA adapters.
|
||
"""
|
||
adapter_names_for_offset_calc = kwargs.get("adapter_names", None)
|
||
if not is_alora_relevant_in_batch(model, adapter_names_for_offset_calc):
|
||
# Nothing to compute
|
||
return kwargs
|
||
alora_offsets = kwargs.get("alora_offsets")
|
||
if alora_offsets is None:
|
||
if input_ids is None and inputs_embeds is not None:
|
||
warnings.warn(
|
||
"Cannot calculate aLoRA offsets when only inputs_embeds are provided. Disabling aLoRA for this forward pass."
|
||
)
|
||
kwargs["alora_offsets"] = None
|
||
elif input_ids is not None:
|
||
kwargs["alora_offsets"] = calculate_alora_offsets(
|
||
model.peft_config,
|
||
model.active_adapter,
|
||
input_ids,
|
||
adapter_names=adapter_names_for_offset_calc,
|
||
)
|
||
else:
|
||
kwargs["alora_offsets"] = None
|
||
return kwargs
|
||
|
||
|
||
def get_alora_offsets_for_generate(model: nn.module, *args, **kwargs):
|
||
"""
|
||
Wrapper around calculate_alora_offsets, for the .generate of the model. It only calculates alora_offsets if the
|
||
batch contains aLoRA adapters.
|
||
"""
|
||
adapter_names_for_offset_calc = kwargs.get("adapter_names")
|
||
if not is_alora_relevant_in_batch(model, adapter_names_for_offset_calc):
|
||
# Nothing to compute
|
||
return kwargs
|
||
alora_offsets_from_kwargs = kwargs.get("alora_offsets")
|
||
if alora_offsets_from_kwargs is None:
|
||
current_input_ids = kwargs.get("input_ids")
|
||
if current_input_ids is None: # args[0] is usually input_ids
|
||
if args and isinstance(args[0], torch.Tensor):
|
||
current_input_ids = args[0]
|
||
else:
|
||
current_input_ids = None
|
||
|
||
if current_input_ids is not None:
|
||
if current_input_ids.ndim == 1:
|
||
current_input_ids = current_input_ids.unsqueeze(0)
|
||
calculated_offsets = calculate_alora_offsets(
|
||
model.peft_config,
|
||
model.active_adapter,
|
||
current_input_ids,
|
||
adapter_names=adapter_names_for_offset_calc,
|
||
)
|
||
kwargs["alora_offsets"] = calculated_offsets
|
||
|
||
else:
|
||
warnings.warn(
|
||
"Cannot calculate aLoRA offsets during generate as input_ids are not available. Disabling aLoRA."
|
||
)
|
||
|
||
kwargs["alora_offsets"] = None
|
||
return kwargs
|
||
|
||
|
||
class BlockDiagonalLinear(nn.Module):
|
||
def __init__(
|
||
self,
|
||
in_features: int,
|
||
out_features: int,
|
||
nblocks: int,
|
||
init_zero: bool = False,
|
||
dtype: torch.dtype = torch.float32,
|
||
device: torch.device = torch.device("cpu"),
|
||
):
|
||
super().__init__()
|
||
self.in_features = in_features
|
||
self.out_features = out_features
|
||
self.nblocks = nblocks
|
||
if self.in_features % nblocks != 0 or self.out_features % nblocks != 0:
|
||
raise ValueError(
|
||
f"self.in_features={self.in_features} or self.out_features={self.out_features} not divisible by {self.nblocks}"
|
||
)
|
||
# Create weight with specified dtype and device
|
||
self.weight = nn.Parameter(torch.empty(out_features, in_features // nblocks, dtype=dtype, device=device))
|
||
|
||
if init_zero:
|
||
torch.nn.init.zeros_(self.weight)
|
||
else:
|
||
torch.nn.init.kaiming_uniform_(self.weight)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
first_dims = x.shape[:-1]
|
||
if x.dim() != 2:
|
||
x = x.reshape(-1, x.shape[-1])
|
||
B = x.shape[0]
|
||
nb = self.nblocks
|
||
m = x.shape[-1] // nb
|
||
n = self.out_features // nb
|
||
x = x.reshape(B, nb, m)
|
||
w = self.weight.view(nb, n, m)
|
||
out = torch.einsum("bim,inm->bin", x, w)
|
||
return out.reshape(*first_dims, -1)
|
||
|
||
def weight_as_blockdiagonal_matrix(self):
|
||
"""Returns weight in a format similar to a vanilla LoRA adapter. For this, we stack the blocks on the diagonal,
|
||
leaving the off-diagonals padded with zero."""
|
||
return torch.block_diag(*torch.chunk(self.weight, self.nblocks, dim=0))
|
||
|
||
|
||
class BdLoraLinearVariant(LoraVariant):
|
||
@staticmethod
|
||
def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs) -> None:
|
||
use_bdlora = config.use_bdlora
|
||
target_name = kwargs.get("target_name", "")
|
||
|
||
# Handle case where use_bdlora is a dict (from saved config) instead of BdLoraConfig object
|
||
if isinstance(use_bdlora, dict):
|
||
use_bdlora = BdLoraConfig(**use_bdlora)
|
||
|
||
lora_a_blockdiagonal_pattern = use_bdlora.target_modules_bd_a or []
|
||
lora_b_blockdiagonal_pattern = use_bdlora.target_modules_bd_b or []
|
||
nblocks = use_bdlora.nblocks
|
||
|
||
has_lora_a_blockdiagonal = any(pattern in target_name for pattern in lora_a_blockdiagonal_pattern)
|
||
has_lora_b_blockdiagonal = any(pattern in target_name for pattern in lora_b_blockdiagonal_pattern)
|
||
|
||
if has_lora_a_blockdiagonal and has_lora_b_blockdiagonal:
|
||
raise ValueError(f"Target {target_name} matches both A and B block-diagonal patterns")
|
||
if use_bdlora.match_strict and not (has_lora_a_blockdiagonal or has_lora_b_blockdiagonal):
|
||
raise ValueError(
|
||
f"Target {target_name} matches neither A nor B block-diagonal patterns."
|
||
"If this is intentional, set match_strict=False in BdLoraConfig during initialization. "
|
||
)
|
||
|
||
if has_lora_a_blockdiagonal:
|
||
r = module.lora_A[adapter_name].out_features
|
||
base_layer = module.get_base_layer()
|
||
layer = BlockDiagonalLinear(
|
||
base_layer.in_features,
|
||
r,
|
||
nblocks=nblocks,
|
||
init_zero=False,
|
||
dtype=base_layer.weight.dtype,
|
||
device=base_layer.weight.device,
|
||
)
|
||
module.lora_A[adapter_name] = layer
|
||
elif has_lora_b_blockdiagonal:
|
||
r = module.lora_B[adapter_name].in_features
|
||
base_layer = module.get_base_layer()
|
||
layer = BlockDiagonalLinear(
|
||
r,
|
||
base_layer.out_features,
|
||
nblocks=nblocks,
|
||
init_zero=True,
|
||
dtype=base_layer.weight.dtype,
|
||
device=base_layer.weight.device,
|
||
)
|
||
module.lora_B[adapter_name] = layer
|
||
|
||
@staticmethod
|
||
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor, **kwargs) -> torch.Tensor:
|
||
lora_A = module.lora_A[active_adapter]
|
||
lora_B = module.lora_B[active_adapter]
|
||
dropout = module.lora_dropout[active_adapter]
|
||
scaling = module.scaling[active_adapter]
|
||
x = dropout(x)
|
||
# Cast input dtype to match lora_A weight dtype
|
||
x = module._cast_input_dtype(x, lora_A.weight.dtype)
|
||
result += lora_B(lora_A(x)) * scaling
|
||
return result
|
||
|
||
@staticmethod
|
||
def _get_weight_from_module_maybe_blockdiagonal(module: nn.Module) -> torch.Tensor:
|
||
if isinstance(module, BlockDiagonalLinear):
|
||
return module.weight_as_blockdiagonal_matrix()
|
||
else:
|
||
return module.weight # type: ignore
|
||
|
||
@staticmethod
|
||
def _get_bdlora_delta_weight(module: Linear, adapter: str) -> torch.Tensor:
|
||
"""Similar to get_delta_weight for a linear module, but we have to eventually reshape the blocks
|
||
of the weights."""
|
||
device = module.lora_B[adapter].weight.device
|
||
# Use base layer dtype to ensure compatibility with merge/unmerge operations
|
||
base_layer = module.get_base_layer()
|
||
dtype = base_layer.weight.dtype
|
||
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
|
||
|
||
weight_A = BdLoraLinearVariant._get_weight_from_module_maybe_blockdiagonal(module.lora_A[adapter])
|
||
weight_B = BdLoraLinearVariant._get_weight_from_module_maybe_blockdiagonal(module.lora_B[adapter])
|
||
|
||
if cast_to_fp32:
|
||
weight_A = weight_A.float()
|
||
weight_B = weight_B.float()
|
||
|
||
output_tensor = transpose(weight_B @ weight_A, module.fan_in_fan_out) * module.scaling[adapter]
|
||
|
||
if cast_to_fp32:
|
||
output_tensor = output_tensor.to(dtype=dtype)
|
||
|
||
# Ensure output tensor matches base layer dtype
|
||
return output_tensor.to(dtype=dtype)
|
||
|
||
@staticmethod
|
||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
return orig_weight + BdLoraLinearVariant._get_bdlora_delta_weight(module, active_adapter)
|
||
|
||
@staticmethod
|
||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||
orig_weight.data += BdLoraLinearVariant._get_bdlora_delta_weight(module, active_adapter)
|
||
|
||
@staticmethod
|
||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||
return orig_weight - BdLoraLinearVariant._get_bdlora_delta_weight(module, active_adapter)
|