Files
lora-lite/docs/refs/peft_lora_layer.py
wassname fdb4c77d6c Add reference-impl URLs to variant docstrings + V2 external review
- Fetch canonical reference impls for offline review:
  * peft_{lora,hra,delora,ia3}_layer.py + peft_lora_{dora,variants}.py
  * orig_pissa_init.py (MuLabPKU/PiSSA)
  * orig_hra_layer.py (DaShenZi721/HRA)
  * orig_delora.py (ExplainableML/DeLoRA author fork)
- Add reference-impl URLs to all 6 variant docstrings
- Document HRA gate=0 dead-grad issue and DoRA detach-omission in their docstrings
- Re-run external review (codex) with refs available -> docs/audit/variants_review_v2.md
  Major NEW findings vs paper-only review:
    * DeLoRA: scalar W.norm() should be per-input-channel norm(dim=0)
    * HRA: PEFT uses symmetric repeated-column init (no dead grad), not zero gate
    * IA3: FFN targets need input-side gating, not output, our up_proj advice wrong
    * All LoRA-family: cfg.dropout silently ignored (no-op)
    * DeLoRA: wnorm should be persistent buffer, not Parameter
  HRA and DeLoRA upgraded to BUGGY (from Partial)
2026-04-26 19:27:47 +08:00

2511 lines
111 KiB
Python

# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import math
import warnings
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import svd_lowrank
from transformers.pytorch_utils import Conv1D
from peft.import_utils import is_transformers_ge_v5_4_0
from peft.tuners._buffer_dict import BufferDict
from peft.tuners.tuners_utils import BaseTunerLayer, _get_in_out_features, check_adapters_to_merge
from peft.utils import ALLOWED_COMPUTE_DTYPES, UPCAST_DTYPES
from peft.utils.integrations import (
dequantize_module_weight,
gather_params_ctx,
get_bnb_param_type,
skip_init_on_device,
)
from peft.utils.loftq_utils import loftq_init
from peft.utils.other import transpose
from peft.utils.warning import PeftWarning
from .config import LoraConfig
VARIANT_KWARG_KEYS = ["alora_offsets"]
class LoraVariant:
"""
Base class for LoRA variants, e.g. DoRA.
This class should be subclassed and the methods below should be implemented accordingly. The methods should be
implemented as static methods, this makes it easier to combine variants.
Note for developers: These methods are prone to change and should thus considered to be "private". Use at your own
discretion.
"""
@staticmethod
def init(module: LoraLayer, adapter_name: str) -> None:
"""Initialization code for the LoRA variant, it's called within `update_layer`"""
raise NotImplementedError
@staticmethod
def merge_safe(module: LoraLayer, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
"""Safe merging of the weights from `merge(..., safe_merge=True)`, should return a new tensor"""
raise NotImplementedError
@staticmethod
def merge_unsafe(module: LoraLayer, active_adapter: str, orig_weight: torch.Tensor) -> None:
"""Unsafe merging of the weights from `merge(..., safe_merge=False)`, should modify the weight in-place"""
@staticmethod
def unmerge(module: LoraLayer, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
"""Remove the adapter weights from the original weights, then return them"""
@staticmethod
def forward(
module: LoraLayer,
active_adapter: str,
x: torch.Tensor,
result: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
The forward pass of the LoRA variant, should return the overall result (not just the diff)
Args:
module (LoraLayer): The module on which the forward pass is called
active_adapter (str): The name of the active adapter
x (torch.Tensor): The input to the forward call
result (torch.Tensor): The result from the base model
**kwargs: Additional arguments passed from [`LoraLayer.forward`].
"""
raise NotImplementedError
class LoraLayer(BaseTunerLayer):
# All names of layers that may contain (trainable) adapter weights
adapter_layer_names: tuple[str, ...] = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B")
# All names of other parameters that may contain adapter-related parameters
other_param_names: tuple[str, ...] = ("r", "lora_alpha", "scaling", "lora_dropout")
def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None:
self.base_layer = base_layer
self.r = {}
self.lora_alpha = {}
self.scaling = {}
self.lora_dropout = nn.ModuleDict({})
self.lora_A = nn.ModuleDict({})
self.lora_B = nn.ModuleDict({})
# For Embedding layer
self.lora_embedding_A = nn.ParameterDict({})
self.lora_embedding_B = nn.ParameterDict({})
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
self.use_dora: dict[str, bool] = {} # not actively used anymore after #2443, keep it for BC
self.use_rslora: dict[str, bool] = {}
self.lora_bias: dict[str, bool] = {}
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
self._caches: dict[str, Any] = {} # small ad hoc cache; values are not part of the state_dict
self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload
# flag to enable/disable casting of input to weight dtype during forward call
self.cast_input_dtype_enabled: bool = True
self.lora_variant: dict[str, LoraVariant] = {}
self.kwargs = kwargs
base_layer = self.get_base_layer()
in_features, out_features = self._get_in_out_features(base_layer)
self.in_features = in_features
self.out_features = out_features
def _get_in_out_features(self, module: nn.Module) -> tuple[int, int] | tuple[None, None]:
return _get_in_out_features(module)
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
"""Return a matching LoRA variant for this layer type.
Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this
method should return the DoRA variant for the given layer. If `use_alora=True`, same for aLoRA.
If there is no fitting variant, return None.
Note: If this layer type does not support the LoRA variant at all, please raise an error during __init__ as is
convention, and not here.
"""
return None
def update_layer(
self,
adapter_name: str,
r: int,
lora_alpha: int,
config: LoraConfig,
**kwargs,
) -> None:
# collect the kwargs
lora_dropout = config.lora_dropout
init_lora_weights = config.init_lora_weights
use_rslora = config.use_rslora
lora_bias = config.lora_bias
inference_mode = config.inference_mode
target_name = kwargs.get("target_name", "") # preserve target_name before overwriting kwargs
kwargs["target_name"] = target_name # restore target_name
tied_adapter = kwargs.get("tied_adapter", None)
# This code works for linear layers, override for other layer types
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
if lora_bias and (getattr(self.get_base_layer(), "bias", None) is None):
warnings.warn(
f"`lora_bias=True` was passed but the targeted layer of type {type(self.get_base_layer()).__name__} "
"has no bias. This means that merging LoRA weights won't be possible.",
PeftWarning,
)
lora_variant = self.resolve_lora_variant(config=config)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias)
# Tying adapters is only implemented for Linear layers
# where the source is the embedding layer.
# Currently, this is the most prevelant way of tying layers (weight tying)
if tied_adapter:
lora_A_params = tied_adapter["lora_A"]
lora_B_params = tied_adapter["lora_B"]
self.lora_A[adapter_name].weight = torch.nn.Parameter(lora_A_params)
self.lora_B[adapter_name].weight = torch.nn.Parameter(lora_B_params)
self.lora_bias[adapter_name] = lora_bias
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r
self.use_rslora[adapter_name] = use_rslora
self.use_dora[adapter_name] = config.use_dora
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
with gather_params_ctx(self.get_base_layer().weight):
self.pissa_init(adapter_name, init_lora_weights)
elif isinstance(init_lora_weights, str) and init_lora_weights.startswith("corda"):
with gather_params_ctx(self.get_base_layer().weight):
self.corda_init(adapter_name, init_lora_weights)
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
with gather_params_ctx(self.get_base_layer().weight):
self.olora_init(adapter_name)
elif init_lora_weights == "loftq":
with gather_params_ctx(self.get_base_layer().weight):
self.loftq_init(adapter_name, config)
elif init_lora_weights == "eva":
nn.init.zeros_(self.lora_B[adapter_name].weight)
elif init_lora_weights == "orthogonal":
with gather_params_ctx(self.get_base_layer().weight):
self.orthogonal_init(adapter_name)
elif init_lora_weights == "lora_ga":
with gather_params_ctx(self.get_base_layer().weight):
self.lora_ga_init(adapter_name, config.lora_ga_config)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)
# call this before init of the lora variants
self._move_adapter_to_device_of_base_layer(adapter_name)
if adapter_name in self.lora_variant:
self.lora_variant[adapter_name].init(self, adapter_name=adapter_name, config=config, **kwargs)
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
# Check for adapters that were added or removed from the arrow_model.
# The arrow model may be modified after creation by adding new experts
# (pre-trained or trainable) or by removing existing ones. Whenever such
# a change occurs, on_adapter_change() is called to update the set of
# active task-specific experts and, if needed, to handle recomputing prototypes
# and doing general knowledge subtraction (GKS) again.
if hasattr(self, "lora_arrow"):
for adapter in self.lora_variant:
if adapter in self.lora_arrow:
self.lora_arrow[adapter].on_adapter_change(self.lora_A, self.lora_B)
def reset_lora_parameters(self, adapter_name, init_lora_weights):
if init_lora_weights is not False:
if adapter_name in self.lora_A.keys():
if init_lora_weights is True:
# initialize A the same way as the default for nn.Linear and B to zero
# https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
elif init_lora_weights.lower() == "gaussian":
nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name])
else:
raise ValueError(f"Unknown initialization {init_lora_weights=}")
nn.init.zeros_(self.lora_B[adapter_name].weight)
if self.lora_bias[adapter_name]:
nn.init.zeros_(self.lora_B[adapter_name].bias)
if adapter_name in self.lora_embedding_A.keys():
# Initialize A to zeros and B the same way as the default for nn.Embedding, see:
# https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L59-L60
nn.init.zeros_(self.lora_embedding_A[adapter_name])
nn.init.normal_(self.lora_embedding_B[adapter_name])
if self.lora_bias[adapter_name]:
# embeddings are not supported at the moment, but still adding this for consistency
nn.init.zeros_(self.lora_embedding_B[adapter_name].bias)
# Always synchronize non-sharded LoRA weights across TP ranks, regardless of
# init_lora_weights, since each rank initializes weights independently.
# We could skip some broadcast, for instance when the lora weights are initialized to zero,
# but this is a minor optimization and would add extra complexity to the code.
if dist.is_available() and dist.is_initialized():
base_layer = self.get_base_layer()
tp_plan = getattr(base_layer, "_hf_tp_plan", None)
device_mesh = getattr(base_layer, "_hf_device_mesh", None)
if device_mesh is not None:
if tp_plan not in ["colwise", "rowwise", "embedding_rowwise"]:
warnings.warn(
f'tp_plan "{tp_plan}" found on the base layer is not supported for LoRA weight '
'synchronization. Expected one of "colwise", "rowwise", "embedding_rowwise". '
"LoRA weights may not be synchronized across ranks."
)
pg = device_mesh.get_group()
src = dist.get_global_rank(pg, 0)
if tp_plan == "colwise" and adapter_name in self.lora_A:
# The adapter weights need to be on device for broadcasting
self._move_adapter_to_device_of_base_layer(adapter_name)
dist.broadcast(self.lora_A[adapter_name].weight.data, src=src, group=pg)
elif tp_plan == "rowwise" and adapter_name in self.lora_B:
# The adapter weights need to be on device for broadcasting
self._move_adapter_to_device_of_base_layer(adapter_name)
dist.broadcast(self.lora_B[adapter_name].weight.data, src=src, group=pg)
elif tp_plan == "embedding_rowwise" and adapter_name in self.lora_embedding_B:
self._move_adapter_to_device_of_base_layer(adapter_name)
dist.broadcast(self.lora_embedding_B[adapter_name].data, src=src, group=pg)
def olora_init(self, adapter_name):
base_layer = self.get_base_layer()
orig_weight = base_layer.weight
bnb_param_type = get_bnb_param_type(orig_weight)
dtype = orig_weight.dtype
if bnb_param_type:
# check without importing bitsandbytes and robust to bnb_4bit_quant_storage=float*
weight_tensor = dequantize_module_weight(base_layer)
elif dtype in [torch.float32, torch.float16, torch.bfloat16]:
weight_tensor = orig_weight
else:
raise TypeError(f"Unsupported data type for the base layer. Got {dtype}.")
scale_factor = self.scaling[adapter_name]
r = self.r[adapter_name]
weight_tensor = weight_tensor.to(torch.float32)
Q, R = torch.linalg.qr(weight_tensor.data)
Qr, Rr = Q[:, :r], R[:r]
self.lora_A[adapter_name].weight.data = Rr.contiguous()
self.lora_B[adapter_name].weight.data = Qr.contiguous()
weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight
if bnb_param_type == "4bit":
weight_tensor = orig_weight.__class__(
weight_tensor,
quant_type=orig_weight.quant_type,
quant_storage=orig_weight.quant_storage,
compress_statistics=orig_weight.compress_statistics,
module=orig_weight.module,
).to(orig_weight.device)
base_layer.weight = weight_tensor
elif bnb_param_type == "8bit":
weight_tensor = orig_weight.__class__(
weight_tensor,
requires_grad=orig_weight.requires_grad,
has_fp16_weights=orig_weight.has_fp16_weights,
).to(orig_weight.device)
base_layer.weight = weight_tensor
else:
weight_tensor = weight_tensor.to(dtype)
base_layer.weight.data = weight_tensor
def pissa_init(self, adapter_name, init_lora_weights):
weight = self.get_base_layer().weight
dtype = weight.dtype
if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise TypeError(
"Please initialize PiSSA under float32, float16, or bfloat16. "
"Subsequently, re-quantize the residual model to help minimize quantization errors."
)
weight = transpose(weight.to(torch.float32), self.fan_in_fan_out)
if init_lora_weights == "pissa":
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
Vr = V[:, : self.r[adapter_name]]
Sr = S[: self.r[adapter_name]]
Sr /= self.scaling[adapter_name]
Uhr = Uh[: self.r[adapter_name]]
elif len(init_lora_weights.split("_niter_")) == 2:
Vr, Sr, Ur = svd_lowrank(
weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1])
)
Sr /= self.scaling[adapter_name]
Uhr = Ur.t()
else:
raise ValueError(
f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead."
)
lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr
lora_B = Vr @ torch.diag(torch.sqrt(Sr))
self.lora_A[adapter_name].weight.data = lora_A
self.lora_B[adapter_name].weight.data = lora_B
weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
weight = transpose(weight.to(dtype), self.fan_in_fan_out)
self.get_base_layer().weight.data = weight
def corda_init(self, adapter_name, init_lora_weights):
linear = self.get_base_layer()
weight = linear.weight
dtype = weight.dtype
if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise TypeError(
"Please initialize CorDA under float32, float16, or bfloat16. "
"Subsequently, re-quantize the residual model to help minimize quantization errors."
)
weight = weight.to(torch.float32)
# For Conv1D, weight is stored as (in_features, out_features), transposed compared to Linear
if isinstance(linear, Conv1D):
out_dim = weight.data.size(1)
in_dim = weight.data.size(0)
else:
out_dim = weight.data.size(0)
in_dim = weight.data.size(1)
# Calculate WC from covariance matrix
if not hasattr(linear, "eigens"):
raise ValueError(
"`eigens` attribute not found for layer, please run `preprocess_corda` first. "
"More information can be found at examples/corda_finetuning/README.md."
)
eigens = linear.eigens
U = eigens.U_WC
S = eigens.S_WC
V = eigens.V_WC
r = self.r[adapter_name]
# nan or inf check
if torch.isnan(S).any() or torch.isinf(S).any():
raise ValueError(
"Invalid value found in matrix S. Please file an issue at https://github.com/huggingface/peft/issues."
)
if torch.isnan(U).any() or torch.isinf(U).any():
raise ValueError(
"Invalid value found in matrix U. Please file an issue at https://github.com/huggingface/peft/issues."
)
if torch.isnan(V).any() or torch.isinf(V).any():
raise ValueError(
"Invalid value found in matrix V. Please file an issue at https://github.com/huggingface/peft/issues."
)
# Sanity check
if U.size(0) != out_dim or U.size(1) != r:
raise ValueError(
f"Matrix U size mismatch: {U.size()} vs. ({out_dim}, {r}). Please make sure the `lora_config` and "
"`model` argument of `preprocess_corda` is consistent with `get_peft_model`. If you're using cache "
"in `preprocess_corda`, please make sure the cache is built with the same model and LoRA rank."
)
if S.size(0) != r:
raise ValueError(
f"Matrix S size mismatch: {S.size()} vs. ({r},). Please make sure the `lora_config` and `model` argument "
"of `preprocess_corda` is consistent with `get_peft_model`. If you're using cache in `preprocess_corda`, "
"please make sure the cache is built with the same model and LoRA rank."
)
if V.size(0) != in_dim or V.size(1) != r:
raise ValueError(
f"Matrix V size mismatch: {V.size()} vs. ({in_dim}, {r}). Please make sure the `lora_config` and "
"`model` argument of `preprocess_corda` is consistent with `get_peft_model`. If you're using cache "
"in `preprocess_corda`, please make sure the cache is built with the same model and LoRA rank."
)
# Apply alpha
S /= self.scaling[adapter_name]
# Init lora_A and lora_B weights
lora_A = V.t().mul(S.sqrt().view(-1, 1)).contiguous()
lora_B = U.mul(S.sqrt()).contiguous()
self.lora_A[adapter_name].weight.data = lora_A
self.lora_B[adapter_name].weight.data = lora_B
# For Conv1D, lora_B @ lora_A gives (out_dim, in_dim) but weight is (in_dim, out_dim)
# So we need to transpose before subtraction
delta = self.scaling[adapter_name] * lora_B @ lora_A
delta = transpose(delta, fan_in_fan_out=self.fan_in_fan_out)
weight = weight.data - delta
weight = weight.to(dtype)
self.get_base_layer().weight.data = weight
# Remove redundant fields
del linear.eigens
def loftq_init(self, adapter_name, config: LoraConfig):
weight = self.get_base_layer().weight
kwargs = {
"num_bits": config.loftq_config["loftq_bits"],
"reduced_rank": self.r[adapter_name],
"num_iter": config.loftq_config["loftq_iter"],
}
qweight, lora_A, lora_B = loftq_init(weight, **kwargs)
if adapter_name in self.lora_A.keys():
# initialize A the same way as the default for nn.Linear and B to zero
self.lora_A[adapter_name].weight.data = lora_A
self.lora_B[adapter_name].weight.data = lora_B
if adapter_name in self.lora_embedding_A.keys():
# initialize a the same way as the default for nn.linear and b to zero
self.lora_embedding_A[adapter_name].weight.data = lora_A
self.lora_embedding_B[adapter_name].weight.data = lora_B
self.get_base_layer().weight.data = qweight
@torch.no_grad()
def orthogonal_init(self, adapter_name):
# https://datta0.github.io/posts/rethink-lora-init/#orthogonal-initialisation
rank = self.r[adapter_name]
if rank % 2 != 0:
raise ValueError(f"Orthogonal initialization requires the LoRA rank to be even, got {rank} instead.")
X = torch.randn(rank, rank)
Q, _ = torch.linalg.qr(X)
q_odd = Q[0::2, :] # Odd rows
q_even = Q[1::2, :] # Even rows
dtype = self.get_base_layer().weight.dtype
lora_A = torch.randn(self.in_features, rank // 2).mm(q_odd).T / 10.0
lora_B = torch.randn(rank // 2, self.out_features).T.mm(q_even) / 10.0
self.lora_A[adapter_name].weight = nn.Parameter(lora_A.contiguous().to(dtype))
self.lora_B[adapter_name].weight = nn.Parameter(lora_B.contiguous().to(dtype))
def lora_ga_init(self, adapter_name, lora_ga_config):
"""
Initialize LoRA weights using gradient approximation.
Uses SVD on the gradient matrix to initialize adapters in a way that aligns with the direction of full
fine-tuning.
Expects that `preprocess_loraga` has been called before this, which attaches the `loraga_grad` attribute to the
base layer.
If gradients are not available (e.g., when loading from a saved adapter), falls back to gaussian
initialization. The weights will be overwritten by the state_dict anyway.
"""
base_layer = self.get_base_layer()
# Check for gradient attached by preprocess_loraga
if not hasattr(base_layer, "_peft_loraga_grad"):
# When loading from saved adapter, gradients won't be available
# Fall back to gaussian initialization (weights will be overwritten by state_dict)
self.reset_lora_parameters(adapter_name, init_lora_weights=True)
return
grad = base_layer._peft_loraga_grad
# Check for lora_ga_config
if lora_ga_config is None:
raise ValueError(
"lora_ga_config must be provided when init_lora_weights='lora_ga'. "
"Please pass lora_ga_config=LoraGAConfig(...) to LoraConfig."
)
direction = lora_ga_config.direction
scale = lora_ga_config.scale
stable_gamma = lora_ga_config.stable_gamma
dtype = self.get_base_layer().weight.dtype
grad = grad.to(torch.float32)
weight = self.get_base_layer().weight
grad = transpose(grad, self.fan_in_fan_out)
r = self.r[adapter_name]
# torch.svd_lowrank returns (U, S, V) where grad ≈ U @ diag(S) @ V.T
# So V is shape (in_features, k) and we need V.T which is (k, in_features) for lora_A
U, S, V = torch.svd_lowrank(grad, q=min(4 * r, min(grad.shape)), niter=4)
# V is (in_features, k), we need Vh = V.T which is (k, in_features)
Vh = V.t()
U = U[:, : 2 * r]
S = S[: 2 * r]
Vh = Vh[: 2 * r, :]
if direction == "ArBr":
# Alternating: A takes rows at odd indices [1,3,5,7], B takes columns at even indices [0,2,4,6]
lora_A_weight = Vh[1 : 2 * r : 2, :] # Shape: (r, in_features)
lora_B_weight = U[:, 0 : 2 * r : 2] # Shape: (out_features, r)
S_B = S[0 : 2 * r : 2]
lora_B_weight = lora_B_weight @ torch.diag(S_B)
elif direction == "A2rBr":
# A takes second half rows [r:2r], B takes first half columns [:r]
lora_A_weight = Vh[r : 2 * r, :] # Shape: (r, in_features)
lora_B_weight = U[:, :r] # Shape: (out_features, r)
S_B = S[:r]
lora_B_weight = lora_B_weight @ torch.diag(S_B)
elif direction == "ArB2r":
# A takes first half rows [:r], B takes second half columns [r:2r]
lora_A_weight = Vh[:r, :] # Shape: (r, in_features)
lora_B_weight = U[:, r : 2 * r] # Shape: (out_features, r)
S_B = S[r : 2 * r]
lora_B_weight = lora_B_weight @ torch.diag(S_B)
elif direction == "random":
indices = torch.randperm(2 * r)[:r]
lora_A_weight = Vh[indices, :] # Shape: (r, in_features)
S_B = S[indices]
lora_B_weight = U[:, indices] @ torch.diag(S_B) # Shape: (out_features, r)
scaling_factor = self.scaling[adapter_name]
out_features = weight.shape[0]
if scale == "stable":
scale_factor = (out_features**0.25) / (stable_gamma**0.5)
lora_B_weight = lora_B_weight * scale_factor
elif scale == "weight_svd":
weight_data = transpose(weight.data.to(torch.float32), self.fan_in_fan_out)
_, weight_S, _ = torch.svd_lowrank(weight_data, q=r, niter=4)
if S_B[0] > 0:
scale_factor = weight_S[0] / S_B[0]
lora_B_weight = lora_B_weight * scale_factor
elif scale == "gd_scale":
lora_A_weight = lora_A_weight / scaling_factor
lora_B_weight = lora_B_weight / scaling_factor
# Convert to target dtype first to ensure weight offset matches adapter precision
lora_A_weight = lora_A_weight.to(dtype)
lora_B_weight = lora_B_weight.to(dtype)
# Assign LoRA weights
# lora_A should be (r, in_features), lora_B should be (out_features, r)
self.lora_A[adapter_name].weight.data = lora_A_weight.contiguous()
self.lora_B[adapter_name].weight.data = lora_B_weight.contiguous()
# Modify base weights: W_new = W_old - scaling * (B @ A)
# Important: compute offset in fp32 using dtype-converted weights to match forward pass precision
weight_data = transpose(weight.data.to(torch.float32), self.fan_in_fan_out)
weight_offset = scaling_factor * (lora_B_weight.float() @ lora_A_weight.float())
weight_data = weight_data - weight_offset
weight_data = transpose(weight_data.to(dtype), self.fan_in_fan_out)
self.get_base_layer().weight.data = weight_data
# Remove redundant fields
del base_layer._peft_loraga_grad
def _cache_store(self, key: str, value: Any) -> None:
# cache intermediate values, e.g. weight norm of DoRA
self._caches[key] = value
def _cache_pop(self, key: str) -> Any:
# retrieve and remove from ad hoc cache
value = self._caches.pop(key)
return value
def set_scale(self, adapter: str, scale: float | int) -> None:
"""Set the scale of the given adapter to the initial scale multiplied by the provided factor
The initial scale is determined by the configured `r` (rank) and `lora_alpha`.
"""
if adapter not in self.scaling:
# Ignore the case where the adapter is not in the layer
return
if self.use_rslora.get(adapter, False):
self.scaling[adapter] = scale * self.lora_alpha[adapter] / math.sqrt(self.r[adapter])
else:
self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter]
def scale_layer(self, scale: float | int) -> None:
"""Multiply the current scale of all active adapters by the provided factor"""
if scale == 1:
return
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
self.scaling[active_adapter] *= scale
def unscale_layer(self, scale: Optional[float | int] = None) -> None:
"""Divide the current scale of all active adapters by the provided factor. If `scale=None` is passed, reset to
initial scale
The initial scale is determined by the configured `r` (rank) and `lora_alpha`.
"""
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
if scale is None:
if self.use_rslora.get(active_adapter, False):
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / math.sqrt(self.r[active_adapter])
else:
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter]
else:
self.scaling[active_adapter] = self.scaling[active_adapter] / scale
def _check_forward_args(self, x, *args, **kwargs):
"""Check if the arguments are compatible with the configs and state of the model"""
adapter_names = kwargs.get("adapter_names", None)
if adapter_names is None:
return
if len(x) != len(adapter_names):
msg = (
"Length of `adapter_names` should be the same as the number of inputs, but got "
f"{len(adapter_names)} and {len(x)} respectively."
)
raise ValueError(msg)
if self.merged:
# It is unclear what would be the right thing to do if users pass adapter_names and there are merged
# adapters. Therefore, it is better to raise an error in this case.
msg = "Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first."
raise ValueError(msg)
# DoRA is not supported (yet), check that it's not being used. Don't check "__base__", as this is the
# placeholder for the base model.
unique_adapters = {name for name in adapter_names if name != "__base__"}
for adapter_name in unique_adapters:
if self.use_dora.get(adapter_name, False):
msg = "Cannot pass `adapter_names` when DoRA is enabled."
raise ValueError(msg)
def _mixed_batch_forward(
self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
) -> torch.Tensor:
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
# extra argument that allows mixing different adapters in the same batch at inference time.
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
alora_offsets = variant_kwargs.get("alora_offsets", None)
for i, active_adapter in enumerate(unique_adapters):
if active_adapter == "__base__":
continue
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype)
if active_adapter not in self.lora_variant: # vanilla LoRA
lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling
result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype)
else:
if alora_offsets is not None:
variant_kwargs["alora_offsets"] = [alora_offsets[j] for j in sub_batch_indices_list[i]]
lora_output = self.lora_variant[active_adapter].forward(
self,
active_adapter=active_adapter,
x=sub_batch,
result=result[sub_batch_indices_list[i]],
**variant_kwargs,
**kwargs,
)
result[sub_batch_indices_list[i]] = lora_output.to(torch_result_dtype)
return result
# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# and modified to work with PyTorch FSDP
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
class Linear(nn.Module, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
base_layer,
adapter_name: str,
config: LoraConfig,
r: int = 0,
lora_alpha: int = 1,
is_target_conv_1d_layer: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = config.fan_in_fan_out
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
config=config,
**kwargs,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer
def resolve_lora_variant(self, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if config.arrow_config is not None:
from .variants import ArrowLinearVariant
return ArrowLinearVariant()
if config.use_bdlora is not None:
from .variants import BdLoraLinearVariant
return BdLoraLinearVariant()
use_alora = config.alora_invocation_tokens is not None
if not config.use_dora and not use_alora:
return None
from .variants import ALoraLinearVariant, DoraLinearVariant
if use_alora:
return ALoraLinearVariant()
else:
return DoraLinearVariant()
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weight = base_layer.weight.data.clone()
orig_dtype = orig_weight.dtype
if active_adapter not in self.lora_variant: # vanilla LoRA
delta_weight = self.get_delta_weight(active_adapter)
orig_weight += delta_weight.to(orig_dtype)
else:
orig_weight = self.lora_variant[active_adapter].merge_safe(self, active_adapter, orig_weight)
if not torch.isfinite(orig_weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weight
if self.lora_bias[active_adapter]:
if getattr(base_layer, "bias", None) is None:
raise RuntimeError(
"Impossible to merge LoRA with `lora_bias=True` because the base layer has no bias."
)
new_bias = base_layer.bias + self.lora_B[active_adapter].bias * self.scaling[active_adapter]
if not torch.isfinite(new_bias).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.bias.data = new_bias.to(orig_dtype)
else:
if active_adapter not in self.lora_variant: # vanilla LoRA
delta_weight = self.get_delta_weight(active_adapter)
base_layer.weight.data += delta_weight
else:
self.lora_variant[active_adapter].merge_unsafe(self, active_adapter, base_layer.weight)
if self.lora_bias[active_adapter]:
if getattr(base_layer, "bias", None) is None:
raise RuntimeError(
"Impossible to merge LoRA with `lora_bias=True` because the base layer has no bias."
)
base_layer.bias.data += self.lora_B[active_adapter].bias * self.scaling[active_adapter]
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys():
weight = self.get_base_layer().weight
if active_adapter not in self.lora_variant: # vanilla LoRA
orig_dtype = weight.dtype
delta_weight = self.get_delta_weight(active_adapter)
weight.data -= delta_weight.to(orig_dtype)
else:
unmerged = self.lora_variant[active_adapter].unmerge(self, active_adapter, weight)
weight.data = unmerged
if self.lora_bias[active_adapter]:
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias * self.scaling[active_adapter]
def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_B[adapter].weight.device
dtype = self.lora_B[adapter].weight.dtype
# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight
if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()
output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
# cast back the weights
self.lora_A[adapter].weight.data = weight_A.to(dtype)
self.lora_B[adapter].weight.data = weight_B.to(dtype)
return output_tensor
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **variant_kwargs, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
lora_A_keys = self.lora_A.keys()
for active_adapter in self.active_adapters:
if active_adapter not in lora_A_keys:
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = self._cast_input_dtype(x, lora_A.weight.dtype)
if active_adapter not in self.lora_variant: # vanilla LoRA
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
result = self.lora_variant[active_adapter].forward(
self,
active_adapter=active_adapter,
x=x,
result=result,
**variant_kwargs,
**kwargs,
)
result = result.to(torch_result_dtype)
return result
def supports_lora_conversion(self, adapter_name: str = "default") -> bool:
return True
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
class _LoraEmbeddingAHolder(nn.Module):
"""
A "fake" module to hold the lora_embedding_A weights for the TP hooks.
"""
def __init__(self, lora_embedding_A_weight):
super().__init__()
self.weight = lora_embedding_A_weight.T # lora_embedding_A shape is (r, vocab_size)
class Embedding(nn.Module, LoraLayer):
# LoRA implemented in a Embedding layer
def __init__(
self,
base_layer: nn.Module,
adapter_name: str,
config: LoraConfig,
r: int = 0,
lora_alpha: int = 1,
init_lora_weights: Union[bool, str] = True,
**kwargs,
) -> None:
if config.lora_bias:
# lora_bias=True is not supported (yet) for embedding layers, as they use nn.Parameter
raise ValueError(f"lora_bias={config.lora_bias} is not supported for {self.__class__.__name__}.")
super().__init__()
LoraLayer.__init__(self, base_layer)
self.fan_in_fan_out = config.fan_in_fan_out
tp_plan = getattr(base_layer, "_hf_tp_plan", None)
self.device_mesh = getattr(base_layer, "_hf_device_mesh", None)
self.tp_layer = None
self.input_fns = {}
self.output_fns = {}
if tp_plan is not None:
if not is_transformers_ge_v5_4_0:
raise RuntimeError("Tensor Parallel with LoRA is only supported for transformers v5.4.0 and above. ")
if tp_plan != "embedding_rowwise":
raise ValueError(
f'Unsupported tensor parallel plan {tp_plan} for embedding layers. Only "embedding_rowwise" is '
"supported."
)
from transformers.integrations.tensor_parallel import ALL_PARALLEL_STYLES
self.tp_layer = copy.deepcopy(ALL_PARALLEL_STYLES[tp_plan])
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
config=config,
)
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if not config.use_dora:
return None
from .variants import DoraEmbeddingVariant
return DoraEmbeddingVariant()
def update_layer(
self,
adapter_name: str,
r: int,
lora_alpha: int,
config: LoraConfig,
**kwargs,
) -> None:
lora_dropout = config.lora_dropout
init_lora_weights = config.init_lora_weights
use_rslora = config.use_rslora
lora_bias = config.lora_bias
inference_mode = config.inference_mode
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
lora_variant = self.resolve_lora_variant(config=config)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout[adapter_name] = lora_dropout_layer
# Actual trainable parameters
weight_A = torch.randn((r, self.in_features))
weight_B = torch.randn((self.out_features, r))
self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A)
self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B)
self.lora_bias[adapter_name] = lora_bias
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r
self.use_rslora[adapter_name] = use_rslora
self.use_dora[adapter_name] = config.use_dora
if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights == "lora_ga":
# Embedding layers don't support LoRA-GA, fall back to standard initialization
self.reset_lora_parameters(adapter_name, True)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)
# call this before init of the lora variants
self._move_adapter_to_device_of_base_layer(adapter_name)
if adapter_name in self.lora_variant:
self.lora_variant[adapter_name].init(self, adapter_name=adapter_name, config=config, **kwargs)
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
# If there is tensor parallelism, we register the hooks for `self._embed`.
if self.tp_layer is not None:
mod = _LoraEmbeddingAHolder(self.lora_embedding_A[adapter_name])
def input_fn(inputs):
return self.tp_layer._prepare_input_fn(mod, inputs, self.device_mesh)
def output_fn(outputs):
return self.tp_layer._prepare_output_fn(mod, outputs, self.device_mesh)
self.input_fns[adapter_name] = input_fn
self.output_fns[adapter_name] = output_fn
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self.lora_embedding_A.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weight = base_layer.weight.data.clone()
if active_adapter not in self.lora_variant: # vanilla LoRA
orig_weight += self.get_delta_weight(active_adapter).to(orig_dtype)
else:
orig_weight = self.lora_variant[active_adapter].merge_safe(self, active_adapter, orig_weight)
if not torch.isfinite(orig_weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weight
else:
if active_adapter not in self.lora_variant: # vanilla LoRA
base_layer.weight.data += self.get_delta_weight(active_adapter).to(orig_dtype)
else:
self.lora_variant[active_adapter].merge_unsafe(self, active_adapter, base_layer.weight)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
orig_dtype = self.get_base_layer().weight.dtype
if active_adapter in self.lora_embedding_A.keys():
weight = self.get_base_layer().weight
if active_adapter not in self.lora_variant: # vanilla LoRA
weight.data -= self.get_delta_weight(active_adapter).to(orig_dtype)
else:
unmerged = self.lora_variant[active_adapter].unmerge(self, active_adapter, weight)
weight.data = unmerged
def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_embedding_B[adapter].device
dtype = self.lora_embedding_A[adapter].dtype
# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
weight_A = self.lora_embedding_A[adapter]
weight_B = self.lora_embedding_B[adapter]
if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()
output_tensor = transpose(weight_B @ weight_A, True) * self.scaling[adapter]
if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
# cast back the weights
self.lora_embedding_A[adapter] = weight_A.to(dtype)
self.lora_embedding_B[adapter] = weight_B.to(dtype)
return output_tensor
def _mixed_batch_forward(
self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
) -> torch.Tensor:
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
# extra argument that allows mixing different adapters in the same batch at inference time.
result = self.base_layer(x, *args, **kwargs)
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
embed_scale = self._get_embed_scale()
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
for i, active_adapter in enumerate(unique_adapters):
if active_adapter == "__base__":
continue
if active_adapter not in self.lora_embedding_A.keys():
continue
embedding_A = self.lora_embedding_A[active_adapter].T
embedding_B = self.lora_embedding_B[active_adapter].T
scaling = self.scaling[active_adapter]
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
sub_batch = x[sub_batch_indices_list[i]]
after_A = self._embed(sub_batch, embedding_A)
adapter_output = (after_A @ embedding_B) * scaling
# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)
result[sub_batch_indices_list[i]] += adapter_output
return result
def _embed(
self,
input: torch.Tensor,
weight: torch.Tensor,
input_fn: Callable[[tuple[torch.Tensor]], torch.Tensor] | None = None,
output_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
) -> torch.Tensor:
base_layer = self.get_base_layer()
if input_fn is not None:
input = input_fn((input,))
output = F.embedding(
input,
weight,
padding_idx=base_layer.padding_idx,
max_norm=base_layer.max_norm,
norm_type=base_layer.norm_type,
scale_grad_by_freq=base_layer.scale_grad_by_freq,
sparse=base_layer.sparse,
)
if output_fn is not None:
output = output_fn(output)
return output
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
# TODO: no dtype conversion here, unlike in Linear, is that correct?
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
embed_scale = self._get_embed_scale()
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_embedding_A:
continue
if active_adapter not in self.lora_variant: # vanilla LoRA
embedding_A = self.lora_embedding_A[active_adapter].T
embedding_B = self.lora_embedding_B[active_adapter].T
scaling = self.scaling[active_adapter]
# input and ouput function hooks for TP support.
input_fn = self.input_fns.get(active_adapter, None)
output_fn = self.output_fns.get(active_adapter, None)
after_A = self._embed(x, embedding_A, input_fn=input_fn, output_fn=output_fn)
adapter_output = (after_A @ embedding_B) * scaling
# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)
result = result + adapter_output
else:
result = self.lora_variant[active_adapter].forward(
self,
active_adapter=active_adapter,
x=x,
result=result,
**variant_kwargs,
**kwargs,
)
result = result.to(torch_result_dtype)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
class _ConvNd(nn.Module, LoraLayer):
# Lora implemented in a conv(2,3)d layer
def __init__(
self,
base_layer: nn.Module,
adapter_name: str,
config: LoraConfig,
r: int = 0,
lora_alpha: int = 1,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)
if kwargs.get("use_alora", False):
raise ValueError("aLoRA does not support adapting conv layers.")
if base_layer.groups > 1:
warnings.warn("LoRA adapter added to ConvNd layer with groups > 1. Merging is not supported.")
if r % base_layer.groups != 0:
raise ValueError(
f"Targeting a {base_layer.__class__.__name__} with groups={base_layer.groups} and rank {r}. "
"Currently, support is limited to conv layers where the rank is divisible by groups. "
"Either choose a different rank or do not target this specific layer."
)
self._active_adapter = adapter_name
self._kernel_dim = base_layer.weight.dim()
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
config=config,
)
def update_layer(
self,
adapter_name: str,
r: int,
lora_alpha: int,
config: LoraConfig,
**kwargs,
) -> None:
lora_dropout = config.lora_dropout
init_lora_weights = config.init_lora_weights
use_rslora = config.use_rslora
lora_bias = config.lora_bias
inference_mode = config.inference_mode
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
if lora_bias and (getattr(self.get_base_layer(), "bias", None) is None):
warnings.warn(
f"`lora_bias=True` was passed but the targeted layer of type {type(self.get_base_layer()).__name__} "
"has no bias. This means that merging LoRA weights won't be possible.",
PeftWarning,
)
lora_variant = self.resolve_lora_variant(config=config)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout[adapter_name] = lora_dropout_layer
# Actual trainable parameters
base_layer = self.get_base_layer()
kernel_size = base_layer.kernel_size
stride = base_layer.stride
padding = base_layer.padding
conv_layer = type(base_layer)
out_kernel = out_stride = (1,) * (self._kernel_dim - 2)
self.lora_A[adapter_name] = conv_layer(self.in_features, r, kernel_size, stride, padding, bias=False)
self.lora_B[adapter_name] = conv_layer(
r, self.out_features, out_kernel, out_stride, groups=base_layer.groups, bias=lora_bias
)
self.lora_bias[adapter_name] = lora_bias
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r
self.use_rslora[adapter_name] = use_rslora
self.use_dora[adapter_name] = config.use_dora
if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights == "lora_ga":
# Conv layers don't support LoRA-GA, fall back to standard initialization
self.reset_lora_parameters(adapter_name, True)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)
# call this before init of the lora variants
self._move_adapter_to_device_of_base_layer(adapter_name)
if adapter_name in self.lora_variant:
self.lora_variant[adapter_name].init(self, adapter_name=adapter_name, config=config, **kwargs)
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def _get_dora_factor_view(self):
return (-1,) + (1,) * (self._kernel_dim - 1)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights inside the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if base_layer.groups > 1:
# https://github.com/huggingface/peft/pull/2403
raise NotImplementedError("Merging is not supported for _ConvNd layers with groups > 1!")
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weight = base_layer.weight.data.clone()
if active_adapter not in self.lora_variant: # vanilla LoRA
delta_weight = self.get_delta_weight(active_adapter)
orig_weight += delta_weight.to(orig_dtype)
else:
orig_weight = self.lora_variant[active_adapter].merge_safe(self, active_adapter, orig_weight)
if not torch.isfinite(orig_weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weight
if self.lora_bias[active_adapter]:
if getattr(base_layer, "bias", None) is None:
raise RuntimeError(
"Impossible to merge LoRA with `lora_bias=True` because the base layer has no bias."
)
new_bias = base_layer.bias + self.lora_B[active_adapter].bias * self.scaling[active_adapter]
if not torch.isfinite(new_bias).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.bias.data = new_bias.to(orig_dtype)
else:
if active_adapter not in self.lora_variant: # vanilla LoRA
delta_weight = self.get_delta_weight(active_adapter)
base_layer.weight.data += delta_weight.to(orig_dtype)
else:
self.lora_variant[active_adapter].merge_unsafe(self, active_adapter, base_layer.weight)
if self.lora_bias[active_adapter]:
if getattr(base_layer, "bias", None) is None:
raise RuntimeError(
"Impossible to merge LoRA with `lora_bias=True` because the base layer has no bias."
)
base_layer.bias.data += self.lora_B[active_adapter].bias * self.scaling[active_adapter]
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys():
weight = self.get_base_layer().weight
if active_adapter not in self.lora_variant: # vanilla LoRA
orig_dtype = weight.dtype
delta_weight = self.get_delta_weight(active_adapter)
weight.data -= delta_weight.to(orig_dtype)
else:
unmerged = self.lora_variant[active_adapter].unmerge(self, active_adapter, weight)
weight.data = unmerged
if self.lora_bias[active_adapter]:
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias * self.scaling[active_adapter]
def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_B[adapter].weight.device
dtype = self.lora_A[adapter].weight.dtype
# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight
if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()
# https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
if self.get_base_layer().weight.size()[2:4] == (1, 1):
# conv2d 1x1
output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(
3
) * self.scaling[adapter]
else:
output_tensor = self.conv_fn(weight_A.transpose(0, 1), weight_B)
if self.get_base_layer().groups > 1:
output_tensor = output_tensor * self.scaling[adapter]
else:
output_tensor = output_tensor.transpose(0, 1) * self.scaling[adapter]
if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
# cast back the weights
self.lora_A[adapter].weight.data = weight_A.to(dtype)
self.lora_B[adapter].weight.data = weight_B.to(dtype)
return output_tensor
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = self._cast_input_dtype(x, lora_A.weight.dtype)
if active_adapter not in self.lora_variant: # vanilla LoRA
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
result = self.lora_variant[active_adapter].forward(
self,
active_adapter=active_adapter,
x=x,
result=result,
**variant_kwargs,
**kwargs,
)
result = result.to(torch_result_dtype)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
class Conv2d(_ConvNd):
# Lora implemented in a conv2d layer
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self._kernel_dim == 4:
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv2d
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if not config.use_dora:
return None
from .variants import DoraConv2dVariant
return DoraConv2dVariant()
class Conv1d(_ConvNd):
# Lora implemented in a conv1d layer
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self._kernel_dim == 3:
raise ValueError(f"Conv1d layer kernel must have 3 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv1d
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if not config.use_dora:
return None
from .variants import DoraConv1dVariant
return DoraConv1dVariant()
class Conv3d(_ConvNd):
# Lora implemented in a conv3d layer
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self._kernel_dim == 5:
raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv3d
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if not config.use_dora:
return None
from .variants import DoraConv3dVariant
return DoraConv3dVariant()
class MultiheadAttention(nn.Module, LoraLayer):
"""LoRA implemented in a multihead attention layer
This is currently only implemented for the case of `_qkv_same_embed_dim = True`, i.e. query, key, and value having
the same dimension.
Note: LoRA is applied to both the in_proj (query/key/value) and out_proj. There is currently no way to specify only
one of them. Don't try to apply LoRA to the out_proj of MultiheadAttention by targeting that layer specifically,
since the forward method of that layer is not being used, hence the LoRA adapter would be ignored.
This is a little bit hacky because of the way that MultiheadAttention is implemented in PyTorch: There are no
`nn.Linear` layers which we can hook onto or, in case of output projection, `.forward` is not used. This
implementation works around these problems by merging the weights before the forward call and unmerging them after
the forward call.
"""
def __init__(
self,
base_layer,
adapter_name: str,
config: LoraConfig,
r: int = 0,
lora_alpha: int = 1,
**kwargs,
) -> None:
# TODO work with separate weights
if not getattr(base_layer, "_qkv_same_embed_dim", True):
# default for this value appears to be True:
# https://github.com/pytorch/pytorch/blob/701ba5203fe68d55d655bd4d6c008be94cf34ea5/torch/nn/modules/activation.py#L1128-L1130
raise ValueError(
f"Only same embed for query/key/value is supported as of now for {self.__class__.__name__}."
)
if config.use_dora:
# TODO: probably not so hard to implement
raise ValueError(f"{self.__class__.__name__} does not support DoRA (yet), please set use_dora to False")
if kwargs.get("use_alora", False):
raise ValueError(f"{self.__class__.__name__} does not support aLoRA (yet), please set use_alora to False")
super().__init__()
LoraLayer.__init__(self, base_layer, **kwargs)
# Note: LoRA is applied to both in_proj and out_proj. There is currently no way to only specify one of them.
if isinstance(base_layer.out_proj, nn.Linear):
self.base_layer.out_proj = Linear(
base_layer.out_proj,
adapter_name,
r=r,
lora_alpha=lora_alpha,
config=config,
**kwargs,
)
else:
raise ValueError(f"out_proj must be an instance of nn.Linear for {self.__class__.__name__}.")
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha=lora_alpha, config=config)
@property
def embed_dim(self) -> int:
return self.get_base_layer().embed_dim
@property
def kdim(self) -> Optional[int]:
return self.get_base_layer().kdim
@property
def vdim(self) -> Optional[int]:
return self.get_base_layer().vdim
@property
def _qkv_same_embed_dim(self) -> bool:
return self.get_base_layer()._qkv_same_embed_dim
@property
def num_heads(self) -> int:
return self.get_base_layer().num_heads
@property
def dropout(self) -> float:
return self.get_base_layer().dropout
@property
def batch_first(self) -> bool:
return self.get_base_layer().batch_first
@property
def head_dim(self) -> int:
return self.get_base_layer().head_dim
@property
def in_proj_weight(self) -> nn.Parameter:
return self.get_base_layer().in_proj_weight
@property
def in_proj_bias(self) -> nn.Parameter:
return self.get_base_layer().in_proj_bias
@property
def out_proj(self) -> nn.Module:
return self.get_base_layer().out_proj.get_base_layer()
@property
def bias_k(self) -> Optional[nn.Parameter]:
return self.get_base_layer().bias_k
@property
def bias_v(self) -> Optional[nn.Parameter]:
return self.get_base_layer().bias_v
def merge_masks(self, *args, **kwargs) -> tuple[Optional[torch.Tensor], Optional[int]]:
return self.get_base_layer().merge_masks(*args, **kwargs)
@property
def add_zero_attn(self) -> bool:
return self.get_base_layer().add_zero_attn
def update_layer(self, *args, **kwargs) -> None:
super().update_layer(*args, **kwargs)
# Note: LoRA is applied to both in_proj and out_proj. There is currently no way to only specify one of them.
self.base_layer.out_proj.update_layer(*args, **kwargs)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
# Implementation follows this:
# https://github.com/Baijiong-Lin/LoRA-Torch/blob/4bfed6820b64fcf47064c30f30606a190a4f0d2e/loratorch/layers.py#L73-L79
# Notably, instead of mutating the weight, we delete the original weight and replace it by the merged weight
# TODO: work with separate weights
for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.out_proj.weight.dtype
if safe_merge:
# TODO: work with separate weights
# merging in_proj (nn.Parameter)
orig_weight_in = base_layer.in_proj_weight.data.detach().clone()
orig_weight_in += self.get_delta_weight(active_adapter).to(orig_dtype)
if not torch.isfinite(orig_weight_in).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
# merging out_proj (subclass of nn.Linear)
orig_weight_out = base_layer.out_proj.weight.data.detach().clone()
orig_weight_out += base_layer.out_proj.get_delta_weight(active_adapter).to(orig_dtype)
if not torch.isfinite(orig_weight_out).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
# unregister parameter implicitly and overwrite using merged weights; gradients are computed after
# forward and, thus, after unmerging (see forward()), therefore this is safe to do.
del base_layer.in_proj_weight
base_layer.in_proj_weight = orig_weight_in
del base_layer.out_proj.get_base_layer().weight
base_layer.out_proj.get_base_layer().weight = orig_weight_out
base_layer.out_proj.merge(adapter_names=[active_adapter])
else:
# merging in_proj (nn.Parameter)
# TODO: work with separate weights
delta_weight = self.get_delta_weight(active_adapter).to(orig_dtype)
weight_merged = base_layer.in_proj_weight.data.detach() + delta_weight
# unregister parameter implicitly and overwrite using merged weights; gradients are computed after
# forward and, thus, after unmerging (see forward()), therefore this is safe to do.
del base_layer.in_proj_weight
base_layer.in_proj_weight = weight_merged
# merging out_proj (subclass of nn.Linear)
delta_weight = base_layer.out_proj.get_delta_weight(active_adapter).to(orig_dtype)
weight_merged = base_layer.out_proj.weight.data.detach() + delta_weight
del base_layer.out_proj.get_base_layer().weight
base_layer.out_proj.get_base_layer().weight = weight_merged
base_layer.out_proj.merge(adapter_names=[active_adapter])
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
# TODO work with separate weights
base_layer = self.get_base_layer()
orig_dtype = base_layer.out_proj.base_layer.weight.dtype
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys():
# Ensure that requires_grad=False for the base weights after unmerging. This may not matter since
# requires_grad was False when the optimizer was initialized, but still let's try to be correct here.
# in_proj
delta_weight = self.get_delta_weight(active_adapter).to(orig_dtype)
old_weight = base_layer.in_proj_weight.data - delta_weight
del base_layer.in_proj_weight
base_layer.register_parameter("in_proj_weight", nn.Parameter(old_weight, requires_grad=False))
# out_proj
delta_weight = base_layer.out_proj.get_delta_weight(active_adapter).to(orig_dtype)
old_weight = base_layer.out_proj.base_layer.weight.data - delta_weight
del base_layer.out_proj.base_layer.weight
base_layer.out_proj.base_layer.register_parameter(
"weight", nn.Parameter(old_weight, requires_grad=False)
)
self.get_base_layer().out_proj.unmerge()
def unload_and_optionally_merge_module(
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
) -> nn.MultiheadAttention:
"""
Merging and unloading of the MultiheadAttention module
This requires an extra step for MultiheadAttention, which is why there is this special method instead of
relying on the normal merge_and_unload code path.
"""
if merge:
self.merge(safe_merge=safe_merge, adapter_names=adapter_names)
base_layer = self.get_base_layer()
# extra steps: re-register weights, take care of out_proj layer
# in_proj
weight = base_layer.in_proj_weight
del base_layer.in_proj_weight
base_layer.register_parameter("in_proj_weight", nn.Parameter(weight.data, requires_grad=weight.requires_grad))
# out_proj
out_proj_layer = base_layer.out_proj.get_base_layer()
weight = out_proj_layer.weight
del out_proj_layer.weight
out_proj_layer.register_parameter("weight", nn.Parameter(weight.data, requires_grad=weight.requires_grad))
base_layer.out_proj = out_proj_layer
return base_layer
def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_B[adapter].weight.device
dtype = self.lora_B[adapter].weight.dtype
# In case users wants to merge the adapter weights that are in
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16
weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight
if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()
output_tensor = (weight_B @ weight_A) * self.scaling[adapter]
if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
# cast back the weights
self.lora_A[adapter].weight.data = weight_A.to(dtype)
self.lora_B[adapter].weight.data = weight_B.to(dtype)
return output_tensor
def _check_forward_args(self, x, *args, **kwargs):
if "adapter_names" in kwargs:
raise TypeError(f"lora.{self.__class__.__name__} does not support mixed adapter batches.")
super()._check_forward_args(x, *args, **kwargs)
def forward(self, query: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = query.dtype
self._check_forward_args(query, *args, **kwargs)
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(query, *args, **kwargs)
elif self.merged:
result = self.base_layer(query, *args, **kwargs)
else:
out_proj = self.get_base_layer().out_proj
if out_proj.active_adapters != self.active_adapters:
# We have a case that in_proj and out_proj have diverging merged adapters. We cannot
# really deal with this correctly, thus it's better to raise than possibly create a hard to debug mess
cls_name = self.get_base_layer().__class__.__name__
raise ValueError(
f"The out_proj layer of {cls_name} has merged layers but {cls_name} itself doesn't; please ensure "
"that either both or none have merged layers"
)
# Merge all adapters that are active for this module, i.e. the LoRA weights for in_proj and out_proj.
# in_proj uses nn.Parameters, therefore, there is no forward method to be used and we have to explicitly
# merge for the LoRA weights to have an effect:
# https://github.com/pytorch/pytorch/blob/6ebb26d572d5fcdc6ac0d1297bdf8d1eb5d20722/torch/nn/modules/activation.py#L1020
# For out_proj, we have an nn.Linear (or rather: NonDynamicallyQuantizableLinear), but its forward method
# is not used:
# https://github.com/pytorch/pytorch/blob/6ebb26d572d5fcdc6ac0d1297bdf8d1eb5d20722/torch/nn/modules/activation.py#L1267-L1271
# Therefore, its LoRA weights also need to be merged to have an effect.
active_adapters = [a for a in self.active_adapters if a in self.lora_A]
try:
self.merge(adapter_names=active_adapters)
result = self.base_layer(query, *args, **kwargs)
finally:
# it's safe to call unmerge(), which unmerges all adapters, because we checked that not self.merged,
# i.e. there is was no merged layer before
self.unmerge()
result = (result[0].to(previous_dtype), result[1].to(previous_dtype) if result[1] is not None else result[1])
return result
# The decorator is needed in case low_cpu_mem_usage=True is used, as we don't want the base layer weights to be
# moved to meta device. This requires the use of PEFT's implementation of init_empty_weight instead of using the one
# from accelerate.
@skip_init_on_device
def _restore_weights(self):
# Restore the weights as registered parameters on the base layer.
# This is necessary because the way that weights are merged/unmerged (which is necessary for forward to work
# correctly), the Module "forgets" these attributes. Therefore, we need to call register_parameter explicitly.
# We cannot call register_parameter for merging/unmerging because that cuts them off from the autograd graph.
# Note that this is hacky, since we need to ensure that _restore_weights is called by each method that needs it.
# in_proj
# TODO work with separate weights
base_layer = self.get_base_layer()
weight = base_layer.in_proj_weight
del base_layer.in_proj_weight
base_layer.register_parameter("in_proj_weight", nn.Parameter(weight.data, requires_grad=weight.requires_grad))
# out_proj
base_layer = base_layer.out_proj.get_base_layer()
weight = base_layer.weight
del base_layer.weight
base_layer.register_parameter("weight", nn.Parameter(weight.data, requires_grad=weight.requires_grad))
def state_dict(self, *args, **kwargs):
self._restore_weights()
return super().state_dict(*args, **kwargs)
def named_modules(self, *args, **kwargs):
# Note: no need to also implement modules(), as modules() calls named_modules() under the hood
self._restore_weights()
return super().named_modules(*args, **kwargs)
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
class _LoraParameterProxy(nn.Module):
"""This proxies an `nn.Parameter` that is targeted with LoRA.
Intended to be used in conjunction with `nn.utils.parametrize`, see `ParamWrapper`.
"""
def __init__(self, delta_weight):
super().__init__()
self.delta_weight = delta_weight
@staticmethod
def _low_prec_add(x, y):
# addition in fp8 is not directly supported, need to use a higher precision
orig_dtype = x.dtype
upcast_dtype = y.dtype
if upcast_dtype not in ALLOWED_COMPUTE_DTYPES:
raise RuntimeError(
f"There is an attempt to upcast the targeted parameter to {upcast_dtype} "
f"but the only supported are: {ALLOWED_COMPUTE_DTYPES}."
)
# this operation can be quite costly
x = x.to(upcast_dtype)
z = x + y
# clamp to valid range before casting down, as this is *not* performed automatically and can thus result in NANs
info = torch.finfo(orig_dtype)
z = z.clamp(min=info.min, max=info.max)
return z.to(orig_dtype)
def forward(self, W):
if any(getattr(torch, dtype_name, None) == W.dtype for dtype_name in UPCAST_DTYPES):
return self._low_prec_add(W, self.delta_weight)
return W + self.delta_weight
# copied from:
# https://github.com/pytorch/pytorch/blob/5e386eec9426f174eea130c0c012d9f65ebe65fb/torch/nn/utils/parametrize.py#L75-L79
def _register_parameter_or_buffer(module, name, X):
if isinstance(X, nn.Parameter):
module.register_parameter(name, X)
else:
module.register_buffer(name, X)
class ParamWrapper(nn.Module, LoraLayer):
"""A LoRA wrapper for `nn.Parameter`. This layer is dispatched if users target a parameter directly with
`lora_config.target_parameters`
Note:
- When accessing the wrapped nn.Parameter directly, e.g. via `module.weight`, the LoRA weights are *not*
applied.
- It is currently not implemented to target multiple parameters on the same module. To achieve this, it is
currently required to create a separate LoRA adapter (with another adapter name) and activate both at the
same time.
"""
def __init__(
self,
base_layer,
adapter_name: str,
parameter_name: str,
config: LoraConfig,
r: int = 0,
lora_alpha: int = 1,
is_target_conv_1d_layer: bool = False,
**kwargs,
) -> None:
self.parameter_name = parameter_name
super().__init__()
LoraLayer.__init__(self, base_layer, **kwargs)
if config.lora_dropout:
# It's not possible to factor out x from lora_B(lora_A(dropout(x))), so dropout can't be correctly
# implemented
raise ValueError(f"lora.{self.__class__.__name__} does not work with lora_dropout != 0.")
if config.fan_in_fan_out:
raise ValueError(f"lora.{self.__class__.__name__} does not work with fan_in_fan_out.")
if config.lora_bias:
raise ValueError(f"lora.{self.__class__.__name__} does not work with lora_bias=True.")
if config.use_dora:
raise ValueError(f"lora.{self.__class__.__name__} does not work with use_dora=True.")
if is_target_conv_1d_layer:
raise ValueError(f"lora.{self.__class__.__name__} does not work with is_target_conv_1d_layer=True.")
self.fan_in_fan_out = config.fan_in_fan_out
self._did_swap_in_out_features = False # ensure we swap only once
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
config=config,
)
def _get_in_out_features(self, module: nn.Module) -> tuple[int, int] | tuple[None, None]:
# For ParamWrapper, we don't derive the in_features and out_features based on the base layer type, but directly
# from the targeted parameter.
param = self.get_param()
if param.ndim == 3:
num_experts, in_features, out_features = param.shape
else:
num_experts, in_features, out_features = 1, param.shape[1], param.shape[0]
if param.ndim not in (2, 3):
raise ValueError(
f"lora.{self.__class__.__name__} was initialized with {param.ndim} dimensional Parameter, but only 2d "
"and 3d are supported."
)
# we have to store the num_experts attribute here, as the parent class only stores in_features and out_features.
self.num_experts = num_experts
return in_features, out_features
def update_layer(
self,
adapter_name: str,
r: int,
lora_alpha: int,
config: LoraConfig,
**kwargs,
) -> None:
# same method as in lora.Linear but taking into account that there can be multiple experts (3d parameter)
lora_dropout = config.lora_dropout
init_lora_weights = config.init_lora_weights
use_rslora = config.use_rslora
lora_bias = config.lora_bias
inference_mode = config.inference_mode
# This code works for linear layers, override for other layer types
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
lora_variant = self.resolve_lora_variant(config=config)
if lora_variant is not None:
raise ValueError(f"lora.{self.__class__.__name__} does not work with LoRA variants like DoRA.")
# for some MoE layers, the order is (experts, out_features, in_features)
is_transposed = getattr(self.get_base_layer(), "is_transposed", False)
swap_in_out_features = (self.get_param().ndim == 3) and not is_transposed
if swap_in_out_features and not self._did_swap_in_out_features:
self.in_features, self.out_features = self.out_features, self.in_features
self._did_swap_in_out_features = True
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
# It's not possible to factor out x from lora_B(lora_A(dropout(x))), so dropout can't be correctly
# implemented
raise ValueError(f"lora.{self.__class__.__name__} does not work with lora_dropout != 0.")
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
# Difference to normal update_layer: consider experts. LoRA layers still use nn.Linear for consistency with
# lora.Linear.
self.lora_A[adapter_name] = nn.Linear(self.in_features, r * self.num_experts, bias=False)
self.lora_B[adapter_name] = nn.Linear(r * self.num_experts, self.out_features, bias=lora_bias)
self.lora_bias[adapter_name] = lora_bias
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r
self.use_rslora[adapter_name] = use_rslora
self.use_dora[adapter_name] = config.use_dora
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
with gather_params_ctx(self.get_base_layer().weight):
self.pissa_init(adapter_name, init_lora_weights)
elif isinstance(init_lora_weights, str) and init_lora_weights.startswith("corda"):
with gather_params_ctx(self.get_base_layer().weight):
self.corda_init(adapter_name, init_lora_weights)
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
with gather_params_ctx(self.get_base_layer().weight):
self.olora_init(adapter_name)
elif init_lora_weights == "loftq":
with gather_params_ctx(self.get_base_layer().weight):
self.loftq_init(adapter_name)
elif init_lora_weights == "eva":
nn.init.zeros_(self.lora_B[adapter_name].weight)
elif init_lora_weights == "orthogonal":
with gather_params_ctx(self.get_base_layer().weight):
self.orthogonal_init(adapter_name)
elif init_lora_weights == "lora_ga":
with gather_params_ctx(self.get_base_layer().weight):
self.lora_ga_init(adapter_name, config.lora_ga_config)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)
# call this before init of the lora variants
self._move_adapter_to_device_of_base_layer(adapter_name)
if adapter_name in self.lora_variant:
self.lora_variant[adapter_name].init(self, config=config, **kwargs)
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None:
"""
Move the adapter of the given name to the device of the base layer. Needs special handling for nn.Parameter
"""
device = self.get_param().device
meta = torch.device("meta")
param = self.get_param()
for adapter_layer_name in self.adapter_layer_names + self.other_param_names:
adapter_layer = getattr(self, adapter_layer_name, None)
if not isinstance(adapter_layer, (nn.ModuleDict, nn.ParameterDict, BufferDict)):
continue
if adapter_name not in adapter_layer:
continue
if any(p.device == meta for p in adapter_layer.parameters()):
continue
if param.dtype.is_floating_point or param.dtype.is_complex:
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device, dtype=param.dtype)
else:
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device)
def get_param(self):
param = getattr(self.get_base_layer(), self.parameter_name)
return param
def get_delta_weight(self, adapter_name, *args, **kwargs):
if self.num_experts == 1:
# could actually be a normal layer or experts stacked block-diagonally, acting like a single layer
delta_weight = Linear.get_delta_weight(self, adapter_name, *args, **kwargs)
else:
weight_A = self.lora_A[adapter_name].weight
weight_B = self.lora_B[adapter_name].weight
# shape: experts x rank x in_features
weight_A = weight_A.reshape(self.num_experts, -1, weight_A.shape[-1])
# shape: out_features x rank x experts
weight_B = weight_B.reshape(weight_B.shape[0], -1, self.num_experts)
# fan_in_fan_out must be False, so no transpose call here
if not self._did_swap_in_out_features:
delta_weight = torch.einsum("o r e, e r i -> e i o", weight_B, weight_A) * self.scaling[adapter_name]
else:
# for some MoE layers, the order is (experts, out_features, in_features)
delta_weight = torch.einsum("o r e, e r i -> e o i", weight_B, weight_A) * self.scaling[adapter_name]
param = self.get_param()
if param.dtype in ALLOWED_COMPUTE_DTYPES:
delta_weight = delta_weight.to(param.device, param.dtype)
else:
# don't cast dW to weight dtype if it is in torch.float8_e4m3fn etc. as these low precision dtypes because
# we want to perform the W+dW addition in high precision before downcasting, see _low_prec_add.
delta_weight = delta_weight.to(param.device)
return delta_weight
@contextmanager
def _activate_lora(self, active_adapters: list[str]):
if not active_adapters or not any(adapter in self.lora_A for adapter in active_adapters):
# no active adapters for this layer
yield
return
delta_weight = None
for active_adapter in active_adapters:
if active_adapter not in self.lora_A:
continue
if delta_weight is None:
delta_weight = self.get_delta_weight(active_adapter)
else:
delta_weight = delta_weight + self.get_delta_weight(active_adapter)
base_layer = self.get_base_layer()
requires_grad_before = self.get_param().requires_grad
nn.utils.parametrize.register_parametrization(
base_layer, self.parameter_name, _LoraParameterProxy(delta_weight)
)
# set requires_grad, as it defaults to False
base_layer.parametrizations[self.parameter_name].original.requires_grad_(requires_grad_before)
try:
with nn.utils.parametrize.cached():
yield
finally:
self._remove_parametrizations()
def _remove_parametrizations(self):
# Remove the parametrization of this specific parameter
base_layer = self.get_base_layer()
parameter_name = self.parameter_name
if parameter_name not in base_layer.parametrizations:
raise ValueError(
"Something went wrong, please report this issue on PEFT: https://github.com/huggingface/peft/issues"
)
param_list = base_layer.parametrizations[parameter_name]
if len(param_list) == 1:
# last parametrization, we can safely remove it completely
nn.utils.parametrize.remove_parametrizations(base_layer, parameter_name, leave_parametrized=False)
return
# If there are multiple parametrizations for the same parameter_name, we only want to remove the LoRA proxy.
# Unfortunately, PyTorch does not support this directly, so we need to take care of it manually. To achieve
# this, we check the ParameterList from the back until we find the _LoraParameterProxy instance and then remove
# it.
reversed_indices = reversed(range(len(param_list)))
for i in reversed_indices:
module = param_list[i]
if isinstance(module, _LoraParameterProxy):
del param_list[i]
break
else: # no break encountered
# this should not happen, but raising an error is probably not necessary
warnings.warn(
f"Could not find any LoRA parametrization on {self}, please open an issue on "
"https://github.com/huggingface/peft/issues and report this warning."
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
# same as lora.Linear.merge but not hard-coding base_layer.weight and without special cases like variants removed
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
param = getattr(base_layer, self.parameter_name)
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weight = param.data.clone()
orig_dtype = orig_weight.dtype
delta_weight = self.get_delta_weight(active_adapter)
orig_weight += delta_weight.to(orig_dtype)
if not torch.isfinite(orig_weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
param.data = orig_weight
else:
delta_weight = self.get_delta_weight(active_adapter)
param.data += delta_weight
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
# same as lora.Linear.unmerge but not hard-coding base_layer.weight and without special cases like variants removed
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys():
param = getattr(self.get_base_layer(), self.parameter_name)
orig_dtype = param.dtype
delta_weight = self.get_delta_weight(active_adapter)
param.data -= delta_weight.to(orig_dtype)
def _check_forward_args(self, x, *args, **kwargs):
"""Check if the arguments are compatible with the configs and state of the model"""
if kwargs.get("adapter_names", None):
raise ValueError(f"lora.{self.__class__.__name__} does not support mixed adapter batches yet.")
super()._check_forward_args(x, *args, **kwargs)
def unload_and_optionally_merge_module(self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]):
base_layer = self.base_layer
# ParamWrappers can be nested, so merge and retrieve base layer recursively
if merge:
self.merge(safe_merge=safe_merge, adapter_names=adapter_names)
while isinstance(base_layer, ParamWrapper):
base_layer.merge(safe_merge=safe_merge, adapter_names=adapter_names)
base_layer = base_layer.base_layer
else:
base_layer = self.get_base_layer()
return base_layer
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
raise ValueError(f"lora.{self.__class__.__name__} does not support mixed batch inference")
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
with self._activate_lora(self.active_adapters):
result = self.base_layer(x, *args, **kwargs)
return result
def __repr__(self) -> str:
rep = super().__repr__()
idx = rep.find("(") + 1
# insert the name of the parameter to allow the repr to be disambiguous when multiple parameters on the same
# module are being targeted
rep = f"{rep[:idx]}\n parameter_name='{self.parameter_name}',{rep[idx:]}"
return "lora." + rep
def dispatch_default(
target: torch.nn.Module,
adapter_name: str,
config: LoraConfig,
parameter_name: Optional[str] = None,
**kwargs,
) -> Optional[torch.nn.Module]:
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if parameter_name is not None:
new_module = ParamWrapper(target, adapter_name, parameter_name=parameter_name, config=config, **kwargs)
elif isinstance(target_base_layer, torch.nn.Embedding):
new_module = Embedding(target, adapter_name, config=config, **kwargs)
elif isinstance(target_base_layer, torch.nn.Conv2d):
new_module = Conv2d(target, adapter_name, config=config, **kwargs)
elif isinstance(target_base_layer, torch.nn.Conv3d):
new_module = Conv3d(target, adapter_name, config=config, **kwargs)
elif isinstance(target_base_layer, nn.Conv1d):
new_module = Conv1d(target, adapter_name, config=config, **kwargs)
elif isinstance(target_base_layer, torch.nn.MultiheadAttention):
new_module = MultiheadAttention(target, adapter_name, config=config, **kwargs)
elif isinstance(target_base_layer, torch.nn.Linear):
if config.fan_in_fan_out:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
config.fan_in_fan_out = False
new_module = Linear(target, adapter_name, config=config, **kwargs)
elif isinstance(target_base_layer, Conv1D):
if not config.fan_in_fan_out:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True."
)
config.fan_in_fan_out = True
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, config=config, **kwargs)
return new_module