From 4db5cee5a933c397f02ff274744a0a597e474720 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Sun, 26 Apr 2026 14:10:20 +0800 Subject: [PATCH] init --- README.md | 161 ++++++++++++++++++ docs/spec/20260426_lora_lite_plan.md | 144 ++++++++++++++++ pyproject.toml | 20 +++ src/lora_lite/__init__.py | 16 ++ src/lora_lite/adapter.py | 103 ++++++++++++ src/lora_lite/config.py | 38 +++++ src/lora_lite/target.py | 59 +++++++ src/lora_lite/variant.py | 56 ++++++ src/lora_lite/variants/__init__.py | 1 + src/lora_lite/variants/delora.py | 43 +++++ src/lora_lite/variants/lora.py | 31 ++++ src/lora_lite/variants/pissa.py | 54 ++++++ tests/smoke.py | 243 +++++++++++++++++++++++++++ 13 files changed, 969 insertions(+) create mode 100644 README.md create mode 100644 docs/spec/20260426_lora_lite_plan.md create mode 100644 pyproject.toml create mode 100644 src/lora_lite/__init__.py create mode 100644 src/lora_lite/adapter.py create mode 100644 src/lora_lite/config.py create mode 100644 src/lora_lite/target.py create mode 100644 src/lora_lite/variant.py create mode 100644 src/lora_lite/variants/__init__.py create mode 100644 src/lora_lite/variants/delora.py create mode 100644 src/lora_lite/variants/lora.py create mode 100644 src/lora_lite/variants/pissa.py create mode 100644 tests/smoke.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..eecd032 --- /dev/null +++ b/README.md @@ -0,0 +1,161 @@ +# lora-lite + +A hackable, single-file-per-variant LoRA library built on PyTorch forward hooks. + +- ~600 LoC total +- One file per variant, ~50 LoC each +- No module replacement, no merge/unmerge, no PEFT config soup +- Save = `torch.save({cfg, state_dict_filtered_by_'lora_'})` +- LoRA/DeLoRA forward hooks work with `nn.Linear` and bnb-style `Linear{4bit,8bitLt}` modules that expose `in_features`, `out_features`, and `weight`. +- PiSSA is fp-only in v1 because it mutates `weight` into `W_res`; quantized PiSSA needs explicit dequantize/requantize. + +Currently shipped variants: + +| Variant | Class | File | +|---|---|---| +| LoRA | A (additive) | [src/lora_lite/variants/lora.py](src/lora_lite/variants/lora.py) | +| PiSSA ([Meng+ 2024](https://arxiv.org/abs/2404.02948)) | A + B (special init mutates W) | [src/lora_lite/variants/pissa.py](src/lora_lite/variants/pissa.py) | +| DeLoRA ([Bini+ 2025](https://arxiv.org/abs/2503.18225)) | A (additive, normalised) | [src/lora_lite/variants/delora.py](src/lora_lite/variants/delora.py) | + +See [docs/spec/20260426_lora_lite_plan.md](docs/spec/20260426_lora_lite_plan.md) for goals, status, TODOs, and the current design plan. The original broader design was stress-tested against the [adapters_as_hypotheses](https://github.com/wassname/adapters_as_hypotheses) catalog (~26/27 variants covered with 3 small API tweaks). + +## Install + +```bash +pip install -e . +``` + +## Quickstart + +```python +import torch, lora_lite as ll + +model = MyTransformer() # any nn.Module containing linear-like children +cfg = ll.LoraLiteConfig(variant="lora", r=8, alpha=16, dtype=torch.bfloat16) +handles = ll.attach(model, cfg) + +# train +trainable = [p for p in model.parameters() if p.requires_grad] +opt = torch.optim.AdamW(trainable, lr=1e-4) +# ... your loop ... + +ll.save(model, "adapter.pt") +ll.detach(model) +# later: +ll.load(model, "adapter.pt") +``` + +Inspect a tensor live: + +```python +A = model.layers[5].self_attn.q_proj.lora_A # just an nn.Parameter +``` + +## Targeting + +By default we target linear-like modules (`in_features`, `out_features`, `weight`) whose shape matches a "reader" (`d_in == d_model`) or "writer" (`d_out == d_model`) role, excluding `lm_head` and `embed_tokens`. This structural test is what lets bnb Linear4bit/8bitLt modules be targeted without a backend-specific class. Knobs on `LoraLiteConfig`: + +- `target_roles`: subset of `("reader", "writer", "inner")`. `()` = all. +- `target_names`: regex includes (must match if non-empty). +- `exclude_names`: regex excludes (default skips `lm_head`, `embed_tokens`). +- `layers`: tuple of layer indices, or `None` for all (matches `.layers..` in module name). + +## Variant API + +A variant is a class with a `name` and three statics: + +```python +@register +class MyVariant: + name = "myvariant" + + @staticmethod + def param_specs(d_in, d_out, cfg) -> dict[str, ParamSpec]: + return {"lora_A": ParamSpec((cfg.r, d_in), init="kaiming"), ...} + + @staticmethod + def init(layer, cfg) -> None: + # Optional. Run after params are created. May read/mutate layer.weight. + ... + + @staticmethod + def forward(layer, x, y) -> Tensor: + # Return the layer's NEW output (additive: `return y + delta`). + ... +``` + +Adapter params attached as `layer.lora_*` get full-path keys in `state_dict()` automatically (e.g. `model.layers.5.self_attn.q_proj.lora_A`). + +## Data-calibrated init + +PiSSA, DeLoRA, and LoRA only use `layer.weight` for init -- no calibration data needed. + +For variants that DO need data (e.g. AntiPaSTO, LoRA-GA, activation-aware SVD), keep dataloaders out of `cfg` so adapter checkpoints stay serializable. Use: + +```python +ll.attach(model, cfg, calibration_data=calib) +``` + +where `calib` is an iterable of whole-model inputs, e.g. `Iterable[dict[str, Tensor]]` for HF models or `Iterable[Tensor]` of token ids. Activation-aware variants implement: + +```python +@staticmethod +def group_init(model, targets, cfg, calibration_data): ... +``` + +`targets` is `list[(name, layer, role)]`. The variant adds temporary hooks, runs `model(batch)` over `calibration_data`, removes the hooks, then writes `lora_*` params. Per-layer `init(layer, cfg)` stays weight-only. + +Sketch: + +```python +@register +class ActSVD: + name = "actsvd" + @staticmethod + def param_specs(d_in, d_out, cfg): ... + @staticmethod + def group_init(model, targets, cfg, calibration_data): + bufs = {name: [] for name, _, _ in targets} + hooks = [ + layer.register_forward_pre_hook( + lambda m, args, name=name: bufs[name].append(args[0].detach().float()) + ) + for name, layer, _ in targets + ] + try: + with torch.no_grad(): + for batch in calibration_data: + model(**batch) if isinstance(batch, dict) else model(batch) + finally: + for h in hooks: + h.remove() + # For each target: X = torch.cat(bufs[name], dim=0); do SVD; write A/B. +``` + +## Smoke test + +```bash +python tests/smoke.py +``` + +Verifies for each of `lora`, `pissa`, `delora`: +1. Identity at t=0: `max|y_adapter - y_base|` within float tolerance. +2. Save/load round-trip preserves outputs. +3. 20 SGD steps reduce a random regression loss by >5%. + +## What's NOT in v1 + +| Feature | Why dropped | +|---|---| +| merge/unmerge | reload base if you want vanilla | +| 4/8-bit-aware merge | DoRA on bnb supported in forward only (drop merge path) | +| Embedding / Conv adapters | trivial extension; add when needed | +| `adapter_names=` mixed batch forward | rare; add when needed | +| Multiple named adapters per layer | one variant per `attach()` | +| HF `PeftConfig` / hub upload | `torch.save({cfg, state})` is enough | +| AdaLoRA-style rank scheduling | needs `Variant.on_step(step)` -- punt | +| ReFT-style position interventions | sibling submodule (different hook site) | + +## Status + +v0.0.1: lora + pissa + delora + smoke test. See spec for next variants (DoRA, VeRA, SSVD). diff --git a/docs/spec/20260426_lora_lite_plan.md b/docs/spec/20260426_lora_lite_plan.md new file mode 100644 index 0000000..c6bf106 --- /dev/null +++ b/docs/spec/20260426_lora_lite_plan.md @@ -0,0 +1,144 @@ +# lora-lite plan and status + +## Goal + +Build a small, hackable LoRA-family adapter library for research experiments. + +The core bet is that adapter variants should own the relationship between `(x, layer.weight, layer.lora_*)` and the layer output, while the library only handles targeting, parameter attachment, hooks, and save/load. + +## Non-goals + +- No PEFT compatibility layer. +- No module replacement. +- No merge/unmerge. +- No multiple named adapters per layer. +- No backward compatibility promises. +- No silent fallbacks. + +## Design constraints + +- Adapter params are attached directly to target layers as `lora_*` parameters. +- Save/load uses normal `state_dict()` keys, filtered by `"lora_"`. +- Forward hooks return the layer's new output, not just a delta. +- Targeting is structural: modules with `in_features`, `out_features`, and `weight` are linear-like. +- LoRA/DeLoRA support bnb-style 4/8-bit forward paths because the quantized base layer computes `y`; the hook only adds adapter math. +- PiSSA is fp-only in v1 because it mutates `layer.weight` into `W_res`. +- Data-calibrated variants use `group_init(model, targets, cfg, calibration_data)`; dataloaders stay out of `cfg` so checkpoints are serializable. + +## Implemented v0.0.1 + +| Area | Status | Evidence | +|---|---:|---| +| `LoraLiteConfig` | done | `src/lora_lite/config.py` | +| Variant registry + `ParamSpec` | done | `src/lora_lite/variant.py` | +| Structural target discovery | done | `src/lora_lite/target.py` | +| `attach` / `detach` / `save` / `load` | done | `src/lora_lite/adapter.py` | +| LoRA | done | `src/lora_lite/variants/lora.py` | +| PiSSA | done, fp-only | `src/lora_lite/variants/pissa.py` | +| DeLoRA | done | `src/lora_lite/variants/delora.py` | +| Smoke tests | done | `tests/smoke.py` | +| bnb minimal forward smoke | done | `Linear8bitLt` and `Linear4bit` pass on CUDA | + +## Current smoke evidence + +Last verified log: `/home/wassname/.cache/agent-tmp/lora_lite_smoke_after_review.log` + +| Check | Result | +|---|---| +| LoRA identity | `0.000e+00` | +| LoRA loss drop | `6.1%` | +| PiSSA identity | `1.550e-06` | +| PiSSA loss drop | `11.5%` | +| DeLoRA identity | `0.000e+00` | +| DeLoRA loss drop | `93.4%` | +| fake non-`nn.Linear` target | attaches, identity `0.000e+00`, grad nonzero | +| bnb `Linear8bitLt` | identity `0.000e+00`, grad nonzero | +| bnb `Linear4bit` | identity `0.000e+00`, grad nonzero | + +## Review history + +A cold subagent review first returned `PASS_WITH_BLOCKERS`: + +1. bnb modules were not targeted. +2. Hook cast `y` to `cfg.dtype`, which could round base outputs. +3. PiSSA overclaimed bnb support. +4. `load()` did not fail on missing adapter keys. +5. Data-calibrated init needed model-level access. + +Fixes applied: + +1. Structural `is_linear_like()` target predicate. +2. Hook only casts `x`, keeps `y` in base output dtype. +3. PiSSA fail-fast rejects non-plain `nn.Linear`. +4. `load()` fails on missing or unexpected `lora_` keys. +5. `attach(..., calibration_data=None)` plus optional `group_init(model, targets, cfg, calibration_data)`. + +Second cold review verdict: `PASS` for the minimal 4bit-enabled scope. + +## TODO / status + +### Next implementation goals + +- [ ] Add DoRA. + - Verify: fp32/bf16 identity at init, finite gradients, and smoke loss drop. + - Caveat: bnb DoRA needs explicit weight dequantization for norm computation or should be fp-only at first. + +- [ ] Add VeRA. + - Verify: shared buffers are allocated once, target slices match shape, identity or near-identity at init. + +- [ ] Add SSVD or AntiPaSTO-style SVD variant. + - Verify: reconstruction or intended rotation invariant at init. + +- [ ] Add real activation-calibrated toy variant using `group_init`. + - Verify: `calibration_data` is consumed during `attach`, hooks are removed, checkpoint is serializable, and `load()` does not require calibration data. + +- [ ] Add load path that can skip calibration init for future `group_init` variants. + - Current caveat: `load()` calls `attach(model, cfg)` with `calibration_data=None`; fine for current variants, but future calibrated variants should separate param creation from calibration. + +- [ ] Add a tiny HF-model smoke when convenient. + - Verify: target names look like real transformer modules and state dict keys match full paths. + +### Design TODOs + +- [ ] Decide whether `group_init` should run before or after forward hooks are registered. + - Current choice: after params are attached, before adapter forward hooks are registered. + +- [ ] Decide whether replacing variants need `runs_base_layer=False` or can always transform `y`. + - OFT-like variants can rotate `y`; variants that truly avoid base forward need module replacement or pre-hook rewriting, likely out of v1. + +- [ ] Add `weight_mode` for BitFit/SHiRA if those variants become in-scope. + - Minimal surface: `weight_mode in {"frozen", "bias_only", "sparse_grad"}`. + +## Variant contract + +```python +class Variant: + name: str + + @staticmethod + def param_specs(d_in, d_out, cfg) -> dict[str, ParamSpec]: ... + + @staticmethod + def init(layer, cfg) -> None: + # weight-only init; may mutate plain fp weights + ... + + @staticmethod + def group_init(model, targets, cfg, calibration_data) -> None: + # optional model-level init for data-calibrated or cross-layer variants + ... + + @staticmethod + def forward(layer, x, y) -> Tensor: + # return NEW output; additive variants return y + delta + ... +``` + +## Done means + +This repo is good enough for a first real experiment when: + +1. A Qwen/Llama model can attach LoRA adapters to intended target layers. +2. A 4bit or 8bit loaded model can train LoRA/DeLoRA params with nonzero gradients. +3. Saved adapter tensors use full-path keys and reload without calibration data. +4. Smoke tests distinguish target-skipping, hook identity drift, and missing-key load failure. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..48aa034 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,20 @@ +[project] +name = "lora-lite" +version = "0.0.1" +description = "Hackable forward-hook LoRA. ~600 LoC. One file per variant." +requires-python = ">=3.10" +dependencies = [ + "torch>=2.1", + "einops>=0.7", + "jaxtyping>=0.2.25", +] + +[project.optional-dependencies] +test = ["pytest"] + +[build-system] +requires = ["setuptools>=68"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/src/lora_lite/__init__.py b/src/lora_lite/__init__.py new file mode 100644 index 0000000..da02615 --- /dev/null +++ b/src/lora_lite/__init__.py @@ -0,0 +1,16 @@ +from .config import LoraLiteConfig +from .adapter import attach, detach, save, load +from .variant import REGISTRY, register, ParamSpec, Variant +from . import variants # noqa: F401 triggers variant registration + +__all__ = [ + "LoraLiteConfig", + "attach", + "detach", + "save", + "load", + "REGISTRY", + "register", + "ParamSpec", + "Variant", +] diff --git a/src/lora_lite/adapter.py b/src/lora_lite/adapter.py new file mode 100644 index 0000000..10cc64c --- /dev/null +++ b/src/lora_lite/adapter.py @@ -0,0 +1,103 @@ +"""attach / detach / save / load. The whole runtime.""" +from __future__ import annotations +import torch +from torch import nn +from torch.utils.hooks import RemovableHandle + +from .config import LoraLiteConfig +from .variant import REGISTRY +from .target import find_targets + + +_ATTACHED_ATTR = "_lora_lite_attached" + + +def _hook(layer, args, y): + (x,) = args + cfg: LoraLiteConfig = layer._lora_cfg + x_cast = x.to(cfg.dtype) + out = layer._lora_variant.forward(layer, x_cast, y) + return out.to(y.dtype) + + +def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None) -> list[RemovableHandle]: + if cfg.variant not in REGISTRY: + raise KeyError(f"unknown variant {cfg.variant!r}; registered: {list(REGISTRY)}") + variant = REGISTRY[cfg.variant] + targets = find_targets(model, cfg) + if not targets: + raise RuntimeError("no target layers matched cfg") + + # freeze base + for p in model.parameters(): + p.requires_grad_(False) + + handles: list[RemovableHandle] = [] + attached_names: list[str] = [] + attached_targets = [] + for name, layer, role in targets: + d_in, d_out = layer.in_features, layer.out_features + for pname, spec in variant.param_specs(d_in, d_out, cfg).items(): + if hasattr(layer, pname): + raise RuntimeError(f"{name} already has attribute {pname}; detach first") + p = spec.make(cfg.dtype, layer.weight.device) + layer.register_parameter(pname, p) + layer._lora_cfg = cfg + layer._lora_variant = variant + layer._lora_role = role + variant.init(layer, cfg) + attached_names.append(name) + attached_targets.append((name, layer, role)) + + group_init = getattr(variant, "group_init", None) + if group_init is not None: + group_init(model, attached_targets, cfg, calibration_data) + + for _, layer, _ in attached_targets: + handles.append(layer.register_forward_hook(_hook)) + + setattr(model, _ATTACHED_ATTR, {"cfg": cfg, "targets": attached_names, "handles": handles}) + return handles + + +def detach(model: nn.Module) -> None: + state = getattr(model, _ATTACHED_ATTR, None) + if state is None: + return + for h in state["handles"]: + h.remove() + # remove params + scratch attrs + for name, layer in model.named_modules(): + if not hasattr(layer, "_lora_variant"): + continue + variant = layer._lora_variant + for pname in variant.param_specs(layer.in_features, layer.out_features, layer._lora_cfg): + if pname in layer._parameters: + del layer._parameters[pname] + for attr in ("_lora_cfg", "_lora_variant", "_lora_role"): + if hasattr(layer, attr): + delattr(layer, attr) + delattr(model, _ATTACHED_ATTR) + + +def save(model: nn.Module, path: str) -> None: + state = getattr(model, _ATTACHED_ATTR, None) + if state is None: + raise RuntimeError("no adapter attached; call attach() first") + sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k} + torch.save({"cfg": state["cfg"].to_dict(), "state": sd}, path) + + +def load(model: nn.Module, path: str) -> list[RemovableHandle]: + blob = torch.load(path, weights_only=True, map_location="cpu") + cfg = LoraLiteConfig.from_dict(blob["cfg"]) + handles = attach(model, cfg) # creates empty params with right shapes + missing, unexpected = model.load_state_dict(blob["state"], strict=False) + expected_lora = {k for k in model.state_dict() if "lora_" in k} + missing_lora = sorted(expected_lora.intersection(missing)) + if missing_lora: + raise RuntimeError(f"missing lora keys in checkpoint: {missing_lora}") + unexpected_lora = [k for k in unexpected if "lora_" in k] + if unexpected_lora: + raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}") + return handles diff --git a/src/lora_lite/config.py b/src/lora_lite/config.py new file mode 100644 index 0000000..3558783 --- /dev/null +++ b/src/lora_lite/config.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass, field, asdict +from typing import Any +import torch + + +@dataclass +class LoraLiteConfig: + variant: str = "lora" + r: int = 8 + alpha: float = 16.0 + dropout: float = 0.0 # currently ignored; variants may use cfg.variant_kwargs + dtype: torch.dtype = torch.bfloat16 + + # targeting + target_roles: tuple[str, ...] = ("reader", "writer") + target_names: tuple[str, ...] = () + exclude_names: tuple[str, ...] = ("lm_head", "embed_tokens") + layers: tuple[int, ...] | None = None + + # variant-specific bag (e.g. lambda0 for DeLoRA) + variant_kwargs: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict: + d = asdict(self) + d["dtype"] = str(self.dtype).removeprefix("torch.") + return d + + @classmethod + def from_dict(cls, d: dict) -> "LoraLiteConfig": + d = dict(d) + if isinstance(d.get("dtype"), str): + d["dtype"] = getattr(torch, d["dtype"]) + if isinstance(d.get("layers"), list): + d["layers"] = tuple(d["layers"]) + for k in ("target_roles", "target_names", "exclude_names"): + if isinstance(d.get(k), list): + d[k] = tuple(d[k]) + return cls(**d) diff --git a/src/lora_lite/target.py b/src/lora_lite/target.py new file mode 100644 index 0000000..2309585 --- /dev/null +++ b/src/lora_lite/target.py @@ -0,0 +1,59 @@ +"""Find linear-like targets by shape (reader/writer/inner) + name regex. + +Structural matching is deliberate: bnb Linear4bit/8bitLt are not nn.Linear, but +they expose in_features/out_features/weight and their forward already handles +dequantization. +""" +import re +from torch import nn + + +def is_linear_like(m: nn.Module) -> bool: + return ( + hasattr(m, "in_features") + and hasattr(m, "out_features") + and hasattr(m, "weight") + and callable(m) + ) + + +def _layer_idx(name: str) -> int | None: + m = re.search(r"\.layers?\.(\d+)\.", name) + return int(m.group(1)) if m else None + + +def _classify(m: nn.Module, d_model: int, name: str) -> str: + di, do = m.in_features, m.out_features + if di == d_model and do != d_model: + return "reader" + if do == d_model and di != d_model: + return "writer" + if di == d_model and do == d_model: + return "writer" if any(s in name for s in ("o_proj", "out_proj")) else "reader" + return "inner" + + +def find_targets(model: nn.Module, cfg) -> list[tuple[str, nn.Module, str]]: + # discover d_model: prefer config.hidden_size, else infer from largest Linear in_features + d_model = getattr(getattr(model, "config", None), "hidden_size", None) + if d_model is None: + dims = [m.in_features for m in model.modules() if is_linear_like(m)] + d_model = max(dims) if dims else 0 + + out = [] + for name, m in model.named_modules(): + if not is_linear_like(m): + continue + if any(re.search(p, name) for p in cfg.exclude_names): + continue + if cfg.layers is not None: + li = _layer_idx(name) + if li is None or li not in cfg.layers: + continue + role = _classify(m, d_model, name) + if cfg.target_roles and role not in cfg.target_roles: + continue + if cfg.target_names and not any(re.search(p, name) for p in cfg.target_names): + continue + out.append((name, m, role)) + return out diff --git a/src/lora_lite/variant.py b/src/lora_lite/variant.py new file mode 100644 index 0000000..dacb844 --- /dev/null +++ b/src/lora_lite/variant.py @@ -0,0 +1,56 @@ +"""Variant protocol + registry. Variants own (x, layer.weight, layer.lora_*) -> y_new.""" +from dataclasses import dataclass +from typing import Callable, Protocol, Any +import torch +from torch import nn + +from .config import LoraLiteConfig + + +@dataclass +class ParamSpec: + shape: tuple[int, ...] + init: str | Callable[[torch.Tensor], None] = "zeros" # 'zeros'|'kaiming'|'ones'|callable(t) + trainable: bool = True + + def make(self, dtype: torch.dtype, device) -> nn.Parameter: + t = torch.empty(self.shape, dtype=dtype, device=device) + if callable(self.init): + self.init(t) + elif self.init == "zeros": + t.zero_() + elif self.init == "ones": + t.fill_(1.0) + elif self.init == "kaiming": + # match nn.Linear default: kaiming_uniform_(a=sqrt(5)) + nn.init.kaiming_uniform_(t, a=5 ** 0.5) if t.ndim >= 2 else t.normal_(0, 0.02) + else: + raise ValueError(f"unknown init: {self.init}") + return nn.Parameter(t, requires_grad=self.trainable) + + +class Variant(Protocol): + name: str + + @staticmethod + def param_specs(d_in: int, d_out: int, cfg: LoraLiteConfig) -> dict[str, ParamSpec]: ... + + @staticmethod + def init(layer: nn.Linear, cfg: LoraLiteConfig) -> None: ... + + @staticmethod + def forward(layer: nn.Linear, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Return the layer's NEW output (not just delta). + Additive variants: `return y + delta`. + Replacing variants: ignore `y`, return new value.""" + ... + + +REGISTRY: dict[str, type] = {} + + +def register(cls): + if not getattr(cls, "name", None): + raise ValueError(f"variant {cls} missing .name") + REGISTRY[cls.name] = cls + return cls diff --git a/src/lora_lite/variants/__init__.py b/src/lora_lite/variants/__init__.py new file mode 100644 index 0000000..ea22188 --- /dev/null +++ b/src/lora_lite/variants/__init__.py @@ -0,0 +1 @@ +from . import lora, pissa, delora # noqa: F401 side-effect: register diff --git a/src/lora_lite/variants/delora.py b/src/lora_lite/variants/delora.py new file mode 100644 index 0000000..6b6e66d --- /dev/null +++ b/src/lora_lite/variants/delora.py @@ -0,0 +1,43 @@ +"""DeLoRA: column-normalised A, B, scaled by lambda/r. Bini et al. 2025 arXiv:2503.18225. + +NOTE on identity at t=0: paper uses kaiming for both A and B with a learned lambda +init at 0 (or small) so the effective delta starts near zero. We honour that: +default lambda0 == 0 gives bit-identity; user can override via variant_kwargs. +""" +import torch +import torch.nn.functional as F +from einops import einsum +from torch import nn + +from ..variant import register, ParamSpec + + +@register +class DeLoRA: + name = "delora" + + @staticmethod + def param_specs(d_in, d_out, cfg): + lam0 = float(cfg.variant_kwargs.get("lambda0", 0.0)) + return { + "lora_A": ParamSpec((cfg.r, d_in), init="kaiming", trainable=True), + "lora_B": ParamSpec((d_out, cfg.r), init="kaiming", trainable=True), + "lora_lambda": ParamSpec( + (), init=lambda t: t.fill_(lam0), trainable=True + ), + } + + @staticmethod + def init(layer: nn.Linear, cfg) -> None: + return + + @staticmethod + def forward(layer: nn.Linear, x, y): + cfg = layer._lora_cfg + # rows of A unit, cols of B unit (per paper) + A = F.normalize(layer.lora_A, dim=1) # (r, d_in) + B = F.normalize(layer.lora_B, dim=0) # (d_out, r) + scale = layer.lora_lambda / cfg.r + h = einsum(x, A, "... i, r i -> ... r") + delta = einsum(h, B, "... r, o r -> ... o") + return y + scale * delta diff --git a/src/lora_lite/variants/lora.py b/src/lora_lite/variants/lora.py new file mode 100644 index 0000000..b24e87b --- /dev/null +++ b/src/lora_lite/variants/lora.py @@ -0,0 +1,31 @@ +"""Vanilla LoRA. Reference variant. y = Wx + (alpha/r) * B @ A @ x.""" +from einops import einsum +from torch import nn +import torch + +from ..variant import register, ParamSpec + + +@register +class LoRA: + name = "lora" + + @staticmethod + def param_specs(d_in, d_out, cfg): + return { + "lora_A": ParamSpec((cfg.r, d_in), init="kaiming", trainable=True), + "lora_B": ParamSpec((d_out, cfg.r), init="zeros", trainable=True), + } + + @staticmethod + def init(layer: nn.Linear, cfg) -> None: + # B is zeros => delta=0 at t=0; identity invariant holds. + return + + @staticmethod + def forward(layer: nn.Linear, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + cfg = layer._lora_cfg + scale = cfg.alpha / cfg.r + h = einsum(x, layer.lora_A, "... i, r i -> ... r") + delta = einsum(h, layer.lora_B, "... r, o r -> ... o") + return y + scale * delta diff --git a/src/lora_lite/variants/pissa.py b/src/lora_lite/variants/pissa.py new file mode 100644 index 0000000..0c32a26 --- /dev/null +++ b/src/lora_lite/variants/pissa.py @@ -0,0 +1,54 @@ +"""PiSSA: top-r SVD of W into A,B; replace W with W_res = W - B@A. + +Meng et al. 2024 https://arxiv.org/abs/2404.02948 +W_eff(t=0) = W_res + B@A = W (numerically; bf16 round-trip not bit-exact). +""" +import torch +from einops import einsum +from torch import nn + +from ..variant import register, ParamSpec + + +@register +class PiSSA: + name = "pissa" + + @staticmethod + def param_specs(d_in, d_out, cfg): + return { + "lora_A": ParamSpec((cfg.r, d_in), init="zeros", trainable=True), + "lora_B": ParamSpec((d_out, cfg.r), init="zeros", trainable=True), + } + + @staticmethod + def init(layer: nn.Linear, cfg) -> None: + if type(layer) is not nn.Linear: + raise TypeError( + "PiSSA mutates layer.weight into W_res, so v1 only supports plain nn.Linear. " + "For bnb 4/8-bit, use LoRA/DeLoRA or implement explicit dequantize/requantize." + ) + W = layer.weight.data.float() # (d_out, d_in) + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + r = cfg.r + Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :] + sqrtS = Sr.sqrt() + # B @ A = Ur diag(Sr) Vhr; pick B = Ur sqrt(Sr), A = sqrt(Sr) * Vhr + B = (Ur * sqrtS).to(cfg.dtype) + A = (sqrtS[:, None] * Vhr).to(cfg.dtype) + layer.lora_B.data.copy_(B) + layer.lora_A.data.copy_(A) + # Compute BA in fp32 for the subtraction so W_res is accurate. + BA = (B.float() @ A.float()) + # NOTE: PiSSA uses scale=1 (alpha==r) implicitly via init. We let the user pick; + # for fidelity at t=0, the convention is alpha==r so scale==1. Document in README. + scale = cfg.alpha / cfg.r + layer.weight.data.copy_((W - scale * BA).to(layer.weight.dtype)) + + @staticmethod + def forward(layer: nn.Linear, x, y): + cfg = layer._lora_cfg + scale = cfg.alpha / cfg.r + h = einsum(x, layer.lora_A, "... i, r i -> ... r") + delta = einsum(h, layer.lora_B, "... r, o r -> ... o") + return y + scale * delta diff --git a/tests/smoke.py b/tests/smoke.py new file mode 100644 index 0000000..819b488 --- /dev/null +++ b/tests/smoke.py @@ -0,0 +1,243 @@ +"""Smoke test: lora / pissa / delora on a tiny synthetic transformer-like model. + +Verifies: + 1. Identity at t=0 (delta ~ 0, output close to base). + 2. Save/load round-trip preserves outputs. + 3. A few SGD steps reduce a random loss (gradients flow). + +Run: + cd lora-lite + python -m pip install -e . + python tests/smoke.py + +BLUF format: + SHOULD: max|y_adapter - y_base| < tol_init for all variants. ELSE init or hook bug. + SHOULD: loss decreases > 5% over 20 SGD steps for all variants. ELSE grad/wiring bug. +""" +from __future__ import annotations +import tempfile, os, sys, math +import torch +from torch import nn + +# allow running as `python tests/smoke.py` without install +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + +import lora_lite as ll # noqa: E402 + + +# ---- a tiny transformer-like stack: 4 blocks of (q,k,v,o, gate,up,down) Linears ---- +class TinyBlock(nn.Module): + def __init__(self, d=64, ff=128): + super().__init__() + self.q_proj = nn.Linear(d, d, bias=False) + self.k_proj = nn.Linear(d, d, bias=False) + self.v_proj = nn.Linear(d, d, bias=False) + self.o_proj = nn.Linear(d, d, bias=False) + self.gate_proj = nn.Linear(d, ff, bias=False) + self.up_proj = nn.Linear(d, ff, bias=False) + self.down_proj = nn.Linear(ff, d, bias=False) + + def forward(self, x): + h = self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x)) + m = self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x)) + return x + h + m + + +class TinyModel(nn.Module): + def __init__(self, n_layers=4, d=64, ff=128, vocab=100): + super().__init__() + self.embed_tokens = nn.Embedding(vocab, d) + self.layers = nn.ModuleList([TinyBlock(d, ff) for _ in range(n_layers)]) + self.lm_head = nn.Linear(d, vocab, bias=False) + + class Cfg: # mimic HF .config.hidden_size + hidden_size = d + self.config = Cfg() + + def forward(self, ids): + x = self.embed_tokens(ids) + for blk in self.layers: + x = blk(x) + return self.lm_head(x) + + +class FakeLinearLike(nn.Module): + """Not nn.Linear, but structurally bnb-like enough for target discovery.""" + + def __init__(self, d_in=8, d_out=8): + super().__init__() + self.in_features = d_in + self.out_features = d_out + self.weight = nn.Parameter(torch.empty(d_out, d_in)) + nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5) + + def forward(self, x): + return torch.nn.functional.linear(x, self.weight) + + +class FakeBnbModel(nn.Module): + def __init__(self): + super().__init__() + self.config = type("Cfg", (), {"hidden_size": 8})() + self.layers = nn.ModuleList([FakeLinearLike(8, 8)]) + + def forward(self, x): + return self.layers[0](x) + + +def variant_test(variant: str, dtype=torch.float32): + print(f"\n=== variant={variant} dtype={dtype} ===") + torch.manual_seed(0) + model = TinyModel().to(dtype) + ids = torch.randint(0, 100, (2, 16)) + + with torch.no_grad(): + y_base = model(ids).clone() + + cfg = ll.LoraLiteConfig( + variant=variant, + r=4, + alpha=4 if variant == "pissa" else 8, # PiSSA needs scale==1 for clean recon + dtype=dtype, + # delora identity-at-init demands lambda0==0 (then delta * scale = 0) + variant_kwargs={"lambda0": 0.0} if variant == "delora" else {}, + ) + handles = ll.attach(model, cfg) + n_targets = len(handles) + n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f" attached {n_targets} targets, trainable params={n_trainable}") + + with torch.no_grad(): + y_adapt = model(ids) + err = (y_adapt - y_base).abs().max().item() + base_scale = y_base.abs().max().item() + print(f" t=0 identity: max|y_adapt - y_base| = {err:.3e} (base scale {base_scale:.3e})") + + # variant-specific identity tolerance + tol = { + "lora": 1e-6, + "pissa": 5e-4, # SVD recon in fp32 is tight; bf16 would be ~1e-2 + "delora": 1e-6, # lambda0=0 + }[variant] * max(1.0, base_scale) + assert err < tol, f" FAIL identity: err {err} > tol {tol}" + print(f" SHOULD: err<{tol:.1e}. PASS.") + + # save/load round-trip + with tempfile.TemporaryDirectory() as d: + p = os.path.join(d, "adapter.pt") + ll.save(model, p) + # detach + fresh model + load + ll.detach(model) + torch.manual_seed(0) + model2 = TinyModel().to(dtype) + # for PiSSA, base weights got mutated; we need them mutated again for the load + # path to make sense. Easiest: re-attach with same cfg first... but that's what + # load() does. The catch: load reads cfg from the file, runs attach (which + # re-runs PiSSA init -> same SVD on same weights -> same A,B -> mutates W + # to the same W_res). Then state_dict overwrites lora_A/B with saved values. + ll.load(model2, p) + with torch.no_grad(): + y_loaded = model2(ids) + err2 = (y_loaded - y_adapt).abs().max().item() + print(f" save/load: max|y_loaded - y_adapt| = {err2:.3e}") + assert err2 < tol, f" FAIL save/load: {err2} > {tol}" + print(f" SHOULD: err2<{tol:.1e}. PASS.") + ll.detach(model2) + + # gradient flow: 20 SGD steps on random target. + # For delora, lambda0==0 makes A,B grads zero (scale=0); use lambda0=0.1 for training. + torch.manual_seed(0) + model = TinyModel().to(dtype) + train_cfg = cfg + if variant == "delora": + train_cfg = ll.LoraLiteConfig( + variant=cfg.variant, r=cfg.r, alpha=cfg.alpha, dtype=cfg.dtype, + variant_kwargs={"lambda0": 0.1}, + ) + ll.attach(model, train_cfg) + target = torch.randn(2, 16, 100, dtype=dtype) * 0.1 + trainable = [p for p in model.parameters() if p.requires_grad] + # delora has tightly-normalised updates; use Adam with higher lr to see signal in 20 steps + if variant == "delora": + opt = torch.optim.Adam(trainable, lr=1e-1) + else: + opt = torch.optim.SGD(trainable, lr=1e-2) + losses = [] + for step in range(20): + opt.zero_grad() + loss = (model(ids) - target).pow(2).mean() + loss.backward() + opt.step() + losses.append(loss.item()) + drop = (losses[0] - losses[-1]) / max(losses[0], 1e-12) + print(f" loss[0]={losses[0]:.4f} loss[-1]={losses[-1]:.4f} drop={100*drop:.1f}%") + assert drop > 0.05, f" FAIL: loss drop only {drop:.2%}, expected >5%" + print(f" SHOULD: drop>5%. PASS.") + + +def structural_linear_like_test(): + print("\n=== structural linear-like target test (bnb-style, not nn.Linear) ===") + torch.manual_seed(0) + model = FakeBnbModel() + x = torch.randn(2, 3, 8) + y_base = model(x).detach() + ll.attach(model, ll.LoraLiteConfig(variant="lora", r=2, alpha=4, dtype=torch.float32, target_roles=())) + layer = model.layers[0] + assert hasattr(layer, "lora_A") and hasattr(layer, "lora_B") + y = model(x) + err = (y.detach() - y_base).abs().max().item() + loss = y.pow(2).mean() + loss.backward() + grad_nonzero = layer.lora_B.grad.abs().sum().item() > 0 + print(f" attached lora_A={tuple(layer.lora_A.shape)} lora_B={tuple(layer.lora_B.shape)}") + print(f" identity_err={err:.3e} grad_nonzero={grad_nonzero}") + assert err == 0.0 + assert grad_nonzero + print(" SHOULD: structural target attaches and lora_B receives grad. PASS.") + + +def bitsandbytes_cuda_smoke(): + print("\n=== optional bitsandbytes CUDA smoke ===") + if not torch.cuda.is_available(): + print(" SKIP: CUDA unavailable; real bnb 4/8-bit forward needs GPU on this machine.") + return + try: + import bitsandbytes as bnb + except ImportError: + print(" SKIP: bitsandbytes unavailable.") + return + + class BnbModel(nn.Module): + def __init__(self, Layer): + super().__init__() + self.config = type("Cfg", (), {"hidden_size": 8})() + self.layers = nn.ModuleList([Layer(8, 8, bias=False)]).cuda() + + def forward(self, x): + return self.layers[0](x) + + for layer_cls in (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit): + torch.manual_seed(0) + model = BnbModel(layer_cls) + x = torch.randn(2, 3, 8, device="cuda") + y_base = model(x).detach() + ll.attach(model, ll.LoraLiteConfig(variant="lora", r=2, alpha=4, dtype=torch.float16, target_roles=())) + y = model(x) + err = (y.detach() - y_base).abs().max().item() + y.pow(2).mean().backward() + grad_nonzero = model.layers[0].lora_B.grad.abs().sum().item() > 0 + print(f" {layer_cls.__name__}: identity_err={err:.3e} grad_nonzero={grad_nonzero}") + assert err == 0.0 + assert grad_nonzero + + +def main(): + for v in ("lora", "pissa", "delora"): + variant_test(v, dtype=torch.float32) + structural_linear_like_test() + bitsandbytes_cuda_smoke() + print("\nALL PASS.") + + +if __name__ == "__main__": + main()