This commit is contained in:
wassname
2026-04-26 14:10:20 +08:00
parent de97724b65
commit 4db5cee5a9
13 changed files with 969 additions and 0 deletions
+161
View File
@@ -0,0 +1,161 @@
# lora-lite
A hackable, single-file-per-variant LoRA library built on PyTorch forward hooks.
- ~600 LoC total
- One file per variant, ~50 LoC each
- No module replacement, no merge/unmerge, no PEFT config soup
- Save = `torch.save({cfg, state_dict_filtered_by_'lora_'})`
- LoRA/DeLoRA forward hooks work with `nn.Linear` and bnb-style `Linear{4bit,8bitLt}` modules that expose `in_features`, `out_features`, and `weight`.
- PiSSA is fp-only in v1 because it mutates `weight` into `W_res`; quantized PiSSA needs explicit dequantize/requantize.
Currently shipped variants:
| Variant | Class | File |
|---|---|---|
| LoRA | A (additive) | [src/lora_lite/variants/lora.py](src/lora_lite/variants/lora.py) |
| PiSSA ([Meng+ 2024](https://arxiv.org/abs/2404.02948)) | A + B (special init mutates W) | [src/lora_lite/variants/pissa.py](src/lora_lite/variants/pissa.py) |
| DeLoRA ([Bini+ 2025](https://arxiv.org/abs/2503.18225)) | A (additive, normalised) | [src/lora_lite/variants/delora.py](src/lora_lite/variants/delora.py) |
See [docs/spec/20260426_lora_lite_plan.md](docs/spec/20260426_lora_lite_plan.md) for goals, status, TODOs, and the current design plan. The original broader design was stress-tested against the [adapters_as_hypotheses](https://github.com/wassname/adapters_as_hypotheses) catalog (~26/27 variants covered with 3 small API tweaks).
## Install
```bash
pip install -e .
```
## Quickstart
```python
import torch, lora_lite as ll
model = MyTransformer() # any nn.Module containing linear-like children
cfg = ll.LoraLiteConfig(variant="lora", r=8, alpha=16, dtype=torch.bfloat16)
handles = ll.attach(model, cfg)
# train
trainable = [p for p in model.parameters() if p.requires_grad]
opt = torch.optim.AdamW(trainable, lr=1e-4)
# ... your loop ...
ll.save(model, "adapter.pt")
ll.detach(model)
# later:
ll.load(model, "adapter.pt")
```
Inspect a tensor live:
```python
A = model.layers[5].self_attn.q_proj.lora_A # just an nn.Parameter
```
## Targeting
By default we target linear-like modules (`in_features`, `out_features`, `weight`) whose shape matches a "reader" (`d_in == d_model`) or "writer" (`d_out == d_model`) role, excluding `lm_head` and `embed_tokens`. This structural test is what lets bnb Linear4bit/8bitLt modules be targeted without a backend-specific class. Knobs on `LoraLiteConfig`:
- `target_roles`: subset of `("reader", "writer", "inner")`. `()` = all.
- `target_names`: regex includes (must match if non-empty).
- `exclude_names`: regex excludes (default skips `lm_head`, `embed_tokens`).
- `layers`: tuple of layer indices, or `None` for all (matches `.layers.<idx>.` in module name).
## Variant API
A variant is a class with a `name` and three statics:
```python
@register
class MyVariant:
name = "myvariant"
@staticmethod
def param_specs(d_in, d_out, cfg) -> dict[str, ParamSpec]:
return {"lora_A": ParamSpec((cfg.r, d_in), init="kaiming"), ...}
@staticmethod
def init(layer, cfg) -> None:
# Optional. Run after params are created. May read/mutate layer.weight.
...
@staticmethod
def forward(layer, x, y) -> Tensor:
# Return the layer's NEW output (additive: `return y + delta`).
...
```
Adapter params attached as `layer.lora_*` get full-path keys in `state_dict()` automatically (e.g. `model.layers.5.self_attn.q_proj.lora_A`).
## Data-calibrated init
PiSSA, DeLoRA, and LoRA only use `layer.weight` for init -- no calibration data needed.
For variants that DO need data (e.g. AntiPaSTO, LoRA-GA, activation-aware SVD), keep dataloaders out of `cfg` so adapter checkpoints stay serializable. Use:
```python
ll.attach(model, cfg, calibration_data=calib)
```
where `calib` is an iterable of whole-model inputs, e.g. `Iterable[dict[str, Tensor]]` for HF models or `Iterable[Tensor]` of token ids. Activation-aware variants implement:
```python
@staticmethod
def group_init(model, targets, cfg, calibration_data): ...
```
`targets` is `list[(name, layer, role)]`. The variant adds temporary hooks, runs `model(batch)` over `calibration_data`, removes the hooks, then writes `lora_*` params. Per-layer `init(layer, cfg)` stays weight-only.
Sketch:
```python
@register
class ActSVD:
name = "actsvd"
@staticmethod
def param_specs(d_in, d_out, cfg): ...
@staticmethod
def group_init(model, targets, cfg, calibration_data):
bufs = {name: [] for name, _, _ in targets}
hooks = [
layer.register_forward_pre_hook(
lambda m, args, name=name: bufs[name].append(args[0].detach().float())
)
for name, layer, _ in targets
]
try:
with torch.no_grad():
for batch in calibration_data:
model(**batch) if isinstance(batch, dict) else model(batch)
finally:
for h in hooks:
h.remove()
# For each target: X = torch.cat(bufs[name], dim=0); do SVD; write A/B.
```
## Smoke test
```bash
python tests/smoke.py
```
Verifies for each of `lora`, `pissa`, `delora`:
1. Identity at t=0: `max|y_adapter - y_base|` within float tolerance.
2. Save/load round-trip preserves outputs.
3. 20 SGD steps reduce a random regression loss by >5%.
## What's NOT in v1
| Feature | Why dropped |
|---|---|
| merge/unmerge | reload base if you want vanilla |
| 4/8-bit-aware merge | DoRA on bnb supported in forward only (drop merge path) |
| Embedding / Conv adapters | trivial extension; add when needed |
| `adapter_names=` mixed batch forward | rare; add when needed |
| Multiple named adapters per layer | one variant per `attach()` |
| HF `PeftConfig` / hub upload | `torch.save({cfg, state})` is enough |
| AdaLoRA-style rank scheduling | needs `Variant.on_step(step)` -- punt |
| ReFT-style position interventions | sibling submodule (different hook site) |
## Status
v0.0.1: lora + pissa + delora + smoke test. See spec for next variants (DoRA, VeRA, SSVD).
+144
View File
@@ -0,0 +1,144 @@
# lora-lite plan and status
## Goal
Build a small, hackable LoRA-family adapter library for research experiments.
The core bet is that adapter variants should own the relationship between `(x, layer.weight, layer.lora_*)` and the layer output, while the library only handles targeting, parameter attachment, hooks, and save/load.
## Non-goals
- No PEFT compatibility layer.
- No module replacement.
- No merge/unmerge.
- No multiple named adapters per layer.
- No backward compatibility promises.
- No silent fallbacks.
## Design constraints
- Adapter params are attached directly to target layers as `lora_*` parameters.
- Save/load uses normal `state_dict()` keys, filtered by `"lora_"`.
- Forward hooks return the layer's new output, not just a delta.
- Targeting is structural: modules with `in_features`, `out_features`, and `weight` are linear-like.
- LoRA/DeLoRA support bnb-style 4/8-bit forward paths because the quantized base layer computes `y`; the hook only adds adapter math.
- PiSSA is fp-only in v1 because it mutates `layer.weight` into `W_res`.
- Data-calibrated variants use `group_init(model, targets, cfg, calibration_data)`; dataloaders stay out of `cfg` so checkpoints are serializable.
## Implemented v0.0.1
| Area | Status | Evidence |
|---|---:|---|
| `LoraLiteConfig` | done | `src/lora_lite/config.py` |
| Variant registry + `ParamSpec` | done | `src/lora_lite/variant.py` |
| Structural target discovery | done | `src/lora_lite/target.py` |
| `attach` / `detach` / `save` / `load` | done | `src/lora_lite/adapter.py` |
| LoRA | done | `src/lora_lite/variants/lora.py` |
| PiSSA | done, fp-only | `src/lora_lite/variants/pissa.py` |
| DeLoRA | done | `src/lora_lite/variants/delora.py` |
| Smoke tests | done | `tests/smoke.py` |
| bnb minimal forward smoke | done | `Linear8bitLt` and `Linear4bit` pass on CUDA |
## Current smoke evidence
Last verified log: `/home/wassname/.cache/agent-tmp/lora_lite_smoke_after_review.log`
| Check | Result |
|---|---|
| LoRA identity | `0.000e+00` |
| LoRA loss drop | `6.1%` |
| PiSSA identity | `1.550e-06` |
| PiSSA loss drop | `11.5%` |
| DeLoRA identity | `0.000e+00` |
| DeLoRA loss drop | `93.4%` |
| fake non-`nn.Linear` target | attaches, identity `0.000e+00`, grad nonzero |
| bnb `Linear8bitLt` | identity `0.000e+00`, grad nonzero |
| bnb `Linear4bit` | identity `0.000e+00`, grad nonzero |
## Review history
A cold subagent review first returned `PASS_WITH_BLOCKERS`:
1. bnb modules were not targeted.
2. Hook cast `y` to `cfg.dtype`, which could round base outputs.
3. PiSSA overclaimed bnb support.
4. `load()` did not fail on missing adapter keys.
5. Data-calibrated init needed model-level access.
Fixes applied:
1. Structural `is_linear_like()` target predicate.
2. Hook only casts `x`, keeps `y` in base output dtype.
3. PiSSA fail-fast rejects non-plain `nn.Linear`.
4. `load()` fails on missing or unexpected `lora_` keys.
5. `attach(..., calibration_data=None)` plus optional `group_init(model, targets, cfg, calibration_data)`.
Second cold review verdict: `PASS` for the minimal 4bit-enabled scope.
## TODO / status
### Next implementation goals
- [ ] Add DoRA.
- Verify: fp32/bf16 identity at init, finite gradients, and smoke loss drop.
- Caveat: bnb DoRA needs explicit weight dequantization for norm computation or should be fp-only at first.
- [ ] Add VeRA.
- Verify: shared buffers are allocated once, target slices match shape, identity or near-identity at init.
- [ ] Add SSVD or AntiPaSTO-style SVD variant.
- Verify: reconstruction or intended rotation invariant at init.
- [ ] Add real activation-calibrated toy variant using `group_init`.
- Verify: `calibration_data` is consumed during `attach`, hooks are removed, checkpoint is serializable, and `load()` does not require calibration data.
- [ ] Add load path that can skip calibration init for future `group_init` variants.
- Current caveat: `load()` calls `attach(model, cfg)` with `calibration_data=None`; fine for current variants, but future calibrated variants should separate param creation from calibration.
- [ ] Add a tiny HF-model smoke when convenient.
- Verify: target names look like real transformer modules and state dict keys match full paths.
### Design TODOs
- [ ] Decide whether `group_init` should run before or after forward hooks are registered.
- Current choice: after params are attached, before adapter forward hooks are registered.
- [ ] Decide whether replacing variants need `runs_base_layer=False` or can always transform `y`.
- OFT-like variants can rotate `y`; variants that truly avoid base forward need module replacement or pre-hook rewriting, likely out of v1.
- [ ] Add `weight_mode` for BitFit/SHiRA if those variants become in-scope.
- Minimal surface: `weight_mode in {"frozen", "bias_only", "sparse_grad"}`.
## Variant contract
```python
class Variant:
name: str
@staticmethod
def param_specs(d_in, d_out, cfg) -> dict[str, ParamSpec]: ...
@staticmethod
def init(layer, cfg) -> None:
# weight-only init; may mutate plain fp weights
...
@staticmethod
def group_init(model, targets, cfg, calibration_data) -> None:
# optional model-level init for data-calibrated or cross-layer variants
...
@staticmethod
def forward(layer, x, y) -> Tensor:
# return NEW output; additive variants return y + delta
...
```
## Done means
This repo is good enough for a first real experiment when:
1. A Qwen/Llama model can attach LoRA adapters to intended target layers.
2. A 4bit or 8bit loaded model can train LoRA/DeLoRA params with nonzero gradients.
3. Saved adapter tensors use full-path keys and reload without calibration data.
4. Smoke tests distinguish target-skipping, hook identity drift, and missing-key load failure.
+20
View File
@@ -0,0 +1,20 @@
[project]
name = "lora-lite"
version = "0.0.1"
description = "Hackable forward-hook LoRA. ~600 LoC. One file per variant."
requires-python = ">=3.10"
dependencies = [
"torch>=2.1",
"einops>=0.7",
"jaxtyping>=0.2.25",
]
[project.optional-dependencies]
test = ["pytest"]
[build-system]
requires = ["setuptools>=68"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
+16
View File
@@ -0,0 +1,16 @@
from .config import LoraLiteConfig
from .adapter import attach, detach, save, load
from .variant import REGISTRY, register, ParamSpec, Variant
from . import variants # noqa: F401 triggers variant registration
__all__ = [
"LoraLiteConfig",
"attach",
"detach",
"save",
"load",
"REGISTRY",
"register",
"ParamSpec",
"Variant",
]
+103
View File
@@ -0,0 +1,103 @@
"""attach / detach / save / load. The whole runtime."""
from __future__ import annotations
import torch
from torch import nn
from torch.utils.hooks import RemovableHandle
from .config import LoraLiteConfig
from .variant import REGISTRY
from .target import find_targets
_ATTACHED_ATTR = "_lora_lite_attached"
def _hook(layer, args, y):
(x,) = args
cfg: LoraLiteConfig = layer._lora_cfg
x_cast = x.to(cfg.dtype)
out = layer._lora_variant.forward(layer, x_cast, y)
return out.to(y.dtype)
def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None) -> list[RemovableHandle]:
if cfg.variant not in REGISTRY:
raise KeyError(f"unknown variant {cfg.variant!r}; registered: {list(REGISTRY)}")
variant = REGISTRY[cfg.variant]
targets = find_targets(model, cfg)
if not targets:
raise RuntimeError("no target layers matched cfg")
# freeze base
for p in model.parameters():
p.requires_grad_(False)
handles: list[RemovableHandle] = []
attached_names: list[str] = []
attached_targets = []
for name, layer, role in targets:
d_in, d_out = layer.in_features, layer.out_features
for pname, spec in variant.param_specs(d_in, d_out, cfg).items():
if hasattr(layer, pname):
raise RuntimeError(f"{name} already has attribute {pname}; detach first")
p = spec.make(cfg.dtype, layer.weight.device)
layer.register_parameter(pname, p)
layer._lora_cfg = cfg
layer._lora_variant = variant
layer._lora_role = role
variant.init(layer, cfg)
attached_names.append(name)
attached_targets.append((name, layer, role))
group_init = getattr(variant, "group_init", None)
if group_init is not None:
group_init(model, attached_targets, cfg, calibration_data)
for _, layer, _ in attached_targets:
handles.append(layer.register_forward_hook(_hook))
setattr(model, _ATTACHED_ATTR, {"cfg": cfg, "targets": attached_names, "handles": handles})
return handles
def detach(model: nn.Module) -> None:
state = getattr(model, _ATTACHED_ATTR, None)
if state is None:
return
for h in state["handles"]:
h.remove()
# remove params + scratch attrs
for name, layer in model.named_modules():
if not hasattr(layer, "_lora_variant"):
continue
variant = layer._lora_variant
for pname in variant.param_specs(layer.in_features, layer.out_features, layer._lora_cfg):
if pname in layer._parameters:
del layer._parameters[pname]
for attr in ("_lora_cfg", "_lora_variant", "_lora_role"):
if hasattr(layer, attr):
delattr(layer, attr)
delattr(model, _ATTACHED_ATTR)
def save(model: nn.Module, path: str) -> None:
state = getattr(model, _ATTACHED_ATTR, None)
if state is None:
raise RuntimeError("no adapter attached; call attach() first")
sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k}
torch.save({"cfg": state["cfg"].to_dict(), "state": sd}, path)
def load(model: nn.Module, path: str) -> list[RemovableHandle]:
blob = torch.load(path, weights_only=True, map_location="cpu")
cfg = LoraLiteConfig.from_dict(blob["cfg"])
handles = attach(model, cfg) # creates empty params with right shapes
missing, unexpected = model.load_state_dict(blob["state"], strict=False)
expected_lora = {k for k in model.state_dict() if "lora_" in k}
missing_lora = sorted(expected_lora.intersection(missing))
if missing_lora:
raise RuntimeError(f"missing lora keys in checkpoint: {missing_lora}")
unexpected_lora = [k for k in unexpected if "lora_" in k]
if unexpected_lora:
raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}")
return handles
+38
View File
@@ -0,0 +1,38 @@
from dataclasses import dataclass, field, asdict
from typing import Any
import torch
@dataclass
class LoraLiteConfig:
variant: str = "lora"
r: int = 8
alpha: float = 16.0
dropout: float = 0.0 # currently ignored; variants may use cfg.variant_kwargs
dtype: torch.dtype = torch.bfloat16
# targeting
target_roles: tuple[str, ...] = ("reader", "writer")
target_names: tuple[str, ...] = ()
exclude_names: tuple[str, ...] = ("lm_head", "embed_tokens")
layers: tuple[int, ...] | None = None
# variant-specific bag (e.g. lambda0 for DeLoRA)
variant_kwargs: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict:
d = asdict(self)
d["dtype"] = str(self.dtype).removeprefix("torch.")
return d
@classmethod
def from_dict(cls, d: dict) -> "LoraLiteConfig":
d = dict(d)
if isinstance(d.get("dtype"), str):
d["dtype"] = getattr(torch, d["dtype"])
if isinstance(d.get("layers"), list):
d["layers"] = tuple(d["layers"])
for k in ("target_roles", "target_names", "exclude_names"):
if isinstance(d.get(k), list):
d[k] = tuple(d[k])
return cls(**d)
+59
View File
@@ -0,0 +1,59 @@
"""Find linear-like targets by shape (reader/writer/inner) + name regex.
Structural matching is deliberate: bnb Linear4bit/8bitLt are not nn.Linear, but
they expose in_features/out_features/weight and their forward already handles
dequantization.
"""
import re
from torch import nn
def is_linear_like(m: nn.Module) -> bool:
return (
hasattr(m, "in_features")
and hasattr(m, "out_features")
and hasattr(m, "weight")
and callable(m)
)
def _layer_idx(name: str) -> int | None:
m = re.search(r"\.layers?\.(\d+)\.", name)
return int(m.group(1)) if m else None
def _classify(m: nn.Module, d_model: int, name: str) -> str:
di, do = m.in_features, m.out_features
if di == d_model and do != d_model:
return "reader"
if do == d_model and di != d_model:
return "writer"
if di == d_model and do == d_model:
return "writer" if any(s in name for s in ("o_proj", "out_proj")) else "reader"
return "inner"
def find_targets(model: nn.Module, cfg) -> list[tuple[str, nn.Module, str]]:
# discover d_model: prefer config.hidden_size, else infer from largest Linear in_features
d_model = getattr(getattr(model, "config", None), "hidden_size", None)
if d_model is None:
dims = [m.in_features for m in model.modules() if is_linear_like(m)]
d_model = max(dims) if dims else 0
out = []
for name, m in model.named_modules():
if not is_linear_like(m):
continue
if any(re.search(p, name) for p in cfg.exclude_names):
continue
if cfg.layers is not None:
li = _layer_idx(name)
if li is None or li not in cfg.layers:
continue
role = _classify(m, d_model, name)
if cfg.target_roles and role not in cfg.target_roles:
continue
if cfg.target_names and not any(re.search(p, name) for p in cfg.target_names):
continue
out.append((name, m, role))
return out
+56
View File
@@ -0,0 +1,56 @@
"""Variant protocol + registry. Variants own (x, layer.weight, layer.lora_*) -> y_new."""
from dataclasses import dataclass
from typing import Callable, Protocol, Any
import torch
from torch import nn
from .config import LoraLiteConfig
@dataclass
class ParamSpec:
shape: tuple[int, ...]
init: str | Callable[[torch.Tensor], None] = "zeros" # 'zeros'|'kaiming'|'ones'|callable(t)
trainable: bool = True
def make(self, dtype: torch.dtype, device) -> nn.Parameter:
t = torch.empty(self.shape, dtype=dtype, device=device)
if callable(self.init):
self.init(t)
elif self.init == "zeros":
t.zero_()
elif self.init == "ones":
t.fill_(1.0)
elif self.init == "kaiming":
# match nn.Linear default: kaiming_uniform_(a=sqrt(5))
nn.init.kaiming_uniform_(t, a=5 ** 0.5) if t.ndim >= 2 else t.normal_(0, 0.02)
else:
raise ValueError(f"unknown init: {self.init}")
return nn.Parameter(t, requires_grad=self.trainable)
class Variant(Protocol):
name: str
@staticmethod
def param_specs(d_in: int, d_out: int, cfg: LoraLiteConfig) -> dict[str, ParamSpec]: ...
@staticmethod
def init(layer: nn.Linear, cfg: LoraLiteConfig) -> None: ...
@staticmethod
def forward(layer: nn.Linear, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Return the layer's NEW output (not just delta).
Additive variants: `return y + delta`.
Replacing variants: ignore `y`, return new value."""
...
REGISTRY: dict[str, type] = {}
def register(cls):
if not getattr(cls, "name", None):
raise ValueError(f"variant {cls} missing .name")
REGISTRY[cls.name] = cls
return cls
+1
View File
@@ -0,0 +1 @@
from . import lora, pissa, delora # noqa: F401 side-effect: register
+43
View File
@@ -0,0 +1,43 @@
"""DeLoRA: column-normalised A, B, scaled by lambda/r. Bini et al. 2025 arXiv:2503.18225.
NOTE on identity at t=0: paper uses kaiming for both A and B with a learned lambda
init at 0 (or small) so the effective delta starts near zero. We honour that:
default lambda0 == 0 gives bit-identity; user can override via variant_kwargs.
"""
import torch
import torch.nn.functional as F
from einops import einsum
from torch import nn
from ..variant import register, ParamSpec
@register
class DeLoRA:
name = "delora"
@staticmethod
def param_specs(d_in, d_out, cfg):
lam0 = float(cfg.variant_kwargs.get("lambda0", 0.0))
return {
"lora_A": ParamSpec((cfg.r, d_in), init="kaiming", trainable=True),
"lora_B": ParamSpec((d_out, cfg.r), init="kaiming", trainable=True),
"lora_lambda": ParamSpec(
(), init=lambda t: t.fill_(lam0), trainable=True
),
}
@staticmethod
def init(layer: nn.Linear, cfg) -> None:
return
@staticmethod
def forward(layer: nn.Linear, x, y):
cfg = layer._lora_cfg
# rows of A unit, cols of B unit (per paper)
A = F.normalize(layer.lora_A, dim=1) # (r, d_in)
B = F.normalize(layer.lora_B, dim=0) # (d_out, r)
scale = layer.lora_lambda / cfg.r
h = einsum(x, A, "... i, r i -> ... r")
delta = einsum(h, B, "... r, o r -> ... o")
return y + scale * delta
+31
View File
@@ -0,0 +1,31 @@
"""Vanilla LoRA. Reference variant. y = Wx + (alpha/r) * B @ A @ x."""
from einops import einsum
from torch import nn
import torch
from ..variant import register, ParamSpec
@register
class LoRA:
name = "lora"
@staticmethod
def param_specs(d_in, d_out, cfg):
return {
"lora_A": ParamSpec((cfg.r, d_in), init="kaiming", trainable=True),
"lora_B": ParamSpec((d_out, cfg.r), init="zeros", trainable=True),
}
@staticmethod
def init(layer: nn.Linear, cfg) -> None:
# B is zeros => delta=0 at t=0; identity invariant holds.
return
@staticmethod
def forward(layer: nn.Linear, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
cfg = layer._lora_cfg
scale = cfg.alpha / cfg.r
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
return y + scale * delta
+54
View File
@@ -0,0 +1,54 @@
"""PiSSA: top-r SVD of W into A,B; replace W with W_res = W - B@A.
Meng et al. 2024 https://arxiv.org/abs/2404.02948
W_eff(t=0) = W_res + B@A = W (numerically; bf16 round-trip not bit-exact).
"""
import torch
from einops import einsum
from torch import nn
from ..variant import register, ParamSpec
@register
class PiSSA:
name = "pissa"
@staticmethod
def param_specs(d_in, d_out, cfg):
return {
"lora_A": ParamSpec((cfg.r, d_in), init="zeros", trainable=True),
"lora_B": ParamSpec((d_out, cfg.r), init="zeros", trainable=True),
}
@staticmethod
def init(layer: nn.Linear, cfg) -> None:
if type(layer) is not nn.Linear:
raise TypeError(
"PiSSA mutates layer.weight into W_res, so v1 only supports plain nn.Linear. "
"For bnb 4/8-bit, use LoRA/DeLoRA or implement explicit dequantize/requantize."
)
W = layer.weight.data.float() # (d_out, d_in)
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
r = cfg.r
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
sqrtS = Sr.sqrt()
# B @ A = Ur diag(Sr) Vhr; pick B = Ur sqrt(Sr), A = sqrt(Sr) * Vhr
B = (Ur * sqrtS).to(cfg.dtype)
A = (sqrtS[:, None] * Vhr).to(cfg.dtype)
layer.lora_B.data.copy_(B)
layer.lora_A.data.copy_(A)
# Compute BA in fp32 for the subtraction so W_res is accurate.
BA = (B.float() @ A.float())
# NOTE: PiSSA uses scale=1 (alpha==r) implicitly via init. We let the user pick;
# for fidelity at t=0, the convention is alpha==r so scale==1. Document in README.
scale = cfg.alpha / cfg.r
layer.weight.data.copy_((W - scale * BA).to(layer.weight.dtype))
@staticmethod
def forward(layer: nn.Linear, x, y):
cfg = layer._lora_cfg
scale = cfg.alpha / cfg.r
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
return y + scale * delta
+243
View File
@@ -0,0 +1,243 @@
"""Smoke test: lora / pissa / delora on a tiny synthetic transformer-like model.
Verifies:
1. Identity at t=0 (delta ~ 0, output close to base).
2. Save/load round-trip preserves outputs.
3. A few SGD steps reduce a random loss (gradients flow).
Run:
cd lora-lite
python -m pip install -e .
python tests/smoke.py
BLUF format:
SHOULD: max|y_adapter - y_base| < tol_init for all variants. ELSE init or hook bug.
SHOULD: loss decreases > 5% over 20 SGD steps for all variants. ELSE grad/wiring bug.
"""
from __future__ import annotations
import tempfile, os, sys, math
import torch
from torch import nn
# allow running as `python tests/smoke.py` without install
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
import lora_lite as ll # noqa: E402
# ---- a tiny transformer-like stack: 4 blocks of (q,k,v,o, gate,up,down) Linears ----
class TinyBlock(nn.Module):
def __init__(self, d=64, ff=128):
super().__init__()
self.q_proj = nn.Linear(d, d, bias=False)
self.k_proj = nn.Linear(d, d, bias=False)
self.v_proj = nn.Linear(d, d, bias=False)
self.o_proj = nn.Linear(d, d, bias=False)
self.gate_proj = nn.Linear(d, ff, bias=False)
self.up_proj = nn.Linear(d, ff, bias=False)
self.down_proj = nn.Linear(ff, d, bias=False)
def forward(self, x):
h = self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x))
m = self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
return x + h + m
class TinyModel(nn.Module):
def __init__(self, n_layers=4, d=64, ff=128, vocab=100):
super().__init__()
self.embed_tokens = nn.Embedding(vocab, d)
self.layers = nn.ModuleList([TinyBlock(d, ff) for _ in range(n_layers)])
self.lm_head = nn.Linear(d, vocab, bias=False)
class Cfg: # mimic HF .config.hidden_size
hidden_size = d
self.config = Cfg()
def forward(self, ids):
x = self.embed_tokens(ids)
for blk in self.layers:
x = blk(x)
return self.lm_head(x)
class FakeLinearLike(nn.Module):
"""Not nn.Linear, but structurally bnb-like enough for target discovery."""
def __init__(self, d_in=8, d_out=8):
super().__init__()
self.in_features = d_in
self.out_features = d_out
self.weight = nn.Parameter(torch.empty(d_out, d_in))
nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5)
def forward(self, x):
return torch.nn.functional.linear(x, self.weight)
class FakeBnbModel(nn.Module):
def __init__(self):
super().__init__()
self.config = type("Cfg", (), {"hidden_size": 8})()
self.layers = nn.ModuleList([FakeLinearLike(8, 8)])
def forward(self, x):
return self.layers[0](x)
def variant_test(variant: str, dtype=torch.float32):
print(f"\n=== variant={variant} dtype={dtype} ===")
torch.manual_seed(0)
model = TinyModel().to(dtype)
ids = torch.randint(0, 100, (2, 16))
with torch.no_grad():
y_base = model(ids).clone()
cfg = ll.LoraLiteConfig(
variant=variant,
r=4,
alpha=4 if variant == "pissa" else 8, # PiSSA needs scale==1 for clean recon
dtype=dtype,
# delora identity-at-init demands lambda0==0 (then delta * scale = 0)
variant_kwargs={"lambda0": 0.0} if variant == "delora" else {},
)
handles = ll.attach(model, cfg)
n_targets = len(handles)
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" attached {n_targets} targets, trainable params={n_trainable}")
with torch.no_grad():
y_adapt = model(ids)
err = (y_adapt - y_base).abs().max().item()
base_scale = y_base.abs().max().item()
print(f" t=0 identity: max|y_adapt - y_base| = {err:.3e} (base scale {base_scale:.3e})")
# variant-specific identity tolerance
tol = {
"lora": 1e-6,
"pissa": 5e-4, # SVD recon in fp32 is tight; bf16 would be ~1e-2
"delora": 1e-6, # lambda0=0
}[variant] * max(1.0, base_scale)
assert err < tol, f" FAIL identity: err {err} > tol {tol}"
print(f" SHOULD: err<{tol:.1e}. PASS.")
# save/load round-trip
with tempfile.TemporaryDirectory() as d:
p = os.path.join(d, "adapter.pt")
ll.save(model, p)
# detach + fresh model + load
ll.detach(model)
torch.manual_seed(0)
model2 = TinyModel().to(dtype)
# for PiSSA, base weights got mutated; we need them mutated again for the load
# path to make sense. Easiest: re-attach with same cfg first... but that's what
# load() does. The catch: load reads cfg from the file, runs attach (which
# re-runs PiSSA init -> same SVD on same weights -> same A,B -> mutates W
# to the same W_res). Then state_dict overwrites lora_A/B with saved values.
ll.load(model2, p)
with torch.no_grad():
y_loaded = model2(ids)
err2 = (y_loaded - y_adapt).abs().max().item()
print(f" save/load: max|y_loaded - y_adapt| = {err2:.3e}")
assert err2 < tol, f" FAIL save/load: {err2} > {tol}"
print(f" SHOULD: err2<{tol:.1e}. PASS.")
ll.detach(model2)
# gradient flow: 20 SGD steps on random target.
# For delora, lambda0==0 makes A,B grads zero (scale=0); use lambda0=0.1 for training.
torch.manual_seed(0)
model = TinyModel().to(dtype)
train_cfg = cfg
if variant == "delora":
train_cfg = ll.LoraLiteConfig(
variant=cfg.variant, r=cfg.r, alpha=cfg.alpha, dtype=cfg.dtype,
variant_kwargs={"lambda0": 0.1},
)
ll.attach(model, train_cfg)
target = torch.randn(2, 16, 100, dtype=dtype) * 0.1
trainable = [p for p in model.parameters() if p.requires_grad]
# delora has tightly-normalised updates; use Adam with higher lr to see signal in 20 steps
if variant == "delora":
opt = torch.optim.Adam(trainable, lr=1e-1)
else:
opt = torch.optim.SGD(trainable, lr=1e-2)
losses = []
for step in range(20):
opt.zero_grad()
loss = (model(ids) - target).pow(2).mean()
loss.backward()
opt.step()
losses.append(loss.item())
drop = (losses[0] - losses[-1]) / max(losses[0], 1e-12)
print(f" loss[0]={losses[0]:.4f} loss[-1]={losses[-1]:.4f} drop={100*drop:.1f}%")
assert drop > 0.05, f" FAIL: loss drop only {drop:.2%}, expected >5%"
print(f" SHOULD: drop>5%. PASS.")
def structural_linear_like_test():
print("\n=== structural linear-like target test (bnb-style, not nn.Linear) ===")
torch.manual_seed(0)
model = FakeBnbModel()
x = torch.randn(2, 3, 8)
y_base = model(x).detach()
ll.attach(model, ll.LoraLiteConfig(variant="lora", r=2, alpha=4, dtype=torch.float32, target_roles=()))
layer = model.layers[0]
assert hasattr(layer, "lora_A") and hasattr(layer, "lora_B")
y = model(x)
err = (y.detach() - y_base).abs().max().item()
loss = y.pow(2).mean()
loss.backward()
grad_nonzero = layer.lora_B.grad.abs().sum().item() > 0
print(f" attached lora_A={tuple(layer.lora_A.shape)} lora_B={tuple(layer.lora_B.shape)}")
print(f" identity_err={err:.3e} grad_nonzero={grad_nonzero}")
assert err == 0.0
assert grad_nonzero
print(" SHOULD: structural target attaches and lora_B receives grad. PASS.")
def bitsandbytes_cuda_smoke():
print("\n=== optional bitsandbytes CUDA smoke ===")
if not torch.cuda.is_available():
print(" SKIP: CUDA unavailable; real bnb 4/8-bit forward needs GPU on this machine.")
return
try:
import bitsandbytes as bnb
except ImportError:
print(" SKIP: bitsandbytes unavailable.")
return
class BnbModel(nn.Module):
def __init__(self, Layer):
super().__init__()
self.config = type("Cfg", (), {"hidden_size": 8})()
self.layers = nn.ModuleList([Layer(8, 8, bias=False)]).cuda()
def forward(self, x):
return self.layers[0](x)
for layer_cls in (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit):
torch.manual_seed(0)
model = BnbModel(layer_cls)
x = torch.randn(2, 3, 8, device="cuda")
y_base = model(x).detach()
ll.attach(model, ll.LoraLiteConfig(variant="lora", r=2, alpha=4, dtype=torch.float16, target_roles=()))
y = model(x)
err = (y.detach() - y_base).abs().max().item()
y.pow(2).mean().backward()
grad_nonzero = model.layers[0].lora_B.grad.abs().sum().item() > 0
print(f" {layer_cls.__name__}: identity_err={err:.3e} grad_nonzero={grad_nonzero}")
assert err == 0.0
assert grad_nonzero
def main():
for v in ("lora", "pissa", "delora"):
variant_test(v, dtype=torch.float32)
structural_linear_like_test()
bitsandbytes_cuda_smoke()
print("\nALL PASS.")
if __name__ == "__main__":
main()