Files
lora-lite/README.md
T
wassname 4db5cee5a9 init
2026-04-26 14:10:20 +08:00

162 lines
5.8 KiB
Markdown

# lora-lite
A hackable, single-file-per-variant LoRA library built on PyTorch forward hooks.
- ~600 LoC total
- One file per variant, ~50 LoC each
- No module replacement, no merge/unmerge, no PEFT config soup
- Save = `torch.save({cfg, state_dict_filtered_by_'lora_'})`
- LoRA/DeLoRA forward hooks work with `nn.Linear` and bnb-style `Linear{4bit,8bitLt}` modules that expose `in_features`, `out_features`, and `weight`.
- PiSSA is fp-only in v1 because it mutates `weight` into `W_res`; quantized PiSSA needs explicit dequantize/requantize.
Currently shipped variants:
| Variant | Class | File |
|---|---|---|
| LoRA | A (additive) | [src/lora_lite/variants/lora.py](src/lora_lite/variants/lora.py) |
| PiSSA ([Meng+ 2024](https://arxiv.org/abs/2404.02948)) | A + B (special init mutates W) | [src/lora_lite/variants/pissa.py](src/lora_lite/variants/pissa.py) |
| DeLoRA ([Bini+ 2025](https://arxiv.org/abs/2503.18225)) | A (additive, normalised) | [src/lora_lite/variants/delora.py](src/lora_lite/variants/delora.py) |
See [docs/spec/20260426_lora_lite_plan.md](docs/spec/20260426_lora_lite_plan.md) for goals, status, TODOs, and the current design plan. The original broader design was stress-tested against the [adapters_as_hypotheses](https://github.com/wassname/adapters_as_hypotheses) catalog (~26/27 variants covered with 3 small API tweaks).
## Install
```bash
pip install -e .
```
## Quickstart
```python
import torch, lora_lite as ll
model = MyTransformer() # any nn.Module containing linear-like children
cfg = ll.LoraLiteConfig(variant="lora", r=8, alpha=16, dtype=torch.bfloat16)
handles = ll.attach(model, cfg)
# train
trainable = [p for p in model.parameters() if p.requires_grad]
opt = torch.optim.AdamW(trainable, lr=1e-4)
# ... your loop ...
ll.save(model, "adapter.pt")
ll.detach(model)
# later:
ll.load(model, "adapter.pt")
```
Inspect a tensor live:
```python
A = model.layers[5].self_attn.q_proj.lora_A # just an nn.Parameter
```
## Targeting
By default we target linear-like modules (`in_features`, `out_features`, `weight`) whose shape matches a "reader" (`d_in == d_model`) or "writer" (`d_out == d_model`) role, excluding `lm_head` and `embed_tokens`. This structural test is what lets bnb Linear4bit/8bitLt modules be targeted without a backend-specific class. Knobs on `LoraLiteConfig`:
- `target_roles`: subset of `("reader", "writer", "inner")`. `()` = all.
- `target_names`: regex includes (must match if non-empty).
- `exclude_names`: regex excludes (default skips `lm_head`, `embed_tokens`).
- `layers`: tuple of layer indices, or `None` for all (matches `.layers.<idx>.` in module name).
## Variant API
A variant is a class with a `name` and three statics:
```python
@register
class MyVariant:
name = "myvariant"
@staticmethod
def param_specs(d_in, d_out, cfg) -> dict[str, ParamSpec]:
return {"lora_A": ParamSpec((cfg.r, d_in), init="kaiming"), ...}
@staticmethod
def init(layer, cfg) -> None:
# Optional. Run after params are created. May read/mutate layer.weight.
...
@staticmethod
def forward(layer, x, y) -> Tensor:
# Return the layer's NEW output (additive: `return y + delta`).
...
```
Adapter params attached as `layer.lora_*` get full-path keys in `state_dict()` automatically (e.g. `model.layers.5.self_attn.q_proj.lora_A`).
## Data-calibrated init
PiSSA, DeLoRA, and LoRA only use `layer.weight` for init -- no calibration data needed.
For variants that DO need data (e.g. AntiPaSTO, LoRA-GA, activation-aware SVD), keep dataloaders out of `cfg` so adapter checkpoints stay serializable. Use:
```python
ll.attach(model, cfg, calibration_data=calib)
```
where `calib` is an iterable of whole-model inputs, e.g. `Iterable[dict[str, Tensor]]` for HF models or `Iterable[Tensor]` of token ids. Activation-aware variants implement:
```python
@staticmethod
def group_init(model, targets, cfg, calibration_data): ...
```
`targets` is `list[(name, layer, role)]`. The variant adds temporary hooks, runs `model(batch)` over `calibration_data`, removes the hooks, then writes `lora_*` params. Per-layer `init(layer, cfg)` stays weight-only.
Sketch:
```python
@register
class ActSVD:
name = "actsvd"
@staticmethod
def param_specs(d_in, d_out, cfg): ...
@staticmethod
def group_init(model, targets, cfg, calibration_data):
bufs = {name: [] for name, _, _ in targets}
hooks = [
layer.register_forward_pre_hook(
lambda m, args, name=name: bufs[name].append(args[0].detach().float())
)
for name, layer, _ in targets
]
try:
with torch.no_grad():
for batch in calibration_data:
model(**batch) if isinstance(batch, dict) else model(batch)
finally:
for h in hooks:
h.remove()
# For each target: X = torch.cat(bufs[name], dim=0); do SVD; write A/B.
```
## Smoke test
```bash
python tests/smoke.py
```
Verifies for each of `lora`, `pissa`, `delora`:
1. Identity at t=0: `max|y_adapter - y_base|` within float tolerance.
2. Save/load round-trip preserves outputs.
3. 20 SGD steps reduce a random regression loss by >5%.
## What's NOT in v1
| Feature | Why dropped |
|---|---|
| merge/unmerge | reload base if you want vanilla |
| 4/8-bit-aware merge | DoRA on bnb supported in forward only (drop merge path) |
| Embedding / Conv adapters | trivial extension; add when needed |
| `adapter_names=` mixed batch forward | rare; add when needed |
| Multiple named adapters per layer | one variant per `attach()` |
| HF `PeftConfig` / hub upload | `torch.save({cfg, state})` is enough |
| AdaLoRA-style rank scheduling | needs `Variant.on_step(step)` -- punt |
| ReFT-style position interventions | sibling submodule (different hook site) |
## Status
v0.0.1: lora + pissa + delora + smoke test. See spec for next variants (DoRA, VeRA, SSVD).