fix V3 review must-fixes: DoRA bias passthrough + EVA load path

V3 external review (docs/audit/variants_review_v3.md, 97KB) found 3
must-fix bugs.

DoRA: bias was being scaled by m/||V|| because we operated on the full
base layer output. Now subtract bias before normalization, add back
after. Matches peft DoRA exactly (docs/refs/peft_lora_dora.py:157-161).
New smoke dora_bias_smoke verifies identity at t=0 with bias=True.

EVA load: adapter.load() called attach() which called group_init() which
required calibration_data and raised. Added _skip_group_init flag to
attach(); load() passes it. EVA group_init still raises loudly when
called directly without data. New smoke verifies save+load WITHOUT
calibration data on load path.

Also tightened EVA error message.

Smoke now covers 8 variants + EVA roundtrip + DoRA-bias roundtrip + bnb
4/8-bit. ALL PASS.

V3 nice-to-haves (PiSSA scaling, AntiPaSTO init choice, stale GH refs)
deferred -- documented as intentional in module docstrings.
This commit is contained in:
copilot
2026-04-26 19:50:48 +08:00
parent 185eb29c70
commit 55757e829d
6 changed files with 1841 additions and 9 deletions
+107
View File
@@ -0,0 +1,107 @@
# V3 Variant Review — per-component audit
You are an expert ML engineer reviewing a from-scratch PEFT library
(`lora-lite`, ~500 LOC) that re-implements 8 LoRA variants. Two prior reviews
already happened (V1 paper-vs-code, V2 with reference implementations
provided). Your job is V3: a tight per-component audit focused on
correctness-of-mechanism rather than overall design.
# Scope
8 variants live in `src/lora_lite/variants/`:
- lora.py
- pissa.py
- delora.py
- ia3.py (two registered: `ia3` and `ia3_ff`)
- dora.py
- hra.py
- eva.py (NEW since V2)
- antipasto.py (NEW since V2)
Plus the runtime in `src/lora_lite/{adapter.py,variant.py,target.py,config.py}`
and the smoke test in `tests/smoke.py`.
Reference implementations are in `docs/refs/` and the URLs are also pasted in
each variant's module docstring. Compare against those.
# What I want from you (per variant, in this order, every time)
For EACH variant, work through these five checkpoints, using only that
variant's file and its referenced peft/author code:
1. **PARAMS** — list every spec returned by `param_specs`. For each:
shape, dtype (cfg.dtype unless overridden), trainable, as_buffer.
Does the shape match the reference impl? Are buffers vs Parameters
chosen correctly (no Parameter that should be a buffer; no buffer
that we want to learn)? Does as_buffer mean it persists in
state_dict (check `register_buffer(..., persistent=True)` in
adapter.py)?
2. **INIT** — what does `init()` (and `group_init()` if defined) do?
Does it match the reference exactly? Pay special attention to
ZERO INITS — they often kill gradient flow on dependent params.
Walk the gradient: at t=0, given this init, which trainable params
actually receive non-zero gradient on the first SGD step?
Are dtype casts placed correctly (fp32 SVD, then to cfg.dtype)?
3. **DTYPE** — trace dtype through init -> param storage -> forward.
Where could silent precision loss happen? Is bf16 or fp16 used
anywhere it shouldn't be? Does identity-at-init survive bf16?
4. **FORWARD** — write the math the forward implements, in the same
convention as the reference (peft/author paper). Compare term by
term. Common mistakes to look for:
- wrong scale (alpha/r vs 1/r vs alpha vs 1)
- missing or doubled normalization
- wrong basis (rotating U vs V; gating input vs output)
- dropout placement (we have no dropout — flag if any variant
references one; see config.py)
5. **LINK SANITY** — open the URLs in the docstring. Verify:
- the paper arxiv link goes to the right paper
- the github link points to a real file
- the offline `docs/refs/` snapshot matches what the URL serves
today (the snapshots may be stale; if so, flag the drift)
# Output format
For each variant, write at most ~60 lines. Use this template:
## <variant>
### params
- <one bullet per ParamSpec; flag bug if any>
### init / group_init
- <bullets; identify GRADIENT FLOW at t=0 explicitly>
### dtype
- <bullets>
### forward
Math: <one-line equation in our convention>
Ref math: <one-line equation in reference convention>
Match? YES / NO + one-line explanation
### links
- paper: OK / WRONG / DEAD
- peft ref: OK / DEAD
- author ref (if any): OK / DEAD
- offline snapshot drift: NONE / MINOR / MAJOR
### verdict
CORRECT / PARTIAL / BUGGY -- one-sentence reason
After all variants, write a "## summary" with a markdown table of verdicts and
a numbered list of MUST-FIX bugs (severity high) vs nice-to-haves.
# Hard rules
- Be specific. Cite line numbers (`src/lora_lite/variants/foo.py:NN`) for
every claim.
- Do NOT propose redesigns. Only flag correctness issues against the
references.
- If an issue is intentional and documented, say so and move on -- don't
re-flag known deviations from the docstrings.
- If you can't tell whether something is a bug, say "AMBIGUOUS" with the
question you'd need answered.
File diff suppressed because it is too large Load Diff
+3 -3
View File
@@ -28,7 +28,7 @@ def _pre_hook(layer, args):
return (x_new.to(x.dtype),)
def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None) -> list[RemovableHandle]:
def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None, *, _skip_group_init: bool = False) -> list[RemovableHandle]:
if cfg.variant not in REGISTRY:
raise KeyError(f"unknown variant {cfg.variant!r}; registered: {list(REGISTRY)}")
variant = REGISTRY[cfg.variant]
@@ -62,7 +62,7 @@ def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None) -> list
attached_targets.append((name, layer, role))
group_init = getattr(variant, "group_init", None)
if group_init is not None:
if group_init is not None and not _skip_group_init:
group_init(model, attached_targets, cfg, calibration_data)
for _, layer, _ in attached_targets:
@@ -132,7 +132,7 @@ def save(model: nn.Module, path: str) -> None:
def load(model: nn.Module, path: str) -> list[RemovableHandle]:
blob = torch.load(path, weights_only=True, map_location="cpu")
cfg = LoraLiteConfig.from_dict(blob["cfg"])
handles = attach(model, cfg) # creates empty params with right shapes
handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict
missing, unexpected = model.load_state_dict(blob["state"], strict=False)
expected_lora = {k for k in model.state_dict() if "lora_" in k}
missing_lora = sorted(expected_lora.intersection(missing))
+7 -3
View File
@@ -59,8 +59,12 @@ class DoRA:
BA = einsum(layer.lora_B, layer.lora_A, "o r, r i -> o i")
V = layer.weight + scale * BA # (d_out, d_in)
v_norm = V.norm(dim=1).clamp_min(1e-12) # (d_out,)
# y' = (m / ||V||_c) * (Wx + scale * BAx) = (m / ||V||_c) * (y + scale * BAx)
# Bias passes through UNSCALED -- only Wx + scale*BAx is normalized.
# Matches peft DoRA forward (docs/refs/peft_lora_dora.py:157-161).
bias = getattr(layer, "bias", None)
wx = y if bias is None else (y - bias)
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
combined = y + scale * delta
return (layer.lora_m / v_norm) * combined
combined = wx + scale * delta
out = (layer.lora_m / v_norm) * combined
return out if bias is None else out + bias
+5 -3
View File
@@ -61,12 +61,14 @@ class EVA:
@staticmethod
def group_init(model: nn.Module, targets, cfg, calibration_data) -> None:
# adapter.load() passes _skip_group_init=True so this is only called on
# the live attach path where calibration_data is required.
if calibration_data is None:
raise ValueError(
"EVA requires calibration_data: an iterable of model inputs "
"(dicts of kwargs to model.forward, or single tensors) used to "
"estimate the input PCA per layer. Pass via "
"lora_lite.attach(model, cfg, calibration_data=batches)."
"(dicts of kwargs to model.forward, tuples of positional args, "
"or single tensors) used to estimate the per-layer input PCA. "
"Pass via lora_lite.attach(model, cfg, calibration_data=batches)."
)
# Collect input activations per target via forward hooks.
layers = {name: layer for name, layer, _ in targets}
+49
View File
@@ -309,6 +309,24 @@ def eva_smoke():
assert all(n > 0 for n in a_norms), "EVA lora_A buffers all zero -> group_init never ran"
print(f" SHOULD: lora_A buffers populated. PASS (mean ||A||={sum(a_norms)/len(a_norms):.3f}).")
# save/load round-trip WITHOUT calibration data on load (load path uses _skip_group_init)
ARTIFACT_DIR.mkdir(exist_ok=True)
p = ARTIFACT_DIR / "eva_smoke_adapter.pt"
ll.save(model, str(p))
ll.detach(model)
torch.manual_seed(0)
model2 = TinyModel().to(torch.float32)
ll.load(model2, str(p)) # must NOT require calibration_data
with torch.no_grad():
y_loaded = model2(ids)
err2 = (y_loaded - y_adapt).abs().max().item()
print(f" save/load (no calibration on load): max err = {err2:.3e}")
assert err2 < 1e-6, f"EVA save/load mismatch {err2}"
print(" SHOULD: load without calibration_data works (uses _skip_group_init). PASS.")
ll.detach(model2)
# re-attach model for training section below
ll.attach(model, cfg, calibration_data=calib)
# gradient flow: only B trains
target = torch.randn(2, 16, 100, dtype=torch.float32) * 0.1
trainable = [p for p in model.parameters() if p.requires_grad]
@@ -328,6 +346,36 @@ def eva_smoke():
ll.detach(model)
def dora_bias_smoke():
"""V3 review caught: DoRA was scaling bias by m/||V||. Fixed; bias passes through."""
print("\n=== dora bias passthrough (V3 fix) ===")
torch.manual_seed(0)
d = 16
layer = nn.Linear(d, d, bias=True).to(torch.float32)
x = torch.randn(2, d)
y_base = layer(x).detach()
class Wrap(nn.Module):
def __init__(self, lin):
super().__init__()
self.config = type("Cfg", (), {"hidden_size": d})()
self.layers = nn.ModuleList([lin])
def forward(self, x):
return self.layers[0](x)
model = Wrap(layer)
cfg = ll.LoraLiteConfig(variant="dora", r=2, alpha=4, dtype=torch.float32, target_roles=())
ll.attach(model, cfg)
with torch.no_grad():
y_adapt = model(x)
err = (y_adapt - y_base).abs().max().item()
print(f" identity with bias=True: max err = {err:.3e}")
assert err < 1e-5, f"DoRA bias-passthrough broken: err {err} (likely bias being scaled)"
print(" SHOULD: identity err < 1e-5 even with bias. PASS.")
ll.detach(model)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--require-bnb", action="store_true")
@@ -336,6 +384,7 @@ def main():
for v in ("lora", "pissa", "delora", "ia3", "dora", "hra", "antipasto"):
variant_test(v, dtype=torch.float32)
eva_smoke()
dora_bias_smoke()
structural_linear_like_test()
bitsandbytes_cuda_smoke(args.require_bnb)
print("\nALL PASS.")