mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 16:15:50 +08:00
init
This commit is contained in:
@@ -0,0 +1,161 @@
|
||||
# lora-lite
|
||||
|
||||
A hackable, single-file-per-variant LoRA library built on PyTorch forward hooks.
|
||||
|
||||
- ~600 LoC total
|
||||
- One file per variant, ~50 LoC each
|
||||
- No module replacement, no merge/unmerge, no PEFT config soup
|
||||
- Save = `torch.save({cfg, state_dict_filtered_by_'lora_'})`
|
||||
- LoRA/DeLoRA forward hooks work with `nn.Linear` and bnb-style `Linear{4bit,8bitLt}` modules that expose `in_features`, `out_features`, and `weight`.
|
||||
- PiSSA is fp-only in v1 because it mutates `weight` into `W_res`; quantized PiSSA needs explicit dequantize/requantize.
|
||||
|
||||
Currently shipped variants:
|
||||
|
||||
| Variant | Class | File |
|
||||
|---|---|---|
|
||||
| LoRA | A (additive) | [src/lora_lite/variants/lora.py](src/lora_lite/variants/lora.py) |
|
||||
| PiSSA ([Meng+ 2024](https://arxiv.org/abs/2404.02948)) | A + B (special init mutates W) | [src/lora_lite/variants/pissa.py](src/lora_lite/variants/pissa.py) |
|
||||
| DeLoRA ([Bini+ 2025](https://arxiv.org/abs/2503.18225)) | A (additive, normalised) | [src/lora_lite/variants/delora.py](src/lora_lite/variants/delora.py) |
|
||||
|
||||
See [docs/spec/20260426_lora_lite_plan.md](docs/spec/20260426_lora_lite_plan.md) for goals, status, TODOs, and the current design plan. The original broader design was stress-tested against the [adapters_as_hypotheses](https://github.com/wassname/adapters_as_hypotheses) catalog (~26/27 variants covered with 3 small API tweaks).
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Quickstart
|
||||
|
||||
```python
|
||||
import torch, lora_lite as ll
|
||||
|
||||
model = MyTransformer() # any nn.Module containing linear-like children
|
||||
cfg = ll.LoraLiteConfig(variant="lora", r=8, alpha=16, dtype=torch.bfloat16)
|
||||
handles = ll.attach(model, cfg)
|
||||
|
||||
# train
|
||||
trainable = [p for p in model.parameters() if p.requires_grad]
|
||||
opt = torch.optim.AdamW(trainable, lr=1e-4)
|
||||
# ... your loop ...
|
||||
|
||||
ll.save(model, "adapter.pt")
|
||||
ll.detach(model)
|
||||
# later:
|
||||
ll.load(model, "adapter.pt")
|
||||
```
|
||||
|
||||
Inspect a tensor live:
|
||||
|
||||
```python
|
||||
A = model.layers[5].self_attn.q_proj.lora_A # just an nn.Parameter
|
||||
```
|
||||
|
||||
## Targeting
|
||||
|
||||
By default we target linear-like modules (`in_features`, `out_features`, `weight`) whose shape matches a "reader" (`d_in == d_model`) or "writer" (`d_out == d_model`) role, excluding `lm_head` and `embed_tokens`. This structural test is what lets bnb Linear4bit/8bitLt modules be targeted without a backend-specific class. Knobs on `LoraLiteConfig`:
|
||||
|
||||
- `target_roles`: subset of `("reader", "writer", "inner")`. `()` = all.
|
||||
- `target_names`: regex includes (must match if non-empty).
|
||||
- `exclude_names`: regex excludes (default skips `lm_head`, `embed_tokens`).
|
||||
- `layers`: tuple of layer indices, or `None` for all (matches `.layers.<idx>.` in module name).
|
||||
|
||||
## Variant API
|
||||
|
||||
A variant is a class with a `name` and three statics:
|
||||
|
||||
```python
|
||||
@register
|
||||
class MyVariant:
|
||||
name = "myvariant"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg) -> dict[str, ParamSpec]:
|
||||
return {"lora_A": ParamSpec((cfg.r, d_in), init="kaiming"), ...}
|
||||
|
||||
@staticmethod
|
||||
def init(layer, cfg) -> None:
|
||||
# Optional. Run after params are created. May read/mutate layer.weight.
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def forward(layer, x, y) -> Tensor:
|
||||
# Return the layer's NEW output (additive: `return y + delta`).
|
||||
...
|
||||
```
|
||||
|
||||
Adapter params attached as `layer.lora_*` get full-path keys in `state_dict()` automatically (e.g. `model.layers.5.self_attn.q_proj.lora_A`).
|
||||
|
||||
## Data-calibrated init
|
||||
|
||||
PiSSA, DeLoRA, and LoRA only use `layer.weight` for init -- no calibration data needed.
|
||||
|
||||
For variants that DO need data (e.g. AntiPaSTO, LoRA-GA, activation-aware SVD), keep dataloaders out of `cfg` so adapter checkpoints stay serializable. Use:
|
||||
|
||||
```python
|
||||
ll.attach(model, cfg, calibration_data=calib)
|
||||
```
|
||||
|
||||
where `calib` is an iterable of whole-model inputs, e.g. `Iterable[dict[str, Tensor]]` for HF models or `Iterable[Tensor]` of token ids. Activation-aware variants implement:
|
||||
|
||||
```python
|
||||
@staticmethod
|
||||
def group_init(model, targets, cfg, calibration_data): ...
|
||||
```
|
||||
|
||||
`targets` is `list[(name, layer, role)]`. The variant adds temporary hooks, runs `model(batch)` over `calibration_data`, removes the hooks, then writes `lora_*` params. Per-layer `init(layer, cfg)` stays weight-only.
|
||||
|
||||
Sketch:
|
||||
|
||||
```python
|
||||
@register
|
||||
class ActSVD:
|
||||
name = "actsvd"
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg): ...
|
||||
@staticmethod
|
||||
def group_init(model, targets, cfg, calibration_data):
|
||||
bufs = {name: [] for name, _, _ in targets}
|
||||
hooks = [
|
||||
layer.register_forward_pre_hook(
|
||||
lambda m, args, name=name: bufs[name].append(args[0].detach().float())
|
||||
)
|
||||
for name, layer, _ in targets
|
||||
]
|
||||
try:
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
model(**batch) if isinstance(batch, dict) else model(batch)
|
||||
finally:
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
# For each target: X = torch.cat(bufs[name], dim=0); do SVD; write A/B.
|
||||
```
|
||||
|
||||
## Smoke test
|
||||
|
||||
```bash
|
||||
python tests/smoke.py
|
||||
```
|
||||
|
||||
Verifies for each of `lora`, `pissa`, `delora`:
|
||||
1. Identity at t=0: `max|y_adapter - y_base|` within float tolerance.
|
||||
2. Save/load round-trip preserves outputs.
|
||||
3. 20 SGD steps reduce a random regression loss by >5%.
|
||||
|
||||
## What's NOT in v1
|
||||
|
||||
| Feature | Why dropped |
|
||||
|---|---|
|
||||
| merge/unmerge | reload base if you want vanilla |
|
||||
| 4/8-bit-aware merge | DoRA on bnb supported in forward only (drop merge path) |
|
||||
| Embedding / Conv adapters | trivial extension; add when needed |
|
||||
| `adapter_names=` mixed batch forward | rare; add when needed |
|
||||
| Multiple named adapters per layer | one variant per `attach()` |
|
||||
| HF `PeftConfig` / hub upload | `torch.save({cfg, state})` is enough |
|
||||
| AdaLoRA-style rank scheduling | needs `Variant.on_step(step)` -- punt |
|
||||
| ReFT-style position interventions | sibling submodule (different hook site) |
|
||||
|
||||
## Status
|
||||
|
||||
v0.0.1: lora + pissa + delora + smoke test. See spec for next variants (DoRA, VeRA, SSVD).
|
||||
@@ -0,0 +1,144 @@
|
||||
# lora-lite plan and status
|
||||
|
||||
## Goal
|
||||
|
||||
Build a small, hackable LoRA-family adapter library for research experiments.
|
||||
|
||||
The core bet is that adapter variants should own the relationship between `(x, layer.weight, layer.lora_*)` and the layer output, while the library only handles targeting, parameter attachment, hooks, and save/load.
|
||||
|
||||
## Non-goals
|
||||
|
||||
- No PEFT compatibility layer.
|
||||
- No module replacement.
|
||||
- No merge/unmerge.
|
||||
- No multiple named adapters per layer.
|
||||
- No backward compatibility promises.
|
||||
- No silent fallbacks.
|
||||
|
||||
## Design constraints
|
||||
|
||||
- Adapter params are attached directly to target layers as `lora_*` parameters.
|
||||
- Save/load uses normal `state_dict()` keys, filtered by `"lora_"`.
|
||||
- Forward hooks return the layer's new output, not just a delta.
|
||||
- Targeting is structural: modules with `in_features`, `out_features`, and `weight` are linear-like.
|
||||
- LoRA/DeLoRA support bnb-style 4/8-bit forward paths because the quantized base layer computes `y`; the hook only adds adapter math.
|
||||
- PiSSA is fp-only in v1 because it mutates `layer.weight` into `W_res`.
|
||||
- Data-calibrated variants use `group_init(model, targets, cfg, calibration_data)`; dataloaders stay out of `cfg` so checkpoints are serializable.
|
||||
|
||||
## Implemented v0.0.1
|
||||
|
||||
| Area | Status | Evidence |
|
||||
|---|---:|---|
|
||||
| `LoraLiteConfig` | done | `src/lora_lite/config.py` |
|
||||
| Variant registry + `ParamSpec` | done | `src/lora_lite/variant.py` |
|
||||
| Structural target discovery | done | `src/lora_lite/target.py` |
|
||||
| `attach` / `detach` / `save` / `load` | done | `src/lora_lite/adapter.py` |
|
||||
| LoRA | done | `src/lora_lite/variants/lora.py` |
|
||||
| PiSSA | done, fp-only | `src/lora_lite/variants/pissa.py` |
|
||||
| DeLoRA | done | `src/lora_lite/variants/delora.py` |
|
||||
| Smoke tests | done | `tests/smoke.py` |
|
||||
| bnb minimal forward smoke | done | `Linear8bitLt` and `Linear4bit` pass on CUDA |
|
||||
|
||||
## Current smoke evidence
|
||||
|
||||
Last verified log: `/home/wassname/.cache/agent-tmp/lora_lite_smoke_after_review.log`
|
||||
|
||||
| Check | Result |
|
||||
|---|---|
|
||||
| LoRA identity | `0.000e+00` |
|
||||
| LoRA loss drop | `6.1%` |
|
||||
| PiSSA identity | `1.550e-06` |
|
||||
| PiSSA loss drop | `11.5%` |
|
||||
| DeLoRA identity | `0.000e+00` |
|
||||
| DeLoRA loss drop | `93.4%` |
|
||||
| fake non-`nn.Linear` target | attaches, identity `0.000e+00`, grad nonzero |
|
||||
| bnb `Linear8bitLt` | identity `0.000e+00`, grad nonzero |
|
||||
| bnb `Linear4bit` | identity `0.000e+00`, grad nonzero |
|
||||
|
||||
## Review history
|
||||
|
||||
A cold subagent review first returned `PASS_WITH_BLOCKERS`:
|
||||
|
||||
1. bnb modules were not targeted.
|
||||
2. Hook cast `y` to `cfg.dtype`, which could round base outputs.
|
||||
3. PiSSA overclaimed bnb support.
|
||||
4. `load()` did not fail on missing adapter keys.
|
||||
5. Data-calibrated init needed model-level access.
|
||||
|
||||
Fixes applied:
|
||||
|
||||
1. Structural `is_linear_like()` target predicate.
|
||||
2. Hook only casts `x`, keeps `y` in base output dtype.
|
||||
3. PiSSA fail-fast rejects non-plain `nn.Linear`.
|
||||
4. `load()` fails on missing or unexpected `lora_` keys.
|
||||
5. `attach(..., calibration_data=None)` plus optional `group_init(model, targets, cfg, calibration_data)`.
|
||||
|
||||
Second cold review verdict: `PASS` for the minimal 4bit-enabled scope.
|
||||
|
||||
## TODO / status
|
||||
|
||||
### Next implementation goals
|
||||
|
||||
- [ ] Add DoRA.
|
||||
- Verify: fp32/bf16 identity at init, finite gradients, and smoke loss drop.
|
||||
- Caveat: bnb DoRA needs explicit weight dequantization for norm computation or should be fp-only at first.
|
||||
|
||||
- [ ] Add VeRA.
|
||||
- Verify: shared buffers are allocated once, target slices match shape, identity or near-identity at init.
|
||||
|
||||
- [ ] Add SSVD or AntiPaSTO-style SVD variant.
|
||||
- Verify: reconstruction or intended rotation invariant at init.
|
||||
|
||||
- [ ] Add real activation-calibrated toy variant using `group_init`.
|
||||
- Verify: `calibration_data` is consumed during `attach`, hooks are removed, checkpoint is serializable, and `load()` does not require calibration data.
|
||||
|
||||
- [ ] Add load path that can skip calibration init for future `group_init` variants.
|
||||
- Current caveat: `load()` calls `attach(model, cfg)` with `calibration_data=None`; fine for current variants, but future calibrated variants should separate param creation from calibration.
|
||||
|
||||
- [ ] Add a tiny HF-model smoke when convenient.
|
||||
- Verify: target names look like real transformer modules and state dict keys match full paths.
|
||||
|
||||
### Design TODOs
|
||||
|
||||
- [ ] Decide whether `group_init` should run before or after forward hooks are registered.
|
||||
- Current choice: after params are attached, before adapter forward hooks are registered.
|
||||
|
||||
- [ ] Decide whether replacing variants need `runs_base_layer=False` or can always transform `y`.
|
||||
- OFT-like variants can rotate `y`; variants that truly avoid base forward need module replacement or pre-hook rewriting, likely out of v1.
|
||||
|
||||
- [ ] Add `weight_mode` for BitFit/SHiRA if those variants become in-scope.
|
||||
- Minimal surface: `weight_mode in {"frozen", "bias_only", "sparse_grad"}`.
|
||||
|
||||
## Variant contract
|
||||
|
||||
```python
|
||||
class Variant:
|
||||
name: str
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg) -> dict[str, ParamSpec]: ...
|
||||
|
||||
@staticmethod
|
||||
def init(layer, cfg) -> None:
|
||||
# weight-only init; may mutate plain fp weights
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def group_init(model, targets, cfg, calibration_data) -> None:
|
||||
# optional model-level init for data-calibrated or cross-layer variants
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def forward(layer, x, y) -> Tensor:
|
||||
# return NEW output; additive variants return y + delta
|
||||
...
|
||||
```
|
||||
|
||||
## Done means
|
||||
|
||||
This repo is good enough for a first real experiment when:
|
||||
|
||||
1. A Qwen/Llama model can attach LoRA adapters to intended target layers.
|
||||
2. A 4bit or 8bit loaded model can train LoRA/DeLoRA params with nonzero gradients.
|
||||
3. Saved adapter tensors use full-path keys and reload without calibration data.
|
||||
4. Smoke tests distinguish target-skipping, hook identity drift, and missing-key load failure.
|
||||
@@ -0,0 +1,20 @@
|
||||
[project]
|
||||
name = "lora-lite"
|
||||
version = "0.0.1"
|
||||
description = "Hackable forward-hook LoRA. ~600 LoC. One file per variant."
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.1",
|
||||
"einops>=0.7",
|
||||
"jaxtyping>=0.2.25",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = ["pytest"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=68"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -0,0 +1,16 @@
|
||||
from .config import LoraLiteConfig
|
||||
from .adapter import attach, detach, save, load
|
||||
from .variant import REGISTRY, register, ParamSpec, Variant
|
||||
from . import variants # noqa: F401 triggers variant registration
|
||||
|
||||
__all__ = [
|
||||
"LoraLiteConfig",
|
||||
"attach",
|
||||
"detach",
|
||||
"save",
|
||||
"load",
|
||||
"REGISTRY",
|
||||
"register",
|
||||
"ParamSpec",
|
||||
"Variant",
|
||||
]
|
||||
@@ -0,0 +1,103 @@
|
||||
"""attach / detach / save / load. The whole runtime."""
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from .config import LoraLiteConfig
|
||||
from .variant import REGISTRY
|
||||
from .target import find_targets
|
||||
|
||||
|
||||
_ATTACHED_ATTR = "_lora_lite_attached"
|
||||
|
||||
|
||||
def _hook(layer, args, y):
|
||||
(x,) = args
|
||||
cfg: LoraLiteConfig = layer._lora_cfg
|
||||
x_cast = x.to(cfg.dtype)
|
||||
out = layer._lora_variant.forward(layer, x_cast, y)
|
||||
return out.to(y.dtype)
|
||||
|
||||
|
||||
def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None) -> list[RemovableHandle]:
|
||||
if cfg.variant not in REGISTRY:
|
||||
raise KeyError(f"unknown variant {cfg.variant!r}; registered: {list(REGISTRY)}")
|
||||
variant = REGISTRY[cfg.variant]
|
||||
targets = find_targets(model, cfg)
|
||||
if not targets:
|
||||
raise RuntimeError("no target layers matched cfg")
|
||||
|
||||
# freeze base
|
||||
for p in model.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
handles: list[RemovableHandle] = []
|
||||
attached_names: list[str] = []
|
||||
attached_targets = []
|
||||
for name, layer, role in targets:
|
||||
d_in, d_out = layer.in_features, layer.out_features
|
||||
for pname, spec in variant.param_specs(d_in, d_out, cfg).items():
|
||||
if hasattr(layer, pname):
|
||||
raise RuntimeError(f"{name} already has attribute {pname}; detach first")
|
||||
p = spec.make(cfg.dtype, layer.weight.device)
|
||||
layer.register_parameter(pname, p)
|
||||
layer._lora_cfg = cfg
|
||||
layer._lora_variant = variant
|
||||
layer._lora_role = role
|
||||
variant.init(layer, cfg)
|
||||
attached_names.append(name)
|
||||
attached_targets.append((name, layer, role))
|
||||
|
||||
group_init = getattr(variant, "group_init", None)
|
||||
if group_init is not None:
|
||||
group_init(model, attached_targets, cfg, calibration_data)
|
||||
|
||||
for _, layer, _ in attached_targets:
|
||||
handles.append(layer.register_forward_hook(_hook))
|
||||
|
||||
setattr(model, _ATTACHED_ATTR, {"cfg": cfg, "targets": attached_names, "handles": handles})
|
||||
return handles
|
||||
|
||||
|
||||
def detach(model: nn.Module) -> None:
|
||||
state = getattr(model, _ATTACHED_ATTR, None)
|
||||
if state is None:
|
||||
return
|
||||
for h in state["handles"]:
|
||||
h.remove()
|
||||
# remove params + scratch attrs
|
||||
for name, layer in model.named_modules():
|
||||
if not hasattr(layer, "_lora_variant"):
|
||||
continue
|
||||
variant = layer._lora_variant
|
||||
for pname in variant.param_specs(layer.in_features, layer.out_features, layer._lora_cfg):
|
||||
if pname in layer._parameters:
|
||||
del layer._parameters[pname]
|
||||
for attr in ("_lora_cfg", "_lora_variant", "_lora_role"):
|
||||
if hasattr(layer, attr):
|
||||
delattr(layer, attr)
|
||||
delattr(model, _ATTACHED_ATTR)
|
||||
|
||||
|
||||
def save(model: nn.Module, path: str) -> None:
|
||||
state = getattr(model, _ATTACHED_ATTR, None)
|
||||
if state is None:
|
||||
raise RuntimeError("no adapter attached; call attach() first")
|
||||
sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k}
|
||||
torch.save({"cfg": state["cfg"].to_dict(), "state": sd}, path)
|
||||
|
||||
|
||||
def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||
blob = torch.load(path, weights_only=True, map_location="cpu")
|
||||
cfg = LoraLiteConfig.from_dict(blob["cfg"])
|
||||
handles = attach(model, cfg) # creates empty params with right shapes
|
||||
missing, unexpected = model.load_state_dict(blob["state"], strict=False)
|
||||
expected_lora = {k for k in model.state_dict() if "lora_" in k}
|
||||
missing_lora = sorted(expected_lora.intersection(missing))
|
||||
if missing_lora:
|
||||
raise RuntimeError(f"missing lora keys in checkpoint: {missing_lora}")
|
||||
unexpected_lora = [k for k in unexpected if "lora_" in k]
|
||||
if unexpected_lora:
|
||||
raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}")
|
||||
return handles
|
||||
@@ -0,0 +1,38 @@
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraLiteConfig:
|
||||
variant: str = "lora"
|
||||
r: int = 8
|
||||
alpha: float = 16.0
|
||||
dropout: float = 0.0 # currently ignored; variants may use cfg.variant_kwargs
|
||||
dtype: torch.dtype = torch.bfloat16
|
||||
|
||||
# targeting
|
||||
target_roles: tuple[str, ...] = ("reader", "writer")
|
||||
target_names: tuple[str, ...] = ()
|
||||
exclude_names: tuple[str, ...] = ("lm_head", "embed_tokens")
|
||||
layers: tuple[int, ...] | None = None
|
||||
|
||||
# variant-specific bag (e.g. lambda0 for DeLoRA)
|
||||
variant_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = asdict(self)
|
||||
d["dtype"] = str(self.dtype).removeprefix("torch.")
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "LoraLiteConfig":
|
||||
d = dict(d)
|
||||
if isinstance(d.get("dtype"), str):
|
||||
d["dtype"] = getattr(torch, d["dtype"])
|
||||
if isinstance(d.get("layers"), list):
|
||||
d["layers"] = tuple(d["layers"])
|
||||
for k in ("target_roles", "target_names", "exclude_names"):
|
||||
if isinstance(d.get(k), list):
|
||||
d[k] = tuple(d[k])
|
||||
return cls(**d)
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Find linear-like targets by shape (reader/writer/inner) + name regex.
|
||||
|
||||
Structural matching is deliberate: bnb Linear4bit/8bitLt are not nn.Linear, but
|
||||
they expose in_features/out_features/weight and their forward already handles
|
||||
dequantization.
|
||||
"""
|
||||
import re
|
||||
from torch import nn
|
||||
|
||||
|
||||
def is_linear_like(m: nn.Module) -> bool:
|
||||
return (
|
||||
hasattr(m, "in_features")
|
||||
and hasattr(m, "out_features")
|
||||
and hasattr(m, "weight")
|
||||
and callable(m)
|
||||
)
|
||||
|
||||
|
||||
def _layer_idx(name: str) -> int | None:
|
||||
m = re.search(r"\.layers?\.(\d+)\.", name)
|
||||
return int(m.group(1)) if m else None
|
||||
|
||||
|
||||
def _classify(m: nn.Module, d_model: int, name: str) -> str:
|
||||
di, do = m.in_features, m.out_features
|
||||
if di == d_model and do != d_model:
|
||||
return "reader"
|
||||
if do == d_model and di != d_model:
|
||||
return "writer"
|
||||
if di == d_model and do == d_model:
|
||||
return "writer" if any(s in name for s in ("o_proj", "out_proj")) else "reader"
|
||||
return "inner"
|
||||
|
||||
|
||||
def find_targets(model: nn.Module, cfg) -> list[tuple[str, nn.Module, str]]:
|
||||
# discover d_model: prefer config.hidden_size, else infer from largest Linear in_features
|
||||
d_model = getattr(getattr(model, "config", None), "hidden_size", None)
|
||||
if d_model is None:
|
||||
dims = [m.in_features for m in model.modules() if is_linear_like(m)]
|
||||
d_model = max(dims) if dims else 0
|
||||
|
||||
out = []
|
||||
for name, m in model.named_modules():
|
||||
if not is_linear_like(m):
|
||||
continue
|
||||
if any(re.search(p, name) for p in cfg.exclude_names):
|
||||
continue
|
||||
if cfg.layers is not None:
|
||||
li = _layer_idx(name)
|
||||
if li is None or li not in cfg.layers:
|
||||
continue
|
||||
role = _classify(m, d_model, name)
|
||||
if cfg.target_roles and role not in cfg.target_roles:
|
||||
continue
|
||||
if cfg.target_names and not any(re.search(p, name) for p in cfg.target_names):
|
||||
continue
|
||||
out.append((name, m, role))
|
||||
return out
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Variant protocol + registry. Variants own (x, layer.weight, layer.lora_*) -> y_new."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Protocol, Any
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .config import LoraLiteConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamSpec:
|
||||
shape: tuple[int, ...]
|
||||
init: str | Callable[[torch.Tensor], None] = "zeros" # 'zeros'|'kaiming'|'ones'|callable(t)
|
||||
trainable: bool = True
|
||||
|
||||
def make(self, dtype: torch.dtype, device) -> nn.Parameter:
|
||||
t = torch.empty(self.shape, dtype=dtype, device=device)
|
||||
if callable(self.init):
|
||||
self.init(t)
|
||||
elif self.init == "zeros":
|
||||
t.zero_()
|
||||
elif self.init == "ones":
|
||||
t.fill_(1.0)
|
||||
elif self.init == "kaiming":
|
||||
# match nn.Linear default: kaiming_uniform_(a=sqrt(5))
|
||||
nn.init.kaiming_uniform_(t, a=5 ** 0.5) if t.ndim >= 2 else t.normal_(0, 0.02)
|
||||
else:
|
||||
raise ValueError(f"unknown init: {self.init}")
|
||||
return nn.Parameter(t, requires_grad=self.trainable)
|
||||
|
||||
|
||||
class Variant(Protocol):
|
||||
name: str
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in: int, d_out: int, cfg: LoraLiteConfig) -> dict[str, ParamSpec]: ...
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Linear, cfg: LoraLiteConfig) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def forward(layer: nn.Linear, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
"""Return the layer's NEW output (not just delta).
|
||||
Additive variants: `return y + delta`.
|
||||
Replacing variants: ignore `y`, return new value."""
|
||||
...
|
||||
|
||||
|
||||
REGISTRY: dict[str, type] = {}
|
||||
|
||||
|
||||
def register(cls):
|
||||
if not getattr(cls, "name", None):
|
||||
raise ValueError(f"variant {cls} missing .name")
|
||||
REGISTRY[cls.name] = cls
|
||||
return cls
|
||||
@@ -0,0 +1 @@
|
||||
from . import lora, pissa, delora # noqa: F401 side-effect: register
|
||||
@@ -0,0 +1,43 @@
|
||||
"""DeLoRA: column-normalised A, B, scaled by lambda/r. Bini et al. 2025 arXiv:2503.18225.
|
||||
|
||||
NOTE on identity at t=0: paper uses kaiming for both A and B with a learned lambda
|
||||
init at 0 (or small) so the effective delta starts near zero. We honour that:
|
||||
default lambda0 == 0 gives bit-identity; user can override via variant_kwargs.
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import einsum
|
||||
from torch import nn
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
|
||||
|
||||
@register
|
||||
class DeLoRA:
|
||||
name = "delora"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
lam0 = float(cfg.variant_kwargs.get("lambda0", 0.0))
|
||||
return {
|
||||
"lora_A": ParamSpec((cfg.r, d_in), init="kaiming", trainable=True),
|
||||
"lora_B": ParamSpec((d_out, cfg.r), init="kaiming", trainable=True),
|
||||
"lora_lambda": ParamSpec(
|
||||
(), init=lambda t: t.fill_(lam0), trainable=True
|
||||
),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Linear, cfg) -> None:
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def forward(layer: nn.Linear, x, y):
|
||||
cfg = layer._lora_cfg
|
||||
# rows of A unit, cols of B unit (per paper)
|
||||
A = F.normalize(layer.lora_A, dim=1) # (r, d_in)
|
||||
B = F.normalize(layer.lora_B, dim=0) # (d_out, r)
|
||||
scale = layer.lora_lambda / cfg.r
|
||||
h = einsum(x, A, "... i, r i -> ... r")
|
||||
delta = einsum(h, B, "... r, o r -> ... o")
|
||||
return y + scale * delta
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Vanilla LoRA. Reference variant. y = Wx + (alpha/r) * B @ A @ x."""
|
||||
from einops import einsum
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
|
||||
|
||||
@register
|
||||
class LoRA:
|
||||
name = "lora"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
return {
|
||||
"lora_A": ParamSpec((cfg.r, d_in), init="kaiming", trainable=True),
|
||||
"lora_B": ParamSpec((d_out, cfg.r), init="zeros", trainable=True),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Linear, cfg) -> None:
|
||||
# B is zeros => delta=0 at t=0; identity invariant holds.
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def forward(layer: nn.Linear, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
cfg = layer._lora_cfg
|
||||
scale = cfg.alpha / cfg.r
|
||||
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
|
||||
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
|
||||
return y + scale * delta
|
||||
@@ -0,0 +1,54 @@
|
||||
"""PiSSA: top-r SVD of W into A,B; replace W with W_res = W - B@A.
|
||||
|
||||
Meng et al. 2024 https://arxiv.org/abs/2404.02948
|
||||
W_eff(t=0) = W_res + B@A = W (numerically; bf16 round-trip not bit-exact).
|
||||
"""
|
||||
import torch
|
||||
from einops import einsum
|
||||
from torch import nn
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
|
||||
|
||||
@register
|
||||
class PiSSA:
|
||||
name = "pissa"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
return {
|
||||
"lora_A": ParamSpec((cfg.r, d_in), init="zeros", trainable=True),
|
||||
"lora_B": ParamSpec((d_out, cfg.r), init="zeros", trainable=True),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Linear, cfg) -> None:
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError(
|
||||
"PiSSA mutates layer.weight into W_res, so v1 only supports plain nn.Linear. "
|
||||
"For bnb 4/8-bit, use LoRA/DeLoRA or implement explicit dequantize/requantize."
|
||||
)
|
||||
W = layer.weight.data.float() # (d_out, d_in)
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
sqrtS = Sr.sqrt()
|
||||
# B @ A = Ur diag(Sr) Vhr; pick B = Ur sqrt(Sr), A = sqrt(Sr) * Vhr
|
||||
B = (Ur * sqrtS).to(cfg.dtype)
|
||||
A = (sqrtS[:, None] * Vhr).to(cfg.dtype)
|
||||
layer.lora_B.data.copy_(B)
|
||||
layer.lora_A.data.copy_(A)
|
||||
# Compute BA in fp32 for the subtraction so W_res is accurate.
|
||||
BA = (B.float() @ A.float())
|
||||
# NOTE: PiSSA uses scale=1 (alpha==r) implicitly via init. We let the user pick;
|
||||
# for fidelity at t=0, the convention is alpha==r so scale==1. Document in README.
|
||||
scale = cfg.alpha / cfg.r
|
||||
layer.weight.data.copy_((W - scale * BA).to(layer.weight.dtype))
|
||||
|
||||
@staticmethod
|
||||
def forward(layer: nn.Linear, x, y):
|
||||
cfg = layer._lora_cfg
|
||||
scale = cfg.alpha / cfg.r
|
||||
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
|
||||
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
|
||||
return y + scale * delta
|
||||
+243
@@ -0,0 +1,243 @@
|
||||
"""Smoke test: lora / pissa / delora on a tiny synthetic transformer-like model.
|
||||
|
||||
Verifies:
|
||||
1. Identity at t=0 (delta ~ 0, output close to base).
|
||||
2. Save/load round-trip preserves outputs.
|
||||
3. A few SGD steps reduce a random loss (gradients flow).
|
||||
|
||||
Run:
|
||||
cd lora-lite
|
||||
python -m pip install -e .
|
||||
python tests/smoke.py
|
||||
|
||||
BLUF format:
|
||||
SHOULD: max|y_adapter - y_base| < tol_init for all variants. ELSE init or hook bug.
|
||||
SHOULD: loss decreases > 5% over 20 SGD steps for all variants. ELSE grad/wiring bug.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import tempfile, os, sys, math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# allow running as `python tests/smoke.py` without install
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
import lora_lite as ll # noqa: E402
|
||||
|
||||
|
||||
# ---- a tiny transformer-like stack: 4 blocks of (q,k,v,o, gate,up,down) Linears ----
|
||||
class TinyBlock(nn.Module):
|
||||
def __init__(self, d=64, ff=128):
|
||||
super().__init__()
|
||||
self.q_proj = nn.Linear(d, d, bias=False)
|
||||
self.k_proj = nn.Linear(d, d, bias=False)
|
||||
self.v_proj = nn.Linear(d, d, bias=False)
|
||||
self.o_proj = nn.Linear(d, d, bias=False)
|
||||
self.gate_proj = nn.Linear(d, ff, bias=False)
|
||||
self.up_proj = nn.Linear(d, ff, bias=False)
|
||||
self.down_proj = nn.Linear(ff, d, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x))
|
||||
m = self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return x + h + m
|
||||
|
||||
|
||||
class TinyModel(nn.Module):
|
||||
def __init__(self, n_layers=4, d=64, ff=128, vocab=100):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(vocab, d)
|
||||
self.layers = nn.ModuleList([TinyBlock(d, ff) for _ in range(n_layers)])
|
||||
self.lm_head = nn.Linear(d, vocab, bias=False)
|
||||
|
||||
class Cfg: # mimic HF .config.hidden_size
|
||||
hidden_size = d
|
||||
self.config = Cfg()
|
||||
|
||||
def forward(self, ids):
|
||||
x = self.embed_tokens(ids)
|
||||
for blk in self.layers:
|
||||
x = blk(x)
|
||||
return self.lm_head(x)
|
||||
|
||||
|
||||
class FakeLinearLike(nn.Module):
|
||||
"""Not nn.Linear, but structurally bnb-like enough for target discovery."""
|
||||
|
||||
def __init__(self, d_in=8, d_out=8):
|
||||
super().__init__()
|
||||
self.in_features = d_in
|
||||
self.out_features = d_out
|
||||
self.weight = nn.Parameter(torch.empty(d_out, d_in))
|
||||
nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.linear(x, self.weight)
|
||||
|
||||
|
||||
class FakeBnbModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.config = type("Cfg", (), {"hidden_size": 8})()
|
||||
self.layers = nn.ModuleList([FakeLinearLike(8, 8)])
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers[0](x)
|
||||
|
||||
|
||||
def variant_test(variant: str, dtype=torch.float32):
|
||||
print(f"\n=== variant={variant} dtype={dtype} ===")
|
||||
torch.manual_seed(0)
|
||||
model = TinyModel().to(dtype)
|
||||
ids = torch.randint(0, 100, (2, 16))
|
||||
|
||||
with torch.no_grad():
|
||||
y_base = model(ids).clone()
|
||||
|
||||
cfg = ll.LoraLiteConfig(
|
||||
variant=variant,
|
||||
r=4,
|
||||
alpha=4 if variant == "pissa" else 8, # PiSSA needs scale==1 for clean recon
|
||||
dtype=dtype,
|
||||
# delora identity-at-init demands lambda0==0 (then delta * scale = 0)
|
||||
variant_kwargs={"lambda0": 0.0} if variant == "delora" else {},
|
||||
)
|
||||
handles = ll.attach(model, cfg)
|
||||
n_targets = len(handles)
|
||||
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f" attached {n_targets} targets, trainable params={n_trainable}")
|
||||
|
||||
with torch.no_grad():
|
||||
y_adapt = model(ids)
|
||||
err = (y_adapt - y_base).abs().max().item()
|
||||
base_scale = y_base.abs().max().item()
|
||||
print(f" t=0 identity: max|y_adapt - y_base| = {err:.3e} (base scale {base_scale:.3e})")
|
||||
|
||||
# variant-specific identity tolerance
|
||||
tol = {
|
||||
"lora": 1e-6,
|
||||
"pissa": 5e-4, # SVD recon in fp32 is tight; bf16 would be ~1e-2
|
||||
"delora": 1e-6, # lambda0=0
|
||||
}[variant] * max(1.0, base_scale)
|
||||
assert err < tol, f" FAIL identity: err {err} > tol {tol}"
|
||||
print(f" SHOULD: err<{tol:.1e}. PASS.")
|
||||
|
||||
# save/load round-trip
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
p = os.path.join(d, "adapter.pt")
|
||||
ll.save(model, p)
|
||||
# detach + fresh model + load
|
||||
ll.detach(model)
|
||||
torch.manual_seed(0)
|
||||
model2 = TinyModel().to(dtype)
|
||||
# for PiSSA, base weights got mutated; we need them mutated again for the load
|
||||
# path to make sense. Easiest: re-attach with same cfg first... but that's what
|
||||
# load() does. The catch: load reads cfg from the file, runs attach (which
|
||||
# re-runs PiSSA init -> same SVD on same weights -> same A,B -> mutates W
|
||||
# to the same W_res). Then state_dict overwrites lora_A/B with saved values.
|
||||
ll.load(model2, p)
|
||||
with torch.no_grad():
|
||||
y_loaded = model2(ids)
|
||||
err2 = (y_loaded - y_adapt).abs().max().item()
|
||||
print(f" save/load: max|y_loaded - y_adapt| = {err2:.3e}")
|
||||
assert err2 < tol, f" FAIL save/load: {err2} > {tol}"
|
||||
print(f" SHOULD: err2<{tol:.1e}. PASS.")
|
||||
ll.detach(model2)
|
||||
|
||||
# gradient flow: 20 SGD steps on random target.
|
||||
# For delora, lambda0==0 makes A,B grads zero (scale=0); use lambda0=0.1 for training.
|
||||
torch.manual_seed(0)
|
||||
model = TinyModel().to(dtype)
|
||||
train_cfg = cfg
|
||||
if variant == "delora":
|
||||
train_cfg = ll.LoraLiteConfig(
|
||||
variant=cfg.variant, r=cfg.r, alpha=cfg.alpha, dtype=cfg.dtype,
|
||||
variant_kwargs={"lambda0": 0.1},
|
||||
)
|
||||
ll.attach(model, train_cfg)
|
||||
target = torch.randn(2, 16, 100, dtype=dtype) * 0.1
|
||||
trainable = [p for p in model.parameters() if p.requires_grad]
|
||||
# delora has tightly-normalised updates; use Adam with higher lr to see signal in 20 steps
|
||||
if variant == "delora":
|
||||
opt = torch.optim.Adam(trainable, lr=1e-1)
|
||||
else:
|
||||
opt = torch.optim.SGD(trainable, lr=1e-2)
|
||||
losses = []
|
||||
for step in range(20):
|
||||
opt.zero_grad()
|
||||
loss = (model(ids) - target).pow(2).mean()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
losses.append(loss.item())
|
||||
drop = (losses[0] - losses[-1]) / max(losses[0], 1e-12)
|
||||
print(f" loss[0]={losses[0]:.4f} loss[-1]={losses[-1]:.4f} drop={100*drop:.1f}%")
|
||||
assert drop > 0.05, f" FAIL: loss drop only {drop:.2%}, expected >5%"
|
||||
print(f" SHOULD: drop>5%. PASS.")
|
||||
|
||||
|
||||
def structural_linear_like_test():
|
||||
print("\n=== structural linear-like target test (bnb-style, not nn.Linear) ===")
|
||||
torch.manual_seed(0)
|
||||
model = FakeBnbModel()
|
||||
x = torch.randn(2, 3, 8)
|
||||
y_base = model(x).detach()
|
||||
ll.attach(model, ll.LoraLiteConfig(variant="lora", r=2, alpha=4, dtype=torch.float32, target_roles=()))
|
||||
layer = model.layers[0]
|
||||
assert hasattr(layer, "lora_A") and hasattr(layer, "lora_B")
|
||||
y = model(x)
|
||||
err = (y.detach() - y_base).abs().max().item()
|
||||
loss = y.pow(2).mean()
|
||||
loss.backward()
|
||||
grad_nonzero = layer.lora_B.grad.abs().sum().item() > 0
|
||||
print(f" attached lora_A={tuple(layer.lora_A.shape)} lora_B={tuple(layer.lora_B.shape)}")
|
||||
print(f" identity_err={err:.3e} grad_nonzero={grad_nonzero}")
|
||||
assert err == 0.0
|
||||
assert grad_nonzero
|
||||
print(" SHOULD: structural target attaches and lora_B receives grad. PASS.")
|
||||
|
||||
|
||||
def bitsandbytes_cuda_smoke():
|
||||
print("\n=== optional bitsandbytes CUDA smoke ===")
|
||||
if not torch.cuda.is_available():
|
||||
print(" SKIP: CUDA unavailable; real bnb 4/8-bit forward needs GPU on this machine.")
|
||||
return
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
print(" SKIP: bitsandbytes unavailable.")
|
||||
return
|
||||
|
||||
class BnbModel(nn.Module):
|
||||
def __init__(self, Layer):
|
||||
super().__init__()
|
||||
self.config = type("Cfg", (), {"hidden_size": 8})()
|
||||
self.layers = nn.ModuleList([Layer(8, 8, bias=False)]).cuda()
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers[0](x)
|
||||
|
||||
for layer_cls in (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit):
|
||||
torch.manual_seed(0)
|
||||
model = BnbModel(layer_cls)
|
||||
x = torch.randn(2, 3, 8, device="cuda")
|
||||
y_base = model(x).detach()
|
||||
ll.attach(model, ll.LoraLiteConfig(variant="lora", r=2, alpha=4, dtype=torch.float16, target_roles=()))
|
||||
y = model(x)
|
||||
err = (y.detach() - y_base).abs().max().item()
|
||||
y.pow(2).mean().backward()
|
||||
grad_nonzero = model.layers[0].lora_B.grad.abs().sum().item() > 0
|
||||
print(f" {layer_cls.__name__}: identity_err={err:.3e} grad_nonzero={grad_nonzero}")
|
||||
assert err == 0.0
|
||||
assert grad_nonzero
|
||||
|
||||
|
||||
def main():
|
||||
for v in ("lora", "pissa", "delora"):
|
||||
variant_test(v, dtype=torch.float32)
|
||||
structural_linear_like_test()
|
||||
bitsandbytes_cuda_smoke()
|
||||
print("\nALL PASS.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user