mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 17:16:12 +08:00
test: prove adapter training paths
This commit is contained in:
+4
-1
@@ -7,4 +7,7 @@ __pycache__/
|
||||
*.db
|
||||
*.sqlite3
|
||||
*.log
|
||||
*.bak
|
||||
*.bak
|
||||
logs/
|
||||
outputs/
|
||||
tests/_artifacts/
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
A hackable, single-file-per-variant LoRA library built on PyTorch forward hooks.
|
||||
|
||||
<!-- Human this is too long! should be into, code, procedure to get started. why, what is it, an example of a hackable adapter, then minimal features, how to install and use, and links/citation -->
|
||||
|
||||
- ~600 LoC total
|
||||
- One file per variant, ~50 LoC each
|
||||
- No module replacement, no merge/unmerge, no PEFT config soup
|
||||
@@ -135,13 +137,29 @@ class ActSVD:
|
||||
## Smoke test
|
||||
|
||||
```bash
|
||||
python tests/smoke.py
|
||||
just test
|
||||
just smoke
|
||||
just qwen-probe
|
||||
```
|
||||
|
||||
Verifies for each of `lora`, `pissa`, `delora`:
|
||||
`just test` verifies, for each of `lora`, `pissa`, `delora`:
|
||||
|
||||
1. Identity at t=0: `max|y_adapter - y_base|` within float tolerance.
|
||||
2. Save/load round-trip preserves outputs.
|
||||
3. 20 SGD steps reduce a random regression loss by >5%.
|
||||
2. Adapter hooks are live: perturbing only `lora_*` changes outputs.
|
||||
3. Save/load round-trip preserves full-path adapter keys and tensors.
|
||||
4. Missing or unexpected `lora_*` checkpoint keys fail loudly.
|
||||
5. Only `lora_*` parameters are trainable and base parameters get no gradients.
|
||||
6. A 20-step tiny regression training probe gets finite nonzero adapter gradients and >5% loss drop.
|
||||
|
||||
`just qwen-probe` is the real-model proof. It loads `Qwen/Qwen3-0.6B` fresh per variant, attaches only layer-0 `q_proj`/`v_proj`, trains one fixed LM batch, saves adapters, reloads into a fresh base model, and checks logits match. Last verified on 2026-04-26:
|
||||
|
||||
| variant | targets | trainable | identity err | perturb delta | loss0 | lossN | drop % | grad norm | adapter delta | reload err |
|
||||
|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|
|
||||
| LoRA | 2 | 20,480 | 0 | 0.3750 | 5.250 | 3.131 | 40.36 | 1.432 | 4.262 | 0 |
|
||||
| PiSSA | 2 | 20,480 | 0.3125 | 0.7500 | 5.250 | 3.629 | 30.88 | 6.124 | 4.381 | 0 |
|
||||
| DeLoRA | 2 | 20,482 | 0.3750 | 0.4062 | 5.246 | 5.166 | 1.537 | 0.04778 | 8.196 | 0 |
|
||||
|
||||
This is an interface/training proof, not a benchmark: exact Qwen target names, hook activity, lora-only gradients, loss decrease, adapter tensor save/load, and reload equivalence on a 0.6B HF model.
|
||||
|
||||
## What's NOT in v1
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# 2026-04-26 code review: testing proof
|
||||
|
||||
## External review
|
||||
|
||||
Reviewer: Gemini 2.5 Flash CLI, read-only prompt.
|
||||
|
||||
Findings:
|
||||
|
||||
- Critical: `tests/smoke.py` could silently pass if base gradients leaked because it did not check non-`lora_*` grads.
|
||||
- Important: `tests/smoke.py` did not explicitly assert the expected number of attached TinyModel targets.
|
||||
|
||||
Resolution:
|
||||
|
||||
- Added `assert_no_base_grads(model)` to the smoke training loop.
|
||||
- Added `assert n_targets == 28` immediately after smoke attach.
|
||||
- Re-ran `just test` and `just smoke`; both passed.
|
||||
|
||||
## Fresh-eyes subagent review
|
||||
|
||||
Verdict: PASS.
|
||||
|
||||
The reviewer could not name a remaining blocker for skipped targets, dead hooks, base-gradient leakage, or broken save/load producing the collected evidence. Caveat: Qwen coverage is intentionally narrow, layer-0 `q_proj`/`v_proj`, one prompt, tiny steps. This supports interface/training proof, not downstream finetuning quality.
|
||||
@@ -55,6 +55,59 @@ Last verified log: `/home/wassname/.cache/agent-tmp/lora_lite_smoke_after_review
|
||||
| bnb `Linear8bitLt` | identity `0.000e+00`, grad nonzero |
|
||||
| bnb `Linear4bit` | identity `0.000e+00`, grad nonzero |
|
||||
|
||||
## 2026-04-26 testing proof pass
|
||||
|
||||
Goal: upgrade from smoke-tested sketch to evidence that the current PEFT-lite interface trains on both toy models and a real HF Qwen model.
|
||||
|
||||
### Scope
|
||||
|
||||
In:
|
||||
|
||||
- Pytest coverage for LoRA, PiSSA, and DeLoRA correctness invariants.
|
||||
- A real `Qwen/Qwen3-0.6B` probe that trains each current variant on layer-0 `q_proj` and `v_proj`.
|
||||
- Repeatable `just` recipes and workspace-local logs/artifacts.
|
||||
|
||||
Out:
|
||||
|
||||
- Benchmark claims.
|
||||
- Quantized Qwen proof for PiSSA. PiSSA remains fp-only because it mutates `weight`.
|
||||
- Full default-target training over every Qwen layer.
|
||||
|
||||
### Requirements and evidence
|
||||
|
||||
| Requirement | Distinguishing check | Evidence |
|
||||
|---|---|---|
|
||||
| R1: toy tests catch skipped targets/hooks | Perturb only `lora_*`; output must change. Missing target must raise. | `just test` -> `8 passed in 2.43s` in `logs/pytest.log` |
|
||||
| R2: toy tests catch base-gradient leakage | After backward, all non-`lora_*` grads are `None`; all trainable names contain `lora_`. | `just test` -> `8 passed in 2.43s` |
|
||||
| R3: save/load is exact for adapters | Saved key set equals full-path `lora_*` state; reload tensors equal; missing/extra `lora_*` keys raise. | `just test` -> `8 passed in 2.43s` |
|
||||
| R4: current variants train on tiny task | 28 TinyModel targets; non-`lora_*` grads stay `None`; 20-step loss drop >5%. | `just smoke` -> LoRA 6.1%, PiSSA 11.5%, DeLoRA 93.4% |
|
||||
| R5: current variants train on real Qwen | Fresh Qwen per variant; exact targets are layer-0 `q_proj`/`v_proj`; perturb >0; lossN < loss0; reload err < tol. | `pueue` task 70, `logs/qwen_probe.log`, all probes pass |
|
||||
| R6: cold review cannot explain evidence under silent failure | External review findings fixed, then fresh-eyes subagent says PASS. | `docs/spec/20260426_code_review.md` |
|
||||
|
||||
### Qwen proof table
|
||||
|
||||
Command:
|
||||
|
||||
```bash
|
||||
pueue add --immediate --follow --label "why: verify warning-free current Qwen probe after dtype API cleanup; resolve: same pass table proves current script" --working-directory "$PWD" --priority 1 -- just qwen-probe
|
||||
```
|
||||
|
||||
Result from task 70:
|
||||
|
||||
| variant | targets | trainable | id_err | perturb | loss0 | lossN | drop% | grad | dθ | reload | adapter |
|
||||
|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---|
|
||||
| lora | 2 | 20480 | 0 | 0.375 | 5.25 | 3.131 | 40.36 | 1.432 | 4.262 | 0 | `outputs/qwen_train_probe/lora_adapter.pt` |
|
||||
| pissa | 2 | 20480 | 0.3125 | 0.75 | 5.25 | 3.629 | 30.88 | 6.124 | 4.381 | 0 | `outputs/qwen_train_probe/pissa_adapter.pt` |
|
||||
| delora | 2 | 20482 | 0.375 | 0.4062 | 5.246 | 5.166 | 1.537 | 0.04778 | 8.196 | 0 | `outputs/qwen_train_probe/delora_adapter.pt` |
|
||||
|
||||
Failure-mode interpretation:
|
||||
|
||||
- If targeting silently skipped, exact target-set assertion would fail before training.
|
||||
- If hooks were attached but dead, perturb delta would be 0.
|
||||
- If base params trained, the non-`lora_*` gradient check would fail.
|
||||
- If adapter grads were absent, `grad` or `dθ` would be 0/non-finite.
|
||||
- If save/load were broken, adapter tensor equality or reload logit error would fail.
|
||||
|
||||
## Review history
|
||||
|
||||
A cold subagent review first returned `PASS_WITH_BLOCKERS`:
|
||||
@@ -142,3 +195,6 @@ This repo is good enough for a first real experiment when:
|
||||
2. A 4bit or 8bit loaded model can train LoRA/DeLoRA params with nonzero gradients.
|
||||
3. Saved adapter tensors use full-path keys and reload without calibration data.
|
||||
4. Smoke tests distinguish target-skipping, hook identity drift, and missing-key load failure.
|
||||
|
||||
see interesting adapters here https://github.com/wassname/adapters_as_hypotheses
|
||||
how peft handle 4bit here https://github.com/huggingface/peft/blob/6030f9160ed2fc17220f6f41382a66f1257b6a93/src/peft/tuners/lora/layer.py
|
||||
@@ -0,0 +1,13 @@
|
||||
set shell := ["bash", "-cu"]
|
||||
|
||||
default:
|
||||
@just --list
|
||||
|
||||
test:
|
||||
uv run --extra test pytest -q
|
||||
|
||||
smoke:
|
||||
uv run --extra test python tests/smoke.py
|
||||
|
||||
qwen-probe:
|
||||
uv run --extra test --extra hf-test python scripts/qwen_train_probe.py
|
||||
+5
-1
@@ -10,7 +10,8 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = ["pytest"]
|
||||
test = ["pytest", "tabulate"]
|
||||
hf-test = ["accelerate>=1.6", "safetensors>=0.5", "transformers>=4.51"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=68"]
|
||||
@@ -18,3 +19,6 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.uv]
|
||||
exclude-newer = "5 days"
|
||||
@@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import lora_lite as ll
|
||||
|
||||
|
||||
PROMPT = "LoRA-lite probe: Paris is the capital of France. The answer is"
|
||||
EXPECTED_TARGETS = {
|
||||
"model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.v_proj",
|
||||
}
|
||||
|
||||
|
||||
def cfg_for_variant(variant: str, dtype: torch.dtype, r: int, alpha: float) -> ll.LoraLiteConfig:
|
||||
return ll.LoraLiteConfig(
|
||||
variant=variant,
|
||||
r=r,
|
||||
alpha=r if variant == "pissa" else alpha,
|
||||
dtype=dtype,
|
||||
target_roles=(),
|
||||
target_names=(r"model\.layers\.0\.self_attn\.(q_proj|v_proj)$",),
|
||||
layers=(0,),
|
||||
variant_kwargs={"lambda0": 0.1} if variant == "delora" else {},
|
||||
)
|
||||
|
||||
|
||||
def adapter_state(model: torch.nn.Module) -> dict[str, torch.Tensor]:
|
||||
return {k: v.detach().clone() for k, v in model.state_dict().items() if "lora_" in k}
|
||||
|
||||
|
||||
def assert_only_lora_trainable(model: torch.nn.Module) -> None:
|
||||
trainable = [name for name, p in model.named_parameters() if p.requires_grad]
|
||||
assert trainable
|
||||
assert all("lora_" in name for name in trainable), trainable[:20]
|
||||
|
||||
|
||||
def assert_no_base_grads(model: torch.nn.Module) -> None:
|
||||
leaked = [name for name, p in model.named_parameters() if "lora_" not in name and p.grad is not None]
|
||||
assert leaked == [], leaked[:20]
|
||||
|
||||
|
||||
def perturb_first_adapter(model: torch.nn.Module) -> None:
|
||||
for name, p in model.named_parameters():
|
||||
if "lora_lambda" in name:
|
||||
with torch.no_grad():
|
||||
p.add_(0.25)
|
||||
return
|
||||
for name, p in model.named_parameters():
|
||||
if "lora_B" in name:
|
||||
with torch.no_grad():
|
||||
p.flatten()[0].add_(0.25)
|
||||
return
|
||||
raise AssertionError("no perturbable adapter parameter found")
|
||||
|
||||
|
||||
def load_model(model_id: str, dtype: torch.dtype, device: str):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype).to(device)
|
||||
model.config.use_cache = False
|
||||
return model
|
||||
|
||||
|
||||
def run_variant(args, variant: str, input_ids: torch.Tensor, labels: torch.Tensor, dtype: torch.dtype):
|
||||
model = load_model(args.model, dtype, args.device)
|
||||
model.train()
|
||||
cfg = cfg_for_variant(variant, dtype, args.r, args.alpha)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_base = model(input_ids=input_ids).logits.detach().clone()
|
||||
|
||||
ll.attach(model, cfg)
|
||||
attached_targets = set(getattr(model, "_lora_lite_attached")["targets"])
|
||||
assert attached_targets == EXPECTED_TARGETS, attached_targets
|
||||
assert_only_lora_trainable(model)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_init = model(input_ids=input_ids).logits.detach().clone()
|
||||
identity_err = (logits_init - logits_base).abs().max().item()
|
||||
|
||||
clean_adapter = adapter_state(model)
|
||||
perturb_first_adapter(model)
|
||||
with torch.no_grad():
|
||||
perturb_delta = (model(input_ids=input_ids).logits - logits_init).abs().max().item()
|
||||
assert perturb_delta > 1e-7, perturb_delta
|
||||
for name, value in clean_adapter.items():
|
||||
model.state_dict()[name].copy_(value)
|
||||
|
||||
opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=args.lr)
|
||||
with torch.no_grad():
|
||||
loss0 = model(input_ids=input_ids, labels=labels).loss.item()
|
||||
|
||||
before_train = adapter_state(model)
|
||||
first_grad_norm = math.nan
|
||||
loss_last = math.nan
|
||||
for step in range(args.steps):
|
||||
opt.zero_grad()
|
||||
loss = model(input_ids=input_ids, labels=labels).loss
|
||||
loss.backward()
|
||||
assert_no_base_grads(model)
|
||||
grad_norm = sum(
|
||||
p.grad.detach().float().norm().item()
|
||||
for name, p in model.named_parameters()
|
||||
if "lora_" in name and p.grad is not None
|
||||
)
|
||||
assert math.isfinite(grad_norm), grad_norm
|
||||
if step == 0:
|
||||
first_grad_norm = grad_norm
|
||||
opt.step()
|
||||
loss_last = loss.item()
|
||||
|
||||
after_train = adapter_state(model)
|
||||
adapter_delta = sum((after_train[k] - before_train[k]).float().norm().item() for k in before_train)
|
||||
assert first_grad_norm > 0, first_grad_norm
|
||||
assert adapter_delta > 0, adapter_delta
|
||||
assert loss_last < loss0, (loss0, loss_last)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
logits_trained = model(input_ids=input_ids).logits.detach().clone()
|
||||
|
||||
out_path = args.out_dir / f"{variant}_adapter.pt"
|
||||
ll.save(model, str(out_path))
|
||||
saved = torch.load(out_path, weights_only=True, map_location="cpu")
|
||||
assert set(saved["state"]) == set(after_train)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
loaded_model = load_model(args.model, dtype, args.device)
|
||||
loaded_model.eval()
|
||||
ll.load(loaded_model, str(out_path))
|
||||
loaded_state = adapter_state(loaded_model)
|
||||
for name, value in saved["state"].items():
|
||||
assert torch.equal(loaded_state[name].cpu(), value)
|
||||
with torch.no_grad():
|
||||
logits_loaded = loaded_model(input_ids=input_ids).logits.detach().clone()
|
||||
reload_err = (logits_loaded - logits_trained).abs().max().item()
|
||||
assert reload_err < args.reload_tol, reload_err
|
||||
|
||||
del loaded_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {
|
||||
"variant": variant,
|
||||
"targets": len(attached_targets),
|
||||
"trainable": sum(v.numel() for v in after_train.values()),
|
||||
"id_err": identity_err,
|
||||
"perturb": perturb_delta,
|
||||
"loss0": loss0,
|
||||
"lossN": loss_last,
|
||||
"drop%": 100 * (loss0 - loss_last) / loss0,
|
||||
"grad": first_grad_norm,
|
||||
"dθ": adapter_delta,
|
||||
"reload": reload_err,
|
||||
"out": str(out_path),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default="Qwen/Qwen3-0.6B")
|
||||
parser.add_argument("--variants", nargs="+", default=["lora", "pissa", "delora"])
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--torch-dtype", default="bfloat16")
|
||||
parser.add_argument("--steps", type=int, default=8)
|
||||
parser.add_argument("--lr", type=float, default=5e-3)
|
||||
parser.add_argument("--r", type=int, default=4)
|
||||
parser.add_argument("--alpha", type=float, default=8.0)
|
||||
parser.add_argument("--reload-tol", type=float, default=2e-2)
|
||||
parser.add_argument("--out-dir", type=Path, default=Path("outputs/qwen_train_probe"))
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device == "cuda" and not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is required for the default Qwen probe. Pass --device cpu explicitly for local debugging.")
|
||||
|
||||
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
dtype = getattr(torch, args.torch_dtype)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
input_ids = tokenizer(PROMPT, return_tensors="pt").input_ids.to(args.device)
|
||||
labels = input_ids.clone()
|
||||
|
||||
print("SHOULD: exact q_proj/v_proj layer-0 targets, lora-only grads, lossN<loss0, perturb>0, reload<tol. ELSE hook/target/train/save bug.")
|
||||
rows = [run_variant(args, variant, input_ids, labels, dtype) for variant in args.variants]
|
||||
print(tabulate(rows, headers="keys", tablefmt="tsv", floatfmt=".4g"))
|
||||
print("ALL QWEN PROBES PASS")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+28
-20
@@ -15,7 +15,8 @@ BLUF format:
|
||||
SHOULD: loss decreases > 5% over 20 SGD steps for all variants. ELSE grad/wiring bug.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import tempfile, os, sys, math
|
||||
import os, sys, math
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -25,6 +26,14 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
import lora_lite as ll # noqa: E402
|
||||
|
||||
|
||||
ARTIFACT_DIR = Path(__file__).parent / "_artifacts"
|
||||
|
||||
|
||||
def assert_no_base_grads(model: nn.Module) -> None:
|
||||
leaked = [name for name, p in model.named_parameters() if "lora_" not in name and p.grad is not None]
|
||||
assert leaked == [], f"base params received grads: {leaked}"
|
||||
|
||||
|
||||
# ---- a tiny transformer-like stack: 4 blocks of (q,k,v,o, gate,up,down) Linears ----
|
||||
class TinyBlock(nn.Module):
|
||||
def __init__(self, d=64, ff=128):
|
||||
@@ -106,6 +115,7 @@ def variant_test(variant: str, dtype=torch.float32):
|
||||
n_targets = len(handles)
|
||||
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f" attached {n_targets} targets, trainable params={n_trainable}")
|
||||
assert n_targets == 28, f"expected 28 TinyModel targets, got {n_targets}"
|
||||
|
||||
with torch.no_grad():
|
||||
y_adapt = model(ids)
|
||||
@@ -123,25 +133,22 @@ def variant_test(variant: str, dtype=torch.float32):
|
||||
print(f" SHOULD: err<{tol:.1e}. PASS.")
|
||||
|
||||
# save/load round-trip
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
p = os.path.join(d, "adapter.pt")
|
||||
ll.save(model, p)
|
||||
# detach + fresh model + load
|
||||
ll.detach(model)
|
||||
torch.manual_seed(0)
|
||||
model2 = TinyModel().to(dtype)
|
||||
# for PiSSA, base weights got mutated; we need them mutated again for the load
|
||||
# path to make sense. Easiest: re-attach with same cfg first... but that's what
|
||||
# load() does. The catch: load reads cfg from the file, runs attach (which
|
||||
# re-runs PiSSA init -> same SVD on same weights -> same A,B -> mutates W
|
||||
# to the same W_res). Then state_dict overwrites lora_A/B with saved values.
|
||||
ll.load(model2, p)
|
||||
with torch.no_grad():
|
||||
y_loaded = model2(ids)
|
||||
err2 = (y_loaded - y_adapt).abs().max().item()
|
||||
print(f" save/load: max|y_loaded - y_adapt| = {err2:.3e}")
|
||||
assert err2 < tol, f" FAIL save/load: {err2} > {tol}"
|
||||
print(f" SHOULD: err2<{tol:.1e}. PASS.")
|
||||
ARTIFACT_DIR.mkdir(exist_ok=True)
|
||||
p = ARTIFACT_DIR / f"{variant}_smoke_adapter.pt"
|
||||
ll.save(model, str(p))
|
||||
# detach + fresh model + load
|
||||
ll.detach(model)
|
||||
torch.manual_seed(0)
|
||||
model2 = TinyModel().to(dtype)
|
||||
# for PiSSA, base weights got mutated; load() re-runs PiSSA init on the fresh
|
||||
# same-seed base, then overwrites lora_A/B with saved values.
|
||||
ll.load(model2, str(p))
|
||||
with torch.no_grad():
|
||||
y_loaded = model2(ids)
|
||||
err2 = (y_loaded - y_adapt).abs().max().item()
|
||||
print(f" save/load: max|y_loaded - y_adapt| = {err2:.3e}")
|
||||
assert err2 < tol, f" FAIL save/load: {err2} > {tol}"
|
||||
print(f" SHOULD: err2<{tol:.1e}. PASS.")
|
||||
ll.detach(model2)
|
||||
|
||||
# gradient flow: 20 SGD steps on random target.
|
||||
@@ -167,6 +174,7 @@ def variant_test(variant: str, dtype=torch.float32):
|
||||
opt.zero_grad()
|
||||
loss = (model(ids) - target).pow(2).mean()
|
||||
loss.backward()
|
||||
assert_no_base_grads(model)
|
||||
opt.step()
|
||||
losses.append(loss.item())
|
||||
drop = (losses[0] - losses[-1]) / max(losses[0], 1e-12)
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import lora_lite as ll
|
||||
|
||||
|
||||
ARTIFACT_DIR = Path(__file__).parent / "_artifacts"
|
||||
|
||||
|
||||
class TinyBlock(nn.Module):
|
||||
def __init__(self, d: int = 64, ff: int = 128):
|
||||
super().__init__()
|
||||
self.q_proj = nn.Linear(d, d, bias=False)
|
||||
self.k_proj = nn.Linear(d, d, bias=False)
|
||||
self.v_proj = nn.Linear(d, d, bias=False)
|
||||
self.o_proj = nn.Linear(d, d, bias=False)
|
||||
self.gate_proj = nn.Linear(d, ff, bias=False)
|
||||
self.up_proj = nn.Linear(d, ff, bias=False)
|
||||
self.down_proj = nn.Linear(ff, d, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x))
|
||||
m = self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return x + h + m
|
||||
|
||||
|
||||
class TinyModel(nn.Module):
|
||||
def __init__(self, n_layers: int = 4, d: int = 64, ff: int = 128, vocab: int = 100):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(vocab, d)
|
||||
self.layers = nn.ModuleList([TinyBlock(d, ff) for _ in range(n_layers)])
|
||||
self.lm_head = nn.Linear(d, vocab, bias=False)
|
||||
self.config = type("Cfg", (), {"hidden_size": d})()
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
x = self.embed_tokens(ids)
|
||||
for block in self.layers:
|
||||
x = block(x)
|
||||
return self.lm_head(x)
|
||||
|
||||
|
||||
class FakeLinearLike(nn.Module):
|
||||
def __init__(self, d_in: int = 8, d_out: int = 8):
|
||||
super().__init__()
|
||||
self.in_features = d_in
|
||||
self.out_features = d_out
|
||||
self.weight = nn.Parameter(torch.empty(d_out, d_in))
|
||||
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.linear(x, self.weight)
|
||||
|
||||
|
||||
class FakeBnbModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.config = type("Cfg", (), {"hidden_size": 8})()
|
||||
self.layers = nn.ModuleList([FakeLinearLike(8, 8)])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.layers[0](x)
|
||||
|
||||
|
||||
def cfg_for_variant(variant: str, *, training: bool = False) -> ll.LoraLiteConfig:
|
||||
return ll.LoraLiteConfig(
|
||||
variant=variant,
|
||||
r=4,
|
||||
alpha=4 if variant == "pissa" else 8,
|
||||
dtype=torch.float32,
|
||||
variant_kwargs={"lambda0": 0.1 if training else 0.0} if variant == "delora" else {},
|
||||
)
|
||||
|
||||
|
||||
def adapter_state(model: nn.Module) -> dict[str, torch.Tensor]:
|
||||
return {k: v.detach().clone() for k, v in model.state_dict().items() if "lora_" in k}
|
||||
|
||||
|
||||
def assert_only_lora_trainable(model: nn.Module) -> None:
|
||||
trainable_names = [name for name, p in model.named_parameters() if p.requires_grad]
|
||||
assert trainable_names
|
||||
assert all("lora_" in name for name in trainable_names)
|
||||
|
||||
|
||||
def assert_no_base_grads(model: nn.Module) -> None:
|
||||
leaked = [name for name, p in model.named_parameters() if "lora_" not in name and p.grad is not None]
|
||||
assert leaked == []
|
||||
|
||||
|
||||
def perturb_first_adapter(model: nn.Module) -> None:
|
||||
for name, p in model.named_parameters():
|
||||
if "lora_lambda" in name:
|
||||
with torch.no_grad():
|
||||
p.add_(0.25)
|
||||
return
|
||||
for name, p in model.named_parameters():
|
||||
if "lora_B" in name:
|
||||
with torch.no_grad():
|
||||
p.flatten()[0].add_(0.25)
|
||||
return
|
||||
raise AssertionError("no perturbable adapter parameter found")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("variant", ["lora", "pissa", "delora"])
|
||||
def test_variant_identity_hook_save_load_and_training(variant: str):
|
||||
ARTIFACT_DIR.mkdir(exist_ok=True)
|
||||
torch.manual_seed(0)
|
||||
model = TinyModel()
|
||||
ids = torch.randint(0, 100, (2, 16))
|
||||
|
||||
with torch.no_grad():
|
||||
y_base = model(ids).clone()
|
||||
|
||||
cfg = cfg_for_variant(variant)
|
||||
handles = ll.attach(model, cfg)
|
||||
assert len(handles) == 28
|
||||
assert_only_lora_trainable(model)
|
||||
|
||||
with torch.no_grad():
|
||||
y_init = model(ids).clone()
|
||||
identity_err = (y_init - y_base).abs().max().item()
|
||||
identity_tol = {"lora": 1e-6, "pissa": 5e-4, "delora": 1e-6}[variant]
|
||||
assert identity_err < identity_tol
|
||||
|
||||
before_perturb = adapter_state(model)
|
||||
perturb_first_adapter(model)
|
||||
with torch.no_grad():
|
||||
perturb_delta = (model(ids) - y_init).abs().max().item()
|
||||
assert perturb_delta > 1e-7
|
||||
for name, value in before_perturb.items():
|
||||
model.state_dict()[name].copy_(value)
|
||||
|
||||
path = ARTIFACT_DIR / f"{variant}_adapter.pt"
|
||||
ll.save(model, str(path))
|
||||
saved = torch.load(path, weights_only=True, map_location="cpu")
|
||||
assert set(saved["state"]) == set(adapter_state(model))
|
||||
assert any(k.startswith("layers.0.q_proj.lora_") for k in saved["state"])
|
||||
|
||||
torch.manual_seed(0)
|
||||
model_loaded = TinyModel()
|
||||
ll.load(model_loaded, str(path))
|
||||
loaded_state = adapter_state(model_loaded)
|
||||
for name, value in saved["state"].items():
|
||||
assert torch.equal(loaded_state[name].cpu(), value)
|
||||
with torch.no_grad():
|
||||
y_loaded = model_loaded(ids)
|
||||
assert (y_loaded - y_init).abs().max().item() < identity_tol
|
||||
|
||||
torch.manual_seed(0)
|
||||
train_model = TinyModel()
|
||||
ll.attach(train_model, cfg_for_variant(variant, training=True))
|
||||
assert_only_lora_trainable(train_model)
|
||||
target = torch.randn(2, 16, 100) * 0.1
|
||||
trainable = [p for p in train_model.parameters() if p.requires_grad]
|
||||
opt = torch.optim.Adam(trainable, lr=0.1) if variant == "delora" else torch.optim.SGD(trainable, lr=1e-2)
|
||||
losses = []
|
||||
first_grad_norm = math.nan
|
||||
before_train = adapter_state(train_model)
|
||||
for step in range(20):
|
||||
opt.zero_grad()
|
||||
loss = (train_model(ids) - target).pow(2).mean()
|
||||
loss.backward()
|
||||
assert_no_base_grads(train_model)
|
||||
grad_norm = sum(
|
||||
p.grad.detach().float().norm().item()
|
||||
for name, p in train_model.named_parameters()
|
||||
if "lora_" in name and p.grad is not None
|
||||
)
|
||||
assert math.isfinite(grad_norm)
|
||||
if step == 0:
|
||||
first_grad_norm = grad_norm
|
||||
opt.step()
|
||||
losses.append(loss.item())
|
||||
after_train = adapter_state(train_model)
|
||||
adapter_delta = sum((after_train[k] - before_train[k]).float().norm().item() for k in before_train)
|
||||
drop = (losses[0] - losses[-1]) / losses[0]
|
||||
assert first_grad_norm > 0
|
||||
assert adapter_delta > 0
|
||||
assert drop > 0.05
|
||||
|
||||
|
||||
def test_load_fails_on_missing_and_unexpected_lora_keys():
|
||||
ARTIFACT_DIR.mkdir(exist_ok=True)
|
||||
torch.manual_seed(0)
|
||||
model = TinyModel()
|
||||
ll.attach(model, cfg_for_variant("lora"))
|
||||
good_path = ARTIFACT_DIR / "lora_good.pt"
|
||||
ll.save(model, str(good_path))
|
||||
blob = torch.load(good_path, weights_only=True, map_location="cpu")
|
||||
|
||||
missing_blob = {"cfg": blob["cfg"], "state": dict(blob["state"])}
|
||||
missing_blob["state"].pop(next(iter(missing_blob["state"])))
|
||||
missing_path = ARTIFACT_DIR / "lora_missing.pt"
|
||||
torch.save(missing_blob, missing_path)
|
||||
with pytest.raises(RuntimeError, match="missing lora keys"):
|
||||
ll.load(TinyModel(), str(missing_path))
|
||||
|
||||
unexpected_blob = {"cfg": blob["cfg"], "state": dict(blob["state"])}
|
||||
unexpected_blob["state"]["layers.0.q_proj.lora_extra"] = torch.zeros(1)
|
||||
unexpected_path = ARTIFACT_DIR / "lora_unexpected.pt"
|
||||
torch.save(unexpected_blob, unexpected_path)
|
||||
with pytest.raises(RuntimeError, match="unexpected lora keys"):
|
||||
ll.load(TinyModel(), str(unexpected_path))
|
||||
|
||||
|
||||
def test_no_target_layers_is_loud_failure():
|
||||
cfg = ll.LoraLiteConfig(variant="lora", target_names=("definitely_missing",))
|
||||
with pytest.raises(RuntimeError, match="no target layers"):
|
||||
ll.attach(TinyModel(), cfg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("variant", ["lora", "delora"])
|
||||
def test_structural_non_linear_target_trains_for_forward_only_variants(variant: str):
|
||||
torch.manual_seed(0)
|
||||
model = FakeBnbModel()
|
||||
x = torch.randn(2, 3, 8)
|
||||
y_base = model(x).detach()
|
||||
cfg = ll.LoraLiteConfig(
|
||||
variant=variant,
|
||||
r=2,
|
||||
alpha=4,
|
||||
dtype=torch.float32,
|
||||
target_roles=(),
|
||||
variant_kwargs={"lambda0": 0.0} if variant == "delora" else {},
|
||||
)
|
||||
ll.attach(model, cfg)
|
||||
y_init = model(x)
|
||||
assert (y_init.detach() - y_base).abs().max().item() < 1e-6
|
||||
loss = y_init.pow(2).mean()
|
||||
loss.backward()
|
||||
assert_no_base_grads(model)
|
||||
adapter_grad_norm = sum(
|
||||
p.grad.detach().float().norm().item()
|
||||
for name, p in model.named_parameters()
|
||||
if "lora_" in name and p.grad is not None
|
||||
)
|
||||
assert adapter_grad_norm > 0
|
||||
|
||||
|
||||
def test_pissa_rejects_structural_non_linear_target():
|
||||
cfg = ll.LoraLiteConfig(variant="pissa", r=2, alpha=2, dtype=torch.float32, target_roles=())
|
||||
with pytest.raises(TypeError, match="plain nn.Linear"):
|
||||
ll.attach(FakeBnbModel(), cfg)
|
||||
Reference in New Issue
Block a user