Add reference-impl URLs to variant docstrings + V2 external review

- Fetch canonical reference impls for offline review:
  * peft_{lora,hra,delora,ia3}_layer.py + peft_lora_{dora,variants}.py
  * orig_pissa_init.py (MuLabPKU/PiSSA)
  * orig_hra_layer.py (DaShenZi721/HRA)
  * orig_delora.py (ExplainableML/DeLoRA author fork)
- Add reference-impl URLs to all 6 variant docstrings
- Document HRA gate=0 dead-grad issue and DoRA detach-omission in their docstrings
- Re-run external review (codex) with refs available -> docs/audit/variants_review_v2.md
  Major NEW findings vs paper-only review:
    * DeLoRA: scalar W.norm() should be per-input-channel norm(dim=0)
    * HRA: PEFT uses symmetric repeated-column init (no dead grad), not zero gate
    * IA3: FFN targets need input-side gating, not output, our up_proj advice wrong
    * All LoRA-family: cfg.dropout silently ignored (no-op)
    * DeLoRA: wnorm should be persistent buffer, not Parameter
  HRA and DeLoRA upgraded to BUGGY (from Partial)
This commit is contained in:
wassname
2026-04-26 19:27:47 +08:00
parent d0b4c52740
commit fdb4c77d6c
17 changed files with 7137 additions and 1 deletions
+74
View File
@@ -0,0 +1,74 @@
# Per-variant paper-faithfulness audit V2 (with reference implementations)
Re-audit of `lora-lite` after adding canonical reference implementation URLs to
each variant docstring. Your job: for each variant, **directly compare** our
implementation against the reference impl (peft and/or paper-author repo), not
just against the paper text. This is round 2 — the previous review (you can
read `docs/audit/variants_review.md`) found:
- HRA gate=0 init kills `lora_U` gradient on step 0
- DeLoRA same pattern with lambda0=0
- IA3 targets q/v not paper k/v/ffn-down (deviation documented but untested)
- PiSSA bf16 init err 0.31 on Qwen
- Saved adapters don't preserve PiSSA W_res mutation
Your job now is to verify those findings against the **reference code**, and
look for anything the prior review missed once you have the reference in hand.
## Inputs
- Our code: `src/lora_lite/variants/{lora,pissa,dora,ia3,hra,delora}.py`
- Adapter plumbing: `src/lora_lite/{adapter.py,target.py,variant.py,config.py}`
- Papers (text): `docs/papers/*_*.txt`
- **Reference implementations** (just added):
- `docs/refs/peft_lora_layer.py` — peft LoRA Linear (and PiSSA init paths)
- `docs/refs/peft_lora_dora.py` — peft DoRA helper module
- `docs/refs/peft_lora_variants.py` — peft per-variant init dispatch (PiSSA, OLoRA, etc.)
- `docs/refs/peft_ia3_layer.py` — peft IA3 layer
- `docs/refs/peft_hra_layer.py` — peft HRA layer (clean, has apply_GS toggle)
- `docs/refs/peft_delora_layer.py` — peft DeLoRA layer (upstreamed)
- `docs/refs/orig_pissa_init.py` — PiSSA paper authors' init script (MuLabPKU)
- `docs/refs/orig_hra_layer.py` — HRA paper authors' OFT-with-HRA layer (DaShenZi721)
- `docs/refs/orig_delora.py` — DeLoRA paper authors' fork-of-peft impl (ExplainableML)
- Logs: `logs/smoke.log`, `logs/qwen_probe.log`
- Prior review: `docs/audit/variants_review.md` (do NOT just restate it)
## What to deliver per variant (LoRA, PiSSA, DoRA, IA3, HRA, DeLoRA)
1. **Reference impl ground-truth** — what does the *reference* code actually do
for: parameter shapes, initialization, scale factor, forward equation,
save/load, target placement? Quote ≤10 lines with file/line cites from
`docs/refs/`.
2. **Our code** — quote our impl (≤10 lines, with `src/lora_lite/variants/<v>.py:LN` cites).
3. **Diff** — bullet list of every meaningful difference.
Mark each one as: `[OK-doc]` (acceptable, documented), `[OK-undoc]` (acceptable,
should add to docstring), `[BUG]` (likely wrong), `[STYLE]` (cosmetic).
4. **Did the prior review get it right?** Quote the relevant prior verdict
line and either confirm or correct.
5. **Verdict** — Faithful / Faithful-with-doc-gap / Partial / Buggy.
One-line reason.
## Final aggregate
Markdown table:
| variant | prior verdict | new verdict | new bugs found | doc gaps |
And a 5-bullet "what to fix next" list, ordered by severity.
## Hard rules
- Quote evidence from `docs/refs/` files. If you can't find the relevant
reference function, say so explicitly — don't guess.
- Do NOT edit code. Output review only.
- Be specific about line numbers from the references. "peft does X" is not
enough; "peft_lora_layer.py:L1234 does X" is.
- If you find a NEW bug not flagged in `variants_review.md`, mark it
`[NEW-BUG]` and explain the failure mode.
- If the prior review was wrong (false positive), mark it `[OVERTURN]`.
Write to stdout. I will redirect to `docs/audit/variants_review_v2.md`.
File diff suppressed because it is too large Load Diff
+446
View File
@@ -0,0 +1,446 @@
import importlib
import math
import re
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D
from ..utils import PeftConfig, PeftType, transpose
def is_bnb_available():
return importlib.util.find_spec("bitsandbytes") is not None
if is_bnb_available():
import bitsandbytes as bnb
@dataclass
class DeloraConfig(PeftConfig):
"""
This is the configuration class to store the configuration of a [`~peft.Delora`].
Args:
r (`int`): Delora attention dimension
target_modules (`Union[List[str],str]`): The names of the modules to apply Delora to.
delora_lambda (`float`): The lambda parameter for Delora scaling.
delora_dropout (`float`): The dropout probability for Delora layers.
merge_weights (`bool`):
Whether to merge the weights of the Delora layers with the base transformer model in `eval` mode.
fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out)
enable_delora ( `List[bool]`): Used with `delora.MergedLinear`.
bias (`str`): Bias type for Delora. Can be 'none', 'all' or 'delora_only'
modules_to_save (`List[str]`):List of modules apart from Delora layers to be set as trainable
and saved in the final checkpoint.
"""
r: int = field(default=8, metadata={"help": "Delora attention dimension"})
target_modules: Optional[Union[List[str], str]] = field(
default=None,
metadata={
"help": "List of module names or regex expression of the module names to replace with Delora."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
},
)
delora_lambda: int = field(default=None, metadata={"help": "Delora lambda"})
delora_dropout: float = field(default=None, metadata={"help": "Delora dropout"})
Wdecompose_target_modules: Optional[Union[List[str], str]] = field(
default=None,
metadata={
"help": "List of module names or regex expression of the module names to only tune the magnitude part"
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
},
)
merge_weights: bool = field(
default=False, metadata={"help": "Merge weights of the original model and the Delora model"}
)
fan_in_fan_out: bool = field(
default=False,
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
)
enable_delora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `delora.MergedLinear`."})
bias: str = field(default="none", metadata={"help": "Bias type for Delora. Can be 'none', 'all' or 'delora_only'"})
modules_to_save: Optional[List[str]] = field(
default=None,
metadata={
"help": "List of modules apart from Delora layers to be set as trainable and saved in the final checkpoint. "
"For example, in Sequence Classification or Token Classification tasks, "
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
},
)
def __post_init__(self):
self.peft_type = PeftType.DELORA
class DeloraModel(torch.nn.Module):
"""
Creates Decoupled Low Rank Adapter (Delora) model from a pretrained transformers model.
Args:
model ([`transformers.PreTrainedModel`]): The model to be adapted.
config ([`DeloraConfig`]): The configuration of the Delora model.
Returns:
`torch.nn.Module`: The Delora model.
Example::
>>> from transformers import AutoModelForSeq2SeqLM, DeloraConfig >>> from peft import DeloraModel, DeloraConfig >>>
config = DeloraConfig(
peft_type="DELORA", task_type="SEQ_2_SEQ_LM", r=8, delora_lambda=32, target_modules=["q", "v"],
delora_dropout=0.01, )
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> delora_model = DeloraModel(config, model)
**Attributes**:
- **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted.
- **peft_config** ([`DeloraConfig`]): The configuration of the Delora model.
"""
def __init__(self, config, model):
super().__init__()
self.peft_config = config
print(self.peft_config)
self.model = model
self._find_and_replace()
mark_only_delora_as_trainable(self.model, self.peft_config.bias)
self.forward = self.model.forward
def _find_and_replace(self):
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
if loaded_in_8bit and not is_bnb_available():
raise ImportError(
"To use Delora with 8-bit quantization, please install the `bitsandbytes` package. "
"You can install it with `pip install bitsandbytes`."
)
is_target_modules_in_base_model = False
is_hf_device_map_available = hasattr(self.model, "hf_device_map")
kwargs = {
"r": self.peft_config.r,
"delora_lambda": self.peft_config.delora_lambda,
"delora_dropout": self.peft_config.delora_dropout,
"fan_in_fan_out": self.peft_config.fan_in_fan_out,
"merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode)
and not is_hf_device_map_available,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
if isinstance(self.peft_config.target_modules, str):
target_module_found = re.fullmatch(self.peft_config.target_modules, key)
else:
target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = self._get_submodules(key)
bias = target.bias is not None
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold,
"index": target.index,
}
)
if self.peft_config.enable_delora is None:
print("8 bit delora")
new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
else:
kwargs.update({"enable_delora": self.peft_config.enable_delora})
new_module = MergedLinear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_delora is None:
new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
elif self.peft_config.enable_delora is not None:
kwargs.update({"enable_delora": self.peft_config.enable_delora})
if isinstance(target, Conv1D):
in_features, out_features = (
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
)
else:
in_features, out_features = target.in_features, target.out_features
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is not a Conv1D. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = self.peft_config.fan_in_fan_out = False
new_module = MergedLinear(in_features, out_features, bias=bias, **kwargs)
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {self.peft_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
def _get_submodules(self, key):
parent = self.model.get_submodule(".".join(key.split(".")[:-1]))
target_name = key.split(".")[-1]
target = self.model.get_submodule(key)
return parent, target, target_name
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
new_module.weight = old_module.weight
if old_module.bias is not None:
new_module.bias = old_module.bias
if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.weight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
if "delora_" in name:
module.to(old_module.weight.device)
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.model, name)
@property
def modules_to_save(self):
return None
def get_peft_config_as_dict(self, inference: bool = False):
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()}
if inference:
config["inference_mode"] = True
return config
def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, DeloraLayer):
module.disable_adapters = False if enabled else True
def enable_adapter_layers(self):
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False)
# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# and modified to work with PyTorch FSDP
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
# had to adapt it for `delora_only` to work
def mark_only_delora_as_trainable(model: nn.Module, bias: str = "none") -> None:
for n, p in model.named_parameters():
if "delora_" not in n:
p.requires_grad = False
if bias == "none":
return
elif bias == "all":
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "delora_only":
for m in model.modules():
if isinstance(m, DeloraLayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
class DeloraLayer:
def __init__(
self,
r: int,
delora_lambda_value: int,
delora_dropout: float,
merge_weights: bool,
):
self.r = r
self.delora_lambda_value = delora_lambda_value
# Optional dropout
if delora_dropout > 0.0:
self.delora_dropout = nn.Dropout(p=delora_dropout)
else:
self.delora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
self.disable_adapters = False
class Linear(nn.Linear, DeloraLayer):
# Delora implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
delora_lambda: float = 1.,
delora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = False,
**kwargs,
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
DeloraLayer.__init__(self, r=r, delora_lambda_value=delora_lambda, delora_dropout=delora_dropout, merge_weights=merge_weights)
self.fan_in_fan_out = fan_in_fan_out
if r > 0:
self.delora_A = nn.Linear(in_features, r, bias=False)
self.delora_B = nn.Linear(r, out_features, bias=False)
self.delora_lambda = nn.Parameter(torch.full((1,), delora_lambda), requires_grad=True)
# Frozen parameters
self.frozen_C = nn.Parameter(torch.empty_like(self.delora_A.weight).copy_(self.delora_A.weight))
self.frozen_C.requires_grad = False
self.frozen_D = nn.Parameter(torch.empty_like(self.delora_B.weight).copy_(self.delora_B.weight))
self.frozen_D.requires_grad = False
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, "delora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.delora_A.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.delora_B.weight, a=math.sqrt(5))
nn.init.constant_(self.delora_lambda, self.delora_lambda_value)
self.frozen_C.data = self.delora_A.weight.data
self.frozen_D.data = self.delora_B.weight.data
def get_ABCD(self):
# Get weights
delora_A_weight = self.delora_A.weight # shape: (r, in_features)
delora_B_weight = self.delora_B.weight # shape: (out_features, r)
# Get norms
delora_A_norm = delora_A_weight.norm(dim=1) # shape: (r,)
delora_B_norm = delora_B_weight.norm(dim=0) # shape: (r,)
frozen_C_norm = self.frozen_C.norm(dim=1) # shape: (r,)
frozen_D_norm = self.frozen_D.norm(dim=0) # shape: (r,)
# AB normalization
diag12 = torch.div(self.delora_lambda / self.r, torch.mul(delora_A_norm, delora_B_norm))
diag12 = torch.diag_embed(diag12)
diag34 = torch.div(self.delora_lambda / self.r, torch.mul(frozen_C_norm, frozen_D_norm))
diag34 = torch.diag_embed(diag34)
# Get ABCD
ABCD = delora_B_weight @ diag12 @ delora_A_weight
ABCD = ABCD - self.frozen_D @ diag34 @ self.frozen_C
# W scaling
Wnorm = self.weight.data.norm(dim=0) # shape: (in_features,)
ABCD = torch.mul(ABCD, Wnorm.unsqueeze(0)) # shape: (out_features, in_features)
return ABCD
def train(self, mode: bool = True):
nn.Linear.train(self, mode)
self.delora_A.train(mode)
self.delora_B.train(mode)
self.delora_lambda.requires_grad = mode
if not mode and self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += self.get_ABCD().to(self.weight.device, dtype=self.weight.dtype)
self.merged = True
elif self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= self.get_ABCD()
self.merged = False
def eval(self):
nn.Linear.eval(self)
self.delora_A.eval()
self.delora_B.eval()
def forward(self, x: torch.Tensor):
previous_dtype = self.weight.dtype
if self.disable_adapters:
if self.r > 0 and self.merged:
self.weight.data -= self.get_ABCD()
self.merged = False
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.r > 0 and not self.merged:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.r > 0:
result += F.linear(self.delora_dropout(x), self.get_ABCD(), bias=None)
else:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if result.dtype != previous_dtype:
result = result.to(previous_dtype)
return result
class MergedLinear(nn.Linear, DeloraLayer):
# Delora implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
delora_lambda: int = 1,
delora_dropout: float = 0.0,
enable_delora: List[bool] = [False],
fan_in_fan_out: bool = False,
merge_weights: bool = True,
**kwargs,
):
raise NotImplementedError
if is_bnb_available():
class Linear8bitLt(bnb.nn.Linear8bitLt, DeloraLayer):
# Delora implemented in a dense layer
def __init__(
self,
in_features,
out_features,
r: int = 0,
delora_lambda: int = 1,
delora_dropout: float = 0.0,
Wdecompose: bool = False,
**kwargs,
):
raise NotImplementedError
class MergedLinear8bitLt(bnb.nn.Linear8bitLt, DeloraLayer):
# Delora implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
delora_lambda: int = 1,
delora_dropout: float = 0.0,
enable_delora: List[bool] = [False],
**kwargs,
):
raise NotImplementedError
+420
View File
@@ -0,0 +1,420 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import Any, List, Optional, Set, Tuple
import torch
import torch.nn as nn
from peft.tuners.lycoris_utils import LycorisLayer, check_adapters_to_merge
class OFTLayer(nn.Module, LycorisLayer):
# All names of layers that may contain adapter weights
adapter_layer_names = ("oft_r",)
# other_param_names is defined on parent class
def __init__(self, base_layer: nn.Module):
super().__init__()
LycorisLayer.__init__(self, base_layer)
# OFT info
self.oft_r = nn.ParameterDict({})
self.coft = {}
self.eps = {}
self.block_share = {}
@property
def _available_adapters(self) -> Set[str]:
return {*self.oft_r}
def create_adapter_parameters(self, adapter_name: str, r: int, shape: Tuple[int, ...], block_share: bool):
# if block_share:
# self.oft_r[adapter_name] = nn.Parameter(torch.empty(1, math.ceil(shape[0] / r), math.ceil(shape[0] / r)))
# else:
# self.oft_r[adapter_name] = nn.Parameter(torch.empty(r, math.ceil(shape[0] / r), math.ceil(shape[0] / r)))
weight = getattr(self.get_base_layer(), "weight", None)
# self.oft_r[adapter_name] = nn.Parameter(torch.cat([weight.new_ones(r, r), weight.new_zeros(shape[0]-r, r)], dim=0))
self.oft_r[adapter_name] = nn.Parameter(
torch.cat([torch.eye(r, device=weight.device, dtype=weight.dtype),
torch.zeros(shape[0] - r, r, device=weight.device, dtype=weight.dtype)], dim=0))
def reset_adapter_parameters(self, adapter_name: str):
# nn.init.zeros_(self.oft_r[adapter_name])
nn.init.kaiming_uniform_(self.oft_r[adapter_name], a=1 / self.eps[adapter_name])
def reset_adapter_parameters_random(self, adapter_name: str):
# nn.init.kaiming_uniform_(self.oft_r[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.oft_r[adapter_name], a=1 / self.eps[adapter_name])
def update_layer(
self,
adapter_name: str,
r: int,
module_dropout: float,
init_weights: bool,
coft: bool = False,
eps: float = 6e-5,
block_share: bool = False,
**kwargs,
) -> None:
"""Internal function to create oft adapter
Args:
adapter_name (`str`): Name for the adapter to add.
r (`int`): Rank for the added adapter.
module_dropout (`float`): The dropout probability for disabling adapter during training.
init_weights (`bool`): Whether to initialize weights.
coft (`bool`): Whether to use the constrained variant of OFT or not.
eps (`float`):
The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True.
block_share (`bool`): Whether to share the OFT parameters between blocks or not.
"""
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.module_dropout[adapter_name] = module_dropout
self.coft[adapter_name] = coft
self.block_share[adapter_name] = block_share
# Determine shape of OFT weights
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
shape = tuple(base_layer.weight.shape)
elif isinstance(base_layer, nn.Conv2d):
shape = (
base_layer.out_channels,
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
)
else:
raise TypeError(f"OFT is not implemented for base layers of type {type(base_layer).__name__}")
# self.eps[adapter_name] = eps * math.ceil(shape[0] / r) * math.ceil(shape[0] / r)
self.eps[adapter_name] = eps
# Create weights with provided shape
self.create_adapter_parameters(adapter_name, r, shape, block_share)
# Initialize weights
# if init_weights:
# self.reset_adapter_parameters(adapter_name)
# else:
# self.reset_adapter_parameters_random(adapter_name)
# Move new weights to device
weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
self.set_adapter(self.active_adapters)
def unscale_layer(self, scale=None) -> None:
# scale is not used
pass
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If `None`, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self._available_adapters:
base_layer = self.get_base_layer()
orig_weights = base_layer.weight.data
if isinstance(base_layer, nn.Linear):
orig_weights = torch.transpose(orig_weights, 0, 1)
elif isinstance(base_layer, nn.Conv2d):
orig_weights = orig_weights.view(
[
base_layer.out_channels,
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
]
)
orig_weights = torch.transpose(orig_weights, 0, 1)
delta_weight = self.get_delta_weight(active_adapter)
if orig_weights.shape[1] != delta_weight.shape[1]:
# when in channels is not divisible by r
delta_weight = delta_weight[: orig_weights.shape[1], : orig_weights.shape[1]]
# delta_weight=delta_weight.to(orig_weights.device, dtype=orig_weights.dtype)
new_weights = torch.mm(orig_weights, delta_weight)
if isinstance(base_layer, nn.Linear):
new_weights = torch.transpose(new_weights, 0, 1)
elif isinstance(base_layer, nn.Conv2d):
new_weights = torch.transpose(new_weights, 0, 1)
new_weights = new_weights.view(
[
base_layer.out_channels,
base_layer.in_channels,
base_layer.kernel_size[0],
base_layer.kernel_size[1],
]
)
if safe_merge and not torch.isfinite(new_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = new_weights
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self._available_adapters:
base_layer = self.get_base_layer()
new_weights = base_layer.weight.data
if isinstance(base_layer, nn.Linear):
new_weights = torch.transpose(new_weights, 0, 1)
elif isinstance(base_layer, nn.Conv2d):
new_weights = new_weights.view(
[
base_layer.out_channels,
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
]
)
new_weights = torch.transpose(new_weights, 0, 1)
delta_weight = self.get_delta_weight(active_adapter)
if new_weights.shape[1] != delta_weight.shape[1]:
# when in channels is not divisible by r
delta_weight = delta_weight[: new_weights.shape[1], : new_weights.shape[1]]
delta_inv = torch.inverse(delta_weight)
orig_weights = torch.mm(new_weights, delta_inv)
if isinstance(base_layer, nn.Linear):
orig_weights = torch.transpose(orig_weights, 0, 1)
elif isinstance(base_layer, nn.Conv2d):
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = orig_weights.reshape(
[
base_layer.out_channels,
base_layer.in_channels,
base_layer.kernel_size[0],
base_layer.kernel_size[1],
]
)
base_layer.weight.data = orig_weights
def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
# rank = self.r[adapter_name]
# coft = self.coft[adapter_name]
# eps = self.eps[adapter_name]
# opt_r = self.oft_r[adapter_name]
# if coft:
# with torch.no_grad():
# opt_r.copy_(self._project_batch(opt_r, eps=eps))
#
# orth_rotate = self._cayley_batch(opt_r)
# weight = self._block_diagonal(orth_rotate, rank)
rank = self.r[adapter_name]
hrft_v = self.oft_r[adapter_name]
in_features = self.oft_r[adapter_name].size(0)
device = self.oft_r[adapter_name].device
dtype = self.oft_r[adapter_name].dtype
# unit_v_list = [hrft_v[:, i].view(-1,1) / (torch.sqrt(torch.sum(hrft_v[:,i] ** 2) + self.eps[adapter_name])) for i in range(8)]
# weight = torch.eye(in_features, device=device, dtype=dtype)
# for unit_v in unit_v_list:
# weight = torch.mm(weight, torch.eye(in_features, device=device, dtype=dtype) - 2 * unit_v @ unit_v.t())
U_list = []
U_list.append((hrft_v[:, 0] / hrft_v[:, 0].norm()).view(-1, 1))
for i in range(1, rank):
Ui = hrft_v[:, i].view(-1, 1)
for j in range(i):
Ui = Ui - (U_list[j].t() @ Ui) * U_list[j]
U_list.append((Ui / Ui.norm()).view(-1, 1))
U_list = torch.cat(U_list, dim=1)
weight = torch.eye(in_features, device=device, dtype=dtype) - 2 * U_list @ U_list.t()
# weight = torch.eye(in_features, device=device) - 2 * U_list @ U_list.t()
return weight
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L144
def _cayley_batch(self, data: torch.Tensor) -> torch.Tensor:
b, r, c = data.shape
# Ensure the input matrix is skew-symmetric
skew = 0.5 * (data - data.transpose(1, 2))
I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) # noqa: E741
# Perform the Cayley parametrization
Q = torch.bmm(I - skew, torch.inverse(I + skew))
return Q
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L155
def _block_diagonal(self, oft_r: torch.Tensor, rank: int) -> torch.Tensor:
if oft_r.shape[0] == 1:
# block share
blocks = [oft_r[0, ...] for i in range(rank)]
else:
blocks = [oft_r[i, ...] for i in range(rank)]
# Use torch.block_diag to create the block diagonal matrix
A = torch.block_diag(*blocks)
return A
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L52
def _project_batch(self, oft_r, eps=1e-5):
# scaling factor for each of the smaller block matrix
eps = eps * 1 / torch.sqrt(torch.tensor(oft_r.shape[0]))
I = ( # noqa: E741
torch.zeros((oft_r.size(1), oft_r.size(1)), device=oft_r.device, dtype=oft_r.dtype)
.unsqueeze(0)
.expand_as(oft_r)
)
diff = oft_r - I
norm_diff = torch.norm(oft_r - I, dim=(1, 2), keepdim=True)
mask = (norm_diff <= eps).bool()
out = torch.where(mask, oft_r, I + eps * (diff / norm_diff))
return out
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
if len(result.shape) == 4:
result = result.permute(0, 2, 3, 1)
base_layer = self.get_base_layer()
base_bias = base_layer.bias
if base_bias is not None:
# Bias should be added after OFT forward
result = result - base_bias.data
# Execute all the adapters
for active_adapter in self.active_adapters:
if active_adapter not in self._available_adapters:
continue
module_dropout = self.module_dropout[active_adapter]
# Modify current execution weights
if (not self.training) or (self.training and torch.rand(1) > module_dropout):
result = self._get_delta_activations(active_adapter, result, *args, **kwargs)
if base_bias is not None:
result = result + base_bias.data
if len(result.shape) == 4:
result = result.permute(0, 3, 1, 2)
result = result.to(previous_dtype)
return result
class Linear(OFTLayer):
"""OFT implemented in Linear layer"""
def __init__(
self,
base_layer: nn.Module,
adapter_name: str = "default",
r: int = 0,
module_dropout: float = 0.0,
init_weights: bool = True,
**kwargs,
):
super().__init__(base_layer)
# Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs)
def _get_delta_activations(
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
delta_weight = self.get_delta_weight(adapter_name)
base_layer = self.get_base_layer()
base_weight = base_layer.weight.data
delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]]
# don't add bias here, because the bias will be added after OFT forward
return torch.matmul(input, delta_weight)
def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep
class Conv2d(OFTLayer):
"""OFT implemented in Conv2d layer"""
def __init__(
self,
base_layer: nn.Module,
adapter_name: str = "default",
r: int = 0,
module_dropout: float = 0.0,
init_weights: bool = True,
**kwargs,
):
super().__init__(base_layer)
# Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs)
def _get_delta_activations(
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
delta_weight = self.get_delta_weight(adapter_name)
base_layer = self.get_base_layer()
base_weight = base_layer.weight.data
delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]]
# don't add bias here, because the bias will be added after OFT forward
return torch.matmul(input, delta_weight)
def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep
+60
View File
@@ -0,0 +1,60 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import os
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
import argparse
parser = argparse.ArgumentParser(description="Separate the principal singular value and singular vectors from base model")
parser.add_argument("--base_model_path", type=str, required=True, help="The name or path of the base model.")
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--bits", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
parser.add_argument("--init_weights", type=str, default="pissa", help="(`['pissa', 'pissa_niter_[number of iters]']`)")
parser.add_argument("--lora_r", type=int, default=128)
parser.add_argument("--lora_alpha", type=int, default=128)
parser.add_argument("--lora_dropout", type=float, default=0)
parser.add_argument('--target_modules', nargs='+', help='', required=True)
script_args = parser.parse_args()
print(script_args)
model = AutoModelForCausalLM.from_pretrained(
script_args.base_model_path,
torch_dtype=(
torch.float16
if script_args.bits == "fp16"
else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32)
),
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
lora_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
init_lora_weights=True if script_args.init_weights=="True" else script_args.init_weights,
lora_dropout=script_args.lora_dropout,
target_modules=script_args.target_modules,
)
peft_model = get_peft_model(model, lora_config)
# Save PiSSA modules:
peft_model.peft_config["default"].init_lora_weights = True
peft_model.save_pretrained(os.path.join(script_args.output_dir, "pissa_init"))
# Save residual model:
peft_model = peft_model.unload()
peft_model.save_pretrained(script_args.output_dir)
# Save the tokenizer:
tokenizer.save_pretrained(script_args.output_dir)
+274
View File
@@ -0,0 +1,274 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import warnings
from typing import Any, Optional
import torch
import torch.nn as nn
from peft.tuners._buffer_dict import BufferDict
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from .config import DeloraConfig
class DeloraLayer(BaseTunerLayer):
# All names of layers that may contain (trainable) adapter weights
adapter_layer_names = (
"delora_A",
"delora_B",
"delora_lambda",
)
# All names of other parameters that may contain adapter-related parameters
other_param_names = (
"r",
"delora_dropout",
"delora_w_norm",
)
def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.base_layer = base_layer
self.r = {}
self.delora_dropout = nn.ModuleDict({})
self.delora_A = nn.ParameterDict({})
self.delora_B = nn.ParameterDict({})
self.delora_lambda = nn.ParameterDict({})
# Use persistent buffers so they are included in state_dict and saved.
self.delora_w_norm = BufferDict({}, persistent=True)
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
self.kwargs = kwargs
base_layer_mod = self.get_base_layer()
if isinstance(base_layer_mod, nn.Linear):
self.in_features, self.out_features = base_layer_mod.in_features, base_layer_mod.out_features
else:
raise ValueError(f"Unsupported layer type {type(base_layer_mod)}")
@staticmethod
def _compute_delta(
A: torch.Tensor, B: torch.Tensor, delora_lambda: torch.Tensor, r: int, w_norm: torch.Tensor
) -> torch.Tensor:
"""Compute delta = B @ diag(delora_lambda/r / (||A_i||*||B^j||)) @ A, scaled by provided w_norm (per-input channel)"""
An = torch.clamp(A.norm(dim=1), min=1e-4)
Bn = torch.clamp(B.norm(dim=0), min=1e-4)
diag = torch.diag_embed(delora_lambda / r / (An * Bn))
delta = B @ diag @ A
delta = delta * w_norm.unsqueeze(0)
return delta
def get_delta_weight(self, adapter: str) -> torch.Tensor:
if adapter not in self.delora_A or adapter not in self.delora_B:
raise ValueError(f"Adapter {adapter} not found.")
delta = self._compute_delta(
self.delora_A[adapter],
self.delora_B[adapter],
self.delora_lambda[adapter],
self.r[adapter],
self.delora_w_norm[adapter],
)
return delta
def update_layer(
self,
adapter_name: str,
r: int,
delora_lambda: float,
config: DeloraConfig,
**kwargs: Any,
) -> None:
"""Internal function to create delora adapter
Args:
adapter_name (`str`): Name for the adapter to add.
r (`int`): Rank for the added adapter.
delora_lambda (`float`): Boundary for the adapter's norm.
config (`DeloraConfig`): The adapter configuration for this layer.
"""
module_dropout = config.module_dropout
init_weights = config.init_weights
inference_mode = config.inference_mode
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.delora_A[adapter_name] = nn.Parameter(torch.empty(r, self.in_features))
self.delora_B[adapter_name] = nn.Parameter(torch.empty(self.out_features, r))
self.delora_lambda[adapter_name] = nn.Parameter(torch.empty(1))
if module_dropout > 0.0:
module_dropout_layer = nn.Dropout(p=module_dropout)
else:
module_dropout_layer = nn.Identity()
self.delora_dropout.update(nn.ModuleDict({adapter_name: module_dropout_layer}))
# Initialize weights
self.reset_delora_parameters(adapter_name, init_weights, delora_lambda)
# Move new weights to device
self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_delora_parameters(
self,
adapter_name: str,
init_weights: bool = True,
delora_lambda: float = 15.0,
) -> None:
if adapter_name not in self.delora_A.keys():
return
if init_weights is True:
nn.init.kaiming_uniform_(self.delora_A[adapter_name], a=math.sqrt(5))
nn.init.zeros_(self.delora_B[adapter_name])
else:
nn.init.kaiming_uniform_(self.delora_A[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.delora_B[adapter_name], a=math.sqrt(5))
self.delora_lambda[adapter_name].data.fill_(float(delora_lambda))
# capture a fixed norm for this adapter to use for future delta computations
with torch.no_grad():
w = self.get_base_layer().weight
if w.device.type != "meta":
w_norm = torch.norm(w.data, dim=0).detach()
else:
# For meta tensors, we can't compute the norm, so use a default value
w_norm = torch.ones(w.shape[1], device=w.device)
self.delora_w_norm[adapter_name] = w_norm
class DeloraLinear(nn.Module, DeloraLayer):
# DeLoRA implemented in a dense layer
def __init__(
self,
base_layer,
adapter_name: str,
config: DeloraConfig,
r: int,
delora_lambda: float,
**kwargs,
) -> None:
super().__init__()
DeloraLayer.__init__(self, base_layer, **kwargs)
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, delora_lambda, config=config)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
return
for active_adapter in adapter_names:
if active_adapter in self.delora_A.keys():
base_layer = self.get_base_layer()
delta_weight = (
self.get_delta_weight(active_adapter)
.detach()
.to(dtype=base_layer.weight.dtype, device=base_layer.weight.device)
)
with torch.no_grad():
if safe_merge:
orig_weights = base_layer.weight.data.clone()
orig_weights = orig_weights + delta_weight
if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weights
else:
base_layer.weight.data.add_(delta_weight)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
Unmerge all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.delora_A.keys():
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
if not self.active_adapters:
return self.base_layer(x, *args, **kwargs).to(previous_dtype)
base_out = self.base_layer(x, *args, **kwargs)
add_out = torch.zeros_like(base_out)
for adapter in self.active_adapters:
if adapter not in self.delora_A:
continue
x_d = self.delora_dropout[adapter](x)
# Decomposed delta calculation
# 1. (x * w_norm) @ A.T
h = nn.functional.linear(x_d * self.delora_w_norm[adapter], self.delora_A[adapter])
# 2. h @ diag
An = torch.clamp(self.delora_A[adapter].norm(dim=1), min=1e-4)
Bn = torch.clamp(self.delora_B[adapter].norm(dim=0), min=1e-4)
scaling = (self.delora_lambda[adapter] / self.r[adapter]) / (An * Bn)
h = h * scaling
# 3. h @ B.T
h = nn.functional.linear(h, self.delora_B[adapter])
add_out += h
result = base_out + add_out.to(base_out.dtype)
result = result.to(previous_dtype)
return result
def supports_lora_conversion(self, adapter_name: str = "default") -> bool:
return True
def __repr__(self) -> str:
rep = super().__repr__()
return "delora." + rep
+462
View File
@@ -0,0 +1,462 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from .config import HRAConfig
class HRALayer(BaseTunerLayer):
# All names of layers that may contain (trainable) adapter weights
adapter_layer_names = ("hra_u",)
# All names of other parameters that may contain adapter-related parameters
other_param_names = ("hra_r", "hra_apply_GS")
def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.base_layer = base_layer
self.hra_r = {}
self.hra_apply_GS = {}
self.hra_u = nn.ParameterDict({})
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
# flag to enable/disable casting of input to weight dtype during forward call
self.cast_input_dtype_enabled = True
self.kwargs = kwargs
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
self.in_features, self.out_features = base_layer.in_features, base_layer.out_features
elif isinstance(base_layer, nn.Conv2d):
self.in_features, self.out_features = base_layer.in_channels, base_layer.out_channels
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")
def update_layer(
self,
adapter_name: str,
r: int,
config: HRAConfig,
**kwargs,
) -> None:
"""Internal function to create hra adapter
Args:
adapter_name (`str`): Name for the adapter to add.
r (`int`): Rank for the added adapter.
config (`HRAConfig`): The adapter configuration for this layer.
"""
apply_GS = config.apply_GS
init_weights = config.init_weights
inference_mode = config.inference_mode
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.hra_r[adapter_name] = r
self.hra_apply_GS[adapter_name] = apply_GS
# Determine shape of HRA weights
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
self.hra_u[adapter_name] = nn.Parameter(torch.empty(self.in_features, r), requires_grad=True)
elif isinstance(base_layer, nn.Conv2d):
self.hra_u[adapter_name] = nn.Parameter(
torch.empty(self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], r),
requires_grad=True,
)
else:
raise TypeError(f"HRA is not implemented for base layers of type {type(base_layer).__name__}")
# Initialize weights
if init_weights:
self.reset_hra_parameters(adapter_name)
else:
self.reset_hra_parameters_random(adapter_name)
# Move new weights to device
self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_hra_parameters(self, adapter_name: str):
if self.hra_r[adapter_name] % 2 != 0:
warnings.warn("The symmetric initialization can NOT be performed when r is odd!")
nn.init.kaiming_uniform_(self.hra_u[adapter_name], a=math.sqrt(5))
else:
shape = self.hra_u[adapter_name].shape
half_u = torch.zeros(shape[0], shape[1] // 2)
nn.init.kaiming_uniform_(half_u, a=math.sqrt(5))
self.hra_u[adapter_name] = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1))
def reset_hra_parameters_random(self, adapter_name: str):
nn.init.kaiming_uniform_(self.hra_u[adapter_name], a=math.sqrt(5))
def scale_layer(self, scale: float) -> None:
if scale == 1:
return
for active_adapter in self.active_adapters:
if active_adapter not in self.hra_u.keys():
continue
warnings.warn("Scaling operation for HRA not supported! Automatically set scale to 1.")
def unscale_layer(self, scale=None) -> None:
for active_adapter in self.active_adapters:
if active_adapter not in self.hra_u.keys():
continue
warnings.warn("Unscaling operation for HRA not supported! Keeping scale at 1.")
class HRALinear(nn.Module, HRALayer):
"""
HRA implemented in a dense layer.
"""
def __init__(
self,
base_layer,
adapter_name: str,
config: HRAConfig,
r: int = 0,
**kwargs,
) -> None:
super().__init__()
HRALayer.__init__(self, base_layer, **kwargs)
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, config=config, **kwargs)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If `None`, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self.hra_u.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weight = base_layer.weight.data.clone()
delta_weight = self.get_delta_weight(active_adapter)
orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
if not torch.isfinite(orig_weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weight.to(orig_dtype)
else:
delta_weight = self.get_delta_weight(active_adapter)
new_weight = torch.mm(base_layer.weight.data.to(delta_weight.dtype), delta_weight)
base_layer.weight.data = new_weight.to(orig_dtype)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if active_adapter in self.hra_u.keys():
orig_weight = base_layer.weight.data.clone()
delta_weight = self.get_delta_weight(active_adapter, reverse=True)
new_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
base_layer.weight.data = new_weight.to(orig_dtype)
def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Tensor:
rank = self.hra_r[adapter_name]
apply_GS = self.hra_apply_GS[adapter_name]
opt_u = self.hra_u[adapter_name]
shape = opt_u.shape
if apply_GS:
weight = [(opt_u[:, 0] / opt_u[:, 0].norm()).view(-1, 1)]
for i in range(1, rank):
ui = opt_u[:, i].view(-1, 1)
for j in range(i):
ui = ui - (weight[j].t() @ ui) * weight[j]
weight.append((ui / ui.norm()).view(-1, 1))
weight = torch.cat(weight, dim=1)
weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * weight @ weight.t()
else:
opt_u = opt_u / opt_u.norm(dim=0)
weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype)
if reverse:
indices = range(rank - 1, -1, -1)
else:
indices = range(rank)
for i in indices:
ui = opt_u[:, i].view(-1, 1)
weight = weight - 2 * weight @ ui @ ui.t()
return weight
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
new_weight = torch.eye(self.in_features, device=x.device)
for active_adapter in self.active_adapters:
if active_adapter not in self.hra_u.keys():
continue
delta_weight = self.get_delta_weight(active_adapter)
new_weight = torch.mm(new_weight.to(delta_weight.dtype), delta_weight)
orig_weight = self.get_base_layer().weight.data
orig_weight = self._cast_input_dtype(orig_weight, new_weight.dtype)
new_weight = torch.mm(orig_weight, new_weight)
bias = self._cast_input_dtype(self.base_layer.bias, new_weight.dtype)
if self.cast_input_dtype_enabled:
x = self._cast_input_dtype(x, new_weight.dtype)
else:
x = x.to(self.get_base_layer().weight.data.dtype)
result = F.linear(input=x, weight=new_weight, bias=bias)
result = result.to(previous_dtype)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "hra." + rep
class HRAConv2d(nn.Module, HRALayer):
"""HRA implemented in Conv2d layer"""
def __init__(
self,
base_layer,
adapter_name: str,
config: HRAConfig,
r: int = 0,
**kwargs,
):
super().__init__()
HRALayer.__init__(self, base_layer)
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, config=config, **kwargs)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If `None`, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self.hra_u.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weight = base_layer.weight.data.clone()
orig_weight = orig_weight.view(
self.out_features,
self.in_features * base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
)
delta_weight = self.get_delta_weight(active_adapter)
orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
orig_weight = orig_weight.view(
self.out_features,
self.in_features,
base_layer.kernel_size[0],
base_layer.kernel_size[0],
)
if not torch.isfinite(orig_weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weight.to(orig_dtype)
else:
orig_weight = base_layer.weight.data
orig_weight = orig_weight.view(
self.out_features,
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
)
delta_weight = self.get_delta_weight(active_adapter)
orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
orig_weight = orig_weight.view(
self.out_features,
self.in_features,
base_layer.kernel_size[0],
base_layer.kernel_size[0],
)
base_layer.weight.data = orig_weight.to(orig_dtype)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if active_adapter in self.hra_u.keys():
orig_weight = base_layer.weight.data.clone()
orig_weight = orig_weight.view(
self.out_features,
self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0],
)
delta_weight = self.get_delta_weight(active_adapter, reverse=True)
orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
orig_weight = orig_weight.view(
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
)
base_layer.weight.data = orig_weight.to(orig_dtype)
def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Tensor:
rank = self.hra_r[adapter_name]
apply_GS = self.hra_apply_GS[adapter_name]
opt_u = self.hra_u[adapter_name]
shape = opt_u.shape
if apply_GS:
weight = [(opt_u[:, 0] / opt_u[:, 0].norm()).view(-1, 1)]
for i in range(1, rank):
ui = opt_u[:, i].view(-1, 1)
for j in range(i):
ui = ui - (weight[j].t() @ ui) * weight[j]
weight.append((ui / ui.norm()).view(-1, 1))
weight = torch.cat(weight, dim=1)
weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * weight @ weight.t()
else:
opt_u = opt_u / opt_u.norm(dim=0)
weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype)
if reverse:
indices = range(rank - 1, -1, -1)
else:
indices = range(rank)
for i in indices:
ui = opt_u[:, i].view(-1, 1)
weight = weight - 2 * weight @ ui @ ui.t()
return weight
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
new_weight = torch.eye(
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
device=x.device,
)
for active_adapter in self.active_adapters:
if active_adapter not in self.hra_u.keys():
continue
delta_weight = self.get_delta_weight(active_adapter)
new_weight = torch.mm(new_weight.to(delta_weight.dtype), delta_weight)
orig_weight = self.base_layer.weight.data
orig_weight = orig_weight.view(
self.out_features,
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
)
orig_weight = self._cast_input_dtype(orig_weight, new_weight.dtype)
bias = self._cast_input_dtype(self.base_layer.bias, new_weight.dtype)
new_weight = torch.mm(orig_weight, new_weight)
new_weight = new_weight.view(
self.out_features,
self.in_features,
self.base_layer.kernel_size[0],
self.base_layer.kernel_size[0],
)
if self.cast_input_dtype_enabled:
x = self._cast_input_dtype(x, new_weight.dtype)
else:
x = x.to(self.get_base_layer().weight.data.dtype)
result = F.conv2d(
input=x,
weight=new_weight,
bias=bias,
padding=self.base_layer.padding[0],
stride=self.base_layer.stride[0],
)
result = result.to(previous_dtype)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "hra." + rep
+336
View File
@@ -0,0 +1,336 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Optional
import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils import transpose
from .config import IA3Config
class IA3Layer(BaseTunerLayer):
# All names of layers that may contain adapter weights
adapter_layer_names = ("ia3_l",)
def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> None:
self.base_layer = base_layer
self.ia3_l = nn.ParameterDict({})
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
self.is_feedforward = is_feedforward
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
in_features, out_features = base_layer.in_features, base_layer.out_features
elif isinstance(base_layer, (nn.Conv2d, nn.Conv3d)):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Embedding):
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim
elif isinstance(base_layer, Conv1D):
in_features, out_features = (
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
)
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")
self.in_features = in_features
self.out_features = out_features
def update_layer(self, adapter_name: str, config: IA3Config, **kwargs):
init_ia3_weights = config.init_ia3_weights
inference_mode = config.inference_mode
# This code works for linear layers, override for other layer types
# Actual trainable parameters
if self.is_feedforward:
weight = torch.randn((1, self.in_features))
else:
weight = torch.randn((self.out_features, 1))
self.ia3_l[adapter_name] = nn.Parameter(weight)
if init_ia3_weights:
self.reset_ia3_parameters(adapter_name)
self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_ia3_parameters(self, adapter_name):
if adapter_name in self.ia3_l.keys():
# initialize learned vector with torch.ones
nn.init.constant_(self.ia3_l[adapter_name], 1.0)
class Linear(nn.Module, IA3Layer):
# (IA)^3 implemented in a dense layer
def __init__(
self,
base_layer: nn.Module,
adapter_name: str,
config: IA3Config,
is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer
is_target_conv_1d_layer: bool = False, # whether target module is a conv1d layer. useful while unloading later
**kwargs,
) -> None:
super().__init__()
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
self.fan_in_fan_out = config.fan_in_fan_out
self.is_target_conv_1d_layer = is_target_conv_1d_layer
self._active_adapter = adapter_name
self.update_layer(adapter_name, config=config)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self.ia3_l.keys():
base_layer = self.get_base_layer()
ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out)
orig_dtype = base_layer.weight.data.dtype
if safe_merge:
orig_weights = base_layer.weight.data
orig_weights = torch.mul(orig_weights, ia3_l)
if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weights.to(orig_dtype)
else:
base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_l).to(orig_dtype)
if not self.is_feedforward and (base_layer.bias is not None):
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
orig_dtype = base_layer.bias.data.dtype
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data).to(orig_dtype)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
warnings.warn("Unmerge result can be inaccurate for (IA)^3.")
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.ia3_l.keys():
base_layer = self.get_base_layer()
# Add tolerace to avoid division by zero
ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) + 1e-8
orig_dtype = base_layer.weight.data.dtype
base_layer.weight.data = torch.div(base_layer.weight.data, ia3_l).to(orig_dtype)
if not self.is_feedforward and (base_layer.bias is not None):
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
orig_dtype = base_layer.bias.data.dtype
base_layer.bias.data = torch.div(base_layer.bias.data, scaling.data + 1e-8).to(orig_dtype)
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
dtype = previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
ia3_scaling = 1
for active_adapter in self.active_adapters:
if active_adapter not in self.ia3_l.keys():
continue
dtype = self.ia3_l[active_adapter].dtype
ia3_scaling *= self.ia3_l[active_adapter].flatten()
if self.is_feedforward:
x = x.to(dtype)
# TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype
# e.g. bf16 vs fp32. Is that okay?
interm = (x * ia3_scaling).to(previous_dtype)
result = self.base_layer(interm, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
result_dtype = result.dtype
result = (result * ia3_scaling).to(result_dtype)
return result
class _ConvNd(nn.Module, IA3Layer):
def __init__(
self,
base_layer: nn.Module,
adapter_name: str,
config: IA3Config,
is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer
**kwargs,
) -> None:
super().__init__()
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
self.fan_in_fan_out = config.fan_in_fan_out
self._active_adapter = adapter_name
self._kernel_dim = base_layer.weight.dim()
self.update_layer(adapter_name, config=config)
def update_layer(self, adapter_name: str, config: IA3Config, **kwargs):
init_ia3_weights = config.init_ia3_weights
inference_mode = config.inference_mode
# Actual trainable parameters
num_features = self.in_features if self.is_feedforward else self.out_features
weights_size = (1, num_features) + (1,) * (self._kernel_dim - 2)
weight = torch.randn(weights_size)
self.ia3_l[adapter_name] = nn.Parameter(weight)
if init_ia3_weights:
self.reset_ia3_parameters(adapter_name)
self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self.ia3_l.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.data.dtype
ia3_scaling = self.ia3_l[active_adapter].data
if not self.is_feedforward:
ia3_scaling = ia3_scaling.transpose(0, 1)
if safe_merge:
output_weight = torch.mul(base_layer.weight.data, ia3_scaling).clone()
if not torch.isfinite(output_weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = output_weight.to(orig_dtype)
else:
base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_scaling).to(orig_dtype)
if not self.is_feedforward and (base_layer.bias is not None):
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data).to(orig_dtype)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
warnings.warn("Unmerge result can be inaccurate for (IA)^3.")
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.ia3_l.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.data.dtype
# divide by (IA)^3 vector. Add tolerace to avoid division by zero
ia3_scaling = self.ia3_l[active_adapter].data
if not self.is_feedforward:
ia3_scaling = ia3_scaling.transpose(0, 1)
base_layer.weight.data = torch.div(base_layer.weight.data, ia3_scaling + 1e-8).to(orig_dtype)
if not self.is_feedforward and (base_layer.bias is not None):
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
orig_dtype = base_layer.bias.data.dtype
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data).to(orig_dtype)
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
dtype = previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
ia3_scaling = 1
for active_adapter in self.active_adapters:
if active_adapter not in self.ia3_l.keys():
continue
dtype = self.ia3_l[active_adapter].dtype
ia3_scaling *= self.ia3_l[active_adapter]
if self.is_feedforward:
x = x.to(dtype)
# TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype
# e.g. bf16 vs fp32. Is that okay?
interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype)
result = self.base_layer(interm, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
result = result.to(dtype) * ia3_scaling
result = result.to(previous_dtype)
return result
class Conv2d(_ConvNd):
# IA3 implemented in a 2D convolutional layer
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self._kernel_dim == 4:
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")
class Conv3d(_ConvNd):
# IA3 implemented in a 3D convolutional layer
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self._kernel_dim == 5:
raise ValueError(f"Conv2d layer kernel must have 5 dimensions, not {self._kernel_dim}")
+287
View File
@@ -0,0 +1,287 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from functools import wraps
from typing import Any, Optional
import torch
import torch.nn.functional as F
from torch import nn
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx
from peft.utils.other import transpose
ENABLE_DORA_CACHING = False
"""Whether to enable DoRA caching, which makes it faster at inference but requires more memory"""
def cache_decorator(cache_key: str):
"""Caching decorator for DoRA
Caching is only enabled if ENABLE_DORA_CACHING is set to True (default: False), when in eval mode, and when the
adapter_name is passed (e.g. not during layer initialization).
"""
def cache_value(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
# if adapter_name is not passed, no caching
adapter_name = kwargs.get("adapter_name")
if (not ENABLE_DORA_CACHING) or self.training or (adapter_name is None):
self._cache_clear()
return func(self, *args, **kwargs)
cache_key_adapter = f"{cache_key}-{adapter_name}"
output = self._cache_get(cache_key_adapter, None)
if output is not None:
return output
output = func(self, *args, **kwargs)
self._cache_store(cache_key_adapter, output)
return output
return wrapper
return cache_value
class DoraLinearLayer(nn.Module):
def __init__(self, fan_in_fan_out):
super().__init__()
self.fan_in_fan_out = fan_in_fan_out
self._dora_cache: dict[str, Any] = {} # small ad hoc cache; values are not part of the state_dict
def _cache_store(self, key: str, value: Any) -> None:
# cache intermediate values, e.g. weight norm of DoRA
self._dora_cache[key] = value
def _cache_get(self, key: str, default: Optional[Any]) -> Optional[Any]:
# retrieve from ad hoc cache
return self._dora_cache.get(key, default)
def _cache_clear(self) -> None:
self._dora_cache.clear()
def train(self, mode: bool = True):
if mode:
self._cache_clear()
super().train(mode=mode)
return self
@cache_decorator("weight-norm")
def get_weight_norm(self, weight, lora_weight, scaling, adapter_name: Optional[str] = None) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = transpose(weight, self.fan_in_fan_out)
weight = weight + scaling * lora_weight
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
return weight_norm
@cache_decorator("lora-weight")
def get_lora_weight(self, lora_A, lora_B, adapter_name: Optional[str] = None):
# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
# calculate the same but using forward.
x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=lora_A.weight.dtype)
lora_weight = lora_B(lora_A(x_eye)).T
return lora_weight
def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False) -> None:
# temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2
dtype_is_fp16 = lora_A.dtype == torch.float16
if dtype_is_fp16:
lora_A = lora_A.float()
lora_B = lora_B.float()
with gather_params_ctx(base_layer.parameters()):
if base_layer.__class__.__name__ == "Linear4bit":
# We have to create a copy of the base layer, otherwise, FSDP will throw an error. 8bit does not work
# yet because Int8Params cannot be correctly deep-copied (attributes vanish)
base_layer = deepcopy(base_layer)
weight = dequantize_module_weight(base_layer)
if weight.data.ndim >= 3: # For handling LoRAs applied to Conv layers.
r = lora_A.shape[0]
lora_weight = torch.mm(lora_B.view([-1, r]), lora_A.view([r, -1]))
lora_weight = lora_weight.reshape(weight.shape)
else:
lora_weight = lora_B @ lora_A
if dtype_is_fp16:
lora_weight = lora_weight.half()
weight_norm = self.get_weight_norm(
weight=weight.to(lora_A.device), lora_weight=lora_weight, scaling=scaling
)
if place_on_cpu:
weight_norm = weight_norm.to("cpu")
self.weight = nn.Parameter(weight_norm, requires_grad=True)
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None, adapter_name="default"):
"""
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
"""
lora_weight = self.get_lora_weight(lora_A=lora_A, lora_B=lora_B, adapter_name=adapter_name)
lora_weight = lora_weight.to(x.dtype)
magnitude = self.weight
weight = dequantize_module_weight(base_layer)
weight = weight.to(x.dtype)
weight_norm = self.get_weight_norm(
weight=weight, lora_weight=lora_weight.detach(), scaling=scaling, adapter_name=adapter_name
)
# see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353)
# "[...] we suggest treating ||V +∆V ||_c in
# Eq. (5) as a constant, thereby detaching it from the gradient
# graph. This means that while ||V + ∆V ||_c dynamically
# reflects the updates of ∆V , it wont receive any gradient
# during backpropagation"
weight_norm = weight_norm.detach()
mag_norm_scale = (magnitude / weight_norm).view(1, -1)
lora_result = lora_B(lora_A(x))
bias = None
if base_result is not None:
bias = base_layer.bias
if bias is not None:
base_result = base_result - bias
else:
base_result = F.linear(x, transpose(weight, self.fan_in_fan_out))
result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling
return result_dora
def __repr__(self) -> str:
rep = super().__repr__()
return "lora.dora." + rep
class DoraEmbeddingLayer(DoraLinearLayer):
@cache_decorator("lora-weight")
def get_lora_weight(self, lora_A, lora_B, adapter_name: Optional[str] = None):
return (lora_A @ lora_B).T
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, embed_fn, adapter_name="default"):
"""
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
"""
lora_weight = self.get_lora_weight(lora_A=lora_A, lora_B=lora_B, adapter_name=adapter_name)
magnitude = self.weight
weight = base_layer.weight
weight_norm = self.get_weight_norm(
weight=weight, lora_weight=lora_weight.detach(), scaling=scaling, adapter_name=adapter_name
)
# see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353)
# "[...] we suggest treating ||V +∆V ||_c in
# Eq. (5) as a constant, thereby detaching it from the gradient
# graph. This means that while ||V + ∆V ||_c dynamically
# reflects the updates of ∆V , it wont receive any gradient
# during backpropagation"
weight_norm = weight_norm.detach()
mag_norm_scale = magnitude / weight_norm
result_dora = mag_norm_scale * (embed_fn(x, lora_A) @ lora_B) * scaling
return mag_norm_scale, result_dora
def __repr__(self) -> str:
rep = super().__repr__()
return "lora.dora." + rep
class _DoraConvNdLayer(DoraLinearLayer):
@cache_decorator("weight-norm")
def get_weight_norm(self, weight, lora_weight, scaling, adapter_name: Optional[str] = None) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = weight + scaling * lora_weight
# the following is needed to have compatibility with the 4/5D weight tensors of Conv2D/3D
dim = tuple(range(1, weight.dim()))
weight_norm = weight.norm(p=2, dim=dim, keepdim=True).transpose(1, 0)
return weight_norm
@cache_decorator("lora-weight")
def get_lora_weight(self, lora_A, lora_B, adapter_name: Optional[str] = None) -> torch.Tensor:
# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
# calculate the same but using forward.
r = lora_A.weight.shape[0]
lora_weight = torch.mm(lora_B.weight.view([-1, r]), lora_A.weight.view([r, -1]))
return lora_weight
def forward(
self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None, adapter_name: str = "default"
) -> torch.Tensor:
"""
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
"""
weight = base_layer.weight
lora_weight = self.get_lora_weight(lora_A=lora_A, lora_B=lora_B, adapter_name=adapter_name).reshape(
weight.shape
)
magnitude = self.weight
weight_norm = self.get_weight_norm(
weight=weight, lora_weight=lora_weight.detach(), scaling=scaling, adapter_name=adapter_name
)
# see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353)
# "[...] we suggest treating ||V +∆V ||_c in
# Eq. (5) as a constant, thereby detaching it from the gradient
# graph. This means that while ||V + ∆V ||_c dynamically
# reflects the updates of ∆V , it wont receive any gradient
# during backpropagation"
weight_norm = weight_norm.detach()
mag_norm_scale = magnitude / weight_norm
if base_result is None:
base_result = self.conv_fn(
x,
weight,
bias=None,
stride=base_layer.stride,
padding=base_layer.padding,
dilation=base_layer.dilation,
groups=base_layer.groups,
)
else:
bias = base_layer.bias
if bias is not None:
# reshape bias to (1, -1, 1, ...)
bias_shape = (1, -1) + (1,) * (base_result.dim() - 2)
base_result = base_result - bias.view(*bias_shape)
result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_B(lora_A(x)) * scaling
return result_dora
def __repr__(self) -> str:
rep = super().__repr__()
return "lora.dora." + rep
class DoraConv1dLayer(_DoraConvNdLayer):
def __init__(self, fan_in_fan_out):
super().__init__(fan_in_fan_out)
self.conv_fn = F.conv1d
class DoraConv2dLayer(_DoraConvNdLayer):
def __init__(self, fan_in_fan_out):
super().__init__(fan_in_fan_out)
self.conv_fn = F.conv2d
class DoraConv3dLayer(_DoraConvNdLayer):
def __init__(self, fan_in_fan_out):
super().__init__(fan_in_fan_out)
self.conv_fn = F.conv3d
File diff suppressed because it is too large Load Diff
+923
View File
@@ -0,0 +1,923 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import collections
import warnings
from typing import Any, Optional
import torch
from accelerate.utils.imports import is_xpu_available
from torch import nn
from peft.tuners.lora.config import BdLoraConfig
from peft.utils.other import transpose
from .arrow import ArrowLoraLinearLayer
from .config import LoraConfig, PeftConfig
from .dora import DoraConv1dLayer, DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer
from .layer import Conv1d, Conv2d, Conv3d, Embedding, Linear, LoraVariant, _ConvNd
class ArrowLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs):
"""
Initialise the ArrowLoraLinearLayer() inside lora_arrow. lora_arrow is nn.ModuleDict(), serving as a container
for ArrowLoraLinearLayer(). A layer of the base model with LoRA adapter loaded on it will be like:
----------------------------------------------------
(qkv_proj): lora.Linear4bit or lora.Linear(
(base_layer): Linear4bit or Linear (lora_dropout): ModuleDict( ... ) (lora_A): ModuleDict( ... )
(lora_B): ModuleDict( ... ) (lora_embedding_A): ParameterDict( ... ) (lora_embedding_B): ParameterDict(
... ) (lora_magnitude_vector): ModuleDict( ... ) (lora_arrow): ModuleDict(
(arrow_router): ArrowLoraLinearLayer() )
)
----------------------------------------------------
Args:
module (Linear): LoRA Layer of the model, containing base_layer, lora_A, lora_B, etc.
adapter_name (str): name of the adapter that will be put in lora_arrow.
The adapter_name is "arrow_router" by default, set in create_arrow_model() in ./arrow.py
"""
# Checking for arrow necessary config
arrow_config = config.arrow_config
if arrow_config is None:
raise ValueError("ArrowLinearVariant.init() did not receive an arrow_config")
# 1-a) build the ArrowLoRALayer
arrow_layer = ArrowLoraLinearLayer(
in_features=module.in_features,
arrow_config=arrow_config,
).to(module.weight.device)
# 1-b) register a container if it doesnt exist yet
if not hasattr(module, "lora_arrow"):
module.lora_arrow = nn.ModuleDict()
module.lora_arrow[adapter_name] = arrow_layer
@staticmethod
def forward(
module: Linear,
*,
active_adapter: str,
x: torch.Tensor,
result: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Parameters mirror those in PEFTs `LoraVariant.forward`. Called every time the host Linear does a fwd pass.
build_prototypes() and gen_know_sub() should run only once before routing. Both are implemented in
ArrowLoraLinearLayer (see ./arrow.py). They are lazily invoked in the forward pass below. Attributes of
ArrowLoraLinearLayer() class ensure they execute only a single time.
Args:
module (Linear): LoRA Layer of the model
active_adapter (str): name of the arrow route, which should be active to perform arrow.
x (torch.Tensor): input to the layer
result (torch.Tensor): output of the base layer.
Return value:
output of the base model + delta weight computed by arrow layer.
"""
arrow = module.lora_arrow[active_adapter] # ArrowLoraLinearLayer
# Apply GenKnowSub the 1st time if applcable. By calling arrow/on_adapter_change(),
# gen_know_sub() is redone for newly added adapters after arrow.create_arrow_model().
arrow.gen_know_sub(module.lora_A, module.lora_B)
# lazily build prototypes the 1st time after GenKnowSub. By calling arrow/on_adapter_change(),
# build_prototypes() is redone for newly added adapters after arrow.create_arrow_model().
arrow.build_prototypes(module.lora_A, module.lora_B)
# A forward path of ArrowLoraLinearLayer is called so routing performs.
# Accept and ignore extra variant kwargs (e.g., 'alora_offsets') for compatibility
delta = arrow(
x,
lora_A=module.lora_A,
lora_B=module.lora_B,
dropout=module.lora_dropout[active_adapter],
scaling=module.scaling,
)
return result + delta
"""
Since Arrow is a Mixture-of-Experts (MoE) approach, merging adapters is not meaningful or even possible: for each
token, the top-k LoRA experts are dynamically selected and routed. Because of this per-token routing, there is no
single set of weights that can represent a merged adapter.
"""
@staticmethod
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
raise RuntimeError("Cannot merge an active Arrow router adapter. Remove it first.")
@staticmethod
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
raise RuntimeError("Cannot merge an active Arrow router adapter. Remove it first.")
@staticmethod
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
raise RuntimeError("Cannot unmerge an active Arrow router adapter. Remove it first.")
class DoraLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
if not module.lora_magnitude_vector:
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_magnitude_vector",)
dora_layer = DoraLinearLayer(fan_in_fan_out=getattr(module, "fan_in_fan_out", False))
lora_A = module.lora_A[adapter_name].weight
lora_B = module.lora_B[adapter_name].weight
place_on_cpu = module.ephemeral_gpu_offload and (lora_A.device.type == "cpu" or lora_B.device.type == "cpu")
if module.ephemeral_gpu_offload:
if lora_A.device.type in ["cuda", "xpu"]:
lora_B = lora_B.to(lora_A.device)
else:
if lora_B.device.type not in ["cuda", "xpu"]:
if is_xpu_available():
lora_B = lora_B.to("xpu")
else:
lora_B = lora_B.to("cuda")
lora_A = lora_A.to(lora_B.device)
scaling = module.scaling[adapter_name]
dora_layer.update_layer(
base_layer=module.get_base_layer(),
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
place_on_cpu=place_on_cpu,
)
module.lora_magnitude_vector[adapter_name] = dora_layer
@staticmethod
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
orig_dtype = orig_weight.dtype
delta_weight = module.get_delta_weight(active_adapter)
# since delta_weight already includes scaling, set it to 1 here
weight_norm = (
module.lora_magnitude_vector[active_adapter]
.get_weight_norm(orig_weight, transpose(delta_weight, module.fan_in_fan_out), scaling=1)
.detach()
)
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = transpose(dora_factor.view(-1, 1), module.fan_in_fan_out)
new_weight = dora_factor * (orig_weight + delta_weight)
new_weight = new_weight.to(orig_dtype)
return new_weight
@staticmethod
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
orig_dtype = orig_weight.dtype
delta_weight = module.get_delta_weight(active_adapter)
weight_norm = (
module.lora_magnitude_vector[active_adapter]
.get_weight_norm(orig_weight, transpose(delta_weight, module.fan_in_fan_out), scaling=1)
.detach()
)
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = transpose(dora_factor.view(-1, 1), module.fan_in_fan_out)
new_weight = dora_factor * (orig_weight.data + delta_weight)
new_weight = new_weight.to(orig_dtype)
orig_weight.data = new_weight
@staticmethod
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
orig_dtype = orig_weight.dtype
delta_weight = module.get_delta_weight(active_adapter)
weight_norm = module._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
new_weight = orig_weight.data / dora_factor.view(-1, 1) - delta_weight
new_weight = new_weight.to(orig_dtype)
return new_weight
@staticmethod
def forward(
module: Linear,
active_adapter: str,
x: torch.Tensor,
result: torch.Tensor,
**kwargs,
) -> torch.Tensor:
lora_A = module.lora_A[active_adapter]
lora_B = module.lora_B[active_adapter]
dropout = module.lora_dropout[active_adapter]
scaling = module.scaling[active_adapter]
if isinstance(dropout, nn.Identity) or not module.training:
base_result = result
else:
x = dropout(x)
base_result = None
result = result + module.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=module.get_base_layer(),
base_result=base_result,
adapter_name=active_adapter,
)
return result
class DoraEmbeddingVariant(DoraLinearVariant):
@staticmethod
def init(module: Embedding, adapter_name: str, **kwargs: Any) -> None:
if module.lora_magnitude_vector is None:
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_magnitude_vector",)
dora_layer = DoraEmbeddingLayer(fan_in_fan_out=True)
lora_embedding_A = module.lora_embedding_A[adapter_name]
lora_embedding_B = module.lora_embedding_B[adapter_name]
scaling = module.scaling[adapter_name]
dora_layer.update_layer(
base_layer=module.get_base_layer(), lora_A=lora_embedding_A, lora_B=lora_embedding_B, scaling=scaling
)
module.lora_magnitude_vector[adapter_name] = dora_layer
@staticmethod
def merge_safe(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
orig_dtype = orig_weight.dtype
delta_weight = module.get_delta_weight(active_adapter)
# since delta_weight already includes scaling, set it to 1 here
weight_norm = (
module.lora_magnitude_vector[active_adapter]
.get_weight_norm(orig_weight, delta_weight.T, scaling=1)
.detach()
)
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = dora_factor.view(1, -1)
new_weight = dora_factor * (orig_weight + delta_weight)
new_weight = new_weight.to(orig_dtype)
return new_weight
@staticmethod
def merge_unsafe(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> None:
orig_dtype = orig_weight.dtype
delta_weight = module.get_delta_weight(active_adapter)
weight_norm = (
module.lora_magnitude_vector[active_adapter]
.get_weight_norm(orig_weight, delta_weight.T, scaling=1)
.detach()
)
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
dora_factor = dora_factor.view(1, -1)
new_weight = dora_factor * (orig_weight.data + delta_weight)
new_weight = new_weight.to(orig_dtype)
orig_weight.data = new_weight
@staticmethod
def unmerge(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
orig_dtype = orig_weight.dtype
delta_weight = module.get_delta_weight(active_adapter)
weight_norm = module._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
new_weight = orig_weight.data / dora_factor.view(1, -1) - delta_weight
new_weight = new_weight.to(orig_dtype)
return new_weight
@staticmethod
def forward(
module: Embedding,
active_adapter: str,
x: torch.Tensor,
result: torch.Tensor,
**kwargs,
) -> torch.Tensor:
embedding_A = module.lora_embedding_A[active_adapter].T
embedding_B = module.lora_embedding_B[active_adapter].T
scaling = module.scaling[active_adapter]
mag_norm_scale, dora_result = module.lora_magnitude_vector[active_adapter](
x,
lora_A=embedding_A,
lora_B=embedding_B,
scaling=scaling,
base_layer=module.get_base_layer(),
embed_fn=module._embed,
adapter_name=active_adapter,
)
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to DoRA contributions too.
# Note: embed_scale is applied AFTER weight norm calculation to preserve DoRA's weight geometry semantics.
embed_scale = module._get_embed_scale()
if embed_scale is not None:
dora_result = dora_result * embed_scale.to(dora_result.dtype)
result = mag_norm_scale * result + dora_result
return result
class _DoraConvNdVariant(LoraVariant):
@staticmethod
def init_convd_variant(module: _ConvNd, adapter_name: str, dora_layer: nn.Module) -> None:
if module.lora_magnitude_vector is None:
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_magnitude_vector",)
lora_A = module.lora_A[adapter_name].weight
lora_B = module.lora_B[adapter_name].weight
scaling = module.scaling[adapter_name]
dora_layer.update_layer(base_layer=module.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling)
module.lora_magnitude_vector[adapter_name] = dora_layer
@staticmethod
def merge_safe(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
orig_dtype = orig_weight.dtype
delta_weight = module.get_delta_weight(active_adapter)
# since delta_weight already includes scaling, set it to 1 here
weight_norm = (
module.lora_magnitude_vector[active_adapter].get_weight_norm(orig_weight, delta_weight, scaling=1).detach()
)
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
new_weight = dora_factor.view(*module._get_dora_factor_view()) * (orig_weight + delta_weight)
new_weight = new_weight.to(orig_dtype)
return new_weight
@staticmethod
def merge_unsafe(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) -> None:
orig_dtype = orig_weight.dtype
delta_weight = module.get_delta_weight(active_adapter)
# since delta_weight already includes scaling, set it to 1 here
weight_norm = (
module.lora_magnitude_vector[active_adapter].get_weight_norm(orig_weight, delta_weight, scaling=1).detach()
)
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
new_weight = dora_factor.view(*module._get_dora_factor_view()) * (orig_weight.data + delta_weight)
new_weight = new_weight.to(orig_dtype)
orig_weight.data = new_weight
@staticmethod
def unmerge(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
orig_dtype = orig_weight.dtype
delta_weight = module.get_delta_weight(active_adapter)
weight_norm = module._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
new_weight = orig_weight.data / dora_factor.view(*module._get_dora_factor_view()) - delta_weight
new_weight = new_weight.to(orig_dtype)
return new_weight
@staticmethod
def forward(
module: _ConvNd,
active_adapter: str,
x: torch.Tensor,
result: torch.Tensor,
**kwargs,
) -> torch.Tensor:
lora_A = module.lora_A[active_adapter]
lora_B = module.lora_B[active_adapter]
dropout = module.lora_dropout[active_adapter]
scaling = module.scaling[active_adapter]
if isinstance(dropout, nn.Identity) or not module.training:
base_result = result
else:
x = dropout(x)
base_result = None
result = result + module.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=module.get_base_layer(),
base_result=base_result,
adapter_name=active_adapter,
)
return result
class DoraConv1dVariant(_DoraConvNdVariant):
@staticmethod
def init(module: Conv1d, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
dora_layer = DoraConv1dLayer(fan_in_fan_out=False)
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)
class DoraConv2dVariant(_DoraConvNdVariant):
@staticmethod
def init(module: Conv2d, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
dora_layer = DoraConv2dLayer(fan_in_fan_out=False)
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)
class DoraConv3dVariant(_DoraConvNdVariant):
@staticmethod
def init(module: Conv3d, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
dora_layer = DoraConv3dLayer(fan_in_fan_out=False)
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)
class QALoraLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
"""
Initializes QALoRA specific parameters for a given adapter.
Args:
module (Linear): The linear module to be adapted.
adapter_name (str): The name of the adapter.
config (LoraConfig): The config of the LoRA adapter.
**kwargs: Additional keyword arguments.
"""
qalora_group_size = config.qalora_group_size
if module.in_features is not None and module.in_features % qalora_group_size != 0:
raise ValueError(
f"`use_qalora=True` requires `module.in_features` ({module.in_features}) to be"
f"divisible by 'qalora_group_size' ({qalora_group_size})"
)
if "qalora_group_size" not in module.other_param_names:
module.other_param_names = module.other_param_names + ("qalora_group_size",)
if not hasattr(module, "qalora_group_size"):
module.qalora_group_size = {}
module.qalora_group_size[adapter_name] = qalora_group_size
old_lora_A_layer = module.lora_A[adapter_name]
r = old_lora_A_layer.out_features
device = old_lora_A_layer.weight.device
dtype = old_lora_A_layer.weight.dtype
new_lora_A_layer = nn.Linear(
old_lora_A_layer.in_features // module.qalora_group_size[adapter_name],
r,
bias=False,
device=device,
dtype=dtype,
)
module.lora_A[adapter_name] = new_lora_A_layer
@staticmethod
def get_delta_weight(module: Linear, active_adapter: str) -> torch.Tensor:
raise NotImplementedError("QALoRA for GPTQ layers does not support 'get_delta_weight'.")
@staticmethod
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("QALoRA for GPTQ layers does not support 'safe_merge'.")
@staticmethod
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
raise NotImplementedError("QALoRA for GPTQ layers does not support 'merge_unsafe'.")
@staticmethod
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("QALoRA for GPTQ layers does not support 'unmerge'.")
@staticmethod
def forward(
module: Linear,
active_adapter: str,
x: torch.Tensor,
result: torch.Tensor,
**kwargs,
) -> torch.Tensor:
lora_A_weight = module.lora_A[active_adapter].weight
lora_B_weight = module.lora_B[active_adapter].weight
dropout = module.lora_dropout[active_adapter]
scaling = module.scaling[active_adapter]
group_size = module.qalora_group_size[active_adapter]
x_dropped = dropout(x) if module.training and not isinstance(dropout, nn.Identity) else x
orig_shape = x_dropped.shape
# Reshape to 2D
if len(orig_shape) > 2:
x_flat = x_dropped.view(-1, module.in_features)
else:
x_flat = x_dropped
batch_size, in_features = x_flat.shape
pooled_features = in_features // group_size
x_pooled = x_flat.view(batch_size, pooled_features, group_size).mean(dim=2)
x_pooled_scaled = x_pooled * pooled_features
# LoRA computation
delta = x_pooled_scaled @ lora_A_weight.t() @ lora_B_weight.t() * scaling
# Reshape back
if len(orig_shape) > 2:
delta = delta.view(orig_shape[:-1] + (delta.size(-1),))
return result + delta
class ALoraLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
pass
@staticmethod
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("aLoRA does not support safe merging.")
@staticmethod
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
raise NotImplementedError("aLoRA does not support merging.")
@staticmethod
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("aLoRA does not support unmerging.")
@staticmethod
def forward(
module: Linear,
active_adapter: str,
x: torch.Tensor,
result: torch.Tensor,
**kwargs,
) -> torch.Tensor:
alora_offsets = kwargs.get("alora_offsets", None)
lora_A = module.lora_A[active_adapter]
lora_B = module.lora_B[active_adapter]
dropout = module.lora_dropout[active_adapter]
scaling = module.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
result_shape = result.shape
B = result_shape[0] # batch
if len(result_shape) == 3:
T = result_shape[1] # tokens
else:
T = 1
D = result_shape[-1] # dimensions
Dx = x.shape[-1]
device = result.device
if alora_offsets is None: # use base model only, but ensure 0 gradient
mask = torch.zeros((B, T), dtype=torch.bool)
else:
# If alora_offsets[i] is None, this means that the invocation sequence was not found in the
# input. As a result, the weights should not be activated anywhere (equivalent to base model).
# Convert None -> 0 and clip to T
offsets = torch.tensor(
[0 if o is None else min(int(o), T) for o in alora_offsets],
device=device,
dtype=torch.long,
)
# Mask True on the last `offsets[i]` positions for each row i
pos = torch.arange(T, device=device).unsqueeze(0) # [1, T]
mask = pos >= (T - offsets).unsqueeze(1)
# Flatten for vectorization
x_flat = x.view(-1, Dx)
res_flat = result.view(-1, D)
mask_flat = mask.view(-1)
# Compute adapter on the selected tokens only
res_flat[mask_flat] += lora_B(lora_A(dropout(x_flat[mask_flat]))) * scaling
return result
def calculate_alora_offsets(
peft_config: PeftConfig, active_adapter: str, input_ids: torch.Tensor, adapter_names: Optional[list[str]] = None
) -> list[int]:
"""
This is a helper function for Activated LoRA (aLoRA) that searches each input token sequence for the last
occurrence of the appropriate "alora_invocation_tokens" invocation sequence. The calculated alora_offset is the
location of the *start* of the invocation tokens, counting backward from the end (will therefore always be >=
len(alora_invocation_tokens). If adapter_names is passed, then each input uses the appropriate invocation sequence
for the specified adapter for that row. Logic is provided to handle mixed collections of adapters for which not all
are aLoRAs (e.g. some base model, some LoRA).
"""
if input_ids is None:
return []
batch_size = input_ids.shape[0]
alora_offsets = [None] * batch_size
cached_invocation_tensors = {}
adapters_to_process_indices = collections.defaultdict(list)
for i in range(batch_size):
current_adapter_name = adapter_names[i] if adapter_names and i < len(adapter_names) else active_adapter
if current_adapter_name == "__base__":
alora_offsets[i] = None
continue
if current_adapter_name not in peft_config:
warnings.warn(f"Adapter '{current_adapter_name}' not found in peft_config. Using base model for row {i}.")
alora_offsets[i] = None
continue
current_peft_config = peft_config[current_adapter_name]
invocation_tokens = getattr(current_peft_config, "alora_invocation_tokens", None)
if invocation_tokens is None:
alora_offsets[i] = None # Not an aLoRA adapter or wrong type
continue
if current_adapter_name not in cached_invocation_tensors:
cached_invocation_tensors[current_adapter_name] = torch.tensor(
invocation_tokens, dtype=torch.long, device=input_ids.device
)
adapters_to_process_indices[current_adapter_name].append(i)
for adapter_name_to_process, indices in adapters_to_process_indices.items():
current_invocation_ids_tensor = cached_invocation_tensors[adapter_name_to_process]
invocation_len = len(current_invocation_ids_tensor)
for i in indices:
sequence = input_ids[i]
seq_len = len(sequence)
best_match_start_idx = -1
possible_starts = (sequence == current_invocation_ids_tensor[0]).nonzero(as_tuple=True)[0]
for start_idx_tensor in possible_starts:
idx = start_idx_tensor.item()
if idx + invocation_len <= seq_len:
if torch.equal(sequence[idx : idx + invocation_len], current_invocation_ids_tensor):
if idx > best_match_start_idx:
best_match_start_idx = idx
if best_match_start_idx != -1:
offset_val = seq_len - best_match_start_idx
alora_offsets[i] = offset_val if offset_val > 0 else None
else: # Invocation sequence not found in input
alora_offsets[i] = None
return alora_offsets
def is_alora_relevant_in_batch(model: nn.Module, adapter_names: Optional[list[str]] = None):
"""
Helper function to determine if the current batch has any aLoRA adapters.
"""
is_alora_relevant = False
if getattr(model.active_peft_config, "alora_invocation_tokens", None):
is_alora_relevant = True
elif adapter_names:
for name in adapter_names:
if name == "__base__":
continue
config_ = model.peft_config.get(name)
if config_ and getattr(config_, "alora_invocation_tokens", None):
is_alora_relevant = True
break
return is_alora_relevant
def get_alora_offsets_for_forward(
model: nn.Module, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, **kwargs
):
"""
Wrapper around calculate_alora_offsets, for the .forward of the model. It only calculates alora_offsets if the
batch contains aLoRA adapters.
"""
adapter_names_for_offset_calc = kwargs.get("adapter_names", None)
if not is_alora_relevant_in_batch(model, adapter_names_for_offset_calc):
# Nothing to compute
return kwargs
alora_offsets = kwargs.get("alora_offsets")
if alora_offsets is None:
if input_ids is None and inputs_embeds is not None:
warnings.warn(
"Cannot calculate aLoRA offsets when only inputs_embeds are provided. Disabling aLoRA for this forward pass."
)
kwargs["alora_offsets"] = None
elif input_ids is not None:
kwargs["alora_offsets"] = calculate_alora_offsets(
model.peft_config,
model.active_adapter,
input_ids,
adapter_names=adapter_names_for_offset_calc,
)
else:
kwargs["alora_offsets"] = None
return kwargs
def get_alora_offsets_for_generate(model: nn.module, *args, **kwargs):
"""
Wrapper around calculate_alora_offsets, for the .generate of the model. It only calculates alora_offsets if the
batch contains aLoRA adapters.
"""
adapter_names_for_offset_calc = kwargs.get("adapter_names")
if not is_alora_relevant_in_batch(model, adapter_names_for_offset_calc):
# Nothing to compute
return kwargs
alora_offsets_from_kwargs = kwargs.get("alora_offsets")
if alora_offsets_from_kwargs is None:
current_input_ids = kwargs.get("input_ids")
if current_input_ids is None: # args[0] is usually input_ids
if args and isinstance(args[0], torch.Tensor):
current_input_ids = args[0]
else:
current_input_ids = None
if current_input_ids is not None:
if current_input_ids.ndim == 1:
current_input_ids = current_input_ids.unsqueeze(0)
calculated_offsets = calculate_alora_offsets(
model.peft_config,
model.active_adapter,
current_input_ids,
adapter_names=adapter_names_for_offset_calc,
)
kwargs["alora_offsets"] = calculated_offsets
else:
warnings.warn(
"Cannot calculate aLoRA offsets during generate as input_ids are not available. Disabling aLoRA."
)
kwargs["alora_offsets"] = None
return kwargs
class BlockDiagonalLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
nblocks: int,
init_zero: bool = False,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.nblocks = nblocks
if self.in_features % nblocks != 0 or self.out_features % nblocks != 0:
raise ValueError(
f"self.in_features={self.in_features} or self.out_features={self.out_features} not divisible by {self.nblocks}"
)
# Create weight with specified dtype and device
self.weight = nn.Parameter(torch.empty(out_features, in_features // nblocks, dtype=dtype, device=device))
if init_zero:
torch.nn.init.zeros_(self.weight)
else:
torch.nn.init.kaiming_uniform_(self.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
first_dims = x.shape[:-1]
if x.dim() != 2:
x = x.reshape(-1, x.shape[-1])
B = x.shape[0]
nb = self.nblocks
m = x.shape[-1] // nb
n = self.out_features // nb
x = x.reshape(B, nb, m)
w = self.weight.view(nb, n, m)
out = torch.einsum("bim,inm->bin", x, w)
return out.reshape(*first_dims, -1)
def weight_as_blockdiagonal_matrix(self):
"""Returns weight in a format similar to a vanilla LoRA adapter. For this, we stack the blocks on the diagonal,
leaving the off-diagonals padded with zero."""
return torch.block_diag(*torch.chunk(self.weight, self.nblocks, dim=0))
class BdLoraLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs) -> None:
use_bdlora = config.use_bdlora
target_name = kwargs.get("target_name", "")
# Handle case where use_bdlora is a dict (from saved config) instead of BdLoraConfig object
if isinstance(use_bdlora, dict):
use_bdlora = BdLoraConfig(**use_bdlora)
lora_a_blockdiagonal_pattern = use_bdlora.target_modules_bd_a or []
lora_b_blockdiagonal_pattern = use_bdlora.target_modules_bd_b or []
nblocks = use_bdlora.nblocks
has_lora_a_blockdiagonal = any(pattern in target_name for pattern in lora_a_blockdiagonal_pattern)
has_lora_b_blockdiagonal = any(pattern in target_name for pattern in lora_b_blockdiagonal_pattern)
if has_lora_a_blockdiagonal and has_lora_b_blockdiagonal:
raise ValueError(f"Target {target_name} matches both A and B block-diagonal patterns")
if use_bdlora.match_strict and not (has_lora_a_blockdiagonal or has_lora_b_blockdiagonal):
raise ValueError(
f"Target {target_name} matches neither A nor B block-diagonal patterns."
"If this is intentional, set match_strict=False in BdLoraConfig during initialization. "
)
if has_lora_a_blockdiagonal:
r = module.lora_A[adapter_name].out_features
base_layer = module.get_base_layer()
layer = BlockDiagonalLinear(
base_layer.in_features,
r,
nblocks=nblocks,
init_zero=False,
dtype=base_layer.weight.dtype,
device=base_layer.weight.device,
)
module.lora_A[adapter_name] = layer
elif has_lora_b_blockdiagonal:
r = module.lora_B[adapter_name].in_features
base_layer = module.get_base_layer()
layer = BlockDiagonalLinear(
r,
base_layer.out_features,
nblocks=nblocks,
init_zero=True,
dtype=base_layer.weight.dtype,
device=base_layer.weight.device,
)
module.lora_B[adapter_name] = layer
@staticmethod
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor, **kwargs) -> torch.Tensor:
lora_A = module.lora_A[active_adapter]
lora_B = module.lora_B[active_adapter]
dropout = module.lora_dropout[active_adapter]
scaling = module.scaling[active_adapter]
x = dropout(x)
# Cast input dtype to match lora_A weight dtype
x = module._cast_input_dtype(x, lora_A.weight.dtype)
result += lora_B(lora_A(x)) * scaling
return result
@staticmethod
def _get_weight_from_module_maybe_blockdiagonal(module: nn.Module) -> torch.Tensor:
if isinstance(module, BlockDiagonalLinear):
return module.weight_as_blockdiagonal_matrix()
else:
return module.weight # type: ignore
@staticmethod
def _get_bdlora_delta_weight(module: Linear, adapter: str) -> torch.Tensor:
"""Similar to get_delta_weight for a linear module, but we have to eventually reshape the blocks
of the weights."""
device = module.lora_B[adapter].weight.device
# Use base layer dtype to ensure compatibility with merge/unmerge operations
base_layer = module.get_base_layer()
dtype = base_layer.weight.dtype
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
weight_A = BdLoraLinearVariant._get_weight_from_module_maybe_blockdiagonal(module.lora_A[adapter])
weight_B = BdLoraLinearVariant._get_weight_from_module_maybe_blockdiagonal(module.lora_B[adapter])
if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()
output_tensor = transpose(weight_B @ weight_A, module.fan_in_fan_out) * module.scaling[adapter]
if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
# Ensure output tensor matches base layer dtype
return output_tensor.to(dtype=dtype)
@staticmethod
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
return orig_weight + BdLoraLinearVariant._get_bdlora_delta_weight(module, active_adapter)
@staticmethod
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
orig_weight.data += BdLoraLinearVariant._get_bdlora_delta_weight(module, active_adapter)
@staticmethod
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
return orig_weight - BdLoraLinearVariant._get_bdlora_delta_weight(module, active_adapter)