mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 18:24:15 +08:00
162 lines
5.8 KiB
Markdown
162 lines
5.8 KiB
Markdown
# lora-lite
|
|
|
|
A hackable, single-file-per-variant LoRA library built on PyTorch forward hooks.
|
|
|
|
- ~600 LoC total
|
|
- One file per variant, ~50 LoC each
|
|
- No module replacement, no merge/unmerge, no PEFT config soup
|
|
- Save = `torch.save({cfg, state_dict_filtered_by_'lora_'})`
|
|
- LoRA/DeLoRA forward hooks work with `nn.Linear` and bnb-style `Linear{4bit,8bitLt}` modules that expose `in_features`, `out_features`, and `weight`.
|
|
- PiSSA is fp-only in v1 because it mutates `weight` into `W_res`; quantized PiSSA needs explicit dequantize/requantize.
|
|
|
|
Currently shipped variants:
|
|
|
|
| Variant | Class | File |
|
|
|---|---|---|
|
|
| LoRA | A (additive) | [src/lora_lite/variants/lora.py](src/lora_lite/variants/lora.py) |
|
|
| PiSSA ([Meng+ 2024](https://arxiv.org/abs/2404.02948)) | A + B (special init mutates W) | [src/lora_lite/variants/pissa.py](src/lora_lite/variants/pissa.py) |
|
|
| DeLoRA ([Bini+ 2025](https://arxiv.org/abs/2503.18225)) | A (additive, normalised) | [src/lora_lite/variants/delora.py](src/lora_lite/variants/delora.py) |
|
|
|
|
See [docs/spec/20260426_lora_lite_plan.md](docs/spec/20260426_lora_lite_plan.md) for goals, status, TODOs, and the current design plan. The original broader design was stress-tested against the [adapters_as_hypotheses](https://github.com/wassname/adapters_as_hypotheses) catalog (~26/27 variants covered with 3 small API tweaks).
|
|
|
|
## Install
|
|
|
|
```bash
|
|
pip install -e .
|
|
```
|
|
|
|
## Quickstart
|
|
|
|
```python
|
|
import torch, lora_lite as ll
|
|
|
|
model = MyTransformer() # any nn.Module containing linear-like children
|
|
cfg = ll.LoraLiteConfig(variant="lora", r=8, alpha=16, dtype=torch.bfloat16)
|
|
handles = ll.attach(model, cfg)
|
|
|
|
# train
|
|
trainable = [p for p in model.parameters() if p.requires_grad]
|
|
opt = torch.optim.AdamW(trainable, lr=1e-4)
|
|
# ... your loop ...
|
|
|
|
ll.save(model, "adapter.pt")
|
|
ll.detach(model)
|
|
# later:
|
|
ll.load(model, "adapter.pt")
|
|
```
|
|
|
|
Inspect a tensor live:
|
|
|
|
```python
|
|
A = model.layers[5].self_attn.q_proj.lora_A # just an nn.Parameter
|
|
```
|
|
|
|
## Targeting
|
|
|
|
By default we target linear-like modules (`in_features`, `out_features`, `weight`) whose shape matches a "reader" (`d_in == d_model`) or "writer" (`d_out == d_model`) role, excluding `lm_head` and `embed_tokens`. This structural test is what lets bnb Linear4bit/8bitLt modules be targeted without a backend-specific class. Knobs on `LoraLiteConfig`:
|
|
|
|
- `target_roles`: subset of `("reader", "writer", "inner")`. `()` = all.
|
|
- `target_names`: regex includes (must match if non-empty).
|
|
- `exclude_names`: regex excludes (default skips `lm_head`, `embed_tokens`).
|
|
- `layers`: tuple of layer indices, or `None` for all (matches `.layers.<idx>.` in module name).
|
|
|
|
## Variant API
|
|
|
|
A variant is a class with a `name` and three statics:
|
|
|
|
```python
|
|
@register
|
|
class MyVariant:
|
|
name = "myvariant"
|
|
|
|
@staticmethod
|
|
def param_specs(d_in, d_out, cfg) -> dict[str, ParamSpec]:
|
|
return {"lora_A": ParamSpec((cfg.r, d_in), init="kaiming"), ...}
|
|
|
|
@staticmethod
|
|
def init(layer, cfg) -> None:
|
|
# Optional. Run after params are created. May read/mutate layer.weight.
|
|
...
|
|
|
|
@staticmethod
|
|
def forward(layer, x, y) -> Tensor:
|
|
# Return the layer's NEW output (additive: `return y + delta`).
|
|
...
|
|
```
|
|
|
|
Adapter params attached as `layer.lora_*` get full-path keys in `state_dict()` automatically (e.g. `model.layers.5.self_attn.q_proj.lora_A`).
|
|
|
|
## Data-calibrated init
|
|
|
|
PiSSA, DeLoRA, and LoRA only use `layer.weight` for init -- no calibration data needed.
|
|
|
|
For variants that DO need data (e.g. AntiPaSTO, LoRA-GA, activation-aware SVD), keep dataloaders out of `cfg` so adapter checkpoints stay serializable. Use:
|
|
|
|
```python
|
|
ll.attach(model, cfg, calibration_data=calib)
|
|
```
|
|
|
|
where `calib` is an iterable of whole-model inputs, e.g. `Iterable[dict[str, Tensor]]` for HF models or `Iterable[Tensor]` of token ids. Activation-aware variants implement:
|
|
|
|
```python
|
|
@staticmethod
|
|
def group_init(model, targets, cfg, calibration_data): ...
|
|
```
|
|
|
|
`targets` is `list[(name, layer, role)]`. The variant adds temporary hooks, runs `model(batch)` over `calibration_data`, removes the hooks, then writes `lora_*` params. Per-layer `init(layer, cfg)` stays weight-only.
|
|
|
|
Sketch:
|
|
|
|
```python
|
|
@register
|
|
class ActSVD:
|
|
name = "actsvd"
|
|
@staticmethod
|
|
def param_specs(d_in, d_out, cfg): ...
|
|
@staticmethod
|
|
def group_init(model, targets, cfg, calibration_data):
|
|
bufs = {name: [] for name, _, _ in targets}
|
|
hooks = [
|
|
layer.register_forward_pre_hook(
|
|
lambda m, args, name=name: bufs[name].append(args[0].detach().float())
|
|
)
|
|
for name, layer, _ in targets
|
|
]
|
|
try:
|
|
with torch.no_grad():
|
|
for batch in calibration_data:
|
|
model(**batch) if isinstance(batch, dict) else model(batch)
|
|
finally:
|
|
for h in hooks:
|
|
h.remove()
|
|
# For each target: X = torch.cat(bufs[name], dim=0); do SVD; write A/B.
|
|
```
|
|
|
|
## Smoke test
|
|
|
|
```bash
|
|
python tests/smoke.py
|
|
```
|
|
|
|
Verifies for each of `lora`, `pissa`, `delora`:
|
|
1. Identity at t=0: `max|y_adapter - y_base|` within float tolerance.
|
|
2. Save/load round-trip preserves outputs.
|
|
3. 20 SGD steps reduce a random regression loss by >5%.
|
|
|
|
## What's NOT in v1
|
|
|
|
| Feature | Why dropped |
|
|
|---|---|
|
|
| merge/unmerge | reload base if you want vanilla |
|
|
| 4/8-bit-aware merge | DoRA on bnb supported in forward only (drop merge path) |
|
|
| Embedding / Conv adapters | trivial extension; add when needed |
|
|
| `adapter_names=` mixed batch forward | rare; add when needed |
|
|
| Multiple named adapters per layer | one variant per `attach()` |
|
|
| HF `PeftConfig` / hub upload | `torch.save({cfg, state})` is enough |
|
|
| AdaLoRA-style rank scheduling | needs `Variant.on_step(step)` -- punt |
|
|
| ReFT-style position interventions | sibling submodule (different hook site) |
|
|
|
|
## Status
|
|
|
|
v0.0.1: lora + pissa + delora + smoke test. See spec for next variants (DoRA, VeRA, SSVD).
|