diff --git a/scripts/cost_report.py b/scripts/cost_report.py index 75534f8..df36457 100644 --- a/scripts/cost_report.py +++ b/scripts/cost_report.py @@ -49,7 +49,8 @@ def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--model", default="Qwen/Qwen3-0.6B-Base") ap.add_argument("--variants", nargs="+", - default=["lora", "antipasto", "antipasto_rot", "antipasto_corda", "antipasto_ablate"]) + default=["lora", "antipasto", "antipasto_rot", "antipasto_corda", + "antipasto_ablate", "antipasto_arrow"]) ap.add_argument("--target-name", nargs="+", default=[r"q_proj$", r"v_proj$"]) ap.add_argument("--r", type=int, default=32) ap.add_argument("--layers", default="all", diff --git a/scripts/metamath_gsm8k_benchmark.py b/scripts/metamath_gsm8k_benchmark.py index 867ec45..a335e54 100644 --- a/scripts/metamath_gsm8k_benchmark.py +++ b/scripts/metamath_gsm8k_benchmark.py @@ -37,6 +37,7 @@ CFG_BY_VARIANT = { "antipasto_rot": ll.AntiPaSTORotConfig, "antipasto_ablate": ll.AntiPaSTOAblateConfig, "antipasto_corda": ll.AntiPaSTOCorDAConfig, + "antipasto_arrow": ll.AntiPaSTOArrowConfig, "road": ll.RoadConfig, } @@ -46,7 +47,7 @@ class BenchmarkConfig: """MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI.""" model: str = "Qwen/Qwen3-0.6B-Base" - variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "road"] = "lora" + variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "antipasto_arrow", "road"] = "lora" mode: Literal["benchmark", "probe"] = "benchmark" device: str = "cuda" torch_dtype: str = "bfloat16" @@ -63,6 +64,8 @@ class BenchmarkConfig: antipasto_cov_orient: bool = False # AntiPaSTO-rot (legacy rotation variant) basis to rotate. antipasto_rotate_basis: Literal["V", "U", "none"] = "V" + # AntiPaSTO-arrow: dense interaction block size on the top-b singular directions. + antipasto_block: int = 8 target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS)) layers: str = "all" train_dataset: str = "meta-math/MetaMathQA" @@ -143,6 +146,9 @@ def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConf if args.variant == "antipasto_ablate": extra = {"coeff": args.antipasto_coeff, "k": args.antipasto_ablate_k, "cov_orient": args.antipasto_cov_orient} + if args.variant == "antipasto_arrow": + extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only, + "block": args.antipasto_block} return CFG_BY_VARIANT[args.variant]( r=args.r, alpha=args.r if args.variant == "pissa" else args.alpha, diff --git a/src/lora_lite/__init__.py b/src/lora_lite/__init__.py index fed00d7..a67401e 100644 --- a/src/lora_lite/__init__.py +++ b/src/lora_lite/__init__.py @@ -23,6 +23,7 @@ from .variants.antipasto import AntiPaSTOConfig from .variants.antipasto_rot import AntiPaSTORotConfig from .variants.antipasto_ablate import AntiPaSTOAblateConfig from .variants.antipasto_corda import AntiPaSTOCorDAConfig +from .variants.antipasto_arrow import AntiPaSTOArrowConfig from .variants.road import RoadConfig __all__ = [ @@ -39,6 +40,7 @@ __all__ = [ "AntiPaSTORotConfig", "AntiPaSTOAblateConfig", "AntiPaSTOCorDAConfig", + "AntiPaSTOArrowConfig", "RoadConfig", "attach", "detach", diff --git a/src/lora_lite/variants/__init__.py b/src/lora_lite/variants/__init__.py index 6eb9994..07ef5c3 100644 --- a/src/lora_lite/variants/__init__.py +++ b/src/lora_lite/variants/__init__.py @@ -1,4 +1,4 @@ from . import ( # noqa: F401 side-effect: register lora, pissa, delora, ia3, dora, hra, eva, antipasto, road, - antipasto_rot, antipasto_ablate, antipasto_corda, + antipasto_rot, antipasto_ablate, antipasto_corda, antipasto_arrow, ) diff --git a/src/lora_lite/variants/antipasto_arrow.py b/src/lora_lite/variants/antipasto_arrow.py new file mode 100644 index 0000000..1f3b7a4 --- /dev/null +++ b/src/lora_lite/variants/antipasto_arrow.py @@ -0,0 +1,188 @@ +"""AntiPaSTO-Arrow: a STRUCTURED fixed-basis core, the cheap way to add cross- +direction mixing that plain antipasto (a diagonal gain) cannot express. + +antipasto's core is diagonal: S_eff = S * (1 + ELU(coeff*g)) reweights each frozen +singular direction independently. It can turn a direction up or down but it can never +let direction i's input drive direction j's output. Yet the behaviour you steer is a +combination Sigma c_i v_i that generically lies OFF any single axis (the same argument +that motivates antipasto_corda), so a diagonal core can only ever approximate it. + +The obvious fix -- a full dense r x r core M, DeltaW = U M Vh -- restores all mixing but +costs r^2 params (r=256 -> 65536, a rank-8 LoRA's worth) and an r x r matmul per forward. +antipasto.py's own header flags this trap: "a dense r x r core is r^2 params ... add a +*fixed-basis* core U M Vh rather than rotating". This file is that core, made cheap by +making it STRUCTURED instead of dense -- an arrowhead, not an r x r. + +Arrowhead structure (dense top-block + diagonal tail): + + core C (r x r, acting on the S-scaled coords) = + + [ B (b x b dense) | 0 ] B couples the top-b directions + [ 0 | diag(1+ELU(c*g)) ] tail = exactly antipasto's gain + + DeltaW = U @ C @ diag(S) @ Vh + +The top b singular directions (largest S = where PiSSA says the action lives) get a full +b x b interaction block B = I_b + coeff*M; the remaining r-b stay on the cheap bounded +diagonal gain. Cost is b^2 + (r-b) params and one b x b matmul per forward -- for b=8,r=256 +that is 312 params and a 64-MAC corner, versus 65536 for dense r x r and versus the +rotation variant's per-forward Cayley solve (measured 72ms vs 36ms). So: cross-direction +mixing where it matters, at diagonal-core cost. + +(We call it "arrowhead" after the shape -- a dense head with a diagonal shaft. A true +numerical-LA arrowhead also carries a hub row+column coupling the block to the tail; that +would add 2(r-b) params and is a one-line extension if the top-b span turns out too small. +Not added until measured to be needed.) + +Identity at init: M=0 -> B=I_b, g=0 -> 1+ELU(0)=1, so C=I and DeltaW = U diag(S) Vh exactly +(up to the one-time SVD-residual rounding). coeff=0 -> C=I too (runtime off). The block is +the linear-amplification regime of antipasto's design (a matmul, constant-gradient, no exp +self-amplification); it is stable like 1+ELU's upper branch, not strictly bounded -- if you +need the tail's structural can't-blow-up guarantee on the top directions too, use +antipasto_ablate instead. + +Refs: antipasto.py (diagonal sibling), antipasto_corda.py (the off-axis argument). +""" +from dataclasses import dataclass +from typing import Iterable, Literal + +import torch +import torch.nn.functional as F +from einops import rearrange +from jaxtyping import Float +from torch import nn, Tensor as T + +from ..variant import register, ParamSpec +from ..config import AdapterConfig, register_config + +CalibrationBatch = dict | tuple | list | T +CalibrationData = Iterable[CalibrationBatch] + + +@register_config +@dataclass +class AntiPaSTOArrowConfig(AdapterConfig): + variant: str = "antipasto_arrow" + r: int = 256 + # Size of the dense interaction block on the top-b singular directions. The ONLY + # quadratic cost (b^2 params); keep small. b=1 degenerates to antipasto. + block: int = 8 + suppress_only: bool = False # clamp the tail g<=0 (attenuate only); block unaffected + coeff: float = 1.0 # runtime knob: 0=identity, scales both block and tail + act_pool: Literal["rms", "mean_abs"] = "rms" # group_init selection, see antipasto + + +@register +class AntiPaSTOArrow: + name = "antipasto_arrow" + + @staticmethod + def param_specs(d_in, d_out, cfg): + r, b = cfg.r, cfg.block + if not 1 <= b < r: + raise ValueError(f"antipasto_arrow needs 1 <= block({b}) < r({r}).") + return dict( + lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True), + lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), + lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), + # Dense b x b interaction on the top-b directions. init 0 -> B=I -> identity. + lora_M=ParamSpec((b, b), init="zeros"), + # Diagonal bounded gain on the remaining r-b directions (== antipasto's g). + lora_g=ParamSpec((r - b,), init="zeros"), + ) + + @staticmethod + def init(layer: nn.Module, cfg) -> None: + if type(layer) is not nn.Linear: + raise TypeError("AntiPaSTOArrow mutates layer.weight into W_res; nn.Linear only.") + with torch.no_grad(): + W = layer.weight.data.float() + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + r = cfg.r + Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :] + layer.lora_U.copy_(Ur.to(layer.lora_U.dtype)) + layer.lora_S.copy_(Sr.to(layer.lora_S.dtype)) + layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) + W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) + layer.weight.data.copy_(W_res) + + @staticmethod + def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: + """Wanda-style data-driven dimension selection, identical to antipasto: re-pick + the top-r directions by S[i] * pool|X @ Vh[i]|. Runs before training (g, M at + their zero init), so re-selecting the basis is a harmless no-op on the core.""" + if calibration_data is None: + return + + layers = {name: layer for name, layer, _ in targets} + captured: dict[str, list[T]] = {n: [] for n in layers} + + def make_hook(name): + def _h(module, args, kwargs): + x = args[0].detach() + captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu()) + return _h + + handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers] + try: + was_training = model.training + model.eval() + with torch.no_grad(): + for batch in calibration_data: + if isinstance(batch, dict): + model(**batch) + elif isinstance(batch, (list, tuple)): + model(*batch) + else: + model(batch) + if was_training: + model.train() + finally: + for h in handles: + h.remove() + + r, pool = cfg.r, cfg.act_pool + for name, layer in layers.items(): + X = torch.cat(captured[name], dim=0) + if X.shape[0] < r: + raise RuntimeError(f"AntiPaSTOArrow at {name}: {X.shape[0]} tokens, need >= r={r}") + W_res = layer.weight.data.float() + W_orig = W_res + (layer.lora_U.float() * layer.lora_S.float()) @ layer.lora_Vh.float() + U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False) + proj = X.to(Vh_full) @ Vh_full.T + act_mag = proj.pow(2).mean(0).sqrt() if pool == "rms" else proj.abs().mean(0) + idx = (S_full * act_mag).argsort(descending=True)[:r].sort().values + Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx] + W_res_new = (W_orig - (Ur * Sr) @ Vhr).to(layer.weight.dtype) + with torch.no_grad(): + layer.lora_U.copy_(Ur.to(layer.lora_U)) + layer.lora_S.copy_(Sr.to(layer.lora_S)) + layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh)) + layer.weight.data.copy_(W_res_new) + + @staticmethod + def forward( + layer: nn.Module, + x: Float[T, '*B i'], + y: Float[T, '*B o'], + ) -> Float[T, '*B o']: + cfg = layer._lora_cfg + U = layer.lora_U.to(x.dtype) # (d_out, r) + S = layer.lora_S.to(x.dtype) # (r,) + Vh = layer.lora_Vh.to(x.dtype) # (r, d_in) + M = layer.lora_M.to(x.dtype) # (b, b) + g = layer.lora_g.to(x.dtype) # (r-b,) + coeff, b = float(cfg.coeff), cfg.block + + cS = (x @ Vh.T) * S # (..., r) = diag(S) Vh x + + # Top-b: dense block B = I_b + coeff*M couples the top singular directions. + eye = torch.eye(b, dtype=x.dtype, device=x.device) + top = cS[..., :b] @ (eye + coeff * M).T # (..., b) + # Tail: antipasto's bounded diagonal gain (see antipasto.py for the 1+ELU why). + if cfg.suppress_only: + g = torch.clamp(g, max=0.0) + tail = cS[..., b:] * (1.0 + F.elu(coeff * g)) # (..., r-b) + + h = torch.cat([top, tail], dim=-1) # (..., r) + return y + h @ U.T diff --git a/tests/test_metamath_smoke.py b/tests/test_metamath_smoke.py index e66e1c6..b37a5f3 100644 --- a/tests/test_metamath_smoke.py +++ b/tests/test_metamath_smoke.py @@ -31,11 +31,13 @@ SPEC.loader.exec_module(benchmark) VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", - "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "road"] + "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", + "antipasto_arrow", "road"] # Variants that fail loud when attached on a bnb-loaded base (read dense weight in init). # delora/eva also read weight but currently silently dequant -- they produce sane attach, # so we don't expect a raise from them in the attach-only smoke. -BNB_RAISERS = {"pissa", "dora", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda"} +BNB_RAISERS = {"pissa", "dora", "antipasto", "antipasto_rot", "antipasto_ablate", + "antipasto_corda", "antipasto_arrow"} TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" HAS_CUDA = torch.cuda.is_available() @@ -57,6 +59,7 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc quantization=quantization, r=4, alpha=8, + antipasto_block=2, # antipasto_arrow needs block < r (r=4 here) target_name=target_name, layers="all", steps=2,