Add rotation-free S-space adapter cores (antipasto family)

Replace antipasto's rotation/Cayley with a bounded 1+ELU gain and split the
S-space idea into four interpretable PiSSA-style cores (frozen U/S/Vh, small
trainable core):

- antipasto: S_eff = S*(1+ELU(coeff*g)). exp-bounded attenuation, linear
  amplification (constant gradient, no runaway). g=0 -> exact identity.
- antipasto_rot: keeps the block-Cayley rotation as a separate variant for
  cost comparison (its per-forward solve is the 72ms vs 36ms gap).
- antipasto_ablate: contractive (I - a c c^T) diag(S), eigenvalues in [0,1],
  cannot blow up. Optional cov_orient (CorDA) basis.
- antipasto_corda: covariance-oriented oblique projector P = Vh C^{-1/2}, the
  data-energy basis rather than the weight-gain basis. 1+ELU gain.

Add scripts/_cost.py + scripts/cost_report.py: one-row-per-variant cost table
(trainable params, peak GPU mem, fwd/bwd ms, added MACs/tok, group_init ms).
Wire all four into the benchmark, smoke test, and __init__ exports.

External review (DeepSeek-v4-pro, docs/reviews/) verified the math; acted on
its one real point (corda g now inits to zeros for exact identity).

Co-Authored-By: Claudypoo <noreply@anthropic.com>
This commit is contained in:
wassname
2026-06-14 19:12:27 +08:00
parent e5048fcaff
commit b80d7778af
11 changed files with 1059 additions and 107 deletions
+38
View File
@@ -0,0 +1,38 @@
## Code Review: Four S-space weight adapters (anti-pasto family + sspace reference + cost script)
### Summary
The code adds four PiSSA-style adapters (antipasto, antipasto_rot, antipasto_ablate, antipasto_corda) that store frozen top-r SVD buffers and a small trainable gain/core. The key mathematical claims in the docstrings (identity at g=0, contraction of ablation core, exact reconstruction of the CorDA decomposition, signed cosine gate) are correct. However, `group_init()` re-selects/re-orients the SVD basis but does **not** reset the trainable parameters (g, delta_s, rot_T, c, alpha), which introduces a latent but important correctness problem — after a data-driven reinit, gains and direction vectors become misaligned with their intended singular axes.
---
### Critical (must fix)
- **`antipasto_corda.py`, `antipasto_rot.py`, `antipasto_ablate.py`:** `group_init()` re-orients or re-selects the top-r SVD basis but leaves the existing trainable parameters (`lora_g`, `lora_delta_s`, `lora_rot_T`, `lora_c`, `lora_alpha`) untouched, still attached to the old indices.
- For `antipasto_corda`, `lora_g` is initialised with `N(0, 4e-4)` in `init()`. After `group_init`, those small random gains now multiply a different set of singular directions, producing an uncontrolled (though tiny) perturbation.
- For `antipasto_rot`, the `+4e-4` bias on `lora_delta_s` remains, now applied to arbitrary resorted directions, and the block rotation `lora_rot_T` is disconnected from the new block structure.
- For `antipasto_ablate`, the ablation directions and strengths (`lora_c`, `lora_alpha`) are not reset; if any warm-starting or training has happened, they point in the wrong subspace.
**Fix:** After re-selection or re-orientation, either re-initialise the trainable parameters (e.g., zero for g, zeros/small for delta_s and rot_T, random-normalised for c, init_alpha for alpha) or re-index them with the same `idx` mapping used for the buffers. Document that `group_init` must be called **before any training** and that the trainable parameters will be fresh after it.
### Important (should fix)
- **`antipasto_ablate.py`** Contraction claim and coeff bounds: the forward pass applies `h = h - coeff * (proj * alpha) @ Chat.T`. The core is a contraction only when both `coeff` and `alpha` are in `[0, 1]`. The configs `coeff` can be *any* float (the docstring mentions `coeff<0` for amplification). There is no runtime clamping of `coeff`. If the intention is to keep the contraction property structurally enforced, clamp `coeff` to `[0, 1]` inside `forward()` or at least validate it.
- **`sspace.py`** Division by `sqrtS` without epsilon: `xS = (y_eff @ U_r) / sqrtS`. For small or zero singular values (e.g., when r is large or W is low-rank) this can produce NaNs/infs. Add a small `ε` denominator (consistent with the ε used elsewhere).
- **`antipasto_ablate.py`** `_orthonormal()` calls `torch.linalg.qr(c.float())` every forward pass. For large `r` and many layers this adds non-trivial cost. A lighter reparameterisation (e.g., maintaining `c` as a matrix with a normalisation step that avoids full QR) might be warranted, but for small `(r,k)` pairs the current approach is acceptable. At minimum, add a comment noting the per-forward cost.
### Suggestions
- **`antipasto.py` group_init:** after the idx sort, `idx = idx.sort().values` ensures a stable, canonical ordering. This is a nice touch for reproducibility.
- **`antipasto_ablate.py`** The docstring says “CorDA-orient the basis from input covariance … the ablation is OUTPUT-side and CorDA's U stays orthonormal …”. The forward code correctly uses the orthonormal `U` for output projection, so the contraction in S-space carries over to the output.
- **`sspace.py`** The signed gate correctly preserves `cos` sign, so anti-aligned tokens receive a negative `gate * alpha` and are pushed opposite to `dS_hat`. Verified.
### Positive
- The documentation is thorough, explaining the design choices (why 1+ELU, why contraction, why CorDA), and includes references.
- The PiSSA-style `W_res` decomposition is implemented correctly across all variants.
- `antipasto.py`s `S_eff = S * (1 + ELU(coeff*g))` is indeed C1, positive, and identity at `g=0` — no sign-flip bugs.
- `antipasto_ablate.py` enforces orthonormal `Chat` and clamps `alpha` to `[0,1]`, making the contraction property safe when `coeff` is also bounded.
### Verdict
**REQUEST CHANGES**
`group_init()` must reset the trainable parameters after it changes the SVD basis; without this, the adapters silently poison their own steering gains with a different set of directions. Fix that and the code is ready.
---
*Note: The `sspace.py` and `_cost.py` files are part of the same move into lora-lite and are free of the parameter-reset issue; the only concern in `sspace.py` is the unprotected division by singular values.*
+131
View File
@@ -0,0 +1,131 @@
"""Measure the cost of an attached adapter: params, FLOPs/MACs, time, GPU mem.
Which metric is "best" for comparing adapters? They answer different questions:
- trainable_params -- deterministic "size" number. The headline.
- macs_per_token -- deterministic, hardware-INDEPENDENT compute. Best for an
apples-to-apples comparison: wall-time is noisy and the old
rotation adapter paid a per-forward Cayley solve the new ones
do not. "adds" (additions) ~= MACs; FLOPs ~= 2 * MACs.
- fwd_ms / bwd_ms -- felt cost, but noisy: warmup + median over `iters`, never one run.
- peak_gpu_mb -- resident + activation peak around fwd(+bwd).
FLOPs come from torch.utils.flop_counter.FlopCounterMode (built in, no new dep). Its
convention is MACs (a (m,k)@(k,n) matmul counts as m*n*k); we expose both `flops`
(as returned) and `macs_per_token = flops / n_tokens` -- calibrate once on a known
matmul if you need to be sure of the factor of 2.
"""
from __future__ import annotations
import statistics
import time
import torch
from torch.utils.flop_counter import FlopCounterMode
def _time_call(fn, warmup: int, iters: int, cuda: bool) -> float:
"""Median wall-time of fn() in milliseconds (warmup excluded)."""
for _ in range(warmup):
fn()
if cuda:
torch.cuda.synchronize()
samples = []
for _ in range(iters):
if cuda:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
fn()
end.record()
torch.cuda.synchronize()
samples.append(start.elapsed_time(end))
else:
t0 = time.perf_counter()
fn()
samples.append((time.perf_counter() - t0) * 1e3)
return statistics.median(samples)
def measure_cost(
model: torch.nn.Module,
fwd_fn,
*,
bwd_step_fn=None,
n_tokens: int | None = None,
adapter_filter: str = "lora_",
warmup: int = 3,
iters: int = 10,
) -> dict:
"""Cost of the currently-attached adapter.
fwd_fn(): run one forward (no grad). Used for FLOPs + fwd timing.
bwd_step_fn(): zero_grad + forward + loss.backward(). Used for bwd timing.
n_tokens: tokens in the fwd_fn batch, for macs_per_token.
adapter_filter: substring marking adapter params/buffers (default 'lora_').
"""
dev = next(model.parameters()).device
cuda = dev.type == "cuda"
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
named = list(model.named_parameters()) + list(model.named_buffers())
adapter_bytes = sum(t.numel() * t.element_size() for n, t in named if adapter_filter in n)
# FLOPs: one forward under the counter (no grad so we count inference cost).
# FlopCounterMode can assert on some fused attention shapes; degrade to None.
try:
fc = FlopCounterMode(display=False)
with torch.no_grad(), fc:
fwd_fn()
flops = fc.get_total_flops()
except Exception as e:
print(f" [warn] FLOP count failed ({type(e).__name__}: {e}); flops=None")
flops = None
if cuda:
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
fwd_ms = _time_call(lambda: _no_grad(fwd_fn), warmup, iters, cuda)
bwd_ms = _time_call(bwd_step_fn, warmup, iters, cuda) if bwd_step_fn is not None else None
peak_gpu_mb = (torch.cuda.max_memory_allocated() / 1e6) if cuda else None
return dict(
trainable_params=trainable_params,
adapter_resident_mb=adapter_bytes / 1e6,
flops=flops,
macs_per_token=(flops / n_tokens) if (flops and n_tokens) else None,
fwd_ms=fwd_ms,
bwd_ms=bwd_ms,
peak_gpu_mb=peak_gpu_mb,
)
def _no_grad(fn):
with torch.no_grad():
return fn()
class group_init_meter:
"""Context manager: wall-time + peak CPU RAM of a group_init / attach-with-calib.
CorDA accumulates C = E[xx^T] on CPU and runs eigh(d_in^3) -- the expensive corner.
Use around ll.attach(model, cfg, calibration_data=...) to log that asymmetry.
"""
def __init__(self):
self.ms = None
self.peak_cpu_mb = None
def __enter__(self):
import tracemalloc
self._tm = tracemalloc
tracemalloc.start()
self._t0 = time.perf_counter()
return self
def __exit__(self, *exc):
self.ms = (time.perf_counter() - self._t0) * 1e3
_, peak = self._tm.get_traced_memory()
self._tm.stop()
self.peak_cpu_mb = peak / 1e6
return False
+142
View File
@@ -0,0 +1,142 @@
"""One-row-per-variant cost table: params, MACs/token, fwd/bwd ms, peak GPU, group_init.
Answers "which is best -- time / flops / adds / params?": MACs/token is the
deterministic apples-to-apples compute number; trainable_params is the size headline;
wall-time is the felt-but-noisy number; group_init is where CorDA's eigh(d_in^3) bites.
Usage:
uv run --extra benchmark python scripts/cost_report.py \
--model Qwen/Qwen3-0.6B-Base --variants antipasto antipasto_corda antipasto_ablate lora \
--target-name 'q_proj$' 'v_proj$' --r 32 --out logs/cost_qwen0.6b.log
Point --target-name at down_proj to see the CorDA covariance corner (large d_in).
"""
from __future__ import annotations
import argparse
import importlib.util
import sys
from pathlib import Path
import torch
from tabulate import tabulate
import lora_lite as ll
_HERE = Path(__file__).resolve().parent
_BENCH = importlib.util.spec_from_file_location("metamath_benchmark", _HERE / "metamath_gsm8k_benchmark.py")
benchmark = importlib.util.module_from_spec(_BENCH)
sys.modules[_BENCH.name] = benchmark
_BENCH.loader.exec_module(benchmark)
_COST = importlib.util.spec_from_file_location("_cost", _HERE / "_cost.py")
cost = importlib.util.module_from_spec(_COST)
sys.modules[_COST.name] = cost
_COST.loader.exec_module(cost)
def build_cfg(variant: str, args, dtype) -> ll.AdapterConfig:
"""Reuse the benchmark's variant->config map; only need r/targets/dtype here."""
bcfg = benchmark.BenchmarkConfig(
model=args.model, variant=variant, r=args.r, alpha=float(args.r),
target_name=list(args.target_name), layers=args.layers, torch_dtype=args.dtype,
antipasto_cov_orient=args.cov_orient,
)
return benchmark.cfg_for_variant(bcfg, dtype)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="Qwen/Qwen3-0.6B-Base")
ap.add_argument("--variants", nargs="+",
default=["lora", "antipasto", "antipasto_rot", "antipasto_corda", "antipasto_ablate"])
ap.add_argument("--target-name", nargs="+", default=[r"q_proj$", r"v_proj$"])
ap.add_argument("--r", type=int, default=32)
ap.add_argument("--layers", default="all",
help="'all' or comma list e.g. '0,1' -- limit layers (CorDA down_proj eigh is slow).")
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
ap.add_argument("--dtype", default="bfloat16")
ap.add_argument("--seq-len", type=int, default=256)
ap.add_argument("--batch", type=int, default=2)
ap.add_argument("--calib-batches", type=int, default=4)
ap.add_argument("--cov-orient", action="store_true",
help="CorDA-orient antipasto_ablate (measure the eigh corner).")
ap.add_argument("--out", default="logs/cost.log")
args = ap.parse_args()
dtype = getattr(torch, args.dtype)
# eager attention: FlopCounterMode's sdpa_flop_count asserts on GQA (Qwen3) SDPA
# shapes (q heads != kv heads). eager uses explicit matmuls it can count.
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained(args.model)
model = AutoModelForCausalLM.from_pretrained(
args.model, dtype=dtype, attn_implementation="eager"
).to(args.device)
model.eval()
n_tokens = args.batch * args.seq_len
ids = torch.randint(0, model.config.vocab_size, (args.batch, args.seq_len), device=args.device)
calib = [{"input_ids": torch.randint(0, model.config.vocab_size,
(args.batch, args.seq_len), device=args.device)}
for _ in range(args.calib_batches)]
def fwd():
model(input_ids=ids)
def bwd_step():
model.zero_grad(set_to_none=True)
loss = model(input_ids=ids).logits.float().pow(2).mean()
loss.backward()
# base (no-adapter) cost, so each row can report the adapter's ADDED MACs/token.
base = cost.measure_cost(model, fwd, bwd_step_fn=bwd_step, n_tokens=n_tokens)
base_macs = base["macs_per_token"]
print(f"base (no adapter): MACs/tok={int(base_macs) if base_macs else None} "
f"fwd_ms={round(base['fwd_ms'],2)} bwd_ms={round(base['bwd_ms'],2)}")
# base = no adapter; model params left trainable, so this is the full-finetune
# GPU-mem reference (its backward stores grads for every weight).
total_params = sum(p.numel() for p in model.parameters())
rows = [{
"variant": "base(full-FT)", "train_params": total_params,
"fwd_ms": round(base["fwd_ms"], 2), "bwd_ms": round(base["bwd_ms"], 2),
"peak_GPU_MB": round(base["peak_gpu_mb"], 1) if base["peak_gpu_mb"] else None,
"added_MACs/tok": 0 if base_macs else None,
"ginit_ms": 0.0, "ginit_CPU_MB": 0.0,
}]
for variant in args.variants:
cfg = build_cfg(variant, args, dtype)
# group_init / attach cost (CorDA's eigh + C live here).
with cost.group_init_meter() as gi:
ll.attach(model, cfg, calibration_data=calib)
c = cost.measure_cost(model, fwd, bwd_step_fn=bwd_step, n_tokens=n_tokens)
ll.detach(model)
rows.append({
"variant": variant,
"train_params": c["trainable_params"],
"fwd_ms": round(c["fwd_ms"], 2),
"bwd_ms": round(c["bwd_ms"], 2) if c["bwd_ms"] else None,
"peak_GPU_MB": round(c["peak_gpu_mb"], 1) if c["peak_gpu_mb"] else None,
# flat across same-r adapters; kept only as a sanity check, not a comparator.
"added_MACs/tok": int(c["macs_per_token"] - base_macs) if (c["macs_per_token"] and base_macs) else None,
"ginit_ms": round(gi.ms, 1),
"ginit_CPU_MB": round(gi.peak_cpu_mb, 1),
})
print(f" {variant}: params={rows[-1]['train_params']} "
f"peak_GPU_MB={rows[-1]['peak_GPU_MB']} bwd_ms={rows[-1]['bwd_ms']} ginit_ms={rows[-1]['ginit_ms']}")
table = tabulate(rows, headers="keys", tablefmt="pipe")
header = (f"# cost report: {args.model} targets={args.target_name} r={args.r} "
f"seq={args.seq_len} batch={args.batch} dtype={args.dtype}\n"
f"# COMPARATORS: train_params, peak_GPU_MB (fwd+bwd, process-local max), bwd_ms, ginit_ms.\n"
f"# added_MACs/tok is flat across same-r adapters (sanity check only).\n"
f"# ginit_CPU_MB undercounts: tracemalloc misses torch C++ tensor allocs (the CorDA C matrix).\n")
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(header + table + "\n")
print("\n" + header + table)
print(f"\nsaved -> {out_path}")
if __name__ == "__main__":
main()
+20 -3
View File
@@ -34,6 +34,9 @@ CFG_BY_VARIANT = {
"hra": ll.HRAConfig,
"eva": ll.EVAConfig,
"antipasto": ll.AntiPaSTOConfig,
"antipasto_rot": ll.AntiPaSTORotConfig,
"antipasto_ablate": ll.AntiPaSTOAblateConfig,
"antipasto_corda": ll.AntiPaSTOCorDAConfig,
"road": ll.RoadConfig,
}
@@ -43,7 +46,7 @@ class BenchmarkConfig:
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
model: str = "Qwen/Qwen3-0.6B-Base"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "road"] = "lora"
mode: Literal["benchmark", "probe"] = "benchmark"
device: str = "cuda"
torch_dtype: str = "bfloat16"
@@ -52,6 +55,13 @@ class BenchmarkConfig:
alpha: float = 64.0
delora_lambda0: float = 0.1
road_group_size: int = 64
# AntiPaSTO family (gain / corda) runtime knobs.
antipasto_coeff: float = 1.0
antipasto_suppress_only: bool = False
# AntiPaSTO-ablate.
antipasto_ablate_k: int = 1
antipasto_cov_orient: bool = False
# AntiPaSTO-rot (legacy rotation variant) basis to rotate.
antipasto_rotate_basis: Literal["V", "U", "none"] = "V"
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
layers: str = "all"
@@ -124,8 +134,15 @@ def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConf
extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {}
if args.variant == "road":
extra = {"group_size": args.road_group_size}
if args.variant == "antipasto":
if args.variant == "antipasto_rot":
extra = {"rotate_basis": args.antipasto_rotate_basis}
if args.variant == "antipasto":
extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only}
if args.variant == "antipasto_corda":
extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only}
if args.variant == "antipasto_ablate":
extra = {"coeff": args.antipasto_coeff, "k": args.antipasto_ablate_k,
"cov_orient": args.antipasto_cov_orient}
return CFG_BY_VARIANT[args.variant](
r=args.r,
alpha=args.r if args.variant == "pissa" else args.alpha,
@@ -155,7 +172,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
def perturb_first_adapter(model: torch.nn.Module) -> None:
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m", "lora_road_theta", "lora_road_alpha")
priority = ("lora_B", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha")
for key in priority:
for _, p in model.named_parameters():
if p.requires_grad and key in _:
+6
View File
@@ -20,6 +20,9 @@ from .variants.dora import DoRAConfig
from .variants.hra import HRAConfig
from .variants.eva import EVAConfig
from .variants.antipasto import AntiPaSTOConfig
from .variants.antipasto_rot import AntiPaSTORotConfig
from .variants.antipasto_ablate import AntiPaSTOAblateConfig
from .variants.antipasto_corda import AntiPaSTOCorDAConfig
from .variants.road import RoadConfig
__all__ = [
@@ -33,6 +36,9 @@ __all__ = [
"HRAConfig",
"EVAConfig",
"AntiPaSTOConfig",
"AntiPaSTORotConfig",
"AntiPaSTOAblateConfig",
"AntiPaSTOCorDAConfig",
"RoadConfig",
"attach",
"detach",
+4 -1
View File
@@ -1 +1,4 @@
from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto, road # noqa: F401 side-effect: register
from . import ( # noqa: F401 side-effect: register
lora, pissa, delora, ia3, dora, hra, eva, antipasto, road,
antipasto_rot, antipasto_ablate, antipasto_corda,
)
+84 -96
View File
@@ -1,31 +1,48 @@
"""AntiPaSTO: SVD steering with learnable singular-value deltas + block-diagonal Cayley rotation.
"""AntiPaSTO: SVD steering with learnable, bounded singular-value reweighting.
wassname 2026 https://arxiv.org/abs/2601.07473
W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r)
learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2)
R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T)
y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T
learn: g (r,) per-singular-direction gain log/lin-scale
S_eff = S * (1 + ELU(coeff * g)) exp(.) for g<=0, 1+. for g>0
suppress_only: clamp g<=0 -> factor in (0,1], attenuation only
y = x @ W_res.T + ((x @ Vh.T) * S_eff) @ U.T
Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ≈ x @ W^T (fp32 SVD round-trip, tiny positive bias on delta_s breaks sign symmetry).
Identity at g=0 (or coeff=0): 1+ELU(0)=1 exactly, so S_eff = S and the output is
x @ W^T up to the one-time SVD-residual rounding. No additive sign-symmetry hack
needed: the basis is frozen, so the direction sign is fixed and exp/(1+.) is
sign-preserving. The 1+ELU shape is chosen over linear (sign-flips at g<-1), exp
(amplification blows up), and tanh (arbitrary bound) -- see forward() for why.
Scope cut vs antipasto3: this is a fine-tuning adapter, not the full runtime
steering interface. There is no per-call alpha, so it does not expose the
bidirectional R(+alpha) / R(-alpha) inference symmetry. The V-basis path uses the
opposite chirality to antipasto3's default U-basis path, so checkpoints are not
portable without a sign/basis convention.
Changes vs the rotation version this replaces:
- Rotation dropped. Rotating Vh/U leaves the interpretable singular basis (the
SVD-direction / Conjecture property), which is the entire point of steering in
S-space, and the Cayley solve was numerically finicky. The basis is now frozen;
the only learned object is the per-direction gain. If you later want
cross-direction mixing, add a *fixed-basis* core U M Vh (M trainable, U/Vh frozen)
rather than rotating -- that keeps the directions interpretable. It is also far
cheaper than PiSSA: a dense r x r core is r^2 params (~= a rank-8 LoRA at r=256),
versus PiSSA's free A,B at r*(d_in+d_out), which drifts off the SVD basis.
- Additive delta_s -> bounded multiplicative S * (1 + ELU(coeff*g)). Multiplicative
is "scaled by S" (uniform *relative* control over an orders-of-magnitude spectrum),
stays positive (no S_eff<0 sign-flip -> no incoherence from that path), and the
1+ELU shape stops the exp blowup. The 4e-4 sign-symmetry hack is gone.
- suppress_only = clamp g<=0 -> factor in (0,1]: attenuation only, structurally
cannot blow up. Matches the eval-awareness use case (turn a direction down).
- coeff: runtime steering scalar (0 = identity, <0 inverts). The per-call alpha
the rotation version lacked.
- group_init activation pooling is configurable: 'rms' weights outliers (ASVD
intuition), 'mean_abs' is the original outlier-robust pooling.
Refs:
- paper: https://github.com/wassname/AntiPaSTO
- lite port of: https://github.com/wassname/antipasto3
(offline: docs/refs/antipasto3_svd_adapter.py)
- sibling (whitened, rotation-free, mean-diff): steering-lite/.../sspace.py
"""
import math
from dataclasses import dataclass
from typing import Iterable, Literal
import torch
from einops import einsum, rearrange
from einops import rearrange
from jaxtyping import Float
from torch import nn, Tensor as T
@@ -40,35 +57,16 @@ CalibrationData = Iterable[CalibrationBatch]
@dataclass
class AntiPaSTOConfig(AdapterConfig):
variant: str = "antipasto"
# Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out).
# Only r + r trainable scalars, so r can be large.
r: int = 256
# Block size for the block-diagonal Cayley rotation. r must be divisible by it.
block_size: int = 4
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
max_rotation_angle: float = 0.5
# Which singular basis to rotate: 'V' (input), 'U' (output), or 'none'.
rotate_basis: Literal["V", "U", "none"] = "V"
def _cayley(skew: torch.Tensor) -> torch.Tensor:
"""R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality."""
bs = skew.shape[-1]
eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew)
X = skew / 2
return torch.linalg.solve(eye - X, eye + X)
def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor:
"""rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation."""
n_blocks, _ = rot_T.shape
rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0)
A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device)
A[:, rows, cols] = rot_T
A = 0.5 * (A - A.transpose(-1, -2))
a_limit = 2.0 * math.tan(max_angle / 2.0)
A = a_limit * torch.tanh(A / a_limit)
return _cayley(A)
# Per-direction reweighting is S_eff = S * (1 + ELU(coeff * g)). See forward()
# for the why; identity at g=0 or coeff=0, positive always, no free bound knob.
suppress_only: bool = False # clamp g<=0 -> factor in (0,1]: attenuation only
# Runtime steering scale. 0 = identity. <0 inverts (swaps amplify/suppress).
coeff: float = 1.0
# group_init Wanda-style pooling of |X @ Vh[i]|: 'rms' is outlier-sensitive
# (ASVD intuition), 'mean_abs' is the original outlier-robust pooling.
act_pool: Literal["rms", "mean_abs"] = "rms"
@register
@@ -78,24 +76,14 @@ class AntiPaSTO:
@staticmethod
def param_specs(d_in, d_out, cfg):
r = cfg.r
bs = int(cfg.block_size)
if r % bs != 0:
raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}")
specs = dict(
# Frozen SVD components captured at init.
return dict(
# Frozen top-r SVD captured at init.
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
# Trainable: per-singular-value delta.
# antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign
# symmetry (rotation alone can't); zero-init works but trains slower.
lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)),
# Trainable per-direction log-scale. init 0 -> 1+ELU(0)=1 -> identity.
lora_g=ParamSpec((r,), init="zeros"),
)
if cfg.rotate_basis != "none":
n_blocks = r // bs
n_triu = bs * (bs - 1) // 2
specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros")
return specs
@staticmethod
def init(layer: nn.Module, cfg) -> None:
@@ -114,17 +102,18 @@ class AntiPaSTO:
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
layer.weight.data.copy_(W_res)
# group_init() refines this to input-aligned directions if calibration_data is given.
# group_init() refines the dimension selection if calibration_data is given.
@staticmethod
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
"""Wanda-style data-driven dimension selection within the weight SVD.
"""Wanda-style, data-driven dimension selection within the weight SVD.
init() picks the top-r singular dimensions by S alone (PiSSA-style).
group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions
that are both large in W AND active given real inputs.
If calibration_data is None the weight-SVD init from init() is kept.
group_init() re-selects by score[i] = S[i] * pool|X @ Vh[i]|: dimensions
that are both large in W AND active on real inputs. pool = 'rms' (outlier-
sensitive, the ASVD intuition that activation outliers carry signal) or
'mean_abs' (the original, outlier-robust). If calibration_data is None the
weight-SVD init from init() is kept.
"""
if calibration_data is None:
return
@@ -160,6 +149,7 @@ class AntiPaSTO:
h.remove()
r = cfg.r
pool = cfg.act_pool
for name, layer in layers.items():
X = torch.cat(captured[name], dim=0) # (N, d_in)
if X.shape[0] < r:
@@ -167,17 +157,19 @@ class AntiPaSTO:
f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}"
)
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r components
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r.
W_res = layer.weight.data.float()
U_old = layer.lora_U.float() # (d_out, r)
S_old = layer.lora_S.float() # (r,)
Vh_old = layer.lora_Vh.float() # (r, d_in)
U_old = layer.lora_U.float()
S_old = layer.lora_S.float()
Vh_old = layer.lora_Vh.float()
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
# Full SVD to score all dimensions
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
# score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude)
act_mag = (X @ Vh_full.T).abs().mean(dim=0) # (k,)
proj = X.to(Vh_full) @ Vh_full.T # (N, k) input in S-coords (X captured on CPU)
if pool == "rms":
act_mag = proj.pow(2).mean(dim=0).sqrt() # outlier-sensitive
else:
act_mag = proj.abs().mean(dim=0) # outlier-robust (original)
scores = S_full * act_mag
idx = scores.argsort(descending=True)[:r] # top-r by joint importance
idx = idx.sort().values # stable ordering
@@ -198,35 +190,31 @@ class AntiPaSTO:
y: Float[T, '*B o'],
) -> Float[T, '*B o']:
cfg = layer._lora_cfg
bs = int(cfg.block_size)
max_angle = float(cfg.max_rotation_angle)
rotate_basis = cfg.rotate_basis
U = layer.lora_U.to(x.dtype) # (d_out, r)
S = layer.lora_S.to(x.dtype) # (r,)
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
g = layer.lora_g.to(x.dtype) # (r,)
coeff = float(cfg.coeff)
if rotate_basis == "none":
U_eff, Vh_eff = U, Vh
else:
R_blocks = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype)
n_blocks = R_blocks.shape[0] # R_blocks: (n, bs, bs)
if rotate_basis == "V":
Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks)
Vh_rot = einsum(R_blocks, Vh_blocks, "n a b, n b i -> n a i")
Vh_eff = rearrange(Vh_rot, "n a i -> (n a) i")
U_eff = U
elif rotate_basis == "U":
U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks)
U_rot = einsum(U_blocks, R_blocks, "d n b, n c b -> d n c")
U_eff = rearrange(U_rot, "d n c -> d (n c)")
Vh_eff = Vh
else:
raise ValueError(f"rotate_basis must be 'U', 'V', or 'none', got {rotate_basis!r}")
if cfg.suppress_only:
g = torch.clamp(g, max=0.0) # factor in (0,1]: attenuation only
# FIXME: try lora_delta_s as [r,k] this is because the main limit of this adapter is that it's under parametised here. `reduce(h @ U_eff.T, '... k -> ...'). But have to make sure it's not lienarly reducable to one adapter.
S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,)
h = x @ Vh_eff.T # x @ Vh_eff.T
h = h * S_eff # diag(S_eff)
delta = h @ U_eff.T # @ U_eff.T
return y + delta
# Per-direction reweighting: S_eff = S * (1 + ELU(coeff * g)).
# 1 + ELU(z) = exp(z) for z<=0, 1+z for z>0.
# Why this and not the obvious ones (all of which we tried):
# linear S*(1+z) : constant gradient (stable), but z<-1 -> S_eff<0,
# a sign flip that drives incoherence. Unstable in
# the negatives.
# exp S*exp(z) : positive, but unbounded and the gradient self-
# amplifies (d/dz exp = exp), so amplification blows up.
# tanh S*exp(c*tanh z): bounded, but c is an arbitrary free knob with no
# principled value, and saturation kills the gradient.
# 1+ELU : uses each in its safe regime -- exp only where it is
# bounded in (0,1] (attenuation, cannot go negative),
# linear where exp would diverge (amplification, const
# gradient). C1 at z=0 (both -> 1, slope 1); >0 always.
# coeff=0 or g=0 -> S_eff = S (identity). coeff<0 swaps amplify/suppress.
S_eff = S * (1.0 + torch.nn.functional.elu(coeff * g))
h = (x @ Vh.T) * S_eff # input in S-coords, reweighted
return y + h @ U.T
+194
View File
@@ -0,0 +1,194 @@
"""AntiPaSTO-Ablate: trainable directional ablation in the weight-SVD output basis.
A contractive sibling of antipasto.py. Instead of reweighting the singular gains,
it projects out a learned direction in the *output* singular basis (the U side):
W = U diag(S) Vh + W_res
learn: c (r, k) ablation directions, alpha (k,) strengths in [0, 1]
Chat = orthonormal(c) # k unit dirs in S-space
h = (x @ Vh.T) * S # output S-coords (= diag(S) Vh x)
h <- h - coeff * (h @ Chat) * alpha @ Chat.T # project the span out
y = x @ W_res.T + h @ U.T
Why this instead of gain reweighting (antipasto.py):
- The core (I - alpha Chat Chat^T) is a CONTRACTION: eigenvalues are 1-alpha along
Chat and 1 elsewhere, all in [0, 1] for alpha in [0, 1]. It cannot amplify and
cannot blow up, so the failure mode the multiplicative gain fights with bounds is
structurally absent. It is also the natural core to recurse (a contraction composed
with itself converges; an amplifier diverges).
- It is the trainable form of directional ablation (Arditi+ 2024). Ablating Chat in
the middle removes output direction U Chat; for a residual *writer*
(mlp.down_proj, self_attn.o_proj) that is a residual-stream direction -- the
SURGICAL regime in the steering-lite sweeps (directional_ablation topped SI).
Target writers, not all Linears, or you get the broad-suppression regime.
Runtime: coeff is the per-call knob. coeff=0 -> identity. coeff in (0, 1] -> ablate.
coeff < 0 -> *add* the direction back (amplify) -- the bidirectional dual; this is the
side that can grow, so bound coeff there.
Init: alpha small (>0 so c receives gradient), c random-normalized. The strong init is
to warm-start c from the contrastive direction dS in S-space (extract it exactly like
sspace.py: dS = mean(xS_pos) - mean(xS_neg) on persona-branching pairs), then fine-tune.
"""
from dataclasses import dataclass
from typing import Iterable
import torch
from einops import rearrange
from jaxtyping import Float
from torch import nn, Tensor as T
from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
CalibrationBatch = dict | tuple | list | T
CalibrationData = Iterable[CalibrationBatch]
ε = 1e-6
@register_config
@dataclass
class AntiPaSTOAblateConfig(AdapterConfig):
variant: str = "antipasto_ablate"
r: int = 256 # top-r SVD captured (or |dS|-selected via group_init)
k: int = 1 # number of ablation directions (rank of the projection)
init_alpha: float = 0.05 # small >0 so c gets gradient at step 0
coeff: float = 1.0 # runtime: 0=identity, (0,1]=ablate, <0=amplify (bound this side)
# CorDA-orient the basis from input covariance (group_init, needs calibration_data).
# The ablation is OUTPUT-side and CorDA's U stays orthonormal, so this is a clean
# contraction; the win is at low r -- the data-oriented top-r captures the behavior
# output direction that plain-SVD top-r drops (measured 1.00 vs 0.65 at r=16).
cov_orient: bool = False
cov_eps: float = 1e-3
@register
class AntiPaSTOAblate:
name = "antipasto_ablate"
@staticmethod
def param_specs(d_in, d_out, cfg):
r, k = cfg.r, cfg.k
return dict(
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
# Trainable: k ablation directions in S-space, and their strengths.
lora_c=ParamSpec((r, k), init=lambda t: t.normal_(0, 1.0 / max(r, 1) ** 0.5)),
lora_alpha=ParamSpec((k,), init=lambda t: t.fill_(float(cfg.init_alpha))),
)
@staticmethod
def init(layer: nn.Module, cfg) -> None:
if type(layer) is not nn.Linear:
raise TypeError("AntiPaSTOAblate mutates layer.weight into W_res; nn.Linear only.")
with torch.no_grad():
W = layer.weight.data.float()
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
r = cfg.r
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
layer.weight.data.copy_(W_res)
# Optional but recommended: group_init() should warm-start lora_c from the
# S-space contrastive direction dS (see sspace.py extract). Random init also
# trains, just slower and with no guarantee it finds the behavior direction.
@staticmethod
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
"""If cov_orient, re-orient each target's SVD by input covariance C=E[x x^T]
(CorDA) so the data-relevant output directions land in the top-r and the
behavior direction is fully ablatable at low r. No-op otherwise (keeps the
plain-SVD basis from init()). C is accumulated on CPU; for down_proj's large
d_in this is heavy -- exclude it or use plain ablation there."""
if not getattr(cfg, "cov_orient", False) or calibration_data is None:
return
layers = {name: layer for name, layer, _ in targets}
cov: dict[str, T] = {}
cnt: dict[str, int] = {n: 0 for n in layers}
def make_hook(name):
def _h(module, args, kwargs):
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
g = x.T @ x
cov[name] = g if name not in cov else cov[name] + g
cnt[name] += x.shape[0]
return _h
handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers]
try:
was_training = model.training
model.eval()
with torch.no_grad():
for batch in calibration_data:
if isinstance(batch, dict):
model(**batch)
elif isinstance(batch, (list, tuple)):
model(*batch)
else:
model(batch)
if was_training:
model.train()
finally:
for h in handles:
h.remove()
r, eps = cfg.r, float(cfg.cov_eps)
for name, layer in layers.items():
if cnt[name] < r:
raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}")
W_res = layer.weight.data.float().cpu()
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
layer.lora_S.float().cpu(),
layer.lora_Vh.float().cpu())
W_orig = W_res + (U_old * S_old) @ Vh_old
C = cov[name] / cnt[name]
lam, Q = torch.linalg.eigh(C)
lam = lam.clamp_min(0) + eps
Chalf = (Q * lam.sqrt()) @ Q.T
Cinvhalf = (Q * lam.rsqrt()) @ Q.T
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
Sr = St[:r]
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
W_res_new = W_orig - (Ur * Sr) @ Pr
with torch.no_grad():
layer.lora_U.copy_(Ur.to(layer.lora_U))
layer.lora_S.copy_(Sr.to(layer.lora_S))
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
layer.weight.data.copy_(W_res_new.to(layer.weight))
@staticmethod
def _orthonormal(c: T) -> T:
"""(r, k) -> (r, k) with orthonormal columns. k=1 is a plain normalize."""
if c.shape[-1] == 1:
return c / (c.norm(dim=0, keepdim=True) + ε)
# geqrf has no bf16/fp16 kernel (CPU or CUDA); do the QR in fp32, cast back.
q, _ = torch.linalg.qr(c.float()) # reduced QR; columns orthonormal
return q.to(c.dtype)
@staticmethod
def forward(
layer: nn.Module,
x: Float[T, '*B i'],
y: Float[T, '*B o'],
) -> Float[T, '*B o']:
cfg = layer._lora_cfg
U = layer.lora_U.to(x.dtype) # (d_out, r)
S = layer.lora_S.to(x.dtype) # (r,)
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
Chat = AntiPaSTOAblate._orthonormal(layer.lora_c.to(x.dtype)) # (r, k)
alpha = layer.lora_alpha.to(x.dtype).clamp(0.0, 1.0) # (k,)
coeff = float(cfg.coeff)
h = (x @ Vh.T) * S # (..., r) output S-coords
proj = h @ Chat # (..., k) component along each dir
# contractive removal: h <- h - coeff * Sum_j alpha_j (h . chat_j) chat_j
h = h - coeff * (proj * alpha) @ Chat.T # (..., r)
return y + h @ U.T
+203
View File
@@ -0,0 +1,203 @@
"""AntiPaSTO-CorDA: steer in a covariance-ORIENTED basis, not the weight-gain basis.
The complaint that motivates this: plain SVD sorts directions by weight gain ||W v||
on an *isotropic* input. The behaviour you steer lives where the *data* has energy.
Those orderings disagree, so the behaviour smears off the top singular axes and a
top-r crop in the weight basis throws it away. CorDA (Yang+ 2024, arXiv:2406.05223)
re-orients the decomposition by the input covariance C = E[x x^T], so the top
directions are the ones with the most energy *on real activations*.
Decomposition (verified: full-rank reconstruction ~1e-5, and on anisotropic data the
top-r data-truncation error drops ~27x vs plain SVD):
C = E[x x^T] (+ eps I) # input second moment on calibration data
C^{1/2}, C^{-1/2} via eigh(C)
W~ = W C^{1/2}; SVD(W~) = U S V~h
P = V~h C^{-1/2} # (r, d_in) OBLIQUE input projector
W = U diag(S) P (exactly) # so y = x W_res^T + ((x P^T) * S_eff) U^T
S here are the singular values of W weighted by input std, so top-r is the optimal
rank-r in the input-weighted norm E||(W - W_r) x||^2 -- the directions that actually
move the output on your data.
Connection to the shared/differing-basis problem: C is built from pos AND neg inputs
pooled, so P spans the *shared* activation structure (the common encoder) that
chosen-minus-rejected cancels by construction. A trainable gain on this basis can
therefore reach shared structure that contrastive dS extraction is blind to.
Core: rotation-free. S_eff = S * (1 + ELU(coeff * g)). This is exp(coeff*g) on the
attenuation side (g<0, bounded, no blow-up) and 1+coeff*g on the amplification side
(g>0, where exp would diverge). g=0 -> identity. coeff is the runtime knob (0=off).
Basis note: P is OBLIQUE (rows not orthonormal -- C^{-1/2} skews them). That is fine
for gain reweighting (we scale oblique coordinates), and also fine for OUTPUT-side
directional ablation: the obliqueness is input-side only, while ablation acts in the
U/output space where U stays orthonormal. antipasto_ablate has a cov_orient flag that
reuses this basis -- at low r it captures the behavior output direction that plain-SVD
top-r drops (measured 1.00 vs 0.65 at r=16).
Falls back to plain SVD (== antipasto, rotation-free) if no calibration_data.
"""
from dataclasses import dataclass
from typing import Iterable
import torch
import torch.nn.functional as F
from einops import rearrange
from jaxtyping import Float
from torch import nn, Tensor as T
from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
CalibrationBatch = dict | tuple | list | T
CalibrationData = Iterable[CalibrationBatch]
@register_config
@dataclass
class AntiPaSTOCorDAConfig(AdapterConfig):
variant: str = "antipasto_corda"
r: int = 256
cov_eps: float = 1e-3 # damping on C eigenvalues; guards C^{-1/2} on rare dirs
coeff: float = 1.0 # runtime steer knob: 0=identity, scales trained g
suppress_only: bool = False # clamp g<=0 (attenuate only; no amplification)
def _gain(S: T, g: T, coeff: float, suppress_only: bool) -> T:
"""S_eff = S * (1 + ELU(coeff*g)); exp-bounded attenuation, linear amplification."""
if suppress_only:
g = g.clamp(max=0.0)
return S * (1.0 + F.elu(coeff * g))
@register
class AntiPaSTOCorDA:
name = "antipasto_corda"
@staticmethod
def param_specs(d_in, d_out, cfg):
r = cfg.r
return dict(
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
# P replaces Vh: oblique covariance-oriented input projector.
lora_P=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
# Trainable per-direction log-scale. init 0 -> 1+ELU(0)=1 -> exact identity.
# No sign-symmetry hack needed (1+ELU is sign-preserving, basis frozen),
# matching antipasto.py.
lora_g=ParamSpec((r,), init="zeros"),
)
@staticmethod
def init(layer: nn.Module, cfg) -> None:
"""Plain-SVD fallback so the adapter is valid before group_init. group_init
refines P/U/S to the covariance-oriented basis when calibration_data is given."""
if type(layer) is not nn.Linear:
raise TypeError("AntiPaSTOCorDA mutates layer.weight into W_res; nn.Linear only.")
with torch.no_grad():
W = layer.weight.data.float()
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
r = cfg.r
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
layer.lora_P.copy_(Vhr.to(layer.lora_P.dtype)) # P := Vh until oriented
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
layer.weight.data.copy_(W_res)
@staticmethod
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
"""Re-orient each target's SVD by its input covariance C = E[x x^T].
Without calibration_data the plain-SVD init from init() is kept (so this
degrades to antipasto, rotation-free).
Called by attach() BEFORE any training, so the trainable g is still at its
zero init when the basis changes -- re-orienting zero gains is a no-op, no
re-indexing needed. Do not call group_init after training has updated g."""
if calibration_data is None:
return
layers = {name: layer for name, layer, _ in targets}
# accumulate C = sum x x^T on CPU. Peak GPU cost would otherwise be
# sum_targets d_in^2 fp32 held at once; for down_proj (d_in=intermediate,
# e.g. 14336) that is ~0.8 GB *per layer* and OOMs. CPU accumulation bounds
# GPU use to the live activation; the eigh/SVD below run on CPU (one-time).
# Diagonal C is NOT a usable shortcut: it misses cross-channel correlation,
# which is where the orientation gain lives (measured ~= plain SVD).
# If down_proj's d_in^2 is too big even on CPU/RAM, exclude it from CorDA
# (leave it on plain antipasto) or use a low-rank C (top-k eig of subsampled
# inputs) -- not implemented here.
cov: dict[str, T] = {}
cnt: dict[str, int] = {n: 0 for n in layers}
def make_hook(name):
def _h(module, args, kwargs):
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
g = x.T @ x # (d_in, d_in) on CPU
cov[name] = g if name not in cov else cov[name] + g
cnt[name] += x.shape[0]
return _h
handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers]
try:
was_training = model.training
model.eval()
with torch.no_grad():
for batch in calibration_data:
if isinstance(batch, dict):
model(**batch)
elif isinstance(batch, (list, tuple)):
model(*batch)
else:
model(batch)
if was_training:
model.train()
finally:
for h in handles:
h.remove()
r, eps = cfg.r, float(cfg.cov_eps)
for name, layer in layers.items():
if cnt[name] < r:
raise RuntimeError(f"AntiPaSTOCorDA at {name}: {cnt[name]} tokens, need >= r={r}")
# decomposition on CPU (where C lives); copy results back to device buffers.
W_res = layer.weight.data.float().cpu()
U_old, S_old, P_old = (layer.lora_U.float().cpu(),
layer.lora_S.float().cpu(),
layer.lora_P.float().cpu())
W_orig = W_res + (U_old * S_old) @ P_old
C = cov[name] / cnt[name] # (d_in, d_in) CPU
lam, Q = torch.linalg.eigh(C)
lam = lam.clamp_min(0) + eps
Chalf = (Q * lam.sqrt()) @ Q.T # C^{1/2}
Cinvhalf = (Q * lam.rsqrt()) @ Q.T # C^{-1/2}
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
Ur = Ut[:, :r] # (d_out, r)
Sr = St[:r] # (r,)
Pr = (Vht[:r] @ Cinvhalf) # (r, d_in) oblique projector
W_res_new = (W_orig - (Ur * Sr) @ Pr)
with torch.no_grad():
layer.lora_U.copy_(Ur.to(layer.lora_U))
layer.lora_S.copy_(Sr.to(layer.lora_S))
layer.lora_P.copy_(Pr.to(layer.lora_P))
layer.weight.data.copy_(W_res_new.to(layer.weight))
@staticmethod
def forward(
layer: nn.Module,
x: Float[T, '*B i'],
y: Float[T, '*B o'],
) -> Float[T, '*B o']:
cfg = layer._lora_cfg
U = layer.lora_U.to(x.dtype) # (d_out, r)
S = layer.lora_S.to(x.dtype) # (r,)
P = layer.lora_P.to(x.dtype) # (r, d_in) oblique
g = layer.lora_g.to(x.dtype) # (r,)
S_eff = _gain(S, g, float(cfg.coeff), bool(cfg.suppress_only))
h = (x @ P.T) * S_eff # (..., r)
return y + h @ U.T
+232
View File
@@ -0,0 +1,232 @@
"""AntiPaSTO-Rot: the original SVD adapter with learnable singular-value deltas +
block-diagonal Cayley rotation. Kept as a SEPARATE variant so we can benchmark the
rotation version against the rotation-free 1+ELU gain (antipasto.py) head to head.
wassname 2026 https://arxiv.org/abs/2601.07473
W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r)
learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2)
R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T)
y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T
Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ~ x @ W^T (tiny positive bias on
delta_s breaks sign symmetry).
Why antipasto.py dropped the rotation: rotating Vh/U leaves the interpretable singular
basis, and the Cayley solve was numerically finicky. This file preserves it for the
all-else-equal comparison (does the cross-direction mixing the rotation buys beat the
cheaper, more stable gain-only adapter on the same targets and budget?).
Refs:
- paper: https://github.com/wassname/AntiPaSTO
- lite port of: https://github.com/wassname/antipasto3
(offline: docs/refs/antipasto3_svd_adapter.py)
"""
import math
from dataclasses import dataclass
from typing import Iterable, Literal
import torch
from einops import einsum, rearrange
from jaxtyping import Float
from torch import nn, Tensor as T
from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
CalibrationBatch = dict | tuple | list | T
CalibrationData = Iterable[CalibrationBatch]
@register_config
@dataclass
class AntiPaSTORotConfig(AdapterConfig):
variant: str = "antipasto_rot"
# Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out).
r: int = 256
# Block size for the block-diagonal Cayley rotation. r must be divisible by it.
block_size: int = 4
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
max_rotation_angle: float = 0.5
# Which singular basis to rotate: 'V' (input), 'U' (output), or 'none'.
rotate_basis: Literal["V", "U", "none"] = "V"
def _cayley(skew: torch.Tensor) -> torch.Tensor:
"""R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality."""
bs = skew.shape[-1]
eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew)
X = skew / 2
return torch.linalg.solve(eye - X, eye + X)
def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor:
"""rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation."""
n_blocks, _ = rot_T.shape
rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0)
A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device)
A[:, rows, cols] = rot_T
A = 0.5 * (A - A.transpose(-1, -2))
a_limit = 2.0 * math.tan(max_angle / 2.0)
A = a_limit * torch.tanh(A / a_limit)
return _cayley(A)
@register
class AntiPaSTORot:
name = "antipasto_rot"
@staticmethod
def param_specs(d_in, d_out, cfg):
r = cfg.r
bs = int(cfg.block_size)
if r % bs != 0:
raise ValueError(f"AntiPaSTORot requires r={r} divisible by block_size={bs}")
specs = dict(
# Frozen SVD components captured at init.
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
# Trainable: per-singular-value delta.
# antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign
# symmetry (rotation alone can't); zero-init works but trains slower.
lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)),
)
if cfg.rotate_basis != "none":
n_blocks = r // bs
n_triu = bs * (bs - 1) // 2
specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros")
return specs
@staticmethod
def init(layer: nn.Module, cfg) -> None:
if type(layer) is not nn.Linear:
raise TypeError(
"AntiPaSTORot mutates layer.weight into W_res (like PiSSA), so v1 "
"only supports plain nn.Linear, not bnb 4/8-bit."
)
with torch.no_grad():
W = layer.weight.data.float()
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
r = cfg.r
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
layer.weight.data.copy_(W_res)
# group_init() refines this to input-aligned directions if calibration_data is given.
@staticmethod
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
"""Wanda-style data-driven dimension selection within the weight SVD.
init() picks the top-r singular dimensions by S alone (PiSSA-style).
group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions
that are both large in W AND active given real inputs.
If calibration_data is None the weight-SVD init from init() is kept.
"""
if calibration_data is None:
return
layers = {name: layer for name, layer, _ in targets}
captured: dict[str, list[T]] = {n: [] for n in layers}
def make_hook(name):
def _h(module, args, kwargs):
x = args[0].detach()
captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu())
return _h
handles = [
layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True)
for n in layers
]
try:
was_training = model.training
model.eval()
with torch.no_grad():
for batch in calibration_data:
if isinstance(batch, dict):
model(**batch)
elif isinstance(batch, (list, tuple)):
model(*batch)
else:
model(batch)
if was_training:
model.train()
finally:
for h in handles:
h.remove()
r = cfg.r
for name, layer in layers.items():
X = torch.cat(captured[name], dim=0) # (N, d_in)
if X.shape[0] < r:
raise RuntimeError(
f"AntiPaSTORot at {name}: only {X.shape[0]} calibration tokens, need >= r={r}"
)
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r components
W_res = layer.weight.data.float()
U_old = layer.lora_U.float() # (d_out, r)
S_old = layer.lora_S.float() # (r,)
Vh_old = layer.lora_Vh.float() # (r, d_in)
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
# Full SVD to score all dimensions
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
# score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude)
act_mag = (X.to(Vh_full) @ Vh_full.T).abs().mean(dim=0) # (k,) -- X captured on CPU
scores = S_full * act_mag
idx = scores.argsort(descending=True)[:r] # top-r by joint importance
idx = idx.sort().values # stable ordering
Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx]
W_res_new = (W_orig - (Ur * Sr.unsqueeze(0)) @ Vhr).to(layer.weight.dtype)
with torch.no_grad():
layer.lora_U.copy_(Ur.to(layer.lora_U))
layer.lora_S.copy_(Sr.to(layer.lora_S))
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh))
layer.weight.data.copy_(W_res_new)
@staticmethod
def forward(
layer: nn.Module,
x: Float[T, '*B i'],
y: Float[T, '*B o'],
) -> Float[T, '*B o']:
cfg = layer._lora_cfg
bs = int(cfg.block_size)
max_angle = float(cfg.max_rotation_angle)
rotate_basis = cfg.rotate_basis
U = layer.lora_U.to(x.dtype) # (d_out, r)
S = layer.lora_S.to(x.dtype) # (r,)
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
if rotate_basis == "none":
U_eff, Vh_eff = U, Vh
else:
R_blocks = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype)
n_blocks = R_blocks.shape[0] # R_blocks: (n, bs, bs)
if rotate_basis == "V":
Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks)
Vh_rot = einsum(R_blocks, Vh_blocks, "n a b, n b i -> n a i")
Vh_eff = rearrange(Vh_rot, "n a i -> (n a) i")
U_eff = U
elif rotate_basis == "U":
U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks)
U_rot = einsum(U_blocks, R_blocks, "d n b, n c b -> d n c")
U_eff = rearrange(U_rot, "d n c -> d (n c)")
Vh_eff = Vh
else:
raise ValueError(f"rotate_basis must be 'U', 'V', or 'none', got {rotate_basis!r}")
S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,)
h = x @ Vh_eff.T # x @ Vh_eff.T
h = h * S_eff # diag(S_eff)
delta = h @ U_eff.T # @ U_eff.T
return y + delta
+3 -5
View File
@@ -14,7 +14,6 @@ from __future__ import annotations
import importlib.util
import sys
from dataclasses import replace
from pathlib import Path
import pytest
@@ -31,11 +30,12 @@ sys.modules[SPEC.name] = benchmark
SPEC.loader.exec_module(benchmark)
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"]
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
"antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "road"]
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
# so we don't expect a raise from them in the attach-only smoke.
BNB_RAISERS = {"pissa", "dora", "antipasto"}
BNB_RAISERS = {"pissa", "dora", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda"}
TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
HAS_CUDA = torch.cuda.is_available()
@@ -75,8 +75,6 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc
log_every=1000,
output_dir=tmp_path / "out",
)
if variant == "antipasto":
cfg = replace(cfg, alpha=4) # block_size=4 -> need r % 4 == 0
return cfg