diff --git a/justfile b/justfile index 4d6f30a..fa98cb1 100644 --- a/justfile +++ b/justfile @@ -75,7 +75,7 @@ metamath-queue variant="lora" steps="5000" model="Qwen/Qwen3-0.6B-Base": # Run a single MetaMathQA->GSM8K benchmark for a given variant. # Per-variant lr / target-name defaults are baked in here. -bench-variant model variant steps="5000": +bench-variant model variant steps="5000" block="8": #!/usr/bin/env bash set -euo pipefail lr=1e-4 @@ -100,6 +100,7 @@ bench-variant model variant steps="5000": --steps {{steps}} \ --lr "$lr" \ --target-name "$target" \ + --antipasto-block {{block}} \ --layers all --r "$r" --alpha "$alpha" metamath-queue-all model="Qwen/Qwen3-0.6B-Base" steps="5000" variants="lora pissa delora dora hra ia3 ia3_ff eva antipasto": diff --git a/scripts/metamath_gsm8k_benchmark.py b/scripts/metamath_gsm8k_benchmark.py index a335e54..97271a5 100644 --- a/scripts/metamath_gsm8k_benchmark.py +++ b/scripts/metamath_gsm8k_benchmark.py @@ -533,6 +533,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]: dtype = getattr(torch, args.torch_dtype) run_commit = current_git_commit() run_id = f"{args.model.replace('/', '--')}__{args.variant}__s{args.steps}__seed{args.seed}" + # arrow's capacity is set by block, not r, so keep block-sweep runs from colliding. + if args.variant == "antipasto_arrow" and args.antipasto_block != 8: + run_id += f"__b{args.antipasto_block}" out_dir = args.output_dir / run_id out_dir.mkdir(parents=True, exist_ok=True) diff --git a/src/lora_lite/variants/antipasto.py b/src/lora_lite/variants/antipasto.py index 59aaa43..71fe778 100644 --- a/src/lora_lite/variants/antipasto.py +++ b/src/lora_lite/variants/antipasto.py @@ -1,42 +1,22 @@ -"""AntiPaSTO: SVD steering with learnable, bounded singular-value reweighting. +"""AntiPaSTO: learnable bounded reweighting of frozen SVD singular values. wassname 2026 https://arxiv.org/abs/2601.07473 - W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r) - learn: g (r,) per-singular-direction gain log/lin-scale - S_eff = S * (1 + ELU(coeff * g)) exp(.) for g<=0, 1+. for g>0 - suppress_only: clamp g<=0 -> factor in (0,1], attenuation only + W = U diag(S) Vh + W_res # top-r SVD; W_res = W - U_r S_r Vh_r, frozen + learn: g (r,) # per-direction gain + S_eff = S * (1 + ELU(coeff * g)) # exp(z) for z<0 (bounded), 1+z for z>0 y = x @ W_res.T + ((x @ Vh.T) * S_eff) @ U.T -Identity at g=0 (or coeff=0): 1+ELU(0)=1 exactly, so S_eff = S and the output is -x @ W^T up to the one-time SVD-residual rounding. No additive sign-symmetry hack -needed: the basis is frozen, so the direction sign is fixed and exp/(1+.) is -sign-preserving. The 1+ELU shape is chosen over linear (sign-flips at g<-1), exp -(amplification blows up), and tanh (arbitrary bound) -- see forward() for why. + suppress_only: clamp g<=0 -> S_eff in (0, S], attenuation only. + coeff: runtime scale; 0 = identity, <0 swaps amplify/suppress. -Changes vs the rotation version this replaces: - - Rotation dropped. Rotating Vh/U leaves the interpretable singular basis (the - SVD-direction / Conjecture property), which is the entire point of steering in - S-space, and the Cayley solve was numerically finicky. The basis is now frozen; - the only learned object is the per-direction gain. If you later want - cross-direction mixing, add a *fixed-basis* core U M Vh (M trainable, U/Vh frozen) - rather than rotating -- that keeps the directions interpretable. It is also far - cheaper than PiSSA: a dense r x r core is r^2 params (~= a rank-8 LoRA at r=256), - versus PiSSA's free A,B at r*(d_in+d_out), which drifts off the SVD basis. - - Additive delta_s -> bounded multiplicative S * (1 + ELU(coeff*g)). Multiplicative - is "scaled by S" (uniform *relative* control over an orders-of-magnitude spectrum), - stays positive (no S_eff<0 sign-flip -> no incoherence from that path), and the - 1+ELU shape stops the exp blowup. The 4e-4 sign-symmetry hack is gone. - - suppress_only = clamp g<=0 -> factor in (0,1]: attenuation only, structurally - cannot blow up. Matches the eval-awareness use case (turn a direction down). - - coeff: runtime steering scalar (0 = identity, <0 inverts). The per-call alpha - the rotation version lacked. - - group_init activation pooling is configurable: 'rms' weights outliers (ASVD - intuition), 'mean_abs' is the original outlier-robust pooling. +Identity at g=0 or coeff=0: 1+ELU(0)=1, so S_eff=S (up to the bf16 SVD round-trip). +The basis (U, Vh) is frozen, so the singular directions stay interpretable and only +the gain is learned. See forward() for why 1+ELU over linear/exp/tanh. Refs: - paper: https://github.com/wassname/AntiPaSTO - - sibling (whitened, rotation-free, mean-diff): steering-lite/.../sspace.py + - sibling (whitened, mean-diff): steering-lite/.../sspace.py """ from dataclasses import dataclass from typing import Iterable, Literal @@ -107,14 +87,15 @@ class AntiPaSTO: @staticmethod def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: - """Wanda-style, data-driven dimension selection within the weight SVD. + """Data-driven re-selection of which top-r singular directions to keep. - init() picks the top-r singular dimensions by S alone (PiSSA-style). - group_init() re-selects by score[i] = S[i] * pool|X @ Vh[i]|: dimensions - that are both large in W AND active on real inputs. pool = 'rms' (outlier- - sensitive, the ASVD intuition that activation outliers carry signal) or - 'mean_abs' (the original, outlier-robust). If calibration_data is None the - weight-SVD init from init() is kept. + init(): top-r by S alone (PiSSA-style) + group_init(): top-r by score[i] = S[i] * pool|X @ Vh[i]| (Wanda/ASVD) + pool = 'rms' (outlier-sensitive) | 'mean_abs' (outlier-robust) + + This re-RANKS W's own singular vectors by activation; it does NOT re-orient + the basis (that is CorDA -> antipasto_corda.py). So the kept directions are + still plain weight-SVD directions, just a better subset. None -> keep init(). """ if calibration_data is None: return @@ -158,7 +139,9 @@ class AntiPaSTO: f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}" ) - # Recover W_orig: init() wrote W_res into layer.weight and stored top-r. + # Rebuild the FULL W: init() stored the exact top-r it subtracted, so + # W_res + U_r S_r Vh_r == W (full rank, not a cropped matrix). The SVD + # below therefore re-selects from W's whole spectrum, not a truncation. W_res = layer.weight.data.float() U_old = layer.lora_U.float() S_old = layer.lora_S.float() @@ -200,21 +183,13 @@ class AntiPaSTO: if cfg.suppress_only: g = torch.clamp(g, max=0.0) # factor in (0,1]: attenuation only - # Per-direction reweighting: S_eff = S * (1 + ELU(coeff * g)). - # 1 + ELU(z) = exp(z) for z<=0, 1+z for z>0. - # Why this and not the obvious ones (all of which we tried): - # linear S*(1+z) : constant gradient (stable), but z<-1 -> S_eff<0, - # a sign flip that drives incoherence. Unstable in - # the negatives. - # exp S*exp(z) : positive, but unbounded and the gradient self- - # amplifies (d/dz exp = exp), so amplification blows up. - # tanh S*exp(c*tanh z): bounded, but c is an arbitrary free knob with no - # principled value, and saturation kills the gradient. - # 1+ELU : uses each in its safe regime -- exp only where it is - # bounded in (0,1] (attenuation, cannot go negative), - # linear where exp would diverge (amplification, const - # gradient). C1 at z=0 (both -> 1, slope 1); >0 always. - # coeff=0 or g=0 -> S_eff = S (identity). coeff<0 swaps amplify/suppress. + # S_eff = S * (1 + ELU(z)), z = coeff*g, 1+ELU(z) = exp(z) for z<=0 else 1+z. + # Why 1+ELU and not the obvious alternatives: + # linear S*(1+z) : z<-1 -> S_eff<0, a sign flip that drives incoherence. + # exp S*exp(z) : unbounded, gradient self-amplifies (amplification blows up). + # tanh bounded : arbitrary bound knob, saturation kills the gradient. + # 1+ELU uses each in its safe regime: exp where it is bounded in (0,1] + # (attenuation), linear where exp would diverge (amplification). >0 always. S_eff = S * (1.0 + torch.nn.functional.elu(coeff * g)) h = (x @ Vh.T) * S_eff # input in S-coords, reweighted diff --git a/src/lora_lite/variants/antipasto_ablate.py b/src/lora_lite/variants/antipasto_ablate.py index 34d26a8..be8d1ff 100644 --- a/src/lora_lite/variants/antipasto_ablate.py +++ b/src/lora_lite/variants/antipasto_ablate.py @@ -1,34 +1,26 @@ """AntiPaSTO-Ablate: trainable directional ablation in the weight-SVD output basis. -A contractive sibling of antipasto.py. Instead of reweighting the singular gains, -it projects out a learned direction in the *output* singular basis (the U side): +A contractive sibling of antipasto.py: instead of reweighting the singular gains it +projects out a learned direction in the output (U-side) singular basis. W = U diag(S) Vh + W_res learn: c (r, k) ablation directions, alpha (k,) strengths in [0, 1] - Chat = orthonormal(c) # k unit dirs in S-space - h = (x @ Vh.T) * S # output S-coords (= diag(S) Vh x) - h <- h - coeff * (h @ Chat) * alpha @ Chat.T # project the span out + Chat = orthonormal(c) # k unit dirs in S-space + h = (x @ Vh.T) * S # output S-coords = diag(S) Vh x + h <- h - coeff * (h @ Chat) * alpha @ Chat.T # project the span out y = x @ W_res.T + h @ U.T -Why this instead of gain reweighting (antipasto.py): - - The core (I - alpha Chat Chat^T) is a CONTRACTION: eigenvalues are 1-alpha along - Chat and 1 elsewhere, all in [0, 1] for alpha in [0, 1]. It cannot amplify and - cannot blow up, so the failure mode the multiplicative gain fights with bounds is - structurally absent. It is also the natural core to recurse (a contraction composed - with itself converges; an amplifier diverges). - - It is the trainable form of directional ablation (Arditi+ 2024). Ablating Chat in - the middle removes output direction U Chat; for a residual *writer* - (mlp.down_proj, self_attn.o_proj) that is a residual-stream direction -- the - SURGICAL regime in the steering-lite sweeps (directional_ablation topped SI). - Target writers, not all Linears, or you get the broad-suppression regime. +The core (I - alpha Chat Chat^T) is a contraction: eigenvalues 1-alpha along Chat, +1 elsewhere, all in [0, 1]. It cannot amplify, so it cannot blow up -- the instability +the multiplicative gain bounds away is structurally absent (and a contraction is the +natural core to recurse). This is the trainable form of directional ablation (Arditi+ +2024): target residual writers (down_proj, o_proj) for the surgical regime, not all +Linears. -Runtime: coeff is the per-call knob. coeff=0 -> identity. coeff in (0, 1] -> ablate. -coeff < 0 -> *add* the direction back (amplify) -- the bidirectional dual; this is the -side that can grow, so bound coeff there. +Runtime: coeff is the per-call knob. coeff=0 -> identity; (0, 1] -> ablate; <0 adds the +direction back (the side that can grow, so bound coeff there). -Init: alpha small (>0 so c receives gradient), c random-normalized. The strong init is -to warm-start c from the contrastive direction dS in S-space (extract it exactly like -sspace.py: dS = mean(xS_pos) - mean(xS_neg) on persona-branching pairs), then fine-tune. +Refs: antipasto.py (gain sibling), directional ablation Arditi+ 2024 arXiv:2406.11717. """ from dataclasses import dataclass from typing import Iterable diff --git a/src/lora_lite/variants/antipasto_arrow.py b/src/lora_lite/variants/antipasto_arrow.py index 8be5cd3..deb5da6 100644 --- a/src/lora_lite/variants/antipasto_arrow.py +++ b/src/lora_lite/variants/antipasto_arrow.py @@ -1,47 +1,22 @@ -"""AntiPaSTO-Arrow: a STRUCTURED fixed-basis core, the cheap way to add cross- -direction mixing that plain antipasto (a diagonal gain) cannot express. +"""AntiPaSTO-Arrow: cross-direction mixing via a cheap arrowhead core. -antipasto's core is diagonal: S_eff = S * (1 + ELU(coeff*g)) reweights each frozen -singular direction independently. It can turn a direction up or down but it can never -let direction i's input drive direction j's output. Yet the behaviour you steer is a -combination Sigma c_i v_i that generically lies OFF any single axis (the same argument -that motivates antipasto_corda), so a diagonal core can only ever approximate it. - -The obvious fix -- a full dense r x r core M, DeltaW = U M Vh -- restores all mixing but -costs r^2 params (r=256 -> 65536, a rank-8 LoRA's worth) and an r x r matmul per forward. -antipasto.py's own header flags this trap: "a dense r x r core is r^2 params ... add a -*fixed-basis* core U M Vh rather than rotating". This file is that core, made cheap by -making it STRUCTURED instead of dense -- an arrowhead, not an r x r. - -Arrowhead structure (dense top-block + diagonal tail): - - core C (r x r, acting on the S-scaled coords) = - - [ B (b x b dense) | 0 ] B couples the top-b directions - [ 0 | diag(1+ELU(c*g)) ] tail = exactly antipasto's gain +antipasto's core is diagonal (S_eff = S * gain): it reweights each singular direction +independently but cannot let direction i drive direction j. A full dense r x r core +restores all mixing but costs r^2 params. The arrowhead is the cheap middle: a dense +block on the top-b directions (where the action lives), the diagonal gain on the rest. + core C (r x r, on the S-scaled coords): + [ B (b x b dense) | 0 ] B = I_b + coeff*M (top-b mixing) + [ 0 | diag(1 + ELU(coeff*g)) ] tail = antipasto's gain DeltaW = U @ C @ diag(S) @ Vh + cost: b^2 + (r-b) params, one b x b matmul per forward. -The top b singular directions (largest S = where PiSSA says the action lives) get a full -b x b interaction block B = I_b + coeff*M; the remaining r-b stay on the cheap bounded -diagonal gain. Cost is b^2 + (r-b) params and one b x b matmul per forward -- for b=8,r=256 -that is 312 params and a 64-MAC corner, versus 65536 for dense r x r and versus the -rotation variant's per-forward Cayley solve (measured 72ms vs 36ms). So: cross-direction -mixing where it matters, at diagonal-core cost. +Identity at init: M=0 -> B=I, g=0 -> 1+ELU(0)=1, so C=I and DeltaW = U diag(S) Vh. +coeff=0 -> C=I too (runtime off). The block is the linear (1+z) regime -- stable but +not strictly bounded; for a can't-blow-up guarantee on the top directions use +antipasto_ablate. -(We call it "arrowhead" after the shape -- a dense head with a diagonal shaft. A true -numerical-LA arrowhead also carries a hub row+column coupling the block to the tail; that -would add 2(r-b) params and is a one-line extension if the top-b span turns out too small. -Not added until measured to be needed.) - -Identity at init: M=0 -> B=I_b, g=0 -> 1+ELU(0)=1, so C=I and DeltaW = U diag(S) Vh exactly -(up to the one-time SVD-residual rounding). coeff=0 -> C=I too (runtime off). The block is -the linear-amplification regime of antipasto's design (a matmul, constant-gradient, no exp -self-amplification); it is stable like 1+ELU's upper branch, not strictly bounded -- if you -need the tail's structural can't-blow-up guarantee on the top directions too, use -antipasto_ablate instead. - -Refs: antipasto.py (diagonal sibling), antipasto_corda.py (the off-axis argument). +Refs: antipasto.py (diagonal sibling), antipasto_corda.py (off-axis basis argument). """ from dataclasses import dataclass from typing import Iterable, Literal @@ -64,8 +39,9 @@ CalibrationData = Iterable[CalibrationBatch] class AntiPaSTOArrowConfig(AdapterConfig): variant: str = "antipasto_arrow" r: int = 256 - # Size of the dense interaction block on the top-b singular directions. The ONLY - # quadratic cost (b^2 params); keep small. b=1 degenerates to antipasto. + # Dense interaction block on the top-b singular directions; sets capacity and the + # only quadratic cost (b^2 params/module). b=1 degenerates to antipasto; b->r + # approaches a full dense r-core (~LoRA params) at the cost arrow exists to avoid. block: int = 8 suppress_only: bool = False # clamp the tail g<=0 (attenuate only); block unaffected. # Tail guarantee holds for coeff>=0; coeff<0 inverts the product and re-amplifies. @@ -152,11 +128,9 @@ class AntiPaSTOArrow: U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False) proj = X.to(Vh_full) @ Vh_full.T act_mag = proj.pow(2).mean(0).sqrt() if pool == "rms" else proj.abs().mean(0) - # Select top-r by score, then re-sort ascending by SVD index. Since svd() - # returns S descending, the first b stored dirs (the block's cS[..., :b]) are - # the b LARGEST-S among the selected r -- not the b highest-score. Matches the - # block's "largest S = where the action lives" intent, but a high-S dir dropped - # by score-selection won't be in the block. + # Pick top-r by score, then sort by SVD index. svd() returns S descending, + # so the block's first-b coords are the b largest-S among the selected r + # (= where the action lives), not the b highest-score. idx = (S_full * act_mag).argsort(descending=True)[:r].sort().values Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx] W_res_new = (W_orig - (Ur * Sr) @ Vhr).to(layer.weight.dtype) diff --git a/src/lora_lite/variants/antipasto_corda.py b/src/lora_lite/variants/antipasto_corda.py index 6ee3723..9ca5bf0 100644 --- a/src/lora_lite/variants/antipasto_corda.py +++ b/src/lora_lite/variants/antipasto_corda.py @@ -1,42 +1,23 @@ -"""AntiPaSTO-CorDA: steer in a covariance-ORIENTED basis, not the weight-gain basis. +"""AntiPaSTO-CorDA: reweight in a covariance-oriented basis, not the weight basis. -The complaint that motivates this: plain SVD sorts directions by weight gain ||W v|| -on an *isotropic* input. The behaviour you steer lives where the *data* has energy. -Those orderings disagree, so the behaviour smears off the top singular axes and a -top-r crop in the weight basis throws it away. CorDA (Yang+ 2024, arXiv:2406.05223) -re-orients the decomposition by the input covariance C = E[x x^T], so the top -directions are the ones with the most energy *on real activations*. +Plain SVD sorts directions by weight gain ||W v|| on isotropic input. The behaviour +you steer lives where the DATA has energy, off the top weight-singular axes. CorDA +(Yang+ 2024, arXiv:2406.05223) re-orients the SVD by the input covariance, so the top-r +directions move the output most on real activations. -Decomposition (verified: full-rank reconstruction ~1e-5, and on anisotropic data the -top-r data-truncation error drops ~27x vs plain SVD): + C = E[x x^T] (+ eps I) # input second moment on calibration data + C^{1/2}, C^{-1/2} via eigh(C) + U S Vht = SVD(W C^{1/2}) + P = Vht C^{-1/2} # (r, d_in) oblique input projector + W = U diag(S) P (exactly) + S_eff = S * (1 + ELU(coeff*g)) # same bounded gain as antipasto + y = x @ W_res.T + ((x @ P.T) * S_eff) @ U.T - C = E[x x^T] (+ eps I) # input second moment on calibration data - C^{1/2}, C^{-1/2} via eigh(C) - W~ = W C^{1/2}; SVD(W~) = U S V~h - P = V~h C^{-1/2} # (r, d_in) OBLIQUE input projector - W = U diag(S) P (exactly) # so y = x W_res^T + ((x P^T) * S_eff) U^T +Identity at g=0 or coeff=0: S_eff=S. P is oblique (rows not orthonormal -- C^{-1/2} +skews them); fine for gain reweighting and for output-side ablation (the obliqueness +is input-side; U stays orthonormal). No calibration_data -> plain SVD (== antipasto). -S here are the singular values of W weighted by input std, so top-r is the optimal -rank-r in the input-weighted norm E||(W - W_r) x||^2 -- the directions that actually -move the output on your data. - -Connection to the shared/differing-basis problem: C is built from pos AND neg inputs -pooled, so P spans the *shared* activation structure (the common encoder) that -chosen-minus-rejected cancels by construction. A trainable gain on this basis can -therefore reach shared structure that contrastive dS extraction is blind to. - -Core: rotation-free. S_eff = S * (1 + ELU(coeff * g)). This is exp(coeff*g) on the -attenuation side (g<0, bounded, no blow-up) and 1+coeff*g on the amplification side -(g>0, where exp would diverge). g=0 -> identity. coeff is the runtime knob (0=off). - -Basis note: P is OBLIQUE (rows not orthonormal -- C^{-1/2} skews them). That is fine -for gain reweighting (we scale oblique coordinates), and also fine for OUTPUT-side -directional ablation: the obliqueness is input-side only, while ablation acts in the -U/output space where U stays orthonormal. antipasto_ablate has a cov_orient flag that -reuses this basis -- at low r it captures the behavior output direction that plain-SVD -top-r drops (measured 1.00 vs 0.65 at r=16). - -Falls back to plain SVD (== antipasto, rotation-free) if no calibration_data. +Refs: antipasto.py (gain + selection sibling), CorDA arXiv:2406.05223. """ from dataclasses import dataclass from typing import Iterable