variants: clean docstrings to research pseudocode; arrow block param

Rewrite antipasto/ablate/corda/arrow docstrings to the house style (purpose +
math block + identity line + refs), dropping the rambly meta-commentary aimed at
past design decisions ('Changes vs the rotation version', chat references, inline
measurements). Net -74 lines.

Also answer the FIXMEs left on main's old copy:
  - group_init is Wanda/ASVD *selection* (re-rank W's own singular vectors), NOT
    CorDA re-orientation -- that is antipasto_corda.py.
  - it rebuilds the FULL W exactly (W_res + stored top-r == W), so the re-SVD sees
    the whole spectrum, not a cropped matrix.

Arrow capacity: --antipasto-block CLI knob (justfile bench-variant 4th arg) so the
block can be scaled toward LoRA params; run_id gets a __b<N> suffix so block-sweep
runs do not collide. Smoke green (14 passed).

Co-Authored-By: Claudypoo <noreply@anthropic.com>
This commit is contained in:
wassname
2026-06-15 18:09:53 +08:00
parent 90b5199ed9
commit d9d31a160f
6 changed files with 83 additions and 157 deletions
+2 -1
View File
@@ -75,7 +75,7 @@ metamath-queue variant="lora" steps="5000" model="Qwen/Qwen3-0.6B-Base":
# Run a single MetaMathQA->GSM8K benchmark for a given variant. # Run a single MetaMathQA->GSM8K benchmark for a given variant.
# Per-variant lr / target-name defaults are baked in here. # Per-variant lr / target-name defaults are baked in here.
bench-variant model variant steps="5000": bench-variant model variant steps="5000" block="8":
#!/usr/bin/env bash #!/usr/bin/env bash
set -euo pipefail set -euo pipefail
lr=1e-4 lr=1e-4
@@ -100,6 +100,7 @@ bench-variant model variant steps="5000":
--steps {{steps}} \ --steps {{steps}} \
--lr "$lr" \ --lr "$lr" \
--target-name "$target" \ --target-name "$target" \
--antipasto-block {{block}} \
--layers all --r "$r" --alpha "$alpha" --layers all --r "$r" --alpha "$alpha"
metamath-queue-all model="Qwen/Qwen3-0.6B-Base" steps="5000" variants="lora pissa delora dora hra ia3 ia3_ff eva antipasto": metamath-queue-all model="Qwen/Qwen3-0.6B-Base" steps="5000" variants="lora pissa delora dora hra ia3 ia3_ff eva antipasto":
+3
View File
@@ -533,6 +533,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
dtype = getattr(torch, args.torch_dtype) dtype = getattr(torch, args.torch_dtype)
run_commit = current_git_commit() run_commit = current_git_commit()
run_id = f"{args.model.replace('/', '--')}__{args.variant}__s{args.steps}__seed{args.seed}" run_id = f"{args.model.replace('/', '--')}__{args.variant}__s{args.steps}__seed{args.seed}"
# arrow's capacity is set by block, not r, so keep block-sweep runs from colliding.
if args.variant == "antipasto_arrow" and args.antipasto_block != 8:
run_id += f"__b{args.antipasto_block}"
out_dir = args.output_dir / run_id out_dir = args.output_dir / run_id
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
+28 -53
View File
@@ -1,42 +1,22 @@
"""AntiPaSTO: SVD steering with learnable, bounded singular-value reweighting. """AntiPaSTO: learnable bounded reweighting of frozen SVD singular values.
wassname 2026 https://arxiv.org/abs/2601.07473 wassname 2026 https://arxiv.org/abs/2601.07473
W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r) W = U diag(S) Vh + W_res # top-r SVD; W_res = W - U_r S_r Vh_r, frozen
learn: g (r,) per-singular-direction gain log/lin-scale learn: g (r,) # per-direction gain
S_eff = S * (1 + ELU(coeff * g)) exp(.) for g<=0, 1+. for g>0 S_eff = S * (1 + ELU(coeff * g)) # exp(z) for z<0 (bounded), 1+z for z>0
suppress_only: clamp g<=0 -> factor in (0,1], attenuation only
y = x @ W_res.T + ((x @ Vh.T) * S_eff) @ U.T y = x @ W_res.T + ((x @ Vh.T) * S_eff) @ U.T
Identity at g=0 (or coeff=0): 1+ELU(0)=1 exactly, so S_eff = S and the output is suppress_only: clamp g<=0 -> S_eff in (0, S], attenuation only.
x @ W^T up to the one-time SVD-residual rounding. No additive sign-symmetry hack coeff: runtime scale; 0 = identity, <0 swaps amplify/suppress.
needed: the basis is frozen, so the direction sign is fixed and exp/(1+.) is
sign-preserving. The 1+ELU shape is chosen over linear (sign-flips at g<-1), exp
(amplification blows up), and tanh (arbitrary bound) -- see forward() for why.
Changes vs the rotation version this replaces: Identity at g=0 or coeff=0: 1+ELU(0)=1, so S_eff=S (up to the bf16 SVD round-trip).
- Rotation dropped. Rotating Vh/U leaves the interpretable singular basis (the The basis (U, Vh) is frozen, so the singular directions stay interpretable and only
SVD-direction / Conjecture property), which is the entire point of steering in the gain is learned. See forward() for why 1+ELU over linear/exp/tanh.
S-space, and the Cayley solve was numerically finicky. The basis is now frozen;
the only learned object is the per-direction gain. If you later want
cross-direction mixing, add a *fixed-basis* core U M Vh (M trainable, U/Vh frozen)
rather than rotating -- that keeps the directions interpretable. It is also far
cheaper than PiSSA: a dense r x r core is r^2 params (~= a rank-8 LoRA at r=256),
versus PiSSA's free A,B at r*(d_in+d_out), which drifts off the SVD basis.
- Additive delta_s -> bounded multiplicative S * (1 + ELU(coeff*g)). Multiplicative
is "scaled by S" (uniform *relative* control over an orders-of-magnitude spectrum),
stays positive (no S_eff<0 sign-flip -> no incoherence from that path), and the
1+ELU shape stops the exp blowup. The 4e-4 sign-symmetry hack is gone.
- suppress_only = clamp g<=0 -> factor in (0,1]: attenuation only, structurally
cannot blow up. Matches the eval-awareness use case (turn a direction down).
- coeff: runtime steering scalar (0 = identity, <0 inverts). The per-call alpha
the rotation version lacked.
- group_init activation pooling is configurable: 'rms' weights outliers (ASVD
intuition), 'mean_abs' is the original outlier-robust pooling.
Refs: Refs:
- paper: https://github.com/wassname/AntiPaSTO - paper: https://github.com/wassname/AntiPaSTO
- sibling (whitened, rotation-free, mean-diff): steering-lite/.../sspace.py - sibling (whitened, mean-diff): steering-lite/.../sspace.py
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Literal from typing import Iterable, Literal
@@ -107,14 +87,15 @@ class AntiPaSTO:
@staticmethod @staticmethod
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
"""Wanda-style, data-driven dimension selection within the weight SVD. """Data-driven re-selection of which top-r singular directions to keep.
init() picks the top-r singular dimensions by S alone (PiSSA-style). init(): top-r by S alone (PiSSA-style)
group_init() re-selects by score[i] = S[i] * pool|X @ Vh[i]|: dimensions group_init(): top-r by score[i] = S[i] * pool|X @ Vh[i]| (Wanda/ASVD)
that are both large in W AND active on real inputs. pool = 'rms' (outlier- pool = 'rms' (outlier-sensitive) | 'mean_abs' (outlier-robust)
sensitive, the ASVD intuition that activation outliers carry signal) or
'mean_abs' (the original, outlier-robust). If calibration_data is None the This re-RANKS W's own singular vectors by activation; it does NOT re-orient
weight-SVD init from init() is kept. the basis (that is CorDA -> antipasto_corda.py). So the kept directions are
still plain weight-SVD directions, just a better subset. None -> keep init().
""" """
if calibration_data is None: if calibration_data is None:
return return
@@ -158,7 +139,9 @@ class AntiPaSTO:
f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}" f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}"
) )
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r. # Rebuild the FULL W: init() stored the exact top-r it subtracted, so
# W_res + U_r S_r Vh_r == W (full rank, not a cropped matrix). The SVD
# below therefore re-selects from W's whole spectrum, not a truncation.
W_res = layer.weight.data.float() W_res = layer.weight.data.float()
U_old = layer.lora_U.float() U_old = layer.lora_U.float()
S_old = layer.lora_S.float() S_old = layer.lora_S.float()
@@ -200,21 +183,13 @@ class AntiPaSTO:
if cfg.suppress_only: if cfg.suppress_only:
g = torch.clamp(g, max=0.0) # factor in (0,1]: attenuation only g = torch.clamp(g, max=0.0) # factor in (0,1]: attenuation only
# Per-direction reweighting: S_eff = S * (1 + ELU(coeff * g)). # S_eff = S * (1 + ELU(z)), z = coeff*g, 1+ELU(z) = exp(z) for z<=0 else 1+z.
# 1 + ELU(z) = exp(z) for z<=0, 1+z for z>0. # Why 1+ELU and not the obvious alternatives:
# Why this and not the obvious ones (all of which we tried): # linear S*(1+z) : z<-1 -> S_eff<0, a sign flip that drives incoherence.
# linear S*(1+z) : constant gradient (stable), but z<-1 -> S_eff<0, # exp S*exp(z) : unbounded, gradient self-amplifies (amplification blows up).
# a sign flip that drives incoherence. Unstable in # tanh bounded : arbitrary bound knob, saturation kills the gradient.
# the negatives. # 1+ELU uses each in its safe regime: exp where it is bounded in (0,1]
# exp S*exp(z) : positive, but unbounded and the gradient self- # (attenuation), linear where exp would diverge (amplification). >0 always.
# amplifies (d/dz exp = exp), so amplification blows up.
# tanh S*exp(c*tanh z): bounded, but c is an arbitrary free knob with no
# principled value, and saturation kills the gradient.
# 1+ELU : uses each in its safe regime -- exp only where it is
# bounded in (0,1] (attenuation, cannot go negative),
# linear where exp would diverge (amplification, const
# gradient). C1 at z=0 (both -> 1, slope 1); >0 always.
# coeff=0 or g=0 -> S_eff = S (identity). coeff<0 swaps amplify/suppress.
S_eff = S * (1.0 + torch.nn.functional.elu(coeff * g)) S_eff = S * (1.0 + torch.nn.functional.elu(coeff * g))
h = (x @ Vh.T) * S_eff # input in S-coords, reweighted h = (x @ Vh.T) * S_eff # input in S-coords, reweighted
+14 -22
View File
@@ -1,34 +1,26 @@
"""AntiPaSTO-Ablate: trainable directional ablation in the weight-SVD output basis. """AntiPaSTO-Ablate: trainable directional ablation in the weight-SVD output basis.
A contractive sibling of antipasto.py. Instead of reweighting the singular gains, A contractive sibling of antipasto.py: instead of reweighting the singular gains it
it projects out a learned direction in the *output* singular basis (the U side): projects out a learned direction in the output (U-side) singular basis.
W = U diag(S) Vh + W_res W = U diag(S) Vh + W_res
learn: c (r, k) ablation directions, alpha (k,) strengths in [0, 1] learn: c (r, k) ablation directions, alpha (k,) strengths in [0, 1]
Chat = orthonormal(c) # k unit dirs in S-space Chat = orthonormal(c) # k unit dirs in S-space
h = (x @ Vh.T) * S # output S-coords (= diag(S) Vh x) h = (x @ Vh.T) * S # output S-coords = diag(S) Vh x
h <- h - coeff * (h @ Chat) * alpha @ Chat.T # project the span out h <- h - coeff * (h @ Chat) * alpha @ Chat.T # project the span out
y = x @ W_res.T + h @ U.T y = x @ W_res.T + h @ U.T
Why this instead of gain reweighting (antipasto.py): The core (I - alpha Chat Chat^T) is a contraction: eigenvalues 1-alpha along Chat,
- The core (I - alpha Chat Chat^T) is a CONTRACTION: eigenvalues are 1-alpha along 1 elsewhere, all in [0, 1]. It cannot amplify, so it cannot blow up -- the instability
Chat and 1 elsewhere, all in [0, 1] for alpha in [0, 1]. It cannot amplify and the multiplicative gain bounds away is structurally absent (and a contraction is the
cannot blow up, so the failure mode the multiplicative gain fights with bounds is natural core to recurse). This is the trainable form of directional ablation (Arditi+
structurally absent. It is also the natural core to recurse (a contraction composed 2024): target residual writers (down_proj, o_proj) for the surgical regime, not all
with itself converges; an amplifier diverges). Linears.
- It is the trainable form of directional ablation (Arditi+ 2024). Ablating Chat in
the middle removes output direction U Chat; for a residual *writer*
(mlp.down_proj, self_attn.o_proj) that is a residual-stream direction -- the
SURGICAL regime in the steering-lite sweeps (directional_ablation topped SI).
Target writers, not all Linears, or you get the broad-suppression regime.
Runtime: coeff is the per-call knob. coeff=0 -> identity. coeff in (0, 1] -> ablate. Runtime: coeff is the per-call knob. coeff=0 -> identity; (0, 1] -> ablate; <0 adds the
coeff < 0 -> *add* the direction back (amplify) -- the bidirectional dual; this is the direction back (the side that can grow, so bound coeff there).
side that can grow, so bound coeff there.
Init: alpha small (>0 so c receives gradient), c random-normalized. The strong init is Refs: antipasto.py (gain sibling), directional ablation Arditi+ 2024 arXiv:2406.11717.
to warm-start c from the contrastive direction dS in S-space (extract it exactly like
sspace.py: dS = mean(xS_pos) - mean(xS_neg) on persona-branching pairs), then fine-tune.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable from typing import Iterable
+20 -46
View File
@@ -1,47 +1,22 @@
"""AntiPaSTO-Arrow: a STRUCTURED fixed-basis core, the cheap way to add cross- """AntiPaSTO-Arrow: cross-direction mixing via a cheap arrowhead core.
direction mixing that plain antipasto (a diagonal gain) cannot express.
antipasto's core is diagonal: S_eff = S * (1 + ELU(coeff*g)) reweights each frozen antipasto's core is diagonal (S_eff = S * gain): it reweights each singular direction
singular direction independently. It can turn a direction up or down but it can never independently but cannot let direction i drive direction j. A full dense r x r core
let direction i's input drive direction j's output. Yet the behaviour you steer is a restores all mixing but costs r^2 params. The arrowhead is the cheap middle: a dense
combination Sigma c_i v_i that generically lies OFF any single axis (the same argument block on the top-b directions (where the action lives), the diagonal gain on the rest.
that motivates antipasto_corda), so a diagonal core can only ever approximate it.
The obvious fix -- a full dense r x r core M, DeltaW = U M Vh -- restores all mixing but
costs r^2 params (r=256 -> 65536, a rank-8 LoRA's worth) and an r x r matmul per forward.
antipasto.py's own header flags this trap: "a dense r x r core is r^2 params ... add a
*fixed-basis* core U M Vh rather than rotating". This file is that core, made cheap by
making it STRUCTURED instead of dense -- an arrowhead, not an r x r.
Arrowhead structure (dense top-block + diagonal tail):
core C (r x r, acting on the S-scaled coords) =
[ B (b x b dense) | 0 ] B couples the top-b directions
[ 0 | diag(1+ELU(c*g)) ] tail = exactly antipasto's gain
core C (r x r, on the S-scaled coords):
[ B (b x b dense) | 0 ] B = I_b + coeff*M (top-b mixing)
[ 0 | diag(1 + ELU(coeff*g)) ] tail = antipasto's gain
DeltaW = U @ C @ diag(S) @ Vh DeltaW = U @ C @ diag(S) @ Vh
cost: b^2 + (r-b) params, one b x b matmul per forward.
The top b singular directions (largest S = where PiSSA says the action lives) get a full Identity at init: M=0 -> B=I, g=0 -> 1+ELU(0)=1, so C=I and DeltaW = U diag(S) Vh.
b x b interaction block B = I_b + coeff*M; the remaining r-b stay on the cheap bounded coeff=0 -> C=I too (runtime off). The block is the linear (1+z) regime -- stable but
diagonal gain. Cost is b^2 + (r-b) params and one b x b matmul per forward -- for b=8,r=256 not strictly bounded; for a can't-blow-up guarantee on the top directions use
that is 312 params and a 64-MAC corner, versus 65536 for dense r x r and versus the antipasto_ablate.
rotation variant's per-forward Cayley solve (measured 72ms vs 36ms). So: cross-direction
mixing where it matters, at diagonal-core cost.
(We call it "arrowhead" after the shape -- a dense head with a diagonal shaft. A true Refs: antipasto.py (diagonal sibling), antipasto_corda.py (off-axis basis argument).
numerical-LA arrowhead also carries a hub row+column coupling the block to the tail; that
would add 2(r-b) params and is a one-line extension if the top-b span turns out too small.
Not added until measured to be needed.)
Identity at init: M=0 -> B=I_b, g=0 -> 1+ELU(0)=1, so C=I and DeltaW = U diag(S) Vh exactly
(up to the one-time SVD-residual rounding). coeff=0 -> C=I too (runtime off). The block is
the linear-amplification regime of antipasto's design (a matmul, constant-gradient, no exp
self-amplification); it is stable like 1+ELU's upper branch, not strictly bounded -- if you
need the tail's structural can't-blow-up guarantee on the top directions too, use
antipasto_ablate instead.
Refs: antipasto.py (diagonal sibling), antipasto_corda.py (the off-axis argument).
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Literal from typing import Iterable, Literal
@@ -64,8 +39,9 @@ CalibrationData = Iterable[CalibrationBatch]
class AntiPaSTOArrowConfig(AdapterConfig): class AntiPaSTOArrowConfig(AdapterConfig):
variant: str = "antipasto_arrow" variant: str = "antipasto_arrow"
r: int = 256 r: int = 256
# Size of the dense interaction block on the top-b singular directions. The ONLY # Dense interaction block on the top-b singular directions; sets capacity and the
# quadratic cost (b^2 params); keep small. b=1 degenerates to antipasto. # only quadratic cost (b^2 params/module). b=1 degenerates to antipasto; b->r
# approaches a full dense r-core (~LoRA params) at the cost arrow exists to avoid.
block: int = 8 block: int = 8
suppress_only: bool = False # clamp the tail g<=0 (attenuate only); block unaffected. suppress_only: bool = False # clamp the tail g<=0 (attenuate only); block unaffected.
# Tail guarantee holds for coeff>=0; coeff<0 inverts the product and re-amplifies. # Tail guarantee holds for coeff>=0; coeff<0 inverts the product and re-amplifies.
@@ -152,11 +128,9 @@ class AntiPaSTOArrow:
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False) U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
proj = X.to(Vh_full) @ Vh_full.T proj = X.to(Vh_full) @ Vh_full.T
act_mag = proj.pow(2).mean(0).sqrt() if pool == "rms" else proj.abs().mean(0) act_mag = proj.pow(2).mean(0).sqrt() if pool == "rms" else proj.abs().mean(0)
# Select top-r by score, then re-sort ascending by SVD index. Since svd() # Pick top-r by score, then sort by SVD index. svd() returns S descending,
# returns S descending, the first b stored dirs (the block's cS[..., :b]) are # so the block's first-b coords are the b largest-S among the selected r
# the b LARGEST-S among the selected r -- not the b highest-score. Matches the # (= where the action lives), not the b highest-score.
# block's "largest S = where the action lives" intent, but a high-S dir dropped
# by score-selection won't be in the block.
idx = (S_full * act_mag).argsort(descending=True)[:r].sort().values idx = (S_full * act_mag).argsort(descending=True)[:r].sort().values
Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx] Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx]
W_res_new = (W_orig - (Ur * Sr) @ Vhr).to(layer.weight.dtype) W_res_new = (W_orig - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
+16 -35
View File
@@ -1,42 +1,23 @@
"""AntiPaSTO-CorDA: steer in a covariance-ORIENTED basis, not the weight-gain basis. """AntiPaSTO-CorDA: reweight in a covariance-oriented basis, not the weight basis.
The complaint that motivates this: plain SVD sorts directions by weight gain ||W v|| Plain SVD sorts directions by weight gain ||W v|| on isotropic input. The behaviour
on an *isotropic* input. The behaviour you steer lives where the *data* has energy. you steer lives where the DATA has energy, off the top weight-singular axes. CorDA
Those orderings disagree, so the behaviour smears off the top singular axes and a (Yang+ 2024, arXiv:2406.05223) re-orients the SVD by the input covariance, so the top-r
top-r crop in the weight basis throws it away. CorDA (Yang+ 2024, arXiv:2406.05223) directions move the output most on real activations.
re-orients the decomposition by the input covariance C = E[x x^T], so the top
directions are the ones with the most energy *on real activations*.
Decomposition (verified: full-rank reconstruction ~1e-5, and on anisotropic data the C = E[x x^T] (+ eps I) # input second moment on calibration data
top-r data-truncation error drops ~27x vs plain SVD): C^{1/2}, C^{-1/2} via eigh(C)
U S Vht = SVD(W C^{1/2})
P = Vht C^{-1/2} # (r, d_in) oblique input projector
W = U diag(S) P (exactly)
S_eff = S * (1 + ELU(coeff*g)) # same bounded gain as antipasto
y = x @ W_res.T + ((x @ P.T) * S_eff) @ U.T
C = E[x x^T] (+ eps I) # input second moment on calibration data Identity at g=0 or coeff=0: S_eff=S. P is oblique (rows not orthonormal -- C^{-1/2}
C^{1/2}, C^{-1/2} via eigh(C) skews them); fine for gain reweighting and for output-side ablation (the obliqueness
W~ = W C^{1/2}; SVD(W~) = U S V~h is input-side; U stays orthonormal). No calibration_data -> plain SVD (== antipasto).
P = V~h C^{-1/2} # (r, d_in) OBLIQUE input projector
W = U diag(S) P (exactly) # so y = x W_res^T + ((x P^T) * S_eff) U^T
S here are the singular values of W weighted by input std, so top-r is the optimal Refs: antipasto.py (gain + selection sibling), CorDA arXiv:2406.05223.
rank-r in the input-weighted norm E||(W - W_r) x||^2 -- the directions that actually
move the output on your data.
Connection to the shared/differing-basis problem: C is built from pos AND neg inputs
pooled, so P spans the *shared* activation structure (the common encoder) that
chosen-minus-rejected cancels by construction. A trainable gain on this basis can
therefore reach shared structure that contrastive dS extraction is blind to.
Core: rotation-free. S_eff = S * (1 + ELU(coeff * g)). This is exp(coeff*g) on the
attenuation side (g<0, bounded, no blow-up) and 1+coeff*g on the amplification side
(g>0, where exp would diverge). g=0 -> identity. coeff is the runtime knob (0=off).
Basis note: P is OBLIQUE (rows not orthonormal -- C^{-1/2} skews them). That is fine
for gain reweighting (we scale oblique coordinates), and also fine for OUTPUT-side
directional ablation: the obliqueness is input-side only, while ablation acts in the
U/output space where U stays orthonormal. antipasto_ablate has a cov_orient flag that
reuses this basis -- at low r it captures the behavior output direction that plain-SVD
top-r drops (measured 1.00 vs 0.65 at r=16).
Falls back to plain SVD (== antipasto, rotation-free) if no calibration_data.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable from typing import Iterable