mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 16:15:50 +08:00
fix V3 review must-fixes: DoRA bias passthrough + EVA load path
V3 external review (docs/audit/variants_review_v3.md, 97KB) found 3 must-fix bugs. DoRA: bias was being scaled by m/||V|| because we operated on the full base layer output. Now subtract bias before normalization, add back after. Matches peft DoRA exactly (docs/refs/peft_lora_dora.py:157-161). New smoke dora_bias_smoke verifies identity at t=0 with bias=True. EVA load: adapter.load() called attach() which called group_init() which required calibration_data and raised. Added _skip_group_init flag to attach(); load() passes it. EVA group_init still raises loudly when called directly without data. New smoke verifies save+load WITHOUT calibration data on load path. Also tightened EVA error message. Smoke now covers 8 variants + EVA roundtrip + DoRA-bias roundtrip + bnb 4/8-bit. ALL PASS. V3 nice-to-haves (PiSSA scaling, AntiPaSTO init choice, stale GH refs) deferred -- documented as intentional in module docstrings.
This commit is contained in:
@@ -0,0 +1,107 @@
|
||||
# V3 Variant Review — per-component audit
|
||||
|
||||
You are an expert ML engineer reviewing a from-scratch PEFT library
|
||||
(`lora-lite`, ~500 LOC) that re-implements 8 LoRA variants. Two prior reviews
|
||||
already happened (V1 paper-vs-code, V2 with reference implementations
|
||||
provided). Your job is V3: a tight per-component audit focused on
|
||||
correctness-of-mechanism rather than overall design.
|
||||
|
||||
# Scope
|
||||
|
||||
8 variants live in `src/lora_lite/variants/`:
|
||||
- lora.py
|
||||
- pissa.py
|
||||
- delora.py
|
||||
- ia3.py (two registered: `ia3` and `ia3_ff`)
|
||||
- dora.py
|
||||
- hra.py
|
||||
- eva.py (NEW since V2)
|
||||
- antipasto.py (NEW since V2)
|
||||
|
||||
Plus the runtime in `src/lora_lite/{adapter.py,variant.py,target.py,config.py}`
|
||||
and the smoke test in `tests/smoke.py`.
|
||||
|
||||
Reference implementations are in `docs/refs/` and the URLs are also pasted in
|
||||
each variant's module docstring. Compare against those.
|
||||
|
||||
# What I want from you (per variant, in this order, every time)
|
||||
|
||||
For EACH variant, work through these five checkpoints, using only that
|
||||
variant's file and its referenced peft/author code:
|
||||
|
||||
1. **PARAMS** — list every spec returned by `param_specs`. For each:
|
||||
shape, dtype (cfg.dtype unless overridden), trainable, as_buffer.
|
||||
Does the shape match the reference impl? Are buffers vs Parameters
|
||||
chosen correctly (no Parameter that should be a buffer; no buffer
|
||||
that we want to learn)? Does as_buffer mean it persists in
|
||||
state_dict (check `register_buffer(..., persistent=True)` in
|
||||
adapter.py)?
|
||||
|
||||
2. **INIT** — what does `init()` (and `group_init()` if defined) do?
|
||||
Does it match the reference exactly? Pay special attention to
|
||||
ZERO INITS — they often kill gradient flow on dependent params.
|
||||
Walk the gradient: at t=0, given this init, which trainable params
|
||||
actually receive non-zero gradient on the first SGD step?
|
||||
Are dtype casts placed correctly (fp32 SVD, then to cfg.dtype)?
|
||||
|
||||
3. **DTYPE** — trace dtype through init -> param storage -> forward.
|
||||
Where could silent precision loss happen? Is bf16 or fp16 used
|
||||
anywhere it shouldn't be? Does identity-at-init survive bf16?
|
||||
|
||||
4. **FORWARD** — write the math the forward implements, in the same
|
||||
convention as the reference (peft/author paper). Compare term by
|
||||
term. Common mistakes to look for:
|
||||
- wrong scale (alpha/r vs 1/r vs alpha vs 1)
|
||||
- missing or doubled normalization
|
||||
- wrong basis (rotating U vs V; gating input vs output)
|
||||
- dropout placement (we have no dropout — flag if any variant
|
||||
references one; see config.py)
|
||||
|
||||
5. **LINK SANITY** — open the URLs in the docstring. Verify:
|
||||
- the paper arxiv link goes to the right paper
|
||||
- the github link points to a real file
|
||||
- the offline `docs/refs/` snapshot matches what the URL serves
|
||||
today (the snapshots may be stale; if so, flag the drift)
|
||||
|
||||
# Output format
|
||||
|
||||
For each variant, write at most ~60 lines. Use this template:
|
||||
|
||||
## <variant>
|
||||
|
||||
### params
|
||||
- <one bullet per ParamSpec; flag bug if any>
|
||||
|
||||
### init / group_init
|
||||
- <bullets; identify GRADIENT FLOW at t=0 explicitly>
|
||||
|
||||
### dtype
|
||||
- <bullets>
|
||||
|
||||
### forward
|
||||
Math: <one-line equation in our convention>
|
||||
Ref math: <one-line equation in reference convention>
|
||||
Match? YES / NO + one-line explanation
|
||||
|
||||
### links
|
||||
- paper: OK / WRONG / DEAD
|
||||
- peft ref: OK / DEAD
|
||||
- author ref (if any): OK / DEAD
|
||||
- offline snapshot drift: NONE / MINOR / MAJOR
|
||||
|
||||
### verdict
|
||||
CORRECT / PARTIAL / BUGGY -- one-sentence reason
|
||||
|
||||
After all variants, write a "## summary" with a markdown table of verdicts and
|
||||
a numbered list of MUST-FIX bugs (severity high) vs nice-to-haves.
|
||||
|
||||
# Hard rules
|
||||
|
||||
- Be specific. Cite line numbers (`src/lora_lite/variants/foo.py:NN`) for
|
||||
every claim.
|
||||
- Do NOT propose redesigns. Only flag correctness issues against the
|
||||
references.
|
||||
- If an issue is intentional and documented, say so and move on -- don't
|
||||
re-flag known deviations from the docstrings.
|
||||
- If you can't tell whether something is a bug, say "AMBIGUOUS" with the
|
||||
question you'd need answered.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -28,7 +28,7 @@ def _pre_hook(layer, args):
|
||||
return (x_new.to(x.dtype),)
|
||||
|
||||
|
||||
def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None) -> list[RemovableHandle]:
|
||||
def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None, *, _skip_group_init: bool = False) -> list[RemovableHandle]:
|
||||
if cfg.variant not in REGISTRY:
|
||||
raise KeyError(f"unknown variant {cfg.variant!r}; registered: {list(REGISTRY)}")
|
||||
variant = REGISTRY[cfg.variant]
|
||||
@@ -62,7 +62,7 @@ def attach(model: nn.Module, cfg: LoraLiteConfig, calibration_data=None) -> list
|
||||
attached_targets.append((name, layer, role))
|
||||
|
||||
group_init = getattr(variant, "group_init", None)
|
||||
if group_init is not None:
|
||||
if group_init is not None and not _skip_group_init:
|
||||
group_init(model, attached_targets, cfg, calibration_data)
|
||||
|
||||
for _, layer, _ in attached_targets:
|
||||
@@ -132,7 +132,7 @@ def save(model: nn.Module, path: str) -> None:
|
||||
def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||
blob = torch.load(path, weights_only=True, map_location="cpu")
|
||||
cfg = LoraLiteConfig.from_dict(blob["cfg"])
|
||||
handles = attach(model, cfg) # creates empty params with right shapes
|
||||
handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict
|
||||
missing, unexpected = model.load_state_dict(blob["state"], strict=False)
|
||||
expected_lora = {k for k in model.state_dict() if "lora_" in k}
|
||||
missing_lora = sorted(expected_lora.intersection(missing))
|
||||
|
||||
@@ -59,8 +59,12 @@ class DoRA:
|
||||
BA = einsum(layer.lora_B, layer.lora_A, "o r, r i -> o i")
|
||||
V = layer.weight + scale * BA # (d_out, d_in)
|
||||
v_norm = V.norm(dim=1).clamp_min(1e-12) # (d_out,)
|
||||
# y' = (m / ||V||_c) * (Wx + scale * BAx) = (m / ||V||_c) * (y + scale * BAx)
|
||||
# Bias passes through UNSCALED -- only Wx + scale*BAx is normalized.
|
||||
# Matches peft DoRA forward (docs/refs/peft_lora_dora.py:157-161).
|
||||
bias = getattr(layer, "bias", None)
|
||||
wx = y if bias is None else (y - bias)
|
||||
h = einsum(x, layer.lora_A, "... i, r i -> ... r")
|
||||
delta = einsum(h, layer.lora_B, "... r, o r -> ... o")
|
||||
combined = y + scale * delta
|
||||
return (layer.lora_m / v_norm) * combined
|
||||
combined = wx + scale * delta
|
||||
out = (layer.lora_m / v_norm) * combined
|
||||
return out if bias is None else out + bias
|
||||
|
||||
@@ -61,12 +61,14 @@ class EVA:
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data) -> None:
|
||||
# adapter.load() passes _skip_group_init=True so this is only called on
|
||||
# the live attach path where calibration_data is required.
|
||||
if calibration_data is None:
|
||||
raise ValueError(
|
||||
"EVA requires calibration_data: an iterable of model inputs "
|
||||
"(dicts of kwargs to model.forward, or single tensors) used to "
|
||||
"estimate the input PCA per layer. Pass via "
|
||||
"lora_lite.attach(model, cfg, calibration_data=batches)."
|
||||
"(dicts of kwargs to model.forward, tuples of positional args, "
|
||||
"or single tensors) used to estimate the per-layer input PCA. "
|
||||
"Pass via lora_lite.attach(model, cfg, calibration_data=batches)."
|
||||
)
|
||||
# Collect input activations per target via forward hooks.
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
|
||||
@@ -309,6 +309,24 @@ def eva_smoke():
|
||||
assert all(n > 0 for n in a_norms), "EVA lora_A buffers all zero -> group_init never ran"
|
||||
print(f" SHOULD: lora_A buffers populated. PASS (mean ||A||={sum(a_norms)/len(a_norms):.3f}).")
|
||||
|
||||
# save/load round-trip WITHOUT calibration data on load (load path uses _skip_group_init)
|
||||
ARTIFACT_DIR.mkdir(exist_ok=True)
|
||||
p = ARTIFACT_DIR / "eva_smoke_adapter.pt"
|
||||
ll.save(model, str(p))
|
||||
ll.detach(model)
|
||||
torch.manual_seed(0)
|
||||
model2 = TinyModel().to(torch.float32)
|
||||
ll.load(model2, str(p)) # must NOT require calibration_data
|
||||
with torch.no_grad():
|
||||
y_loaded = model2(ids)
|
||||
err2 = (y_loaded - y_adapt).abs().max().item()
|
||||
print(f" save/load (no calibration on load): max err = {err2:.3e}")
|
||||
assert err2 < 1e-6, f"EVA save/load mismatch {err2}"
|
||||
print(" SHOULD: load without calibration_data works (uses _skip_group_init). PASS.")
|
||||
ll.detach(model2)
|
||||
# re-attach model for training section below
|
||||
ll.attach(model, cfg, calibration_data=calib)
|
||||
|
||||
# gradient flow: only B trains
|
||||
target = torch.randn(2, 16, 100, dtype=torch.float32) * 0.1
|
||||
trainable = [p for p in model.parameters() if p.requires_grad]
|
||||
@@ -328,6 +346,36 @@ def eva_smoke():
|
||||
ll.detach(model)
|
||||
|
||||
|
||||
def dora_bias_smoke():
|
||||
"""V3 review caught: DoRA was scaling bias by m/||V||. Fixed; bias passes through."""
|
||||
print("\n=== dora bias passthrough (V3 fix) ===")
|
||||
torch.manual_seed(0)
|
||||
d = 16
|
||||
layer = nn.Linear(d, d, bias=True).to(torch.float32)
|
||||
x = torch.randn(2, d)
|
||||
y_base = layer(x).detach()
|
||||
|
||||
class Wrap(nn.Module):
|
||||
def __init__(self, lin):
|
||||
super().__init__()
|
||||
self.config = type("Cfg", (), {"hidden_size": d})()
|
||||
self.layers = nn.ModuleList([lin])
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers[0](x)
|
||||
|
||||
model = Wrap(layer)
|
||||
cfg = ll.LoraLiteConfig(variant="dora", r=2, alpha=4, dtype=torch.float32, target_roles=())
|
||||
ll.attach(model, cfg)
|
||||
with torch.no_grad():
|
||||
y_adapt = model(x)
|
||||
err = (y_adapt - y_base).abs().max().item()
|
||||
print(f" identity with bias=True: max err = {err:.3e}")
|
||||
assert err < 1e-5, f"DoRA bias-passthrough broken: err {err} (likely bias being scaled)"
|
||||
print(" SHOULD: identity err < 1e-5 even with bias. PASS.")
|
||||
ll.detach(model)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--require-bnb", action="store_true")
|
||||
@@ -336,6 +384,7 @@ def main():
|
||||
for v in ("lora", "pissa", "delora", "ia3", "dora", "hra", "antipasto"):
|
||||
variant_test(v, dtype=torch.float32)
|
||||
eva_smoke()
|
||||
dora_bias_smoke()
|
||||
structural_linear_like_test()
|
||||
bitsandbytes_cuda_smoke(args.require_bnb)
|
||||
print("\nALL PASS.")
|
||||
|
||||
Reference in New Issue
Block a user