Add antipasto_arrow: structured fixed-basis core (cross-direction mixing)

antipasto's diagonal core can only rescale each frozen singular direction; it
can never let direction i's input drive direction j's output, yet the steered
behaviour is an off-axis combination. A dense r x r core fixes that but costs
r^2 params. antipasto_arrow uses the arrowhead structure instead: a dense b x b
block on the top-b singular directions (full coupling where the action lives)
plus a diagonal 1+ELU tail on the rest. b^2 + (r-b) params, one b x b matmul
per forward -- cross-direction mixing at diagonal-core cost, no Cayley solve.

Identity at init (M=0 -> B=I, g=0 -> gain=1). Verified on a Linear: rel_err
1.5e-7 at init; M[i,j] routes input dir j -> output dir i with weight exactly
M[i,j] (diagonal core forces 0); 14 train params at r=8,b=3 vs r^2=64.

Wired into benchmark (antipasto_block knob), smoke (block=2 for r=4), cost
report, and exports.

Co-Authored-By: Claudypoo <noreply@anthropic.com>
This commit is contained in:
wassname
2026-06-14 19:18:59 +08:00
parent b80d7778af
commit 0d40cc9b38
6 changed files with 205 additions and 5 deletions
+2 -1
View File
@@ -49,7 +49,8 @@ def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="Qwen/Qwen3-0.6B-Base")
ap.add_argument("--variants", nargs="+",
default=["lora", "antipasto", "antipasto_rot", "antipasto_corda", "antipasto_ablate"])
default=["lora", "antipasto", "antipasto_rot", "antipasto_corda",
"antipasto_ablate", "antipasto_arrow"])
ap.add_argument("--target-name", nargs="+", default=[r"q_proj$", r"v_proj$"])
ap.add_argument("--r", type=int, default=32)
ap.add_argument("--layers", default="all",
+7 -1
View File
@@ -37,6 +37,7 @@ CFG_BY_VARIANT = {
"antipasto_rot": ll.AntiPaSTORotConfig,
"antipasto_ablate": ll.AntiPaSTOAblateConfig,
"antipasto_corda": ll.AntiPaSTOCorDAConfig,
"antipasto_arrow": ll.AntiPaSTOArrowConfig,
"road": ll.RoadConfig,
}
@@ -46,7 +47,7 @@ class BenchmarkConfig:
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
model: str = "Qwen/Qwen3-0.6B-Base"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "road"] = "lora"
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "antipasto_arrow", "road"] = "lora"
mode: Literal["benchmark", "probe"] = "benchmark"
device: str = "cuda"
torch_dtype: str = "bfloat16"
@@ -63,6 +64,8 @@ class BenchmarkConfig:
antipasto_cov_orient: bool = False
# AntiPaSTO-rot (legacy rotation variant) basis to rotate.
antipasto_rotate_basis: Literal["V", "U", "none"] = "V"
# AntiPaSTO-arrow: dense interaction block size on the top-b singular directions.
antipasto_block: int = 8
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
layers: str = "all"
train_dataset: str = "meta-math/MetaMathQA"
@@ -143,6 +146,9 @@ def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConf
if args.variant == "antipasto_ablate":
extra = {"coeff": args.antipasto_coeff, "k": args.antipasto_ablate_k,
"cov_orient": args.antipasto_cov_orient}
if args.variant == "antipasto_arrow":
extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only,
"block": args.antipasto_block}
return CFG_BY_VARIANT[args.variant](
r=args.r,
alpha=args.r if args.variant == "pissa" else args.alpha,
+2
View File
@@ -23,6 +23,7 @@ from .variants.antipasto import AntiPaSTOConfig
from .variants.antipasto_rot import AntiPaSTORotConfig
from .variants.antipasto_ablate import AntiPaSTOAblateConfig
from .variants.antipasto_corda import AntiPaSTOCorDAConfig
from .variants.antipasto_arrow import AntiPaSTOArrowConfig
from .variants.road import RoadConfig
__all__ = [
@@ -39,6 +40,7 @@ __all__ = [
"AntiPaSTORotConfig",
"AntiPaSTOAblateConfig",
"AntiPaSTOCorDAConfig",
"AntiPaSTOArrowConfig",
"RoadConfig",
"attach",
"detach",
+1 -1
View File
@@ -1,4 +1,4 @@
from . import ( # noqa: F401 side-effect: register
lora, pissa, delora, ia3, dora, hra, eva, antipasto, road,
antipasto_rot, antipasto_ablate, antipasto_corda,
antipasto_rot, antipasto_ablate, antipasto_corda, antipasto_arrow,
)
+188
View File
@@ -0,0 +1,188 @@
"""AntiPaSTO-Arrow: a STRUCTURED fixed-basis core, the cheap way to add cross-
direction mixing that plain antipasto (a diagonal gain) cannot express.
antipasto's core is diagonal: S_eff = S * (1 + ELU(coeff*g)) reweights each frozen
singular direction independently. It can turn a direction up or down but it can never
let direction i's input drive direction j's output. Yet the behaviour you steer is a
combination Sigma c_i v_i that generically lies OFF any single axis (the same argument
that motivates antipasto_corda), so a diagonal core can only ever approximate it.
The obvious fix -- a full dense r x r core M, DeltaW = U M Vh -- restores all mixing but
costs r^2 params (r=256 -> 65536, a rank-8 LoRA's worth) and an r x r matmul per forward.
antipasto.py's own header flags this trap: "a dense r x r core is r^2 params ... add a
*fixed-basis* core U M Vh rather than rotating". This file is that core, made cheap by
making it STRUCTURED instead of dense -- an arrowhead, not an r x r.
Arrowhead structure (dense top-block + diagonal tail):
core C (r x r, acting on the S-scaled coords) =
[ B (b x b dense) | 0 ] B couples the top-b directions
[ 0 | diag(1+ELU(c*g)) ] tail = exactly antipasto's gain
DeltaW = U @ C @ diag(S) @ Vh
The top b singular directions (largest S = where PiSSA says the action lives) get a full
b x b interaction block B = I_b + coeff*M; the remaining r-b stay on the cheap bounded
diagonal gain. Cost is b^2 + (r-b) params and one b x b matmul per forward -- for b=8,r=256
that is 312 params and a 64-MAC corner, versus 65536 for dense r x r and versus the
rotation variant's per-forward Cayley solve (measured 72ms vs 36ms). So: cross-direction
mixing where it matters, at diagonal-core cost.
(We call it "arrowhead" after the shape -- a dense head with a diagonal shaft. A true
numerical-LA arrowhead also carries a hub row+column coupling the block to the tail; that
would add 2(r-b) params and is a one-line extension if the top-b span turns out too small.
Not added until measured to be needed.)
Identity at init: M=0 -> B=I_b, g=0 -> 1+ELU(0)=1, so C=I and DeltaW = U diag(S) Vh exactly
(up to the one-time SVD-residual rounding). coeff=0 -> C=I too (runtime off). The block is
the linear-amplification regime of antipasto's design (a matmul, constant-gradient, no exp
self-amplification); it is stable like 1+ELU's upper branch, not strictly bounded -- if you
need the tail's structural can't-blow-up guarantee on the top directions too, use
antipasto_ablate instead.
Refs: antipasto.py (diagonal sibling), antipasto_corda.py (the off-axis argument).
"""
from dataclasses import dataclass
from typing import Iterable, Literal
import torch
import torch.nn.functional as F
from einops import rearrange
from jaxtyping import Float
from torch import nn, Tensor as T
from ..variant import register, ParamSpec
from ..config import AdapterConfig, register_config
CalibrationBatch = dict | tuple | list | T
CalibrationData = Iterable[CalibrationBatch]
@register_config
@dataclass
class AntiPaSTOArrowConfig(AdapterConfig):
variant: str = "antipasto_arrow"
r: int = 256
# Size of the dense interaction block on the top-b singular directions. The ONLY
# quadratic cost (b^2 params); keep small. b=1 degenerates to antipasto.
block: int = 8
suppress_only: bool = False # clamp the tail g<=0 (attenuate only); block unaffected
coeff: float = 1.0 # runtime knob: 0=identity, scales both block and tail
act_pool: Literal["rms", "mean_abs"] = "rms" # group_init selection, see antipasto
@register
class AntiPaSTOArrow:
name = "antipasto_arrow"
@staticmethod
def param_specs(d_in, d_out, cfg):
r, b = cfg.r, cfg.block
if not 1 <= b < r:
raise ValueError(f"antipasto_arrow needs 1 <= block({b}) < r({r}).")
return dict(
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
# Dense b x b interaction on the top-b directions. init 0 -> B=I -> identity.
lora_M=ParamSpec((b, b), init="zeros"),
# Diagonal bounded gain on the remaining r-b directions (== antipasto's g).
lora_g=ParamSpec((r - b,), init="zeros"),
)
@staticmethod
def init(layer: nn.Module, cfg) -> None:
if type(layer) is not nn.Linear:
raise TypeError("AntiPaSTOArrow mutates layer.weight into W_res; nn.Linear only.")
with torch.no_grad():
W = layer.weight.data.float()
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
r = cfg.r
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
layer.weight.data.copy_(W_res)
@staticmethod
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
"""Wanda-style data-driven dimension selection, identical to antipasto: re-pick
the top-r directions by S[i] * pool|X @ Vh[i]|. Runs before training (g, M at
their zero init), so re-selecting the basis is a harmless no-op on the core."""
if calibration_data is None:
return
layers = {name: layer for name, layer, _ in targets}
captured: dict[str, list[T]] = {n: [] for n in layers}
def make_hook(name):
def _h(module, args, kwargs):
x = args[0].detach()
captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu())
return _h
handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers]
try:
was_training = model.training
model.eval()
with torch.no_grad():
for batch in calibration_data:
if isinstance(batch, dict):
model(**batch)
elif isinstance(batch, (list, tuple)):
model(*batch)
else:
model(batch)
if was_training:
model.train()
finally:
for h in handles:
h.remove()
r, pool = cfg.r, cfg.act_pool
for name, layer in layers.items():
X = torch.cat(captured[name], dim=0)
if X.shape[0] < r:
raise RuntimeError(f"AntiPaSTOArrow at {name}: {X.shape[0]} tokens, need >= r={r}")
W_res = layer.weight.data.float()
W_orig = W_res + (layer.lora_U.float() * layer.lora_S.float()) @ layer.lora_Vh.float()
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
proj = X.to(Vh_full) @ Vh_full.T
act_mag = proj.pow(2).mean(0).sqrt() if pool == "rms" else proj.abs().mean(0)
idx = (S_full * act_mag).argsort(descending=True)[:r].sort().values
Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx]
W_res_new = (W_orig - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
with torch.no_grad():
layer.lora_U.copy_(Ur.to(layer.lora_U))
layer.lora_S.copy_(Sr.to(layer.lora_S))
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh))
layer.weight.data.copy_(W_res_new)
@staticmethod
def forward(
layer: nn.Module,
x: Float[T, '*B i'],
y: Float[T, '*B o'],
) -> Float[T, '*B o']:
cfg = layer._lora_cfg
U = layer.lora_U.to(x.dtype) # (d_out, r)
S = layer.lora_S.to(x.dtype) # (r,)
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
M = layer.lora_M.to(x.dtype) # (b, b)
g = layer.lora_g.to(x.dtype) # (r-b,)
coeff, b = float(cfg.coeff), cfg.block
cS = (x @ Vh.T) * S # (..., r) = diag(S) Vh x
# Top-b: dense block B = I_b + coeff*M couples the top singular directions.
eye = torch.eye(b, dtype=x.dtype, device=x.device)
top = cS[..., :b] @ (eye + coeff * M).T # (..., b)
# Tail: antipasto's bounded diagonal gain (see antipasto.py for the 1+ELU why).
if cfg.suppress_only:
g = torch.clamp(g, max=0.0)
tail = cS[..., b:] * (1.0 + F.elu(coeff * g)) # (..., r-b)
h = torch.cat([top, tail], dim=-1) # (..., r)
return y + h @ U.T
+5 -2
View File
@@ -31,11 +31,13 @@ SPEC.loader.exec_module(benchmark)
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
"antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "road"]
"antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda",
"antipasto_arrow", "road"]
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
# so we don't expect a raise from them in the attach-only smoke.
BNB_RAISERS = {"pissa", "dora", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda"}
BNB_RAISERS = {"pissa", "dora", "antipasto", "antipasto_rot", "antipasto_ablate",
"antipasto_corda", "antipasto_arrow"}
TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
HAS_CUDA = torch.cuda.is_available()
@@ -57,6 +59,7 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc
quantization=quantization,
r=4,
alpha=8,
antipasto_block=2, # antipasto_arrow needs block < r (r=4 here)
target_name=target_name,
layers="all",
steps=2,