mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 19:15:35 +08:00
fdb4c77d6c
- Fetch canonical reference impls for offline review:
* peft_{lora,hra,delora,ia3}_layer.py + peft_lora_{dora,variants}.py
* orig_pissa_init.py (MuLabPKU/PiSSA)
* orig_hra_layer.py (DaShenZi721/HRA)
* orig_delora.py (ExplainableML/DeLoRA author fork)
- Add reference-impl URLs to all 6 variant docstrings
- Document HRA gate=0 dead-grad issue and DoRA detach-omission in their docstrings
- Re-run external review (codex) with refs available -> docs/audit/variants_review_v2.md
Major NEW findings vs paper-only review:
* DeLoRA: scalar W.norm() should be per-input-channel norm(dim=0)
* HRA: PEFT uses symmetric repeated-column init (no dead grad), not zero gate
* IA3: FFN targets need input-side gating, not output, our up_proj advice wrong
* All LoRA-family: cfg.dropout silently ignored (no-op)
* DeLoRA: wnorm should be persistent buffer, not Parameter
HRA and DeLoRA upgraded to BUGGY (from Partial)
2511 lines
111 KiB
Python
2511 lines
111 KiB
Python
# Copyright 2023-present the HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import math
|
|
import warnings
|
|
from collections.abc import Callable
|
|
from contextlib import contextmanager
|
|
from typing import Any, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import svd_lowrank
|
|
from transformers.pytorch_utils import Conv1D
|
|
|
|
from peft.import_utils import is_transformers_ge_v5_4_0
|
|
from peft.tuners._buffer_dict import BufferDict
|
|
from peft.tuners.tuners_utils import BaseTunerLayer, _get_in_out_features, check_adapters_to_merge
|
|
from peft.utils import ALLOWED_COMPUTE_DTYPES, UPCAST_DTYPES
|
|
from peft.utils.integrations import (
|
|
dequantize_module_weight,
|
|
gather_params_ctx,
|
|
get_bnb_param_type,
|
|
skip_init_on_device,
|
|
)
|
|
from peft.utils.loftq_utils import loftq_init
|
|
from peft.utils.other import transpose
|
|
from peft.utils.warning import PeftWarning
|
|
|
|
from .config import LoraConfig
|
|
|
|
|
|
VARIANT_KWARG_KEYS = ["alora_offsets"]
|
|
|
|
|
|
class LoraVariant:
|
|
"""
|
|
Base class for LoRA variants, e.g. DoRA.
|
|
|
|
This class should be subclassed and the methods below should be implemented accordingly. The methods should be
|
|
implemented as static methods, this makes it easier to combine variants.
|
|
|
|
Note for developers: These methods are prone to change and should thus considered to be "private". Use at your own
|
|
discretion.
|
|
"""
|
|
|
|
@staticmethod
|
|
def init(module: LoraLayer, adapter_name: str) -> None:
|
|
"""Initialization code for the LoRA variant, it's called within `update_layer`"""
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def merge_safe(module: LoraLayer, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
"""Safe merging of the weights from `merge(..., safe_merge=True)`, should return a new tensor"""
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def merge_unsafe(module: LoraLayer, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
|
"""Unsafe merging of the weights from `merge(..., safe_merge=False)`, should modify the weight in-place"""
|
|
|
|
@staticmethod
|
|
def unmerge(module: LoraLayer, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
"""Remove the adapter weights from the original weights, then return them"""
|
|
|
|
@staticmethod
|
|
def forward(
|
|
module: LoraLayer,
|
|
active_adapter: str,
|
|
x: torch.Tensor,
|
|
result: torch.Tensor,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""
|
|
The forward pass of the LoRA variant, should return the overall result (not just the diff)
|
|
|
|
Args:
|
|
module (LoraLayer): The module on which the forward pass is called
|
|
active_adapter (str): The name of the active adapter
|
|
x (torch.Tensor): The input to the forward call
|
|
result (torch.Tensor): The result from the base model
|
|
**kwargs: Additional arguments passed from [`LoraLayer.forward`].
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class LoraLayer(BaseTunerLayer):
|
|
# All names of layers that may contain (trainable) adapter weights
|
|
adapter_layer_names: tuple[str, ...] = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B")
|
|
# All names of other parameters that may contain adapter-related parameters
|
|
other_param_names: tuple[str, ...] = ("r", "lora_alpha", "scaling", "lora_dropout")
|
|
|
|
def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None:
|
|
self.base_layer = base_layer
|
|
self.r = {}
|
|
self.lora_alpha = {}
|
|
self.scaling = {}
|
|
self.lora_dropout = nn.ModuleDict({})
|
|
self.lora_A = nn.ModuleDict({})
|
|
self.lora_B = nn.ModuleDict({})
|
|
# For Embedding layer
|
|
self.lora_embedding_A = nn.ParameterDict({})
|
|
self.lora_embedding_B = nn.ParameterDict({})
|
|
# Mark the weight as unmerged
|
|
self._disable_adapters = False
|
|
self.merged_adapters = []
|
|
self.use_dora: dict[str, bool] = {} # not actively used anymore after #2443, keep it for BC
|
|
self.use_rslora: dict[str, bool] = {}
|
|
self.lora_bias: dict[str, bool] = {}
|
|
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
|
|
self._caches: dict[str, Any] = {} # small ad hoc cache; values are not part of the state_dict
|
|
self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload
|
|
# flag to enable/disable casting of input to weight dtype during forward call
|
|
self.cast_input_dtype_enabled: bool = True
|
|
self.lora_variant: dict[str, LoraVariant] = {}
|
|
self.kwargs = kwargs
|
|
|
|
base_layer = self.get_base_layer()
|
|
in_features, out_features = self._get_in_out_features(base_layer)
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
|
|
def _get_in_out_features(self, module: nn.Module) -> tuple[int, int] | tuple[None, None]:
|
|
return _get_in_out_features(module)
|
|
|
|
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
|
|
"""Return a matching LoRA variant for this layer type.
|
|
|
|
Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this
|
|
method should return the DoRA variant for the given layer. If `use_alora=True`, same for aLoRA.
|
|
|
|
If there is no fitting variant, return None.
|
|
|
|
Note: If this layer type does not support the LoRA variant at all, please raise an error during __init__ as is
|
|
convention, and not here.
|
|
|
|
"""
|
|
return None
|
|
|
|
def update_layer(
|
|
self,
|
|
adapter_name: str,
|
|
r: int,
|
|
lora_alpha: int,
|
|
config: LoraConfig,
|
|
**kwargs,
|
|
) -> None:
|
|
# collect the kwargs
|
|
lora_dropout = config.lora_dropout
|
|
init_lora_weights = config.init_lora_weights
|
|
use_rslora = config.use_rslora
|
|
lora_bias = config.lora_bias
|
|
inference_mode = config.inference_mode
|
|
|
|
target_name = kwargs.get("target_name", "") # preserve target_name before overwriting kwargs
|
|
kwargs["target_name"] = target_name # restore target_name
|
|
tied_adapter = kwargs.get("tied_adapter", None)
|
|
|
|
# This code works for linear layers, override for other layer types
|
|
if r <= 0:
|
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
|
|
|
if lora_bias and (getattr(self.get_base_layer(), "bias", None) is None):
|
|
warnings.warn(
|
|
f"`lora_bias=True` was passed but the targeted layer of type {type(self.get_base_layer()).__name__} "
|
|
"has no bias. This means that merging LoRA weights won't be possible.",
|
|
PeftWarning,
|
|
)
|
|
|
|
lora_variant = self.resolve_lora_variant(config=config)
|
|
if lora_variant is not None:
|
|
self.lora_variant[adapter_name] = lora_variant
|
|
|
|
self.r[adapter_name] = r
|
|
self.lora_alpha[adapter_name] = lora_alpha
|
|
if lora_dropout > 0.0:
|
|
lora_dropout_layer = nn.Dropout(p=lora_dropout)
|
|
else:
|
|
lora_dropout_layer = nn.Identity()
|
|
|
|
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
|
|
|
|
# Actual trainable parameters
|
|
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
|
|
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias)
|
|
|
|
# Tying adapters is only implemented for Linear layers
|
|
# where the source is the embedding layer.
|
|
# Currently, this is the most prevelant way of tying layers (weight tying)
|
|
if tied_adapter:
|
|
lora_A_params = tied_adapter["lora_A"]
|
|
lora_B_params = tied_adapter["lora_B"]
|
|
|
|
self.lora_A[adapter_name].weight = torch.nn.Parameter(lora_A_params)
|
|
self.lora_B[adapter_name].weight = torch.nn.Parameter(lora_B_params)
|
|
|
|
self.lora_bias[adapter_name] = lora_bias
|
|
|
|
if use_rslora:
|
|
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
|
|
else:
|
|
self.scaling[adapter_name] = lora_alpha / r
|
|
|
|
self.use_rslora[adapter_name] = use_rslora
|
|
|
|
self.use_dora[adapter_name] = config.use_dora
|
|
|
|
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
|
|
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
|
|
with gather_params_ctx(self.get_base_layer().weight):
|
|
self.pissa_init(adapter_name, init_lora_weights)
|
|
elif isinstance(init_lora_weights, str) and init_lora_weights.startswith("corda"):
|
|
with gather_params_ctx(self.get_base_layer().weight):
|
|
self.corda_init(adapter_name, init_lora_weights)
|
|
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
|
|
with gather_params_ctx(self.get_base_layer().weight):
|
|
self.olora_init(adapter_name)
|
|
elif init_lora_weights == "loftq":
|
|
with gather_params_ctx(self.get_base_layer().weight):
|
|
self.loftq_init(adapter_name, config)
|
|
elif init_lora_weights == "eva":
|
|
nn.init.zeros_(self.lora_B[adapter_name].weight)
|
|
elif init_lora_weights == "orthogonal":
|
|
with gather_params_ctx(self.get_base_layer().weight):
|
|
self.orthogonal_init(adapter_name)
|
|
elif init_lora_weights == "lora_ga":
|
|
with gather_params_ctx(self.get_base_layer().weight):
|
|
self.lora_ga_init(adapter_name, config.lora_ga_config)
|
|
elif init_lora_weights:
|
|
self.reset_lora_parameters(adapter_name, init_lora_weights)
|
|
# call this before init of the lora variants
|
|
self._move_adapter_to_device_of_base_layer(adapter_name)
|
|
|
|
if adapter_name in self.lora_variant:
|
|
self.lora_variant[adapter_name].init(self, adapter_name=adapter_name, config=config, **kwargs)
|
|
|
|
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
|
|
|
|
# Check for adapters that were added or removed from the arrow_model.
|
|
# The arrow model may be modified after creation by adding new experts
|
|
# (pre-trained or trainable) or by removing existing ones. Whenever such
|
|
# a change occurs, on_adapter_change() is called to update the set of
|
|
# active task-specific experts and, if needed, to handle recomputing prototypes
|
|
# and doing general knowledge subtraction (GKS) again.
|
|
if hasattr(self, "lora_arrow"):
|
|
for adapter in self.lora_variant:
|
|
if adapter in self.lora_arrow:
|
|
self.lora_arrow[adapter].on_adapter_change(self.lora_A, self.lora_B)
|
|
|
|
def reset_lora_parameters(self, adapter_name, init_lora_weights):
|
|
if init_lora_weights is not False:
|
|
if adapter_name in self.lora_A.keys():
|
|
if init_lora_weights is True:
|
|
# initialize A the same way as the default for nn.Linear and B to zero
|
|
# https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
|
|
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
|
|
elif init_lora_weights.lower() == "gaussian":
|
|
nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name])
|
|
else:
|
|
raise ValueError(f"Unknown initialization {init_lora_weights=}")
|
|
nn.init.zeros_(self.lora_B[adapter_name].weight)
|
|
if self.lora_bias[adapter_name]:
|
|
nn.init.zeros_(self.lora_B[adapter_name].bias)
|
|
if adapter_name in self.lora_embedding_A.keys():
|
|
# Initialize A to zeros and B the same way as the default for nn.Embedding, see:
|
|
# https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L59-L60
|
|
nn.init.zeros_(self.lora_embedding_A[adapter_name])
|
|
nn.init.normal_(self.lora_embedding_B[adapter_name])
|
|
if self.lora_bias[adapter_name]:
|
|
# embeddings are not supported at the moment, but still adding this for consistency
|
|
nn.init.zeros_(self.lora_embedding_B[adapter_name].bias)
|
|
|
|
# Always synchronize non-sharded LoRA weights across TP ranks, regardless of
|
|
# init_lora_weights, since each rank initializes weights independently.
|
|
# We could skip some broadcast, for instance when the lora weights are initialized to zero,
|
|
# but this is a minor optimization and would add extra complexity to the code.
|
|
if dist.is_available() and dist.is_initialized():
|
|
base_layer = self.get_base_layer()
|
|
tp_plan = getattr(base_layer, "_hf_tp_plan", None)
|
|
device_mesh = getattr(base_layer, "_hf_device_mesh", None)
|
|
if device_mesh is not None:
|
|
if tp_plan not in ["colwise", "rowwise", "embedding_rowwise"]:
|
|
warnings.warn(
|
|
f'tp_plan "{tp_plan}" found on the base layer is not supported for LoRA weight '
|
|
'synchronization. Expected one of "colwise", "rowwise", "embedding_rowwise". '
|
|
"LoRA weights may not be synchronized across ranks."
|
|
)
|
|
pg = device_mesh.get_group()
|
|
src = dist.get_global_rank(pg, 0)
|
|
if tp_plan == "colwise" and adapter_name in self.lora_A:
|
|
# The adapter weights need to be on device for broadcasting
|
|
self._move_adapter_to_device_of_base_layer(adapter_name)
|
|
dist.broadcast(self.lora_A[adapter_name].weight.data, src=src, group=pg)
|
|
elif tp_plan == "rowwise" and adapter_name in self.lora_B:
|
|
# The adapter weights need to be on device for broadcasting
|
|
self._move_adapter_to_device_of_base_layer(adapter_name)
|
|
dist.broadcast(self.lora_B[adapter_name].weight.data, src=src, group=pg)
|
|
elif tp_plan == "embedding_rowwise" and adapter_name in self.lora_embedding_B:
|
|
self._move_adapter_to_device_of_base_layer(adapter_name)
|
|
dist.broadcast(self.lora_embedding_B[adapter_name].data, src=src, group=pg)
|
|
|
|
def olora_init(self, adapter_name):
|
|
base_layer = self.get_base_layer()
|
|
orig_weight = base_layer.weight
|
|
bnb_param_type = get_bnb_param_type(orig_weight)
|
|
dtype = orig_weight.dtype
|
|
|
|
if bnb_param_type:
|
|
# check without importing bitsandbytes and robust to bnb_4bit_quant_storage=float*
|
|
weight_tensor = dequantize_module_weight(base_layer)
|
|
elif dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
|
weight_tensor = orig_weight
|
|
else:
|
|
raise TypeError(f"Unsupported data type for the base layer. Got {dtype}.")
|
|
|
|
scale_factor = self.scaling[adapter_name]
|
|
r = self.r[adapter_name]
|
|
weight_tensor = weight_tensor.to(torch.float32)
|
|
Q, R = torch.linalg.qr(weight_tensor.data)
|
|
|
|
Qr, Rr = Q[:, :r], R[:r]
|
|
|
|
self.lora_A[adapter_name].weight.data = Rr.contiguous()
|
|
self.lora_B[adapter_name].weight.data = Qr.contiguous()
|
|
|
|
weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight
|
|
if bnb_param_type == "4bit":
|
|
weight_tensor = orig_weight.__class__(
|
|
weight_tensor,
|
|
quant_type=orig_weight.quant_type,
|
|
quant_storage=orig_weight.quant_storage,
|
|
compress_statistics=orig_weight.compress_statistics,
|
|
module=orig_weight.module,
|
|
).to(orig_weight.device)
|
|
base_layer.weight = weight_tensor
|
|
elif bnb_param_type == "8bit":
|
|
weight_tensor = orig_weight.__class__(
|
|
weight_tensor,
|
|
requires_grad=orig_weight.requires_grad,
|
|
has_fp16_weights=orig_weight.has_fp16_weights,
|
|
).to(orig_weight.device)
|
|
base_layer.weight = weight_tensor
|
|
else:
|
|
weight_tensor = weight_tensor.to(dtype)
|
|
base_layer.weight.data = weight_tensor
|
|
|
|
def pissa_init(self, adapter_name, init_lora_weights):
|
|
weight = self.get_base_layer().weight
|
|
dtype = weight.dtype
|
|
if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
|
raise TypeError(
|
|
"Please initialize PiSSA under float32, float16, or bfloat16. "
|
|
"Subsequently, re-quantize the residual model to help minimize quantization errors."
|
|
)
|
|
weight = transpose(weight.to(torch.float32), self.fan_in_fan_out)
|
|
if init_lora_weights == "pissa":
|
|
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
|
|
V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
|
|
Vr = V[:, : self.r[adapter_name]]
|
|
Sr = S[: self.r[adapter_name]]
|
|
Sr /= self.scaling[adapter_name]
|
|
Uhr = Uh[: self.r[adapter_name]]
|
|
elif len(init_lora_weights.split("_niter_")) == 2:
|
|
Vr, Sr, Ur = svd_lowrank(
|
|
weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1])
|
|
)
|
|
Sr /= self.scaling[adapter_name]
|
|
Uhr = Ur.t()
|
|
else:
|
|
raise ValueError(
|
|
f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead."
|
|
)
|
|
|
|
lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr
|
|
lora_B = Vr @ torch.diag(torch.sqrt(Sr))
|
|
self.lora_A[adapter_name].weight.data = lora_A
|
|
self.lora_B[adapter_name].weight.data = lora_B
|
|
weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
|
|
weight = transpose(weight.to(dtype), self.fan_in_fan_out)
|
|
self.get_base_layer().weight.data = weight
|
|
|
|
def corda_init(self, adapter_name, init_lora_weights):
|
|
linear = self.get_base_layer()
|
|
weight = linear.weight
|
|
dtype = weight.dtype
|
|
if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
|
raise TypeError(
|
|
"Please initialize CorDA under float32, float16, or bfloat16. "
|
|
"Subsequently, re-quantize the residual model to help minimize quantization errors."
|
|
)
|
|
weight = weight.to(torch.float32)
|
|
# For Conv1D, weight is stored as (in_features, out_features), transposed compared to Linear
|
|
if isinstance(linear, Conv1D):
|
|
out_dim = weight.data.size(1)
|
|
in_dim = weight.data.size(0)
|
|
else:
|
|
out_dim = weight.data.size(0)
|
|
in_dim = weight.data.size(1)
|
|
|
|
# Calculate WC from covariance matrix
|
|
if not hasattr(linear, "eigens"):
|
|
raise ValueError(
|
|
"`eigens` attribute not found for layer, please run `preprocess_corda` first. "
|
|
"More information can be found at examples/corda_finetuning/README.md."
|
|
)
|
|
eigens = linear.eigens
|
|
U = eigens.U_WC
|
|
S = eigens.S_WC
|
|
V = eigens.V_WC
|
|
r = self.r[adapter_name]
|
|
|
|
# nan or inf check
|
|
if torch.isnan(S).any() or torch.isinf(S).any():
|
|
raise ValueError(
|
|
"Invalid value found in matrix S. Please file an issue at https://github.com/huggingface/peft/issues."
|
|
)
|
|
if torch.isnan(U).any() or torch.isinf(U).any():
|
|
raise ValueError(
|
|
"Invalid value found in matrix U. Please file an issue at https://github.com/huggingface/peft/issues."
|
|
)
|
|
if torch.isnan(V).any() or torch.isinf(V).any():
|
|
raise ValueError(
|
|
"Invalid value found in matrix V. Please file an issue at https://github.com/huggingface/peft/issues."
|
|
)
|
|
|
|
# Sanity check
|
|
if U.size(0) != out_dim or U.size(1) != r:
|
|
raise ValueError(
|
|
f"Matrix U size mismatch: {U.size()} vs. ({out_dim}, {r}). Please make sure the `lora_config` and "
|
|
"`model` argument of `preprocess_corda` is consistent with `get_peft_model`. If you're using cache "
|
|
"in `preprocess_corda`, please make sure the cache is built with the same model and LoRA rank."
|
|
)
|
|
if S.size(0) != r:
|
|
raise ValueError(
|
|
f"Matrix S size mismatch: {S.size()} vs. ({r},). Please make sure the `lora_config` and `model` argument "
|
|
"of `preprocess_corda` is consistent with `get_peft_model`. If you're using cache in `preprocess_corda`, "
|
|
"please make sure the cache is built with the same model and LoRA rank."
|
|
)
|
|
if V.size(0) != in_dim or V.size(1) != r:
|
|
raise ValueError(
|
|
f"Matrix V size mismatch: {V.size()} vs. ({in_dim}, {r}). Please make sure the `lora_config` and "
|
|
"`model` argument of `preprocess_corda` is consistent with `get_peft_model`. If you're using cache "
|
|
"in `preprocess_corda`, please make sure the cache is built with the same model and LoRA rank."
|
|
)
|
|
|
|
# Apply alpha
|
|
S /= self.scaling[adapter_name]
|
|
|
|
# Init lora_A and lora_B weights
|
|
lora_A = V.t().mul(S.sqrt().view(-1, 1)).contiguous()
|
|
lora_B = U.mul(S.sqrt()).contiguous()
|
|
self.lora_A[adapter_name].weight.data = lora_A
|
|
self.lora_B[adapter_name].weight.data = lora_B
|
|
|
|
# For Conv1D, lora_B @ lora_A gives (out_dim, in_dim) but weight is (in_dim, out_dim)
|
|
# So we need to transpose before subtraction
|
|
delta = self.scaling[adapter_name] * lora_B @ lora_A
|
|
delta = transpose(delta, fan_in_fan_out=self.fan_in_fan_out)
|
|
weight = weight.data - delta
|
|
weight = weight.to(dtype)
|
|
self.get_base_layer().weight.data = weight
|
|
|
|
# Remove redundant fields
|
|
del linear.eigens
|
|
|
|
def loftq_init(self, adapter_name, config: LoraConfig):
|
|
weight = self.get_base_layer().weight
|
|
kwargs = {
|
|
"num_bits": config.loftq_config["loftq_bits"],
|
|
"reduced_rank": self.r[adapter_name],
|
|
"num_iter": config.loftq_config["loftq_iter"],
|
|
}
|
|
|
|
qweight, lora_A, lora_B = loftq_init(weight, **kwargs)
|
|
if adapter_name in self.lora_A.keys():
|
|
# initialize A the same way as the default for nn.Linear and B to zero
|
|
self.lora_A[adapter_name].weight.data = lora_A
|
|
self.lora_B[adapter_name].weight.data = lora_B
|
|
if adapter_name in self.lora_embedding_A.keys():
|
|
# initialize a the same way as the default for nn.linear and b to zero
|
|
self.lora_embedding_A[adapter_name].weight.data = lora_A
|
|
self.lora_embedding_B[adapter_name].weight.data = lora_B
|
|
self.get_base_layer().weight.data = qweight
|
|
|
|
@torch.no_grad()
|
|
def orthogonal_init(self, adapter_name):
|
|
# https://datta0.github.io/posts/rethink-lora-init/#orthogonal-initialisation
|
|
rank = self.r[adapter_name]
|
|
if rank % 2 != 0:
|
|
raise ValueError(f"Orthogonal initialization requires the LoRA rank to be even, got {rank} instead.")
|
|
|
|
X = torch.randn(rank, rank)
|
|
Q, _ = torch.linalg.qr(X)
|
|
q_odd = Q[0::2, :] # Odd rows
|
|
q_even = Q[1::2, :] # Even rows
|
|
dtype = self.get_base_layer().weight.dtype
|
|
lora_A = torch.randn(self.in_features, rank // 2).mm(q_odd).T / 10.0
|
|
lora_B = torch.randn(rank // 2, self.out_features).T.mm(q_even) / 10.0
|
|
self.lora_A[adapter_name].weight = nn.Parameter(lora_A.contiguous().to(dtype))
|
|
self.lora_B[adapter_name].weight = nn.Parameter(lora_B.contiguous().to(dtype))
|
|
|
|
def lora_ga_init(self, adapter_name, lora_ga_config):
|
|
"""
|
|
Initialize LoRA weights using gradient approximation.
|
|
|
|
Uses SVD on the gradient matrix to initialize adapters in a way that aligns with the direction of full
|
|
fine-tuning.
|
|
|
|
Expects that `preprocess_loraga` has been called before this, which attaches the `loraga_grad` attribute to the
|
|
base layer.
|
|
|
|
If gradients are not available (e.g., when loading from a saved adapter), falls back to gaussian
|
|
initialization. The weights will be overwritten by the state_dict anyway.
|
|
"""
|
|
base_layer = self.get_base_layer()
|
|
|
|
# Check for gradient attached by preprocess_loraga
|
|
if not hasattr(base_layer, "_peft_loraga_grad"):
|
|
# When loading from saved adapter, gradients won't be available
|
|
# Fall back to gaussian initialization (weights will be overwritten by state_dict)
|
|
self.reset_lora_parameters(adapter_name, init_lora_weights=True)
|
|
return
|
|
|
|
grad = base_layer._peft_loraga_grad
|
|
|
|
# Check for lora_ga_config
|
|
if lora_ga_config is None:
|
|
raise ValueError(
|
|
"lora_ga_config must be provided when init_lora_weights='lora_ga'. "
|
|
"Please pass lora_ga_config=LoraGAConfig(...) to LoraConfig."
|
|
)
|
|
direction = lora_ga_config.direction
|
|
scale = lora_ga_config.scale
|
|
stable_gamma = lora_ga_config.stable_gamma
|
|
dtype = self.get_base_layer().weight.dtype
|
|
|
|
grad = grad.to(torch.float32)
|
|
weight = self.get_base_layer().weight
|
|
|
|
grad = transpose(grad, self.fan_in_fan_out)
|
|
|
|
r = self.r[adapter_name]
|
|
|
|
# torch.svd_lowrank returns (U, S, V) where grad ≈ U @ diag(S) @ V.T
|
|
# So V is shape (in_features, k) and we need V.T which is (k, in_features) for lora_A
|
|
U, S, V = torch.svd_lowrank(grad, q=min(4 * r, min(grad.shape)), niter=4)
|
|
|
|
# V is (in_features, k), we need Vh = V.T which is (k, in_features)
|
|
Vh = V.t()
|
|
|
|
U = U[:, : 2 * r]
|
|
S = S[: 2 * r]
|
|
Vh = Vh[: 2 * r, :]
|
|
|
|
if direction == "ArBr":
|
|
# Alternating: A takes rows at odd indices [1,3,5,7], B takes columns at even indices [0,2,4,6]
|
|
lora_A_weight = Vh[1 : 2 * r : 2, :] # Shape: (r, in_features)
|
|
lora_B_weight = U[:, 0 : 2 * r : 2] # Shape: (out_features, r)
|
|
S_B = S[0 : 2 * r : 2]
|
|
lora_B_weight = lora_B_weight @ torch.diag(S_B)
|
|
|
|
elif direction == "A2rBr":
|
|
# A takes second half rows [r:2r], B takes first half columns [:r]
|
|
lora_A_weight = Vh[r : 2 * r, :] # Shape: (r, in_features)
|
|
lora_B_weight = U[:, :r] # Shape: (out_features, r)
|
|
S_B = S[:r]
|
|
lora_B_weight = lora_B_weight @ torch.diag(S_B)
|
|
|
|
elif direction == "ArB2r":
|
|
# A takes first half rows [:r], B takes second half columns [r:2r]
|
|
lora_A_weight = Vh[:r, :] # Shape: (r, in_features)
|
|
lora_B_weight = U[:, r : 2 * r] # Shape: (out_features, r)
|
|
S_B = S[r : 2 * r]
|
|
lora_B_weight = lora_B_weight @ torch.diag(S_B)
|
|
|
|
elif direction == "random":
|
|
indices = torch.randperm(2 * r)[:r]
|
|
lora_A_weight = Vh[indices, :] # Shape: (r, in_features)
|
|
S_B = S[indices]
|
|
lora_B_weight = U[:, indices] @ torch.diag(S_B) # Shape: (out_features, r)
|
|
|
|
scaling_factor = self.scaling[adapter_name]
|
|
out_features = weight.shape[0]
|
|
|
|
if scale == "stable":
|
|
scale_factor = (out_features**0.25) / (stable_gamma**0.5)
|
|
lora_B_weight = lora_B_weight * scale_factor
|
|
|
|
elif scale == "weight_svd":
|
|
weight_data = transpose(weight.data.to(torch.float32), self.fan_in_fan_out)
|
|
_, weight_S, _ = torch.svd_lowrank(weight_data, q=r, niter=4)
|
|
if S_B[0] > 0:
|
|
scale_factor = weight_S[0] / S_B[0]
|
|
lora_B_weight = lora_B_weight * scale_factor
|
|
|
|
elif scale == "gd_scale":
|
|
lora_A_weight = lora_A_weight / scaling_factor
|
|
lora_B_weight = lora_B_weight / scaling_factor
|
|
|
|
# Convert to target dtype first to ensure weight offset matches adapter precision
|
|
lora_A_weight = lora_A_weight.to(dtype)
|
|
lora_B_weight = lora_B_weight.to(dtype)
|
|
|
|
# Assign LoRA weights
|
|
# lora_A should be (r, in_features), lora_B should be (out_features, r)
|
|
self.lora_A[adapter_name].weight.data = lora_A_weight.contiguous()
|
|
self.lora_B[adapter_name].weight.data = lora_B_weight.contiguous()
|
|
|
|
# Modify base weights: W_new = W_old - scaling * (B @ A)
|
|
# Important: compute offset in fp32 using dtype-converted weights to match forward pass precision
|
|
weight_data = transpose(weight.data.to(torch.float32), self.fan_in_fan_out)
|
|
weight_offset = scaling_factor * (lora_B_weight.float() @ lora_A_weight.float())
|
|
weight_data = weight_data - weight_offset
|
|
weight_data = transpose(weight_data.to(dtype), self.fan_in_fan_out)
|
|
self.get_base_layer().weight.data = weight_data
|
|
|
|
# Remove redundant fields
|
|
del base_layer._peft_loraga_grad
|
|
|
|
def _cache_store(self, key: str, value: Any) -> None:
|
|
# cache intermediate values, e.g. weight norm of DoRA
|
|
self._caches[key] = value
|
|
|
|
def _cache_pop(self, key: str) -> Any:
|
|
# retrieve and remove from ad hoc cache
|
|
value = self._caches.pop(key)
|
|
return value
|
|
|
|
def set_scale(self, adapter: str, scale: float | int) -> None:
|
|
"""Set the scale of the given adapter to the initial scale multiplied by the provided factor
|
|
|
|
The initial scale is determined by the configured `r` (rank) and `lora_alpha`.
|
|
"""
|
|
if adapter not in self.scaling:
|
|
# Ignore the case where the adapter is not in the layer
|
|
return
|
|
if self.use_rslora.get(adapter, False):
|
|
self.scaling[adapter] = scale * self.lora_alpha[adapter] / math.sqrt(self.r[adapter])
|
|
else:
|
|
self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter]
|
|
|
|
def scale_layer(self, scale: float | int) -> None:
|
|
"""Multiply the current scale of all active adapters by the provided factor"""
|
|
if scale == 1:
|
|
return
|
|
|
|
for active_adapter in self.active_adapters:
|
|
if active_adapter not in self.lora_A.keys():
|
|
continue
|
|
|
|
self.scaling[active_adapter] *= scale
|
|
|
|
def unscale_layer(self, scale: Optional[float | int] = None) -> None:
|
|
"""Divide the current scale of all active adapters by the provided factor. If `scale=None` is passed, reset to
|
|
initial scale
|
|
|
|
The initial scale is determined by the configured `r` (rank) and `lora_alpha`.
|
|
|
|
"""
|
|
for active_adapter in self.active_adapters:
|
|
if active_adapter not in self.lora_A.keys():
|
|
continue
|
|
|
|
if scale is None:
|
|
if self.use_rslora.get(active_adapter, False):
|
|
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / math.sqrt(self.r[active_adapter])
|
|
else:
|
|
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter]
|
|
else:
|
|
self.scaling[active_adapter] = self.scaling[active_adapter] / scale
|
|
|
|
def _check_forward_args(self, x, *args, **kwargs):
|
|
"""Check if the arguments are compatible with the configs and state of the model"""
|
|
adapter_names = kwargs.get("adapter_names", None)
|
|
if adapter_names is None:
|
|
return
|
|
|
|
if len(x) != len(adapter_names):
|
|
msg = (
|
|
"Length of `adapter_names` should be the same as the number of inputs, but got "
|
|
f"{len(adapter_names)} and {len(x)} respectively."
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
if self.merged:
|
|
# It is unclear what would be the right thing to do if users pass adapter_names and there are merged
|
|
# adapters. Therefore, it is better to raise an error in this case.
|
|
msg = "Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first."
|
|
raise ValueError(msg)
|
|
|
|
# DoRA is not supported (yet), check that it's not being used. Don't check "__base__", as this is the
|
|
# placeholder for the base model.
|
|
unique_adapters = {name for name in adapter_names if name != "__base__"}
|
|
for adapter_name in unique_adapters:
|
|
if self.use_dora.get(adapter_name, False):
|
|
msg = "Cannot pass `adapter_names` when DoRA is enabled."
|
|
raise ValueError(msg)
|
|
|
|
def _mixed_batch_forward(
|
|
self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
|
|
) -> torch.Tensor:
|
|
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
|
|
# extra argument that allows mixing different adapters in the same batch at inference time.
|
|
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
torch_result_dtype = result.dtype
|
|
|
|
unique_adapters = set(adapter_names)
|
|
sub_batch_indices_list = []
|
|
for adapter in unique_adapters:
|
|
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
|
|
alora_offsets = variant_kwargs.get("alora_offsets", None)
|
|
for i, active_adapter in enumerate(unique_adapters):
|
|
if active_adapter == "__base__":
|
|
continue
|
|
if active_adapter not in self.lora_A.keys():
|
|
continue
|
|
|
|
lora_A = self.lora_A[active_adapter]
|
|
lora_B = self.lora_B[active_adapter]
|
|
dropout = self.lora_dropout[active_adapter]
|
|
scaling = self.scaling[active_adapter]
|
|
|
|
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
|
|
# layer output
|
|
sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype)
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling
|
|
result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype)
|
|
else:
|
|
if alora_offsets is not None:
|
|
variant_kwargs["alora_offsets"] = [alora_offsets[j] for j in sub_batch_indices_list[i]]
|
|
lora_output = self.lora_variant[active_adapter].forward(
|
|
self,
|
|
active_adapter=active_adapter,
|
|
x=sub_batch,
|
|
result=result[sub_batch_indices_list[i]],
|
|
**variant_kwargs,
|
|
**kwargs,
|
|
)
|
|
result[sub_batch_indices_list[i]] = lora_output.to(torch_result_dtype)
|
|
|
|
return result
|
|
|
|
|
|
# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
|
# and modified to work with PyTorch FSDP
|
|
|
|
|
|
# ------------------------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
|
# ------------------------------------------------------------------------------------------
|
|
|
|
|
|
class Linear(nn.Module, LoraLayer):
|
|
# Lora implemented in a dense layer
|
|
def __init__(
|
|
self,
|
|
base_layer,
|
|
adapter_name: str,
|
|
config: LoraConfig,
|
|
r: int = 0,
|
|
lora_alpha: int = 1,
|
|
is_target_conv_1d_layer: bool = False,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__()
|
|
LoraLayer.__init__(self, base_layer, **kwargs)
|
|
self.fan_in_fan_out = config.fan_in_fan_out
|
|
|
|
self._active_adapter = adapter_name
|
|
self.update_layer(
|
|
adapter_name,
|
|
r,
|
|
lora_alpha=lora_alpha,
|
|
config=config,
|
|
**kwargs,
|
|
)
|
|
self.is_target_conv_1d_layer = is_target_conv_1d_layer
|
|
|
|
def resolve_lora_variant(self, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
|
|
if config.arrow_config is not None:
|
|
from .variants import ArrowLinearVariant
|
|
|
|
return ArrowLinearVariant()
|
|
|
|
if config.use_bdlora is not None:
|
|
from .variants import BdLoraLinearVariant
|
|
|
|
return BdLoraLinearVariant()
|
|
|
|
use_alora = config.alora_invocation_tokens is not None
|
|
if not config.use_dora and not use_alora:
|
|
return None
|
|
|
|
from .variants import ALoraLinearVariant, DoraLinearVariant
|
|
|
|
if use_alora:
|
|
return ALoraLinearVariant()
|
|
else:
|
|
return DoraLinearVariant()
|
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
|
"""
|
|
Merge the active adapter weights into the base weights
|
|
|
|
Args:
|
|
safe_merge (`bool`, *optional*):
|
|
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
|
before merging the weights. This is useful if you want to check if the merge operation will produce
|
|
NaNs. Defaults to `False`.
|
|
adapter_names (`list[str]`, *optional*):
|
|
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
|
to `None`.
|
|
"""
|
|
adapter_names = check_adapters_to_merge(self, adapter_names)
|
|
if not adapter_names:
|
|
# no adapter to merge
|
|
return
|
|
|
|
for active_adapter in adapter_names:
|
|
if active_adapter in self.lora_A.keys():
|
|
base_layer = self.get_base_layer()
|
|
if safe_merge:
|
|
# Note that safe_merge will be slower than the normal merge
|
|
# because of the copy operation.
|
|
orig_weight = base_layer.weight.data.clone()
|
|
orig_dtype = orig_weight.dtype
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
delta_weight = self.get_delta_weight(active_adapter)
|
|
orig_weight += delta_weight.to(orig_dtype)
|
|
else:
|
|
orig_weight = self.lora_variant[active_adapter].merge_safe(self, active_adapter, orig_weight)
|
|
|
|
if not torch.isfinite(orig_weight).all():
|
|
raise ValueError(
|
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
|
)
|
|
|
|
base_layer.weight.data = orig_weight
|
|
|
|
if self.lora_bias[active_adapter]:
|
|
if getattr(base_layer, "bias", None) is None:
|
|
raise RuntimeError(
|
|
"Impossible to merge LoRA with `lora_bias=True` because the base layer has no bias."
|
|
)
|
|
new_bias = base_layer.bias + self.lora_B[active_adapter].bias * self.scaling[active_adapter]
|
|
if not torch.isfinite(new_bias).all():
|
|
raise ValueError(
|
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
|
)
|
|
base_layer.bias.data = new_bias.to(orig_dtype)
|
|
|
|
else:
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
delta_weight = self.get_delta_weight(active_adapter)
|
|
base_layer.weight.data += delta_weight
|
|
else:
|
|
self.lora_variant[active_adapter].merge_unsafe(self, active_adapter, base_layer.weight)
|
|
|
|
if self.lora_bias[active_adapter]:
|
|
if getattr(base_layer, "bias", None) is None:
|
|
raise RuntimeError(
|
|
"Impossible to merge LoRA with `lora_bias=True` because the base layer has no bias."
|
|
)
|
|
base_layer.bias.data += self.lora_B[active_adapter].bias * self.scaling[active_adapter]
|
|
|
|
self.merged_adapters.append(active_adapter)
|
|
|
|
def unmerge(self) -> None:
|
|
"""
|
|
This method unmerges all merged adapter layers from the base weights.
|
|
"""
|
|
|
|
if not self.merged:
|
|
warnings.warn("Already unmerged. Nothing to do.")
|
|
return
|
|
while len(self.merged_adapters) > 0:
|
|
active_adapter = self.merged_adapters.pop()
|
|
if active_adapter in self.lora_A.keys():
|
|
weight = self.get_base_layer().weight
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
orig_dtype = weight.dtype
|
|
delta_weight = self.get_delta_weight(active_adapter)
|
|
weight.data -= delta_weight.to(orig_dtype)
|
|
else:
|
|
unmerged = self.lora_variant[active_adapter].unmerge(self, active_adapter, weight)
|
|
weight.data = unmerged
|
|
|
|
if self.lora_bias[active_adapter]:
|
|
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias * self.scaling[active_adapter]
|
|
|
|
def get_delta_weight(self, adapter) -> torch.Tensor:
|
|
"""
|
|
Compute the delta weight for the given adapter.
|
|
|
|
Args:
|
|
adapter (str):
|
|
The name of the adapter for which the delta weight should be computed.
|
|
"""
|
|
device = self.lora_B[adapter].weight.device
|
|
dtype = self.lora_B[adapter].weight.dtype
|
|
|
|
# In case users wants to merge the adapter weights that are in
|
|
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
|
|
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
|
|
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
|
|
|
|
weight_A = self.lora_A[adapter].weight
|
|
weight_B = self.lora_B[adapter].weight
|
|
|
|
if cast_to_fp32:
|
|
weight_A = weight_A.float()
|
|
weight_B = weight_B.float()
|
|
|
|
output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
|
|
|
|
if cast_to_fp32:
|
|
output_tensor = output_tensor.to(dtype=dtype)
|
|
|
|
# cast back the weights
|
|
self.lora_A[adapter].weight.data = weight_A.to(dtype)
|
|
self.lora_B[adapter].weight.data = weight_B.to(dtype)
|
|
|
|
return output_tensor
|
|
|
|
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
|
self._check_forward_args(x, *args, **kwargs)
|
|
adapter_names = kwargs.pop("adapter_names", None)
|
|
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
|
|
|
if self.disable_adapters:
|
|
if self.merged:
|
|
self.unmerge()
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
elif adapter_names is not None:
|
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **variant_kwargs, **kwargs)
|
|
elif self.merged:
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
else:
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
torch_result_dtype = result.dtype
|
|
|
|
lora_A_keys = self.lora_A.keys()
|
|
for active_adapter in self.active_adapters:
|
|
if active_adapter not in lora_A_keys:
|
|
continue
|
|
|
|
lora_A = self.lora_A[active_adapter]
|
|
lora_B = self.lora_B[active_adapter]
|
|
dropout = self.lora_dropout[active_adapter]
|
|
scaling = self.scaling[active_adapter]
|
|
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
result = result + lora_B(lora_A(dropout(x))) * scaling
|
|
else:
|
|
result = self.lora_variant[active_adapter].forward(
|
|
self,
|
|
active_adapter=active_adapter,
|
|
x=x,
|
|
result=result,
|
|
**variant_kwargs,
|
|
**kwargs,
|
|
)
|
|
|
|
result = result.to(torch_result_dtype)
|
|
|
|
return result
|
|
|
|
def supports_lora_conversion(self, adapter_name: str = "default") -> bool:
|
|
return True
|
|
|
|
def __repr__(self) -> str:
|
|
rep = super().__repr__()
|
|
return "lora." + rep
|
|
|
|
|
|
class _LoraEmbeddingAHolder(nn.Module):
|
|
"""
|
|
A "fake" module to hold the lora_embedding_A weights for the TP hooks.
|
|
"""
|
|
|
|
def __init__(self, lora_embedding_A_weight):
|
|
super().__init__()
|
|
self.weight = lora_embedding_A_weight.T # lora_embedding_A shape is (r, vocab_size)
|
|
|
|
|
|
class Embedding(nn.Module, LoraLayer):
|
|
# LoRA implemented in a Embedding layer
|
|
def __init__(
|
|
self,
|
|
base_layer: nn.Module,
|
|
adapter_name: str,
|
|
config: LoraConfig,
|
|
r: int = 0,
|
|
lora_alpha: int = 1,
|
|
init_lora_weights: Union[bool, str] = True,
|
|
**kwargs,
|
|
) -> None:
|
|
if config.lora_bias:
|
|
# lora_bias=True is not supported (yet) for embedding layers, as they use nn.Parameter
|
|
raise ValueError(f"lora_bias={config.lora_bias} is not supported for {self.__class__.__name__}.")
|
|
|
|
super().__init__()
|
|
LoraLayer.__init__(self, base_layer)
|
|
self.fan_in_fan_out = config.fan_in_fan_out
|
|
|
|
tp_plan = getattr(base_layer, "_hf_tp_plan", None)
|
|
|
|
self.device_mesh = getattr(base_layer, "_hf_device_mesh", None)
|
|
self.tp_layer = None
|
|
|
|
self.input_fns = {}
|
|
self.output_fns = {}
|
|
|
|
if tp_plan is not None:
|
|
if not is_transformers_ge_v5_4_0:
|
|
raise RuntimeError("Tensor Parallel with LoRA is only supported for transformers v5.4.0 and above. ")
|
|
|
|
if tp_plan != "embedding_rowwise":
|
|
raise ValueError(
|
|
f'Unsupported tensor parallel plan {tp_plan} for embedding layers. Only "embedding_rowwise" is '
|
|
"supported."
|
|
)
|
|
|
|
from transformers.integrations.tensor_parallel import ALL_PARALLEL_STYLES
|
|
|
|
self.tp_layer = copy.deepcopy(ALL_PARALLEL_STYLES[tp_plan])
|
|
|
|
self._active_adapter = adapter_name
|
|
self.update_layer(
|
|
adapter_name,
|
|
r,
|
|
lora_alpha=lora_alpha,
|
|
config=config,
|
|
)
|
|
|
|
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
|
|
if not config.use_dora:
|
|
return None
|
|
|
|
from .variants import DoraEmbeddingVariant
|
|
|
|
return DoraEmbeddingVariant()
|
|
|
|
def update_layer(
|
|
self,
|
|
adapter_name: str,
|
|
r: int,
|
|
lora_alpha: int,
|
|
config: LoraConfig,
|
|
**kwargs,
|
|
) -> None:
|
|
lora_dropout = config.lora_dropout
|
|
init_lora_weights = config.init_lora_weights
|
|
use_rslora = config.use_rslora
|
|
lora_bias = config.lora_bias
|
|
inference_mode = config.inference_mode
|
|
|
|
if r <= 0:
|
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
|
|
|
lora_variant = self.resolve_lora_variant(config=config)
|
|
if lora_variant is not None:
|
|
self.lora_variant[adapter_name] = lora_variant
|
|
|
|
self.r[adapter_name] = r
|
|
self.lora_alpha[adapter_name] = lora_alpha
|
|
if lora_dropout > 0.0:
|
|
lora_dropout_layer = nn.Dropout(p=lora_dropout)
|
|
else:
|
|
lora_dropout_layer = nn.Identity()
|
|
|
|
self.lora_dropout[adapter_name] = lora_dropout_layer
|
|
# Actual trainable parameters
|
|
weight_A = torch.randn((r, self.in_features))
|
|
weight_B = torch.randn((self.out_features, r))
|
|
self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A)
|
|
self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B)
|
|
self.lora_bias[adapter_name] = lora_bias
|
|
|
|
if use_rslora:
|
|
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
|
|
else:
|
|
self.scaling[adapter_name] = lora_alpha / r
|
|
|
|
self.use_rslora[adapter_name] = use_rslora
|
|
|
|
self.use_dora[adapter_name] = config.use_dora
|
|
|
|
if init_lora_weights == "loftq":
|
|
self.loftq_init(adapter_name)
|
|
elif init_lora_weights == "lora_ga":
|
|
# Embedding layers don't support LoRA-GA, fall back to standard initialization
|
|
self.reset_lora_parameters(adapter_name, True)
|
|
elif init_lora_weights:
|
|
self.reset_lora_parameters(adapter_name, init_lora_weights)
|
|
|
|
# call this before init of the lora variants
|
|
self._move_adapter_to_device_of_base_layer(adapter_name)
|
|
|
|
if adapter_name in self.lora_variant:
|
|
self.lora_variant[adapter_name].init(self, adapter_name=adapter_name, config=config, **kwargs)
|
|
|
|
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
|
|
|
|
# If there is tensor parallelism, we register the hooks for `self._embed`.
|
|
if self.tp_layer is not None:
|
|
mod = _LoraEmbeddingAHolder(self.lora_embedding_A[adapter_name])
|
|
|
|
def input_fn(inputs):
|
|
return self.tp_layer._prepare_input_fn(mod, inputs, self.device_mesh)
|
|
|
|
def output_fn(outputs):
|
|
return self.tp_layer._prepare_output_fn(mod, outputs, self.device_mesh)
|
|
|
|
self.input_fns[adapter_name] = input_fn
|
|
self.output_fns[adapter_name] = output_fn
|
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
|
"""
|
|
Merge the active adapter weights into the base weights
|
|
|
|
Args:
|
|
safe_merge (`bool`, *optional*):
|
|
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
|
before merging the weights. This is useful if you want to check if the merge operation will produce
|
|
NaNs. Defaults to `False`.
|
|
adapter_names (`list[str]`, *optional*):
|
|
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
|
to `None`.
|
|
"""
|
|
adapter_names = check_adapters_to_merge(self, adapter_names)
|
|
if not adapter_names:
|
|
# no adapter to merge
|
|
return
|
|
|
|
for active_adapter in adapter_names:
|
|
if active_adapter in self.lora_embedding_A.keys():
|
|
base_layer = self.get_base_layer()
|
|
orig_dtype = base_layer.weight.dtype
|
|
if safe_merge:
|
|
# Note that safe_merge will be slower than the normal merge
|
|
# because of the copy operation.
|
|
orig_weight = base_layer.weight.data.clone()
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
orig_weight += self.get_delta_weight(active_adapter).to(orig_dtype)
|
|
else:
|
|
orig_weight = self.lora_variant[active_adapter].merge_safe(self, active_adapter, orig_weight)
|
|
|
|
if not torch.isfinite(orig_weight).all():
|
|
raise ValueError(
|
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
|
)
|
|
|
|
base_layer.weight.data = orig_weight
|
|
else:
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
base_layer.weight.data += self.get_delta_weight(active_adapter).to(orig_dtype)
|
|
else:
|
|
self.lora_variant[active_adapter].merge_unsafe(self, active_adapter, base_layer.weight)
|
|
self.merged_adapters.append(active_adapter)
|
|
|
|
def unmerge(self) -> None:
|
|
"""
|
|
This method unmerges all merged adapter layers from the base weights.
|
|
"""
|
|
if not self.merged:
|
|
warnings.warn("Already unmerged. Nothing to do.")
|
|
return
|
|
while len(self.merged_adapters) > 0:
|
|
active_adapter = self.merged_adapters.pop()
|
|
orig_dtype = self.get_base_layer().weight.dtype
|
|
if active_adapter in self.lora_embedding_A.keys():
|
|
weight = self.get_base_layer().weight
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
weight.data -= self.get_delta_weight(active_adapter).to(orig_dtype)
|
|
else:
|
|
unmerged = self.lora_variant[active_adapter].unmerge(self, active_adapter, weight)
|
|
weight.data = unmerged
|
|
|
|
def get_delta_weight(self, adapter) -> torch.Tensor:
|
|
"""
|
|
Compute the delta weight for the given adapter.
|
|
|
|
Args:
|
|
adapter (str):
|
|
The name of the adapter for which the delta weight should be computed.
|
|
"""
|
|
device = self.lora_embedding_B[adapter].device
|
|
dtype = self.lora_embedding_A[adapter].dtype
|
|
|
|
# In case users wants to merge the adapter weights that are in
|
|
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
|
|
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
|
|
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
|
|
|
|
weight_A = self.lora_embedding_A[adapter]
|
|
weight_B = self.lora_embedding_B[adapter]
|
|
|
|
if cast_to_fp32:
|
|
weight_A = weight_A.float()
|
|
weight_B = weight_B.float()
|
|
|
|
output_tensor = transpose(weight_B @ weight_A, True) * self.scaling[adapter]
|
|
|
|
if cast_to_fp32:
|
|
output_tensor = output_tensor.to(dtype=dtype)
|
|
|
|
# cast back the weights
|
|
self.lora_embedding_A[adapter] = weight_A.to(dtype)
|
|
self.lora_embedding_B[adapter] = weight_B.to(dtype)
|
|
|
|
return output_tensor
|
|
|
|
def _mixed_batch_forward(
|
|
self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
|
|
) -> torch.Tensor:
|
|
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
|
|
# extra argument that allows mixing different adapters in the same batch at inference time.
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
|
|
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
|
|
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
|
|
embed_scale = self._get_embed_scale()
|
|
|
|
unique_adapters = set(adapter_names)
|
|
sub_batch_indices_list = []
|
|
for adapter in unique_adapters:
|
|
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
|
|
|
|
for i, active_adapter in enumerate(unique_adapters):
|
|
if active_adapter == "__base__":
|
|
continue
|
|
if active_adapter not in self.lora_embedding_A.keys():
|
|
continue
|
|
|
|
embedding_A = self.lora_embedding_A[active_adapter].T
|
|
embedding_B = self.lora_embedding_B[active_adapter].T
|
|
scaling = self.scaling[active_adapter]
|
|
|
|
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
|
|
# layer output
|
|
sub_batch = x[sub_batch_indices_list[i]]
|
|
after_A = self._embed(sub_batch, embedding_A)
|
|
adapter_output = (after_A @ embedding_B) * scaling
|
|
|
|
# Apply embed_scale to match the base layer's scaling
|
|
if embed_scale is not None:
|
|
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)
|
|
|
|
result[sub_batch_indices_list[i]] += adapter_output
|
|
|
|
return result
|
|
|
|
def _embed(
|
|
self,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
input_fn: Callable[[tuple[torch.Tensor]], torch.Tensor] | None = None,
|
|
output_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
) -> torch.Tensor:
|
|
base_layer = self.get_base_layer()
|
|
if input_fn is not None:
|
|
input = input_fn((input,))
|
|
output = F.embedding(
|
|
input,
|
|
weight,
|
|
padding_idx=base_layer.padding_idx,
|
|
max_norm=base_layer.max_norm,
|
|
norm_type=base_layer.norm_type,
|
|
scale_grad_by_freq=base_layer.scale_grad_by_freq,
|
|
sparse=base_layer.sparse,
|
|
)
|
|
if output_fn is not None:
|
|
output = output_fn(output)
|
|
return output
|
|
|
|
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
|
# TODO: no dtype conversion here, unlike in Linear, is that correct?
|
|
self._check_forward_args(x, *args, **kwargs)
|
|
adapter_names = kwargs.pop("adapter_names", None)
|
|
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
|
if self.disable_adapters:
|
|
if self.merged:
|
|
self.unmerge()
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
elif adapter_names is not None:
|
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
|
elif self.merged:
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
else:
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
torch_result_dtype = result.dtype
|
|
|
|
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
|
|
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
|
|
embed_scale = self._get_embed_scale()
|
|
|
|
for active_adapter in self.active_adapters:
|
|
if active_adapter not in self.lora_embedding_A:
|
|
continue
|
|
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
embedding_A = self.lora_embedding_A[active_adapter].T
|
|
embedding_B = self.lora_embedding_B[active_adapter].T
|
|
scaling = self.scaling[active_adapter]
|
|
# input and ouput function hooks for TP support.
|
|
input_fn = self.input_fns.get(active_adapter, None)
|
|
output_fn = self.output_fns.get(active_adapter, None)
|
|
after_A = self._embed(x, embedding_A, input_fn=input_fn, output_fn=output_fn)
|
|
adapter_output = (after_A @ embedding_B) * scaling
|
|
|
|
# Apply embed_scale to match the base layer's scaling
|
|
if embed_scale is not None:
|
|
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)
|
|
|
|
result = result + adapter_output
|
|
else:
|
|
result = self.lora_variant[active_adapter].forward(
|
|
self,
|
|
active_adapter=active_adapter,
|
|
x=x,
|
|
result=result,
|
|
**variant_kwargs,
|
|
**kwargs,
|
|
)
|
|
result = result.to(torch_result_dtype)
|
|
|
|
return result
|
|
|
|
def __repr__(self) -> str:
|
|
rep = super().__repr__()
|
|
return "lora." + rep
|
|
|
|
|
|
class _ConvNd(nn.Module, LoraLayer):
|
|
# Lora implemented in a conv(2,3)d layer
|
|
def __init__(
|
|
self,
|
|
base_layer: nn.Module,
|
|
adapter_name: str,
|
|
config: LoraConfig,
|
|
r: int = 0,
|
|
lora_alpha: int = 1,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__()
|
|
LoraLayer.__init__(self, base_layer)
|
|
if kwargs.get("use_alora", False):
|
|
raise ValueError("aLoRA does not support adapting conv layers.")
|
|
if base_layer.groups > 1:
|
|
warnings.warn("LoRA adapter added to ConvNd layer with groups > 1. Merging is not supported.")
|
|
|
|
if r % base_layer.groups != 0:
|
|
raise ValueError(
|
|
f"Targeting a {base_layer.__class__.__name__} with groups={base_layer.groups} and rank {r}. "
|
|
"Currently, support is limited to conv layers where the rank is divisible by groups. "
|
|
"Either choose a different rank or do not target this specific layer."
|
|
)
|
|
|
|
self._active_adapter = adapter_name
|
|
self._kernel_dim = base_layer.weight.dim()
|
|
|
|
self.update_layer(
|
|
adapter_name,
|
|
r,
|
|
lora_alpha=lora_alpha,
|
|
config=config,
|
|
)
|
|
|
|
def update_layer(
|
|
self,
|
|
adapter_name: str,
|
|
r: int,
|
|
lora_alpha: int,
|
|
config: LoraConfig,
|
|
**kwargs,
|
|
) -> None:
|
|
lora_dropout = config.lora_dropout
|
|
init_lora_weights = config.init_lora_weights
|
|
use_rslora = config.use_rslora
|
|
lora_bias = config.lora_bias
|
|
inference_mode = config.inference_mode
|
|
|
|
if r <= 0:
|
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
|
|
|
if lora_bias and (getattr(self.get_base_layer(), "bias", None) is None):
|
|
warnings.warn(
|
|
f"`lora_bias=True` was passed but the targeted layer of type {type(self.get_base_layer()).__name__} "
|
|
"has no bias. This means that merging LoRA weights won't be possible.",
|
|
PeftWarning,
|
|
)
|
|
|
|
lora_variant = self.resolve_lora_variant(config=config)
|
|
if lora_variant is not None:
|
|
self.lora_variant[adapter_name] = lora_variant
|
|
|
|
self.r[adapter_name] = r
|
|
self.lora_alpha[adapter_name] = lora_alpha
|
|
if lora_dropout > 0.0:
|
|
lora_dropout_layer = nn.Dropout(p=lora_dropout)
|
|
else:
|
|
lora_dropout_layer = nn.Identity()
|
|
|
|
self.lora_dropout[adapter_name] = lora_dropout_layer
|
|
# Actual trainable parameters
|
|
base_layer = self.get_base_layer()
|
|
kernel_size = base_layer.kernel_size
|
|
stride = base_layer.stride
|
|
padding = base_layer.padding
|
|
conv_layer = type(base_layer)
|
|
out_kernel = out_stride = (1,) * (self._kernel_dim - 2)
|
|
self.lora_A[adapter_name] = conv_layer(self.in_features, r, kernel_size, stride, padding, bias=False)
|
|
self.lora_B[adapter_name] = conv_layer(
|
|
r, self.out_features, out_kernel, out_stride, groups=base_layer.groups, bias=lora_bias
|
|
)
|
|
self.lora_bias[adapter_name] = lora_bias
|
|
|
|
if use_rslora:
|
|
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
|
|
else:
|
|
self.scaling[adapter_name] = lora_alpha / r
|
|
|
|
self.use_rslora[adapter_name] = use_rslora
|
|
|
|
self.use_dora[adapter_name] = config.use_dora
|
|
|
|
if init_lora_weights == "loftq":
|
|
self.loftq_init(adapter_name)
|
|
elif init_lora_weights == "lora_ga":
|
|
# Conv layers don't support LoRA-GA, fall back to standard initialization
|
|
self.reset_lora_parameters(adapter_name, True)
|
|
elif init_lora_weights:
|
|
self.reset_lora_parameters(adapter_name, init_lora_weights)
|
|
|
|
# call this before init of the lora variants
|
|
self._move_adapter_to_device_of_base_layer(adapter_name)
|
|
|
|
if adapter_name in self.lora_variant:
|
|
self.lora_variant[adapter_name].init(self, adapter_name=adapter_name, config=config, **kwargs)
|
|
|
|
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
|
|
|
|
def _get_dora_factor_view(self):
|
|
return (-1,) + (1,) * (self._kernel_dim - 1)
|
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
|
"""
|
|
Merge the active adapter weights inside the base weights
|
|
|
|
Args:
|
|
safe_merge (`bool`, *optional*):
|
|
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
|
before merging the weights. This is useful if you want to check if the merge operation will produce
|
|
NaNs. Defaults to `False`.
|
|
adapter_names (`list[str]`, *optional*):
|
|
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
|
to `None`.
|
|
"""
|
|
adapter_names = check_adapters_to_merge(self, adapter_names)
|
|
if not adapter_names:
|
|
# no adapter to merge
|
|
return
|
|
|
|
for active_adapter in adapter_names:
|
|
if active_adapter in self.lora_A.keys():
|
|
base_layer = self.get_base_layer()
|
|
orig_dtype = base_layer.weight.dtype
|
|
|
|
if base_layer.groups > 1:
|
|
# https://github.com/huggingface/peft/pull/2403
|
|
raise NotImplementedError("Merging is not supported for _ConvNd layers with groups > 1!")
|
|
|
|
if safe_merge:
|
|
# Note that safe_merge will be slower than the normal merge
|
|
# because of the copy operation.
|
|
orig_weight = base_layer.weight.data.clone()
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
delta_weight = self.get_delta_weight(active_adapter)
|
|
orig_weight += delta_weight.to(orig_dtype)
|
|
else:
|
|
orig_weight = self.lora_variant[active_adapter].merge_safe(self, active_adapter, orig_weight)
|
|
|
|
if not torch.isfinite(orig_weight).all():
|
|
raise ValueError(
|
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
|
)
|
|
|
|
base_layer.weight.data = orig_weight
|
|
|
|
if self.lora_bias[active_adapter]:
|
|
if getattr(base_layer, "bias", None) is None:
|
|
raise RuntimeError(
|
|
"Impossible to merge LoRA with `lora_bias=True` because the base layer has no bias."
|
|
)
|
|
new_bias = base_layer.bias + self.lora_B[active_adapter].bias * self.scaling[active_adapter]
|
|
if not torch.isfinite(new_bias).all():
|
|
raise ValueError(
|
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
|
)
|
|
base_layer.bias.data = new_bias.to(orig_dtype)
|
|
|
|
else:
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
delta_weight = self.get_delta_weight(active_adapter)
|
|
base_layer.weight.data += delta_weight.to(orig_dtype)
|
|
else:
|
|
self.lora_variant[active_adapter].merge_unsafe(self, active_adapter, base_layer.weight)
|
|
|
|
if self.lora_bias[active_adapter]:
|
|
if getattr(base_layer, "bias", None) is None:
|
|
raise RuntimeError(
|
|
"Impossible to merge LoRA with `lora_bias=True` because the base layer has no bias."
|
|
)
|
|
base_layer.bias.data += self.lora_B[active_adapter].bias * self.scaling[active_adapter]
|
|
|
|
self.merged_adapters.append(active_adapter)
|
|
|
|
def unmerge(self) -> None:
|
|
"""
|
|
This method unmerges all merged adapter layers from the base weights.
|
|
"""
|
|
if not self.merged:
|
|
warnings.warn("Already unmerged. Nothing to do.")
|
|
return
|
|
while len(self.merged_adapters) > 0:
|
|
active_adapter = self.merged_adapters.pop()
|
|
if active_adapter in self.lora_A.keys():
|
|
weight = self.get_base_layer().weight
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
orig_dtype = weight.dtype
|
|
delta_weight = self.get_delta_weight(active_adapter)
|
|
weight.data -= delta_weight.to(orig_dtype)
|
|
else:
|
|
unmerged = self.lora_variant[active_adapter].unmerge(self, active_adapter, weight)
|
|
weight.data = unmerged
|
|
|
|
if self.lora_bias[active_adapter]:
|
|
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias * self.scaling[active_adapter]
|
|
|
|
def get_delta_weight(self, adapter) -> torch.Tensor:
|
|
"""
|
|
Compute the delta weight for the given adapter.
|
|
|
|
Args:
|
|
adapter (str):
|
|
The name of the adapter for which the delta weight should be computed.
|
|
"""
|
|
device = self.lora_B[adapter].weight.device
|
|
dtype = self.lora_A[adapter].weight.dtype
|
|
|
|
# In case users wants to merge the adapter weights that are in
|
|
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
|
|
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
|
|
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
|
|
|
|
weight_A = self.lora_A[adapter].weight
|
|
weight_B = self.lora_B[adapter].weight
|
|
|
|
if cast_to_fp32:
|
|
weight_A = weight_A.float()
|
|
weight_B = weight_B.float()
|
|
|
|
# https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
|
|
if self.get_base_layer().weight.size()[2:4] == (1, 1):
|
|
# conv2d 1x1
|
|
output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(
|
|
3
|
|
) * self.scaling[adapter]
|
|
else:
|
|
output_tensor = self.conv_fn(weight_A.transpose(0, 1), weight_B)
|
|
|
|
if self.get_base_layer().groups > 1:
|
|
output_tensor = output_tensor * self.scaling[adapter]
|
|
else:
|
|
output_tensor = output_tensor.transpose(0, 1) * self.scaling[adapter]
|
|
|
|
if cast_to_fp32:
|
|
output_tensor = output_tensor.to(dtype=dtype)
|
|
|
|
# cast back the weights
|
|
self.lora_A[adapter].weight.data = weight_A.to(dtype)
|
|
self.lora_B[adapter].weight.data = weight_B.to(dtype)
|
|
|
|
return output_tensor
|
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|
self._check_forward_args(x, *args, **kwargs)
|
|
adapter_names = kwargs.pop("adapter_names", None)
|
|
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
|
if self.disable_adapters:
|
|
if self.merged:
|
|
self.unmerge()
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
elif adapter_names is not None:
|
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
|
elif self.merged:
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
|
|
else:
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
torch_result_dtype = result.dtype
|
|
|
|
for active_adapter in self.active_adapters:
|
|
if active_adapter not in self.lora_A.keys():
|
|
continue
|
|
lora_A = self.lora_A[active_adapter]
|
|
lora_B = self.lora_B[active_adapter]
|
|
dropout = self.lora_dropout[active_adapter]
|
|
scaling = self.scaling[active_adapter]
|
|
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
|
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
result = result + lora_B(lora_A(dropout(x))) * scaling
|
|
else:
|
|
result = self.lora_variant[active_adapter].forward(
|
|
self,
|
|
active_adapter=active_adapter,
|
|
x=x,
|
|
result=result,
|
|
**variant_kwargs,
|
|
**kwargs,
|
|
)
|
|
|
|
result = result.to(torch_result_dtype)
|
|
return result
|
|
|
|
def __repr__(self) -> str:
|
|
rep = super().__repr__()
|
|
return "lora." + rep
|
|
|
|
|
|
class Conv2d(_ConvNd):
|
|
# Lora implemented in a conv2d layer
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
if not self._kernel_dim == 4:
|
|
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")
|
|
self.conv_fn = F.conv2d
|
|
|
|
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
|
|
if not config.use_dora:
|
|
return None
|
|
|
|
from .variants import DoraConv2dVariant
|
|
|
|
return DoraConv2dVariant()
|
|
|
|
|
|
class Conv1d(_ConvNd):
|
|
# Lora implemented in a conv1d layer
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
if not self._kernel_dim == 3:
|
|
raise ValueError(f"Conv1d layer kernel must have 3 dimensions, not {self._kernel_dim}")
|
|
self.conv_fn = F.conv1d
|
|
|
|
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
|
|
if not config.use_dora:
|
|
return None
|
|
|
|
from .variants import DoraConv1dVariant
|
|
|
|
return DoraConv1dVariant()
|
|
|
|
|
|
class Conv3d(_ConvNd):
|
|
# Lora implemented in a conv3d layer
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
if not self._kernel_dim == 5:
|
|
raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}")
|
|
self.conv_fn = F.conv3d
|
|
|
|
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
|
|
if not config.use_dora:
|
|
return None
|
|
|
|
from .variants import DoraConv3dVariant
|
|
|
|
return DoraConv3dVariant()
|
|
|
|
|
|
class MultiheadAttention(nn.Module, LoraLayer):
|
|
"""LoRA implemented in a multihead attention layer
|
|
|
|
This is currently only implemented for the case of `_qkv_same_embed_dim = True`, i.e. query, key, and value having
|
|
the same dimension.
|
|
|
|
Note: LoRA is applied to both the in_proj (query/key/value) and out_proj. There is currently no way to specify only
|
|
one of them. Don't try to apply LoRA to the out_proj of MultiheadAttention by targeting that layer specifically,
|
|
since the forward method of that layer is not being used, hence the LoRA adapter would be ignored.
|
|
|
|
This is a little bit hacky because of the way that MultiheadAttention is implemented in PyTorch: There are no
|
|
`nn.Linear` layers which we can hook onto or, in case of output projection, `.forward` is not used. This
|
|
implementation works around these problems by merging the weights before the forward call and unmerging them after
|
|
the forward call.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_layer,
|
|
adapter_name: str,
|
|
config: LoraConfig,
|
|
r: int = 0,
|
|
lora_alpha: int = 1,
|
|
**kwargs,
|
|
) -> None:
|
|
# TODO work with separate weights
|
|
if not getattr(base_layer, "_qkv_same_embed_dim", True):
|
|
# default for this value appears to be True:
|
|
# https://github.com/pytorch/pytorch/blob/701ba5203fe68d55d655bd4d6c008be94cf34ea5/torch/nn/modules/activation.py#L1128-L1130
|
|
raise ValueError(
|
|
f"Only same embed for query/key/value is supported as of now for {self.__class__.__name__}."
|
|
)
|
|
if config.use_dora:
|
|
# TODO: probably not so hard to implement
|
|
raise ValueError(f"{self.__class__.__name__} does not support DoRA (yet), please set use_dora to False")
|
|
if kwargs.get("use_alora", False):
|
|
raise ValueError(f"{self.__class__.__name__} does not support aLoRA (yet), please set use_alora to False")
|
|
super().__init__()
|
|
LoraLayer.__init__(self, base_layer, **kwargs)
|
|
|
|
# Note: LoRA is applied to both in_proj and out_proj. There is currently no way to only specify one of them.
|
|
if isinstance(base_layer.out_proj, nn.Linear):
|
|
self.base_layer.out_proj = Linear(
|
|
base_layer.out_proj,
|
|
adapter_name,
|
|
r=r,
|
|
lora_alpha=lora_alpha,
|
|
config=config,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
raise ValueError(f"out_proj must be an instance of nn.Linear for {self.__class__.__name__}.")
|
|
|
|
self._active_adapter = adapter_name
|
|
self.update_layer(adapter_name, r, lora_alpha=lora_alpha, config=config)
|
|
|
|
@property
|
|
def embed_dim(self) -> int:
|
|
return self.get_base_layer().embed_dim
|
|
|
|
@property
|
|
def kdim(self) -> Optional[int]:
|
|
return self.get_base_layer().kdim
|
|
|
|
@property
|
|
def vdim(self) -> Optional[int]:
|
|
return self.get_base_layer().vdim
|
|
|
|
@property
|
|
def _qkv_same_embed_dim(self) -> bool:
|
|
return self.get_base_layer()._qkv_same_embed_dim
|
|
|
|
@property
|
|
def num_heads(self) -> int:
|
|
return self.get_base_layer().num_heads
|
|
|
|
@property
|
|
def dropout(self) -> float:
|
|
return self.get_base_layer().dropout
|
|
|
|
@property
|
|
def batch_first(self) -> bool:
|
|
return self.get_base_layer().batch_first
|
|
|
|
@property
|
|
def head_dim(self) -> int:
|
|
return self.get_base_layer().head_dim
|
|
|
|
@property
|
|
def in_proj_weight(self) -> nn.Parameter:
|
|
return self.get_base_layer().in_proj_weight
|
|
|
|
@property
|
|
def in_proj_bias(self) -> nn.Parameter:
|
|
return self.get_base_layer().in_proj_bias
|
|
|
|
@property
|
|
def out_proj(self) -> nn.Module:
|
|
return self.get_base_layer().out_proj.get_base_layer()
|
|
|
|
@property
|
|
def bias_k(self) -> Optional[nn.Parameter]:
|
|
return self.get_base_layer().bias_k
|
|
|
|
@property
|
|
def bias_v(self) -> Optional[nn.Parameter]:
|
|
return self.get_base_layer().bias_v
|
|
|
|
def merge_masks(self, *args, **kwargs) -> tuple[Optional[torch.Tensor], Optional[int]]:
|
|
return self.get_base_layer().merge_masks(*args, **kwargs)
|
|
|
|
@property
|
|
def add_zero_attn(self) -> bool:
|
|
return self.get_base_layer().add_zero_attn
|
|
|
|
def update_layer(self, *args, **kwargs) -> None:
|
|
super().update_layer(*args, **kwargs)
|
|
# Note: LoRA is applied to both in_proj and out_proj. There is currently no way to only specify one of them.
|
|
self.base_layer.out_proj.update_layer(*args, **kwargs)
|
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
|
"""
|
|
Merge the active adapter weights into the base weights
|
|
|
|
Args:
|
|
safe_merge (`bool`, *optional*):
|
|
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
|
before merging the weights. This is useful if you want to check if the merge operation will produce
|
|
NaNs. Defaults to `False`.
|
|
adapter_names (`List[str]`, *optional*):
|
|
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
|
to `None`.
|
|
"""
|
|
adapter_names = check_adapters_to_merge(self, adapter_names)
|
|
if not adapter_names:
|
|
# no adapter to merge
|
|
return
|
|
|
|
# Implementation follows this:
|
|
# https://github.com/Baijiong-Lin/LoRA-Torch/blob/4bfed6820b64fcf47064c30f30606a190a4f0d2e/loratorch/layers.py#L73-L79
|
|
# Notably, instead of mutating the weight, we delete the original weight and replace it by the merged weight
|
|
# TODO: work with separate weights
|
|
for active_adapter in adapter_names:
|
|
if active_adapter in self.lora_A.keys():
|
|
base_layer = self.get_base_layer()
|
|
orig_dtype = base_layer.out_proj.weight.dtype
|
|
if safe_merge:
|
|
# TODO: work with separate weights
|
|
# merging in_proj (nn.Parameter)
|
|
orig_weight_in = base_layer.in_proj_weight.data.detach().clone()
|
|
orig_weight_in += self.get_delta_weight(active_adapter).to(orig_dtype)
|
|
if not torch.isfinite(orig_weight_in).all():
|
|
raise ValueError(
|
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
|
)
|
|
|
|
# merging out_proj (subclass of nn.Linear)
|
|
orig_weight_out = base_layer.out_proj.weight.data.detach().clone()
|
|
orig_weight_out += base_layer.out_proj.get_delta_weight(active_adapter).to(orig_dtype)
|
|
if not torch.isfinite(orig_weight_out).all():
|
|
raise ValueError(
|
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
|
)
|
|
|
|
# unregister parameter implicitly and overwrite using merged weights; gradients are computed after
|
|
# forward and, thus, after unmerging (see forward()), therefore this is safe to do.
|
|
del base_layer.in_proj_weight
|
|
base_layer.in_proj_weight = orig_weight_in
|
|
|
|
del base_layer.out_proj.get_base_layer().weight
|
|
base_layer.out_proj.get_base_layer().weight = orig_weight_out
|
|
base_layer.out_proj.merge(adapter_names=[active_adapter])
|
|
else:
|
|
# merging in_proj (nn.Parameter)
|
|
# TODO: work with separate weights
|
|
delta_weight = self.get_delta_weight(active_adapter).to(orig_dtype)
|
|
weight_merged = base_layer.in_proj_weight.data.detach() + delta_weight
|
|
|
|
# unregister parameter implicitly and overwrite using merged weights; gradients are computed after
|
|
# forward and, thus, after unmerging (see forward()), therefore this is safe to do.
|
|
del base_layer.in_proj_weight
|
|
base_layer.in_proj_weight = weight_merged
|
|
|
|
# merging out_proj (subclass of nn.Linear)
|
|
delta_weight = base_layer.out_proj.get_delta_weight(active_adapter).to(orig_dtype)
|
|
weight_merged = base_layer.out_proj.weight.data.detach() + delta_weight
|
|
del base_layer.out_proj.get_base_layer().weight
|
|
base_layer.out_proj.get_base_layer().weight = weight_merged
|
|
base_layer.out_proj.merge(adapter_names=[active_adapter])
|
|
self.merged_adapters.append(active_adapter)
|
|
|
|
def unmerge(self) -> None:
|
|
"""
|
|
This method unmerges all merged adapter layers from the base weights.
|
|
"""
|
|
if not self.merged:
|
|
warnings.warn("Already unmerged. Nothing to do.")
|
|
return
|
|
|
|
# TODO work with separate weights
|
|
base_layer = self.get_base_layer()
|
|
orig_dtype = base_layer.out_proj.base_layer.weight.dtype
|
|
while len(self.merged_adapters) > 0:
|
|
active_adapter = self.merged_adapters.pop()
|
|
if active_adapter in self.lora_A.keys():
|
|
# Ensure that requires_grad=False for the base weights after unmerging. This may not matter since
|
|
# requires_grad was False when the optimizer was initialized, but still let's try to be correct here.
|
|
|
|
# in_proj
|
|
delta_weight = self.get_delta_weight(active_adapter).to(orig_dtype)
|
|
old_weight = base_layer.in_proj_weight.data - delta_weight
|
|
del base_layer.in_proj_weight
|
|
base_layer.register_parameter("in_proj_weight", nn.Parameter(old_weight, requires_grad=False))
|
|
|
|
# out_proj
|
|
delta_weight = base_layer.out_proj.get_delta_weight(active_adapter).to(orig_dtype)
|
|
old_weight = base_layer.out_proj.base_layer.weight.data - delta_weight
|
|
del base_layer.out_proj.base_layer.weight
|
|
base_layer.out_proj.base_layer.register_parameter(
|
|
"weight", nn.Parameter(old_weight, requires_grad=False)
|
|
)
|
|
|
|
self.get_base_layer().out_proj.unmerge()
|
|
|
|
def unload_and_optionally_merge_module(
|
|
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
|
|
) -> nn.MultiheadAttention:
|
|
"""
|
|
Merging and unloading of the MultiheadAttention module
|
|
|
|
This requires an extra step for MultiheadAttention, which is why there is this special method instead of
|
|
relying on the normal merge_and_unload code path.
|
|
"""
|
|
if merge:
|
|
self.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
|
base_layer = self.get_base_layer()
|
|
|
|
# extra steps: re-register weights, take care of out_proj layer
|
|
# in_proj
|
|
weight = base_layer.in_proj_weight
|
|
del base_layer.in_proj_weight
|
|
base_layer.register_parameter("in_proj_weight", nn.Parameter(weight.data, requires_grad=weight.requires_grad))
|
|
|
|
# out_proj
|
|
out_proj_layer = base_layer.out_proj.get_base_layer()
|
|
weight = out_proj_layer.weight
|
|
del out_proj_layer.weight
|
|
out_proj_layer.register_parameter("weight", nn.Parameter(weight.data, requires_grad=weight.requires_grad))
|
|
|
|
base_layer.out_proj = out_proj_layer
|
|
return base_layer
|
|
|
|
def get_delta_weight(self, adapter) -> torch.Tensor:
|
|
"""
|
|
Compute the delta weight for the given adapter.
|
|
|
|
Args:
|
|
adapter (str):
|
|
The name of the adapter for which the delta weight should be computed.
|
|
"""
|
|
device = self.lora_B[adapter].weight.device
|
|
dtype = self.lora_B[adapter].weight.dtype
|
|
|
|
# In case users wants to merge the adapter weights that are in
|
|
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
|
|
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
|
|
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16
|
|
|
|
weight_A = self.lora_A[adapter].weight
|
|
weight_B = self.lora_B[adapter].weight
|
|
|
|
if cast_to_fp32:
|
|
weight_A = weight_A.float()
|
|
weight_B = weight_B.float()
|
|
|
|
output_tensor = (weight_B @ weight_A) * self.scaling[adapter]
|
|
|
|
if cast_to_fp32:
|
|
output_tensor = output_tensor.to(dtype=dtype)
|
|
|
|
# cast back the weights
|
|
self.lora_A[adapter].weight.data = weight_A.to(dtype)
|
|
self.lora_B[adapter].weight.data = weight_B.to(dtype)
|
|
|
|
return output_tensor
|
|
|
|
def _check_forward_args(self, x, *args, **kwargs):
|
|
if "adapter_names" in kwargs:
|
|
raise TypeError(f"lora.{self.__class__.__name__} does not support mixed adapter batches.")
|
|
super()._check_forward_args(x, *args, **kwargs)
|
|
|
|
def forward(self, query: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
|
previous_dtype = query.dtype
|
|
self._check_forward_args(query, *args, **kwargs)
|
|
|
|
if self.disable_adapters:
|
|
if self.merged:
|
|
self.unmerge()
|
|
result = self.base_layer(query, *args, **kwargs)
|
|
elif self.merged:
|
|
result = self.base_layer(query, *args, **kwargs)
|
|
else:
|
|
out_proj = self.get_base_layer().out_proj
|
|
if out_proj.active_adapters != self.active_adapters:
|
|
# We have a case that in_proj and out_proj have diverging merged adapters. We cannot
|
|
# really deal with this correctly, thus it's better to raise than possibly create a hard to debug mess
|
|
cls_name = self.get_base_layer().__class__.__name__
|
|
raise ValueError(
|
|
f"The out_proj layer of {cls_name} has merged layers but {cls_name} itself doesn't; please ensure "
|
|
"that either both or none have merged layers"
|
|
)
|
|
|
|
# Merge all adapters that are active for this module, i.e. the LoRA weights for in_proj and out_proj.
|
|
# in_proj uses nn.Parameters, therefore, there is no forward method to be used and we have to explicitly
|
|
# merge for the LoRA weights to have an effect:
|
|
# https://github.com/pytorch/pytorch/blob/6ebb26d572d5fcdc6ac0d1297bdf8d1eb5d20722/torch/nn/modules/activation.py#L1020
|
|
# For out_proj, we have an nn.Linear (or rather: NonDynamicallyQuantizableLinear), but its forward method
|
|
# is not used:
|
|
# https://github.com/pytorch/pytorch/blob/6ebb26d572d5fcdc6ac0d1297bdf8d1eb5d20722/torch/nn/modules/activation.py#L1267-L1271
|
|
# Therefore, its LoRA weights also need to be merged to have an effect.
|
|
active_adapters = [a for a in self.active_adapters if a in self.lora_A]
|
|
try:
|
|
self.merge(adapter_names=active_adapters)
|
|
result = self.base_layer(query, *args, **kwargs)
|
|
finally:
|
|
# it's safe to call unmerge(), which unmerges all adapters, because we checked that not self.merged,
|
|
# i.e. there is was no merged layer before
|
|
self.unmerge()
|
|
|
|
result = (result[0].to(previous_dtype), result[1].to(previous_dtype) if result[1] is not None else result[1])
|
|
return result
|
|
|
|
# The decorator is needed in case low_cpu_mem_usage=True is used, as we don't want the base layer weights to be
|
|
# moved to meta device. This requires the use of PEFT's implementation of init_empty_weight instead of using the one
|
|
# from accelerate.
|
|
@skip_init_on_device
|
|
def _restore_weights(self):
|
|
# Restore the weights as registered parameters on the base layer.
|
|
# This is necessary because the way that weights are merged/unmerged (which is necessary for forward to work
|
|
# correctly), the Module "forgets" these attributes. Therefore, we need to call register_parameter explicitly.
|
|
# We cannot call register_parameter for merging/unmerging because that cuts them off from the autograd graph.
|
|
# Note that this is hacky, since we need to ensure that _restore_weights is called by each method that needs it.
|
|
|
|
# in_proj
|
|
# TODO work with separate weights
|
|
base_layer = self.get_base_layer()
|
|
weight = base_layer.in_proj_weight
|
|
del base_layer.in_proj_weight
|
|
base_layer.register_parameter("in_proj_weight", nn.Parameter(weight.data, requires_grad=weight.requires_grad))
|
|
|
|
# out_proj
|
|
base_layer = base_layer.out_proj.get_base_layer()
|
|
weight = base_layer.weight
|
|
del base_layer.weight
|
|
base_layer.register_parameter("weight", nn.Parameter(weight.data, requires_grad=weight.requires_grad))
|
|
|
|
def state_dict(self, *args, **kwargs):
|
|
self._restore_weights()
|
|
return super().state_dict(*args, **kwargs)
|
|
|
|
def named_modules(self, *args, **kwargs):
|
|
# Note: no need to also implement modules(), as modules() calls named_modules() under the hood
|
|
self._restore_weights()
|
|
return super().named_modules(*args, **kwargs)
|
|
|
|
def __repr__(self) -> str:
|
|
rep = super().__repr__()
|
|
return "lora." + rep
|
|
|
|
|
|
class _LoraParameterProxy(nn.Module):
|
|
"""This proxies an `nn.Parameter` that is targeted with LoRA.
|
|
Intended to be used in conjunction with `nn.utils.parametrize`, see `ParamWrapper`.
|
|
"""
|
|
|
|
def __init__(self, delta_weight):
|
|
super().__init__()
|
|
self.delta_weight = delta_weight
|
|
|
|
@staticmethod
|
|
def _low_prec_add(x, y):
|
|
# addition in fp8 is not directly supported, need to use a higher precision
|
|
orig_dtype = x.dtype
|
|
upcast_dtype = y.dtype
|
|
if upcast_dtype not in ALLOWED_COMPUTE_DTYPES:
|
|
raise RuntimeError(
|
|
f"There is an attempt to upcast the targeted parameter to {upcast_dtype} "
|
|
f"but the only supported are: {ALLOWED_COMPUTE_DTYPES}."
|
|
)
|
|
|
|
# this operation can be quite costly
|
|
x = x.to(upcast_dtype)
|
|
z = x + y
|
|
# clamp to valid range before casting down, as this is *not* performed automatically and can thus result in NANs
|
|
info = torch.finfo(orig_dtype)
|
|
z = z.clamp(min=info.min, max=info.max)
|
|
return z.to(orig_dtype)
|
|
|
|
def forward(self, W):
|
|
if any(getattr(torch, dtype_name, None) == W.dtype for dtype_name in UPCAST_DTYPES):
|
|
return self._low_prec_add(W, self.delta_weight)
|
|
return W + self.delta_weight
|
|
|
|
|
|
# copied from:
|
|
# https://github.com/pytorch/pytorch/blob/5e386eec9426f174eea130c0c012d9f65ebe65fb/torch/nn/utils/parametrize.py#L75-L79
|
|
def _register_parameter_or_buffer(module, name, X):
|
|
if isinstance(X, nn.Parameter):
|
|
module.register_parameter(name, X)
|
|
else:
|
|
module.register_buffer(name, X)
|
|
|
|
|
|
class ParamWrapper(nn.Module, LoraLayer):
|
|
"""A LoRA wrapper for `nn.Parameter`. This layer is dispatched if users target a parameter directly with
|
|
`lora_config.target_parameters`
|
|
Note:
|
|
- When accessing the wrapped nn.Parameter directly, e.g. via `module.weight`, the LoRA weights are *not*
|
|
applied.
|
|
- It is currently not implemented to target multiple parameters on the same module. To achieve this, it is
|
|
currently required to create a separate LoRA adapter (with another adapter name) and activate both at the
|
|
same time.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_layer,
|
|
adapter_name: str,
|
|
parameter_name: str,
|
|
config: LoraConfig,
|
|
r: int = 0,
|
|
lora_alpha: int = 1,
|
|
is_target_conv_1d_layer: bool = False,
|
|
**kwargs,
|
|
) -> None:
|
|
self.parameter_name = parameter_name
|
|
super().__init__()
|
|
LoraLayer.__init__(self, base_layer, **kwargs)
|
|
|
|
if config.lora_dropout:
|
|
# It's not possible to factor out x from lora_B(lora_A(dropout(x))), so dropout can't be correctly
|
|
# implemented
|
|
raise ValueError(f"lora.{self.__class__.__name__} does not work with lora_dropout != 0.")
|
|
if config.fan_in_fan_out:
|
|
raise ValueError(f"lora.{self.__class__.__name__} does not work with fan_in_fan_out.")
|
|
if config.lora_bias:
|
|
raise ValueError(f"lora.{self.__class__.__name__} does not work with lora_bias=True.")
|
|
if config.use_dora:
|
|
raise ValueError(f"lora.{self.__class__.__name__} does not work with use_dora=True.")
|
|
if is_target_conv_1d_layer:
|
|
raise ValueError(f"lora.{self.__class__.__name__} does not work with is_target_conv_1d_layer=True.")
|
|
|
|
self.fan_in_fan_out = config.fan_in_fan_out
|
|
self._did_swap_in_out_features = False # ensure we swap only once
|
|
self._active_adapter = adapter_name
|
|
self.update_layer(
|
|
adapter_name,
|
|
r,
|
|
lora_alpha=lora_alpha,
|
|
config=config,
|
|
)
|
|
|
|
def _get_in_out_features(self, module: nn.Module) -> tuple[int, int] | tuple[None, None]:
|
|
# For ParamWrapper, we don't derive the in_features and out_features based on the base layer type, but directly
|
|
# from the targeted parameter.
|
|
param = self.get_param()
|
|
if param.ndim == 3:
|
|
num_experts, in_features, out_features = param.shape
|
|
else:
|
|
num_experts, in_features, out_features = 1, param.shape[1], param.shape[0]
|
|
if param.ndim not in (2, 3):
|
|
raise ValueError(
|
|
f"lora.{self.__class__.__name__} was initialized with {param.ndim} dimensional Parameter, but only 2d "
|
|
"and 3d are supported."
|
|
)
|
|
# we have to store the num_experts attribute here, as the parent class only stores in_features and out_features.
|
|
self.num_experts = num_experts
|
|
return in_features, out_features
|
|
|
|
def update_layer(
|
|
self,
|
|
adapter_name: str,
|
|
r: int,
|
|
lora_alpha: int,
|
|
config: LoraConfig,
|
|
**kwargs,
|
|
) -> None:
|
|
# same method as in lora.Linear but taking into account that there can be multiple experts (3d parameter)
|
|
lora_dropout = config.lora_dropout
|
|
init_lora_weights = config.init_lora_weights
|
|
use_rslora = config.use_rslora
|
|
lora_bias = config.lora_bias
|
|
inference_mode = config.inference_mode
|
|
# This code works for linear layers, override for other layer types
|
|
if r <= 0:
|
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
|
|
|
lora_variant = self.resolve_lora_variant(config=config)
|
|
if lora_variant is not None:
|
|
raise ValueError(f"lora.{self.__class__.__name__} does not work with LoRA variants like DoRA.")
|
|
|
|
# for some MoE layers, the order is (experts, out_features, in_features)
|
|
is_transposed = getattr(self.get_base_layer(), "is_transposed", False)
|
|
swap_in_out_features = (self.get_param().ndim == 3) and not is_transposed
|
|
if swap_in_out_features and not self._did_swap_in_out_features:
|
|
self.in_features, self.out_features = self.out_features, self.in_features
|
|
self._did_swap_in_out_features = True
|
|
|
|
self.r[adapter_name] = r
|
|
self.lora_alpha[adapter_name] = lora_alpha
|
|
if lora_dropout > 0.0:
|
|
# It's not possible to factor out x from lora_B(lora_A(dropout(x))), so dropout can't be correctly
|
|
# implemented
|
|
raise ValueError(f"lora.{self.__class__.__name__} does not work with lora_dropout != 0.")
|
|
else:
|
|
lora_dropout_layer = nn.Identity()
|
|
|
|
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
|
|
# Actual trainable parameters
|
|
# Difference to normal update_layer: consider experts. LoRA layers still use nn.Linear for consistency with
|
|
# lora.Linear.
|
|
self.lora_A[adapter_name] = nn.Linear(self.in_features, r * self.num_experts, bias=False)
|
|
self.lora_B[adapter_name] = nn.Linear(r * self.num_experts, self.out_features, bias=lora_bias)
|
|
self.lora_bias[adapter_name] = lora_bias
|
|
|
|
if use_rslora:
|
|
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
|
|
else:
|
|
self.scaling[adapter_name] = lora_alpha / r
|
|
|
|
self.use_rslora[adapter_name] = use_rslora
|
|
|
|
self.use_dora[adapter_name] = config.use_dora
|
|
|
|
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
|
|
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
|
|
with gather_params_ctx(self.get_base_layer().weight):
|
|
self.pissa_init(adapter_name, init_lora_weights)
|
|
elif isinstance(init_lora_weights, str) and init_lora_weights.startswith("corda"):
|
|
with gather_params_ctx(self.get_base_layer().weight):
|
|
self.corda_init(adapter_name, init_lora_weights)
|
|
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
|
|
with gather_params_ctx(self.get_base_layer().weight):
|
|
self.olora_init(adapter_name)
|
|
elif init_lora_weights == "loftq":
|
|
with gather_params_ctx(self.get_base_layer().weight):
|
|
self.loftq_init(adapter_name)
|
|
elif init_lora_weights == "eva":
|
|
nn.init.zeros_(self.lora_B[adapter_name].weight)
|
|
elif init_lora_weights == "orthogonal":
|
|
with gather_params_ctx(self.get_base_layer().weight):
|
|
self.orthogonal_init(adapter_name)
|
|
elif init_lora_weights == "lora_ga":
|
|
with gather_params_ctx(self.get_base_layer().weight):
|
|
self.lora_ga_init(adapter_name, config.lora_ga_config)
|
|
elif init_lora_weights:
|
|
self.reset_lora_parameters(adapter_name, init_lora_weights)
|
|
# call this before init of the lora variants
|
|
self._move_adapter_to_device_of_base_layer(adapter_name)
|
|
|
|
if adapter_name in self.lora_variant:
|
|
self.lora_variant[adapter_name].init(self, config=config, **kwargs)
|
|
|
|
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
|
|
|
|
def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None:
|
|
"""
|
|
Move the adapter of the given name to the device of the base layer. Needs special handling for nn.Parameter
|
|
"""
|
|
device = self.get_param().device
|
|
meta = torch.device("meta")
|
|
param = self.get_param()
|
|
|
|
for adapter_layer_name in self.adapter_layer_names + self.other_param_names:
|
|
adapter_layer = getattr(self, adapter_layer_name, None)
|
|
if not isinstance(adapter_layer, (nn.ModuleDict, nn.ParameterDict, BufferDict)):
|
|
continue
|
|
if adapter_name not in adapter_layer:
|
|
continue
|
|
if any(p.device == meta for p in adapter_layer.parameters()):
|
|
continue
|
|
|
|
if param.dtype.is_floating_point or param.dtype.is_complex:
|
|
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device, dtype=param.dtype)
|
|
else:
|
|
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device)
|
|
|
|
def get_param(self):
|
|
param = getattr(self.get_base_layer(), self.parameter_name)
|
|
return param
|
|
|
|
def get_delta_weight(self, adapter_name, *args, **kwargs):
|
|
if self.num_experts == 1:
|
|
# could actually be a normal layer or experts stacked block-diagonally, acting like a single layer
|
|
delta_weight = Linear.get_delta_weight(self, adapter_name, *args, **kwargs)
|
|
else:
|
|
weight_A = self.lora_A[adapter_name].weight
|
|
weight_B = self.lora_B[adapter_name].weight
|
|
# shape: experts x rank x in_features
|
|
weight_A = weight_A.reshape(self.num_experts, -1, weight_A.shape[-1])
|
|
# shape: out_features x rank x experts
|
|
weight_B = weight_B.reshape(weight_B.shape[0], -1, self.num_experts)
|
|
# fan_in_fan_out must be False, so no transpose call here
|
|
if not self._did_swap_in_out_features:
|
|
delta_weight = torch.einsum("o r e, e r i -> e i o", weight_B, weight_A) * self.scaling[adapter_name]
|
|
else:
|
|
# for some MoE layers, the order is (experts, out_features, in_features)
|
|
delta_weight = torch.einsum("o r e, e r i -> e o i", weight_B, weight_A) * self.scaling[adapter_name]
|
|
|
|
param = self.get_param()
|
|
if param.dtype in ALLOWED_COMPUTE_DTYPES:
|
|
delta_weight = delta_weight.to(param.device, param.dtype)
|
|
else:
|
|
# don't cast dW to weight dtype if it is in torch.float8_e4m3fn etc. as these low precision dtypes because
|
|
# we want to perform the W+dW addition in high precision before downcasting, see _low_prec_add.
|
|
delta_weight = delta_weight.to(param.device)
|
|
return delta_weight
|
|
|
|
@contextmanager
|
|
def _activate_lora(self, active_adapters: list[str]):
|
|
if not active_adapters or not any(adapter in self.lora_A for adapter in active_adapters):
|
|
# no active adapters for this layer
|
|
yield
|
|
return
|
|
|
|
delta_weight = None
|
|
for active_adapter in active_adapters:
|
|
if active_adapter not in self.lora_A:
|
|
continue
|
|
if delta_weight is None:
|
|
delta_weight = self.get_delta_weight(active_adapter)
|
|
else:
|
|
delta_weight = delta_weight + self.get_delta_weight(active_adapter)
|
|
|
|
base_layer = self.get_base_layer()
|
|
requires_grad_before = self.get_param().requires_grad
|
|
nn.utils.parametrize.register_parametrization(
|
|
base_layer, self.parameter_name, _LoraParameterProxy(delta_weight)
|
|
)
|
|
# set requires_grad, as it defaults to False
|
|
base_layer.parametrizations[self.parameter_name].original.requires_grad_(requires_grad_before)
|
|
try:
|
|
with nn.utils.parametrize.cached():
|
|
yield
|
|
finally:
|
|
self._remove_parametrizations()
|
|
|
|
def _remove_parametrizations(self):
|
|
# Remove the parametrization of this specific parameter
|
|
base_layer = self.get_base_layer()
|
|
parameter_name = self.parameter_name
|
|
if parameter_name not in base_layer.parametrizations:
|
|
raise ValueError(
|
|
"Something went wrong, please report this issue on PEFT: https://github.com/huggingface/peft/issues"
|
|
)
|
|
|
|
param_list = base_layer.parametrizations[parameter_name]
|
|
if len(param_list) == 1:
|
|
# last parametrization, we can safely remove it completely
|
|
nn.utils.parametrize.remove_parametrizations(base_layer, parameter_name, leave_parametrized=False)
|
|
return
|
|
|
|
# If there are multiple parametrizations for the same parameter_name, we only want to remove the LoRA proxy.
|
|
# Unfortunately, PyTorch does not support this directly, so we need to take care of it manually. To achieve
|
|
# this, we check the ParameterList from the back until we find the _LoraParameterProxy instance and then remove
|
|
# it.
|
|
reversed_indices = reversed(range(len(param_list)))
|
|
for i in reversed_indices:
|
|
module = param_list[i]
|
|
if isinstance(module, _LoraParameterProxy):
|
|
del param_list[i]
|
|
break
|
|
else: # no break encountered
|
|
# this should not happen, but raising an error is probably not necessary
|
|
warnings.warn(
|
|
f"Could not find any LoRA parametrization on {self}, please open an issue on "
|
|
"https://github.com/huggingface/peft/issues and report this warning."
|
|
)
|
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
|
# same as lora.Linear.merge but not hard-coding base_layer.weight and without special cases like variants removed
|
|
adapter_names = check_adapters_to_merge(self, adapter_names)
|
|
if not adapter_names:
|
|
# no adapter to merge
|
|
return
|
|
|
|
for active_adapter in adapter_names:
|
|
if active_adapter in self.lora_A.keys():
|
|
base_layer = self.get_base_layer()
|
|
param = getattr(base_layer, self.parameter_name)
|
|
if safe_merge:
|
|
# Note that safe_merge will be slower than the normal merge
|
|
# because of the copy operation.
|
|
orig_weight = param.data.clone()
|
|
orig_dtype = orig_weight.dtype
|
|
delta_weight = self.get_delta_weight(active_adapter)
|
|
orig_weight += delta_weight.to(orig_dtype)
|
|
|
|
if not torch.isfinite(orig_weight).all():
|
|
raise ValueError(
|
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
|
)
|
|
|
|
param.data = orig_weight
|
|
|
|
else:
|
|
delta_weight = self.get_delta_weight(active_adapter)
|
|
param.data += delta_weight
|
|
|
|
self.merged_adapters.append(active_adapter)
|
|
|
|
def unmerge(self) -> None:
|
|
# same as lora.Linear.unmerge but not hard-coding base_layer.weight and without special cases like variants removed
|
|
if not self.merged:
|
|
warnings.warn("Already unmerged. Nothing to do.")
|
|
return
|
|
while len(self.merged_adapters) > 0:
|
|
active_adapter = self.merged_adapters.pop()
|
|
if active_adapter in self.lora_A.keys():
|
|
param = getattr(self.get_base_layer(), self.parameter_name)
|
|
orig_dtype = param.dtype
|
|
delta_weight = self.get_delta_weight(active_adapter)
|
|
param.data -= delta_weight.to(orig_dtype)
|
|
|
|
def _check_forward_args(self, x, *args, **kwargs):
|
|
"""Check if the arguments are compatible with the configs and state of the model"""
|
|
if kwargs.get("adapter_names", None):
|
|
raise ValueError(f"lora.{self.__class__.__name__} does not support mixed adapter batches yet.")
|
|
super()._check_forward_args(x, *args, **kwargs)
|
|
|
|
def unload_and_optionally_merge_module(self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]):
|
|
base_layer = self.base_layer
|
|
# ParamWrappers can be nested, so merge and retrieve base layer recursively
|
|
if merge:
|
|
self.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
|
while isinstance(base_layer, ParamWrapper):
|
|
base_layer.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
|
base_layer = base_layer.base_layer
|
|
else:
|
|
base_layer = self.get_base_layer()
|
|
return base_layer
|
|
|
|
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
|
self._check_forward_args(x, *args, **kwargs)
|
|
adapter_names = kwargs.pop("adapter_names", None)
|
|
|
|
if self.disable_adapters:
|
|
if self.merged:
|
|
self.unmerge()
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
elif adapter_names is not None:
|
|
raise ValueError(f"lora.{self.__class__.__name__} does not support mixed batch inference")
|
|
elif self.merged:
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
else:
|
|
with self._activate_lora(self.active_adapters):
|
|
result = self.base_layer(x, *args, **kwargs)
|
|
return result
|
|
|
|
def __repr__(self) -> str:
|
|
rep = super().__repr__()
|
|
idx = rep.find("(") + 1
|
|
# insert the name of the parameter to allow the repr to be disambiguous when multiple parameters on the same
|
|
# module are being targeted
|
|
rep = f"{rep[:idx]}\n parameter_name='{self.parameter_name}',{rep[idx:]}"
|
|
return "lora." + rep
|
|
|
|
|
|
def dispatch_default(
|
|
target: torch.nn.Module,
|
|
adapter_name: str,
|
|
config: LoraConfig,
|
|
parameter_name: Optional[str] = None,
|
|
**kwargs,
|
|
) -> Optional[torch.nn.Module]:
|
|
new_module = None
|
|
|
|
if isinstance(target, BaseTunerLayer):
|
|
target_base_layer = target.get_base_layer()
|
|
else:
|
|
target_base_layer = target
|
|
|
|
if parameter_name is not None:
|
|
new_module = ParamWrapper(target, adapter_name, parameter_name=parameter_name, config=config, **kwargs)
|
|
elif isinstance(target_base_layer, torch.nn.Embedding):
|
|
new_module = Embedding(target, adapter_name, config=config, **kwargs)
|
|
elif isinstance(target_base_layer, torch.nn.Conv2d):
|
|
new_module = Conv2d(target, adapter_name, config=config, **kwargs)
|
|
elif isinstance(target_base_layer, torch.nn.Conv3d):
|
|
new_module = Conv3d(target, adapter_name, config=config, **kwargs)
|
|
elif isinstance(target_base_layer, nn.Conv1d):
|
|
new_module = Conv1d(target, adapter_name, config=config, **kwargs)
|
|
elif isinstance(target_base_layer, torch.nn.MultiheadAttention):
|
|
new_module = MultiheadAttention(target, adapter_name, config=config, **kwargs)
|
|
elif isinstance(target_base_layer, torch.nn.Linear):
|
|
if config.fan_in_fan_out:
|
|
warnings.warn(
|
|
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
|
"Setting fan_in_fan_out to False."
|
|
)
|
|
config.fan_in_fan_out = False
|
|
new_module = Linear(target, adapter_name, config=config, **kwargs)
|
|
elif isinstance(target_base_layer, Conv1D):
|
|
if not config.fan_in_fan_out:
|
|
warnings.warn(
|
|
"fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True."
|
|
)
|
|
config.fan_in_fan_out = True
|
|
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, config=config, **kwargs)
|
|
|
|
return new_module
|