mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 17:01:14 +08:00
Merge antipasto-svd-cores: rotation-free S-space adapter family
Replaces rotation/Cayley antipasto.py with three bounded, interpretable cores
(gain 1+ELU, contractive ablation, CorDA/ASVD better-basis) + dplr, plus full
GSM8K cost table and the rot-basis ablation. Resolves the three review FIXMEs
from 3af2a2a (rambling removed; CorDA split into its own variant; group_init
recovers W_orig so it no longer runs on cropped matrices).
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
# Conflicts:
# src/lora_lite/variants/antipasto.py
This commit is contained in:
@@ -47,25 +47,48 @@ just qwen-probe # Qwen/Qwen3-0.6B train/save-load probe
|
||||
|
||||
## Variants
|
||||
|
||||
| Variant | 4bit/8bit | GSM8K % | Params | Peak GPU (GB) |
|
||||
| --------------------------------------------- | --------- | ------- | ---------- | ------------- |
|
||||
| [LoRA](https://arxiv.org/abs/2106.09685) | yes | 63.2% | 4.59M | 11.3 |
|
||||
| [PiSSA](https://arxiv.org/abs/2404.02948) | no | 63.2% | 4.59M | 11.3 |
|
||||
| [DoRA](https://arxiv.org/abs/2402.09353) | no | 62.4% | 4.67M | 11.3 |
|
||||
| [DeLoRA](https://arxiv.org/abs/2503.18225) | yes | 61.5% | 4.59M | 11.3 |
|
||||
| [AntiPaSTO](https://arxiv.org/abs/2601.07473) | no | 61.4% | 35.8K | 11.5 |
|
||||
| [IA3-FF](https://arxiv.org/pdf/2205.05638) | yes | 61.4% | 86K | 11.4 |
|
||||
| [EVA](https://arxiv.org/abs/2410.07170) | no | 60.3% | 4.59M | 11.3 |
|
||||
| [IA3](https://arxiv.org/pdf/2205.05638) | yes | 60.0% | 57K | 11.4 |
|
||||
| [HRA](https://arxiv.org/abs/2405.17484) | yes | 61.6% | 1.84M | 11.3 |
|
||||
Trained on a MetaMathQA subset, tested on GSM8K, all on `Qwen/Qwen3.5-0.8B-Base` targeting
|
||||
`down_proj` in all 24 layers (2500 steps, effective batch 8 = 20k samples). Standard adapters
|
||||
use r=32; the AntiPaSTO family uses r=256 (it tunes only S-space gain, so it needs the rank).
|
||||
|
||||
Params = trainable adapter params. Peak GPU = peak CUDA memory during train+eval (logged from this run onward; older runs predate the column).
|
||||
| Variant | test % | valid % | Params | +MACs/tok | fwd/bwd (ms) | init (s) |
|
||||
| --------------------------------------------- | -----: | ------: | ------: | --------: | -----------: | -------: |
|
||||
| [DoRA](https://arxiv.org/abs/2402.09353) | 60.2 | 68.0 | 3.56M | 3.54M | 161 / 556 | 0.16 |
|
||||
| [LoRA](https://arxiv.org/abs/2106.09685) | 59.8 | 68.0 | 3.54M | 3.54M | 173 / 573 | 0.02 |
|
||||
| [PiSSA](https://arxiv.org/abs/2404.02948) | 59.8 | 76.0 | 3.54M | 3.54M | 146 / 549 | 2.04 |
|
||||
| [HRA](https://arxiv.org/abs/2405.17484) | 59.2 | 70.0 | 2.75M | 2.75M | 225 / 948 | 0.04 |
|
||||
| [EVA](https://arxiv.org/abs/2410.07170) | 59.3 | 74.0 | 3.54M | 3.54M | 151 / 660 | 28.3 |
|
||||
| [IA3-FF](https://arxiv.org/pdf/2205.05638) | 56.3 | 62.0 | 0.086M | 0M | 140 / 510 | 0.01 |
|
||||
| [DeLoRA](https://arxiv.org/abs/2503.18225) | 56.2 | 62.0 | 3.54M | 3.54M | 169 / 593 | 0.21 |
|
||||
| [AntiPaSTO](https://arxiv.org/abs/2601.07473) | 56.0 | 62.0 | 0.0061M | 28.3M | 166 / 571 | 2.5 |
|
||||
| AntiPaSTO-rot | 57.2 | 60.0 | 0.0154M | 28.3M | 165 / 596 | 2.0 |
|
||||
| AntiPaSTO-ablate | 56.0 | 68.0 | 0.0062M | 28.3M | 166 / 580 | 2.2 |
|
||||
| AntiPaSTO-dplr | 56.0 | 64.0 | 0.1044M | 28.4M | 153 / 582 | 3.6 |
|
||||
| AntiPaSTO-ASVD (diag C) | 55.6 | 64.0 | 0.0061M | 28.3M | 150 / 533 | 34 |
|
||||
| AntiPaSTO-CorDA (full C) | 54.7 | 58.0 | 0.0061M | 28.3M | 146 / 576 | 120 |
|
||||
| [IA3](https://arxiv.org/pdf/2205.05638) | 52.3 | 62.0 | 0.0061M | 0M | 161 / 515 | 0.01 |
|
||||
|
||||
Setup: Qwen3-0.6B-Base, MetaMathQA train (5k steps, batch 4 = 20k samples unless noted), r=32, all q/v targets, GSM8K test (1319 examples). HRA used batch 2 (10k samples) due to memory. AntiPaSTO used r=256 (default for this variant).
|
||||
test/valid % = GSM8K exact-match accuracy. Params = trainable adapter params. +MACs/tok = added
|
||||
forward MACs per token (analytic, hardware-independent). fwd/bwd = median ms over one batch.
|
||||
init = one-time calibration (CorDA's `d_in x d_in` covariance eigh; ~0 for the rest). Peak CUDA
|
||||
memory is ~9.8 GB for every row. Empty rows fill in as the sweep lands.
|
||||
|
||||
Reference: PEFT reports LoRA at 49.0% on Llama-3.2-3B (different model, different sample count). Our numbers are not directly comparable but suggest the adapters work.
|
||||
We validate our adapters the same way [PEFT](https://github.com/huggingface/peft/tree/main/method_comparison) does: train on a MetaMathQA subset and check meaningful GSM8K accuracy. See [this file](scripts/metamath_gsm8k_benchmark.py) for details.
|
||||
|
||||
AntiPaSTO at 59.5% with 4.5K trainable params (1000x fewer than LoRA's 4.59M). It trains singular-value deltas + block-Cayley rotation within the SVD subspace, so it can rescale and reorient existing directions but not create new ones. Higher rank (r>32) or data-driven dimension selection (from antipasto3) may close the gap further.
|
||||
AntiPaSTO is the novel row here: instead of adding trainable directions like LoRA, it freezes W's own top-r SVD and learns only a bounded per-direction gain `S_eff = S * (1 + ELU(g))`. The singular basis stays fixed and interpretable, and the adapter is O(r) params (the 6.1K gain is ~580x smaller than LoRA's 3.54M). The variants change only the basis or core: rot learns a small block-rotation of the frozen basis, CorDA/ASVD orient it by the input second moment (full covariance vs diagonal-only, [Yang+ 2024](https://arxiv.org/abs/2406.05223) / [Yuan+ 2023](https://arxiv.org/abs/2312.05821)), ablate learns a contractive directional ablation ([Arditi+ 2024](https://arxiv.org/abs/2406.11717)), dplr adds a small low-rank core for cross-direction mixing.
|
||||
|
||||
CorDA (full C) and ASVD (diag C) are a metric-axis ablation against plain AntiPaSTO (C=I): does
|
||||
covariance orientation earn its `d_in x d_in` eigh over the cheap diagonal or no calibration at
|
||||
all? On GSM8K/down_proj the answer is no: C=I 56.0, diag C 55.6, full C 54.7 (single seed). The
|
||||
off-diagonal orientation is the slowest arm (120 s init vs 2.5 s) and lands slightly *below* no
|
||||
calibration, so plain top-r SVD is the right default for this bounded-gain adapter here.
|
||||
|
||||
AntiPaSTO-rot tunes that basis instead of the metric: a block-diagonal Cayley rotation of the
|
||||
input (V), output (U), or both. The table row is V (the default); the ablation gives V 57.2 >
|
||||
U 56.5 > both 55.6 (single seed). So rotating which inputs feed each frozen direction helps most,
|
||||
the output-side rotation is slightly worse, and doing both is worst -- the second rotation is
|
||||
redundant capacity that hurts. rot(V) is the best small-parameter arm overall (57.2 at 15K params
|
||||
vs LoRA's 59.8 at 3.54M).
|
||||
|
||||
|
||||
## Developer docs
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
# Review request: CorDA / ASVD covariance-oriented SVD adapter init
|
||||
|
||||
You are reviewing the linear-algebra correctness of two PEFT-adapter init routines in a
|
||||
research codebase. This is a frozen-basis bounded-gain adapter ("AntiPaSTO"): it takes the
|
||||
top-r SVD of a Linear weight W (d_out x d_in), freezes (U, S, P), and trains only a
|
||||
per-direction gain g via S_eff = S * (1 + ELU(coeff*g)). At g=0 the adapter must be an
|
||||
EXACT identity (output equals the original W x).
|
||||
|
||||
Two init variants re-orient the SVD basis by the input second moment of calibration data:
|
||||
- CorDA (Yang+ 2024, arXiv:2406.05223): full covariance C = E[x x^T], via eigh.
|
||||
- ASVD (Yuan+ 2023, arXiv:2312.05821): diagonal only, M = diag(E[x_i^2]).
|
||||
|
||||
The two share one function `_covariance_orient(..., diag)`; only the `diag` flag differs.
|
||||
|
||||
## Claims I want you to verify or refute, each with reasoning
|
||||
|
||||
1. **Reconstruction is lossless / identity-at-init holds.** After re-orientation, the code
|
||||
sets `W_res_new = W_orig - (U_r S_r) P_r` and stores (U_r, S_r, P_r). The forward adds
|
||||
`((x @ P^T) * S_eff) @ U^T` to `x @ W_res^T`. At g=0 (S_eff=S_r), is the total output
|
||||
exactly `x @ W_orig^T`, in exact arithmetic? Note P_r is the TRUNCATED top-r projector,
|
||||
not full rank. Is `W_res_new + U_r S_r P_r == W_orig` exactly, or only approximately?
|
||||
|
||||
2. **CorDA whitening form is correct.** The code computes (full case):
|
||||
`C^{1/2}, C^{-1/2}` via eigh; `U,S,Vh = svd(W @ C^{1/2})`; `P_r = Vh[:r] @ C^{-1/2}`;
|
||||
`U_r = U[:, :r]`, `S_r = S[:r]`. Question: is `U_r diag(S_r) P_r` the rank-r truncation
|
||||
that is Eckart-Young optimal for reconstructing W under inputs x ~ N(0, C)? i.e. does
|
||||
minimizing `||(W - W_hat) x||` over rank-r W_hat with x~N(0,C) reduce to truncated SVD
|
||||
of `W C^{1/2}` followed by right-multiply by `C^{-1/2}`? Show the algebra.
|
||||
|
||||
3. **ASVD diagonal form is the consistent diagonal special case.** With `c = E[x_i^2]`
|
||||
(a d_in vector), code does `svd(W * c.sqrt())` (broadcast scales COLUMNS of W) and
|
||||
`P_r = Vh[:r] * c.rsqrt()` (scales COLUMNS of Vh). Is this exactly variant 2 with C
|
||||
replaced by diag(c)? Is the column-broadcast `W * c.sqrt()` equal to `W @ diag(sqrt(c))`?
|
||||
|
||||
4. **The eps damping does not break identity.** `lam = lam.clamp_min(0) + eps` (full) and
|
||||
`c = (...).clamp_min(0) + eps` (diag). The eps enters BOTH the forward map C^{1/2} used
|
||||
in the SVD AND the inverse C^{-1/2} in P. Does the damped C^{1/2} and damped C^{-1/2}
|
||||
still compose to identity inside the reconstruction (so claim 1 still holds with eps>0),
|
||||
or does eps introduce a reconstruction error? Specifically: the SVD is of `W @ M^{1/2}`
|
||||
and P uses the SAME M's `M^{-1/2}`; does `(W M^{1/2}) truncated-then-times M^{-1/2}`
|
||||
telescope regardless of what M is, as long as M^{1/2} and M^{-1/2} are true inverses?
|
||||
|
||||
5. **Covariance estimator.** `m = x.T @ x` summed over tokens, divided by total token
|
||||
count `cnt` = sum of b*s. This is the UNCENTERED second moment E[x x^T], not the
|
||||
centered covariance. Is uncentered correct for this use (we want to reconstruct W x well
|
||||
on the actual activation distribution, which includes the mean)? Any concern?
|
||||
|
||||
6. **Anything wrong, risky, or non-obvious** — numerical (eigh of a d_in x d_in ~3584^2
|
||||
moment in fp32), the clamp_min(0) before adding eps, the `cnt < r` guard, dtype round
|
||||
trips (buffers bf16, math in fp32), or the oblique P (rows not orthonormal) interacting
|
||||
with the gain. Be concrete and cite the line.
|
||||
|
||||
Structure findings by severity (blocker / should-fix / nit). If a claim is correct, say so
|
||||
plainly with the one-line reason; do not invent problems. Answer from the code below; do
|
||||
not say you will read files.
|
||||
|
||||
--- FILE: src/lora_lite/variants/antipasto_corda.py ---
|
||||
@@ -0,0 +1,232 @@
|
||||
# Reference-fidelity check: `antipasto_ablate.py` vs Arditi+ 2024 (arXiv:2406.11717)
|
||||
|
||||
Scope: math/algorithm fidelity + citation only. Fail-fast research code, so no
|
||||
defensive-programming, None-check, or backward-compat flags. Secondary `cov_orient`
|
||||
block checked against CorDA (arXiv:2406.05223).
|
||||
|
||||
Files:
|
||||
- impl: `/media/wassname/SGIronWolf/projects5/2026/lite/lora-lite/src/lora_lite/variants/antipasto_ablate.py`
|
||||
- paper: `/media/wassname/SGIronWolf/projects5/2026/lite/lora-lite/docs/papers/md/arditi_2406.11717.md`
|
||||
- secondary: `/media/wassname/SGIronWolf/projects5/2026/lite/lora-lite/docs/papers/md/corda_2406.05223.md`
|
||||
|
||||
---
|
||||
|
||||
## 1. Ablation operator: residual-stream projection vs per-layer output-singular projection
|
||||
|
||||
Paper, directional ablation operator (arditi md:153-157):
|
||||
|
||||
> 𝐱′ ← 𝐱 − 𝐫̂ 𝐫̂⊺ 𝐱. [...] We perform this operation at every activation 𝐱ᵢ(l) and 𝐱̃ᵢ(l), across all layers l and all token positions i. This effectively prevents the model from ever representing this direction in its residual stream.
|
||||
|
||||
Paper, the weight-space equivalent (the form our code structurally resembles), Eq 5 (arditi md:246-248):
|
||||
|
||||
> W_out′ ← W_out − 𝐫̂ 𝐫̂⊺ W_out. [...] the matrices that write to the residual stream are: the embedding matrix, the positional embedding matrix, attention out matrices, and MLP out matrices. Orthogonalizing all of these matrices [...] with respect to the direction 𝐫̂ effectively prevents the model from ever writing 𝐫̂ to its residual stream.
|
||||
|
||||
Code core (antipasto_ablate.py:182-186):
|
||||
|
||||
```
|
||||
h = (x @ Vh.T) * S # (..., r) output S-coords
|
||||
proj = h @ Chat # (..., k)
|
||||
h = h - coeff * (proj * alpha) @ Chat.T # contractive removal in r-dim U-space
|
||||
return y + h @ U.T # map back to d_out
|
||||
```
|
||||
|
||||
Observation (math): Chat are orthonormal in the r-dim coordinate space that indexes
|
||||
the columns of U (lora_U is (d_out, r), Chat is (r, k); ablate at :183-185 happens on
|
||||
h which lives in this r-space, before `h @ U.T` lifts to d_out at :186). Since U has
|
||||
orthonormal columns, a unit direction `c` in r-space maps to the unit direction
|
||||
`U @ c` in d_out (residual) space, and the projector obeys
|
||||
`U (I - Chat Chatᵀ) Uᵀ = I_{d_out} - (U Chat)(U Chat)ᵀ` on the column space of U. So
|
||||
within the rank-r output subspace this is exactly an Arditi-style outer-product
|
||||
projector `I - d̂ d̂ᵀ` with `d̂ = U @ ĉ`.
|
||||
|
||||
Inference (the mismatch, three parts):
|
||||
1. Per-layer, not shared. Arditi ablates ONE direction 𝐫̂ read off the residual
|
||||
stream and applies the SAME 𝐫̂ at every component (Eq 5 lists embed/attn-out/
|
||||
mlp-out all orthogonalized w.r.t. the same 𝐫̂). Our code learns a SEPARATE
|
||||
direction per target layer (lora_c is per-layer, param_specs:70). These are
|
||||
different objects: Arditi's d̂ is global; ours is a bouquet of per-layer d̂.
|
||||
2. Restricted to top-r U-span. Arditi's projector acts on the full d_model residual
|
||||
vector. Ours can only remove components that lie in span(U[:, :r]); anything in
|
||||
W_res (the frozen remainder, init:86-87) is untouched. With cov_orient=False this
|
||||
is plain-SVD top-r, which need not contain the behavior direction (the docstring
|
||||
itself flags this: ":52-53 measured 1.00 vs 0.65 capture at r=16").
|
||||
3. Output-side only vs residual-stream. Arditi's residual-stream ablation (Eq 4) and
|
||||
its weight-equivalent (Eq 5) zero the direction the layer WRITES. Our `y + h@U.T`
|
||||
subtracts from this layer's additive output contribution, which IS a residual
|
||||
writer for down_proj/o_proj. The docstring scopes the variant to exactly those
|
||||
("target residual writers (down_proj, o_proj)", :17-18), so for those targets the
|
||||
output-side framing is the correct analogue of Eq 5. For a non-writer Linear it
|
||||
would not be (correctly excluded by design).
|
||||
|
||||
VERDICT: DEVIATES-OK. The single-direction outer-product structure is faithfully
|
||||
reproduced inside each layer's output subspace, and for residual-writer targets the
|
||||
output-side action matches Eq 5's "don't write 𝐫̂". The deviations (per-layer learned
|
||||
direction instead of one shared difference-in-means direction; confined to top-r) are
|
||||
deliberate design choices, not the paper's algorithm, and the docstring's framing
|
||||
("trainable form of directional ablation") signals the analogy rather than claiming
|
||||
identity. No bug, but the correspondence is "Arditi-style projector, per-layer,
|
||||
within rank-r" — not the verbatim global single-direction ablation.
|
||||
|
||||
---
|
||||
|
||||
## 2. Contraction / "ablation cannot amplify"
|
||||
|
||||
Paper does not state a contraction theorem in the algebraic sense, but the operator
|
||||
itself is a projection. Closest textual support, Eq 4 (arditi md:155) is the
|
||||
orthogonal projector `I - 𝐫̂𝐫̂ᵀ` (idempotent, eigenvalues {0,1}); and the paper
|
||||
repeatedly frames ablation as removal/erasure, never amplification:
|
||||
|
||||
> Directional ablation "zeroes out" the component along 𝐫̂ for every residual stream activation (arditi md:151)
|
||||
|
||||
> these loss metrics suggest that directional ablation is more surgical than activation addition based methods (arditi md:323)
|
||||
|
||||
(The paper's amplification is a SEPARATE intervention — activation *addition*, Eq 3,
|
||||
md:144 — confirming ablation is the non-amplifying side.)
|
||||
|
||||
Code (forward:179, 185 and clamp):
|
||||
|
||||
```
|
||||
alpha = layer.lora_alpha.to(x.dtype).clamp(0.0, 1.0) # :179
|
||||
h = h - coeff * (proj * alpha) @ Chat.T # :185
|
||||
```
|
||||
|
||||
Observation (math): with coeff=1 the core operator on r-space is
|
||||
`M = I - Chat diag(alpha) Chatᵀ`. Chat orthonormal ⇒ eigenvalues are `1 - alphaⱼ`
|
||||
along each chat_j and `1` on the orthogonal complement. With alpha clamped to [0,1]
|
||||
(:179), every eigenvalue lies in [0,1]. So M is a contraction (operator norm ≤ 1):
|
||||
it cannot amplify. Arditi's exact projector is the alpha=1 endpoint (eigenvalue 0 =
|
||||
full erasure); alpha<1 is partial ablation, still non-amplifying.
|
||||
|
||||
Caveat (the code's own warning, :20-21, :49): coeff is OUTSIDE the clamp. coeff<0
|
||||
flips the sign and ADDS the direction back (eigenvalue `1 + |coeff|·alpha` > 1), and
|
||||
coeff>1 over-subtracts (eigenvalue `1 - coeff·alpha` can go negative, |·|>1). The
|
||||
contraction guarantee holds only for coeff ∈ [0,1], which the docstring states
|
||||
explicitly ("<0 adds the direction back (the side that can grow, so bound coeff
|
||||
there)"). That is correct and self-documented, not a silent bug.
|
||||
|
||||
VERDICT: MATCHES (for the documented operating range coeff∈[0,1], alpha∈[0,1]).
|
||||
Eigenvalues in [0,1] confirmed; the alpha clamp is the load-bearing line. The coeff<0
|
||||
amplifying branch is intentional and flagged.
|
||||
|
||||
---
|
||||
|
||||
## 3. Direction source: fixed difference-in-means vs trainable
|
||||
|
||||
Paper, the direction is a FIXED difference-of-means, selected once (arditi md:124, 134):
|
||||
|
||||
> We then compute the difference-in-means vector 𝐫ᵢ(l) = 𝛍ᵢ(l) − 𝛎ᵢ(l).
|
||||
|
||||
> We notate the selected vector as 𝐫, and its corresponding unit-norm vector as 𝐫̂.
|
||||
|
||||
i.e. 𝐫̂ is computed from mean(harmful) − mean(harmless) and then frozen; the paper
|
||||
performs NO gradient descent on it ("does not require gradient-based optimization",
|
||||
md:236).
|
||||
|
||||
Code: lora_c is a trainable parameter (param_specs:70 `trainable` default; random
|
||||
normal init), optionally warm-started from a contrastive S-space direction dS
|
||||
(docstring init:88-90 "group_init() should warm-start lora_c from the S-space
|
||||
contrastive direction dS").
|
||||
|
||||
Observation: ours is structurally the same OBJECT (a unit direction that gets
|
||||
projected out) but obtained by a different procedure (SGD, optionally seeded from a
|
||||
contrastive diff) rather than a closed-form difference-of-means.
|
||||
|
||||
Inference: calling it "the trainable form of directional ablation (Arditi+ 2024)"
|
||||
(:16-17) is defensible AS A FRAMING: the operator (Eq 4/5 outer-product removal) is
|
||||
Arditi's; the novelty claimed is making the direction learnable. The optional
|
||||
warm-start from a contrastive dS is even closer to Arditi's diff-of-means seed. This
|
||||
is a fair "trainable variant of X" claim, not a misattribution. It would be wrong to
|
||||
claim Arditi's *method* (which is explicitly gradient-free) — but the code claims the
|
||||
*ablation operator*, not the extraction method.
|
||||
|
||||
VERDICT: DEVIATES-OK. "Trainable form of directional ablation" is an honest framing:
|
||||
same operator, deliberately different (learned) direction source. Recommend the
|
||||
docstring keep the word "form"/"trainable" prominent so it is not read as
|
||||
reproducing Arditi's gradient-free diff-of-means extraction.
|
||||
|
||||
---
|
||||
|
||||
## 4. SVD sign disambiguation
|
||||
|
||||
Paper: SILENT on singular-vector sign. Arditi never does an SVD of a weight to get
|
||||
its ablation direction — 𝐫̂ comes from difference-of-means (md:117-124), and the
|
||||
weight-orthogonalization Eq 5 (md:246) uses the outer product 𝐫̂𝐫̂ᵀ. So the paper
|
||||
offers no quote on SVD sign; I reason from the math.
|
||||
|
||||
Code: ablation core is the outer product `Chat Chatᵀ` (forward:183-185), and U is
|
||||
orthonormal from `torch.linalg.svd` (init:80).
|
||||
|
||||
Observation (math):
|
||||
- Chat enters only as `Chat Chatᵀ` (proj `h @ Chat` then `@ Chat.T`, :183-185).
|
||||
Flipping any column sign `chat_j → −chat_j` leaves `chat_j chat_jᵀ` unchanged ⇒
|
||||
the operator is sign-invariant in c. (lora_c is trainable anyway, so its sign is
|
||||
not even an SVD artifact.)
|
||||
- U enters only as `h @ U.T` (:186) AND the projector identity from point 1 is
|
||||
`(U c)(U c)ᵀ`; a sign flip `u_i → −u_i` with the matching `vh_i → −vh_i` leaves
|
||||
`W = (U S) Vh` invariant (init:86) and leaves `(U c)(U c)ᵀ` invariant. So neither
|
||||
the reconstructed weight nor the ablation projector depends on per-vector sign.
|
||||
- S is non-negative by SVD definition; no sign issue.
|
||||
|
||||
Inference: no sign canonicalization is needed anywhere in this file. The ablation is
|
||||
a quadratic form in both the (trainable, hence sign-free) direction and the
|
||||
orthonormal basis, and every place U/Vh appear they appear in sign-paired products
|
||||
or outer products. This is the correct situation (a basis/span use, not a
|
||||
sign-sensitive coordinate use).
|
||||
|
||||
VERDICT: MATCHES (math-derived; paper silent on SVD sign, stated). No
|
||||
canonicalization required and none missing.
|
||||
|
||||
---
|
||||
|
||||
## 5. Citation check
|
||||
|
||||
Docstring (:16-17, :23):
|
||||
|
||||
> This is the trainable form of directional ablation (Arditi+ 2024 [...]). Refs: [...] directional ablation Arditi+ 2024 arXiv:2406.11717.
|
||||
|
||||
Paper title/authors (arditi md:1-21):
|
||||
|
||||
> # Refusal in Language Models Is Mediated by a Single Direction
|
||||
> Andy Arditi [...] Oscar Obeso [...] Aaquib Syed [...] Daniel Paleka [...] Nina Rimsky [...] Wes Gurnee [...] Neel Nanda
|
||||
|
||||
Observation: arXiv:2406.11717 = "Refusal in Language Models Is Mediated by a Single
|
||||
Direction", first author surname Arditi. The term "directional ablation" is the
|
||||
paper's own (§2.4 heading, md:148). The arXiv id, surname, year, and method name all
|
||||
match.
|
||||
|
||||
cov_orient / CorDA attribution. Code config comment (:50) and group_init docstring
|
||||
(:94) say "CorDA" by name:
|
||||
|
||||
> CorDA-orient the basis from input covariance (group_init [...]) (:50)
|
||||
> re-orient each target's SVD by input covariance C=E[x xᵀ] (CorDA) (:94)
|
||||
|
||||
CorDA paper core (corda md:23, line "C=XXᵀ", and `SVD(WC)=UΣVᵀ`, reconstruct
|
||||
`Ŵ=UΣVᵀC⁻¹`):
|
||||
|
||||
> obtain the covariance matrix of the input activation [...] C=XXᵀ [...] perform singular value decomposition for the weight multiplied by the covariance matrix, i.e. SVD(WC)=UΣVᵀ [...] the inverse of these covariance matrices is multiplied with the decomposed components to hold the same inference result with the original model
|
||||
|
||||
Observation: CorDA IS attributed by name (:50, :94). One math nuance worth recording
|
||||
(not a citation error): CorDA whitens with `C` on one side and reconstructs with
|
||||
`C⁻¹` (`SVD(WC)`, then `VᵀC⁻¹`). The code instead uses SYMMETRIC whitening
|
||||
`SVD(W C^{1/2})` with `Pr = Vh C^{-1/2}` (group_init:145-150). Both preserve the
|
||||
forward map (`W C^{1/2} · C^{-1/2} = W`) and both put data-relevant output
|
||||
directions in the top-r, but the orientation is not bit-identical to CorDA's `WC`.
|
||||
This is an acknowledged variant (task list #22 "Ablate whitening: C^1/2 (mine) vs C
|
||||
(PEFT)"), so the "CorDA-orient" name is a fair attribution of the IDEA
|
||||
(covariance-oriented SVD with inverse-covariance reconstruction), with the symmetric
|
||||
square-root being this repo's choice.
|
||||
|
||||
VERDICT: MATCHES (Arditi citation correct: id, surname, method name). CorDA is
|
||||
attributed by name; the C^{1/2} symmetric-whitening detail is a labeled variant of
|
||||
CorDA's C/C⁻¹, not a mis-citation.
|
||||
|
||||
---
|
||||
|
||||
## Bottom line
|
||||
|
||||
No real bugs: the operator reproduces Arditi's outer-product ablation faithfully
|
||||
inside each layer's rank-r output subspace, is a proven contraction for the documented
|
||||
coeff∈[0,1]/alpha∈[0,1] range, needs no SVD sign canonicalization, and both Arditi and
|
||||
CorDA are correctly attributed (the per-layer-learned direction and C^{1/2} whitening
|
||||
are deliberate, self-documented design choices, not deviations from a claimed
|
||||
reproduction).
|
||||
@@ -0,0 +1,252 @@
|
||||
# Reference-fidelity check: antipasto.py vs Wanda / ASVD / PiSSA
|
||||
|
||||
Scope: math/algorithm fidelity and citation honesty only. Fail-fast research code, so
|
||||
missing None-checks / fallbacks / backward-compat are explicitly NOT flagged.
|
||||
|
||||
File under review: `src/lora_lite/variants/antipasto.py`
|
||||
Papers (read in full where relevant):
|
||||
- Wanda: `docs/papers/md/wanda_2306.11695.md`
|
||||
- ASVD: `docs/papers/md/asvd_2312.05821.md`
|
||||
- PiSSA: `docs/papers/pissa_2404.02948.txt`
|
||||
|
||||
Legend per point: block quote (paper, location) -> code (file:line) -> VERDICT.
|
||||
"obs" = directly read; "inf" = my inference from those reads.
|
||||
|
||||
---
|
||||
|
||||
## 1. Wanda metric vs our per-direction selection score
|
||||
|
||||
Paper, Wanda Eq.(1), Sec.3 "Pruning Metric" (wanda_2306.11695.md:83):
|
||||
|
||||
> 𝐒_ij = |𝐖_ij| · ‖𝐗_j‖₂
|
||||
|
||||
and (wanda_2306.11695.md:62, Fig.1 caption):
|
||||
|
||||
> we compute the weight importance as the elementwise product between the weight
|
||||
> magnitude and the norm of input activations (|𝐖| · ‖𝐗‖₂). Weight importance
|
||||
> scores are compared on a per-output basis (within each row in 𝐖), rather than
|
||||
> globally across the entire matrix.
|
||||
|
||||
Our code (antipasto.py:159, with :154-158):
|
||||
|
||||
```python
|
||||
proj = X.to(Vh_full) @ Vh_full.T # (N, r), input projected onto right singular vecs
|
||||
act_mag = proj.pow(2).mean(0).sqrt() # 'rms' (L2-style, per-direction)
|
||||
# = proj.abs().mean(0) # 'mean_abs'
|
||||
scores = S_full * act_mag # score[i] = S[i] * pool|X @ Vh[i]|
|
||||
```
|
||||
|
||||
obs: Wanda's score is per scalar WEIGHT element `W_ij` (a `C_out x C_in` grid of
|
||||
scores), magnitude `|W_ij|` times the L2 norm of the corresponding INPUT-CHANNEL
|
||||
activation `X_j`, compared within each output row to decide which entries to zero.
|
||||
|
||||
obs: Our score is per singular DIRECTION `i in [0,r)`: singular value `S[i]` (a
|
||||
property of the whole rank-1 component `U[:,i] S[i] Vh[i]`) times the pooled
|
||||
magnitude of the activation projected onto the right singular vector `Vh[i]`.
|
||||
The comparison group is the spectrum of one weight matrix; the action is direction
|
||||
SELECTION (keep top-r), not weight zeroing.
|
||||
|
||||
inf: The shared idea is genuine: "importance = magnitude x how much the input
|
||||
actually drives this coordinate," estimated from calibration activations in a
|
||||
single forward pass. With `pool='rms'`, `pool|X @ Vh[i]|` IS exactly an L2-style
|
||||
norm of the activation in the rotated (singular) basis, so it is the Wanda norm
|
||||
applied to `X @ Vh^T` instead of raw `X`. If `Vh` were the identity (axis-aligned
|
||||
channels), `S[i]·rms(X[:,i])` would reduce to a per-channel Wanda score.
|
||||
|
||||
inf: Where the analogy breaks (three real differences, none hidden by the code):
|
||||
(a) granularity: Wanda scores individual weights `W_ij`; ours scores rank-1
|
||||
SVD components. (b) basis: Wanda works in the raw input-channel basis; ours in the
|
||||
right-singular-vector basis (`X @ Vh^T`). (c) operation: Wanda PRUNES (sets weights
|
||||
to 0, keeping the matrix shape); ours SELECTS which directions land in the trainable
|
||||
low-rank core vs the frozen residual. So `|W_ij|` -> `S[i]` is an analogy, not an
|
||||
identity: `S[i]` is the magnitude of a whole component, not of one weight.
|
||||
|
||||
VERDICT: DEVIATES-OK (honest analogy, not an over-claim). The docstring labels it
|
||||
"Wanda-style pooling" / "Wanda/ASVD" (antipasto.py:50, 95) -- "-style" is the right
|
||||
hedge. It would only be an OVERCLAIM if it claimed to BE Wanda; it says
|
||||
"Wanda-style," which is accurate. Note: Wanda found L2 beats L1/L_inf
|
||||
(wanda_2306.11695.md:85); our `'rms'` pool matches that recommendation, our default
|
||||
(antipasto.py:52) is `'rms'`.
|
||||
|
||||
---
|
||||
|
||||
## 2. ASVD: whitening vs intuition-only citation
|
||||
|
||||
Paper, ASVD Sec.3.3, transform + scaled SVD (asvd_2312.05821.md:215-249):
|
||||
|
||||
> 𝐖 = 𝐖𝐒𝐒⁻¹ = (𝐖𝐒)𝐒⁻¹. ... apply SVD to the transformed matrix 𝐖𝐒 ...
|
||||
> (𝐖𝐒):,i = 𝐖:,i 𝐒_ii ... 𝐒_ii⁻¹ scales the i-th channel of the activation
|
||||
|
||||
Paper, magnitude rule for S (asvd_2312.05821.md:254, Eq.8):
|
||||
|
||||
> 𝐒_ii := (1/n Σ_j |𝐗_ij|)^α
|
||||
|
||||
obs: ASVD's mechanism is: build a diagonal (or Cholesky) scaling `S` from activation
|
||||
statistics, decompose the SCALED matrix `W S`, then fold `S^-1` back into `V`. The
|
||||
SVD basis is changed by the activation statistics. ASVD's default magnitude rule is
|
||||
absolute-MEAN per channel raised to alpha (Eq.8), with alpha=0.5 in their experiments
|
||||
(asvd_2312.05821.md:315).
|
||||
|
||||
obs: Our code does plain `torch.linalg.svd(W_orig)` (antipasto.py:153) on the raw,
|
||||
UNscaled weight. There is no `S` whitening matrix, no `W S`, no `S^-1` fold-back.
|
||||
Activations enter ONLY through the post-hoc selection score (antipasto.py:154-159),
|
||||
never into the decomposition basis.
|
||||
|
||||
inf: So antipasto does NOT implement ASVD whitening. It borrows one narrow idea: that
|
||||
an outlier-sensitive (L2/rms) pooling of activations is the right statistic, vs the
|
||||
outlier-robust mean-abs. That is the `act_pool` knob (antipasto.py:51-52, 155-158).
|
||||
ASVD itself uses mean-abs (Eq.8); the "rms = ASVD intuition" comment is a slight
|
||||
liberty -- ASVD's STATED motivation is absorbing activation OUTLIERS
|
||||
(asvd_2312.05821.md:113, 132), and rms/L2 is more outlier-sensitive than mean-abs, so
|
||||
the rms choice is "in the spirit of ASVD's outlier-awareness," even though ASVD's own
|
||||
formula uses mean-abs. The docstring is careful: "'rms' is outlier-sensitive (ASVD
|
||||
intuition)" (antipasto.py:51) and the corda docstring (antipasto.py:99) explicitly
|
||||
says re-orienting the basis "is CorDA -> antipasto_corda.py," i.e. antipasto does NOT
|
||||
whiten.
|
||||
|
||||
VERDICT: DEVIATES-OK / honest. Citation is for INTUITION (outlier-sensitive pooling),
|
||||
not IMPLEMENTATION (whitening), and the word "intuition" is right there in the
|
||||
comment. No over-claim that antipasto performs activation-aware SVD. Minor nit (not a
|
||||
bug): ASVD's own scaling statistic is mean-abs, so attributing rms specifically to
|
||||
ASVD is loose; the honest attribution is "ASVD-style outlier-awareness, but
|
||||
L2-pooled." Optional one-word fix, not required.
|
||||
|
||||
---
|
||||
|
||||
## 3. PiSSA: top-r init vs training the components
|
||||
|
||||
Paper, PiSSA abstract (pissa_2404.02948.txt:13-18):
|
||||
|
||||
> PiSSA ... initializes the adaptor matrices A and B with the principal components
|
||||
> of the original matrix W, and put the remaining components into a residual matrix
|
||||
> W_res ∈ R^{m×n} which is frozen during fine-tuning. ... PiSSA updates the
|
||||
> principal components while freezing the "residual" parts.
|
||||
|
||||
Paper, PiSSA Table 1 (pissa_2404.02948.txt:70-99):
|
||||
|
||||
> A = U[:,:r] S^{1/2}[:r,:r]; B = S^{1/2}[:r,:r] V^T[:,:r];
|
||||
> W_res = U[:,r:] S[r:,r:] V^T[:,r:]; "Fine-tunes principal parts freezing W_res."
|
||||
|
||||
obs: PiSSA makes the top-r principal components `A,B` (i.e. `U_r, S_r, V_r`)
|
||||
TRAINABLE and freezes `W_res = W - U_r S_r V_r`. The singular vectors themselves move
|
||||
during fine-tuning.
|
||||
|
||||
obs: Our init (antipasto.py:78-87) takes the same top-r SVD and the same residual:
|
||||
`W_res = W - (U_r * S_r) @ Vh_r`, written into `layer.weight` (antipasto.py:86-87),
|
||||
matching PiSSA's `W_res`. BUT `lora_U, lora_S, lora_Vh` are registered
|
||||
`trainable=False, as_buffer=True` (antipasto.py:64-66); the ONLY trainable parameter
|
||||
is `lora_g` (antipasto.py:68), a per-direction gain. The forward (antipasto.py:195)
|
||||
keeps `U, S, Vh` frozen and only learns `S_eff = S*(1+ELU(coeff*g))`.
|
||||
|
||||
inf: So antipasto shares PiSSA's INITIALIZATION (top-r SVD + frozen residual) but NOT
|
||||
its training target. PiSSA trains the full `U,S,V`; antipasto freezes the basis and
|
||||
learns only a scalar gain per direction. These are different methods; antipasto is
|
||||
much more constrained (r+r... actually r buffers + r trainable scalars).
|
||||
|
||||
obs: The docstring does NOT over-claim. It cites PiSSA precisely as "top-r SVD init"
|
||||
(antipasto.py:21, 95: "init(): top-r by S alone (PiSSA-style)") and the init() error
|
||||
message says "mutates layer.weight into W_res (like PiSSA)" (antipasto.py:75) --
|
||||
scoped to the W_res construction, not to training the components. Line 14-15 of the
|
||||
module docstring states the basis is frozen and only the gain is learned, the opposite
|
||||
of PiSSA's claim, so no reader would conflate them.
|
||||
|
||||
VERDICT: MATCHES (citation correctly scoped). PiSSA is invoked only for the top-r SVD
|
||||
init / W_res residual idea, explicitly "PiSSA-style," and the docstring repeatedly
|
||||
states the basis is FROZEN -- no false PiSSA-equivalence. Not an over-claim.
|
||||
|
||||
---
|
||||
|
||||
## 4. SVD sign disambiguation (the user's specific question)
|
||||
|
||||
obs: None of the three papers canonicalizes singular-vector signs.
|
||||
- Wanda never decomposes via SVD; it scores `|W_ij|·‖X_j‖₂` on raw weights. Sign is
|
||||
irrelevant by construction (absolute value). No svd_flip.
|
||||
- ASVD does SVD on `W S` (asvd_2312.05821.md:217) but its objective is the
|
||||
reconstruction `U_k Σ_k V_k^T` and the Frobenius output error `‖ΔY‖_F`
|
||||
(asvd_2312.05821.md:186, 260). A simultaneous sign flip of column `U[:,i]` and row
|
||||
`V[:,i]^T` leaves the product (hence the reconstruction and the error) invariant, so
|
||||
ASVD has no reason to canonicalize and the paper does not mention sign/`svd_flip`.
|
||||
- PiSSA sets `A = U[:,:r] S^{1/2}`, `B = S^{1/2} V^T[:,:r]` (pissa_2404.02948.txt:71-76).
|
||||
The product `A B = U_r S_r V_r^T` is sign-invariant under a paired column/row flip,
|
||||
and PiSSA TRAINS `A,B` afterward, so an initial sign is just a starting point. No
|
||||
sign canonicalization in the paper.
|
||||
|
||||
obs: Our code performs NO sign flip anywhere (no svd_flip, no max-abs-positive
|
||||
convention). `torch.linalg.svd` returns whatever signs LAPACK gives.
|
||||
|
||||
inf: Omitting sign canonicalization is correct here, because every place the signs
|
||||
could matter is sign-invariant:
|
||||
- `S > 0` always (singular values are nonnegative), and the gain rides on `S`:
|
||||
`S_eff = S*(1+ELU(coeff*g))` (antipasto.py:195). `1+ELU(.) > 0`, so `S_eff > 0`
|
||||
regardless of `U/Vh` sign. The reconstruction `((x @ Vh^T) * S_eff) @ U^T`
|
||||
(antipasto.py:197-198) is invariant under a paired flip of `Vh[i]` and `U[:,i]`
|
||||
because the flips cancel in the rank-1 term `(x·Vh[i]) S_eff[i] U[:,i]`.
|
||||
- The selection score uses `|X @ Vh[i]|` via `proj.pow(2)` or `proj.abs()`
|
||||
(antipasto.py:156-158), both even in the sign of `Vh[i]`. So a flipped `Vh[i]`
|
||||
gives the identical score.
|
||||
- `lora_g` init is 0 (antipasto.py:68), a sign-symmetric starting point; the learned
|
||||
gain multiplies `S` (positive), not `U/Vh`, so its meaning does not depend on basis
|
||||
sign.
|
||||
|
||||
inf: A sign convention WOULD matter if antipasto ever (a) compared `U`/`Vh` across
|
||||
layers/checkpoints, (b) initialized `g` from a signed activation projection (e.g.
|
||||
`X @ Vh` without abs), or (c) added the rank-1 terms with separate trainable signs.
|
||||
It does none of these. Re: Bro et al. 2008 (sign-determination by data alignment) --
|
||||
none of the three cited papers use it, and antipasto does not need it.
|
||||
|
||||
VERDICT: MATCHES (omission is correct). No cited paper canonicalizes signs, and our
|
||||
two sign-touching quantities (`S`-rode gain, `|X@Vh|` score) are both sign-invariant.
|
||||
Adding svd_flip would be dead code here.
|
||||
|
||||
---
|
||||
|
||||
## 5. The "1+ELU" gain: attribution
|
||||
|
||||
obs: The `S_eff = S*(1+ELU(coeff*g))` reparameterization (antipasto.py:7, 195) is not
|
||||
in Wanda, ASVD, or PiSSA -- none of them learn a per-direction gain at all (Wanda
|
||||
prunes, ASVD truncates, PiSSA trains full A,B). The module header attributes the
|
||||
overall method to "wassname 2026 https://arxiv.org/abs/2601.07473" (antipasto.py:3)
|
||||
and "paper: https://github.com/wassname/AntiPaSTO" (antipasto.py:18).
|
||||
|
||||
obs: The Refs block (antipasto.py:17-21) lists Wanda/ASVD under "selection" and PiSSA
|
||||
under "top-r SVD init" only. Neither the `1+ELU` line nor the forward() rationale
|
||||
(antipasto.py:188-194) cites any of the three papers for the gain. The gain is
|
||||
presented as the authors' own.
|
||||
|
||||
VERDICT: MATCHES (no false attribution). The 1+ELU gain is correctly presented as the
|
||||
authors' own contribution; the three citations are confined to selection and init.
|
||||
(Per instructions, the linear/exp/tanh rationale comment at antipasto.py:189-194 is
|
||||
intentional and not flagged.)
|
||||
|
||||
---
|
||||
|
||||
## 6. Citations: arXiv ids, surnames, years
|
||||
|
||||
obs (from the source files' own headers/URLs):
|
||||
- Wanda: `wanda_2306.11695.md:22-30` -- "A Simple and Effective Pruning Approach...",
|
||||
Mingjie Sun, Zhuang Liu, Anna Bair, J. Zico Kolter; arXiv html id `2306.11695v3`.
|
||||
Code -> antipasto.py:20 "Wanda (Sun+ 2023, arXiv:2306.11695)". Sun, 2023, id match.
|
||||
- ASVD: `asvd_2312.05821.md:26-68` -- "ASVD: Activation-aware Singular Value
|
||||
Decomposition...", Zhihang Yuan (first author) et al.; arXiv html id `2312.05821v5`.
|
||||
Code -> antipasto.py:20 "ASVD (Yuan+ 2023, arXiv:2312.05821)". Yuan, 2023, id match.
|
||||
- PiSSA: `pissa_2404.02948.txt:1-44` -- "PiSSA: Principal Singular Values and Singular
|
||||
Vectors Adaptation...", Fanxu Meng, Zhaohui Wang, Muhan Zhang; NeurIPS 2024;
|
||||
"arXiv:2404.02948v4". Code -> antipasto.py:21 "PiSSA (Meng+ 2024, arXiv:2404.02948)".
|
||||
Meng, 2024, id match.
|
||||
|
||||
inf: All three (surname, year, arXiv id) check out against the papers' own front
|
||||
matter. Wanda and ASVD first appeared on arXiv in 2023 (v1), PiSSA is NeurIPS 2024 --
|
||||
the "+ year" tags are the submission years, which is the conventional choice.
|
||||
|
||||
VERDICT: MATCHES (all citations correct). No CITATION-WRONG.
|
||||
|
||||
---
|
||||
|
||||
## Bottom line
|
||||
|
||||
No real math/algorithm bugs and no dishonest citations: Wanda/ASVD are honestly cited
|
||||
as "-style"/"intuition" (selection + outlier-aware pooling, not literal pruning or
|
||||
whitening), PiSSA is correctly scoped to the top-r SVD/W_res init (not training the
|
||||
components), the no-sign-flip choice is correct because every sign-sensitive quantity
|
||||
is sign-invariant, the 1+ELU gain is the authors' own and not mis-attributed, and all
|
||||
three arXiv ids/surnames/years match.
|
||||
@@ -0,0 +1,209 @@
|
||||
# Ref-check: `antipasto_corda.py` vs CorDA (Yang+ 2024, arXiv:2406.05223)
|
||||
|
||||
Reviewer: skeptical ML reviewer. Scope: math/algorithm fidelity + citation only.
|
||||
Files:
|
||||
- impl: `/media/wassname/SGIronWolf/projects5/2026/lite/lora-lite/src/lora_lite/variants/antipasto_corda.py`
|
||||
- paper: `/media/wassname/SGIronWolf/projects5/2026/lite/lora-lite/docs/papers/md/corda_2406.05223.md`
|
||||
- secondary: `/media/wassname/SGIronWolf/projects5/2026/lite/lora-lite/docs/papers/md/asvd_2312.05821.md`
|
||||
|
||||
Pre-note on the whitening question (drives points 2,3): PEFT's CorDA does `C = L L^T`
|
||||
(Cholesky), `SVD(W L)`, then unwinds with `L^{-1}`. Our code does `C^{1/2}` symmetric via
|
||||
`eigh`, `SVD(W C^{1/2})`, unwinds with `C^{-1/2}`. The paper's *prose* says `SVD(W C)` with
|
||||
`C^{-1}` (full covariance, Eq. 2-3), which is a THIRD object. The distinction matters and is
|
||||
worked through in point 2.
|
||||
|
||||
---
|
||||
|
||||
## 1. Covariance definition + which-half (KPM vs IPM)
|
||||
|
||||
> **Paper, Sec. 3.2 (Context-Oriented Decomposition):** "Denote $X \in \mathbb{R}^{d_{in}\times BL}$
|
||||
> as the input activation of a linear layer ... We have the covariance matrix
|
||||
> $C = XX^T \in \mathbb{R}^{d_{in}\times d_{in}}$."
|
||||
|
||||
> **Paper, Sec. 3.4 (Mode 2: Instruction-Previewed Adaptation):** "we use the first $r$
|
||||
> components with the largest $r$ singular values ... $B = U_{[:,:r]}\sqrt{\Sigma}_{[:r]}$,
|
||||
> $A = \sqrt{\Sigma}_{[:r]}(V^T C^{-1})_{[:r,:]}$" — covariance from "instruction and response
|
||||
> from the training data used for fine-tuning".
|
||||
|
||||
Observation: `C = XX^T` is the un-centered input second moment (a Gram matrix over tokens),
|
||||
NOT a mean-centered covariance, and NOT normalized by token count. CorDA's KPM (Mode 1, Eq. 4)
|
||||
keeps the SMALLEST r (freezes the principal/knowledge directions); IPM (Mode 2, Eq. 5) trains
|
||||
the LARGEST r.
|
||||
|
||||
Code (`antipasto_corda.py:110-113, 145, 151-153`):
|
||||
```python
|
||||
g = x.T @ x # (d_in, d_in) -> sum_tokens x x^T (XX^T)
|
||||
cov[name] = ... cov[name] + g
|
||||
...
|
||||
C = cov[name] / cnt[name] # normalized to a MEAN second moment E[x x^T]
|
||||
...
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, ...)
|
||||
Ur = Ut[:, :r]; Sr = St[:r] # keeps TOP-r
|
||||
```
|
||||
|
||||
Inference: Code accumulates `sum x x^T` = the paper's `XX^T` exactly (same un-centered second
|
||||
moment). The `/ cnt` rescale (line 145) is a global positive scalar on `C`; it multiplies every
|
||||
singular value of `W C^{1/2}` by `sqrt(1/cnt)` and leaves the singular VECTORS (hence the top-r
|
||||
subspace, U, P-directions) untouched. The absolute `S` differs by `sqrt(1/cnt)`, but `S` is
|
||||
only a frozen init for the trainable gain `g` (`S_eff = S*(1+ELU(coeff*g))`), so the scale is
|
||||
absorbed and behaviorally inert.
|
||||
|
||||
Keeping TOP-r = the paper's IPM (Mode 2). Code collects cov on `calibration_data` passed at
|
||||
attach-time; the docstring calls it "downstream-task samples" — consistent with IPM (instruction
|
||||
data orients the decomposition, train the largest r). The code does NOT implement KPM (bottom-r).
|
||||
|
||||
VERDICT: **MATCHES** (covariance = paper's `XX^T`; `/cnt` is an inert global scale; TOP-r is
|
||||
IPM, which the code is — KPM simply not implemented, not a discrepancy).
|
||||
|
||||
---
|
||||
|
||||
## 2. Whitening object: `SVD(W C^{1/2})` vs paper's `SVD(W C)` vs PEFT's `SVD(W L)`
|
||||
|
||||
> **Paper, Sec. 1 / Eq. 2:** "$\verb|SVD|(WC) = U\Sigma V^T = \sum_{i=1}^R \sigma_i \mathbf{u}_i \mathbf{v}_i^T$"
|
||||
> **Paper, Eq. 3:** "$\hat W = \verb|SVD|(WC)\,C^{-1} = U\Sigma(V^T C^{-1})$".
|
||||
|
||||
Code (`antipasto_corda.py:146-154`):
|
||||
```python
|
||||
lam, Q = torch.linalg.eigh(C)
|
||||
lam = lam.clamp_min(0) + eps
|
||||
Chalf = (Q * lam.sqrt()) @ Q.T # symmetric C^{1/2}
|
||||
Cinvhalf = (Q * lam.rsqrt()) @ Q.T # symmetric C^{-1/2}
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
|
||||
...
|
||||
Pr = (Vht[:r] @ Cinvhalf)
|
||||
```
|
||||
|
||||
This is the crux. Three candidate objects:
|
||||
|
||||
(a) Paper literal `SVD(W C)`: C is symmetric PSD, so `W C` has singular values = those of
|
||||
`W C` directly. Since `C = C^{1/2} C^{1/2}`, `W C = (W C^{1/2}) C^{1/2}`. The singular values of
|
||||
`W C` are NOT equal to those of `W C^{1/2}` in general (they are squared-ish in the C-spectrum:
|
||||
`W C` weights directions by `lambda_i`, `W C^{1/2}` by `sqrt(lambda_i)`). So the paper's literal
|
||||
`SVD(WC)` and our `SVD(W C^{1/2})` give DIFFERENT singular values and DIFFERENT top-r subspaces.
|
||||
|
||||
(b) PEFT's `SVD(W L)`, `C = L L^T` Cholesky. `L` is a (different) square-root factor of C.
|
||||
|
||||
(c) Our `SVD(W C^{1/2})`, symmetric square root.
|
||||
|
||||
Key linear-algebra fact for (b) vs (c): both `L` and `C^{1/2}` satisfy `M M^T = C`. Any two such
|
||||
factors relate by `C^{1/2} = L Q` for some orthogonal `Q` (polar/QR freedom). Then
|
||||
`W C^{1/2} = (W L) Q`. Right-multiplying by orthogonal `Q` leaves SINGULAR VALUES identical and
|
||||
LEFT singular vectors `U` identical; only the right singular vectors rotate (`V_sym = Q^T V_chol`).
|
||||
Therefore:
|
||||
- captured top-r left-subspace `U_r`: **identical** between (b) and (c).
|
||||
- singular values `S_r`: **identical** between (b) and (c).
|
||||
- the input-side row space of `Pr = Vht[:r] C^{-1/2}` vs PEFT's `(V^T)[:r] L^{-1}`: both equal
|
||||
`(W L)`/`(W C^{1/2})` top-r right-vectors unwound by the matching inverse factor. Since
|
||||
`V_sym^T C^{-1/2} = (Q^T V_chol)^T (L Q)^{-1} = V_chol^T Q Q^{-1} L^{-1} = V_chol^T L^{-1}`,
|
||||
the projector `Pr` is **identical** between (b) and (c) too.
|
||||
|
||||
So our symmetric-sqrt form is exactly PEFT's Cholesky form for both (a) captured top-r subspace
|
||||
and (b) the reconstruction. They are interchangeable — confirmed rigorously, not hand-waved.
|
||||
|
||||
But neither equals the paper's LITERAL `SVD(WC)` (object (a) above). The PEFT reference impl and
|
||||
our code both implement the square-root-whitening variant, which is the numerically sane reading
|
||||
(`WC` mixes a `d_out x d_in` weight with a `d_in x d_in` covariance giving the "wrong" energy
|
||||
weighting `lambda` instead of `sqrt(lambda)`; the Eckart-Young-optimal-under-`x~N(0,C)` story the
|
||||
docstring tells is the `sqrt` version). PEFT — the authors' own HF integration linked in the
|
||||
paper header — uses Cholesky `W L`, confirming the paper's `SVD(WC)` prose is loose and the
|
||||
intended/implemented object is the square-root one.
|
||||
|
||||
Reconstruction exactness (point 3 overlaps): code line 155 `W_res_new = W_orig - (Ur*Sr)@Pr` and
|
||||
`W_orig = U_r S_r P_r + W_res`. With `Pr = Vht[:r] C^{-1/2}`:
|
||||
`(Ur Sr Vht[:r]) C^{-1/2} = [SVD(W C^{1/2}) truncated] C^{-1/2}`. The FULL (untruncated) product
|
||||
`U S Vht C^{-1/2} = (W C^{1/2}) C^{-1/2} = W` exactly. Truncated to r it is the top-r piece, and
|
||||
`W_res` carries the exact remainder by subtraction (same trick the paper uses in Eq. 4:
|
||||
"$W' = W - BA$ ... to avoid the numerical error"). Reconstruction `W = W_res + U_r S_r P_r` is
|
||||
exact by construction.
|
||||
|
||||
VERDICT: **DEVIATES-OK**. Symmetric `C^{1/2}` differs from the paper's *printed* `SVD(WC)` but is
|
||||
provably identical to PEFT's `SVD(WL)` Cholesky reference for both top-r subspace and
|
||||
reconstruction. The paper's literal `WC` is the loose/wrong-energy form; matching PEFT (authors'
|
||||
own impl) is the correct choice. Worth a one-line code comment noting the symmetric-sqrt vs
|
||||
paper-prose `WC` discrepancy so a future reader is not confused.
|
||||
|
||||
---
|
||||
|
||||
## 3. Projector `Pr = Vht[:r] @ Cinvhalf` and reconstruction exactness
|
||||
|
||||
> **Paper, Eq. 3:** "$\hat W = U\Sigma(V^T C^{-1})$", with $\hat{\mathbf v}_i^T$ = i-th row of
|
||||
> $V^T C^{-1}$.
|
||||
> **Paper, Eq. 5 (IPM):** "$A = \sqrt{\Sigma}_{[:r]}(V^T C^{-1})_{[:r,:]}$".
|
||||
|
||||
Code (`antipasto_corda.py:154-155`):
|
||||
```python
|
||||
Pr = (Vht[:r] @ Cinvhalf) # (r, d_in)
|
||||
W_res_new = (W_orig - (Ur * Sr) @ Pr)
|
||||
```
|
||||
|
||||
Observation: paper unwinds with `C^{-1}` (because it whitened by full `C`); code unwinds with
|
||||
`C^{-1/2}` (because it whitened by `C^{1/2}`). These are the matched inverse-factors for their
|
||||
respective forward objects — paper: `SVD(WC) C^{-1}` recovers `W`; code:
|
||||
`SVD(W C^{1/2}) C^{-1/2}` recovers `W`. Both are self-consistent; the code's is the PEFT-equivalent
|
||||
square-root form (point 2). The projector is `(d_in)`-side, oblique (rows not orthonormal because
|
||||
`C^{-1/2}` skews them) — matches the paper's `\hat v_i` being rows of a whitened `V^T C^{-...}`.
|
||||
|
||||
Note one structural difference vs paper's adapter split: paper puts `sqrt(Sigma)` into BOTH `B`
|
||||
and `A` (`B = U sqrt(Sigma)`, `A = sqrt(Sigma) V^T C^{-1}`) so the trained product `B*A* `
|
||||
re-learns the magnitude. Our code keeps `S` whole on the projector side via the runtime gain
|
||||
`S_eff` and an orthonormal `U` (`y + ((x@P^T)*S_eff)@U^T`). This is an intentional architectural
|
||||
choice (gain-reweighting antipasto, not LoRA-style free B/A retraining), not a fidelity bug —
|
||||
the captured subspace and the exact-reconstruction identity are unchanged.
|
||||
|
||||
VERDICT: **MATCHES** (reconstruction exact; projector is the matched square-root unwind; the
|
||||
S-on-one-side split is an intentional antipasto design difference, not a CorDA discrepancy).
|
||||
|
||||
---
|
||||
|
||||
## 4. SVD sign disambiguation
|
||||
|
||||
Paper: searched for "sign", "svd_flip", "flip" — **no mention**. CorDA never canonicalizes
|
||||
singular-vector signs. Eq. 2-5 use `U`, `V^T` straight from SVD; the reconstruction `U S V^T C^{-1}`
|
||||
is sign-invariant anyway (a sign flip on column `u_i` and row `v_i^T` cancels in `u_i sigma_i v_i^T`).
|
||||
|
||||
Code: no `svd_flip` / max-abs / data-alignment anywhere (`init` line 82, `group_init` line 151).
|
||||
Docstring/param note lines 68-70: "No sign-symmetry hack needed (1+ELU is sign-preserving, basis
|
||||
frozen)".
|
||||
|
||||
Inference: Our forward is `S_eff = S*(1 + ELU(coeff*g))`, `g` trained from 0, gain rides on `S>0`.
|
||||
A sign flip on `u_i` flips the corresponding row of `P` (=`v_i`-derived) too, so the rank-1 term
|
||||
`u_i (S_eff)_i (P row_i)` is invariant to the joint sign — exactly as in the paper's
|
||||
sign-invariant `u_i sigma_i v_i^T`. `g` is a scalar magnitude per direction, not tied to any
|
||||
fixed sign convention, and the basis is frozen after `group_init`. So omitting sign
|
||||
canonicalization is correct: neither the paper needs it nor do we.
|
||||
|
||||
(Caveat, not a bug: if any downstream analysis inspected `U` or `P` rows individually and assumed
|
||||
a sign convention, it would break. None does here — the gain is sign-agnostic.)
|
||||
|
||||
VERDICT: **MATCHES** (paper does not canonicalize signs; our reconstruction + sign-invariant gain
|
||||
make omission correct).
|
||||
|
||||
---
|
||||
|
||||
## 5. Citation
|
||||
|
||||
> **Paper header:** "CorDA: Context-Oriented Decomposition Adaptation ... Yibo Yang, Xiaojie Li,
|
||||
> Zhongzhu Zhou, Shuaiwen Leon Song, Jianlong Wu, Liqiang Nie, Bernard Ghanem". arXiv:2406.05223.
|
||||
|
||||
Code docstring (`antipasto_corda.py:5, 20`): "CorDA (Yang+ 2024, arXiv:2406.05223)".
|
||||
|
||||
Observation: first author surname is **Yang** (Yibo Yang). arXiv id 2406.05223 matches the paper
|
||||
file and header. Year 2024 matches (NeurIPS 2024; arXiv June 2024).
|
||||
|
||||
Secondary cite (`antipasto_corda.py:103-104`): "Yuan+ 2023, ASVD, arXiv:2312.05821 is the diagonal
|
||||
case". ASVD paper confirms a diagonal scaling matrix `S` (Sec 3.3, "set the transform matrix as a
|
||||
diagonal matrix", Eq. 8) as the simple case, and a Cholesky `L` of `XX^T` as the better variant
|
||||
(ASVD+, lines 261-267) — so calling ASVD "the diagonal case" of covariance-whitening is accurate.
|
||||
ASVD first author is Zhihang Yuan; arXiv 2312.05821 correct.
|
||||
|
||||
VERDICT: **MATCHES** (Yang+ 2024 / 2406.05223 correct; ASVD Yuan+ 2023 / 2312.05821 correct;
|
||||
"diagonal case" characterization accurate).
|
||||
|
||||
---
|
||||
|
||||
## Bottom line
|
||||
|
||||
No real bugs. The one substantive math note: code whitens with symmetric `C^{1/2}` (eigh), which
|
||||
is provably identical to PEFT's Cholesky `W L` reference (same top-r U/S, same projector, exact
|
||||
reconstruction) but differs from the paper's loosely-printed `SVD(WC)` full-covariance form — an
|
||||
intentional, correct deviation; add a one-line comment flagging it. Everything else matches.
|
||||
@@ -0,0 +1,130 @@
|
||||
# Reference-fidelity check: `antipasto_dplr.py` vs LoRA (Hu+ 2021)
|
||||
|
||||
Implementation: `/media/wassname/SGIronWolf/projects5/2026/lite/lora-lite/src/lora_lite/variants/antipasto_dplr.py`
|
||||
Paper: `docs/papers/lora_2106.09685.txt` (arXiv:2106.09685v2)
|
||||
|
||||
Scope: only math/algorithm fidelity and citation. Fail-fast research code; defensive-programming gaps are out of scope.
|
||||
|
||||
---
|
||||
|
||||
## 1. Low-rank core == LoRA's BA (shapes + rank + A/B mapping)
|
||||
|
||||
Paper, Section 4.1, lines 199-204:
|
||||
|
||||
> For a pre-trained weight matrix W0∈ Rd×k, we constrain its update by representing the latter with a low-rank decomposition W0 + ∆W = W0 + BA, where B∈ Rd×r, A∈ Rr×k, and the rank r≪ min(d, k). [...] For h = W0x, our modified forward pass yields: h = W0x + ∆W x = W0x + BAx (3)
|
||||
|
||||
Code, the core term (file:165):
|
||||
|
||||
> `h = p * S_eff + coeff * (p @ A.T) @ B.T # (..., r)`
|
||||
|
||||
with params (file:68-69):
|
||||
|
||||
> `lora_A=ParamSpec((k, r), init="kaiming"),`
|
||||
> `lora_B=ParamSpec((r, k), init="zeros"),`
|
||||
|
||||
Observation: here the "input" to the core is `p` of dim `r` (the SVD subspace coordinate, file:160), and the core maps `r -> r`. So the core's effective weight matrix is square `(r, r)` in subspace coords, NOT `(d, k)`. The bottleneck rank is `k = lora_rank` (the paper's `r`).
|
||||
|
||||
Shape trace for `core(p) = (p @ A.T) @ B.T` with `p: (..., r)`:
|
||||
- `A.T: (r, k)`, so `p @ A.T : (..., k)` -- this is the DOWN-projection `r -> k` (paper's `A: R^{r_paper x k_paper}`, the down-proj).
|
||||
- `B.T: (k, r)`, so `(p @ A.T) @ B.T : (..., r)` -- this is the UP-projection `k -> r` (paper's `B: R^{d x r_paper}`, the up-proj).
|
||||
|
||||
As an operator on a column vector `p`, `core(p) = B @ A @ p` where `A` is `(k, r)` and `B` is `(r, k)`, i.e. the matrix `B@A` of shape `(r, r)` with rank <= `k`. This composes and matches the paper's `BA` form (paper's `B@A` is `(d, k)` rank-`r_paper`; here it is `(r, r)` rank-`k`).
|
||||
|
||||
Mapping (ours -> paper, using paper's subscript `_p`):
|
||||
- our `lora_A (k, r)` <-> paper `A_p (r_p, k_p)`: the down-projection. (our `k` = paper `r_p`; our `r` = paper `k_p`).
|
||||
- our `lora_B (r, k)` <-> paper `B_p (d_p, r_p)`: the up-projection. (our `r` = paper `d_p`; our `k` = paper `r_p`).
|
||||
- bottleneck rank: our `lora_rank = k` <-> paper `r`. Docstring/config already say "the low-rank mixing core (LoRA's r, but inside the frozen subspace)" (file:43).
|
||||
|
||||
VERDICT: MATCHES. Shapes compose; bottleneck rank is `lora_rank=k`; A is down-proj (`r->k`), B is up-proj (`k->r`), consistent with the paper's `BA` once you note the input is `p` (dim `r`), not `x` (dim `d_in`).
|
||||
|
||||
---
|
||||
|
||||
## 2. Zero-init identity (which side is zero; kaiming vs Gaussian)
|
||||
|
||||
Paper, Section 4.1, lines 205-206:
|
||||
|
||||
> We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training.
|
||||
|
||||
Code (file:68-69):
|
||||
|
||||
> `lora_A=ParamSpec((k, r), init="kaiming"),`
|
||||
> `lora_B=ParamSpec((r, k), init="zeros"),`
|
||||
|
||||
Observation: paper zeros `B` (the up-projection) and random-inits `A` (the down-projection). Ours zeros `lora_B`, which I established in point 1 is the up-projection (`k->r`), and random-inits `lora_A`, the down-projection (`r->k`). So the zero is on the up-projection on BOTH sides. The core `(p @ A.T) @ B.T = ... @ 0 = 0` at init regardless of `A`. Combined with `g=0 -> 1+ELU(0)=1`, the adapter output is `h @ U.T` with `h = p*S` = exact reconstruction of the top-r part, so the layer is identity at init. Docstring confirms intent (file:17): "Identity at init (B=0, g=0)".
|
||||
|
||||
Inference on kaiming vs Gaussian: the paper's only stated requirement is that the random side be a random init so symmetry is broken and `BA=0` purely from the zero side. Both kaiming-uniform and `N(0, sigma^2)` are zero-mean symmetry-breaking inits; the network output at init is identical (0, because the zero side kills it) and gradients to `B` at step 1 depend only on `A`'s scale, not its distribution family. Kaiming-uniform is the PyTorch `nn.Linear` default and is what the reference `microsoft/LoRA` code actually uses (`kaiming_uniform_(A, a=sqrt(5))`) despite the paper text saying "Gaussian". So this is not a meaningful deviation; if anything it matches the reference implementation more closely than the paper prose.
|
||||
|
||||
VERDICT: MATCHES (zero on the up-proj on both sides; identity at init holds). Kaiming-vs-Gaussian on the random side: DEVIATES-OK, immaterial.
|
||||
|
||||
---
|
||||
|
||||
## 3. Scaling: paper's alpha/r vs our `coeff`
|
||||
|
||||
Paper, Section 4.1, lines 206-211:
|
||||
|
||||
> We then scale ∆W x by α/r, where α is a constant in r. When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately. [...] This scaling helps to reduce the need to retune hyperparameters when we vary r (Yang & Hu, 2021).
|
||||
|
||||
Code (file:155, 165):
|
||||
|
||||
> `coeff = float(cfg.coeff)`
|
||||
> `h = p * S_eff + coeff * (p @ A.T) @ B.T # (..., r)`
|
||||
|
||||
Observation: ours multiplies the core by `coeff` (a runtime scalar, default 1.0, file:47) and has no `alpha/r` factor. `coeff` also scales the gain via `S_eff = S * (1 + ELU(coeff * g))` (file:161), so it is a shared global knob, not a per-adapter LoRA `alpha`.
|
||||
|
||||
Inference, init-time identity: the paper's `alpha/r` is a constant multiplier on a quantity (`BA`) that is exactly zero at init. A constant times zero is zero. So `alpha/r` has NO effect on init-time identity, and neither does its absence here -- identity at init is governed entirely by `B=0`, which holds (point 2). The absence of `alpha/r` does not break identity.
|
||||
|
||||
Inference, training dynamics: `alpha/r` is a fixed scalar folded into the effective learning rate for the LoRA branch. Its purpose (per the quote and Yang & Hu 2021 / muP) is to keep update magnitudes stable as you sweep `r`, so you do not have to retune LR per rank. Dropping it means: if you sweep `lora_rank` here, the effective step size on the core is not auto-normalized, so the optimal LR may shift with rank. That is a real difference in training dynamics, but it is a tuning-convenience factor, not a correctness bug -- the reachable function class is identical (any `alpha/r` is absorbable into `A`,`B` magnitude and LR). `coeff` here additionally couples gain and core scaling, which is a deliberate design choice (single knob: 0=identity), not the paper's per-branch `alpha`.
|
||||
|
||||
VERDICT: DEVIATES-OK. No effect on init-time identity (zero is scale-invariant). For training dynamics it removes the rank-stabilizing `alpha/r` convenience; reachable functions unchanged, but rank sweeps may need LR retuning. Not a bug.
|
||||
|
||||
---
|
||||
|
||||
## 4. Subspace restriction (deliberate deviation from full-space ΔW)
|
||||
|
||||
Paper, Section 4.1, lines 199-200:
|
||||
|
||||
> For a pre-trained weight matrix W0∈ Rd×k, we constrain its update by representing the latter with a low-rank decomposition W0 + ∆W = W0 + BA, where B∈ Rd×r, A∈ Rr×k
|
||||
|
||||
Observation: LoRA's `BA` lives in the full `(d, k)` space; `B`'s column space and `A`'s row space are free, unconstrained by `W0`. Ours instead acts on `p = x @ Vh.T` (file:160), the projection of the input into the frozen top-r right-singular basis, and writes back via `@ U.T` (file:166), the frozen top-r left-singular basis. So the entire adapter (gain + core) is sandwiched as `U @ (...) @ Vh` and is confined to `W`'s top-r row/column subspace by construction.
|
||||
|
||||
Inference, soundness: with `p` in the `Vh`-row-space and output through `U`, the core can only read directions in `W`'s top-r right-singular subspace and only write directions in its top-r left-singular subspace. The `(r,r)` matrix `B@A` mixes WITHIN that subspace -- it can rotate/mix singular direction `i` into singular direction `j` for `i,j <= r`, but cannot inject any component outside the stored `U`/`Vh` span. This is exactly the stated design ("a low-rank mixing core in the frozen SVD basis", file:1; "the rank-k term is LoRA's core ... restricted to W's top-r subspace", file:14-15). It is a deliberate, internally consistent restriction, and the math confirms the core cannot escape the top-r span.
|
||||
|
||||
VERDICT: DEVIATES-OK. Stated explicitly as a deliberate deviation (docstring file:1-17); sound -- the `U .. Vh` sandwich provably confines mixing to W's top-r subspace.
|
||||
|
||||
---
|
||||
|
||||
## 5. SVD sign disambiguation
|
||||
|
||||
Relevant code: basis frozen at init from `torch.linalg.svd` (file:78), core learned from zero (`lora_B` zeros, file:69), forward reads/writes through `U`/`Vh` (file:160, 166).
|
||||
|
||||
There is no paper quote for this -- it is a property of using SVD vectors as a fixed basis, so I reason from the math.
|
||||
|
||||
SVD is sign-ambiguous: for each `i`, flipping `(u_i, v_i) -> (-u_i, -v_i)` leaves `W = sum_i s_i u_i v_i^T` unchanged (the two sign flips cancel). The diagonal gain term is sign-immune already: `p_i = x . v_i`, output contribution `(p_i * S_eff_i) u_i`; flipping both `u_i, v_i` sends `p_i -> -p_i` and `u_i -> -u_i`, product unchanged.
|
||||
|
||||
For the core, consider the learned operator in subspace coords `M = B@A` (shape `(r,r)`), so output `= (p^T M) @ U` writing `y += sum_i (pM)_i u_i^T`. Let `D = diag(+-1)` be any per-direction sign flip applied to the frozen basis (`U -> U D`, `Vh -> D Vh`, hence `p -> D p`). The full reconstruction `U D ... D Vh` is invariant for the gain. For the core, the post-training output uses the LEARNED `M`; under a sign flip of the basis, the SAME target output is produced by the gradient-reachable `M' = D M D`. Because `M` is learned from `M=0` (`B=0` at init) with no prior tying it to any particular sign convention, the optimizer is free to land on `M'` instead of `M` -- the loss landscape is identical up to the relabeling `M <-> D M D`. There is no asymmetry (no nonzero init, no regularizer referencing absolute sign, ELU acts only on the separate gain `g` not on the core) that would make one sign choice reachable and the mirror not.
|
||||
|
||||
Therefore a global (or per-direction) sign flip of the frozen singular vectors is absorbed by the learned-from-zero core; the represented function class and the optimum are sign-invariant.
|
||||
|
||||
VERDICT: MATCHES (no canonicalization needed). Sign ambiguity is absorbed by the from-zero core; confirmed from the math, no paper claim required.
|
||||
|
||||
---
|
||||
|
||||
## 6. Citation
|
||||
|
||||
Code (file:14, file:17, file:19):
|
||||
|
||||
> `subspace (Hu+ 2021, arXiv:2106.09685).` / `lora.py (low-rank core)` / docstring header.
|
||||
|
||||
Paper identity (lines 1-3, 67):
|
||||
|
||||
> LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS -- Edward Hu, Yelong Shen, ... -- arXiv:2106.09685v2 [cs.CL] 16 Oct 2021.
|
||||
|
||||
Observation: first author Hu, year 2021, arXiv ID 2106.09685 all match.
|
||||
|
||||
VERDICT: CITATION-CORRECT.
|
||||
|
||||
---
|
||||
|
||||
## Bottom line
|
||||
|
||||
No real bugs: all six points MATCH or DEVIATE-OK (the two intentional deviations -- subspace restriction and dropping alpha/r -- are stated and sound; kaiming-vs-Gaussian is immaterial). Citation correct.
|
||||
@@ -0,0 +1,38 @@
|
||||
## Code Review: Four S-space weight adapters (anti-pasto family + sspace reference + cost script)
|
||||
|
||||
### Summary
|
||||
The code adds four PiSSA-style adapters (antipasto, antipasto_rot, antipasto_ablate, antipasto_corda) that store frozen top-r SVD buffers and a small trainable gain/core. The key mathematical claims in the docstrings (identity at g=0, contraction of ablation core, exact reconstruction of the CorDA decomposition, signed cosine gate) are correct. However, `group_init()` re-selects/re-orients the SVD basis but does **not** reset the trainable parameters (g, delta_s, rot_T, c, alpha), which introduces a latent but important correctness problem — after a data-driven reinit, gains and direction vectors become misaligned with their intended singular axes.
|
||||
|
||||
---
|
||||
|
||||
### Critical (must fix)
|
||||
- **`antipasto_corda.py`, `antipasto_rot.py`, `antipasto_ablate.py`:** `group_init()` re-orients or re-selects the top-r SVD basis but leaves the existing trainable parameters (`lora_g`, `lora_delta_s`, `lora_rot_T`, `lora_c`, `lora_alpha`) untouched, still attached to the old indices.
|
||||
- For `antipasto_corda`, `lora_g` is initialised with `N(0, 4e-4)` in `init()`. After `group_init`, those small random gains now multiply a different set of singular directions, producing an uncontrolled (though tiny) perturbation.
|
||||
- For `antipasto_rot`, the `+4e-4` bias on `lora_delta_s` remains, now applied to arbitrary resorted directions, and the block rotation `lora_rot_T` is disconnected from the new block structure.
|
||||
- For `antipasto_ablate`, the ablation directions and strengths (`lora_c`, `lora_alpha`) are not reset; if any warm-starting or training has happened, they point in the wrong subspace.
|
||||
**Fix:** After re-selection or re-orientation, either re-initialise the trainable parameters (e.g., zero for g, zeros/small for delta_s and rot_T, random-normalised for c, init_alpha for alpha) or re-index them with the same `idx` mapping used for the buffers. Document that `group_init` must be called **before any training** and that the trainable parameters will be fresh after it.
|
||||
|
||||
### Important (should fix)
|
||||
- **`antipasto_ablate.py`** – Contraction claim and coeff bounds: the forward pass applies `h = h - coeff * (proj * alpha) @ Chat.T`. The core is a contraction only when both `coeff` and `alpha` are in `[0, 1]`. The config’s `coeff` can be *any* float (the docstring mentions `coeff<0` for amplification). There is no runtime clamping of `coeff`. If the intention is to keep the contraction property structurally enforced, clamp `coeff` to `[0, 1]` inside `forward()` or at least validate it.
|
||||
- **`sspace.py`** – Division by `sqrtS` without epsilon: `xS = (y_eff @ U_r) / sqrtS`. For small or zero singular values (e.g., when r is large or W is low-rank) this can produce NaNs/infs. Add a small `ε` denominator (consistent with the ε used elsewhere).
|
||||
- **`antipasto_ablate.py`** – `_orthonormal()` calls `torch.linalg.qr(c.float())` every forward pass. For large `r` and many layers this adds non-trivial cost. A lighter reparameterisation (e.g., maintaining `c` as a matrix with a normalisation step that avoids full QR) might be warranted, but for small `(r,k)` pairs the current approach is acceptable. At minimum, add a comment noting the per-forward cost.
|
||||
|
||||
### Suggestions
|
||||
- **`antipasto.py` group_init:** after the idx sort, `idx = idx.sort().values` ensures a stable, canonical ordering. This is a nice touch for reproducibility.
|
||||
- **`antipasto_ablate.py`** – The docstring says “CorDA-orient the basis from input covariance … the ablation is OUTPUT-side and CorDA's U stays orthonormal …”. The forward code correctly uses the orthonormal `U` for output projection, so the contraction in S-space carries over to the output.
|
||||
- **`sspace.py`** – The signed gate correctly preserves `cos` sign, so anti-aligned tokens receive a negative `gate * alpha` and are pushed opposite to `dS_hat`. Verified.
|
||||
|
||||
### Positive
|
||||
- The documentation is thorough, explaining the design choices (why 1+ELU, why contraction, why CorDA), and includes references.
|
||||
- The PiSSA-style `W_res` decomposition is implemented correctly across all variants.
|
||||
- `antipasto.py`’s `S_eff = S * (1 + ELU(coeff*g))` is indeed C1, positive, and identity at `g=0` — no sign-flip bugs.
|
||||
- `antipasto_ablate.py` enforces orthonormal `Chat` and clamps `alpha` to `[0,1]`, making the contraction property safe when `coeff` is also bounded.
|
||||
|
||||
### Verdict
|
||||
**REQUEST CHANGES**
|
||||
|
||||
`group_init()` must reset the trainable parameters after it changes the SVD basis; without this, the adapters silently poison their own steering gains with a different set of directions. Fix that and the code is ready.
|
||||
|
||||
---
|
||||
|
||||
*Note: The `sspace.py` and `_cost.py` files are part of the same move into lora-lite and are free of the parameter-reset issue; the only concern in `sspace.py` is the unprotected division by singular values.*
|
||||
@@ -0,0 +1,19 @@
|
||||
**Evaluating reconstruction error**
|
||||
|
||||
I'm exploring how eps impacts reconstruction error. It seems eps doesn't introduce additional errors due to finite precision but just alters the decomposition. I think for full SVD, as long as Mhalf and Minvhalf are true inverses, it works. However, for truncated SVD, that's not the case. The residual comes out exact. I should also consider discussing clamp_min(0) in combination with eps. It looks like there’s a bit more to analyze here!**Evaluating identity output**
|
||||
|
||||
I’m considering whether the user needs the original W x output to be exactly the same in arithmetic versus in code using bf16. In terms of exact arithmetic, yes, but it seems like bf16 won't match precisely. If exact precision is necessary, this might be something that should be fixed. I wonder if I could propose computing the residual after quantizing U/S/P as a solution to address this issue.**Clarifying code concepts**
|
||||
|
||||
In the code context, if `r` is less than the rank, then `(Ur*Sr)@Pr` represents the weighted rank-r. Essentially, `W_res_new` is calculated as `W_orig` minus this value. When discussing equality, it's important to note that in precise arithmetic, it becomes trivial. However, the top-r part should be treated as an approximate reconstruction when it stands alone. I need to keep these distinctions clear for better understanding!**Evaluating quant methods**
|
||||
|
||||
Many quantitative methods use activation scaling with RMS that isn't centered. If the mean is nonzero, the uncentered approach can lead to a heavily biased μ. That's an important detail! I wonder if other methods might give different results if a centered approach were used instead. It's interesting to think about how such small changes can impact the outcomes in quantitative analysis. I'll keep this in mind as I explore more about these methods.**Considering token padding issues**
|
||||
|
||||
I'm thinking about how tokenization works and the impact of padding. If activations from masked padding are included, they might skew the calibration results. The user didn’t explicitly ask, but I wonder if representativeness is a concern. I should cite the formula used for counting tokens and note how each position is weighted equally, which might not be ideal if variable lengths include padding. It's important to optimize the data correctly, especially in transformers!**Drafting final verdict**
|
||||
|
||||
I’m preparing the final verdict, concluding that claims 1, 3, and 5 are correct. For claim 2, it holds with the SPD second moment used in a damped objective due to epsilon. Claim 4 reveals that no identity error arises with the residual; the truncated term doesn’t telescope.
|
||||
|
||||
Now, I want to lay out the details clearly but concisely. I’ll include mathematical notation, detailing the relationship between column input and row returns, and then provide proof for claim 2 using the necessary equations.**Clarifying mathematical notation**
|
||||
|
||||
I'm thinking about how to handle the matrix M, where it’s not just C but rather C plus eps times the identity matrix (I). I need to clarify this in the context of ASVD, focusing on the entries. It’s important to accurately express how the addition of eps affects the matrix and its properties. I’ll ensure to incorporate this into the overall explanation properly.**Finalizing structure**
|
||||
|
||||
I’m starting with the initialization process, making sure everything is set up correctly. The code is being reconstructed from previous versions, which feels promising. Now, I want to write the final version clearly and in a structured way. I’ll be careful about keeping it concise and will use code snippets for citations, focusing on exact lines instead of line numbers. This way, the user can easily follow along. Let’s wrap this up!
|
||||
@@ -0,0 +1,116 @@
|
||||
## Review
|
||||
|
||||
**BLOCKER:** none found. The math, save/load round-trip, and variant coverage all check out.
|
||||
|
||||
---
|
||||
|
||||
### SHOULD
|
||||
|
||||
**1. antipasto_corda.py:13-18 — docstring falsely claims fallback behavior**
|
||||
```python
|
||||
Identity at g=0 or coeff=0: S_eff=S. P is oblique (rows not orthonormal -- C^{-1/2}
|
||||
skews them); fine for gain reweighting and for output-side ablation (the obliqueness
|
||||
is input-side; U stays orthonormal). No calibration_data -> plain SVD (== antipasto).
|
||||
```
|
||||
The last sentence is now **wrong**: `group_init` raises `ValueError` when `calibration_data is None` (by design, after the bugfix). The docstring advertises a silent fallback that no longer exists. Remove or replace with "requires calibration_data (raises otherwise)".
|
||||
|
||||
**2. antipasto_corda.py:55-67 — group_init docstring narrates the bugfix history**
|
||||
```python
|
||||
"""Re-orient each target's SVD by its input covariance C = E[x x^T].
|
||||
...
|
||||
Covariance orientation IS this variant's identity, so calibration_data is
|
||||
mandatory -- fail loud rather than silently degrade to plain SVD (which is
|
||||
just antipasto and was the bug that made every corda run a no-op).
|
||||
...
|
||||
Do not call group_init after training has updated g."""
|
||||
```
|
||||
The commentary "was the bug that made every corda run a no-op" is changelog narration. The last sentence lectures the reader. Replace with a single line like `# Requires calibration_data; raises otherwise. Call only at attach-time (before training).`
|
||||
|
||||
**3. antipasto_corda.py:80-88 — long design-rationale block is rambling**
|
||||
```python
|
||||
# accumulate C = sum x x^T on CPU. Peak GPU cost would otherwise be
|
||||
# sum_targets d_in^2 fp32 held at once; for down_proj (d_in=intermediate,
|
||||
# e.g. 14336) that is ~0.8 GB *per layer* and OOMs. CPU accumulation bounds
|
||||
# GPU use to the live activation; the eigh/SVD below run on CPU (one-time).
|
||||
# Diagonal C is NOT a usable shortcut: it misses cross-channel correlation,
|
||||
# which is where the orientation gain lives (measured ~= plain SVD).
|
||||
# If down_proj's d_in^2 is too big even on CPU/RAM, exclude it from CorDA
|
||||
# (leave it on plain antipasto) or use a low-rank C (top-k eig of subsampled
|
||||
# inputs) -- not implemented here.
|
||||
```
|
||||
This is lecture/rambling and future-work speculation. Slim to `# CPU accumulation: d_in^2 per layer OOMs GPU.` The rest is for a design doc or paper, not inline.
|
||||
|
||||
**4. antipasto_dplr.py:14-20 — docstring narrates abandoned variants**
|
||||
```python
|
||||
antipasto's core is diagonal (a per-direction gain); it rescales each singular
|
||||
direction but cannot mix one into another. The arrowhead tried a dense b x b block
|
||||
on the top-b directions, but a dense block is the wrong shape (b^2 params, mixes only
|
||||
the top-b) and -- sitting on the S-scaled coords -- its perturbation is amplified by
|
||||
the largest singular values, so it destabilizes. The fix is LoRA's lesson: a low-rank
|
||||
core. ...
|
||||
```
|
||||
History of "the arrowhead" variant and "the fix is LoRA's lesson" is changelog narration. The paragraph starting "Why the low-rank part is ADDED…" (lines 23-26) repeats the same. Trim to: `# Additive low-rank core B@A in the frozen SVD basis. Independent of S → no amplification edge.` and drop the storytelling.
|
||||
|
||||
**5. antipasto_dplr.py:26 vs 144 — docstring claims operator is `B A`, code computes `(B A)^T`**
|
||||
Docstring:
|
||||
```
|
||||
y = x @ W_res.T + ( (Vh x) * S_eff + coeff * B (A (Vh x)) ) @ U.T
|
||||
```
|
||||
Code:
|
||||
```python
|
||||
h = p * S_eff + coeff * (p @ A.T) @ B.T # p = x @ Vh.T
|
||||
```
|
||||
The effective operator in S-space is `diag(S_eff) + coeff * (B @ A).T`, not `B @ A`. The parameterization can represent the same matrices (swap B↔A^T), so it's not a math bug — but the docstring describes the wrong composition. Fix the docstring or swap `A.T`/`B.T` to `B`/`A` in the forward pass to match the stated convention. The LoRA convention is `x @ A.T @ B.T` for `ΔW = B @ A`; here the natural convention would be `p @ B @ A` for core `B @ A`, but the code produces the transpose.
|
||||
|
||||
---
|
||||
|
||||
### NICE
|
||||
|
||||
**6. antipasto_corda.py:12 — `C^{1/2}` justification is solid; user's note is mistaken**
|
||||
The user asks whether canonical CorDA uses `W @ C` (not `C^{1/2}`) and whether `C^{1/2}` is defensible. **Both the canonical CorDA paper (arXiv:2406.05223) and this code use `W @ C^{1/2}`** — CorDA does Cholesky `C = L L^T` and SVD on `W L`, which is `W @ C^{1/2}`. The `C^{1/2}` form is **correct**: minimizing `||(W - W_r) C^{1/2}||_F` minimizes the expected reconstruction error `E[||W x - W_r x||^2]` for `x ~ N(0, C)`. The symmetric eigh square-root used here is equivalent to the Cholesky factor. The comment on line 93 ("CorDA whitens with full C") in an earlier version of the docstring would be factually wrong — but the current file doesn't contain that claim; only the user's prompt does.
|
||||
|
||||
**7. antipasto.py:97 — misleading comment about CPU capture**
|
||||
```python
|
||||
proj = X.to(Vh_full) @ Vh_full.T # input in S-coords (X captured on CPU)
|
||||
```
|
||||
`X` was captured on CPU, but `X.to(Vh_full)` moves it to GPU for the projection (since `Vh_full` is on GPU after `torch.linalg.svd(W_orig)`). The parenthetical is either stale or ambiguous. Clarify: `# X was accumulated on CPU; moved here to GPU for the projection.`
|
||||
|
||||
**8. antipasto.py vs antipasto_corda.py — duplicated ELU logic**
|
||||
`antipasto.py:forward()` inlines `1.0 + F.elu(coeff * g)` while `antipasto_corda.py` has a factored-out `_gain()` helper. The `antipasto_dplr.py` inlines it again. Unify to `_gain` or a shared utility. Low priority.
|
||||
|
||||
**9. adapter.py:57 — `base_weight_keys` uses `attached_names` from `for` loop, but `attached_names` is only populated if no early exceptions occur**
|
||||
```python
|
||||
attached_names: list[str] = []
|
||||
attached_targets = []
|
||||
for name, layer, role in targets:
|
||||
...
|
||||
attached_names.append(name)
|
||||
attached_targets.append((name, layer, role))
|
||||
```
|
||||
If `variant.init()` or `param_specs` raises mid-loop, `attached_names` is inconsistent with `attached_targets`. Harmless since the exception propagates and `attach()` is discarded, but the partial `attached_targets` is passed to `group_init` which would see only the first N layers. Since `group_init` uses `targets` as a dict keyed by name, and `attached_targets` is only needed for the hook registration after group_init, this could cause group_init to silently miss later layers if an exception were caught — but nothing catches it. Not a bug in current fail-fast style.
|
||||
|
||||
---
|
||||
|
||||
### Save/load completeness
|
||||
|
||||
**Coverage:** All three variants (`antipasto`, `antipasto_corda`, `antipasto_dplr`) have `group_init` that mutates `layer.weight` when `calibration_data` is provided. The `ran_data_init` guard in `attach()` (line 64) only sets `base_weight_keys` when `group_init` is present AND `calibration_data is not None` AND `_skip_group_init` is False — this correctly covers all three. ✓
|
||||
|
||||
**Ordering at load:** `attach(_skip_group_init=True)` runs `init()` (plain-SVD crop), then `load_state_dict` overwrites all buffers AND the base weight (if present in checkpoint). The final state is the checkpoint's, not init's. ✓
|
||||
|
||||
**No double-restore:** `base_weight_keys` is only used during `save()`. During `load()`, the checkpoint's keys are applied via `load_state_dict(strict=False)` indiscriminately; base weights are restored if present. The `base_weight_keys` list from the **loading** model (always `[]` since `_skip_group_init=True`) is irrelevant — it's not consulted during load. ✓
|
||||
|
||||
**Checkpoint cross-variant safety:** If a checkpoint saved from `antipasto_corda` (with base weights) is loaded into an `antipasto` model, the `lora_P` key in checkpoint has no matching parameter in the model → `unexpected_lora` catches it and raises. If a corda checkpoint saved WITHOUT base weights (pre-bugfix, `ran_data_init=False`) is loaded into corda, the checkpoint has plain-SVD `lora_P` but no base weight → model's `init()` W_res and checkpoint's `lora_P` are both plain SVD, so they match. ✓
|
||||
|
||||
**No gap:** Every variant whose `group_init` rewrites `layer.weight` is covered by the `ran_data_init` → `base_weight_keys` path. ✓
|
||||
|
||||
---
|
||||
|
||||
### Perf (minor)
|
||||
|
||||
- No `einops`/`einsum` in forward hot loops. `einops.rearrange` only appears in `group_init` hook callbacks (one-time calibration). ✓
|
||||
|
||||
---
|
||||
|
||||
### Verdict
|
||||
|
||||
No math errors found. The `C^{1/2}` approach is equivalent to canonical CorDA (via Cholesky). Save/load round-trip is correct and complete. The main issues are docstring rot (stale fallback claim, changelog narration, operator transpose mismatch) and overly chatty inline comments.
|
||||
@@ -0,0 +1,112 @@
|
||||
## BLOCKER
|
||||
|
||||
- `src/lora_lite/adapter.py:57` / `:64` / `:66-70`
|
||||
|
||||
> `variant.init(layer, cfg)`
|
||||
> `group_init(model, attached_targets, cfg, calibration_data)`
|
||||
> `for _, layer, _ in attached_targets:`
|
||||
|
||||
`init()` crops every targeted `layer.weight` to `W_res` before calibration, but the runtime hooks that add back the frozen SVD component are only registered after `group_init()`. Therefore `group_init()` runs the calibration forward pass through a damaged model. Inputs/covariances for downstream target layers are collected from activations produced by earlier cropped layers, so Wanda re-selection / CorDA covariance bases are silently data-wrong.
|
||||
|
||||
**Fix:** install the identity adapter hooks before running data-driven `group_init()`, or collect calibration activations before mutating weights. With current variants, hooks-before-`group_init` should preserve identity because `g=0`, `B=0`, and `P/Vh` initially reconstruct the cropped component.
|
||||
|
||||
- `src/lora_lite/adapter.py:75`, `:111`, `:121-122`
|
||||
|
||||
> `base_weight_keys = [f"{n}.weight" for n in attached_names] if ran_data_init else []`
|
||||
> `metadata = {"cfg": json.dumps(state["cfg"].to_dict())}`
|
||||
> `handles = attach(model, cfg, _skip_group_init=True)`
|
||||
> `missing, unexpected = model.load_state_dict(sd, strict=False)`
|
||||
|
||||
Persisting rewritten base residuals is correct for the shown data-driven mutators: `antipasto`, `antipasto_corda`, and `antipasto_dplr` all rewrite `layer.weight` in `group_init()` when data is used. But the fix is incomplete: `base_weight_keys` is process-local attach state, not checkpoint metadata. After loading a data-driven checkpoint, `attach(..., _skip_group_init=True)` sets `base_weight_keys=[]`; `load_state_dict` may restore the saved weights, but a subsequent `save()` drops them again. Also, old/broken checkpoints missing the residual weights load silently with plain-SVD residuals plus data-oriented `U/S/P`, which is wrong.
|
||||
|
||||
**Fix:** save `base_weight_keys` in safetensors metadata, validate those keys exist in `sd` during load, and restore `model._lora_lite_attached["base_weight_keys"]` after load. For `antipasto_corda`, absence of expected base weights should be a hard error for any valid data-oriented checkpoint.
|
||||
|
||||
- `src/lora_lite/adapter.py:92-99`
|
||||
|
||||
> `for pname in variant.param_specs(layer.in_features, layer.out_features, layer._lora_cfg):`
|
||||
> `del layer._parameters[pname]`
|
||||
> `del layer._buffers[pname]`
|
||||
|
||||
`detach()` deletes the adapter state but never restores the frozen component into `layer.weight`. For SVD-mutating variants, the base weight remains `W_res`, so the detached model is silently not the original model.
|
||||
|
||||
**Fix:** before deleting params/buffers, restore the frozen base component:
|
||||
- `antipasto`: `weight += (U * S) @ Vh`
|
||||
- `antipasto_dplr`: same frozen `U diag(S) Vh`; do not merge learned `g/A/B` unless this is a separate `merge()`
|
||||
- `antipasto_corda`: `weight += (U * S) @ P`
|
||||
|
||||
Prefer a variant method like `restore_base(layer)` to avoid hard-coding basis names in `adapter.py`.
|
||||
|
||||
## SHOULD
|
||||
|
||||
- `src/lora_lite/variants/antipasto_corda.py:12`
|
||||
|
||||
> `W = U diag(S) P (exactly)`
|
||||
|
||||
This is false as written for stored rank `r`. Exact reconstruction only holds for the full SVD of `W C_eps^{1/2}`. The implemented rank-`r` adapter reconstructs via residual:
|
||||
|
||||
`W = W_res + U_r diag(S_r) P_r`
|
||||
|
||||
**Fix:** change the pseudocode to include `W_res`. If making the weighted approximation claim, state it precisely: SVD of `W C^{1/2}` gives the Eckart–Young best rank-`r` approximation in `E ||(W - B)x||_2^2 = ||(W-B)C^{1/2}||_F^2`.
|
||||
|
||||
- `src/lora_lite/variants/antipasto_corda.py:10-11`
|
||||
|
||||
> `U S Vht = SVD(W C^{1/2})`
|
||||
> `P = Vht C^{-1/2}`
|
||||
|
||||
The `C^{1/2}` choice is mathematically defensible; it is not obviously wrong. It corresponds to optimal low-rank approximation in the input-covariance-weighted output norm. But it is not the same as canonical PEFT CorDA if that uses `W @ C`.
|
||||
|
||||
**Fix:** document this as a deliberate CorDA-like / covariance-weighted Eckart–Young variant. If exact PEFT compatibility is intended instead, change the algebra to match PEFT and accept that it optimizes a different weighting.
|
||||
|
||||
- `src/lora_lite/variants/antipasto_corda.py:16`
|
||||
|
||||
> `No calibration_data -> plain SVD (== antipasto).`
|
||||
|
||||
Stale after the fail-fast change. `group_init()` now raises on `calibration_data is None`.
|
||||
|
||||
**Fix:** remove this sentence or replace with “load uses plain-SVD seeding; public CorDA attach requires calibration data.”
|
||||
|
||||
- `src/lora_lite/adapter.py:57` + `src/lora_lite/variants/antipasto_corda.py:102-106`
|
||||
|
||||
> `variant.init(layer, cfg)`
|
||||
> `if calibration_data is None:`
|
||||
> `raise ValueError(...)`
|
||||
|
||||
Fail-fast is fine, but this fails after `init()` has already cropped the weights and before `_ATTACHED_ATTR` is set. In an interactive run, catching the `ValueError` leaves a corrupted model that `detach()` cannot find.
|
||||
|
||||
**Fix:** validate CorDA’s calibration-data requirement before mutating weights, e.g. a variant flag like `requires_calibration_data = True` checked in `attach()` before the init loop.
|
||||
|
||||
- `src/lora_lite/variants/antipasto_dplr.py:53` / `:67`
|
||||
|
||||
> `# Params = r (gain) + 2*r*lora_rank. k=0 degenerates to plain antipasto.`
|
||||
> `if not 0 < k <= r:`
|
||||
|
||||
Comment says `k=0` is supported; code rejects it.
|
||||
|
||||
**Fix:** either allow `k=0` by omitting `lora_A/lora_B` and branching forward, or change the comment.
|
||||
|
||||
## NICE
|
||||
|
||||
- `src/lora_lite/variants/antipasto_corda.py:95-98`
|
||||
|
||||
> `mandatory -- fail loud rather than silently degrade to plain SVD`
|
||||
> `just antipasto and was the bug that made every corda run a no-op`
|
||||
|
||||
Changelog narration. Remove the “was the bug” history; keep only the invariant.
|
||||
|
||||
- `src/lora_lite/variants/antipasto.py:187-193`
|
||||
|
||||
> `# Why 1+ELU and not the obvious alternatives:`
|
||||
|
||||
Too lecture-like for slim research code. Compress to one equation/comment plus citation or move rationale to docs.
|
||||
|
||||
- `src/lora_lite/variants/antipasto_dplr.py:4-8`
|
||||
|
||||
> `The arrowhead tried a dense b x b block...`
|
||||
|
||||
Historical narration. Keep the current model equation and maybe one citation; drop the archaeology.
|
||||
|
||||
- `src/lora_lite/variants/antipasto_corda.py:109-117`
|
||||
|
||||
Long CPU/OOM implementation note. Useful, but too much in-code prose. Compress to “accumulate full C on CPU; diagonal C loses correlations” and move sizing advice elsewhere.
|
||||
|
||||
- PERF: no hot-loop `einops` issue. The `rearrange` use is only calibration/init-side, not per-token forward.
|
||||
@@ -1,5 +1,13 @@
|
||||
set shell := ["bash", "-cu"]
|
||||
|
||||
# Base (NOT Instruct) text model: CorDA/PiSSA/ASVD decompose the pretrained weight and
|
||||
# orient by calibration covariance -- the task must not be pre-baked by RLHF, or the
|
||||
# variant differences ceiling out. AutoModelForCausalLM resolves Qwen3.5-0.8B-Base to
|
||||
# the text-only Qwen3_5ForCausalLM (0.75B, no vision tower). It is a hybrid: 18 of 24
|
||||
# layers are GatedDeltaNet (no q/v), 6 are full attention. So we target down_proj (dense
|
||||
# nn.Linear in ALL 24 layers, d_in=3584) -- also CorDA/ASVD's canonical, highest-d_in target.
|
||||
model := "Qwen/Qwen3.5-0.8B-Base"
|
||||
|
||||
default:
|
||||
@just --list
|
||||
|
||||
@@ -25,7 +33,7 @@ qwen-probe variants="lora pissa delora ia3" steps="5":
|
||||
for variant in {{variants}}; do
|
||||
uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py \
|
||||
--mode probe \
|
||||
--model Qwen/Qwen3-0.6B-Base \
|
||||
--model {{model}} \
|
||||
--variant "$variant" \
|
||||
--steps {{steps}} \
|
||||
--batch-size 1 \
|
||||
@@ -38,7 +46,7 @@ qwen-probe variants="lora pissa delora ia3" steps="5":
|
||||
--alpha 8 \
|
||||
--layers 0 \
|
||||
--lr 5e-3 \
|
||||
--target-name 'model\.layers\.0\.self_attn\.(q_proj|v_proj)$'
|
||||
--target-name 'model\.layers\.0\.mlp\.down_proj$'
|
||||
done
|
||||
|
||||
qwen-queue variants="lora pissa delora ia3" steps="16":
|
||||
@@ -65,7 +73,7 @@ metamath-smoke variant="lora" steps="2" max_train_samples="8" max_eval_samples="
|
||||
--torch-dtype float32 \
|
||||
--device {{device}}
|
||||
|
||||
metamath-queue variant="lora" steps="5000" model="Qwen/Qwen3-0.6B-Base":
|
||||
metamath-queue variant="lora" steps="5000" model=model:
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
pueue add \
|
||||
@@ -75,11 +83,14 @@ metamath-queue variant="lora" steps="5000" model="Qwen/Qwen3-0.6B-Base":
|
||||
|
||||
# Run a single MetaMathQA->GSM8K benchmark for a given variant.
|
||||
# Per-variant lr / target-name defaults are baked in here.
|
||||
bench-variant model variant steps="5000":
|
||||
bench-variant model variant steps="5000" lora_rank="8" r_override="" lr_override="" rotate_basis="V":
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
lr=1e-4
|
||||
target='(q_proj|v_proj)$'
|
||||
# down_proj: dense nn.Linear in all 24 layers of the hybrid Qwen3.5 (q/v exist in
|
||||
# only 6 full-attention layers) and CorDA/ASVD's canonical highest-d_in target.
|
||||
target='(down_proj)$'
|
||||
r=32; alpha=64
|
||||
# IA3 lr: paper uses 3e-3 to 1e-2 (Liu et al. 2022 §3.3). Also a hard
|
||||
# bf16 floor: lora_g inits to 1.0 where bf16 spacing is ~7.8e-3, so
|
||||
# AdamW updates with lr<<3.9e-3 round back to 1.0 and the param freezes.
|
||||
@@ -88,19 +99,36 @@ bench-variant model variant steps="5000":
|
||||
delora) lr=1e-3 ;;
|
||||
ia3) lr=5e-3; target='(k_proj|v_proj)$' ;;
|
||||
ia3_ff) lr=5e-3; target='(down_proj)$' ;;
|
||||
antipasto) lr=5e-3 ;; # small params need higher lr
|
||||
# antipasto cores tune only S-space gain/block (tiny params), so a small
|
||||
# r leaves almost nothing trainable; r=256 is the variant default and
|
||||
# matches the published AntiPaSTO row. alpha=r (no extra scaling).
|
||||
antipasto*) lr=5e-3; r=256; alpha=256 ;;
|
||||
esac
|
||||
# r override (e.g. low-rank corda sweep); alpha tracks r for the antipasto family.
|
||||
if [ -n "{{r_override}}" ]; then r="{{r_override}}"; alpha="{{r_override}}"; fi
|
||||
# lr override (e.g. dplr core wants a tamer lr than the gain's 5e-3).
|
||||
if [ -n "{{lr_override}}" ]; then lr="{{lr_override}}"; fi
|
||||
# 0.8B + large vocab: HF ForCausalLMLoss upcasts logits to fp32 (bs*seq*vocab*4),
|
||||
# which OOMs the 24GB card at the old bs=4/seq=768. micro-batch 2 fits at ~10GB;
|
||||
# grad-accum 4 -> effective batch 8 (optimization quality without the memory).
|
||||
# expandable_segments curbs fragmentation. Same for all variants -> fair comparison.
|
||||
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
exec uv run --extra benchmark python scripts/metamath_gsm8k_benchmark.py \
|
||||
--model '{{model}}' \
|
||||
--variant '{{variant}}' \
|
||||
--steps {{steps}} \
|
||||
--lr "$lr" \
|
||||
--target-name "$target" \
|
||||
--layers all --r 32 --alpha 64 "$@"
|
||||
--antipasto-lora-rank {{lora_rank}} \
|
||||
--batch-size 2 --grad-accum 4 --max-seq-length 512 --batch-size-eval 16 \
|
||||
--layers all --r "$r" --alpha "$alpha" \
|
||||
--antipasto-rotate-basis '{{rotate_basis}}'
|
||||
|
||||
metamath-queue-all model="Qwen/Qwen3-0.6B-Base" steps="5000" variants="lora pissa delora dora hra ia3 ia3_ff eva antipasto":
|
||||
metamath-queue-all model=model steps="2500" variants="lora pissa delora dora hra ia3 ia3_ff eva antipasto antipasto_rot antipasto_corda antipasto_asvd antipasto_ablate antipasto_dplr":
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
# One pueue job per variant (each runs the live code at run time, so editing
|
||||
# while queued is safe). Re-queue here whenever the base model changes.
|
||||
for variant in {{variants}}; do
|
||||
pueue add \
|
||||
-l "why: benchmark {{model}} ${variant} on MetaMathQA->GSM8K at {{steps}} steps; resolve: outputs/metamath_gsm8k/results/benchmark_results.tsv gets a row with accuracy commit time method argv and result JSON for ${variant}" \
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
"""Measure the cost of an attached adapter: params, FLOPs/MACs, time, GPU mem.
|
||||
|
||||
Which metric is "best" for comparing adapters? They answer different questions:
|
||||
|
||||
- trainable_params -- deterministic "size" number. The headline.
|
||||
- macs_per_token -- deterministic, hardware-INDEPENDENT compute. Best for an
|
||||
apples-to-apples comparison: wall-time is noisy and the old
|
||||
rotation adapter paid a per-forward Cayley solve the new ones
|
||||
do not. "adds" (additions) ~= MACs; FLOPs ~= 2 * MACs.
|
||||
- fwd_ms / bwd_ms -- felt cost, but noisy: warmup + median over `iters`, never one run.
|
||||
- peak_gpu_mb -- resident + activation peak around fwd(+bwd).
|
||||
|
||||
FLOPs come from torch.utils.flop_counter.FlopCounterMode (built in, no new dep). Its
|
||||
convention is MACs (a (m,k)@(k,n) matmul counts as m*n*k); we expose both `flops`
|
||||
(as returned) and `macs_per_token = flops / n_tokens` -- calibrate once on a known
|
||||
matmul if you need to be sure of the factor of 2.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import statistics
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.utils.flop_counter import FlopCounterMode
|
||||
|
||||
|
||||
def _time_call(fn, warmup: int, iters: int, cuda: bool) -> float:
|
||||
"""Median wall-time of fn() in milliseconds (warmup excluded)."""
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
if cuda:
|
||||
torch.cuda.synchronize()
|
||||
samples = []
|
||||
for _ in range(iters):
|
||||
if cuda:
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start.record()
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
samples.append(start.elapsed_time(end))
|
||||
else:
|
||||
t0 = time.perf_counter()
|
||||
fn()
|
||||
samples.append((time.perf_counter() - t0) * 1e3)
|
||||
return statistics.median(samples)
|
||||
|
||||
|
||||
def measure_cost(
|
||||
model: torch.nn.Module,
|
||||
fwd_fn,
|
||||
*,
|
||||
bwd_step_fn=None,
|
||||
n_tokens: int | None = None,
|
||||
adapter_filter: str = "lora_",
|
||||
warmup: int = 3,
|
||||
iters: int = 10,
|
||||
) -> dict:
|
||||
"""Cost of the currently-attached adapter.
|
||||
|
||||
fwd_fn(): run one forward (no grad). Used for FLOPs + fwd timing.
|
||||
bwd_step_fn(): zero_grad + forward + loss.backward(). Used for bwd timing.
|
||||
n_tokens: tokens in the fwd_fn batch, for macs_per_token.
|
||||
adapter_filter: substring marking adapter params/buffers (default 'lora_').
|
||||
"""
|
||||
dev = next(model.parameters()).device
|
||||
cuda = dev.type == "cuda"
|
||||
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
named = list(model.named_parameters()) + list(model.named_buffers())
|
||||
adapter_bytes = sum(t.numel() * t.element_size() for n, t in named if adapter_filter in n)
|
||||
|
||||
# Adapter ADDED MACs/token, analytic and arch-independent (the FLOP counter below
|
||||
# asserts on some fused/linear-attention shapes -> None). Each 2D adapter weight of
|
||||
# shape (a, b) is used once in a per-token matmul, contributing a*b MACs; summing 2D
|
||||
# adapter-tensor numel is therefore the exact added compute for the U/Vh/P/A/B paths.
|
||||
# (Slight undercount for cores that reuse a factor twice, e.g. ablate's C C^T.)
|
||||
added_macs_per_token = sum(t.numel() for n, t in named if adapter_filter in n and t.ndim == 2)
|
||||
|
||||
# FLOPs: one forward under the counter (no grad so we count inference cost).
|
||||
# FlopCounterMode can assert on some fused attention shapes; degrade to None.
|
||||
try:
|
||||
fc = FlopCounterMode(display=False)
|
||||
with torch.no_grad(), fc:
|
||||
fwd_fn()
|
||||
flops = fc.get_total_flops()
|
||||
except Exception as e:
|
||||
print(f" [warn] FLOP count failed ({type(e).__name__}: {e}); flops=None")
|
||||
flops = None
|
||||
|
||||
if cuda:
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
fwd_ms = _time_call(lambda: _no_grad(fwd_fn), warmup, iters, cuda)
|
||||
bwd_ms = _time_call(bwd_step_fn, warmup, iters, cuda) if bwd_step_fn is not None else None
|
||||
peak_gpu_mb = (torch.cuda.max_memory_allocated() / 1e6) if cuda else None
|
||||
|
||||
return dict(
|
||||
trainable_params=trainable_params,
|
||||
adapter_resident_mb=adapter_bytes / 1e6,
|
||||
added_macs_per_token=added_macs_per_token, # adapter-only, always populated
|
||||
flops=flops, # whole model, best-effort (None on hybrid attn)
|
||||
macs_per_token=(flops / n_tokens) if (flops and n_tokens) else None,
|
||||
fwd_ms=fwd_ms,
|
||||
bwd_ms=bwd_ms,
|
||||
peak_gpu_mb=peak_gpu_mb,
|
||||
)
|
||||
|
||||
|
||||
def _no_grad(fn):
|
||||
with torch.no_grad():
|
||||
return fn()
|
||||
|
||||
|
||||
class group_init_meter:
|
||||
"""Context manager: wall-time + peak CPU RAM of a group_init / attach-with-calib.
|
||||
|
||||
CorDA accumulates C = E[xx^T] on CPU and runs eigh(d_in^3) -- the expensive corner.
|
||||
Use around ll.attach(model, cfg, calibration_data=...) to log that asymmetry.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.ms = None
|
||||
self.peak_cpu_mb = None
|
||||
|
||||
def __enter__(self):
|
||||
import tracemalloc
|
||||
self._tm = tracemalloc
|
||||
tracemalloc.start()
|
||||
self._t0 = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.ms = (time.perf_counter() - self._t0) * 1e3
|
||||
_, peak = self._tm.get_traced_memory()
|
||||
self._tm.stop()
|
||||
self.peak_cpu_mb = peak / 1e6
|
||||
return False
|
||||
@@ -0,0 +1,143 @@
|
||||
"""One-row-per-variant cost table: params, MACs/token, fwd/bwd ms, peak GPU, group_init.
|
||||
|
||||
Answers "which is best -- time / flops / adds / params?": MACs/token is the
|
||||
deterministic apples-to-apples compute number; trainable_params is the size headline;
|
||||
wall-time is the felt-but-noisy number; group_init is where CorDA's eigh(d_in^3) bites.
|
||||
|
||||
Usage:
|
||||
uv run --extra benchmark python scripts/cost_report.py \
|
||||
--model Qwen/Qwen3-0.6B-Base --variants antipasto antipasto_corda antipasto_ablate lora \
|
||||
--target-name 'q_proj$' 'v_proj$' --r 32 --out logs/cost_qwen0.6b.log
|
||||
|
||||
Point --target-name at down_proj to see the CorDA covariance corner (large d_in).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tabulate import tabulate
|
||||
|
||||
import lora_lite as ll
|
||||
|
||||
_HERE = Path(__file__).resolve().parent
|
||||
_BENCH = importlib.util.spec_from_file_location("metamath_benchmark", _HERE / "metamath_gsm8k_benchmark.py")
|
||||
benchmark = importlib.util.module_from_spec(_BENCH)
|
||||
sys.modules[_BENCH.name] = benchmark
|
||||
_BENCH.loader.exec_module(benchmark)
|
||||
|
||||
_COST = importlib.util.spec_from_file_location("_cost", _HERE / "_cost.py")
|
||||
cost = importlib.util.module_from_spec(_COST)
|
||||
sys.modules[_COST.name] = cost
|
||||
_COST.loader.exec_module(cost)
|
||||
|
||||
|
||||
def build_cfg(variant: str, args, dtype) -> ll.AdapterConfig:
|
||||
"""Reuse the benchmark's variant->config map; only need r/targets/dtype here."""
|
||||
bcfg = benchmark.BenchmarkConfig(
|
||||
model=args.model, variant=variant, r=args.r, alpha=float(args.r),
|
||||
target_name=list(args.target_name), layers=args.layers, torch_dtype=args.dtype,
|
||||
antipasto_cov_orient=args.cov_orient,
|
||||
)
|
||||
return benchmark.cfg_for_variant(bcfg, dtype)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default="Qwen/Qwen3-0.6B-Base")
|
||||
ap.add_argument("--variants", nargs="+",
|
||||
default=["lora", "antipasto", "antipasto_rot", "antipasto_corda",
|
||||
"antipasto_ablate", "antipasto_dplr"])
|
||||
ap.add_argument("--target-name", nargs="+", default=[r"q_proj$", r"v_proj$"])
|
||||
ap.add_argument("--r", type=int, default=32)
|
||||
ap.add_argument("--layers", default="all",
|
||||
help="'all' or comma list e.g. '0,1' -- limit layers (CorDA down_proj eigh is slow).")
|
||||
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
ap.add_argument("--dtype", default="bfloat16")
|
||||
ap.add_argument("--seq-len", type=int, default=256)
|
||||
ap.add_argument("--batch", type=int, default=2)
|
||||
ap.add_argument("--calib-batches", type=int, default=4)
|
||||
ap.add_argument("--cov-orient", action="store_true",
|
||||
help="CorDA-orient antipasto_ablate (measure the eigh corner).")
|
||||
ap.add_argument("--out", default="logs/cost.log")
|
||||
args = ap.parse_args()
|
||||
|
||||
dtype = getattr(torch, args.dtype)
|
||||
# eager attention: FlopCounterMode's sdpa_flop_count asserts on GQA (Qwen3) SDPA
|
||||
# shapes (q heads != kv heads). eager uses explicit matmuls it can count.
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model, dtype=dtype, attn_implementation="eager"
|
||||
).to(args.device)
|
||||
model.eval()
|
||||
|
||||
n_tokens = args.batch * args.seq_len
|
||||
ids = torch.randint(0, model.config.vocab_size, (args.batch, args.seq_len), device=args.device)
|
||||
calib = [{"input_ids": torch.randint(0, model.config.vocab_size,
|
||||
(args.batch, args.seq_len), device=args.device)}
|
||||
for _ in range(args.calib_batches)]
|
||||
|
||||
def fwd():
|
||||
model(input_ids=ids)
|
||||
|
||||
def bwd_step():
|
||||
model.zero_grad(set_to_none=True)
|
||||
loss = model(input_ids=ids).logits.float().pow(2).mean()
|
||||
loss.backward()
|
||||
|
||||
# base (no-adapter) cost, so each row can report the adapter's ADDED MACs/token.
|
||||
base = cost.measure_cost(model, fwd, bwd_step_fn=bwd_step, n_tokens=n_tokens)
|
||||
base_macs = base["macs_per_token"]
|
||||
print(f"base (no adapter): MACs/tok={int(base_macs) if base_macs else None} "
|
||||
f"fwd_ms={round(base['fwd_ms'],2)} bwd_ms={round(base['bwd_ms'],2)}")
|
||||
|
||||
# base = no adapter; model params left trainable, so this is the full-finetune
|
||||
# GPU-mem reference (its backward stores grads for every weight).
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
rows = [{
|
||||
"variant": "base(full-FT)", "train_params": total_params,
|
||||
"fwd_ms": round(base["fwd_ms"], 2), "bwd_ms": round(base["bwd_ms"], 2),
|
||||
"peak_GPU_MB": round(base["peak_gpu_mb"], 1) if base["peak_gpu_mb"] else None,
|
||||
"added_MACs/tok": 0 if base_macs else None,
|
||||
"ginit_ms": 0.0, "ginit_CPU_MB": 0.0,
|
||||
}]
|
||||
for variant in args.variants:
|
||||
cfg = build_cfg(variant, args, dtype)
|
||||
# group_init / attach cost (CorDA's eigh + C live here).
|
||||
with cost.group_init_meter() as gi:
|
||||
ll.attach(model, cfg, calibration_data=calib)
|
||||
c = cost.measure_cost(model, fwd, bwd_step_fn=bwd_step, n_tokens=n_tokens)
|
||||
ll.detach(model)
|
||||
rows.append({
|
||||
"variant": variant,
|
||||
"train_params": c["trainable_params"],
|
||||
"fwd_ms": round(c["fwd_ms"], 2),
|
||||
"bwd_ms": round(c["bwd_ms"], 2) if c["bwd_ms"] else None,
|
||||
"peak_GPU_MB": round(c["peak_gpu_mb"], 1) if c["peak_gpu_mb"] else None,
|
||||
# flat across same-r adapters; kept only as a sanity check, not a comparator.
|
||||
"added_MACs/tok": int(c["macs_per_token"] - base_macs) if (c["macs_per_token"] and base_macs) else None,
|
||||
"ginit_ms": round(gi.ms, 1),
|
||||
"ginit_CPU_MB": round(gi.peak_cpu_mb, 1),
|
||||
})
|
||||
print(f" {variant}: params={rows[-1]['train_params']} "
|
||||
f"peak_GPU_MB={rows[-1]['peak_GPU_MB']} bwd_ms={rows[-1]['bwd_ms']} ginit_ms={rows[-1]['ginit_ms']}")
|
||||
|
||||
table = tabulate(rows, headers="keys", tablefmt="pipe")
|
||||
header = (f"# cost report: {args.model} targets={args.target_name} r={args.r} "
|
||||
f"seq={args.seq_len} batch={args.batch} dtype={args.dtype}\n"
|
||||
f"# COMPARATORS: train_params, peak_GPU_MB (fwd+bwd, process-local max), bwd_ms, ginit_ms.\n"
|
||||
f"# added_MACs/tok is flat across same-r adapters (sanity check only).\n"
|
||||
f"# ginit_CPU_MB undercounts: tracemalloc misses torch C++ tensor allocs (the CorDA C matrix).\n")
|
||||
out_path = Path(args.out)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path.write_text(header + table + "\n")
|
||||
print("\n" + header + table)
|
||||
print(f"\nsaved -> {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -19,6 +19,7 @@ from tabulate import tabulate
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import lora_lite as ll
|
||||
from _cost import measure_cost, group_init_meter
|
||||
|
||||
|
||||
PROMPT = "Question: {query} Think step by step.\nAnswer:"
|
||||
@@ -34,6 +35,11 @@ CFG_BY_VARIANT = {
|
||||
"hra": ll.HRAConfig,
|
||||
"eva": ll.EVAConfig,
|
||||
"antipasto": ll.AntiPaSTOConfig,
|
||||
"antipasto_rot": ll.AntiPaSTORotConfig,
|
||||
"antipasto_ablate": ll.AntiPaSTOAblateConfig,
|
||||
"antipasto_corda": ll.AntiPaSTOCorDAConfig,
|
||||
"antipasto_asvd": ll.AntiPaSTOASVDConfig,
|
||||
"antipasto_dplr": ll.AntiPaSTODPLRConfig,
|
||||
"road": ll.RoadConfig,
|
||||
}
|
||||
|
||||
@@ -42,8 +48,8 @@ CFG_BY_VARIANT = {
|
||||
class BenchmarkConfig:
|
||||
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
|
||||
|
||||
model: str = "Qwen/Qwen3-0.6B-Base"
|
||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
|
||||
model: str = "Qwen/Qwen3.5-0.8B-Base"
|
||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "antipasto_asvd", "antipasto_dplr", "road"] = "lora"
|
||||
mode: Literal["benchmark", "probe"] = "benchmark"
|
||||
device: str = "cuda"
|
||||
torch_dtype: str = "bfloat16"
|
||||
@@ -52,14 +58,24 @@ class BenchmarkConfig:
|
||||
alpha: float = 64.0
|
||||
delora_lambda0: float = 0.1
|
||||
road_group_size: int = 64
|
||||
antipasto_rotate_basis: Literal["V", "U", "none"] = "V"
|
||||
# AntiPaSTO family (gain / corda) runtime knobs.
|
||||
antipasto_coeff: float = 1.0
|
||||
antipasto_suppress_only: bool = False
|
||||
# AntiPaSTO-ablate.
|
||||
antipasto_ablate_k: int = 1
|
||||
antipasto_cov_orient: bool = False
|
||||
# AntiPaSTO-rot (legacy rotation variant) basis to rotate.
|
||||
antipasto_rotate_basis: Literal["V", "U", "both", "none"] = "V"
|
||||
# AntiPaSTO-dplr: rank of the low-rank mixing core in the frozen subspace.
|
||||
antipasto_lora_rank: int = 8
|
||||
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
|
||||
layers: str = "all"
|
||||
train_dataset: str = "meta-math/MetaMathQA"
|
||||
eval_dataset: str = "openai/gsm8k"
|
||||
eval_config: str = "main"
|
||||
steps: int = 5000
|
||||
batch_size: int = 4
|
||||
steps: int = 5000 # optimizer updates (each accumulates grad_accum micro-batches)
|
||||
batch_size: int = 4 # micro-batch (memory-bound); effective batch = batch_size * grad_accum
|
||||
grad_accum: int = 1 # gradient accumulation: raise effective batch without more memory
|
||||
batch_size_eval: int = 50
|
||||
max_train_samples: int | None = None
|
||||
max_eval_samples: int | None = None
|
||||
@@ -124,8 +140,16 @@ def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConf
|
||||
extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {}
|
||||
if args.variant == "road":
|
||||
extra = {"group_size": args.road_group_size}
|
||||
if args.variant == "antipasto":
|
||||
if args.variant == "antipasto_rot":
|
||||
extra = {"rotate_basis": args.antipasto_rotate_basis}
|
||||
if args.variant in ("antipasto", "antipasto_corda", "antipasto_asvd"):
|
||||
extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only}
|
||||
if args.variant == "antipasto_ablate":
|
||||
extra = {"coeff": args.antipasto_coeff, "k": args.antipasto_ablate_k,
|
||||
"cov_orient": args.antipasto_cov_orient}
|
||||
if args.variant == "antipasto_dplr":
|
||||
extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only,
|
||||
"lora_rank": args.antipasto_lora_rank}
|
||||
return CFG_BY_VARIANT[args.variant](
|
||||
r=args.r,
|
||||
alpha=args.r if args.variant == "pissa" else args.alpha,
|
||||
@@ -155,7 +179,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
|
||||
|
||||
|
||||
def perturb_first_adapter(model: torch.nn.Module) -> None:
|
||||
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m", "lora_road_theta", "lora_road_alpha")
|
||||
priority = ("lora_B", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha")
|
||||
for key in priority:
|
||||
for _, p in model.named_parameters():
|
||||
if p.requires_grad and key in _:
|
||||
@@ -259,7 +283,8 @@ def pad_batch(examples: list[dict[str, torch.Tensor | int]], pad_token_id: int,
|
||||
|
||||
|
||||
def make_train_batches(train_dataset, tokenizer, args: BenchmarkConfig) -> tuple[list[dict[str, torch.Tensor | int]], int]:
|
||||
needed = args.steps * args.batch_size
|
||||
# steps optimizer updates x grad_accum micro-batches/update x batch_size examples/micro-batch.
|
||||
needed = args.steps * args.grad_accum * args.batch_size
|
||||
examples = []
|
||||
skipped_prompt_too_long = 0
|
||||
for row in train_dataset:
|
||||
@@ -301,15 +326,23 @@ def train(model: torch.nn.Module, batches: list[dict[str, torch.Tensor | int]],
|
||||
last_loss = math.nan
|
||||
train_total_tokens = 0
|
||||
probe_batch = batches[0]
|
||||
pbar = tqdm(batches, desc="train", mininterval=60.0, dynamic_ncols=True)
|
||||
for step, batch in enumerate(pbar):
|
||||
accum = args.grad_accum
|
||||
# One optimizer update per `accum` micro-batches: scale each micro-loss by 1/accum so
|
||||
# the accumulated gradient equals a single backward over the effective batch.
|
||||
pbar = tqdm(range(args.steps), desc="train", mininterval=60.0, dynamic_ncols=True)
|
||||
for step in pbar:
|
||||
opt.zero_grad()
|
||||
step_loss = 0.0
|
||||
for micro in range(accum):
|
||||
batch = batches[step * accum + micro]
|
||||
loss = model(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"],
|
||||
).loss
|
||||
).loss / accum
|
||||
loss.backward()
|
||||
step_loss += loss.item() # micro already /accum -> sum is the mean
|
||||
train_total_tokens += int(batch["label_tokens"])
|
||||
grad_norm = sum(
|
||||
p.grad.detach().float().norm().item()
|
||||
for name, p in model.named_parameters()
|
||||
@@ -317,13 +350,12 @@ def train(model: torch.nn.Module, batches: list[dict[str, torch.Tensor | int]],
|
||||
)
|
||||
if step == 0:
|
||||
first_grad_norm = grad_norm
|
||||
first_loss = loss.item()
|
||||
first_loss = step_loss
|
||||
base_grad_leaks += count_base_grad_leaks(model)
|
||||
torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], args.grad_norm_clip)
|
||||
opt.step()
|
||||
scheduler.step()
|
||||
last_loss = loss.item()
|
||||
train_total_tokens += int(batch["label_tokens"])
|
||||
last_loss = step_loss
|
||||
pbar.set_postfix(loss=f"{last_loss:.4g}", grad=f"{grad_norm:.3g}", tok=train_total_tokens)
|
||||
pbar.close()
|
||||
after = adapter_state(model)
|
||||
@@ -420,11 +452,14 @@ def check_probe_reload(
|
||||
ll.load(loaded_model, str(adapter_path))
|
||||
from safetensors.torch import load_file
|
||||
saved_sd = load_file(str(adapter_path), device="cpu")
|
||||
loaded_state = adapter_state(loaded_model)
|
||||
if set(saved_sd) != set(loaded_state):
|
||||
raise AssertionError("loaded adapter keys differ from saved adapter keys")
|
||||
# Every saved tensor (lora_ buffers AND, for data-driven variants, the rewritten
|
||||
# base residuals) must reload bit-identical onto the model.
|
||||
loaded_full = loaded_model.state_dict()
|
||||
missing = set(saved_sd) - set(loaded_full)
|
||||
if missing:
|
||||
raise AssertionError(f"saved adapter keys absent from loaded model: {sorted(missing)[:8]}")
|
||||
for name, value in saved_sd.items():
|
||||
if not torch.equal(loaded_state[name].cpu(), value):
|
||||
if not torch.equal(loaded_full[name].cpu(), value):
|
||||
raise AssertionError(f"loaded adapter tensor differs: {name}")
|
||||
logits_loaded = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().clone()
|
||||
reload_err = (logits_loaded - logits_trained).abs().max().item()
|
||||
@@ -436,6 +471,24 @@ def check_probe_reload(
|
||||
return {"reload_err": reload_err, "saved_tensors": len(saved_sd)}
|
||||
|
||||
|
||||
def print_first_train_sample(tokenizer, batch: dict[str, torch.Tensor | int]) -> None:
|
||||
"""Dump row 0 of the first train batch WITH special tokens + the supervised span.
|
||||
|
||||
Transformers framing (pad side, eos, prompt/response boundary) is the #1 silent
|
||||
fine-tune bug; printing the real encoded batch once is the cheap canary for it.
|
||||
"""
|
||||
ids = batch["input_ids"][0]
|
||||
labels = batch["labels"][0]
|
||||
sup = labels != -100 # positions contributing to the loss
|
||||
print("\n=== first train sample (input_ids[0], special tokens shown) ===")
|
||||
print(repr(tokenizer.decode(ids, skip_special_tokens=False)))
|
||||
print("--- supervised span (labels != -100, what the model is trained to emit) ---")
|
||||
print(repr(tokenizer.decode(ids[sup], skip_special_tokens=False)))
|
||||
print(f"SHOULD: prompt ends with the PROMPT template then the answer+eos; supervised span = answer+eos ONLY "
|
||||
f"(pad_side={tokenizer.padding_side}, eos={tokenizer.eos_token!r}). "
|
||||
f"ELSE prompt/response boundary or pad/eos is mis-encoded. (len={len(ids)}, supervised={int(sup.sum())})\n")
|
||||
|
||||
|
||||
def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> None:
|
||||
# BLUF: status line first so log tails are immediately readable
|
||||
cue = "🟢" if row.get("base_grad_leaks", 0) == 0 and row.get("grad", 0) > 0 else "🔴"
|
||||
@@ -443,9 +496,10 @@ def print_final_report(row: dict[str, Any], result_path: Path, mode: str) -> Non
|
||||
print()
|
||||
print(f"{cue} test_acc={row['test_acc']:.4g} valid_acc={row['valid_acc']:.4g} grad={row['grad']:.3g} dθ={row['dθ']:.3g} base_grad_leaks={row['base_grad_leaks']} N={n}")
|
||||
print("SHOULD: grad>0, dθ>0, base_grad_leaks=0; test/valid_acc meaningful only in benchmark mode. ELSE adapter or eval wiring is dead/wrong.")
|
||||
print("SHOULD(cost): addMACs_M ~equal across antipasto cores at same r (r*(d_in+d_out)*n_targets added matmul); params_M differs (dplr/ablate add a trainable core); init_ms is large for the calibrated variants (corda/asvd/eva), and corda > asvd (full-covariance eigh vs cheap diagonal). ELSE the cost model is wrong.")
|
||||
print()
|
||||
# ordered: most important / shortest columns first
|
||||
display_keys = ["variant", "test_acc", "valid_acc", "params_M", "peak_mem_GB", "grad", "dθ", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"]
|
||||
display_keys = ["variant", "test_acc", "valid_acc", "params_M", "fwd_ms", "bwd_ms", "addMACs_M", "init_ms", "peak_mem_GB", "grad", "dθ", "base_grad_leaks", "steps", "samples", "loss0", "lossN", "commit"]
|
||||
if "perturb" in row:
|
||||
display_keys += ["perturb", "reload"]
|
||||
display_keys += ["run_id"]
|
||||
@@ -474,6 +528,7 @@ def append_results_row(
|
||||
finished_label = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
snapshot_path = results_dir / f"{result['run_id']}__{finished_label}.json"
|
||||
snapshot_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
||||
c = result.get("cost", {})
|
||||
row = {
|
||||
"test_acc": result["test_acc"],
|
||||
"valid_acc": result["valid_acc"],
|
||||
@@ -482,6 +537,16 @@ def append_results_row(
|
||||
"samples": result["train_samples"],
|
||||
"params_M": round(result["trainable_param_count"] / 1e6, 4),
|
||||
"peak_mem_GB": round(result.get("peak_cuda_mem_gb", 0.0), 3),
|
||||
# cost profile (one-time, measured at attach; see _cost.py). All deterministic
|
||||
# except the *_ms wall-times (median over warmup+iters), which stay noisy.
|
||||
"fwd_ms": round(c["fwd_ms"], 3) if c.get("fwd_ms") else None,
|
||||
"bwd_ms": round(c["bwd_ms"], 3) if c.get("bwd_ms") else None,
|
||||
"added_macs_per_tok": c.get("added_macs_per_token"), # adapter-only, arch-independent
|
||||
"fwd_macs": c.get("flops"), # whole model, None on hybrid attn
|
||||
"macs_per_tok": round(c["macs_per_token"]) if c.get("macs_per_token") else None,
|
||||
"adapter_mb": round(c["adapter_resident_mb"], 3) if c.get("adapter_resident_mb") else None,
|
||||
"init_ms": round(c["init_ms"], 1) if c.get("init_ms") else None,
|
||||
"init_peak_cpu_mb": round(c["init_peak_cpu_mb"], 1) if c.get("init_peak_cpu_mb") else None,
|
||||
"model": args.model,
|
||||
"commit": run_commit[:12],
|
||||
"wall_time_s": round(result["wall_time_s"]),
|
||||
@@ -494,6 +559,10 @@ def append_results_row(
|
||||
values = "\t".join(str(value) for value in row.values())
|
||||
with lock_path.open("w", encoding="utf-8") as lock_handle:
|
||||
fcntl.flock(lock_handle.fileno(), fcntl.LOCK_EX)
|
||||
# Rotate the file aside if its header no longer matches (e.g. cost columns added),
|
||||
# rather than appending misaligned rows under a stale header.
|
||||
if tsv_path.exists() and tsv_path.read_text(encoding="utf-8").split("\n", 1)[0] != header:
|
||||
tsv_path.rename(results_dir / f"summary.{finished_label}.tsv.bak")
|
||||
if not tsv_path.exists():
|
||||
tsv_path.write_text(header + "\n" + values + "\n", encoding="utf-8")
|
||||
else:
|
||||
@@ -510,20 +579,45 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
dtype = getattr(torch, args.torch_dtype)
|
||||
run_commit = current_git_commit()
|
||||
run_id = f"{args.model.replace('/', '--')}__{args.variant}__s{args.steps}__seed{args.seed}"
|
||||
# dplr capacity is set by lora_rank, not r, so keep rank-sweep runs from colliding.
|
||||
if args.variant == "antipasto_dplr" and args.antipasto_lora_rank != 8:
|
||||
run_id += f"__k{args.antipasto_lora_rank}"
|
||||
# antipasto family defaults to r=256; low-rank sweeps get their own dirs.
|
||||
if args.variant.startswith("antipasto") and args.r != 256:
|
||||
run_id += f"__r{args.r}"
|
||||
# antipasto_rot defaults to rotating V; U/both are ablation axes -> own dirs.
|
||||
if args.variant == "antipasto_rot" and args.antipasto_rotate_basis != "V":
|
||||
run_id += f"__rot{args.antipasto_rotate_basis}"
|
||||
# antipasto family defaults to lr=5e-3; lr sweeps get their own dirs (the dense/
|
||||
# low-rank cores want a tamer lr than the gain, so this is a real axis).
|
||||
if args.variant.startswith("antipasto") and abs(args.lr - 5e-3) > 1e-9:
|
||||
run_id += f"__lr{args.lr:g}"
|
||||
out_dir = args.output_dir / run_id
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
datasets = load_datasets(args)
|
||||
model, tokenizer = load_model_and_tokenizer(args.model, dtype, args.device, args.quantization)
|
||||
batches, skipped_train_prompt_too_long = make_train_batches(datasets["train"], tokenizer, args)
|
||||
print_first_train_sample(tokenizer, batches[0])
|
||||
cfg = cfg_for_variant(args, dtype)
|
||||
if args.variant == "eva":
|
||||
# Variants with a data-driven group_init need calibration activations from the
|
||||
# downstream task (IPM mode, per CorDA). eva needs only a few batches for its init;
|
||||
# corda/asvd/cov-orient estimate an input second moment, so we hand them many more
|
||||
# batches (PEFT calibrates on a few hundred sequences) for a well-conditioned basis.
|
||||
# antipasto_ablate always calibrates now: group_init warm-starts lora_c from the
|
||||
# S-space output variance (cov_orient adds the heavier CorDA re-orient on top).
|
||||
needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd", "antipasto_ablate")
|
||||
init_meter = group_init_meter() # wall-time + peak CPU RAM of group_init
|
||||
if needs_calib:
|
||||
n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches))
|
||||
calib = [
|
||||
{"input_ids": b["input_ids"], "attention_mask": b["attention_mask"]}
|
||||
for b in batches[: min(4, len(batches))]
|
||||
for b in batches[:n_batches]
|
||||
]
|
||||
with init_meter: # CorDA's d_in^3 eigh on CPU is the cost asymmetry
|
||||
ll.attach(model, cfg, calibration_data=calib)
|
||||
else:
|
||||
with init_meter:
|
||||
ll.attach(model, cfg)
|
||||
attached = getattr(model, "_lora_lite_attached")
|
||||
trainable_names = assert_only_lora_trainable(model)
|
||||
@@ -532,6 +626,22 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
probe_metrics = probe_before_train(model, batches[0], attached["targets"])
|
||||
model.train()
|
||||
|
||||
# One-time cost profile of the attached adapter, measured BEFORE training (~free:
|
||||
# ~10 fwd + ~10 fwd/bwd on one batch vs thousands of train steps). Reuses the exact
|
||||
# train loss path (input_ids/attention_mask/labels -> .loss.backward) so fwd/bwd ms
|
||||
# and FLOPs match what training pays. group_init cost captured separately above.
|
||||
b0 = batches[0]
|
||||
n_tokens = b0["input_ids"].numel() # padded positions the FLOP counter processes
|
||||
def _cost_fwd():
|
||||
model(input_ids=b0["input_ids"], attention_mask=b0["attention_mask"])
|
||||
def _cost_bwd_step():
|
||||
model.zero_grad(set_to_none=True)
|
||||
model(input_ids=b0["input_ids"], attention_mask=b0["attention_mask"], labels=b0["labels"]).loss.backward()
|
||||
cost = measure_cost(model, _cost_fwd, bwd_step_fn=_cost_bwd_step, n_tokens=n_tokens)
|
||||
cost["init_ms"] = init_meter.ms
|
||||
cost["init_peak_cpu_mb"] = init_meter.peak_cpu_mb
|
||||
model.zero_grad(set_to_none=True) # clear cost-measurement grads before training
|
||||
|
||||
if args.device == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
started = time.time()
|
||||
@@ -571,7 +681,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
"steps": args.steps,
|
||||
"batch_size": args.batch_size,
|
||||
"batch_size_eval": args.batch_size_eval,
|
||||
"train_samples": args.steps * args.batch_size,
|
||||
"train_samples": args.steps * args.grad_accum * args.batch_size,
|
||||
"grad_accum": args.grad_accum,
|
||||
"effective_batch": args.grad_accum * args.batch_size,
|
||||
"max_seq_length": args.max_seq_length,
|
||||
"optimizer": "AdamW",
|
||||
"lr": args.lr,
|
||||
@@ -587,6 +699,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
"adapter_path": str(adapter_path),
|
||||
"wall_time_s": time.time() - started,
|
||||
"peak_cuda_mem_gb": peak_mem_gb,
|
||||
"cost": cost, # params, FLOPs/MACs, fwd/bwd ms, peak gpu mb, group_init ms + peak cpu mb
|
||||
}
|
||||
result_path = out_dir / "result.json"
|
||||
result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
||||
@@ -601,7 +714,7 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
"run_id": run_id,
|
||||
"variant": args.variant,
|
||||
"steps": args.steps,
|
||||
"samples": args.steps * args.batch_size,
|
||||
"samples": args.steps * args.grad_accum * args.batch_size,
|
||||
"loss0": train_metrics["train_loss_first"],
|
||||
"lossN": train_metrics["train_loss_last"],
|
||||
"probeΔ": train_metrics["train_loss_probe_delta"],
|
||||
@@ -612,6 +725,11 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
"test_acc": test_metrics["accuracy"],
|
||||
"params_M": round(result["trainable_param_count"] / 1e6, 4),
|
||||
"peak_mem_GB": round(peak_mem_gb, 3),
|
||||
# cost profile (see _cost.py). fwd/bwd in ms, macs/token in M, init = group_init.
|
||||
"fwd_ms": round(cost["fwd_ms"], 2) if cost.get("fwd_ms") else None,
|
||||
"bwd_ms": round(cost["bwd_ms"], 2) if cost.get("bwd_ms") else None,
|
||||
"addMACs_M": round(cost["added_macs_per_token"] / 1e6, 2) if cost.get("added_macs_per_token") else None,
|
||||
"init_ms": round(cost["init_ms"], 1) if cost.get("init_ms") else None,
|
||||
"commit": run_commit[:12],
|
||||
"result": str(result_path),
|
||||
}
|
||||
|
||||
@@ -20,6 +20,11 @@ from .variants.dora import DoRAConfig
|
||||
from .variants.hra import HRAConfig
|
||||
from .variants.eva import EVAConfig
|
||||
from .variants.antipasto import AntiPaSTOConfig
|
||||
from .variants.antipasto_rot import AntiPaSTORotConfig
|
||||
from .variants.antipasto_ablate import AntiPaSTOAblateConfig
|
||||
from .variants.antipasto_corda import AntiPaSTOCorDAConfig
|
||||
from .variants.antipasto_asvd import AntiPaSTOASVDConfig
|
||||
from .variants.antipasto_dplr import AntiPaSTODPLRConfig
|
||||
from .variants.road import RoadConfig
|
||||
|
||||
__all__ = [
|
||||
@@ -33,6 +38,11 @@ __all__ = [
|
||||
"HRAConfig",
|
||||
"EVAConfig",
|
||||
"AntiPaSTOConfig",
|
||||
"AntiPaSTORotConfig",
|
||||
"AntiPaSTOAblateConfig",
|
||||
"AntiPaSTOCorDAConfig",
|
||||
"AntiPaSTOASVDConfig",
|
||||
"AntiPaSTODPLRConfig",
|
||||
"RoadConfig",
|
||||
"attach",
|
||||
"detach",
|
||||
|
||||
@@ -62,17 +62,28 @@ def attach(model: nn.Module, cfg: AdapterConfig, calibration_data=None, *, _skip
|
||||
attached_names.append(name)
|
||||
attached_targets.append((name, layer, role))
|
||||
|
||||
group_init = getattr(variant, "group_init", None)
|
||||
if group_init is not None and not _skip_group_init:
|
||||
group_init(model, attached_targets, cfg, calibration_data)
|
||||
|
||||
# Register the adapter hooks BEFORE group_init. init() crops each weight to W_res,
|
||||
# so without the hooks the calibration forward inside group_init would run through a
|
||||
# model missing every target's top-r. At g=0 (and B=0) the hooks reconstruct the
|
||||
# cropped component exactly, so calibration sees the true full W.
|
||||
for _, layer, _ in attached_targets:
|
||||
if hasattr(layer._lora_variant, "forward_input"):
|
||||
handles.append(layer.register_forward_pre_hook(_pre_hook))
|
||||
else:
|
||||
handles.append(layer.register_forward_hook(_hook))
|
||||
|
||||
setattr(model, _ATTACHED_ATTR, {"cfg": cfg, "targets": attached_names, "handles": handles})
|
||||
group_init = getattr(variant, "group_init", None)
|
||||
ran_data_init = group_init is not None and not _skip_group_init and calibration_data is not None
|
||||
if group_init is not None and not _skip_group_init:
|
||||
group_init(model, attached_targets, cfg, calibration_data)
|
||||
|
||||
# A data-driven group_init (CorDA orient, Wanda re-select) rewrites the frozen
|
||||
# base residual W_res into a form init() cannot reproduce at load time (it only
|
||||
# knows the plain top-r crop). So those residuals are part of the saved adapter.
|
||||
base_weight_keys = [f"{n}.weight" for n in attached_names] if ran_data_init else []
|
||||
setattr(model, _ATTACHED_ATTR,
|
||||
{"cfg": cfg, "targets": attached_names, "handles": handles,
|
||||
"base_weight_keys": base_weight_keys})
|
||||
return handles
|
||||
|
||||
|
||||
@@ -87,6 +98,14 @@ def detach(model: nn.Module) -> None:
|
||||
if not hasattr(layer, "_lora_variant"):
|
||||
continue
|
||||
variant = layer._lora_variant
|
||||
# Undo the PiSSA-style crop: init() set weight = W - U_r S_r (Vh|P)_r, so add the
|
||||
# frozen top-r back to recover the original W (the trained gain/core are dropped).
|
||||
# Keyed on the shared SVD-gain buffer convention (antipasto family); variants
|
||||
# without lora_U leave weight untouched (e.g. LoRA never cropped it).
|
||||
if hasattr(layer, "lora_U"):
|
||||
proj = layer.lora_P if hasattr(layer, "lora_P") else layer.lora_Vh
|
||||
with torch.no_grad():
|
||||
layer.weight.data += ((layer.lora_U * layer.lora_S) @ proj).to(layer.weight.dtype)
|
||||
for pname in variant.param_specs(layer.in_features, layer.out_features, layer._lora_cfg):
|
||||
if pname in layer._parameters:
|
||||
del layer._parameters[pname]
|
||||
@@ -102,8 +121,14 @@ def save(model: nn.Module, path: str) -> None:
|
||||
state = getattr(model, _ATTACHED_ATTR, None)
|
||||
if state is None:
|
||||
raise RuntimeError("no adapter attached; call attach() first")
|
||||
sd = {k: v.detach().cpu() for k, v in model.state_dict().items() if "lora_" in k}
|
||||
metadata = {"cfg": json.dumps(state["cfg"].to_dict())}
|
||||
full_sd = model.state_dict()
|
||||
sd = {k: v.detach().cpu() for k, v in full_sd.items() if "lora_" in k}
|
||||
# data-driven variants also persist their rewritten base residuals (see attach()).
|
||||
base_weight_keys = state.get("base_weight_keys", [])
|
||||
for wk in base_weight_keys:
|
||||
sd[wk] = full_sd[wk].detach().cpu()
|
||||
metadata = {"cfg": json.dumps(state["cfg"].to_dict()),
|
||||
"base_weight_keys": json.dumps(base_weight_keys)}
|
||||
from safetensors.torch import save_file
|
||||
save_file(sd, path, metadata=metadata)
|
||||
|
||||
@@ -114,6 +139,12 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||
metadata = f.metadata()
|
||||
sd = load_file(path, device="cpu")
|
||||
cfg = AdapterConfig.from_dict(json.loads(metadata["cfg"]))
|
||||
# Base residuals a data-driven group_init rewrote: must be in the checkpoint and
|
||||
# are restored by load_state_dict below (init()'s plain crop would be wrong).
|
||||
base_weight_keys = json.loads(metadata.get("base_weight_keys", "[]"))
|
||||
missing_base = [wk for wk in base_weight_keys if wk not in sd]
|
||||
if missing_base:
|
||||
raise RuntimeError(f"checkpoint declares but omits base residuals: {missing_base}")
|
||||
handles = attach(model, cfg, _skip_group_init=True) # creates empty params; data-driven inits restored from state_dict
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False)
|
||||
expected_lora = {k for k in model.state_dict() if "lora_" in k}
|
||||
@@ -123,4 +154,6 @@ def load(model: nn.Module, path: str) -> list[RemovableHandle]:
|
||||
unexpected_lora = [k for k in unexpected if "lora_" in k]
|
||||
if unexpected_lora:
|
||||
raise RuntimeError(f"unexpected lora keys in checkpoint: {unexpected_lora}")
|
||||
# Carry the residual keys onto the attach state so a later save() re-persists them.
|
||||
getattr(model, _ATTACHED_ATTR)["base_weight_keys"] = base_weight_keys
|
||||
return handles
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""AdapterConfig: per-variant typed dataclass.
|
||||
|
||||
Replaces the older `LoraLiteConfig` + `variant_kwargs` dict. Each variant
|
||||
ships its own subclass under `variants/*.py` (e.g. `DeLoRAConfig`), adding
|
||||
strongly-typed fields so users discover the knobs via IDE / dataclass
|
||||
introspection instead of stringly-typed dict lookups.
|
||||
Each variant ships its own subclass under `variants/*.py` (e.g. `DeLoRAConfig`),
|
||||
adding strongly-typed fields so the knobs are discoverable via IDE / dataclass
|
||||
introspection rather than stringly-typed dict lookups.
|
||||
|
||||
Wire-up:
|
||||
- `AdapterConfig` holds the universal fields (variant name, rank, alpha,
|
||||
|
||||
@@ -24,9 +24,6 @@ class ParamSpec:
|
||||
# avoid exact-zero dead zone; N(0, 1e-4) is small enough to be
|
||||
# ~identity but nonzero so gradients always have somewhere to go
|
||||
t.normal_(0, 1e-4)
|
||||
elif self.init == "near_one":
|
||||
# avoid exact-one dead zone; 1 + N(0, 1e-4)
|
||||
t.fill_(1.0).add_(torch.randn_like(t).mul_(1e-4))
|
||||
elif self.init == "ones":
|
||||
t.fill_(1.0)
|
||||
elif self.init == "kaiming":
|
||||
@@ -37,7 +34,7 @@ class ParamSpec:
|
||||
return t
|
||||
|
||||
def make(self, dtype: torch.dtype, device) -> nn.Parameter:
|
||||
# legacy entry: returns a Parameter (used for trainable adapter params)
|
||||
# trainable params -> Parameter; buffers go through make_tensor (see attach)
|
||||
if self.as_buffer:
|
||||
raise RuntimeError("as_buffer spec must be installed via register_buffer; see adapter.attach")
|
||||
return nn.Parameter(self._empty(dtype, device), requires_grad=self.trainable)
|
||||
|
||||
@@ -1 +1,4 @@
|
||||
from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto, road # noqa: F401 side-effect: register
|
||||
from . import ( # noqa: F401 side-effect: register
|
||||
lora, pissa, delora, ia3, dora, hra, eva, antipasto, road,
|
||||
antipasto_rot, antipasto_ablate, antipasto_corda, antipasto_asvd, antipasto_dplr,
|
||||
)
|
||||
|
||||
@@ -1,32 +1,30 @@
|
||||
"""AntiPaSTO: SVD steering with learnable singular-value deltas + block-diagonal Cayley rotation.
|
||||
"""AntiPaSTO: learnable bounded reweighting of frozen SVD singular values.
|
||||
|
||||
wassname 2026 https://arxiv.org/abs/2601.07473
|
||||
|
||||
W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r)
|
||||
learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2)
|
||||
R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T)
|
||||
y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T
|
||||
W = U diag(S) Vh + W_res # top-r SVD; W_res = W - U_r S_r Vh_r, frozen
|
||||
learn: g (r,) # per-direction gain
|
||||
S_eff = S * (1 + ELU(coeff * g)) # exp(z) for z<0 (bounded), 1+z for z>0
|
||||
y = x @ W_res.T + ((x @ Vh.T) * S_eff) @ U.T
|
||||
|
||||
Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ≈ x @ W^T (fp32 SVD round-trip, tiny positive bias on delta_s breaks sign symmetry).
|
||||
suppress_only: clamp g<=0 -> S_eff in (0, S], attenuation only.
|
||||
coeff: runtime scale; 0 = identity, <0 swaps amplify/suppress.
|
||||
|
||||
TODO remove rambling
|
||||
Scope cut vs antipasto3: this is a fine-tuning adapter, not the full runtime
|
||||
steering interface. There is no per-call alpha, so it does not expose the
|
||||
bidirectional R(+alpha) / R(-alpha) inference symmetry. The V-basis path uses the
|
||||
opposite chirality to antipasto3's default U-basis path, so checkpoints are not
|
||||
portable without a sign/basis convention.
|
||||
Identity at g=0 or coeff=0: 1+ELU(0)=1, so S_eff=S (up to the bf16 SVD round-trip).
|
||||
The basis (U, Vh) is frozen, so the singular directions stay interpretable and only
|
||||
the gain is learned. See forward() for why 1+ELU over linear/exp/tanh.
|
||||
|
||||
Refs:
|
||||
- paper: https://github.com/wassname/AntiPaSTO
|
||||
- lite port of: https://github.com/wassname/antipasto3
|
||||
(offline: docs/refs/antipasto3_svd_adapter.py)
|
||||
- sibling (whitened, mean-diff): steering-lite/.../sspace.py
|
||||
- selection: Wanda (Sun+ 2023, arXiv:2306.11695), ASVD (Yuan+ 2023, arXiv:2312.05821)
|
||||
- top-r SVD init: PiSSA (Meng+ 2024, arXiv:2404.02948)
|
||||
"""
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Literal
|
||||
|
||||
import torch
|
||||
from einops import einsum, rearrange
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
@@ -41,35 +39,17 @@ CalibrationData = Iterable[CalibrationBatch]
|
||||
@dataclass
|
||||
class AntiPaSTOConfig(AdapterConfig):
|
||||
variant: str = "antipasto"
|
||||
# Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out).
|
||||
# Only r + r trainable scalars, so r can be large.
|
||||
r: int = 256
|
||||
|
||||
# Block size for the block-diagonal Cayley rotation. r must be divisible by it.
|
||||
block_size: int = 4
|
||||
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
|
||||
max_rotation_angle: float = 0.5
|
||||
# Which singular basis to rotate: 'V' (input), 'U' (output), or 'none'.
|
||||
rotate_basis: Literal["V", "U", "none"] = "V"
|
||||
|
||||
|
||||
def _cayley(skew: torch.Tensor) -> torch.Tensor:
|
||||
"""R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality."""
|
||||
bs = skew.shape[-1]
|
||||
eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew)
|
||||
X = skew / 2
|
||||
return torch.linalg.solve(eye - X, eye + X)
|
||||
|
||||
|
||||
def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor:
|
||||
"""rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation."""
|
||||
n_blocks, _ = rot_T.shape
|
||||
rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0)
|
||||
A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device)
|
||||
A[:, rows, cols] = rot_T
|
||||
A = 0.5 * (A - A.transpose(-1, -2))
|
||||
a_limit = 2.0 * math.tan(max_angle / 2.0)
|
||||
A = a_limit * torch.tanh(A / a_limit)
|
||||
return _cayley(A)
|
||||
# Per-direction reweighting is S_eff = S * (1 + ELU(coeff * g)). See forward()
|
||||
# for the why; identity at g=0 or coeff=0, positive always, no free bound knob.
|
||||
suppress_only: bool = False # clamp g<=0 -> factor in (0,1]: attenuation only.
|
||||
# Guarantee holds for coeff>=0; coeff<0 inverts the product and re-amplifies.
|
||||
# Runtime steering scale. 0 = identity. <0 inverts (swaps amplify/suppress).
|
||||
coeff: float = 1.0
|
||||
# group_init Wanda-style pooling of |X @ Vh[i]|: 'rms' is outlier-sensitive
|
||||
# (ASVD intuition), 'mean_abs' is the original outlier-robust pooling.
|
||||
act_pool: Literal["rms", "mean_abs"] = "rms"
|
||||
|
||||
|
||||
@register
|
||||
@@ -79,24 +59,14 @@ class AntiPaSTO:
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r = cfg.r
|
||||
bs = int(cfg.block_size)
|
||||
if r % bs != 0:
|
||||
raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}")
|
||||
specs = dict(
|
||||
# Frozen SVD components captured at init.
|
||||
return dict(
|
||||
# Frozen top-r SVD captured at init.
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable: per-singular-value delta.
|
||||
# antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign
|
||||
# symmetry (rotation alone can't); zero-init works but trains slower.
|
||||
lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)),
|
||||
# Trainable per-direction log-scale. init 0 -> 1+ELU(0)=1 -> identity.
|
||||
lora_g=ParamSpec((r,), init="zeros"),
|
||||
)
|
||||
if cfg.rotate_basis != "none":
|
||||
n_blocks = r // bs
|
||||
n_triu = bs * (bs - 1) // 2
|
||||
specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros")
|
||||
return specs
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
@@ -115,18 +85,19 @@ class AntiPaSTO:
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
# group_init() refines this to input-aligned directions if calibration_data is given.
|
||||
# group_init() refines the dimension selection if calibration_data is given.
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Wanda-style data-driven dimension selection within the weight SVD.
|
||||
"""Data-driven re-selection of which top-r singular directions to keep.
|
||||
|
||||
init() picks the top-r singular dimensions by S alone (PiSSA-style).
|
||||
group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions
|
||||
that are both large in W AND active given real inputs.
|
||||
FIXME os that corda? or pissa? wanted corda. and or ASVD
|
||||
init(): top-r by S alone (PiSSA-style)
|
||||
group_init(): top-r by score[i] = S[i] * pool|X @ Vh[i]| (Wanda/ASVD)
|
||||
pool = 'rms' (outlier-sensitive) | 'mean_abs' (outlier-robust)
|
||||
|
||||
If calibration_data is None the weight-SVD init from init() is kept.
|
||||
This re-RANKS W's own singular vectors by activation; it does NOT re-orient
|
||||
the basis (that is CorDA -> antipasto_corda.py). So the kept directions are
|
||||
still plain weight-SVD directions, just a better subset. None -> keep init().
|
||||
"""
|
||||
if calibration_data is None:
|
||||
return
|
||||
@@ -162,6 +133,7 @@ class AntiPaSTO:
|
||||
h.remove()
|
||||
|
||||
r = cfg.r
|
||||
pool = cfg.act_pool
|
||||
for name, layer in layers.items():
|
||||
X = torch.cat(captured[name], dim=0) # (N, d_in)
|
||||
if X.shape[0] < r:
|
||||
@@ -169,18 +141,21 @@ class AntiPaSTO:
|
||||
f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}"
|
||||
)
|
||||
|
||||
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r components
|
||||
# FIXME isnt this run after, not instead of init. so this is using cropped matrixes
|
||||
# Rebuild the FULL W: init() stored the exact top-r it subtracted, so
|
||||
# W_res + U_r S_r Vh_r == W (full rank, not a cropped matrix). The SVD
|
||||
# below therefore re-selects from W's whole spectrum, not a truncation.
|
||||
W_res = layer.weight.data.float()
|
||||
U_old = layer.lora_U.float() # (d_out, r)
|
||||
S_old = layer.lora_S.float() # (r,)
|
||||
Vh_old = layer.lora_Vh.float() # (r, d_in)
|
||||
U_old = layer.lora_U.float()
|
||||
S_old = layer.lora_S.float()
|
||||
Vh_old = layer.lora_Vh.float()
|
||||
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
|
||||
|
||||
# Full SVD to score all dimensions
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
|
||||
# score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude)
|
||||
act_mag = (X @ Vh_full.T).abs().mean(dim=0) # (k,)
|
||||
proj = X.to(Vh_full) @ Vh_full.T # (N, r) input in S-coords (X CPU -> GPU here)
|
||||
if pool == "rms":
|
||||
act_mag = proj.pow(2).mean(dim=0).sqrt() # outlier-sensitive
|
||||
else:
|
||||
act_mag = proj.abs().mean(dim=0) # outlier-robust (original)
|
||||
scores = S_full * act_mag
|
||||
idx = scores.argsort(descending=True)[:r] # top-r by joint importance
|
||||
idx = idx.sort().values # stable ordering
|
||||
@@ -201,35 +176,23 @@ class AntiPaSTO:
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
bs = int(cfg.block_size)
|
||||
max_angle = float(cfg.max_rotation_angle)
|
||||
rotate_basis = cfg.rotate_basis
|
||||
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
g = layer.lora_g.to(x.dtype) # (r,)
|
||||
coeff = float(cfg.coeff)
|
||||
|
||||
if rotate_basis == "none":
|
||||
U_eff, Vh_eff = U, Vh
|
||||
else:
|
||||
R_blocks = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype)
|
||||
n_blocks = R_blocks.shape[0] # R_blocks: (n, bs, bs)
|
||||
if rotate_basis == "V":
|
||||
Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks)
|
||||
Vh_rot = einsum(R_blocks, Vh_blocks, "n a b, n b i -> n a i")
|
||||
Vh_eff = rearrange(Vh_rot, "n a i -> (n a) i")
|
||||
U_eff = U
|
||||
elif rotate_basis == "U":
|
||||
U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks)
|
||||
U_rot = einsum(U_blocks, R_blocks, "d n b, n c b -> d n c")
|
||||
U_eff = rearrange(U_rot, "d n c -> d (n c)")
|
||||
Vh_eff = Vh
|
||||
else:
|
||||
raise ValueError(f"rotate_basis must be 'U', 'V', or 'none', got {rotate_basis!r}")
|
||||
if cfg.suppress_only:
|
||||
g = torch.clamp(g, max=0.0) # factor in (0,1]: attenuation only
|
||||
|
||||
# FIXME: try lora_delta_s as [r,k] this is because the main limit of this adapter is that it's under parametised here. `reduce(h @ U_eff.T, '... k -> ...'). But have to make sure it's not lienarly reducable to one adapter.
|
||||
S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,)
|
||||
h = x @ Vh_eff.T # x @ Vh_eff.T
|
||||
h = h * S_eff # diag(S_eff)
|
||||
delta = h @ U_eff.T # @ U_eff.T
|
||||
return y + delta
|
||||
# S_eff = S * (1 + ELU(z)), z = coeff*g, 1+ELU(z) = exp(z) for z<=0 else 1+z.
|
||||
# Why 1+ELU and not the obvious alternatives:
|
||||
# linear S*(1+z) : z<-1 -> S_eff<0, a sign flip that drives incoherence.
|
||||
# exp S*exp(z) : unbounded, gradient self-amplifies (amplification blows up).
|
||||
# tanh bounded : arbitrary bound knob, saturation kills the gradient.
|
||||
# 1+ELU uses each in its safe regime: exp where it is bounded in (0,1]
|
||||
# (attenuation), linear where exp would diverge (amplification). >0 always.
|
||||
S_eff = S * (1.0 + torch.nn.functional.elu(coeff * g))
|
||||
|
||||
h = (x @ Vh.T) * S_eff # input in S-coords, reweighted
|
||||
return y + h @ U.T
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
"""AntiPaSTO-Ablate: trainable directional ablation in the weight-SVD output basis.
|
||||
|
||||
A contractive sibling of antipasto.py: instead of reweighting the singular gains it
|
||||
projects out a learned direction in the output (U-side) singular basis.
|
||||
|
||||
W = U diag(S) Vh + W_res
|
||||
learn: c (r, k) ablation directions, alpha (k,) strengths in [0, 1]
|
||||
Chat = orthonormal(c) # k unit dirs in S-space
|
||||
h = (x @ Vh.T) * S # output S-coords = diag(S) Vh x
|
||||
h <- h - coeff * (h @ Chat) * alpha @ Chat.T # project the span out
|
||||
y = x @ W_res.T + h @ U.T
|
||||
|
||||
The core (I - alpha Chat Chat^T) is a contraction: eigenvalues 1-alpha along Chat,
|
||||
1 elsewhere, all in [0, 1]. It cannot amplify, so it cannot blow up -- the instability
|
||||
the multiplicative gain bounds away is structurally absent (and a contraction is the
|
||||
natural core to recurse). This is the trainable form of directional ablation (Arditi+
|
||||
2024): target residual writers (down_proj, o_proj) for the surgical regime, not all
|
||||
Linears.
|
||||
|
||||
Runtime: coeff is the per-call knob. coeff=0 -> identity; (0, 1] -> ablate; <0 adds the
|
||||
direction back (the side that can grow, so bound coeff there).
|
||||
|
||||
Refs: antipasto.py (gain sibling), directional ablation Arditi+ 2024 arXiv:2406.11717.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
ε = 1e-6
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTOAblateConfig(AdapterConfig):
|
||||
variant: str = "antipasto_ablate"
|
||||
r: int = 256 # top-r SVD captured (or |dS|-selected via group_init)
|
||||
k: int = 1 # number of ablation directions (rank of the projection)
|
||||
init_alpha: float = 0.05 # small >0 so c gets gradient at step 0
|
||||
coeff: float = 1.0 # runtime: 0=identity, (0,1]=ablate, <0=amplify (bound this side)
|
||||
# CorDA-orient the basis from input covariance (group_init, needs calibration_data).
|
||||
# The ablation is OUTPUT-side and CorDA's U stays orthonormal, so this is a clean
|
||||
# contraction; the win is at low r -- the data-oriented top-r captures the behavior
|
||||
# output direction that plain-SVD top-r drops (measured 1.00 vs 0.65 at r=16).
|
||||
cov_orient: bool = False
|
||||
cov_eps: float = 1e-3
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTOAblate:
|
||||
name = "antipasto_ablate"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r, k = cfg.r, cfg.k
|
||||
return dict(
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable: k ablation directions in S-space, and their strengths.
|
||||
lora_c=ParamSpec((r, k), init=lambda t: t.normal_(0, 1.0 / max(r, 1) ** 0.5)),
|
||||
lora_alpha=ParamSpec((k,), init=lambda t: t.fill_(float(cfg.init_alpha))),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError("AntiPaSTOAblate mutates layer.weight into W_res; nn.Linear only.")
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
# lora_c starts random here; group_init warm-starts it from the S-space output
|
||||
# variance when calibration_data is given (see group_init), else it trains from noise.
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Warm-start each lora_c from calibration activations (and, if cov_orient,
|
||||
re-orient the frozen SVD by input covariance C=E[x xᵀ] first, CorDA-style).
|
||||
|
||||
lora_c is seeded to the top-k principal axes of the S-space OUTPUT coords
|
||||
h = diag(S) Vh x over the calibration set: the highest-energy output directions,
|
||||
where the loss-gradient on the ablation strength is largest, so lora_c starts in a
|
||||
high-gradient region instead of a near-orthogonal random one. NOTE this is the data
|
||||
VARIANCE direction, not a contrastive behavior direction -- this benchmark is SFT
|
||||
with no pos/neg split. For steering with contrastive pairs, seed lora_c from
|
||||
mean(h|pos) - mean(h|neg) instead (cf. steering-lite sspace extract).
|
||||
|
||||
Σ xxᵀ (d_in², heavy for down_proj) is only accumulated to orient; the warm-start
|
||||
alone (cov_orient=False) needs just the cheap r×r second moment Σ hhᵀ."""
|
||||
if calibration_data is None:
|
||||
return
|
||||
orient = bool(getattr(cfg, "cov_orient", False))
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
gram: dict[str, T] = {} # Σ xxᵀ (d_in²), only when orienting
|
||||
mom: dict[str, T] = {} # Σ hhᵀ (r²), when not orienting (basis is fixed at init)
|
||||
cnt: dict[str, int] = {n: 0 for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
layer = layers[name]
|
||||
def _h(module, args, kwargs):
|
||||
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
|
||||
if orient:
|
||||
g = x.T @ x
|
||||
gram[name] = g if name not in gram else gram[name] + g
|
||||
else:
|
||||
h = (x @ layer.lora_Vh.float().cpu().T) * layer.lora_S.float().cpu()
|
||||
m = h.T @ h
|
||||
mom[name] = m if name not in mom else mom[name] + m
|
||||
cnt[name] += x.shape[0]
|
||||
return _h
|
||||
|
||||
handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
if isinstance(batch, dict):
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r, k, eps = cfg.r, cfg.k, float(cfg.cov_eps)
|
||||
for name, layer in layers.items():
|
||||
if cnt[name] < r:
|
||||
raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}")
|
||||
if orient:
|
||||
W_res = layer.weight.data.float().cpu()
|
||||
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
|
||||
layer.lora_S.float().cpu(),
|
||||
layer.lora_Vh.float().cpu())
|
||||
W_orig = W_res + (U_old * S_old) @ Vh_old
|
||||
|
||||
C = gram[name] / cnt[name]
|
||||
lam, Q = torch.linalg.eigh(C)
|
||||
lam = lam.clamp_min(0) + eps
|
||||
Chalf = (Q * lam.sqrt()) @ Q.T
|
||||
Cinvhalf = (Q * lam.rsqrt()) @ Q.T
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
|
||||
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
|
||||
Sr = St[:r]
|
||||
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
|
||||
W_res_new = W_orig - (Ur * Sr) @ Pr
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
|
||||
layer.weight.data.copy_(W_res_new.to(layer.weight))
|
||||
# output S-space second moment in the (now oriented) basis: diag(S) P Σxxᵀ Pᵀ diag(S)
|
||||
SP = Sr[:, None] * Pr
|
||||
M = SP @ gram[name] @ SP.T
|
||||
else:
|
||||
M = mom[name] # (r, r) Σ hhᵀ in the init basis
|
||||
|
||||
c0 = torch.linalg.eigh(M).eigenvectors[:, -k:] # top-k principal dirs (orthonormal)
|
||||
with torch.no_grad():
|
||||
layer.lora_c.copy_(c0.to(layer.lora_c))
|
||||
|
||||
@staticmethod
|
||||
def _orthonormal(c: T) -> T:
|
||||
"""(r, k) -> (r, k) with orthonormal columns. k=1 is a plain normalize."""
|
||||
if c.shape[-1] == 1:
|
||||
return c / (c.norm(dim=0, keepdim=True) + ε)
|
||||
# geqrf has no bf16/fp16 kernel (CPU or CUDA); do the QR in fp32, cast back.
|
||||
q, _ = torch.linalg.qr(c.float()) # reduced QR; columns orthonormal
|
||||
return q.to(c.dtype)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
Chat = AntiPaSTOAblate._orthonormal(layer.lora_c.to(x.dtype)) # (r, k)
|
||||
alpha = layer.lora_alpha.to(x.dtype).clamp(0.0, 1.0) # (k,)
|
||||
coeff = float(cfg.coeff)
|
||||
|
||||
h = (x @ Vh.T) * S # (..., r) output S-coords
|
||||
proj = h @ Chat # (..., k) component along each dir
|
||||
# contractive removal: h <- h - coeff * Sum_j alpha_j (h . chat_j) chat_j
|
||||
h = h - coeff * (proj * alpha) @ Chat.T # (..., r)
|
||||
return y + h @ U.T
|
||||
@@ -0,0 +1,43 @@
|
||||
"""AntiPaSTO-ASVD: diagonal-covariance sibling of antipasto_corda.
|
||||
|
||||
Same frozen-basis bounded gain, but orients the SVD by the DIAGONAL of the input
|
||||
second moment (per-channel activation scale) instead of the full covariance:
|
||||
|
||||
M = diag(E[x_i^2]) vs CorDA's full C = E[x x^T]
|
||||
|
||||
This is Activation-aware SVD (Yuan+ 2023, arXiv:2312.05821): SVD(W diag(s)) with s a
|
||||
per-channel scale. It is NOT a sub-basis of CorDA -- diag(C)^{1/2} and C^{1/2} are
|
||||
different oblique rotations, so the top-r directions differ and either can win on a task.
|
||||
ASVD is the cheap arm: O(d_in) moment, no d_in x d_in matrix, no eigh. The head-to-head
|
||||
with antipasto_corda isolates whether the off-diagonal of C earns its init cost here.
|
||||
|
||||
Reuses antipasto_corda's buffers (U, S, P, g), plain-SVD init, gain forward, and the
|
||||
shared `_covariance_orient` (only the diag flag differs), so there is one copy of the
|
||||
math to keep in sync.
|
||||
|
||||
Refs: antipasto_corda.py (full-covariance sibling), ASVD arXiv:2312.05821.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..variant import register
|
||||
from ..config import register_config
|
||||
from .antipasto_corda import AntiPaSTOCorDA, AntiPaSTOCorDAConfig, _covariance_orient
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTOASVDConfig(AntiPaSTOCorDAConfig):
|
||||
variant: str = "antipasto_asvd"
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTOASVD:
|
||||
name = "antipasto_asvd"
|
||||
param_specs = staticmethod(AntiPaSTOCorDA.param_specs)
|
||||
init = staticmethod(AntiPaSTOCorDA.init)
|
||||
forward = staticmethod(AntiPaSTOCorDA.forward)
|
||||
|
||||
@staticmethod
|
||||
def group_init(model, targets, cfg, calibration_data) -> None:
|
||||
"""ASVD: re-orient by the diagonal of the input second moment (per-channel)."""
|
||||
_covariance_orient(model, targets, cfg, calibration_data, diag=True)
|
||||
@@ -0,0 +1,203 @@
|
||||
"""AntiPaSTO-CorDA: reweight in a covariance-oriented basis, not the weight basis.
|
||||
|
||||
Plain SVD sorts directions by weight gain ||W v|| on isotropic input. The behaviour
|
||||
you steer lives where the DATA has energy, off the top weight-singular axes. CorDA
|
||||
(Yang+ 2024, arXiv:2406.05223) re-orients the SVD by the input covariance, so the top-r
|
||||
directions move the output most on real activations.
|
||||
|
||||
C = E[x x^T] (+ eps I) # input second moment on calibration data
|
||||
C^{1/2}, C^{-1/2} via eigh(C)
|
||||
U S Vht = SVD(W C^{1/2}) # top-r is Eckart-Young best under x ~ N(0,C)
|
||||
P = Vht C^{-1/2} # (r, d_in) oblique input projector
|
||||
W = W_res + U_r diag(S_r) P_r # exact (residual carries the dropped tail)
|
||||
S_eff = S * (1 + ELU(coeff*g)) # same bounded gain as antipasto
|
||||
y = x @ W_res.T + ((x @ P.T) * S_eff) @ U.T
|
||||
|
||||
Identity at g=0 or coeff=0: S_eff=S. P is oblique (rows not orthonormal -- C^{-1/2}
|
||||
skews them); fine for gain reweighting since U stays orthonormal. Requires
|
||||
calibration_data (group_init raises otherwise).
|
||||
|
||||
Refs: antipasto.py (gain + selection sibling), CorDA arXiv:2406.05223.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTOCorDAConfig(AdapterConfig):
|
||||
variant: str = "antipasto_corda"
|
||||
r: int = 256
|
||||
cov_eps: float = 1e-3 # damping on C eigenvalues; guards C^{-1/2} on rare dirs
|
||||
coeff: float = 1.0 # runtime steer knob: 0=identity, scales trained g
|
||||
suppress_only: bool = False # clamp g<=0 (attenuate only) -- for coeff>=0;
|
||||
# coeff<0 inverts the product (coeff*g>=0) and re-amplifies.
|
||||
|
||||
|
||||
def _gain(S: T, g: T, coeff: float, suppress_only: bool) -> T:
|
||||
"""S_eff = S * (1 + ELU(coeff*g)); exp-bounded attenuation, linear amplification."""
|
||||
if suppress_only:
|
||||
g = g.clamp(max=0.0)
|
||||
return S * (1.0 + F.elu(coeff * g))
|
||||
|
||||
|
||||
def _covariance_orient(model, targets, cfg, calibration_data, *, diag: bool) -> None:
|
||||
"""Re-orient each target's SVD by its input second moment, then rewrite the frozen
|
||||
buffers (U, S, P) and residual weight in that basis. Shared by CorDA and ASVD:
|
||||
|
||||
diag=False -> CorDA: full C = E[x x^T] (cross-channel covariance, via eigh)
|
||||
diag=True -> ASVD: diag(C) = E[x_i^2] only (per-channel scale, O(d_in), no eigh)
|
||||
|
||||
The off-diagonal of C is the sole difference. g=0 stays exact identity either way --
|
||||
the reconstruction (W_res + U_r S_r P_r = W_orig) is lossless. Accumulated on CPU: a
|
||||
full C is d_in^2 fp32 per target and would crowd the GPU; the diagonal is a d_in vector.
|
||||
Call at attach-time, before training touches g (re-orienting g=0 is a no-op).
|
||||
"""
|
||||
if calibration_data is None:
|
||||
raise ValueError("covariance orientation requires calibration_data; got None.")
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
moment: dict[str, T] = {} # (d_in,d_in) full, or (d_in,) diagonal
|
||||
cnt: dict[str, int] = {n: 0 for n in layers}
|
||||
keep: dict[str, T] = {} # non-pad mask of the in-flight batch
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
|
||||
if "mask" in keep:
|
||||
x = x[keep["mask"]] # drop padding positions (see loop below)
|
||||
m = x.pow(2).sum(0) if diag else x.T @ x
|
||||
moment[name] = m if name not in moment else moment[name] + m
|
||||
cnt[name] += x.shape[0] # real (non-pad) tokens accumulated
|
||||
return _h
|
||||
|
||||
handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
# Padding activations are not task-representative; mask them out of the moment
|
||||
# so the oriented basis reflects real tokens (CorDA/SVD-LLM official code does
|
||||
# the same). The mask is per-token, shared across all target layers in a batch.
|
||||
keep.pop("mask", None)
|
||||
if isinstance(batch, dict):
|
||||
if "attention_mask" in batch:
|
||||
keep["mask"] = rearrange(batch["attention_mask"], "... -> (...)").bool().cpu()
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r, eps = cfg.r, float(cfg.cov_eps)
|
||||
for name, layer in layers.items():
|
||||
if cnt[name] < r:
|
||||
raise RuntimeError(f"covariance orient at {name}: {cnt[name]} tokens, need >= r={r}")
|
||||
# decomposition on CPU (where the moment lives); results copied back to device buffers.
|
||||
W_res = layer.weight.data.float().cpu()
|
||||
U_old, S_old, P_old = (layer.lora_U.float().cpu(),
|
||||
layer.lora_S.float().cpu(),
|
||||
layer.lora_P.float().cpu())
|
||||
W_orig = W_res + (U_old * S_old) @ P_old
|
||||
|
||||
if diag:
|
||||
c = (moment[name] / cnt[name]).clamp_min(0) + eps # (d_in,) per-channel scale
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig * c.sqrt(), full_matrices=False) # @ diag(c^1/2)
|
||||
Pr = Vht[:r] * c.rsqrt() # @ diag(c^-1/2): oblique projector
|
||||
else:
|
||||
C = moment[name] / cnt[name] # (d_in,d_in)
|
||||
lam, Q = torch.linalg.eigh(C)
|
||||
lam = lam.clamp_min(0) + eps
|
||||
Mhalf = (Q * lam.sqrt()) @ Q.T # C^{1/2}
|
||||
Minvhalf = (Q * lam.rsqrt()) @ Q.T # C^{-1/2}
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Mhalf, full_matrices=False)
|
||||
Pr = Vht[:r] @ Minvhalf # (r, d_in) oblique projector
|
||||
# Quantize the frozen buffers to their stored dtype FIRST, then form the residual
|
||||
# against those exact (bf16) values. The forward reconstructs from the bf16 buffers,
|
||||
# so W_res + U_r S_r P_r = W_orig to one residual-rounding -- without this, the
|
||||
# residual is built from fp32 U/S/P and the forward also eats the U/S/P quantization
|
||||
# mismatch, so g=0 drifts further from identity.
|
||||
Ur = Ut[:, :r].to(layer.lora_U.dtype)
|
||||
Sr = St[:r].to(layer.lora_S.dtype)
|
||||
Pr = Pr.to(layer.lora_P.dtype)
|
||||
W_res_new = W_orig - (Ur.float() * Sr.float()) @ Pr.float()
|
||||
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur)
|
||||
layer.lora_S.copy_(Sr)
|
||||
layer.lora_P.copy_(Pr)
|
||||
layer.weight.data.copy_(W_res_new.to(layer.weight))
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTOCorDA:
|
||||
name = "antipasto_corda"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r = cfg.r
|
||||
return dict(
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
# P replaces Vh: oblique covariance-oriented input projector.
|
||||
lora_P=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable per-direction log-scale. init 0 -> 1+ELU(0)=1 -> exact identity.
|
||||
# No sign-symmetry hack needed (1+ELU is sign-preserving, basis frozen),
|
||||
# matching antipasto.py.
|
||||
lora_g=ParamSpec((r,), init="zeros"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
"""Plain-SVD fallback so the adapter is valid before group_init. group_init
|
||||
refines P/U/S to the covariance-oriented basis when calibration_data is given."""
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError("AntiPaSTOCorDA mutates layer.weight into W_res; nn.Linear only.")
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_P.copy_(Vhr.to(layer.lora_P.dtype)) # P := Vh until oriented
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""CorDA: re-orient by the full input covariance C = E[x x^T] (cross-channel)."""
|
||||
_covariance_orient(model, targets, cfg, calibration_data, diag=False)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
P = layer.lora_P.to(x.dtype) # (r, d_in) oblique
|
||||
g = layer.lora_g.to(x.dtype) # (r,)
|
||||
S_eff = _gain(S, g, float(cfg.coeff), bool(cfg.suppress_only))
|
||||
h = (x @ P.T) * S_eff # (..., r)
|
||||
return y + h @ U.T
|
||||
@@ -0,0 +1,166 @@
|
||||
"""AntiPaSTO-DPLR: diagonal gain plus a low-rank mixing core in the frozen SVD basis.
|
||||
|
||||
antipasto's diagonal gain rescales each singular direction but cannot mix one into
|
||||
another. DPLR adds a trainable rank-k core that does, inside the frozen U/Vh basis:
|
||||
|
||||
W = U diag(S) Vh + W_res # frozen top-r SVD
|
||||
learn: g (r,) # diagonal gain
|
||||
A (k,r) kaiming, B (r,k) zero # low-rank mixing core
|
||||
p = x @ Vh.T # (r,) input in the frozen S-basis
|
||||
S_eff = S * (1 + ELU(coeff * g))
|
||||
h = p * S_eff + coeff * (p @ A.T) @ B.T # diagonal gain + rank-k mixing
|
||||
y = x @ W_res.T + h @ U.T
|
||||
|
||||
The rank-k term is LoRA's core (Hu+ 2021, arXiv:2106.09685) restricted to W's top-r
|
||||
subspace, ADDED to the gain rather than folded into diag(S): being independent of S, a
|
||||
unit step moves W by O(1) not O(S), so it has no singular-value amplification. Params
|
||||
= r + 2*r*k. Identity at init (B=0, g=0) and at coeff=0. Basis (U, Vh) stays frozen.
|
||||
|
||||
Refs: antipasto.py (diagonal sibling), lora.py (low-rank core), antipasto_corda.py
|
||||
(oriented basis -- composes with this core).
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTODPLRConfig(AdapterConfig):
|
||||
variant: str = "antipasto_dplr"
|
||||
r: int = 256
|
||||
# Rank of the low-rank mixing core (LoRA's r, but inside the frozen subspace).
|
||||
# Params = r (gain) + 2*r*lora_rank. Requires 1 <= lora_rank <= r.
|
||||
lora_rank: int = 8
|
||||
suppress_only: bool = False # clamp the gain g<=0 (attenuate only); core unaffected.
|
||||
coeff: float = 1.0 # runtime knob: 0=identity, scales gain and core.
|
||||
act_pool: Literal["rms", "mean_abs"] = "rms" # group_init selection, see antipasto.
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTODPLR:
|
||||
name = "antipasto_dplr"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r, k = cfg.r, cfg.lora_rank
|
||||
if not 0 < k <= r:
|
||||
raise ValueError(f"antipasto_dplr needs 0 < lora_rank({k}) <= r({r}).")
|
||||
return dict(
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Diagonal gain (== antipasto). init 0 -> 1+ELU(0)=1 -> identity.
|
||||
lora_g=ParamSpec((r,), init="zeros"),
|
||||
# Low-rank core B@A in the frozen subspace. A down (r->k), B up (k->r).
|
||||
# B=0 at init -> core=0 -> identity (LoRA convention).
|
||||
lora_A=ParamSpec((k, r), init="kaiming"),
|
||||
lora_B=ParamSpec((r, k), init="zeros"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError("AntiPaSTODPLR mutates layer.weight into W_res; nn.Linear only.")
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Wanda-style re-selection of the top-r directions, identical to antipasto.
|
||||
Runs before training while g and B are still zero, so the core contributes
|
||||
nothing and re-selecting the basis is a no-op on the adapter output."""
|
||||
if calibration_data is None:
|
||||
return
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
captured: dict[str, list[T]] = {n: [] for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
x = args[0].detach()
|
||||
captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu())
|
||||
return _h
|
||||
|
||||
handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
if isinstance(batch, dict):
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r, pool = cfg.r, cfg.act_pool
|
||||
for name, layer in layers.items():
|
||||
X = torch.cat(captured[name], dim=0)
|
||||
if X.shape[0] < r:
|
||||
raise RuntimeError(f"AntiPaSTODPLR at {name}: {X.shape[0]} tokens, need >= r={r}")
|
||||
# Rebuild the FULL W exactly (W_res + stored top-r), then re-select top-r.
|
||||
W_res = layer.weight.data.float()
|
||||
W_orig = W_res + (layer.lora_U.float() * layer.lora_S.float()) @ layer.lora_Vh.float()
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
|
||||
proj = X.to(Vh_full) @ Vh_full.T
|
||||
act_mag = proj.pow(2).mean(0).sqrt() if pool == "rms" else proj.abs().mean(0)
|
||||
idx = (S_full * act_mag).argsort(descending=True)[:r].sort().values
|
||||
Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx]
|
||||
W_res_new = (W_orig - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh))
|
||||
layer.weight.data.copy_(W_res_new)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
g = layer.lora_g.to(x.dtype) # (r,)
|
||||
A = layer.lora_A.to(x.dtype) # (k, r)
|
||||
B = layer.lora_B.to(x.dtype) # (r, k)
|
||||
coeff = float(cfg.coeff)
|
||||
|
||||
if cfg.suppress_only:
|
||||
g = torch.clamp(g, max=0.0)
|
||||
|
||||
p = x @ Vh.T # (..., r) = Vh x (unscaled)
|
||||
S_eff = S * (1.0 + F.elu(coeff * g)) # diagonal gain (see antipasto.py)
|
||||
# Diagonal part scales each direction; low-rank part B@A mixes across the
|
||||
# subspace. Additive (not * diag(S)), so the core is S-independent: a unit
|
||||
# step in B@A moves W by O(1), not O(S) -- no S-amplification edge.
|
||||
h = p * S_eff + coeff * (p @ A.T) @ B.T # (..., r)
|
||||
return y + h @ U.T
|
||||
@@ -0,0 +1,227 @@
|
||||
"""AntiPaSTO-Rot: SVD adapter with learnable singular-value deltas + a block-diagonal
|
||||
Cayley rotation of the frozen basis. The rotation arm vs antipasto.py's gain-only core.
|
||||
|
||||
wassname 2026 https://arxiv.org/abs/2601.07473
|
||||
|
||||
W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r)
|
||||
learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2)
|
||||
R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T)
|
||||
y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T
|
||||
|
||||
Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ~ x @ W^T (tiny positive bias on
|
||||
delta_s breaks sign symmetry; rotation alone can't).
|
||||
|
||||
Refs:
|
||||
- paper: https://github.com/wassname/AntiPaSTO
|
||||
- lite port of: https://github.com/wassname/antipasto3
|
||||
(offline: docs/refs/antipasto3_svd_adapter.py)
|
||||
"""
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Literal
|
||||
|
||||
import torch
|
||||
from einops import einsum, rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTORotConfig(AdapterConfig):
|
||||
variant: str = "antipasto_rot"
|
||||
# Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out).
|
||||
r: int = 256
|
||||
# Block size for the block-diagonal Cayley rotation. r must be divisible by it.
|
||||
block_size: int = 4
|
||||
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
|
||||
max_rotation_angle: float = 0.5
|
||||
# Which singular basis to rotate: 'V' (input), 'U' (output), 'both', or 'none'.
|
||||
rotate_basis: Literal["V", "U", "both", "none"] = "V"
|
||||
|
||||
|
||||
def _cayley(skew: torch.Tensor) -> torch.Tensor:
|
||||
"""R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality."""
|
||||
bs = skew.shape[-1]
|
||||
eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew)
|
||||
X = skew / 2
|
||||
return torch.linalg.solve(eye - X, eye + X)
|
||||
|
||||
|
||||
def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor:
|
||||
"""rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation."""
|
||||
n_blocks, _ = rot_T.shape
|
||||
rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0)
|
||||
A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device)
|
||||
A[:, rows, cols] = rot_T
|
||||
A = 0.5 * (A - A.transpose(-1, -2))
|
||||
a_limit = 2.0 * math.tan(max_angle / 2.0)
|
||||
A = a_limit * torch.tanh(A / a_limit)
|
||||
return _cayley(A)
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTORot:
|
||||
name = "antipasto_rot"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r = cfg.r
|
||||
bs = int(cfg.block_size)
|
||||
if r % bs != 0:
|
||||
raise ValueError(f"AntiPaSTORot requires r={r} divisible by block_size={bs}")
|
||||
specs = dict(
|
||||
# Frozen SVD components captured at init.
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable: per-singular-value delta.
|
||||
# antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign
|
||||
# symmetry (rotation alone can't); zero-init works but trains slower.
|
||||
lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)),
|
||||
)
|
||||
if cfg.rotate_basis != "none":
|
||||
n_blocks = r // bs
|
||||
n_triu = bs * (bs - 1) // 2
|
||||
specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros")
|
||||
if cfg.rotate_basis == "both":
|
||||
# 'both' rotates V (lora_rot_T) and U independently; lora_rot_T_u is the U-side.
|
||||
specs["lora_rot_T_u"] = ParamSpec((n_blocks, n_triu), init="zeros")
|
||||
return specs
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError(
|
||||
"AntiPaSTORot mutates layer.weight into W_res (like PiSSA), so v1 "
|
||||
"only supports plain nn.Linear, not bnb 4/8-bit."
|
||||
)
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
# group_init() refines this to input-aligned directions if calibration_data is given.
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Wanda-style data-driven dimension selection within the weight SVD.
|
||||
|
||||
init() picks the top-r singular dimensions by S alone (PiSSA-style).
|
||||
group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions
|
||||
that are both large in W AND active given real inputs.
|
||||
|
||||
If calibration_data is None the weight-SVD init from init() is kept.
|
||||
"""
|
||||
if calibration_data is None:
|
||||
return
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
captured: dict[str, list[T]] = {n: [] for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
x = args[0].detach()
|
||||
captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu())
|
||||
return _h
|
||||
|
||||
handles = [
|
||||
layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True)
|
||||
for n in layers
|
||||
]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
if isinstance(batch, dict):
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r = cfg.r
|
||||
for name, layer in layers.items():
|
||||
X = torch.cat(captured[name], dim=0) # (N, d_in)
|
||||
if X.shape[0] < r:
|
||||
raise RuntimeError(
|
||||
f"AntiPaSTORot at {name}: only {X.shape[0]} calibration tokens, need >= r={r}"
|
||||
)
|
||||
|
||||
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r components
|
||||
W_res = layer.weight.data.float()
|
||||
U_old = layer.lora_U.float() # (d_out, r)
|
||||
S_old = layer.lora_S.float() # (r,)
|
||||
Vh_old = layer.lora_Vh.float() # (r, d_in)
|
||||
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
|
||||
|
||||
# Full SVD to score all dimensions
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
|
||||
# score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude)
|
||||
act_mag = (X.to(Vh_full) @ Vh_full.T).abs().mean(dim=0) # (k,) -- X captured on CPU
|
||||
scores = S_full * act_mag
|
||||
idx = scores.argsort(descending=True)[:r] # top-r by joint importance
|
||||
idx = idx.sort().values # stable ordering
|
||||
|
||||
Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx]
|
||||
W_res_new = (W_orig - (Ur * Sr.unsqueeze(0)) @ Vhr).to(layer.weight.dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh))
|
||||
layer.weight.data.copy_(W_res_new)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
bs = int(cfg.block_size)
|
||||
max_angle = float(cfg.max_rotation_angle)
|
||||
rotate_basis = cfg.rotate_basis
|
||||
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
|
||||
if rotate_basis == "none":
|
||||
U_eff, Vh_eff = U, Vh
|
||||
else:
|
||||
R = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype)
|
||||
n_blocks = R.shape[0] # R: (n, bs, bs)
|
||||
U_eff, Vh_eff = U, Vh
|
||||
# 'V'/'U' rotate that one basis with lora_rot_T; 'both' rotates V with
|
||||
# lora_rot_T and U with a separate lora_rot_T_u (independent rotations).
|
||||
if rotate_basis in ("V", "both"):
|
||||
Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks)
|
||||
Vh_eff = rearrange(einsum(R, Vh_blocks, "n a b, n b i -> n a i"), "n a i -> (n a) i")
|
||||
if rotate_basis in ("U", "both"):
|
||||
R_u = _build_rotation(layer.lora_rot_T_u.float(), bs, max_angle).to(x.dtype) if rotate_basis == "both" else R
|
||||
U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks)
|
||||
U_eff = rearrange(einsum(U_blocks, R_u, "d n b, n c b -> d n c"), "d n c -> d (n c)")
|
||||
|
||||
S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,)
|
||||
h = x @ Vh_eff.T # x @ Vh_eff.T
|
||||
h = h * S_eff # diag(S_eff)
|
||||
delta = h @ U_eff.T # @ U_eff.T
|
||||
return y + delta
|
||||
@@ -63,12 +63,15 @@ class EVA:
|
||||
# Collect input activations per target via forward hooks.
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
captured: dict[str, list[T]] = {n: [] for n in layers}
|
||||
keep: dict[str, T] = {} # non-pad mask of the in-flight batch
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
# signature: pre-forward, args[0] is the input tensor
|
||||
x = args[0].detach()
|
||||
captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu())
|
||||
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
|
||||
if "mask" in keep:
|
||||
x = x[keep["mask"]] # drop padding positions (see loop below)
|
||||
captured[name].append(x)
|
||||
return _h
|
||||
|
||||
handles = [
|
||||
@@ -80,7 +83,12 @@ class EVA:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
# Padding activations are not task-representative; mask them out of the
|
||||
# PCA so the basis reflects real tokens (matches antipasto_corda).
|
||||
keep.pop("mask", None)
|
||||
if isinstance(batch, dict):
|
||||
if "attention_mask" in batch:
|
||||
keep["mask"] = rearrange(batch["attention_mask"], "... -> (...)").bool().cpu()
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
|
||||
@@ -14,7 +14,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from dataclasses import replace
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -22,20 +21,25 @@ import torch
|
||||
|
||||
import lora_lite as ll
|
||||
|
||||
_SCRIPTS = Path(__file__).resolve().parent.parent / "scripts"
|
||||
sys.path.insert(0, str(_SCRIPTS)) # benchmark does `from _cost import ...` (sibling module)
|
||||
SPEC = importlib.util.spec_from_file_location(
|
||||
"metamath_benchmark",
|
||||
Path(__file__).resolve().parent.parent / "scripts" / "metamath_gsm8k_benchmark.py",
|
||||
_SCRIPTS / "metamath_gsm8k_benchmark.py",
|
||||
)
|
||||
benchmark = importlib.util.module_from_spec(SPEC)
|
||||
sys.modules[SPEC.name] = benchmark
|
||||
SPEC.loader.exec_module(benchmark)
|
||||
|
||||
|
||||
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"]
|
||||
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
|
||||
"antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda",
|
||||
"antipasto_asvd", "antipasto_dplr", "road"]
|
||||
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
|
||||
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
|
||||
# so we don't expect a raise from them in the attach-only smoke.
|
||||
BNB_RAISERS = {"pissa", "dora", "antipasto"}
|
||||
BNB_RAISERS = {"pissa", "dora", "antipasto", "antipasto_rot", "antipasto_ablate",
|
||||
"antipasto_corda", "antipasto_dplr"}
|
||||
TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
@@ -57,6 +61,7 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc
|
||||
quantization=quantization,
|
||||
r=4,
|
||||
alpha=8,
|
||||
antipasto_lora_rank=2, # antipasto_dplr needs 0 < lora_rank <= r (r=4 here)
|
||||
target_name=target_name,
|
||||
layers="all",
|
||||
steps=2,
|
||||
@@ -75,8 +80,6 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc
|
||||
log_every=1000,
|
||||
output_dir=tmp_path / "out",
|
||||
)
|
||||
if variant == "antipasto":
|
||||
cfg = replace(cfg, alpha=4) # block_size=4 -> need r % 4 == 0
|
||||
return cfg
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user