mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 15:00:33 +08:00
Add reference-impl URLs to variant docstrings + V2 external review
- Fetch canonical reference impls for offline review:
* peft_{lora,hra,delora,ia3}_layer.py + peft_lora_{dora,variants}.py
* orig_pissa_init.py (MuLabPKU/PiSSA)
* orig_hra_layer.py (DaShenZi721/HRA)
* orig_delora.py (ExplainableML/DeLoRA author fork)
- Add reference-impl URLs to all 6 variant docstrings
- Document HRA gate=0 dead-grad issue and DoRA detach-omission in their docstrings
- Re-run external review (codex) with refs available -> docs/audit/variants_review_v2.md
Major NEW findings vs paper-only review:
* DeLoRA: scalar W.norm() should be per-input-channel norm(dim=0)
* HRA: PEFT uses symmetric repeated-column init (no dead grad), not zero gate
* IA3: FFN targets need input-side gating, not output, our up_proj advice wrong
* All LoRA-family: cfg.dropout silently ignored (no-op)
* DeLoRA: wnorm should be persistent buffer, not Parameter
HRA and DeLoRA upgraded to BUGGY (from Partial)
This commit is contained in:
@@ -0,0 +1,74 @@
|
||||
# Per-variant paper-faithfulness audit V2 (with reference implementations)
|
||||
|
||||
Re-audit of `lora-lite` after adding canonical reference implementation URLs to
|
||||
each variant docstring. Your job: for each variant, **directly compare** our
|
||||
implementation against the reference impl (peft and/or paper-author repo), not
|
||||
just against the paper text. This is round 2 — the previous review (you can
|
||||
read `docs/audit/variants_review.md`) found:
|
||||
|
||||
- HRA gate=0 init kills `lora_U` gradient on step 0
|
||||
- DeLoRA same pattern with lambda0=0
|
||||
- IA3 targets q/v not paper k/v/ffn-down (deviation documented but untested)
|
||||
- PiSSA bf16 init err 0.31 on Qwen
|
||||
- Saved adapters don't preserve PiSSA W_res mutation
|
||||
|
||||
Your job now is to verify those findings against the **reference code**, and
|
||||
look for anything the prior review missed once you have the reference in hand.
|
||||
|
||||
## Inputs
|
||||
|
||||
- Our code: `src/lora_lite/variants/{lora,pissa,dora,ia3,hra,delora}.py`
|
||||
- Adapter plumbing: `src/lora_lite/{adapter.py,target.py,variant.py,config.py}`
|
||||
- Papers (text): `docs/papers/*_*.txt`
|
||||
- **Reference implementations** (just added):
|
||||
- `docs/refs/peft_lora_layer.py` — peft LoRA Linear (and PiSSA init paths)
|
||||
- `docs/refs/peft_lora_dora.py` — peft DoRA helper module
|
||||
- `docs/refs/peft_lora_variants.py` — peft per-variant init dispatch (PiSSA, OLoRA, etc.)
|
||||
- `docs/refs/peft_ia3_layer.py` — peft IA3 layer
|
||||
- `docs/refs/peft_hra_layer.py` — peft HRA layer (clean, has apply_GS toggle)
|
||||
- `docs/refs/peft_delora_layer.py` — peft DeLoRA layer (upstreamed)
|
||||
- `docs/refs/orig_pissa_init.py` — PiSSA paper authors' init script (MuLabPKU)
|
||||
- `docs/refs/orig_hra_layer.py` — HRA paper authors' OFT-with-HRA layer (DaShenZi721)
|
||||
- `docs/refs/orig_delora.py` — DeLoRA paper authors' fork-of-peft impl (ExplainableML)
|
||||
- Logs: `logs/smoke.log`, `logs/qwen_probe.log`
|
||||
- Prior review: `docs/audit/variants_review.md` (do NOT just restate it)
|
||||
|
||||
## What to deliver per variant (LoRA, PiSSA, DoRA, IA3, HRA, DeLoRA)
|
||||
|
||||
1. **Reference impl ground-truth** — what does the *reference* code actually do
|
||||
for: parameter shapes, initialization, scale factor, forward equation,
|
||||
save/load, target placement? Quote ≤10 lines with file/line cites from
|
||||
`docs/refs/`.
|
||||
|
||||
2. **Our code** — quote our impl (≤10 lines, with `src/lora_lite/variants/<v>.py:LN` cites).
|
||||
|
||||
3. **Diff** — bullet list of every meaningful difference.
|
||||
Mark each one as: `[OK-doc]` (acceptable, documented), `[OK-undoc]` (acceptable,
|
||||
should add to docstring), `[BUG]` (likely wrong), `[STYLE]` (cosmetic).
|
||||
|
||||
4. **Did the prior review get it right?** Quote the relevant prior verdict
|
||||
line and either confirm or correct.
|
||||
|
||||
5. **Verdict** — Faithful / Faithful-with-doc-gap / Partial / Buggy.
|
||||
One-line reason.
|
||||
|
||||
## Final aggregate
|
||||
|
||||
Markdown table:
|
||||
|
||||
| variant | prior verdict | new verdict | new bugs found | doc gaps |
|
||||
|
||||
And a 5-bullet "what to fix next" list, ordered by severity.
|
||||
|
||||
## Hard rules
|
||||
|
||||
- Quote evidence from `docs/refs/` files. If you can't find the relevant
|
||||
reference function, say so explicitly — don't guess.
|
||||
- Do NOT edit code. Output review only.
|
||||
- Be specific about line numbers from the references. "peft does X" is not
|
||||
enough; "peft_lora_layer.py:L1234 does X" is.
|
||||
- If you find a NEW bug not flagged in `variants_review.md`, mark it
|
||||
`[NEW-BUG]` and explain the failure mode.
|
||||
- If the prior review was wrong (false positive), mark it `[OVERTURN]`.
|
||||
|
||||
Write to stdout. I will redirect to `docs/audit/variants_review_v2.md`.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,446 @@
|
||||
import importlib
|
||||
import math
|
||||
import re
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from ..utils import PeftConfig, PeftType, transpose
|
||||
|
||||
|
||||
def is_bnb_available():
|
||||
return importlib.util.find_spec("bitsandbytes") is not None
|
||||
|
||||
|
||||
if is_bnb_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeloraConfig(PeftConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`~peft.Delora`].
|
||||
|
||||
Args:
|
||||
r (`int`): Delora attention dimension
|
||||
target_modules (`Union[List[str],str]`): The names of the modules to apply Delora to.
|
||||
delora_lambda (`float`): The lambda parameter for Delora scaling.
|
||||
delora_dropout (`float`): The dropout probability for Delora layers.
|
||||
merge_weights (`bool`):
|
||||
Whether to merge the weights of the Delora layers with the base transformer model in `eval` mode.
|
||||
fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
enable_delora ( `List[bool]`): Used with `delora.MergedLinear`.
|
||||
bias (`str`): Bias type for Delora. Can be 'none', 'all' or 'delora_only'
|
||||
modules_to_save (`List[str]`):List of modules apart from Delora layers to be set as trainable
|
||||
and saved in the final checkpoint.
|
||||
"""
|
||||
|
||||
r: int = field(default=8, metadata={"help": "Delora attention dimension"})
|
||||
target_modules: Optional[Union[List[str], str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "List of module names or regex expression of the module names to replace with Delora."
|
||||
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
|
||||
},
|
||||
)
|
||||
delora_lambda: int = field(default=None, metadata={"help": "Delora lambda"})
|
||||
delora_dropout: float = field(default=None, metadata={"help": "Delora dropout"})
|
||||
Wdecompose_target_modules: Optional[Union[List[str], str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "List of module names or regex expression of the module names to only tune the magnitude part"
|
||||
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
|
||||
},
|
||||
)
|
||||
merge_weights: bool = field(
|
||||
default=False, metadata={"help": "Merge weights of the original model and the Delora model"}
|
||||
)
|
||||
fan_in_fan_out: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
|
||||
)
|
||||
enable_delora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `delora.MergedLinear`."})
|
||||
bias: str = field(default="none", metadata={"help": "Bias type for Delora. Can be 'none', 'all' or 'delora_only'"})
|
||||
modules_to_save: Optional[List[str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "List of modules apart from Delora layers to be set as trainable and saved in the final checkpoint. "
|
||||
"For example, in Sequence Classification or Token Classification tasks, "
|
||||
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.peft_type = PeftType.DELORA
|
||||
|
||||
|
||||
class DeloraModel(torch.nn.Module):
|
||||
"""
|
||||
Creates Decoupled Low Rank Adapter (Delora) model from a pretrained transformers model.
|
||||
|
||||
Args:
|
||||
model ([`transformers.PreTrainedModel`]): The model to be adapted.
|
||||
config ([`DeloraConfig`]): The configuration of the Delora model.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: The Delora model.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import AutoModelForSeq2SeqLM, DeloraConfig >>> from peft import DeloraModel, DeloraConfig >>>
|
||||
config = DeloraConfig(
|
||||
peft_type="DELORA", task_type="SEQ_2_SEQ_LM", r=8, delora_lambda=32, target_modules=["q", "v"],
|
||||
delora_dropout=0.01, )
|
||||
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> delora_model = DeloraModel(config, model)
|
||||
|
||||
**Attributes**:
|
||||
- **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted.
|
||||
- **peft_config** ([`DeloraConfig`]): The configuration of the Delora model.
|
||||
"""
|
||||
|
||||
def __init__(self, config, model):
|
||||
super().__init__()
|
||||
self.peft_config = config
|
||||
print(self.peft_config)
|
||||
self.model = model
|
||||
self._find_and_replace()
|
||||
mark_only_delora_as_trainable(self.model, self.peft_config.bias)
|
||||
self.forward = self.model.forward
|
||||
|
||||
def _find_and_replace(self):
|
||||
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
|
||||
if loaded_in_8bit and not is_bnb_available():
|
||||
raise ImportError(
|
||||
"To use Delora with 8-bit quantization, please install the `bitsandbytes` package. "
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
is_target_modules_in_base_model = False
|
||||
is_hf_device_map_available = hasattr(self.model, "hf_device_map")
|
||||
kwargs = {
|
||||
"r": self.peft_config.r,
|
||||
"delora_lambda": self.peft_config.delora_lambda,
|
||||
"delora_dropout": self.peft_config.delora_dropout,
|
||||
"fan_in_fan_out": self.peft_config.fan_in_fan_out,
|
||||
"merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode)
|
||||
and not is_hf_device_map_available,
|
||||
}
|
||||
key_list = [key for key, _ in self.model.named_modules()]
|
||||
for key in key_list:
|
||||
if isinstance(self.peft_config.target_modules, str):
|
||||
target_module_found = re.fullmatch(self.peft_config.target_modules, key)
|
||||
else:
|
||||
target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules)
|
||||
|
||||
if target_module_found:
|
||||
if not is_target_modules_in_base_model:
|
||||
is_target_modules_in_base_model = True
|
||||
parent, target, target_name = self._get_submodules(key)
|
||||
bias = target.bias is not None
|
||||
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
|
||||
kwargs.update(
|
||||
{
|
||||
"has_fp16_weights": target.state.has_fp16_weights,
|
||||
"memory_efficient_backward": target.state.memory_efficient_backward,
|
||||
"threshold": target.state.threshold,
|
||||
"index": target.index,
|
||||
}
|
||||
)
|
||||
if self.peft_config.enable_delora is None:
|
||||
print("8 bit delora")
|
||||
new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
|
||||
else:
|
||||
kwargs.update({"enable_delora": self.peft_config.enable_delora})
|
||||
new_module = MergedLinear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
|
||||
elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_delora is None:
|
||||
new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
|
||||
elif self.peft_config.enable_delora is not None:
|
||||
kwargs.update({"enable_delora": self.peft_config.enable_delora})
|
||||
if isinstance(target, Conv1D):
|
||||
in_features, out_features = (
|
||||
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
|
||||
)
|
||||
else:
|
||||
in_features, out_features = target.in_features, target.out_features
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is not a Conv1D. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = self.peft_config.fan_in_fan_out = False
|
||||
new_module = MergedLinear(in_features, out_features, bias=bias, **kwargs)
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
if not is_target_modules_in_base_model:
|
||||
raise ValueError(
|
||||
f"Target modules {self.peft_config.target_modules} not found in the base model. "
|
||||
f"Please check the target modules and try again."
|
||||
)
|
||||
|
||||
def _get_submodules(self, key):
|
||||
parent = self.model.get_submodule(".".join(key.split(".")[:-1]))
|
||||
target_name = key.split(".")[-1]
|
||||
target = self.model.get_submodule(key)
|
||||
return parent, target, target_name
|
||||
|
||||
def _replace_module(self, parent_module, child_name, new_module, old_module):
|
||||
setattr(parent_module, child_name, new_module)
|
||||
new_module.weight = old_module.weight
|
||||
if old_module.bias is not None:
|
||||
new_module.bias = old_module.bias
|
||||
if getattr(old_module, "state", None) is not None:
|
||||
new_module.state = old_module.state
|
||||
new_module.to(old_module.weight.device)
|
||||
|
||||
# dispatch to correct device
|
||||
for name, module in new_module.named_modules():
|
||||
if "delora_" in name:
|
||||
module.to(old_module.weight.device)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""Forward missing attributes to the wrapped module."""
|
||||
try:
|
||||
return super().__getattr__(name) # defer to nn.Module's logic
|
||||
except AttributeError:
|
||||
return getattr(self.model, name)
|
||||
|
||||
@property
|
||||
def modules_to_save(self):
|
||||
return None
|
||||
|
||||
def get_peft_config_as_dict(self, inference: bool = False):
|
||||
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()}
|
||||
if inference:
|
||||
config["inference_mode"] = True
|
||||
return config
|
||||
|
||||
def _set_adapter_layers(self, enabled=True):
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, DeloraLayer):
|
||||
module.disable_adapters = False if enabled else True
|
||||
|
||||
def enable_adapter_layers(self):
|
||||
self._set_adapter_layers(enabled=True)
|
||||
|
||||
def disable_adapter_layers(self):
|
||||
self._set_adapter_layers(enabled=False)
|
||||
|
||||
|
||||
# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
||||
# and modified to work with PyTorch FSDP
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# had to adapt it for `delora_only` to work
|
||||
def mark_only_delora_as_trainable(model: nn.Module, bias: str = "none") -> None:
|
||||
for n, p in model.named_parameters():
|
||||
if "delora_" not in n:
|
||||
p.requires_grad = False
|
||||
if bias == "none":
|
||||
return
|
||||
elif bias == "all":
|
||||
for n, p in model.named_parameters():
|
||||
if "bias" in n:
|
||||
p.requires_grad = True
|
||||
elif bias == "delora_only":
|
||||
for m in model.modules():
|
||||
if isinstance(m, DeloraLayer) and hasattr(m, "bias") and m.bias is not None:
|
||||
m.bias.requires_grad = True
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DeloraLayer:
|
||||
def __init__(
|
||||
self,
|
||||
r: int,
|
||||
delora_lambda_value: int,
|
||||
delora_dropout: float,
|
||||
merge_weights: bool,
|
||||
):
|
||||
self.r = r
|
||||
self.delora_lambda_value = delora_lambda_value
|
||||
# Optional dropout
|
||||
if delora_dropout > 0.0:
|
||||
self.delora_dropout = nn.Dropout(p=delora_dropout)
|
||||
else:
|
||||
self.delora_dropout = lambda x: x
|
||||
# Mark the weight as unmerged
|
||||
self.merged = False
|
||||
self.merge_weights = merge_weights
|
||||
self.disable_adapters = False
|
||||
|
||||
|
||||
class Linear(nn.Linear, DeloraLayer):
|
||||
# Delora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int = 0,
|
||||
delora_lambda: float = 1.,
|
||||
delora_dropout: float = 0.0,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
merge_weights: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
||||
DeloraLayer.__init__(self, r=r, delora_lambda_value=delora_lambda, delora_dropout=delora_dropout, merge_weights=merge_weights)
|
||||
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
|
||||
if r > 0:
|
||||
self.delora_A = nn.Linear(in_features, r, bias=False)
|
||||
self.delora_B = nn.Linear(r, out_features, bias=False)
|
||||
|
||||
self.delora_lambda = nn.Parameter(torch.full((1,), delora_lambda), requires_grad=True)
|
||||
|
||||
# Frozen parameters
|
||||
self.frozen_C = nn.Parameter(torch.empty_like(self.delora_A.weight).copy_(self.delora_A.weight))
|
||||
self.frozen_C.requires_grad = False
|
||||
self.frozen_D = nn.Parameter(torch.empty_like(self.delora_B.weight).copy_(self.delora_B.weight))
|
||||
self.frozen_D.requires_grad = False
|
||||
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
|
||||
self.reset_parameters()
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.Linear.reset_parameters(self)
|
||||
if hasattr(self, "delora_A"):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.kaiming_uniform_(self.delora_A.weight, a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.delora_B.weight, a=math.sqrt(5))
|
||||
nn.init.constant_(self.delora_lambda, self.delora_lambda_value)
|
||||
self.frozen_C.data = self.delora_A.weight.data
|
||||
self.frozen_D.data = self.delora_B.weight.data
|
||||
|
||||
def get_ABCD(self):
|
||||
# Get weights
|
||||
delora_A_weight = self.delora_A.weight # shape: (r, in_features)
|
||||
delora_B_weight = self.delora_B.weight # shape: (out_features, r)
|
||||
|
||||
# Get norms
|
||||
delora_A_norm = delora_A_weight.norm(dim=1) # shape: (r,)
|
||||
delora_B_norm = delora_B_weight.norm(dim=0) # shape: (r,)
|
||||
frozen_C_norm = self.frozen_C.norm(dim=1) # shape: (r,)
|
||||
frozen_D_norm = self.frozen_D.norm(dim=0) # shape: (r,)
|
||||
|
||||
# AB normalization
|
||||
diag12 = torch.div(self.delora_lambda / self.r, torch.mul(delora_A_norm, delora_B_norm))
|
||||
diag12 = torch.diag_embed(diag12)
|
||||
|
||||
diag34 = torch.div(self.delora_lambda / self.r, torch.mul(frozen_C_norm, frozen_D_norm))
|
||||
diag34 = torch.diag_embed(diag34)
|
||||
|
||||
# Get ABCD
|
||||
ABCD = delora_B_weight @ diag12 @ delora_A_weight
|
||||
ABCD = ABCD - self.frozen_D @ diag34 @ self.frozen_C
|
||||
|
||||
# W scaling
|
||||
Wnorm = self.weight.data.norm(dim=0) # shape: (in_features,)
|
||||
ABCD = torch.mul(ABCD, Wnorm.unsqueeze(0)) # shape: (out_features, in_features)
|
||||
|
||||
return ABCD
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
nn.Linear.train(self, mode)
|
||||
self.delora_A.train(mode)
|
||||
self.delora_B.train(mode)
|
||||
self.delora_lambda.requires_grad = mode
|
||||
|
||||
if not mode and self.merge_weights and not self.merged:
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += self.get_ABCD().to(self.weight.device, dtype=self.weight.dtype)
|
||||
self.merged = True
|
||||
elif self.merge_weights and self.merged:
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0:
|
||||
self.weight.data -= self.get_ABCD()
|
||||
self.merged = False
|
||||
|
||||
def eval(self):
|
||||
nn.Linear.eval(self)
|
||||
self.delora_A.eval()
|
||||
self.delora_B.eval()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
previous_dtype = self.weight.dtype
|
||||
if self.disable_adapters:
|
||||
if self.r > 0 and self.merged:
|
||||
self.weight.data -= self.get_ABCD()
|
||||
self.merged = False
|
||||
|
||||
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
||||
elif self.r > 0 and not self.merged:
|
||||
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
||||
if self.r > 0:
|
||||
result += F.linear(self.delora_dropout(x), self.get_ABCD(), bias=None)
|
||||
else:
|
||||
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
|
||||
|
||||
if result.dtype != previous_dtype:
|
||||
result = result.to(previous_dtype)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class MergedLinear(nn.Linear, DeloraLayer):
|
||||
# Delora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int = 0,
|
||||
delora_lambda: int = 1,
|
||||
delora_dropout: float = 0.0,
|
||||
enable_delora: List[bool] = [False],
|
||||
fan_in_fan_out: bool = False,
|
||||
merge_weights: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
if is_bnb_available():
|
||||
|
||||
class Linear8bitLt(bnb.nn.Linear8bitLt, DeloraLayer):
|
||||
# Delora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
r: int = 0,
|
||||
delora_lambda: int = 1,
|
||||
delora_dropout: float = 0.0,
|
||||
Wdecompose: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
class MergedLinear8bitLt(bnb.nn.Linear8bitLt, DeloraLayer):
|
||||
# Delora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int = 0,
|
||||
delora_lambda: int = 1,
|
||||
delora_dropout: float = 0.0,
|
||||
enable_delora: List[bool] = [False],
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,420 @@
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from peft.tuners.lycoris_utils import LycorisLayer, check_adapters_to_merge
|
||||
|
||||
|
||||
class OFTLayer(nn.Module, LycorisLayer):
|
||||
# All names of layers that may contain adapter weights
|
||||
adapter_layer_names = ("oft_r",)
|
||||
|
||||
# other_param_names is defined on parent class
|
||||
|
||||
def __init__(self, base_layer: nn.Module):
|
||||
super().__init__()
|
||||
LycorisLayer.__init__(self, base_layer)
|
||||
|
||||
# OFT info
|
||||
self.oft_r = nn.ParameterDict({})
|
||||
self.coft = {}
|
||||
self.eps = {}
|
||||
self.block_share = {}
|
||||
|
||||
@property
|
||||
def _available_adapters(self) -> Set[str]:
|
||||
return {*self.oft_r}
|
||||
|
||||
def create_adapter_parameters(self, adapter_name: str, r: int, shape: Tuple[int, ...], block_share: bool):
|
||||
# if block_share:
|
||||
# self.oft_r[adapter_name] = nn.Parameter(torch.empty(1, math.ceil(shape[0] / r), math.ceil(shape[0] / r)))
|
||||
# else:
|
||||
# self.oft_r[adapter_name] = nn.Parameter(torch.empty(r, math.ceil(shape[0] / r), math.ceil(shape[0] / r)))
|
||||
weight = getattr(self.get_base_layer(), "weight", None)
|
||||
# self.oft_r[adapter_name] = nn.Parameter(torch.cat([weight.new_ones(r, r), weight.new_zeros(shape[0]-r, r)], dim=0))
|
||||
self.oft_r[adapter_name] = nn.Parameter(
|
||||
torch.cat([torch.eye(r, device=weight.device, dtype=weight.dtype),
|
||||
torch.zeros(shape[0] - r, r, device=weight.device, dtype=weight.dtype)], dim=0))
|
||||
|
||||
def reset_adapter_parameters(self, adapter_name: str):
|
||||
# nn.init.zeros_(self.oft_r[adapter_name])
|
||||
nn.init.kaiming_uniform_(self.oft_r[adapter_name], a=1 / self.eps[adapter_name])
|
||||
|
||||
def reset_adapter_parameters_random(self, adapter_name: str):
|
||||
# nn.init.kaiming_uniform_(self.oft_r[adapter_name], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.oft_r[adapter_name], a=1 / self.eps[adapter_name])
|
||||
|
||||
def update_layer(
|
||||
self,
|
||||
adapter_name: str,
|
||||
r: int,
|
||||
module_dropout: float,
|
||||
init_weights: bool,
|
||||
coft: bool = False,
|
||||
eps: float = 6e-5,
|
||||
block_share: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Internal function to create oft adapter
|
||||
|
||||
Args:
|
||||
adapter_name (`str`): Name for the adapter to add.
|
||||
r (`int`): Rank for the added adapter.
|
||||
module_dropout (`float`): The dropout probability for disabling adapter during training.
|
||||
init_weights (`bool`): Whether to initialize weights.
|
||||
coft (`bool`): Whether to use the constrained variant of OFT or not.
|
||||
eps (`float`):
|
||||
The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True.
|
||||
block_share (`bool`): Whether to share the OFT parameters between blocks or not.
|
||||
"""
|
||||
if r <= 0:
|
||||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
||||
|
||||
self.r[adapter_name] = r
|
||||
self.module_dropout[adapter_name] = module_dropout
|
||||
self.coft[adapter_name] = coft
|
||||
self.block_share[adapter_name] = block_share
|
||||
|
||||
# Determine shape of OFT weights
|
||||
base_layer = self.get_base_layer()
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
shape = tuple(base_layer.weight.shape)
|
||||
elif isinstance(base_layer, nn.Conv2d):
|
||||
shape = (
|
||||
base_layer.out_channels,
|
||||
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"OFT is not implemented for base layers of type {type(base_layer).__name__}")
|
||||
|
||||
# self.eps[adapter_name] = eps * math.ceil(shape[0] / r) * math.ceil(shape[0] / r)
|
||||
self.eps[adapter_name] = eps
|
||||
|
||||
# Create weights with provided shape
|
||||
self.create_adapter_parameters(adapter_name, r, shape, block_share)
|
||||
|
||||
# Initialize weights
|
||||
# if init_weights:
|
||||
# self.reset_adapter_parameters(adapter_name)
|
||||
# else:
|
||||
# self.reset_adapter_parameters_random(adapter_name)
|
||||
|
||||
# Move new weights to device
|
||||
weight = getattr(self.get_base_layer(), "weight", None)
|
||||
if weight is not None:
|
||||
# the layer is already completely initialized, this is an update
|
||||
if weight.dtype.is_floating_point or weight.dtype.is_complex:
|
||||
self.to(weight.device, dtype=weight.dtype)
|
||||
else:
|
||||
self.to(weight.device)
|
||||
self.set_adapter(self.active_adapters)
|
||||
|
||||
def unscale_layer(self, scale=None) -> None:
|
||||
# scale is not used
|
||||
pass
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`, *optional*):
|
||||
If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If `None`, all active adapters will be merged.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
adapter_names = check_adapters_to_merge(self, adapter_names)
|
||||
if not adapter_names:
|
||||
# no adapter to merge
|
||||
return
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self._available_adapters:
|
||||
base_layer = self.get_base_layer()
|
||||
|
||||
orig_weights = base_layer.weight.data
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
elif isinstance(base_layer, nn.Conv2d):
|
||||
orig_weights = orig_weights.view(
|
||||
[
|
||||
base_layer.out_channels,
|
||||
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
|
||||
]
|
||||
)
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
delta_weight = self.get_delta_weight(active_adapter)
|
||||
if orig_weights.shape[1] != delta_weight.shape[1]:
|
||||
# when in channels is not divisible by r
|
||||
delta_weight = delta_weight[: orig_weights.shape[1], : orig_weights.shape[1]]
|
||||
# delta_weight=delta_weight.to(orig_weights.device, dtype=orig_weights.dtype)
|
||||
new_weights = torch.mm(orig_weights, delta_weight)
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
new_weights = torch.transpose(new_weights, 0, 1)
|
||||
elif isinstance(base_layer, nn.Conv2d):
|
||||
new_weights = torch.transpose(new_weights, 0, 1)
|
||||
new_weights = new_weights.view(
|
||||
[
|
||||
base_layer.out_channels,
|
||||
base_layer.in_channels,
|
||||
base_layer.kernel_size[0],
|
||||
base_layer.kernel_size[1],
|
||||
]
|
||||
)
|
||||
|
||||
if safe_merge and not torch.isfinite(new_weights).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
base_layer.weight.data = new_weights
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
"""
|
||||
This method unmerges all merged adapter layers from the base weights.
|
||||
"""
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self._available_adapters:
|
||||
base_layer = self.get_base_layer()
|
||||
new_weights = base_layer.weight.data
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
new_weights = torch.transpose(new_weights, 0, 1)
|
||||
elif isinstance(base_layer, nn.Conv2d):
|
||||
new_weights = new_weights.view(
|
||||
[
|
||||
base_layer.out_channels,
|
||||
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
|
||||
]
|
||||
)
|
||||
new_weights = torch.transpose(new_weights, 0, 1)
|
||||
delta_weight = self.get_delta_weight(active_adapter)
|
||||
if new_weights.shape[1] != delta_weight.shape[1]:
|
||||
# when in channels is not divisible by r
|
||||
delta_weight = delta_weight[: new_weights.shape[1], : new_weights.shape[1]]
|
||||
delta_inv = torch.inverse(delta_weight)
|
||||
orig_weights = torch.mm(new_weights, delta_inv)
|
||||
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
elif isinstance(base_layer, nn.Conv2d):
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
orig_weights = orig_weights.reshape(
|
||||
[
|
||||
base_layer.out_channels,
|
||||
base_layer.in_channels,
|
||||
base_layer.kernel_size[0],
|
||||
base_layer.kernel_size[1],
|
||||
]
|
||||
)
|
||||
base_layer.weight.data = orig_weights
|
||||
|
||||
def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
|
||||
# rank = self.r[adapter_name]
|
||||
# coft = self.coft[adapter_name]
|
||||
# eps = self.eps[adapter_name]
|
||||
# opt_r = self.oft_r[adapter_name]
|
||||
|
||||
# if coft:
|
||||
# with torch.no_grad():
|
||||
# opt_r.copy_(self._project_batch(opt_r, eps=eps))
|
||||
#
|
||||
# orth_rotate = self._cayley_batch(opt_r)
|
||||
# weight = self._block_diagonal(orth_rotate, rank)
|
||||
rank = self.r[adapter_name]
|
||||
hrft_v = self.oft_r[adapter_name]
|
||||
in_features = self.oft_r[adapter_name].size(0)
|
||||
device = self.oft_r[adapter_name].device
|
||||
dtype = self.oft_r[adapter_name].dtype
|
||||
|
||||
# unit_v_list = [hrft_v[:, i].view(-1,1) / (torch.sqrt(torch.sum(hrft_v[:,i] ** 2) + self.eps[adapter_name])) for i in range(8)]
|
||||
|
||||
# weight = torch.eye(in_features, device=device, dtype=dtype)
|
||||
# for unit_v in unit_v_list:
|
||||
# weight = torch.mm(weight, torch.eye(in_features, device=device, dtype=dtype) - 2 * unit_v @ unit_v.t())
|
||||
|
||||
U_list = []
|
||||
U_list.append((hrft_v[:, 0] / hrft_v[:, 0].norm()).view(-1, 1))
|
||||
for i in range(1, rank):
|
||||
Ui = hrft_v[:, i].view(-1, 1)
|
||||
for j in range(i):
|
||||
Ui = Ui - (U_list[j].t() @ Ui) * U_list[j]
|
||||
U_list.append((Ui / Ui.norm()).view(-1, 1))
|
||||
U_list = torch.cat(U_list, dim=1)
|
||||
weight = torch.eye(in_features, device=device, dtype=dtype) - 2 * U_list @ U_list.t()
|
||||
# weight = torch.eye(in_features, device=device) - 2 * U_list @ U_list.t()
|
||||
|
||||
return weight
|
||||
|
||||
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L144
|
||||
def _cayley_batch(self, data: torch.Tensor) -> torch.Tensor:
|
||||
b, r, c = data.shape
|
||||
# Ensure the input matrix is skew-symmetric
|
||||
skew = 0.5 * (data - data.transpose(1, 2))
|
||||
I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) # noqa: E741
|
||||
|
||||
# Perform the Cayley parametrization
|
||||
Q = torch.bmm(I - skew, torch.inverse(I + skew))
|
||||
|
||||
return Q
|
||||
|
||||
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L155
|
||||
def _block_diagonal(self, oft_r: torch.Tensor, rank: int) -> torch.Tensor:
|
||||
if oft_r.shape[0] == 1:
|
||||
# block share
|
||||
blocks = [oft_r[0, ...] for i in range(rank)]
|
||||
else:
|
||||
blocks = [oft_r[i, ...] for i in range(rank)]
|
||||
|
||||
# Use torch.block_diag to create the block diagonal matrix
|
||||
A = torch.block_diag(*blocks)
|
||||
|
||||
return A
|
||||
|
||||
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L52
|
||||
def _project_batch(self, oft_r, eps=1e-5):
|
||||
# scaling factor for each of the smaller block matrix
|
||||
eps = eps * 1 / torch.sqrt(torch.tensor(oft_r.shape[0]))
|
||||
I = ( # noqa: E741
|
||||
torch.zeros((oft_r.size(1), oft_r.size(1)), device=oft_r.device, dtype=oft_r.dtype)
|
||||
.unsqueeze(0)
|
||||
.expand_as(oft_r)
|
||||
)
|
||||
diff = oft_r - I
|
||||
norm_diff = torch.norm(oft_r - I, dim=(1, 2), keepdim=True)
|
||||
mask = (norm_diff <= eps).bool()
|
||||
out = torch.where(mask, oft_r, I + eps * (diff / norm_diff))
|
||||
return out
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
if len(result.shape) == 4:
|
||||
result = result.permute(0, 2, 3, 1)
|
||||
|
||||
base_layer = self.get_base_layer()
|
||||
base_bias = base_layer.bias
|
||||
if base_bias is not None:
|
||||
# Bias should be added after OFT forward
|
||||
result = result - base_bias.data
|
||||
|
||||
# Execute all the adapters
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self._available_adapters:
|
||||
continue
|
||||
|
||||
module_dropout = self.module_dropout[active_adapter]
|
||||
|
||||
# Modify current execution weights
|
||||
if (not self.training) or (self.training and torch.rand(1) > module_dropout):
|
||||
result = self._get_delta_activations(active_adapter, result, *args, **kwargs)
|
||||
|
||||
if base_bias is not None:
|
||||
result = result + base_bias.data
|
||||
if len(result.shape) == 4:
|
||||
result = result.permute(0, 3, 1, 2)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
|
||||
|
||||
class Linear(OFTLayer):
|
||||
"""OFT implemented in Linear layer"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
adapter_name: str = "default",
|
||||
r: int = 0,
|
||||
module_dropout: float = 0.0,
|
||||
init_weights: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(base_layer)
|
||||
|
||||
# Create adapter and set it active
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs)
|
||||
|
||||
def _get_delta_activations(
|
||||
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
delta_weight = self.get_delta_weight(adapter_name)
|
||||
|
||||
base_layer = self.get_base_layer()
|
||||
base_weight = base_layer.weight.data
|
||||
delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]]
|
||||
|
||||
# don't add bias here, because the bias will be added after OFT forward
|
||||
return torch.matmul(input, delta_weight)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "oft." + rep
|
||||
|
||||
|
||||
class Conv2d(OFTLayer):
|
||||
"""OFT implemented in Conv2d layer"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
adapter_name: str = "default",
|
||||
r: int = 0,
|
||||
module_dropout: float = 0.0,
|
||||
init_weights: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(base_layer)
|
||||
|
||||
# Create adapter and set it active
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs)
|
||||
|
||||
def _get_delta_activations(
|
||||
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
delta_weight = self.get_delta_weight(adapter_name)
|
||||
|
||||
base_layer = self.get_base_layer()
|
||||
base_weight = base_layer.weight.data
|
||||
delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]]
|
||||
|
||||
# don't add bias here, because the bias will be added after OFT forward
|
||||
return torch.matmul(input, delta_weight)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "oft." + rep
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import os
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Separate the principal singular value and singular vectors from base model")
|
||||
parser.add_argument("--base_model_path", type=str, required=True, help="The name or path of the base model.")
|
||||
parser.add_argument("--output_dir", type=str, required=True)
|
||||
parser.add_argument("--bits", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||
parser.add_argument("--init_weights", type=str, default="pissa", help="(`['pissa', 'pissa_niter_[number of iters]']`)")
|
||||
parser.add_argument("--lora_r", type=int, default=128)
|
||||
parser.add_argument("--lora_alpha", type=int, default=128)
|
||||
parser.add_argument("--lora_dropout", type=float, default=0)
|
||||
parser.add_argument('--target_modules', nargs='+', help='', required=True)
|
||||
script_args = parser.parse_args()
|
||||
print(script_args)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.base_model_path,
|
||||
torch_dtype=(
|
||||
torch.float16
|
||||
if script_args.bits == "fp16"
|
||||
else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32)
|
||||
),
|
||||
device_map="auto",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_path)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
lora_config = LoraConfig(
|
||||
r=script_args.lora_r,
|
||||
lora_alpha=script_args.lora_alpha,
|
||||
init_lora_weights=True if script_args.init_weights=="True" else script_args.init_weights,
|
||||
lora_dropout=script_args.lora_dropout,
|
||||
target_modules=script_args.target_modules,
|
||||
)
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
|
||||
# Save PiSSA modules:
|
||||
peft_model.peft_config["default"].init_lora_weights = True
|
||||
peft_model.save_pretrained(os.path.join(script_args.output_dir, "pissa_init"))
|
||||
# Save residual model:
|
||||
peft_model = peft_model.unload()
|
||||
peft_model.save_pretrained(script_args.output_dir)
|
||||
# Save the tokenizer:
|
||||
tokenizer.save_pretrained(script_args.output_dir)
|
||||
@@ -0,0 +1,274 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from peft.tuners._buffer_dict import BufferDict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
||||
|
||||
from .config import DeloraConfig
|
||||
|
||||
|
||||
class DeloraLayer(BaseTunerLayer):
|
||||
# All names of layers that may contain (trainable) adapter weights
|
||||
adapter_layer_names = (
|
||||
"delora_A",
|
||||
"delora_B",
|
||||
"delora_lambda",
|
||||
)
|
||||
# All names of other parameters that may contain adapter-related parameters
|
||||
other_param_names = (
|
||||
"r",
|
||||
"delora_dropout",
|
||||
"delora_w_norm",
|
||||
)
|
||||
|
||||
def __init__(self, base_layer: nn.Module, **kwargs) -> None:
|
||||
self.base_layer = base_layer
|
||||
self.r = {}
|
||||
self.delora_dropout = nn.ModuleDict({})
|
||||
self.delora_A = nn.ParameterDict({})
|
||||
self.delora_B = nn.ParameterDict({})
|
||||
self.delora_lambda = nn.ParameterDict({})
|
||||
# Use persistent buffers so they are included in state_dict and saved.
|
||||
self.delora_w_norm = BufferDict({}, persistent=True)
|
||||
# Mark the weight as unmerged
|
||||
self._disable_adapters = False
|
||||
self.merged_adapters = []
|
||||
self.kwargs = kwargs
|
||||
|
||||
base_layer_mod = self.get_base_layer()
|
||||
if isinstance(base_layer_mod, nn.Linear):
|
||||
self.in_features, self.out_features = base_layer_mod.in_features, base_layer_mod.out_features
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type {type(base_layer_mod)}")
|
||||
|
||||
@staticmethod
|
||||
def _compute_delta(
|
||||
A: torch.Tensor, B: torch.Tensor, delora_lambda: torch.Tensor, r: int, w_norm: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Compute delta = B @ diag(delora_lambda/r / (||A_i||*||B^j||)) @ A, scaled by provided w_norm (per-input channel)"""
|
||||
An = torch.clamp(A.norm(dim=1), min=1e-4)
|
||||
Bn = torch.clamp(B.norm(dim=0), min=1e-4)
|
||||
diag = torch.diag_embed(delora_lambda / r / (An * Bn))
|
||||
delta = B @ diag @ A
|
||||
delta = delta * w_norm.unsqueeze(0)
|
||||
return delta
|
||||
|
||||
def get_delta_weight(self, adapter: str) -> torch.Tensor:
|
||||
if adapter not in self.delora_A or adapter not in self.delora_B:
|
||||
raise ValueError(f"Adapter {adapter} not found.")
|
||||
|
||||
delta = self._compute_delta(
|
||||
self.delora_A[adapter],
|
||||
self.delora_B[adapter],
|
||||
self.delora_lambda[adapter],
|
||||
self.r[adapter],
|
||||
self.delora_w_norm[adapter],
|
||||
)
|
||||
return delta
|
||||
|
||||
def update_layer(
|
||||
self,
|
||||
adapter_name: str,
|
||||
r: int,
|
||||
delora_lambda: float,
|
||||
config: DeloraConfig,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Internal function to create delora adapter
|
||||
|
||||
Args:
|
||||
adapter_name (`str`): Name for the adapter to add.
|
||||
r (`int`): Rank for the added adapter.
|
||||
delora_lambda (`float`): Boundary for the adapter's norm.
|
||||
config (`DeloraConfig`): The adapter configuration for this layer.
|
||||
"""
|
||||
module_dropout = config.module_dropout
|
||||
init_weights = config.init_weights
|
||||
inference_mode = config.inference_mode
|
||||
|
||||
if r <= 0:
|
||||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
||||
|
||||
self.r[adapter_name] = r
|
||||
self.delora_A[adapter_name] = nn.Parameter(torch.empty(r, self.in_features))
|
||||
self.delora_B[adapter_name] = nn.Parameter(torch.empty(self.out_features, r))
|
||||
self.delora_lambda[adapter_name] = nn.Parameter(torch.empty(1))
|
||||
if module_dropout > 0.0:
|
||||
module_dropout_layer = nn.Dropout(p=module_dropout)
|
||||
else:
|
||||
module_dropout_layer = nn.Identity()
|
||||
self.delora_dropout.update(nn.ModuleDict({adapter_name: module_dropout_layer}))
|
||||
|
||||
# Initialize weights
|
||||
self.reset_delora_parameters(adapter_name, init_weights, delora_lambda)
|
||||
|
||||
# Move new weights to device
|
||||
self._move_adapter_to_device_of_base_layer(adapter_name)
|
||||
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
|
||||
|
||||
def reset_delora_parameters(
|
||||
self,
|
||||
adapter_name: str,
|
||||
init_weights: bool = True,
|
||||
delora_lambda: float = 15.0,
|
||||
) -> None:
|
||||
if adapter_name not in self.delora_A.keys():
|
||||
return
|
||||
|
||||
if init_weights is True:
|
||||
nn.init.kaiming_uniform_(self.delora_A[adapter_name], a=math.sqrt(5))
|
||||
nn.init.zeros_(self.delora_B[adapter_name])
|
||||
else:
|
||||
nn.init.kaiming_uniform_(self.delora_A[adapter_name], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.delora_B[adapter_name], a=math.sqrt(5))
|
||||
|
||||
self.delora_lambda[adapter_name].data.fill_(float(delora_lambda))
|
||||
|
||||
# capture a fixed norm for this adapter to use for future delta computations
|
||||
with torch.no_grad():
|
||||
w = self.get_base_layer().weight
|
||||
if w.device.type != "meta":
|
||||
w_norm = torch.norm(w.data, dim=0).detach()
|
||||
else:
|
||||
# For meta tensors, we can't compute the norm, so use a default value
|
||||
w_norm = torch.ones(w.shape[1], device=w.device)
|
||||
self.delora_w_norm[adapter_name] = w_norm
|
||||
|
||||
|
||||
class DeloraLinear(nn.Module, DeloraLayer):
|
||||
# DeLoRA implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name: str,
|
||||
config: DeloraConfig,
|
||||
r: int,
|
||||
delora_lambda: float,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
DeloraLayer.__init__(self, base_layer, **kwargs)
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(adapter_name, r, delora_lambda, config=config)
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`, *optional*):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`list[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
"""
|
||||
adapter_names = check_adapters_to_merge(self, adapter_names)
|
||||
if not adapter_names:
|
||||
return
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self.delora_A.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
delta_weight = (
|
||||
self.get_delta_weight(active_adapter)
|
||||
.detach()
|
||||
.to(dtype=base_layer.weight.dtype, device=base_layer.weight.device)
|
||||
)
|
||||
with torch.no_grad():
|
||||
if safe_merge:
|
||||
orig_weights = base_layer.weight.data.clone()
|
||||
orig_weights = orig_weights + delta_weight
|
||||
|
||||
if not torch.isfinite(orig_weights).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
base_layer.weight.data = orig_weights
|
||||
else:
|
||||
base_layer.weight.data.add_(delta_weight)
|
||||
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
"""
|
||||
Unmerge all merged adapter layers from the base weights.
|
||||
"""
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self.delora_A.keys():
|
||||
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
|
||||
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
if not self.active_adapters:
|
||||
return self.base_layer(x, *args, **kwargs).to(previous_dtype)
|
||||
|
||||
base_out = self.base_layer(x, *args, **kwargs)
|
||||
add_out = torch.zeros_like(base_out)
|
||||
|
||||
for adapter in self.active_adapters:
|
||||
if adapter not in self.delora_A:
|
||||
continue
|
||||
|
||||
x_d = self.delora_dropout[adapter](x)
|
||||
|
||||
# Decomposed delta calculation
|
||||
# 1. (x * w_norm) @ A.T
|
||||
h = nn.functional.linear(x_d * self.delora_w_norm[adapter], self.delora_A[adapter])
|
||||
|
||||
# 2. h @ diag
|
||||
An = torch.clamp(self.delora_A[adapter].norm(dim=1), min=1e-4)
|
||||
Bn = torch.clamp(self.delora_B[adapter].norm(dim=0), min=1e-4)
|
||||
scaling = (self.delora_lambda[adapter] / self.r[adapter]) / (An * Bn)
|
||||
|
||||
h = h * scaling
|
||||
|
||||
# 3. h @ B.T
|
||||
h = nn.functional.linear(h, self.delora_B[adapter])
|
||||
|
||||
add_out += h
|
||||
|
||||
result = base_out + add_out.to(base_out.dtype)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
|
||||
def supports_lora_conversion(self, adapter_name: str = "default") -> bool:
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "delora." + rep
|
||||
@@ -0,0 +1,462 @@
|
||||
# Copyright 2024-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
||||
|
||||
from .config import HRAConfig
|
||||
|
||||
|
||||
class HRALayer(BaseTunerLayer):
|
||||
# All names of layers that may contain (trainable) adapter weights
|
||||
adapter_layer_names = ("hra_u",)
|
||||
# All names of other parameters that may contain adapter-related parameters
|
||||
other_param_names = ("hra_r", "hra_apply_GS")
|
||||
|
||||
def __init__(self, base_layer: nn.Module, **kwargs) -> None:
|
||||
self.base_layer = base_layer
|
||||
self.hra_r = {}
|
||||
self.hra_apply_GS = {}
|
||||
self.hra_u = nn.ParameterDict({})
|
||||
# Mark the weight as unmerged
|
||||
self._disable_adapters = False
|
||||
self.merged_adapters = []
|
||||
# flag to enable/disable casting of input to weight dtype during forward call
|
||||
self.cast_input_dtype_enabled = True
|
||||
self.kwargs = kwargs
|
||||
|
||||
base_layer = self.get_base_layer()
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
self.in_features, self.out_features = base_layer.in_features, base_layer.out_features
|
||||
elif isinstance(base_layer, nn.Conv2d):
|
||||
self.in_features, self.out_features = base_layer.in_channels, base_layer.out_channels
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type {type(base_layer)}")
|
||||
|
||||
def update_layer(
|
||||
self,
|
||||
adapter_name: str,
|
||||
r: int,
|
||||
config: HRAConfig,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Internal function to create hra adapter
|
||||
|
||||
Args:
|
||||
adapter_name (`str`): Name for the adapter to add.
|
||||
r (`int`): Rank for the added adapter.
|
||||
config (`HRAConfig`): The adapter configuration for this layer.
|
||||
"""
|
||||
apply_GS = config.apply_GS
|
||||
init_weights = config.init_weights
|
||||
inference_mode = config.inference_mode
|
||||
|
||||
if r <= 0:
|
||||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
||||
|
||||
self.hra_r[adapter_name] = r
|
||||
self.hra_apply_GS[adapter_name] = apply_GS
|
||||
|
||||
# Determine shape of HRA weights
|
||||
base_layer = self.get_base_layer()
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
self.hra_u[adapter_name] = nn.Parameter(torch.empty(self.in_features, r), requires_grad=True)
|
||||
elif isinstance(base_layer, nn.Conv2d):
|
||||
self.hra_u[adapter_name] = nn.Parameter(
|
||||
torch.empty(self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], r),
|
||||
requires_grad=True,
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"HRA is not implemented for base layers of type {type(base_layer).__name__}")
|
||||
|
||||
# Initialize weights
|
||||
if init_weights:
|
||||
self.reset_hra_parameters(adapter_name)
|
||||
else:
|
||||
self.reset_hra_parameters_random(adapter_name)
|
||||
|
||||
# Move new weights to device
|
||||
self._move_adapter_to_device_of_base_layer(adapter_name)
|
||||
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
|
||||
|
||||
def reset_hra_parameters(self, adapter_name: str):
|
||||
if self.hra_r[adapter_name] % 2 != 0:
|
||||
warnings.warn("The symmetric initialization can NOT be performed when r is odd!")
|
||||
nn.init.kaiming_uniform_(self.hra_u[adapter_name], a=math.sqrt(5))
|
||||
else:
|
||||
shape = self.hra_u[adapter_name].shape
|
||||
half_u = torch.zeros(shape[0], shape[1] // 2)
|
||||
nn.init.kaiming_uniform_(half_u, a=math.sqrt(5))
|
||||
self.hra_u[adapter_name] = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1))
|
||||
|
||||
def reset_hra_parameters_random(self, adapter_name: str):
|
||||
nn.init.kaiming_uniform_(self.hra_u[adapter_name], a=math.sqrt(5))
|
||||
|
||||
def scale_layer(self, scale: float) -> None:
|
||||
if scale == 1:
|
||||
return
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.hra_u.keys():
|
||||
continue
|
||||
|
||||
warnings.warn("Scaling operation for HRA not supported! Automatically set scale to 1.")
|
||||
|
||||
def unscale_layer(self, scale=None) -> None:
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.hra_u.keys():
|
||||
continue
|
||||
|
||||
warnings.warn("Unscaling operation for HRA not supported! Keeping scale at 1.")
|
||||
|
||||
|
||||
class HRALinear(nn.Module, HRALayer):
|
||||
"""
|
||||
HRA implemented in a dense layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name: str,
|
||||
config: HRAConfig,
|
||||
r: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
HRALayer.__init__(self, base_layer, **kwargs)
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(adapter_name, r, config=config, **kwargs)
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`, *optional*):
|
||||
If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If `None`, all active adapters will be merged.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
adapter_names = check_adapters_to_merge(self, adapter_names)
|
||||
if not adapter_names:
|
||||
# no adapter to merge
|
||||
return
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self.hra_u.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
orig_dtype = base_layer.weight.dtype
|
||||
if safe_merge:
|
||||
# Note that safe_merge will be slower than the normal merge
|
||||
# because of the copy operation.
|
||||
orig_weight = base_layer.weight.data.clone()
|
||||
delta_weight = self.get_delta_weight(active_adapter)
|
||||
orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
|
||||
|
||||
if not torch.isfinite(orig_weight).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
base_layer.weight.data = orig_weight.to(orig_dtype)
|
||||
else:
|
||||
delta_weight = self.get_delta_weight(active_adapter)
|
||||
new_weight = torch.mm(base_layer.weight.data.to(delta_weight.dtype), delta_weight)
|
||||
base_layer.weight.data = new_weight.to(orig_dtype)
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
"""
|
||||
This method unmerges all merged adapter layers from the base weights.
|
||||
"""
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
base_layer = self.get_base_layer()
|
||||
orig_dtype = base_layer.weight.dtype
|
||||
if active_adapter in self.hra_u.keys():
|
||||
orig_weight = base_layer.weight.data.clone()
|
||||
delta_weight = self.get_delta_weight(active_adapter, reverse=True)
|
||||
new_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
|
||||
base_layer.weight.data = new_weight.to(orig_dtype)
|
||||
|
||||
def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Tensor:
|
||||
rank = self.hra_r[adapter_name]
|
||||
apply_GS = self.hra_apply_GS[adapter_name]
|
||||
opt_u = self.hra_u[adapter_name]
|
||||
shape = opt_u.shape
|
||||
|
||||
if apply_GS:
|
||||
weight = [(opt_u[:, 0] / opt_u[:, 0].norm()).view(-1, 1)]
|
||||
for i in range(1, rank):
|
||||
ui = opt_u[:, i].view(-1, 1)
|
||||
for j in range(i):
|
||||
ui = ui - (weight[j].t() @ ui) * weight[j]
|
||||
weight.append((ui / ui.norm()).view(-1, 1))
|
||||
weight = torch.cat(weight, dim=1)
|
||||
weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * weight @ weight.t()
|
||||
|
||||
else:
|
||||
opt_u = opt_u / opt_u.norm(dim=0)
|
||||
weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype)
|
||||
if reverse:
|
||||
indices = range(rank - 1, -1, -1)
|
||||
else:
|
||||
indices = range(rank)
|
||||
|
||||
for i in indices:
|
||||
ui = opt_u[:, i].view(-1, 1)
|
||||
weight = weight - 2 * weight @ ui @ ui.t()
|
||||
|
||||
return weight
|
||||
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
new_weight = torch.eye(self.in_features, device=x.device)
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.hra_u.keys():
|
||||
continue
|
||||
delta_weight = self.get_delta_weight(active_adapter)
|
||||
new_weight = torch.mm(new_weight.to(delta_weight.dtype), delta_weight)
|
||||
|
||||
orig_weight = self.get_base_layer().weight.data
|
||||
orig_weight = self._cast_input_dtype(orig_weight, new_weight.dtype)
|
||||
new_weight = torch.mm(orig_weight, new_weight)
|
||||
bias = self._cast_input_dtype(self.base_layer.bias, new_weight.dtype)
|
||||
|
||||
if self.cast_input_dtype_enabled:
|
||||
x = self._cast_input_dtype(x, new_weight.dtype)
|
||||
else:
|
||||
x = x.to(self.get_base_layer().weight.data.dtype)
|
||||
result = F.linear(input=x, weight=new_weight, bias=bias)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "hra." + rep
|
||||
|
||||
|
||||
class HRAConv2d(nn.Module, HRALayer):
|
||||
"""HRA implemented in Conv2d layer"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name: str,
|
||||
config: HRAConfig,
|
||||
r: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
HRALayer.__init__(self, base_layer)
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(adapter_name, r, config=config, **kwargs)
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`, *optional*):
|
||||
If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If `None`, all active adapters will be merged.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
adapter_names = check_adapters_to_merge(self, adapter_names)
|
||||
if not adapter_names:
|
||||
# no adapter to merge
|
||||
return
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self.hra_u.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
orig_dtype = base_layer.weight.dtype
|
||||
if safe_merge:
|
||||
# Note that safe_merge will be slower than the normal merge
|
||||
# because of the copy operation.
|
||||
orig_weight = base_layer.weight.data.clone()
|
||||
orig_weight = orig_weight.view(
|
||||
self.out_features,
|
||||
self.in_features * base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
|
||||
)
|
||||
delta_weight = self.get_delta_weight(active_adapter)
|
||||
orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
|
||||
orig_weight = orig_weight.view(
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
base_layer.kernel_size[0],
|
||||
base_layer.kernel_size[0],
|
||||
)
|
||||
|
||||
if not torch.isfinite(orig_weight).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
base_layer.weight.data = orig_weight.to(orig_dtype)
|
||||
else:
|
||||
orig_weight = base_layer.weight.data
|
||||
orig_weight = orig_weight.view(
|
||||
self.out_features,
|
||||
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
|
||||
)
|
||||
delta_weight = self.get_delta_weight(active_adapter)
|
||||
orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
|
||||
orig_weight = orig_weight.view(
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
base_layer.kernel_size[0],
|
||||
base_layer.kernel_size[0],
|
||||
)
|
||||
|
||||
base_layer.weight.data = orig_weight.to(orig_dtype)
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
"""
|
||||
This method unmerges all merged adapter layers from the base weights.
|
||||
"""
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
base_layer = self.get_base_layer()
|
||||
orig_dtype = base_layer.weight.dtype
|
||||
if active_adapter in self.hra_u.keys():
|
||||
orig_weight = base_layer.weight.data.clone()
|
||||
orig_weight = orig_weight.view(
|
||||
self.out_features,
|
||||
self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0],
|
||||
)
|
||||
delta_weight = self.get_delta_weight(active_adapter, reverse=True)
|
||||
orig_weight = torch.mm(orig_weight.to(delta_weight.dtype), delta_weight)
|
||||
orig_weight = orig_weight.view(
|
||||
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
|
||||
)
|
||||
|
||||
base_layer.weight.data = orig_weight.to(orig_dtype)
|
||||
|
||||
def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Tensor:
|
||||
rank = self.hra_r[adapter_name]
|
||||
apply_GS = self.hra_apply_GS[adapter_name]
|
||||
opt_u = self.hra_u[adapter_name]
|
||||
shape = opt_u.shape
|
||||
|
||||
if apply_GS:
|
||||
weight = [(opt_u[:, 0] / opt_u[:, 0].norm()).view(-1, 1)]
|
||||
for i in range(1, rank):
|
||||
ui = opt_u[:, i].view(-1, 1)
|
||||
for j in range(i):
|
||||
ui = ui - (weight[j].t() @ ui) * weight[j]
|
||||
weight.append((ui / ui.norm()).view(-1, 1))
|
||||
weight = torch.cat(weight, dim=1)
|
||||
weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * weight @ weight.t()
|
||||
|
||||
else:
|
||||
opt_u = opt_u / opt_u.norm(dim=0)
|
||||
weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype)
|
||||
if reverse:
|
||||
indices = range(rank - 1, -1, -1)
|
||||
else:
|
||||
indices = range(rank)
|
||||
|
||||
for i in indices:
|
||||
ui = opt_u[:, i].view(-1, 1)
|
||||
weight = weight - 2 * weight @ ui @ ui.t()
|
||||
|
||||
return weight
|
||||
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
new_weight = torch.eye(
|
||||
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
|
||||
device=x.device,
|
||||
)
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.hra_u.keys():
|
||||
continue
|
||||
delta_weight = self.get_delta_weight(active_adapter)
|
||||
new_weight = torch.mm(new_weight.to(delta_weight.dtype), delta_weight)
|
||||
|
||||
orig_weight = self.base_layer.weight.data
|
||||
orig_weight = orig_weight.view(
|
||||
self.out_features,
|
||||
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
|
||||
)
|
||||
orig_weight = self._cast_input_dtype(orig_weight, new_weight.dtype)
|
||||
bias = self._cast_input_dtype(self.base_layer.bias, new_weight.dtype)
|
||||
|
||||
new_weight = torch.mm(orig_weight, new_weight)
|
||||
new_weight = new_weight.view(
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
self.base_layer.kernel_size[0],
|
||||
self.base_layer.kernel_size[0],
|
||||
)
|
||||
|
||||
if self.cast_input_dtype_enabled:
|
||||
x = self._cast_input_dtype(x, new_weight.dtype)
|
||||
else:
|
||||
x = x.to(self.get_base_layer().weight.data.dtype)
|
||||
result = F.conv2d(
|
||||
input=x,
|
||||
weight=new_weight,
|
||||
bias=bias,
|
||||
padding=self.base_layer.padding[0],
|
||||
stride=self.base_layer.stride[0],
|
||||
)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "hra." + rep
|
||||
@@ -0,0 +1,336 @@
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
||||
from peft.utils import transpose
|
||||
|
||||
from .config import IA3Config
|
||||
|
||||
|
||||
class IA3Layer(BaseTunerLayer):
|
||||
# All names of layers that may contain adapter weights
|
||||
adapter_layer_names = ("ia3_l",)
|
||||
|
||||
def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> None:
|
||||
self.base_layer = base_layer
|
||||
self.ia3_l = nn.ParameterDict({})
|
||||
# Mark the weight as unmerged
|
||||
self._disable_adapters = False
|
||||
self.merged_adapters = []
|
||||
self.is_feedforward = is_feedforward
|
||||
|
||||
base_layer = self.get_base_layer()
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||
elif isinstance(base_layer, (nn.Conv2d, nn.Conv3d)):
|
||||
in_features, out_features = base_layer.in_channels, base_layer.out_channels
|
||||
elif isinstance(base_layer, nn.Embedding):
|
||||
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim
|
||||
elif isinstance(base_layer, Conv1D):
|
||||
in_features, out_features = (
|
||||
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type {type(base_layer)}")
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
def update_layer(self, adapter_name: str, config: IA3Config, **kwargs):
|
||||
init_ia3_weights = config.init_ia3_weights
|
||||
inference_mode = config.inference_mode
|
||||
|
||||
# This code works for linear layers, override for other layer types
|
||||
# Actual trainable parameters
|
||||
if self.is_feedforward:
|
||||
weight = torch.randn((1, self.in_features))
|
||||
else:
|
||||
weight = torch.randn((self.out_features, 1))
|
||||
self.ia3_l[adapter_name] = nn.Parameter(weight)
|
||||
if init_ia3_weights:
|
||||
self.reset_ia3_parameters(adapter_name)
|
||||
self._move_adapter_to_device_of_base_layer(adapter_name)
|
||||
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
|
||||
|
||||
def reset_ia3_parameters(self, adapter_name):
|
||||
if adapter_name in self.ia3_l.keys():
|
||||
# initialize learned vector with torch.ones
|
||||
nn.init.constant_(self.ia3_l[adapter_name], 1.0)
|
||||
|
||||
|
||||
class Linear(nn.Module, IA3Layer):
|
||||
# (IA)^3 implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
adapter_name: str,
|
||||
config: IA3Config,
|
||||
is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer
|
||||
is_target_conv_1d_layer: bool = False, # whether target module is a conv1d layer. useful while unloading later
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
|
||||
self.fan_in_fan_out = config.fan_in_fan_out
|
||||
self.is_target_conv_1d_layer = is_target_conv_1d_layer
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(adapter_name, config=config)
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`, *optional*):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
"""
|
||||
adapter_names = check_adapters_to_merge(self, adapter_names)
|
||||
if not adapter_names:
|
||||
# no adapter to merge
|
||||
return
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self.ia3_l.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out)
|
||||
orig_dtype = base_layer.weight.data.dtype
|
||||
if safe_merge:
|
||||
orig_weights = base_layer.weight.data
|
||||
orig_weights = torch.mul(orig_weights, ia3_l)
|
||||
|
||||
if not torch.isfinite(orig_weights).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
base_layer.weight.data = orig_weights.to(orig_dtype)
|
||||
else:
|
||||
base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_l).to(orig_dtype)
|
||||
|
||||
if not self.is_feedforward and (base_layer.bias is not None):
|
||||
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
|
||||
orig_dtype = base_layer.bias.data.dtype
|
||||
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data).to(orig_dtype)
|
||||
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
"""
|
||||
This method unmerges all merged adapter layers from the base weights.
|
||||
"""
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
|
||||
warnings.warn("Unmerge result can be inaccurate for (IA)^3.")
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self.ia3_l.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
# Add tolerace to avoid division by zero
|
||||
ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) + 1e-8
|
||||
orig_dtype = base_layer.weight.data.dtype
|
||||
base_layer.weight.data = torch.div(base_layer.weight.data, ia3_l).to(orig_dtype)
|
||||
|
||||
if not self.is_feedforward and (base_layer.bias is not None):
|
||||
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
|
||||
orig_dtype = base_layer.bias.data.dtype
|
||||
base_layer.bias.data = torch.div(base_layer.bias.data, scaling.data + 1e-8).to(orig_dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
dtype = previous_dtype = x.dtype
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
ia3_scaling = 1
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.ia3_l.keys():
|
||||
continue
|
||||
dtype = self.ia3_l[active_adapter].dtype
|
||||
ia3_scaling *= self.ia3_l[active_adapter].flatten()
|
||||
|
||||
if self.is_feedforward:
|
||||
x = x.to(dtype)
|
||||
# TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype
|
||||
# e.g. bf16 vs fp32. Is that okay?
|
||||
interm = (x * ia3_scaling).to(previous_dtype)
|
||||
result = self.base_layer(interm, *args, **kwargs)
|
||||
else:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
result_dtype = result.dtype
|
||||
result = (result * ia3_scaling).to(result_dtype)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class _ConvNd(nn.Module, IA3Layer):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
adapter_name: str,
|
||||
config: IA3Config,
|
||||
is_feedforward: bool = False, # Set to True if the layer is treated as a feedforward layer
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward)
|
||||
self.fan_in_fan_out = config.fan_in_fan_out
|
||||
self._active_adapter = adapter_name
|
||||
self._kernel_dim = base_layer.weight.dim()
|
||||
|
||||
self.update_layer(adapter_name, config=config)
|
||||
|
||||
def update_layer(self, adapter_name: str, config: IA3Config, **kwargs):
|
||||
init_ia3_weights = config.init_ia3_weights
|
||||
inference_mode = config.inference_mode
|
||||
|
||||
# Actual trainable parameters
|
||||
num_features = self.in_features if self.is_feedforward else self.out_features
|
||||
weights_size = (1, num_features) + (1,) * (self._kernel_dim - 2)
|
||||
weight = torch.randn(weights_size)
|
||||
self.ia3_l[adapter_name] = nn.Parameter(weight)
|
||||
if init_ia3_weights:
|
||||
self.reset_ia3_parameters(adapter_name)
|
||||
self._move_adapter_to_device_of_base_layer(adapter_name)
|
||||
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`, *optional*):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
"""
|
||||
adapter_names = check_adapters_to_merge(self, adapter_names)
|
||||
if not adapter_names:
|
||||
# no adapter to merge
|
||||
return
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self.ia3_l.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
orig_dtype = base_layer.weight.data.dtype
|
||||
ia3_scaling = self.ia3_l[active_adapter].data
|
||||
if not self.is_feedforward:
|
||||
ia3_scaling = ia3_scaling.transpose(0, 1)
|
||||
|
||||
if safe_merge:
|
||||
output_weight = torch.mul(base_layer.weight.data, ia3_scaling).clone()
|
||||
|
||||
if not torch.isfinite(output_weight).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
base_layer.weight.data = output_weight.to(orig_dtype)
|
||||
else:
|
||||
base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_scaling).to(orig_dtype)
|
||||
|
||||
if not self.is_feedforward and (base_layer.bias is not None):
|
||||
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
|
||||
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data).to(orig_dtype)
|
||||
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
"""
|
||||
This method unmerges all merged adapter layers from the base weights.
|
||||
"""
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
|
||||
warnings.warn("Unmerge result can be inaccurate for (IA)^3.")
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self.ia3_l.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
orig_dtype = base_layer.weight.data.dtype
|
||||
# divide by (IA)^3 vector. Add tolerace to avoid division by zero
|
||||
ia3_scaling = self.ia3_l[active_adapter].data
|
||||
if not self.is_feedforward:
|
||||
ia3_scaling = ia3_scaling.transpose(0, 1)
|
||||
base_layer.weight.data = torch.div(base_layer.weight.data, ia3_scaling + 1e-8).to(orig_dtype)
|
||||
|
||||
if not self.is_feedforward and (base_layer.bias is not None):
|
||||
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape)
|
||||
orig_dtype = base_layer.bias.data.dtype
|
||||
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data).to(orig_dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
dtype = previous_dtype = x.dtype
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
ia3_scaling = 1
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.ia3_l.keys():
|
||||
continue
|
||||
dtype = self.ia3_l[active_adapter].dtype
|
||||
ia3_scaling *= self.ia3_l[active_adapter]
|
||||
|
||||
if self.is_feedforward:
|
||||
x = x.to(dtype)
|
||||
# TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype
|
||||
# e.g. bf16 vs fp32. Is that okay?
|
||||
interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype)
|
||||
result = self.base_layer(interm, *args, **kwargs)
|
||||
else:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
result = result.to(dtype) * ia3_scaling
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
|
||||
|
||||
class Conv2d(_ConvNd):
|
||||
# IA3 implemented in a 2D convolutional layer
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if not self._kernel_dim == 4:
|
||||
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")
|
||||
|
||||
|
||||
class Conv3d(_ConvNd):
|
||||
# IA3 implemented in a 3D convolutional layer
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if not self._kernel_dim == 5:
|
||||
raise ValueError(f"Conv2d layer kernel must have 5 dimensions, not {self._kernel_dim}")
|
||||
@@ -0,0 +1,287 @@
|
||||
# Copyright 2024-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx
|
||||
from peft.utils.other import transpose
|
||||
|
||||
|
||||
ENABLE_DORA_CACHING = False
|
||||
"""Whether to enable DoRA caching, which makes it faster at inference but requires more memory"""
|
||||
|
||||
|
||||
def cache_decorator(cache_key: str):
|
||||
"""Caching decorator for DoRA
|
||||
|
||||
Caching is only enabled if ENABLE_DORA_CACHING is set to True (default: False), when in eval mode, and when the
|
||||
adapter_name is passed (e.g. not during layer initialization).
|
||||
|
||||
"""
|
||||
|
||||
def cache_value(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
# if adapter_name is not passed, no caching
|
||||
adapter_name = kwargs.get("adapter_name")
|
||||
if (not ENABLE_DORA_CACHING) or self.training or (adapter_name is None):
|
||||
self._cache_clear()
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
cache_key_adapter = f"{cache_key}-{adapter_name}"
|
||||
output = self._cache_get(cache_key_adapter, None)
|
||||
if output is not None:
|
||||
return output
|
||||
|
||||
output = func(self, *args, **kwargs)
|
||||
self._cache_store(cache_key_adapter, output)
|
||||
return output
|
||||
|
||||
return wrapper
|
||||
|
||||
return cache_value
|
||||
|
||||
|
||||
class DoraLinearLayer(nn.Module):
|
||||
def __init__(self, fan_in_fan_out):
|
||||
super().__init__()
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
self._dora_cache: dict[str, Any] = {} # small ad hoc cache; values are not part of the state_dict
|
||||
|
||||
def _cache_store(self, key: str, value: Any) -> None:
|
||||
# cache intermediate values, e.g. weight norm of DoRA
|
||||
self._dora_cache[key] = value
|
||||
|
||||
def _cache_get(self, key: str, default: Optional[Any]) -> Optional[Any]:
|
||||
# retrieve from ad hoc cache
|
||||
return self._dora_cache.get(key, default)
|
||||
|
||||
def _cache_clear(self) -> None:
|
||||
self._dora_cache.clear()
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
if mode:
|
||||
self._cache_clear()
|
||||
super().train(mode=mode)
|
||||
return self
|
||||
|
||||
@cache_decorator("weight-norm")
|
||||
def get_weight_norm(self, weight, lora_weight, scaling, adapter_name: Optional[str] = None) -> torch.Tensor:
|
||||
# calculate L2 norm of weight matrix, column-wise
|
||||
weight = transpose(weight, self.fan_in_fan_out)
|
||||
weight = weight + scaling * lora_weight
|
||||
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
|
||||
return weight_norm
|
||||
|
||||
@cache_decorator("lora-weight")
|
||||
def get_lora_weight(self, lora_A, lora_B, adapter_name: Optional[str] = None):
|
||||
# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
|
||||
# calculate the same but using forward.
|
||||
x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=lora_A.weight.dtype)
|
||||
lora_weight = lora_B(lora_A(x_eye)).T
|
||||
return lora_weight
|
||||
|
||||
def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False) -> None:
|
||||
# temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2
|
||||
dtype_is_fp16 = lora_A.dtype == torch.float16
|
||||
if dtype_is_fp16:
|
||||
lora_A = lora_A.float()
|
||||
lora_B = lora_B.float()
|
||||
|
||||
with gather_params_ctx(base_layer.parameters()):
|
||||
if base_layer.__class__.__name__ == "Linear4bit":
|
||||
# We have to create a copy of the base layer, otherwise, FSDP will throw an error. 8bit does not work
|
||||
# yet because Int8Params cannot be correctly deep-copied (attributes vanish)
|
||||
base_layer = deepcopy(base_layer)
|
||||
|
||||
weight = dequantize_module_weight(base_layer)
|
||||
if weight.data.ndim >= 3: # For handling LoRAs applied to Conv layers.
|
||||
r = lora_A.shape[0]
|
||||
lora_weight = torch.mm(lora_B.view([-1, r]), lora_A.view([r, -1]))
|
||||
lora_weight = lora_weight.reshape(weight.shape)
|
||||
else:
|
||||
lora_weight = lora_B @ lora_A
|
||||
|
||||
if dtype_is_fp16:
|
||||
lora_weight = lora_weight.half()
|
||||
weight_norm = self.get_weight_norm(
|
||||
weight=weight.to(lora_A.device), lora_weight=lora_weight, scaling=scaling
|
||||
)
|
||||
|
||||
if place_on_cpu:
|
||||
weight_norm = weight_norm.to("cpu")
|
||||
self.weight = nn.Parameter(weight_norm, requires_grad=True)
|
||||
|
||||
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None, adapter_name="default"):
|
||||
"""
|
||||
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
|
||||
output.
|
||||
"""
|
||||
lora_weight = self.get_lora_weight(lora_A=lora_A, lora_B=lora_B, adapter_name=adapter_name)
|
||||
lora_weight = lora_weight.to(x.dtype)
|
||||
|
||||
magnitude = self.weight
|
||||
weight = dequantize_module_weight(base_layer)
|
||||
weight = weight.to(x.dtype)
|
||||
weight_norm = self.get_weight_norm(
|
||||
weight=weight, lora_weight=lora_weight.detach(), scaling=scaling, adapter_name=adapter_name
|
||||
)
|
||||
# see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353)
|
||||
# "[...] we suggest treating ||V +∆V ||_c in
|
||||
# Eq. (5) as a constant, thereby detaching it from the gradient
|
||||
# graph. This means that while ||V + ∆V ||_c dynamically
|
||||
# reflects the updates of ∆V , it won’t receive any gradient
|
||||
# during backpropagation"
|
||||
weight_norm = weight_norm.detach()
|
||||
mag_norm_scale = (magnitude / weight_norm).view(1, -1)
|
||||
|
||||
lora_result = lora_B(lora_A(x))
|
||||
|
||||
bias = None
|
||||
if base_result is not None:
|
||||
bias = base_layer.bias
|
||||
if bias is not None:
|
||||
base_result = base_result - bias
|
||||
else:
|
||||
base_result = F.linear(x, transpose(weight, self.fan_in_fan_out))
|
||||
|
||||
result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling
|
||||
return result_dora
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "lora.dora." + rep
|
||||
|
||||
|
||||
class DoraEmbeddingLayer(DoraLinearLayer):
|
||||
@cache_decorator("lora-weight")
|
||||
def get_lora_weight(self, lora_A, lora_B, adapter_name: Optional[str] = None):
|
||||
return (lora_A @ lora_B).T
|
||||
|
||||
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, embed_fn, adapter_name="default"):
|
||||
"""
|
||||
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
|
||||
output.
|
||||
"""
|
||||
lora_weight = self.get_lora_weight(lora_A=lora_A, lora_B=lora_B, adapter_name=adapter_name)
|
||||
magnitude = self.weight
|
||||
weight = base_layer.weight
|
||||
weight_norm = self.get_weight_norm(
|
||||
weight=weight, lora_weight=lora_weight.detach(), scaling=scaling, adapter_name=adapter_name
|
||||
)
|
||||
# see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353)
|
||||
# "[...] we suggest treating ||V +∆V ||_c in
|
||||
# Eq. (5) as a constant, thereby detaching it from the gradient
|
||||
# graph. This means that while ||V + ∆V ||_c dynamically
|
||||
# reflects the updates of ∆V , it won’t receive any gradient
|
||||
# during backpropagation"
|
||||
weight_norm = weight_norm.detach()
|
||||
mag_norm_scale = magnitude / weight_norm
|
||||
result_dora = mag_norm_scale * (embed_fn(x, lora_A) @ lora_B) * scaling
|
||||
return mag_norm_scale, result_dora
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "lora.dora." + rep
|
||||
|
||||
|
||||
class _DoraConvNdLayer(DoraLinearLayer):
|
||||
@cache_decorator("weight-norm")
|
||||
def get_weight_norm(self, weight, lora_weight, scaling, adapter_name: Optional[str] = None) -> torch.Tensor:
|
||||
# calculate L2 norm of weight matrix, column-wise
|
||||
weight = weight + scaling * lora_weight
|
||||
# the following is needed to have compatibility with the 4/5D weight tensors of Conv2D/3D
|
||||
dim = tuple(range(1, weight.dim()))
|
||||
weight_norm = weight.norm(p=2, dim=dim, keepdim=True).transpose(1, 0)
|
||||
return weight_norm
|
||||
|
||||
@cache_decorator("lora-weight")
|
||||
def get_lora_weight(self, lora_A, lora_B, adapter_name: Optional[str] = None) -> torch.Tensor:
|
||||
# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
|
||||
# calculate the same but using forward.
|
||||
r = lora_A.weight.shape[0]
|
||||
lora_weight = torch.mm(lora_B.weight.view([-1, r]), lora_A.weight.view([r, -1]))
|
||||
return lora_weight
|
||||
|
||||
def forward(
|
||||
self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None, adapter_name: str = "default"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
|
||||
output.
|
||||
"""
|
||||
weight = base_layer.weight
|
||||
lora_weight = self.get_lora_weight(lora_A=lora_A, lora_B=lora_B, adapter_name=adapter_name).reshape(
|
||||
weight.shape
|
||||
)
|
||||
magnitude = self.weight
|
||||
weight_norm = self.get_weight_norm(
|
||||
weight=weight, lora_weight=lora_weight.detach(), scaling=scaling, adapter_name=adapter_name
|
||||
)
|
||||
# see section 4.3 of DoRA (https://huggingface.co/papers/2402.09353)
|
||||
# "[...] we suggest treating ||V +∆V ||_c in
|
||||
# Eq. (5) as a constant, thereby detaching it from the gradient
|
||||
# graph. This means that while ||V + ∆V ||_c dynamically
|
||||
# reflects the updates of ∆V , it won’t receive any gradient
|
||||
# during backpropagation"
|
||||
weight_norm = weight_norm.detach()
|
||||
mag_norm_scale = magnitude / weight_norm
|
||||
|
||||
if base_result is None:
|
||||
base_result = self.conv_fn(
|
||||
x,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=base_layer.stride,
|
||||
padding=base_layer.padding,
|
||||
dilation=base_layer.dilation,
|
||||
groups=base_layer.groups,
|
||||
)
|
||||
else:
|
||||
bias = base_layer.bias
|
||||
if bias is not None:
|
||||
# reshape bias to (1, -1, 1, ...)
|
||||
bias_shape = (1, -1) + (1,) * (base_result.dim() - 2)
|
||||
base_result = base_result - bias.view(*bias_shape)
|
||||
|
||||
result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_B(lora_A(x)) * scaling
|
||||
return result_dora
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "lora.dora." + rep
|
||||
|
||||
|
||||
class DoraConv1dLayer(_DoraConvNdLayer):
|
||||
def __init__(self, fan_in_fan_out):
|
||||
super().__init__(fan_in_fan_out)
|
||||
self.conv_fn = F.conv1d
|
||||
|
||||
|
||||
class DoraConv2dLayer(_DoraConvNdLayer):
|
||||
def __init__(self, fan_in_fan_out):
|
||||
super().__init__(fan_in_fan_out)
|
||||
self.conv_fn = F.conv2d
|
||||
|
||||
|
||||
class DoraConv3dLayer(_DoraConvNdLayer):
|
||||
def __init__(self, fan_in_fan_out):
|
||||
super().__init__(fan_in_fan_out)
|
||||
self.conv_fn = F.conv3d
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,923 @@
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import warnings
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.utils.imports import is_xpu_available
|
||||
from torch import nn
|
||||
|
||||
from peft.tuners.lora.config import BdLoraConfig
|
||||
from peft.utils.other import transpose
|
||||
|
||||
from .arrow import ArrowLoraLinearLayer
|
||||
from .config import LoraConfig, PeftConfig
|
||||
from .dora import DoraConv1dLayer, DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer
|
||||
from .layer import Conv1d, Conv2d, Conv3d, Embedding, Linear, LoraVariant, _ConvNd
|
||||
|
||||
|
||||
class ArrowLinearVariant(LoraVariant):
|
||||
@staticmethod
|
||||
def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs):
|
||||
"""
|
||||
Initialise the ArrowLoraLinearLayer() inside lora_arrow. lora_arrow is nn.ModuleDict(), serving as a container
|
||||
for ArrowLoraLinearLayer(). A layer of the base model with LoRA adapter loaded on it will be like:
|
||||
----------------------------------------------------
|
||||
(qkv_proj): lora.Linear4bit or lora.Linear(
|
||||
(base_layer): Linear4bit or Linear (lora_dropout): ModuleDict( ... ) (lora_A): ModuleDict( ... )
|
||||
(lora_B): ModuleDict( ... ) (lora_embedding_A): ParameterDict( ... ) (lora_embedding_B): ParameterDict(
|
||||
... ) (lora_magnitude_vector): ModuleDict( ... ) (lora_arrow): ModuleDict(
|
||||
(arrow_router): ArrowLoraLinearLayer() )
|
||||
)
|
||||
----------------------------------------------------
|
||||
|
||||
Args:
|
||||
module (Linear): LoRA Layer of the model, containing base_layer, lora_A, lora_B, etc.
|
||||
adapter_name (str): name of the adapter that will be put in lora_arrow.
|
||||
The adapter_name is "arrow_router" by default, set in create_arrow_model() in ./arrow.py
|
||||
"""
|
||||
# Checking for arrow necessary config
|
||||
arrow_config = config.arrow_config
|
||||
if arrow_config is None:
|
||||
raise ValueError("ArrowLinearVariant.init() did not receive an arrow_config")
|
||||
|
||||
# 1-a) build the ArrowLoRALayer
|
||||
arrow_layer = ArrowLoraLinearLayer(
|
||||
in_features=module.in_features,
|
||||
arrow_config=arrow_config,
|
||||
).to(module.weight.device)
|
||||
|
||||
# 1-b) register a container if it doesn’t exist yet
|
||||
if not hasattr(module, "lora_arrow"):
|
||||
module.lora_arrow = nn.ModuleDict()
|
||||
|
||||
module.lora_arrow[adapter_name] = arrow_layer
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
module: Linear,
|
||||
*,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Parameters mirror those in PEFT’s `LoraVariant.forward`. Called every time the host Linear does a fwd pass.
|
||||
|
||||
build_prototypes() and gen_know_sub() should run only once before routing. Both are implemented in
|
||||
ArrowLoraLinearLayer (see ./arrow.py). They are lazily invoked in the forward pass below. Attributes of
|
||||
ArrowLoraLinearLayer() class ensure they execute only a single time.
|
||||
|
||||
Args:
|
||||
module (Linear): LoRA Layer of the model
|
||||
active_adapter (str): name of the arrow route, which should be active to perform arrow.
|
||||
x (torch.Tensor): input to the layer
|
||||
result (torch.Tensor): output of the base layer.
|
||||
|
||||
Return value:
|
||||
output of the base model + delta weight computed by arrow layer.
|
||||
"""
|
||||
arrow = module.lora_arrow[active_adapter] # ArrowLoraLinearLayer
|
||||
# Apply GenKnowSub the 1st time if applcable. By calling arrow/on_adapter_change(),
|
||||
# gen_know_sub() is redone for newly added adapters after arrow.create_arrow_model().
|
||||
arrow.gen_know_sub(module.lora_A, module.lora_B)
|
||||
# lazily build prototypes the 1st time after GenKnowSub. By calling arrow/on_adapter_change(),
|
||||
# build_prototypes() is redone for newly added adapters after arrow.create_arrow_model().
|
||||
arrow.build_prototypes(module.lora_A, module.lora_B)
|
||||
|
||||
# A forward path of ArrowLoraLinearLayer is called so routing performs.
|
||||
# Accept and ignore extra variant kwargs (e.g., 'alora_offsets') for compatibility
|
||||
delta = arrow(
|
||||
x,
|
||||
lora_A=module.lora_A,
|
||||
lora_B=module.lora_B,
|
||||
dropout=module.lora_dropout[active_adapter],
|
||||
scaling=module.scaling,
|
||||
)
|
||||
return result + delta
|
||||
|
||||
"""
|
||||
Since Arrow is a Mixture-of-Experts (MoE) approach, merging adapters is not meaningful or even possible: for each
|
||||
token, the top-k LoRA experts are dynamically selected and routed. Because of this per-token routing, there is no
|
||||
single set of weights that can represent a merged adapter.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise RuntimeError("Cannot merge an active Arrow router adapter. Remove it first.")
|
||||
|
||||
@staticmethod
|
||||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||||
raise RuntimeError("Cannot merge an active Arrow router adapter. Remove it first.")
|
||||
|
||||
@staticmethod
|
||||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise RuntimeError("Cannot unmerge an active Arrow router adapter. Remove it first.")
|
||||
|
||||
|
||||
class DoraLinearVariant(LoraVariant):
|
||||
@staticmethod
|
||||
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
|
||||
if not module.lora_magnitude_vector:
|
||||
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
|
||||
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_magnitude_vector",)
|
||||
|
||||
dora_layer = DoraLinearLayer(fan_in_fan_out=getattr(module, "fan_in_fan_out", False))
|
||||
lora_A = module.lora_A[adapter_name].weight
|
||||
lora_B = module.lora_B[adapter_name].weight
|
||||
place_on_cpu = module.ephemeral_gpu_offload and (lora_A.device.type == "cpu" or lora_B.device.type == "cpu")
|
||||
if module.ephemeral_gpu_offload:
|
||||
if lora_A.device.type in ["cuda", "xpu"]:
|
||||
lora_B = lora_B.to(lora_A.device)
|
||||
else:
|
||||
if lora_B.device.type not in ["cuda", "xpu"]:
|
||||
if is_xpu_available():
|
||||
lora_B = lora_B.to("xpu")
|
||||
else:
|
||||
lora_B = lora_B.to("cuda")
|
||||
lora_A = lora_A.to(lora_B.device)
|
||||
scaling = module.scaling[adapter_name]
|
||||
dora_layer.update_layer(
|
||||
base_layer=module.get_base_layer(),
|
||||
lora_A=lora_A,
|
||||
lora_B=lora_B,
|
||||
scaling=scaling,
|
||||
place_on_cpu=place_on_cpu,
|
||||
)
|
||||
module.lora_magnitude_vector[adapter_name] = dora_layer
|
||||
|
||||
@staticmethod
|
||||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = orig_weight.dtype
|
||||
delta_weight = module.get_delta_weight(active_adapter)
|
||||
|
||||
# since delta_weight already includes scaling, set it to 1 here
|
||||
weight_norm = (
|
||||
module.lora_magnitude_vector[active_adapter]
|
||||
.get_weight_norm(orig_weight, transpose(delta_weight, module.fan_in_fan_out), scaling=1)
|
||||
.detach()
|
||||
)
|
||||
# We need to cache weight_norm because it has to be based on the original weights. We
|
||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||
# different value
|
||||
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||
dora_factor = transpose(dora_factor.view(-1, 1), module.fan_in_fan_out)
|
||||
new_weight = dora_factor * (orig_weight + delta_weight)
|
||||
new_weight = new_weight.to(orig_dtype)
|
||||
return new_weight
|
||||
|
||||
@staticmethod
|
||||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||||
orig_dtype = orig_weight.dtype
|
||||
delta_weight = module.get_delta_weight(active_adapter)
|
||||
weight_norm = (
|
||||
module.lora_magnitude_vector[active_adapter]
|
||||
.get_weight_norm(orig_weight, transpose(delta_weight, module.fan_in_fan_out), scaling=1)
|
||||
.detach()
|
||||
)
|
||||
# We need to cache weight_norm because it has to be based on the original weights. We
|
||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||
# different value
|
||||
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||
dora_factor = transpose(dora_factor.view(-1, 1), module.fan_in_fan_out)
|
||||
new_weight = dora_factor * (orig_weight.data + delta_weight)
|
||||
new_weight = new_weight.to(orig_dtype)
|
||||
orig_weight.data = new_weight
|
||||
|
||||
@staticmethod
|
||||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = orig_weight.dtype
|
||||
delta_weight = module.get_delta_weight(active_adapter)
|
||||
weight_norm = module._cache_pop(f"{active_adapter}-weight_norm")
|
||||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||
new_weight = orig_weight.data / dora_factor.view(-1, 1) - delta_weight
|
||||
new_weight = new_weight.to(orig_dtype)
|
||||
return new_weight
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
module: Linear,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
lora_A = module.lora_A[active_adapter]
|
||||
lora_B = module.lora_B[active_adapter]
|
||||
dropout = module.lora_dropout[active_adapter]
|
||||
scaling = module.scaling[active_adapter]
|
||||
|
||||
if isinstance(dropout, nn.Identity) or not module.training:
|
||||
base_result = result
|
||||
else:
|
||||
x = dropout(x)
|
||||
base_result = None
|
||||
|
||||
result = result + module.lora_magnitude_vector[active_adapter](
|
||||
x,
|
||||
lora_A=lora_A,
|
||||
lora_B=lora_B,
|
||||
scaling=scaling,
|
||||
base_layer=module.get_base_layer(),
|
||||
base_result=base_result,
|
||||
adapter_name=active_adapter,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class DoraEmbeddingVariant(DoraLinearVariant):
|
||||
@staticmethod
|
||||
def init(module: Embedding, adapter_name: str, **kwargs: Any) -> None:
|
||||
if module.lora_magnitude_vector is None:
|
||||
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
|
||||
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_magnitude_vector",)
|
||||
|
||||
dora_layer = DoraEmbeddingLayer(fan_in_fan_out=True)
|
||||
lora_embedding_A = module.lora_embedding_A[adapter_name]
|
||||
lora_embedding_B = module.lora_embedding_B[adapter_name]
|
||||
scaling = module.scaling[adapter_name]
|
||||
dora_layer.update_layer(
|
||||
base_layer=module.get_base_layer(), lora_A=lora_embedding_A, lora_B=lora_embedding_B, scaling=scaling
|
||||
)
|
||||
module.lora_magnitude_vector[adapter_name] = dora_layer
|
||||
|
||||
@staticmethod
|
||||
def merge_safe(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = orig_weight.dtype
|
||||
delta_weight = module.get_delta_weight(active_adapter)
|
||||
|
||||
# since delta_weight already includes scaling, set it to 1 here
|
||||
weight_norm = (
|
||||
module.lora_magnitude_vector[active_adapter]
|
||||
.get_weight_norm(orig_weight, delta_weight.T, scaling=1)
|
||||
.detach()
|
||||
)
|
||||
# We need to cache weight_norm because it has to be based on the original weights. We
|
||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||
# different value
|
||||
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||
dora_factor = dora_factor.view(1, -1)
|
||||
new_weight = dora_factor * (orig_weight + delta_weight)
|
||||
new_weight = new_weight.to(orig_dtype)
|
||||
return new_weight
|
||||
|
||||
@staticmethod
|
||||
def merge_unsafe(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||||
orig_dtype = orig_weight.dtype
|
||||
delta_weight = module.get_delta_weight(active_adapter)
|
||||
weight_norm = (
|
||||
module.lora_magnitude_vector[active_adapter]
|
||||
.get_weight_norm(orig_weight, delta_weight.T, scaling=1)
|
||||
.detach()
|
||||
)
|
||||
# We need to cache weight_norm because it has to be based on the original weights. We
|
||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||
# different value
|
||||
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||
dora_factor = dora_factor.view(1, -1)
|
||||
new_weight = dora_factor * (orig_weight.data + delta_weight)
|
||||
new_weight = new_weight.to(orig_dtype)
|
||||
orig_weight.data = new_weight
|
||||
|
||||
@staticmethod
|
||||
def unmerge(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = orig_weight.dtype
|
||||
delta_weight = module.get_delta_weight(active_adapter)
|
||||
weight_norm = module._cache_pop(f"{active_adapter}-weight_norm")
|
||||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||
new_weight = orig_weight.data / dora_factor.view(1, -1) - delta_weight
|
||||
new_weight = new_weight.to(orig_dtype)
|
||||
return new_weight
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
module: Embedding,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
embedding_A = module.lora_embedding_A[active_adapter].T
|
||||
embedding_B = module.lora_embedding_B[active_adapter].T
|
||||
scaling = module.scaling[active_adapter]
|
||||
|
||||
mag_norm_scale, dora_result = module.lora_magnitude_vector[active_adapter](
|
||||
x,
|
||||
lora_A=embedding_A,
|
||||
lora_B=embedding_B,
|
||||
scaling=scaling,
|
||||
base_layer=module.get_base_layer(),
|
||||
embed_fn=module._embed,
|
||||
adapter_name=active_adapter,
|
||||
)
|
||||
|
||||
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
|
||||
# Since base_layer(x) already includes this scaling, we need to apply it to DoRA contributions too.
|
||||
# Note: embed_scale is applied AFTER weight norm calculation to preserve DoRA's weight geometry semantics.
|
||||
embed_scale = module._get_embed_scale()
|
||||
if embed_scale is not None:
|
||||
dora_result = dora_result * embed_scale.to(dora_result.dtype)
|
||||
|
||||
result = mag_norm_scale * result + dora_result
|
||||
return result
|
||||
|
||||
|
||||
class _DoraConvNdVariant(LoraVariant):
|
||||
@staticmethod
|
||||
def init_convd_variant(module: _ConvNd, adapter_name: str, dora_layer: nn.Module) -> None:
|
||||
if module.lora_magnitude_vector is None:
|
||||
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
|
||||
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_magnitude_vector",)
|
||||
|
||||
lora_A = module.lora_A[adapter_name].weight
|
||||
lora_B = module.lora_B[adapter_name].weight
|
||||
scaling = module.scaling[adapter_name]
|
||||
dora_layer.update_layer(base_layer=module.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling)
|
||||
module.lora_magnitude_vector[adapter_name] = dora_layer
|
||||
|
||||
@staticmethod
|
||||
def merge_safe(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = orig_weight.dtype
|
||||
delta_weight = module.get_delta_weight(active_adapter)
|
||||
|
||||
# since delta_weight already includes scaling, set it to 1 here
|
||||
weight_norm = (
|
||||
module.lora_magnitude_vector[active_adapter].get_weight_norm(orig_weight, delta_weight, scaling=1).detach()
|
||||
)
|
||||
# We need to cache weight_norm because it has to be based on the original weights. We
|
||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||
# different value
|
||||
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||
new_weight = dora_factor.view(*module._get_dora_factor_view()) * (orig_weight + delta_weight)
|
||||
new_weight = new_weight.to(orig_dtype)
|
||||
return new_weight
|
||||
|
||||
@staticmethod
|
||||
def merge_unsafe(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||||
orig_dtype = orig_weight.dtype
|
||||
delta_weight = module.get_delta_weight(active_adapter)
|
||||
# since delta_weight already includes scaling, set it to 1 here
|
||||
weight_norm = (
|
||||
module.lora_magnitude_vector[active_adapter].get_weight_norm(orig_weight, delta_weight, scaling=1).detach()
|
||||
)
|
||||
# We need to cache weight_norm because it has to be based on the original weights. We
|
||||
# cannot calculate it on the fly based on the merged weights when unmerging because its a
|
||||
# different value
|
||||
module._cache_store(f"{active_adapter}-weight_norm", weight_norm)
|
||||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||
new_weight = dora_factor.view(*module._get_dora_factor_view()) * (orig_weight.data + delta_weight)
|
||||
new_weight = new_weight.to(orig_dtype)
|
||||
orig_weight.data = new_weight
|
||||
|
||||
@staticmethod
|
||||
def unmerge(module: _ConvNd, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = orig_weight.dtype
|
||||
delta_weight = module.get_delta_weight(active_adapter)
|
||||
weight_norm = module._cache_pop(f"{active_adapter}-weight_norm")
|
||||
dora_factor = module.lora_magnitude_vector[active_adapter].weight / weight_norm
|
||||
new_weight = orig_weight.data / dora_factor.view(*module._get_dora_factor_view()) - delta_weight
|
||||
new_weight = new_weight.to(orig_dtype)
|
||||
return new_weight
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
module: _ConvNd,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
lora_A = module.lora_A[active_adapter]
|
||||
lora_B = module.lora_B[active_adapter]
|
||||
dropout = module.lora_dropout[active_adapter]
|
||||
scaling = module.scaling[active_adapter]
|
||||
|
||||
if isinstance(dropout, nn.Identity) or not module.training:
|
||||
base_result = result
|
||||
else:
|
||||
x = dropout(x)
|
||||
base_result = None
|
||||
|
||||
result = result + module.lora_magnitude_vector[active_adapter](
|
||||
x,
|
||||
lora_A=lora_A,
|
||||
lora_B=lora_B,
|
||||
scaling=scaling,
|
||||
base_layer=module.get_base_layer(),
|
||||
base_result=base_result,
|
||||
adapter_name=active_adapter,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class DoraConv1dVariant(_DoraConvNdVariant):
|
||||
@staticmethod
|
||||
def init(module: Conv1d, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
|
||||
dora_layer = DoraConv1dLayer(fan_in_fan_out=False)
|
||||
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)
|
||||
|
||||
|
||||
class DoraConv2dVariant(_DoraConvNdVariant):
|
||||
@staticmethod
|
||||
def init(module: Conv2d, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
|
||||
dora_layer = DoraConv2dLayer(fan_in_fan_out=False)
|
||||
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)
|
||||
|
||||
|
||||
class DoraConv3dVariant(_DoraConvNdVariant):
|
||||
@staticmethod
|
||||
def init(module: Conv3d, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
|
||||
dora_layer = DoraConv3dLayer(fan_in_fan_out=False)
|
||||
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)
|
||||
|
||||
|
||||
class QALoraLinearVariant(LoraVariant):
|
||||
@staticmethod
|
||||
def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initializes QALoRA specific parameters for a given adapter.
|
||||
|
||||
Args:
|
||||
module (Linear): The linear module to be adapted.
|
||||
adapter_name (str): The name of the adapter.
|
||||
config (LoraConfig): The config of the LoRA adapter.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
qalora_group_size = config.qalora_group_size
|
||||
if module.in_features is not None and module.in_features % qalora_group_size != 0:
|
||||
raise ValueError(
|
||||
f"`use_qalora=True` requires `module.in_features` ({module.in_features}) to be"
|
||||
f"divisible by 'qalora_group_size' ({qalora_group_size})"
|
||||
)
|
||||
|
||||
if "qalora_group_size" not in module.other_param_names:
|
||||
module.other_param_names = module.other_param_names + ("qalora_group_size",)
|
||||
|
||||
if not hasattr(module, "qalora_group_size"):
|
||||
module.qalora_group_size = {}
|
||||
module.qalora_group_size[adapter_name] = qalora_group_size
|
||||
|
||||
old_lora_A_layer = module.lora_A[adapter_name]
|
||||
r = old_lora_A_layer.out_features
|
||||
device = old_lora_A_layer.weight.device
|
||||
dtype = old_lora_A_layer.weight.dtype
|
||||
|
||||
new_lora_A_layer = nn.Linear(
|
||||
old_lora_A_layer.in_features // module.qalora_group_size[adapter_name],
|
||||
r,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
module.lora_A[adapter_name] = new_lora_A_layer
|
||||
|
||||
@staticmethod
|
||||
def get_delta_weight(module: Linear, active_adapter: str) -> torch.Tensor:
|
||||
raise NotImplementedError("QALoRA for GPTQ layers does not support 'get_delta_weight'.")
|
||||
|
||||
@staticmethod
|
||||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError("QALoRA for GPTQ layers does not support 'safe_merge'.")
|
||||
|
||||
@staticmethod
|
||||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||||
raise NotImplementedError("QALoRA for GPTQ layers does not support 'merge_unsafe'.")
|
||||
|
||||
@staticmethod
|
||||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError("QALoRA for GPTQ layers does not support 'unmerge'.")
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
module: Linear,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
lora_A_weight = module.lora_A[active_adapter].weight
|
||||
lora_B_weight = module.lora_B[active_adapter].weight
|
||||
dropout = module.lora_dropout[active_adapter]
|
||||
scaling = module.scaling[active_adapter]
|
||||
group_size = module.qalora_group_size[active_adapter]
|
||||
|
||||
x_dropped = dropout(x) if module.training and not isinstance(dropout, nn.Identity) else x
|
||||
orig_shape = x_dropped.shape
|
||||
|
||||
# Reshape to 2D
|
||||
if len(orig_shape) > 2:
|
||||
x_flat = x_dropped.view(-1, module.in_features)
|
||||
else:
|
||||
x_flat = x_dropped
|
||||
|
||||
batch_size, in_features = x_flat.shape
|
||||
pooled_features = in_features // group_size
|
||||
|
||||
x_pooled = x_flat.view(batch_size, pooled_features, group_size).mean(dim=2)
|
||||
|
||||
x_pooled_scaled = x_pooled * pooled_features
|
||||
|
||||
# LoRA computation
|
||||
delta = x_pooled_scaled @ lora_A_weight.t() @ lora_B_weight.t() * scaling
|
||||
|
||||
# Reshape back
|
||||
if len(orig_shape) > 2:
|
||||
delta = delta.view(orig_shape[:-1] + (delta.size(-1),))
|
||||
|
||||
return result + delta
|
||||
|
||||
|
||||
class ALoraLinearVariant(LoraVariant):
|
||||
@staticmethod
|
||||
def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError("aLoRA does not support safe merging.")
|
||||
|
||||
@staticmethod
|
||||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||||
raise NotImplementedError("aLoRA does not support merging.")
|
||||
|
||||
@staticmethod
|
||||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError("aLoRA does not support unmerging.")
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
module: Linear,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
alora_offsets = kwargs.get("alora_offsets", None)
|
||||
lora_A = module.lora_A[active_adapter]
|
||||
lora_B = module.lora_B[active_adapter]
|
||||
dropout = module.lora_dropout[active_adapter]
|
||||
scaling = module.scaling[active_adapter]
|
||||
x = x.to(lora_A.weight.dtype)
|
||||
result_shape = result.shape
|
||||
B = result_shape[0] # batch
|
||||
if len(result_shape) == 3:
|
||||
T = result_shape[1] # tokens
|
||||
else:
|
||||
T = 1
|
||||
D = result_shape[-1] # dimensions
|
||||
Dx = x.shape[-1]
|
||||
device = result.device
|
||||
if alora_offsets is None: # use base model only, but ensure 0 gradient
|
||||
mask = torch.zeros((B, T), dtype=torch.bool)
|
||||
else:
|
||||
# If alora_offsets[i] is None, this means that the invocation sequence was not found in the
|
||||
# input. As a result, the weights should not be activated anywhere (equivalent to base model).
|
||||
# Convert None -> 0 and clip to T
|
||||
offsets = torch.tensor(
|
||||
[0 if o is None else min(int(o), T) for o in alora_offsets],
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
# Mask True on the last `offsets[i]` positions for each row i
|
||||
pos = torch.arange(T, device=device).unsqueeze(0) # [1, T]
|
||||
mask = pos >= (T - offsets).unsqueeze(1)
|
||||
|
||||
# Flatten for vectorization
|
||||
x_flat = x.view(-1, Dx)
|
||||
res_flat = result.view(-1, D)
|
||||
mask_flat = mask.view(-1)
|
||||
|
||||
# Compute adapter on the selected tokens only
|
||||
res_flat[mask_flat] += lora_B(lora_A(dropout(x_flat[mask_flat]))) * scaling
|
||||
return result
|
||||
|
||||
|
||||
def calculate_alora_offsets(
|
||||
peft_config: PeftConfig, active_adapter: str, input_ids: torch.Tensor, adapter_names: Optional[list[str]] = None
|
||||
) -> list[int]:
|
||||
"""
|
||||
This is a helper function for Activated LoRA (aLoRA) that searches each input token sequence for the last
|
||||
occurrence of the appropriate "alora_invocation_tokens" invocation sequence. The calculated alora_offset is the
|
||||
location of the *start* of the invocation tokens, counting backward from the end (will therefore always be >=
|
||||
len(alora_invocation_tokens). If adapter_names is passed, then each input uses the appropriate invocation sequence
|
||||
for the specified adapter for that row. Logic is provided to handle mixed collections of adapters for which not all
|
||||
are aLoRAs (e.g. some base model, some LoRA).
|
||||
"""
|
||||
if input_ids is None:
|
||||
return []
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
alora_offsets = [None] * batch_size
|
||||
|
||||
cached_invocation_tensors = {}
|
||||
adapters_to_process_indices = collections.defaultdict(list)
|
||||
|
||||
for i in range(batch_size):
|
||||
current_adapter_name = adapter_names[i] if adapter_names and i < len(adapter_names) else active_adapter
|
||||
|
||||
if current_adapter_name == "__base__":
|
||||
alora_offsets[i] = None
|
||||
continue
|
||||
|
||||
if current_adapter_name not in peft_config:
|
||||
warnings.warn(f"Adapter '{current_adapter_name}' not found in peft_config. Using base model for row {i}.")
|
||||
alora_offsets[i] = None
|
||||
continue
|
||||
|
||||
current_peft_config = peft_config[current_adapter_name]
|
||||
|
||||
invocation_tokens = getattr(current_peft_config, "alora_invocation_tokens", None)
|
||||
if invocation_tokens is None:
|
||||
alora_offsets[i] = None # Not an aLoRA adapter or wrong type
|
||||
continue
|
||||
|
||||
if current_adapter_name not in cached_invocation_tensors:
|
||||
cached_invocation_tensors[current_adapter_name] = torch.tensor(
|
||||
invocation_tokens, dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
|
||||
adapters_to_process_indices[current_adapter_name].append(i)
|
||||
|
||||
for adapter_name_to_process, indices in adapters_to_process_indices.items():
|
||||
current_invocation_ids_tensor = cached_invocation_tensors[adapter_name_to_process]
|
||||
invocation_len = len(current_invocation_ids_tensor)
|
||||
|
||||
for i in indices:
|
||||
sequence = input_ids[i]
|
||||
seq_len = len(sequence)
|
||||
best_match_start_idx = -1
|
||||
|
||||
possible_starts = (sequence == current_invocation_ids_tensor[0]).nonzero(as_tuple=True)[0]
|
||||
|
||||
for start_idx_tensor in possible_starts:
|
||||
idx = start_idx_tensor.item()
|
||||
if idx + invocation_len <= seq_len:
|
||||
if torch.equal(sequence[idx : idx + invocation_len], current_invocation_ids_tensor):
|
||||
if idx > best_match_start_idx:
|
||||
best_match_start_idx = idx
|
||||
|
||||
if best_match_start_idx != -1:
|
||||
offset_val = seq_len - best_match_start_idx
|
||||
alora_offsets[i] = offset_val if offset_val > 0 else None
|
||||
else: # Invocation sequence not found in input
|
||||
alora_offsets[i] = None
|
||||
return alora_offsets
|
||||
|
||||
|
||||
def is_alora_relevant_in_batch(model: nn.Module, adapter_names: Optional[list[str]] = None):
|
||||
"""
|
||||
Helper function to determine if the current batch has any aLoRA adapters.
|
||||
"""
|
||||
is_alora_relevant = False
|
||||
if getattr(model.active_peft_config, "alora_invocation_tokens", None):
|
||||
is_alora_relevant = True
|
||||
elif adapter_names:
|
||||
for name in adapter_names:
|
||||
if name == "__base__":
|
||||
continue
|
||||
config_ = model.peft_config.get(name)
|
||||
if config_ and getattr(config_, "alora_invocation_tokens", None):
|
||||
is_alora_relevant = True
|
||||
break
|
||||
|
||||
return is_alora_relevant
|
||||
|
||||
|
||||
def get_alora_offsets_for_forward(
|
||||
model: nn.Module, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Wrapper around calculate_alora_offsets, for the .forward of the model. It only calculates alora_offsets if the
|
||||
batch contains aLoRA adapters.
|
||||
"""
|
||||
adapter_names_for_offset_calc = kwargs.get("adapter_names", None)
|
||||
if not is_alora_relevant_in_batch(model, adapter_names_for_offset_calc):
|
||||
# Nothing to compute
|
||||
return kwargs
|
||||
alora_offsets = kwargs.get("alora_offsets")
|
||||
if alora_offsets is None:
|
||||
if input_ids is None and inputs_embeds is not None:
|
||||
warnings.warn(
|
||||
"Cannot calculate aLoRA offsets when only inputs_embeds are provided. Disabling aLoRA for this forward pass."
|
||||
)
|
||||
kwargs["alora_offsets"] = None
|
||||
elif input_ids is not None:
|
||||
kwargs["alora_offsets"] = calculate_alora_offsets(
|
||||
model.peft_config,
|
||||
model.active_adapter,
|
||||
input_ids,
|
||||
adapter_names=adapter_names_for_offset_calc,
|
||||
)
|
||||
else:
|
||||
kwargs["alora_offsets"] = None
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_alora_offsets_for_generate(model: nn.module, *args, **kwargs):
|
||||
"""
|
||||
Wrapper around calculate_alora_offsets, for the .generate of the model. It only calculates alora_offsets if the
|
||||
batch contains aLoRA adapters.
|
||||
"""
|
||||
adapter_names_for_offset_calc = kwargs.get("adapter_names")
|
||||
if not is_alora_relevant_in_batch(model, adapter_names_for_offset_calc):
|
||||
# Nothing to compute
|
||||
return kwargs
|
||||
alora_offsets_from_kwargs = kwargs.get("alora_offsets")
|
||||
if alora_offsets_from_kwargs is None:
|
||||
current_input_ids = kwargs.get("input_ids")
|
||||
if current_input_ids is None: # args[0] is usually input_ids
|
||||
if args and isinstance(args[0], torch.Tensor):
|
||||
current_input_ids = args[0]
|
||||
else:
|
||||
current_input_ids = None
|
||||
|
||||
if current_input_ids is not None:
|
||||
if current_input_ids.ndim == 1:
|
||||
current_input_ids = current_input_ids.unsqueeze(0)
|
||||
calculated_offsets = calculate_alora_offsets(
|
||||
model.peft_config,
|
||||
model.active_adapter,
|
||||
current_input_ids,
|
||||
adapter_names=adapter_names_for_offset_calc,
|
||||
)
|
||||
kwargs["alora_offsets"] = calculated_offsets
|
||||
|
||||
else:
|
||||
warnings.warn(
|
||||
"Cannot calculate aLoRA offsets during generate as input_ids are not available. Disabling aLoRA."
|
||||
)
|
||||
|
||||
kwargs["alora_offsets"] = None
|
||||
return kwargs
|
||||
|
||||
|
||||
class BlockDiagonalLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
nblocks: int,
|
||||
init_zero: bool = False,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.nblocks = nblocks
|
||||
if self.in_features % nblocks != 0 or self.out_features % nblocks != 0:
|
||||
raise ValueError(
|
||||
f"self.in_features={self.in_features} or self.out_features={self.out_features} not divisible by {self.nblocks}"
|
||||
)
|
||||
# Create weight with specified dtype and device
|
||||
self.weight = nn.Parameter(torch.empty(out_features, in_features // nblocks, dtype=dtype, device=device))
|
||||
|
||||
if init_zero:
|
||||
torch.nn.init.zeros_(self.weight)
|
||||
else:
|
||||
torch.nn.init.kaiming_uniform_(self.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
first_dims = x.shape[:-1]
|
||||
if x.dim() != 2:
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
B = x.shape[0]
|
||||
nb = self.nblocks
|
||||
m = x.shape[-1] // nb
|
||||
n = self.out_features // nb
|
||||
x = x.reshape(B, nb, m)
|
||||
w = self.weight.view(nb, n, m)
|
||||
out = torch.einsum("bim,inm->bin", x, w)
|
||||
return out.reshape(*first_dims, -1)
|
||||
|
||||
def weight_as_blockdiagonal_matrix(self):
|
||||
"""Returns weight in a format similar to a vanilla LoRA adapter. For this, we stack the blocks on the diagonal,
|
||||
leaving the off-diagonals padded with zero."""
|
||||
return torch.block_diag(*torch.chunk(self.weight, self.nblocks, dim=0))
|
||||
|
||||
|
||||
class BdLoraLinearVariant(LoraVariant):
|
||||
@staticmethod
|
||||
def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs) -> None:
|
||||
use_bdlora = config.use_bdlora
|
||||
target_name = kwargs.get("target_name", "")
|
||||
|
||||
# Handle case where use_bdlora is a dict (from saved config) instead of BdLoraConfig object
|
||||
if isinstance(use_bdlora, dict):
|
||||
use_bdlora = BdLoraConfig(**use_bdlora)
|
||||
|
||||
lora_a_blockdiagonal_pattern = use_bdlora.target_modules_bd_a or []
|
||||
lora_b_blockdiagonal_pattern = use_bdlora.target_modules_bd_b or []
|
||||
nblocks = use_bdlora.nblocks
|
||||
|
||||
has_lora_a_blockdiagonal = any(pattern in target_name for pattern in lora_a_blockdiagonal_pattern)
|
||||
has_lora_b_blockdiagonal = any(pattern in target_name for pattern in lora_b_blockdiagonal_pattern)
|
||||
|
||||
if has_lora_a_blockdiagonal and has_lora_b_blockdiagonal:
|
||||
raise ValueError(f"Target {target_name} matches both A and B block-diagonal patterns")
|
||||
if use_bdlora.match_strict and not (has_lora_a_blockdiagonal or has_lora_b_blockdiagonal):
|
||||
raise ValueError(
|
||||
f"Target {target_name} matches neither A nor B block-diagonal patterns."
|
||||
"If this is intentional, set match_strict=False in BdLoraConfig during initialization. "
|
||||
)
|
||||
|
||||
if has_lora_a_blockdiagonal:
|
||||
r = module.lora_A[adapter_name].out_features
|
||||
base_layer = module.get_base_layer()
|
||||
layer = BlockDiagonalLinear(
|
||||
base_layer.in_features,
|
||||
r,
|
||||
nblocks=nblocks,
|
||||
init_zero=False,
|
||||
dtype=base_layer.weight.dtype,
|
||||
device=base_layer.weight.device,
|
||||
)
|
||||
module.lora_A[adapter_name] = layer
|
||||
elif has_lora_b_blockdiagonal:
|
||||
r = module.lora_B[adapter_name].in_features
|
||||
base_layer = module.get_base_layer()
|
||||
layer = BlockDiagonalLinear(
|
||||
r,
|
||||
base_layer.out_features,
|
||||
nblocks=nblocks,
|
||||
init_zero=True,
|
||||
dtype=base_layer.weight.dtype,
|
||||
device=base_layer.weight.device,
|
||||
)
|
||||
module.lora_B[adapter_name] = layer
|
||||
|
||||
@staticmethod
|
||||
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
lora_A = module.lora_A[active_adapter]
|
||||
lora_B = module.lora_B[active_adapter]
|
||||
dropout = module.lora_dropout[active_adapter]
|
||||
scaling = module.scaling[active_adapter]
|
||||
x = dropout(x)
|
||||
# Cast input dtype to match lora_A weight dtype
|
||||
x = module._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
result += lora_B(lora_A(x)) * scaling
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_weight_from_module_maybe_blockdiagonal(module: nn.Module) -> torch.Tensor:
|
||||
if isinstance(module, BlockDiagonalLinear):
|
||||
return module.weight_as_blockdiagonal_matrix()
|
||||
else:
|
||||
return module.weight # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _get_bdlora_delta_weight(module: Linear, adapter: str) -> torch.Tensor:
|
||||
"""Similar to get_delta_weight for a linear module, but we have to eventually reshape the blocks
|
||||
of the weights."""
|
||||
device = module.lora_B[adapter].weight.device
|
||||
# Use base layer dtype to ensure compatibility with merge/unmerge operations
|
||||
base_layer = module.get_base_layer()
|
||||
dtype = base_layer.weight.dtype
|
||||
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
|
||||
|
||||
weight_A = BdLoraLinearVariant._get_weight_from_module_maybe_blockdiagonal(module.lora_A[adapter])
|
||||
weight_B = BdLoraLinearVariant._get_weight_from_module_maybe_blockdiagonal(module.lora_B[adapter])
|
||||
|
||||
if cast_to_fp32:
|
||||
weight_A = weight_A.float()
|
||||
weight_B = weight_B.float()
|
||||
|
||||
output_tensor = transpose(weight_B @ weight_A, module.fan_in_fan_out) * module.scaling[adapter]
|
||||
|
||||
if cast_to_fp32:
|
||||
output_tensor = output_tensor.to(dtype=dtype)
|
||||
|
||||
# Ensure output tensor matches base layer dtype
|
||||
return output_tensor.to(dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return orig_weight + BdLoraLinearVariant._get_bdlora_delta_weight(module, active_adapter)
|
||||
|
||||
@staticmethod
|
||||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||||
orig_weight.data += BdLoraLinearVariant._get_bdlora_delta_weight(module, active_adapter)
|
||||
|
||||
@staticmethod
|
||||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return orig_weight - BdLoraLinearVariant._get_bdlora_delta_weight(module, active_adapter)
|
||||
@@ -1,6 +1,6 @@
|
||||
"""DeLoRA: column-normalised A, B, scaled by lambda * ||W||_F / r.
|
||||
|
||||
Bini et al. 2025 https://arxiv.org/abs/2503.18225
|
||||
Bini et al. 2025 (ICLR'25) https://arxiv.org/abs/2503.18225
|
||||
|
||||
Paper Eq. 8: W' = W + (lambda * ||W||_F / r) B Xi A
|
||||
where Xi_{i,i} = 1 / (||b_i|| ||a_i||) makes each rank-1 component unit-norm.
|
||||
@@ -12,7 +12,22 @@ Identity at t=0: paper uses kaiming init for both A and B with `lambda` initiali
|
||||
to 0 (or small) so the effective delta starts near zero. We honour that:
|
||||
default lambda0 == 0 gives bit-identity; user can override via variant_kwargs.
|
||||
|
||||
KNOWN GRADIENT ISSUE (flagged by external review 2026-04-26):
|
||||
With lambda0=0 the *forward* is identity but `A,B` get zero gradient on step 0
|
||||
(delta = lambda * ... -> d_output/d_A is proportional to lambda). Only
|
||||
`lora_lambda` moves first step. With lambda0>0, A,B train but identity is broken.
|
||||
Paper's true initialization (frozen-copy trick, see Eq. 9) achieves both;
|
||||
we do NOT implement that here.
|
||||
|
||||
The frozen ||W||_F factor is captured once at init() into a buffer `lora_wnorm`.
|
||||
|
||||
Reference implementations (for review/cross-check):
|
||||
- DeLoRA paper authors (ExplainableML/DeLoRA) -- their fork of peft:
|
||||
https://github.com/ExplainableML/DeLoRA/blob/main/peft/src/peft/tuners/delora.py
|
||||
(offline: docs/refs/orig_delora.py)
|
||||
- peft DeLoRA (upstreamed):
|
||||
https://github.com/huggingface/peft/blob/main/src/peft/tuners/delora/layer.py
|
||||
(offline: docs/refs/peft_delora_layer.py)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -6,6 +6,17 @@ At t=0: B=0 -> V=W -> y_new = (m_init / ||W||_c) (Wx + 0) = Wx when m_init =
|
||||
|
||||
Limitation: requires materializing the dense weight to compute ||V||_c. v1 supports
|
||||
plain nn.Linear only; bnb 4/8-bit layers raise loudly.
|
||||
|
||||
DEVIATION (numerical):
|
||||
- We differentiate through ||V||_c every forward. The paper's sec. 4.3 suggests
|
||||
a 'cost-saving' variant that detaches ||V|| in backward (gradient only flows
|
||||
through V); we do NOT do that. Real impact: slower step, slightly different
|
||||
gradient direction. Faithful to the eq.5 forward, not the optimized one.
|
||||
|
||||
Reference implementations (for review/cross-check):
|
||||
- peft DoRA (separate file under lora/):
|
||||
https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/dora.py
|
||||
(offline: docs/refs/peft_lora_dora.py)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -13,10 +13,25 @@ Identity at t=0: `lora_gate` is initialized to 0 and gates each Householder
|
||||
vector, so the effective u_i starts at 0 -> H_i = I -> R = I -> y' = y.
|
||||
At training time the gate scales the active reflection direction.
|
||||
|
||||
KNOWN GRADIENT ISSUE (flagged by external review 2026-04-26):
|
||||
Forward is `x + gate * (Rx - x)`. With gate=0 at init, d_output/d_U is
|
||||
proportional to gate, so on step 0 ONLY `lora_gate` receives gradient;
|
||||
`lora_U` is dead. Once gate moves off zero, U starts learning. This deviates
|
||||
from the paper, which has no such gate -- paper uses orthogonal init of U so
|
||||
R != I from step 0. We trade paper-faithful init for identity-at-init.
|
||||
|
||||
OMITTED: paper also adds an orthogonality regularizer
|
||||
lambda * sum_i (u_i^T u_j)^2 (Eq. 6 / Sec. 3.3)
|
||||
which is a loss term, not a forward-pass change. Add it in your training loop if
|
||||
you want the regularized HRA variant.
|
||||
|
||||
Reference implementations (for review/cross-check):
|
||||
- HRA paper authors (DaShenZi721/HRA), llama variant of OFT layer with HRA:
|
||||
https://github.com/DaShenZi721/HRA/blob/master/llama/peft/oft/layer_GS_HRA.py
|
||||
(offline: docs/refs/orig_hra_layer.py)
|
||||
- peft HRA layer (cleaner, includes apply_GS toggle for orthogonalization):
|
||||
https://github.com/huggingface/peft/blob/main/src/peft/tuners/hra/layer.py
|
||||
(offline: docs/refs/peft_hra_layer.py)
|
||||
"""
|
||||
import torch
|
||||
from einops import einsum
|
||||
|
||||
@@ -16,6 +16,11 @@ DEVIATION FROM PAPER:
|
||||
|
||||
`up_proj` is the closest stand-in for "FFN intermediate" in gated-MLP blocks
|
||||
(Llama uses gate * up; gating the up branch is the IA3-spirit choice).
|
||||
|
||||
Reference implementations (for review/cross-check):
|
||||
- peft IA3 layer (uses ia3_l elementwise scaling, fan_in_fan_out aware):
|
||||
https://github.com/huggingface/peft/blob/main/src/peft/tuners/ia3/layer.py
|
||||
(offline: docs/refs/peft_ia3_layer.py)
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
h = W x + (alpha/r) B A x
|
||||
|
||||
Identity at t=0 from B=0. Faithful to the paper.
|
||||
|
||||
Reference implementations (for review/cross-check):
|
||||
- peft Linear.update_layer + lora_A/B init, forward:
|
||||
https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py
|
||||
(see docs/refs/peft_lora_layer.py for offline copy)
|
||||
"""
|
||||
from einops import einsum
|
||||
from torch import nn
|
||||
|
||||
@@ -2,6 +2,21 @@
|
||||
|
||||
Meng et al. 2024 https://arxiv.org/abs/2404.02948
|
||||
W_eff(t=0) = W_res + B@A = W (numerically; bf16 round-trip not bit-exact).
|
||||
|
||||
DEVIATION FROM PAPER (documented):
|
||||
- Paper sets adapter scale = 1 (no alpha/r factor); we keep LoRA's alpha/r
|
||||
pipeline so callers must pass alpha=r to get paper-faithful identity.
|
||||
- Saved adapter does NOT include W_res; load() recomputes PiSSA init on the
|
||||
*same-seed base* before overwriting A/B. Reload is exact only on identical
|
||||
base weights.
|
||||
|
||||
Reference implementations (for review/cross-check):
|
||||
- PiSSA original (NeurIPS'24 spotlight) init script (SVD on dequant W):
|
||||
https://github.com/MuLabPKU/PiSSA/blob/main/utils/init_pissa.py
|
||||
(offline: docs/refs/orig_pissa_init.py)
|
||||
- peft PiSSA flavor (init_lora_weights='pissa') in:
|
||||
https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py
|
||||
(offline: docs/refs/peft_lora_layer.py, see pissa_init / loftq_init paths)
|
||||
"""
|
||||
import torch
|
||||
from einops import einsum
|
||||
|
||||
Reference in New Issue
Block a user