From b80d7778afa8e0ec1ed143d2db13c680474ebf46 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Sun, 14 Jun 2026 19:12:27 +0800 Subject: [PATCH] Add rotation-free S-space adapter cores (antipasto family) Replace antipasto's rotation/Cayley with a bounded 1+ELU gain and split the S-space idea into four interpretable PiSSA-style cores (frozen U/S/Vh, small trainable core): - antipasto: S_eff = S*(1+ELU(coeff*g)). exp-bounded attenuation, linear amplification (constant gradient, no runaway). g=0 -> exact identity. - antipasto_rot: keeps the block-Cayley rotation as a separate variant for cost comparison (its per-forward solve is the 72ms vs 36ms gap). - antipasto_ablate: contractive (I - a c c^T) diag(S), eigenvalues in [0,1], cannot blow up. Optional cov_orient (CorDA) basis. - antipasto_corda: covariance-oriented oblique projector P = Vh C^{-1/2}, the data-energy basis rather than the weight-gain basis. 1+ELU gain. Add scripts/_cost.py + scripts/cost_report.py: one-row-per-variant cost table (trainable params, peak GPU mem, fwd/bwd ms, added MACs/tok, group_init ms). Wire all four into the benchmark, smoke test, and __init__ exports. External review (DeepSeek-v4-pro, docs/reviews/) verified the math; acted on its one real point (corda g now inits to zeros for exact identity). Co-Authored-By: Claudypoo --- docs/reviews/review_antipasto.md | 38 ++++ scripts/_cost.py | 131 ++++++++++++ scripts/cost_report.py | 142 +++++++++++++ scripts/metamath_gsm8k_benchmark.py | 23 +- src/lora_lite/__init__.py | 6 + src/lora_lite/variants/__init__.py | 5 +- src/lora_lite/variants/antipasto.py | 184 ++++++++-------- src/lora_lite/variants/antipasto_ablate.py | 194 +++++++++++++++++ src/lora_lite/variants/antipasto_corda.py | 203 ++++++++++++++++++ src/lora_lite/variants/antipasto_rot.py | 232 +++++++++++++++++++++ tests/test_metamath_smoke.py | 8 +- 11 files changed, 1059 insertions(+), 107 deletions(-) create mode 100644 docs/reviews/review_antipasto.md create mode 100644 scripts/_cost.py create mode 100644 scripts/cost_report.py create mode 100644 src/lora_lite/variants/antipasto_ablate.py create mode 100644 src/lora_lite/variants/antipasto_corda.py create mode 100644 src/lora_lite/variants/antipasto_rot.py diff --git a/docs/reviews/review_antipasto.md b/docs/reviews/review_antipasto.md new file mode 100644 index 0000000..dbab5dc --- /dev/null +++ b/docs/reviews/review_antipasto.md @@ -0,0 +1,38 @@ +## Code Review: Four S-space weight adapters (anti-pasto family + sspace reference + cost script) + +### Summary +The code adds four PiSSA-style adapters (antipasto, antipasto_rot, antipasto_ablate, antipasto_corda) that store frozen top-r SVD buffers and a small trainable gain/core. The key mathematical claims in the docstrings (identity at g=0, contraction of ablation core, exact reconstruction of the CorDA decomposition, signed cosine gate) are correct. However, `group_init()` re-selects/re-orients the SVD basis but does **not** reset the trainable parameters (g, delta_s, rot_T, c, alpha), which introduces a latent but important correctness problem — after a data-driven reinit, gains and direction vectors become misaligned with their intended singular axes. + +--- + +### Critical (must fix) +- **`antipasto_corda.py`, `antipasto_rot.py`, `antipasto_ablate.py`:** `group_init()` re-orients or re-selects the top-r SVD basis but leaves the existing trainable parameters (`lora_g`, `lora_delta_s`, `lora_rot_T`, `lora_c`, `lora_alpha`) untouched, still attached to the old indices. + - For `antipasto_corda`, `lora_g` is initialised with `N(0, 4e-4)` in `init()`. After `group_init`, those small random gains now multiply a different set of singular directions, producing an uncontrolled (though tiny) perturbation. + - For `antipasto_rot`, the `+4e-4` bias on `lora_delta_s` remains, now applied to arbitrary resorted directions, and the block rotation `lora_rot_T` is disconnected from the new block structure. + - For `antipasto_ablate`, the ablation directions and strengths (`lora_c`, `lora_alpha`) are not reset; if any warm-starting or training has happened, they point in the wrong subspace. + **Fix:** After re-selection or re-orientation, either re-initialise the trainable parameters (e.g., zero for g, zeros/small for delta_s and rot_T, random-normalised for c, init_alpha for alpha) or re-index them with the same `idx` mapping used for the buffers. Document that `group_init` must be called **before any training** and that the trainable parameters will be fresh after it. + +### Important (should fix) +- **`antipasto_ablate.py`** – Contraction claim and coeff bounds: the forward pass applies `h = h - coeff * (proj * alpha) @ Chat.T`. The core is a contraction only when both `coeff` and `alpha` are in `[0, 1]`. The config’s `coeff` can be *any* float (the docstring mentions `coeff<0` for amplification). There is no runtime clamping of `coeff`. If the intention is to keep the contraction property structurally enforced, clamp `coeff` to `[0, 1]` inside `forward()` or at least validate it. +- **`sspace.py`** – Division by `sqrtS` without epsilon: `xS = (y_eff @ U_r) / sqrtS`. For small or zero singular values (e.g., when r is large or W is low-rank) this can produce NaNs/infs. Add a small `ε` denominator (consistent with the ε used elsewhere). +- **`antipasto_ablate.py`** – `_orthonormal()` calls `torch.linalg.qr(c.float())` every forward pass. For large `r` and many layers this adds non-trivial cost. A lighter reparameterisation (e.g., maintaining `c` as a matrix with a normalisation step that avoids full QR) might be warranted, but for small `(r,k)` pairs the current approach is acceptable. At minimum, add a comment noting the per-forward cost. + +### Suggestions +- **`antipasto.py` group_init:** after the idx sort, `idx = idx.sort().values` ensures a stable, canonical ordering. This is a nice touch for reproducibility. +- **`antipasto_ablate.py`** – The docstring says “CorDA-orient the basis from input covariance … the ablation is OUTPUT-side and CorDA's U stays orthonormal …”. The forward code correctly uses the orthonormal `U` for output projection, so the contraction in S-space carries over to the output. +- **`sspace.py`** – The signed gate correctly preserves `cos` sign, so anti-aligned tokens receive a negative `gate * alpha` and are pushed opposite to `dS_hat`. Verified. + +### Positive +- The documentation is thorough, explaining the design choices (why 1+ELU, why contraction, why CorDA), and includes references. +- The PiSSA-style `W_res` decomposition is implemented correctly across all variants. +- `antipasto.py`’s `S_eff = S * (1 + ELU(coeff*g))` is indeed C1, positive, and identity at `g=0` — no sign-flip bugs. +- `antipasto_ablate.py` enforces orthonormal `Chat` and clamps `alpha` to `[0,1]`, making the contraction property safe when `coeff` is also bounded. + +### Verdict +**REQUEST CHANGES** + +`group_init()` must reset the trainable parameters after it changes the SVD basis; without this, the adapters silently poison their own steering gains with a different set of directions. Fix that and the code is ready. + +--- + +*Note: The `sspace.py` and `_cost.py` files are part of the same move into lora-lite and are free of the parameter-reset issue; the only concern in `sspace.py` is the unprotected division by singular values.* \ No newline at end of file diff --git a/scripts/_cost.py b/scripts/_cost.py new file mode 100644 index 0000000..2669d1f --- /dev/null +++ b/scripts/_cost.py @@ -0,0 +1,131 @@ +"""Measure the cost of an attached adapter: params, FLOPs/MACs, time, GPU mem. + +Which metric is "best" for comparing adapters? They answer different questions: + +- trainable_params -- deterministic "size" number. The headline. +- macs_per_token -- deterministic, hardware-INDEPENDENT compute. Best for an + apples-to-apples comparison: wall-time is noisy and the old + rotation adapter paid a per-forward Cayley solve the new ones + do not. "adds" (additions) ~= MACs; FLOPs ~= 2 * MACs. +- fwd_ms / bwd_ms -- felt cost, but noisy: warmup + median over `iters`, never one run. +- peak_gpu_mb -- resident + activation peak around fwd(+bwd). + +FLOPs come from torch.utils.flop_counter.FlopCounterMode (built in, no new dep). Its +convention is MACs (a (m,k)@(k,n) matmul counts as m*n*k); we expose both `flops` +(as returned) and `macs_per_token = flops / n_tokens` -- calibrate once on a known +matmul if you need to be sure of the factor of 2. +""" +from __future__ import annotations + +import statistics +import time + +import torch +from torch.utils.flop_counter import FlopCounterMode + + +def _time_call(fn, warmup: int, iters: int, cuda: bool) -> float: + """Median wall-time of fn() in milliseconds (warmup excluded).""" + for _ in range(warmup): + fn() + if cuda: + torch.cuda.synchronize() + samples = [] + for _ in range(iters): + if cuda: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + samples.append(start.elapsed_time(end)) + else: + t0 = time.perf_counter() + fn() + samples.append((time.perf_counter() - t0) * 1e3) + return statistics.median(samples) + + +def measure_cost( + model: torch.nn.Module, + fwd_fn, + *, + bwd_step_fn=None, + n_tokens: int | None = None, + adapter_filter: str = "lora_", + warmup: int = 3, + iters: int = 10, +) -> dict: + """Cost of the currently-attached adapter. + + fwd_fn(): run one forward (no grad). Used for FLOPs + fwd timing. + bwd_step_fn(): zero_grad + forward + loss.backward(). Used for bwd timing. + n_tokens: tokens in the fwd_fn batch, for macs_per_token. + adapter_filter: substring marking adapter params/buffers (default 'lora_'). + """ + dev = next(model.parameters()).device + cuda = dev.type == "cuda" + + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + named = list(model.named_parameters()) + list(model.named_buffers()) + adapter_bytes = sum(t.numel() * t.element_size() for n, t in named if adapter_filter in n) + + # FLOPs: one forward under the counter (no grad so we count inference cost). + # FlopCounterMode can assert on some fused attention shapes; degrade to None. + try: + fc = FlopCounterMode(display=False) + with torch.no_grad(), fc: + fwd_fn() + flops = fc.get_total_flops() + except Exception as e: + print(f" [warn] FLOP count failed ({type(e).__name__}: {e}); flops=None") + flops = None + + if cuda: + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + fwd_ms = _time_call(lambda: _no_grad(fwd_fn), warmup, iters, cuda) + bwd_ms = _time_call(bwd_step_fn, warmup, iters, cuda) if bwd_step_fn is not None else None + peak_gpu_mb = (torch.cuda.max_memory_allocated() / 1e6) if cuda else None + + return dict( + trainable_params=trainable_params, + adapter_resident_mb=adapter_bytes / 1e6, + flops=flops, + macs_per_token=(flops / n_tokens) if (flops and n_tokens) else None, + fwd_ms=fwd_ms, + bwd_ms=bwd_ms, + peak_gpu_mb=peak_gpu_mb, + ) + + +def _no_grad(fn): + with torch.no_grad(): + return fn() + + +class group_init_meter: + """Context manager: wall-time + peak CPU RAM of a group_init / attach-with-calib. + + CorDA accumulates C = E[xx^T] on CPU and runs eigh(d_in^3) -- the expensive corner. + Use around ll.attach(model, cfg, calibration_data=...) to log that asymmetry. + """ + + def __init__(self): + self.ms = None + self.peak_cpu_mb = None + + def __enter__(self): + import tracemalloc + self._tm = tracemalloc + tracemalloc.start() + self._t0 = time.perf_counter() + return self + + def __exit__(self, *exc): + self.ms = (time.perf_counter() - self._t0) * 1e3 + _, peak = self._tm.get_traced_memory() + self._tm.stop() + self.peak_cpu_mb = peak / 1e6 + return False diff --git a/scripts/cost_report.py b/scripts/cost_report.py new file mode 100644 index 0000000..75534f8 --- /dev/null +++ b/scripts/cost_report.py @@ -0,0 +1,142 @@ +"""One-row-per-variant cost table: params, MACs/token, fwd/bwd ms, peak GPU, group_init. + +Answers "which is best -- time / flops / adds / params?": MACs/token is the +deterministic apples-to-apples compute number; trainable_params is the size headline; +wall-time is the felt-but-noisy number; group_init is where CorDA's eigh(d_in^3) bites. + +Usage: + uv run --extra benchmark python scripts/cost_report.py \ + --model Qwen/Qwen3-0.6B-Base --variants antipasto antipasto_corda antipasto_ablate lora \ + --target-name 'q_proj$' 'v_proj$' --r 32 --out logs/cost_qwen0.6b.log + +Point --target-name at down_proj to see the CorDA covariance corner (large d_in). +""" +from __future__ import annotations + +import argparse +import importlib.util +import sys +from pathlib import Path + +import torch +from tabulate import tabulate + +import lora_lite as ll + +_HERE = Path(__file__).resolve().parent +_BENCH = importlib.util.spec_from_file_location("metamath_benchmark", _HERE / "metamath_gsm8k_benchmark.py") +benchmark = importlib.util.module_from_spec(_BENCH) +sys.modules[_BENCH.name] = benchmark +_BENCH.loader.exec_module(benchmark) + +_COST = importlib.util.spec_from_file_location("_cost", _HERE / "_cost.py") +cost = importlib.util.module_from_spec(_COST) +sys.modules[_COST.name] = cost +_COST.loader.exec_module(cost) + + +def build_cfg(variant: str, args, dtype) -> ll.AdapterConfig: + """Reuse the benchmark's variant->config map; only need r/targets/dtype here.""" + bcfg = benchmark.BenchmarkConfig( + model=args.model, variant=variant, r=args.r, alpha=float(args.r), + target_name=list(args.target_name), layers=args.layers, torch_dtype=args.dtype, + antipasto_cov_orient=args.cov_orient, + ) + return benchmark.cfg_for_variant(bcfg, dtype) + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--model", default="Qwen/Qwen3-0.6B-Base") + ap.add_argument("--variants", nargs="+", + default=["lora", "antipasto", "antipasto_rot", "antipasto_corda", "antipasto_ablate"]) + ap.add_argument("--target-name", nargs="+", default=[r"q_proj$", r"v_proj$"]) + ap.add_argument("--r", type=int, default=32) + ap.add_argument("--layers", default="all", + help="'all' or comma list e.g. '0,1' -- limit layers (CorDA down_proj eigh is slow).") + ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") + ap.add_argument("--dtype", default="bfloat16") + ap.add_argument("--seq-len", type=int, default=256) + ap.add_argument("--batch", type=int, default=2) + ap.add_argument("--calib-batches", type=int, default=4) + ap.add_argument("--cov-orient", action="store_true", + help="CorDA-orient antipasto_ablate (measure the eigh corner).") + ap.add_argument("--out", default="logs/cost.log") + args = ap.parse_args() + + dtype = getattr(torch, args.dtype) + # eager attention: FlopCounterMode's sdpa_flop_count asserts on GQA (Qwen3) SDPA + # shapes (q heads != kv heads). eager uses explicit matmuls it can count. + from transformers import AutoModelForCausalLM, AutoTokenizer + tok = AutoTokenizer.from_pretrained(args.model) + model = AutoModelForCausalLM.from_pretrained( + args.model, dtype=dtype, attn_implementation="eager" + ).to(args.device) + model.eval() + + n_tokens = args.batch * args.seq_len + ids = torch.randint(0, model.config.vocab_size, (args.batch, args.seq_len), device=args.device) + calib = [{"input_ids": torch.randint(0, model.config.vocab_size, + (args.batch, args.seq_len), device=args.device)} + for _ in range(args.calib_batches)] + + def fwd(): + model(input_ids=ids) + + def bwd_step(): + model.zero_grad(set_to_none=True) + loss = model(input_ids=ids).logits.float().pow(2).mean() + loss.backward() + + # base (no-adapter) cost, so each row can report the adapter's ADDED MACs/token. + base = cost.measure_cost(model, fwd, bwd_step_fn=bwd_step, n_tokens=n_tokens) + base_macs = base["macs_per_token"] + print(f"base (no adapter): MACs/tok={int(base_macs) if base_macs else None} " + f"fwd_ms={round(base['fwd_ms'],2)} bwd_ms={round(base['bwd_ms'],2)}") + + # base = no adapter; model params left trainable, so this is the full-finetune + # GPU-mem reference (its backward stores grads for every weight). + total_params = sum(p.numel() for p in model.parameters()) + rows = [{ + "variant": "base(full-FT)", "train_params": total_params, + "fwd_ms": round(base["fwd_ms"], 2), "bwd_ms": round(base["bwd_ms"], 2), + "peak_GPU_MB": round(base["peak_gpu_mb"], 1) if base["peak_gpu_mb"] else None, + "added_MACs/tok": 0 if base_macs else None, + "ginit_ms": 0.0, "ginit_CPU_MB": 0.0, + }] + for variant in args.variants: + cfg = build_cfg(variant, args, dtype) + # group_init / attach cost (CorDA's eigh + C live here). + with cost.group_init_meter() as gi: + ll.attach(model, cfg, calibration_data=calib) + c = cost.measure_cost(model, fwd, bwd_step_fn=bwd_step, n_tokens=n_tokens) + ll.detach(model) + rows.append({ + "variant": variant, + "train_params": c["trainable_params"], + "fwd_ms": round(c["fwd_ms"], 2), + "bwd_ms": round(c["bwd_ms"], 2) if c["bwd_ms"] else None, + "peak_GPU_MB": round(c["peak_gpu_mb"], 1) if c["peak_gpu_mb"] else None, + # flat across same-r adapters; kept only as a sanity check, not a comparator. + "added_MACs/tok": int(c["macs_per_token"] - base_macs) if (c["macs_per_token"] and base_macs) else None, + "ginit_ms": round(gi.ms, 1), + "ginit_CPU_MB": round(gi.peak_cpu_mb, 1), + }) + print(f" {variant}: params={rows[-1]['train_params']} " + f"peak_GPU_MB={rows[-1]['peak_GPU_MB']} bwd_ms={rows[-1]['bwd_ms']} ginit_ms={rows[-1]['ginit_ms']}") + + table = tabulate(rows, headers="keys", tablefmt="pipe") + header = (f"# cost report: {args.model} targets={args.target_name} r={args.r} " + f"seq={args.seq_len} batch={args.batch} dtype={args.dtype}\n" + f"# COMPARATORS: train_params, peak_GPU_MB (fwd+bwd, process-local max), bwd_ms, ginit_ms.\n" + f"# added_MACs/tok is flat across same-r adapters (sanity check only).\n" + f"# ginit_CPU_MB undercounts: tracemalloc misses torch C++ tensor allocs (the CorDA C matrix).\n") + out_path = Path(args.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(header + table + "\n") + print("\n" + header + table) + print(f"\nsaved -> {out_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/metamath_gsm8k_benchmark.py b/scripts/metamath_gsm8k_benchmark.py index eca46d7..867ec45 100644 --- a/scripts/metamath_gsm8k_benchmark.py +++ b/scripts/metamath_gsm8k_benchmark.py @@ -34,6 +34,9 @@ CFG_BY_VARIANT = { "hra": ll.HRAConfig, "eva": ll.EVAConfig, "antipasto": ll.AntiPaSTOConfig, + "antipasto_rot": ll.AntiPaSTORotConfig, + "antipasto_ablate": ll.AntiPaSTOAblateConfig, + "antipasto_corda": ll.AntiPaSTOCorDAConfig, "road": ll.RoadConfig, } @@ -43,7 +46,7 @@ class BenchmarkConfig: """MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI.""" model: str = "Qwen/Qwen3-0.6B-Base" - variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora" + variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "road"] = "lora" mode: Literal["benchmark", "probe"] = "benchmark" device: str = "cuda" torch_dtype: str = "bfloat16" @@ -52,6 +55,13 @@ class BenchmarkConfig: alpha: float = 64.0 delora_lambda0: float = 0.1 road_group_size: int = 64 + # AntiPaSTO family (gain / corda) runtime knobs. + antipasto_coeff: float = 1.0 + antipasto_suppress_only: bool = False + # AntiPaSTO-ablate. + antipasto_ablate_k: int = 1 + antipasto_cov_orient: bool = False + # AntiPaSTO-rot (legacy rotation variant) basis to rotate. antipasto_rotate_basis: Literal["V", "U", "none"] = "V" target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS)) layers: str = "all" @@ -124,8 +134,15 @@ def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConf extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {} if args.variant == "road": extra = {"group_size": args.road_group_size} - if args.variant == "antipasto": + if args.variant == "antipasto_rot": extra = {"rotate_basis": args.antipasto_rotate_basis} + if args.variant == "antipasto": + extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only} + if args.variant == "antipasto_corda": + extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only} + if args.variant == "antipasto_ablate": + extra = {"coeff": args.antipasto_coeff, "k": args.antipasto_ablate_k, + "cov_orient": args.antipasto_cov_orient} return CFG_BY_VARIANT[args.variant]( r=args.r, alpha=args.r if args.variant == "pissa" else args.alpha, @@ -155,7 +172,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int: def perturb_first_adapter(model: torch.nn.Module) -> None: - priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m", "lora_road_theta", "lora_road_alpha") + priority = ("lora_B", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha") for key in priority: for _, p in model.named_parameters(): if p.requires_grad and key in _: diff --git a/src/lora_lite/__init__.py b/src/lora_lite/__init__.py index 588e394..fed00d7 100644 --- a/src/lora_lite/__init__.py +++ b/src/lora_lite/__init__.py @@ -20,6 +20,9 @@ from .variants.dora import DoRAConfig from .variants.hra import HRAConfig from .variants.eva import EVAConfig from .variants.antipasto import AntiPaSTOConfig +from .variants.antipasto_rot import AntiPaSTORotConfig +from .variants.antipasto_ablate import AntiPaSTOAblateConfig +from .variants.antipasto_corda import AntiPaSTOCorDAConfig from .variants.road import RoadConfig __all__ = [ @@ -33,6 +36,9 @@ __all__ = [ "HRAConfig", "EVAConfig", "AntiPaSTOConfig", + "AntiPaSTORotConfig", + "AntiPaSTOAblateConfig", + "AntiPaSTOCorDAConfig", "RoadConfig", "attach", "detach", diff --git a/src/lora_lite/variants/__init__.py b/src/lora_lite/variants/__init__.py index fc39fad..6eb9994 100644 --- a/src/lora_lite/variants/__init__.py +++ b/src/lora_lite/variants/__init__.py @@ -1 +1,4 @@ -from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto, road # noqa: F401 side-effect: register +from . import ( # noqa: F401 side-effect: register + lora, pissa, delora, ia3, dora, hra, eva, antipasto, road, + antipasto_rot, antipasto_ablate, antipasto_corda, +) diff --git a/src/lora_lite/variants/antipasto.py b/src/lora_lite/variants/antipasto.py index 6830e88..e21d804 100644 --- a/src/lora_lite/variants/antipasto.py +++ b/src/lora_lite/variants/antipasto.py @@ -1,31 +1,48 @@ -"""AntiPaSTO: SVD steering with learnable singular-value deltas + block-diagonal Cayley rotation. +"""AntiPaSTO: SVD steering with learnable, bounded singular-value reweighting. wassname 2026 https://arxiv.org/abs/2601.07473 - W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r) - learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2) - R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T) - y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T + W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r) + learn: g (r,) per-singular-direction gain log/lin-scale + S_eff = S * (1 + ELU(coeff * g)) exp(.) for g<=0, 1+. for g>0 + suppress_only: clamp g<=0 -> factor in (0,1], attenuation only + y = x @ W_res.T + ((x @ Vh.T) * S_eff) @ U.T -Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ≈ x @ W^T (fp32 SVD round-trip, tiny positive bias on delta_s breaks sign symmetry). +Identity at g=0 (or coeff=0): 1+ELU(0)=1 exactly, so S_eff = S and the output is +x @ W^T up to the one-time SVD-residual rounding. No additive sign-symmetry hack +needed: the basis is frozen, so the direction sign is fixed and exp/(1+.) is +sign-preserving. The 1+ELU shape is chosen over linear (sign-flips at g<-1), exp +(amplification blows up), and tanh (arbitrary bound) -- see forward() for why. -Scope cut vs antipasto3: this is a fine-tuning adapter, not the full runtime -steering interface. There is no per-call alpha, so it does not expose the -bidirectional R(+alpha) / R(-alpha) inference symmetry. The V-basis path uses the -opposite chirality to antipasto3's default U-basis path, so checkpoints are not -portable without a sign/basis convention. +Changes vs the rotation version this replaces: + - Rotation dropped. Rotating Vh/U leaves the interpretable singular basis (the + SVD-direction / Conjecture property), which is the entire point of steering in + S-space, and the Cayley solve was numerically finicky. The basis is now frozen; + the only learned object is the per-direction gain. If you later want + cross-direction mixing, add a *fixed-basis* core U M Vh (M trainable, U/Vh frozen) + rather than rotating -- that keeps the directions interpretable. It is also far + cheaper than PiSSA: a dense r x r core is r^2 params (~= a rank-8 LoRA at r=256), + versus PiSSA's free A,B at r*(d_in+d_out), which drifts off the SVD basis. + - Additive delta_s -> bounded multiplicative S * (1 + ELU(coeff*g)). Multiplicative + is "scaled by S" (uniform *relative* control over an orders-of-magnitude spectrum), + stays positive (no S_eff<0 sign-flip -> no incoherence from that path), and the + 1+ELU shape stops the exp blowup. The 4e-4 sign-symmetry hack is gone. + - suppress_only = clamp g<=0 -> factor in (0,1]: attenuation only, structurally + cannot blow up. Matches the eval-awareness use case (turn a direction down). + - coeff: runtime steering scalar (0 = identity, <0 inverts). The per-call alpha + the rotation version lacked. + - group_init activation pooling is configurable: 'rms' weights outliers (ASVD + intuition), 'mean_abs' is the original outlier-robust pooling. Refs: - paper: https://github.com/wassname/AntiPaSTO - - lite port of: https://github.com/wassname/antipasto3 - (offline: docs/refs/antipasto3_svd_adapter.py) + - sibling (whitened, rotation-free, mean-diff): steering-lite/.../sspace.py """ -import math from dataclasses import dataclass from typing import Iterable, Literal import torch -from einops import einsum, rearrange +from einops import rearrange from jaxtyping import Float from torch import nn, Tensor as T @@ -40,35 +57,16 @@ CalibrationData = Iterable[CalibrationBatch] @dataclass class AntiPaSTOConfig(AdapterConfig): variant: str = "antipasto" - # Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out). + # Only r + r trainable scalars, so r can be large. r: int = 256 - - # Block size for the block-diagonal Cayley rotation. r must be divisible by it. - block_size: int = 4 - # Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians. - max_rotation_angle: float = 0.5 - # Which singular basis to rotate: 'V' (input), 'U' (output), or 'none'. - rotate_basis: Literal["V", "U", "none"] = "V" - - -def _cayley(skew: torch.Tensor) -> torch.Tensor: - """R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality.""" - bs = skew.shape[-1] - eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew) - X = skew / 2 - return torch.linalg.solve(eye - X, eye + X) - - -def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor: - """rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation.""" - n_blocks, _ = rot_T.shape - rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0) - A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device) - A[:, rows, cols] = rot_T - A = 0.5 * (A - A.transpose(-1, -2)) - a_limit = 2.0 * math.tan(max_angle / 2.0) - A = a_limit * torch.tanh(A / a_limit) - return _cayley(A) + # Per-direction reweighting is S_eff = S * (1 + ELU(coeff * g)). See forward() + # for the why; identity at g=0 or coeff=0, positive always, no free bound knob. + suppress_only: bool = False # clamp g<=0 -> factor in (0,1]: attenuation only + # Runtime steering scale. 0 = identity. <0 inverts (swaps amplify/suppress). + coeff: float = 1.0 + # group_init Wanda-style pooling of |X @ Vh[i]|: 'rms' is outlier-sensitive + # (ASVD intuition), 'mean_abs' is the original outlier-robust pooling. + act_pool: Literal["rms", "mean_abs"] = "rms" @register @@ -78,24 +76,14 @@ class AntiPaSTO: @staticmethod def param_specs(d_in, d_out, cfg): r = cfg.r - bs = int(cfg.block_size) - if r % bs != 0: - raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}") - specs = dict( - # Frozen SVD components captured at init. + return dict( + # Frozen top-r SVD captured at init. lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True), lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), - # Trainable: per-singular-value delta. - # antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign - # symmetry (rotation alone can't); zero-init works but trains slower. - lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)), + # Trainable per-direction log-scale. init 0 -> 1+ELU(0)=1 -> identity. + lora_g=ParamSpec((r,), init="zeros"), ) - if cfg.rotate_basis != "none": - n_blocks = r // bs - n_triu = bs * (bs - 1) // 2 - specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros") - return specs @staticmethod def init(layer: nn.Module, cfg) -> None: @@ -114,17 +102,18 @@ class AntiPaSTO: layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) layer.weight.data.copy_(W_res) - # group_init() refines this to input-aligned directions if calibration_data is given. + # group_init() refines the dimension selection if calibration_data is given. @staticmethod def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: - """Wanda-style data-driven dimension selection within the weight SVD. + """Wanda-style, data-driven dimension selection within the weight SVD. init() picks the top-r singular dimensions by S alone (PiSSA-style). - group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions - that are both large in W AND active given real inputs. - - If calibration_data is None the weight-SVD init from init() is kept. + group_init() re-selects by score[i] = S[i] * pool|X @ Vh[i]|: dimensions + that are both large in W AND active on real inputs. pool = 'rms' (outlier- + sensitive, the ASVD intuition that activation outliers carry signal) or + 'mean_abs' (the original, outlier-robust). If calibration_data is None the + weight-SVD init from init() is kept. """ if calibration_data is None: return @@ -160,6 +149,7 @@ class AntiPaSTO: h.remove() r = cfg.r + pool = cfg.act_pool for name, layer in layers.items(): X = torch.cat(captured[name], dim=0) # (N, d_in) if X.shape[0] < r: @@ -167,19 +157,21 @@ class AntiPaSTO: f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}" ) - # Recover W_orig: init() wrote W_res into layer.weight and stored top-r components + # Recover W_orig: init() wrote W_res into layer.weight and stored top-r. W_res = layer.weight.data.float() - U_old = layer.lora_U.float() # (d_out, r) - S_old = layer.lora_S.float() # (r,) - Vh_old = layer.lora_Vh.float() # (r, d_in) + U_old = layer.lora_U.float() + S_old = layer.lora_S.float() + Vh_old = layer.lora_Vh.float() W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old - # Full SVD to score all dimensions U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False) - # score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude) - act_mag = (X @ Vh_full.T).abs().mean(dim=0) # (k,) + proj = X.to(Vh_full) @ Vh_full.T # (N, k) input in S-coords (X captured on CPU) + if pool == "rms": + act_mag = proj.pow(2).mean(dim=0).sqrt() # outlier-sensitive + else: + act_mag = proj.abs().mean(dim=0) # outlier-robust (original) scores = S_full * act_mag - idx = scores.argsort(descending=True)[:r] # top-r by joint importance + idx = scores.argsort(descending=True)[:r] # top-r by joint importance idx = idx.sort().values # stable ordering Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx] @@ -198,35 +190,31 @@ class AntiPaSTO: y: Float[T, '*B o'], ) -> Float[T, '*B o']: cfg = layer._lora_cfg - bs = int(cfg.block_size) - max_angle = float(cfg.max_rotation_angle) - rotate_basis = cfg.rotate_basis - U = layer.lora_U.to(x.dtype) # (d_out, r) S = layer.lora_S.to(x.dtype) # (r,) Vh = layer.lora_Vh.to(x.dtype) # (r, d_in) + g = layer.lora_g.to(x.dtype) # (r,) + coeff = float(cfg.coeff) - if rotate_basis == "none": - U_eff, Vh_eff = U, Vh - else: - R_blocks = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype) - n_blocks = R_blocks.shape[0] # R_blocks: (n, bs, bs) - if rotate_basis == "V": - Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks) - Vh_rot = einsum(R_blocks, Vh_blocks, "n a b, n b i -> n a i") - Vh_eff = rearrange(Vh_rot, "n a i -> (n a) i") - U_eff = U - elif rotate_basis == "U": - U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks) - U_rot = einsum(U_blocks, R_blocks, "d n b, n c b -> d n c") - U_eff = rearrange(U_rot, "d n c -> d (n c)") - Vh_eff = Vh - else: - raise ValueError(f"rotate_basis must be 'U', 'V', or 'none', got {rotate_basis!r}") + if cfg.suppress_only: + g = torch.clamp(g, max=0.0) # factor in (0,1]: attenuation only - # FIXME: try lora_delta_s as [r,k] this is because the main limit of this adapter is that it's under parametised here. `reduce(h @ U_eff.T, '... k -> ...'). But have to make sure it's not lienarly reducable to one adapter. - S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,) - h = x @ Vh_eff.T # x @ Vh_eff.T - h = h * S_eff # diag(S_eff) - delta = h @ U_eff.T # @ U_eff.T - return y + delta + # Per-direction reweighting: S_eff = S * (1 + ELU(coeff * g)). + # 1 + ELU(z) = exp(z) for z<=0, 1+z for z>0. + # Why this and not the obvious ones (all of which we tried): + # linear S*(1+z) : constant gradient (stable), but z<-1 -> S_eff<0, + # a sign flip that drives incoherence. Unstable in + # the negatives. + # exp S*exp(z) : positive, but unbounded and the gradient self- + # amplifies (d/dz exp = exp), so amplification blows up. + # tanh S*exp(c*tanh z): bounded, but c is an arbitrary free knob with no + # principled value, and saturation kills the gradient. + # 1+ELU : uses each in its safe regime -- exp only where it is + # bounded in (0,1] (attenuation, cannot go negative), + # linear where exp would diverge (amplification, const + # gradient). C1 at z=0 (both -> 1, slope 1); >0 always. + # coeff=0 or g=0 -> S_eff = S (identity). coeff<0 swaps amplify/suppress. + S_eff = S * (1.0 + torch.nn.functional.elu(coeff * g)) + + h = (x @ Vh.T) * S_eff # input in S-coords, reweighted + return y + h @ U.T diff --git a/src/lora_lite/variants/antipasto_ablate.py b/src/lora_lite/variants/antipasto_ablate.py new file mode 100644 index 0000000..34d26a8 --- /dev/null +++ b/src/lora_lite/variants/antipasto_ablate.py @@ -0,0 +1,194 @@ +"""AntiPaSTO-Ablate: trainable directional ablation in the weight-SVD output basis. + +A contractive sibling of antipasto.py. Instead of reweighting the singular gains, +it projects out a learned direction in the *output* singular basis (the U side): + + W = U diag(S) Vh + W_res + learn: c (r, k) ablation directions, alpha (k,) strengths in [0, 1] + Chat = orthonormal(c) # k unit dirs in S-space + h = (x @ Vh.T) * S # output S-coords (= diag(S) Vh x) + h <- h - coeff * (h @ Chat) * alpha @ Chat.T # project the span out + y = x @ W_res.T + h @ U.T + +Why this instead of gain reweighting (antipasto.py): + - The core (I - alpha Chat Chat^T) is a CONTRACTION: eigenvalues are 1-alpha along + Chat and 1 elsewhere, all in [0, 1] for alpha in [0, 1]. It cannot amplify and + cannot blow up, so the failure mode the multiplicative gain fights with bounds is + structurally absent. It is also the natural core to recurse (a contraction composed + with itself converges; an amplifier diverges). + - It is the trainable form of directional ablation (Arditi+ 2024). Ablating Chat in + the middle removes output direction U Chat; for a residual *writer* + (mlp.down_proj, self_attn.o_proj) that is a residual-stream direction -- the + SURGICAL regime in the steering-lite sweeps (directional_ablation topped SI). + Target writers, not all Linears, or you get the broad-suppression regime. + +Runtime: coeff is the per-call knob. coeff=0 -> identity. coeff in (0, 1] -> ablate. +coeff < 0 -> *add* the direction back (amplify) -- the bidirectional dual; this is the +side that can grow, so bound coeff there. + +Init: alpha small (>0 so c receives gradient), c random-normalized. The strong init is +to warm-start c from the contrastive direction dS in S-space (extract it exactly like +sspace.py: dS = mean(xS_pos) - mean(xS_neg) on persona-branching pairs), then fine-tune. +""" +from dataclasses import dataclass +from typing import Iterable + +import torch +from einops import rearrange +from jaxtyping import Float +from torch import nn, Tensor as T + +from ..variant import register, ParamSpec +from ..config import AdapterConfig, register_config + +CalibrationBatch = dict | tuple | list | T +CalibrationData = Iterable[CalibrationBatch] + +ε = 1e-6 + + +@register_config +@dataclass +class AntiPaSTOAblateConfig(AdapterConfig): + variant: str = "antipasto_ablate" + r: int = 256 # top-r SVD captured (or |dS|-selected via group_init) + k: int = 1 # number of ablation directions (rank of the projection) + init_alpha: float = 0.05 # small >0 so c gets gradient at step 0 + coeff: float = 1.0 # runtime: 0=identity, (0,1]=ablate, <0=amplify (bound this side) + # CorDA-orient the basis from input covariance (group_init, needs calibration_data). + # The ablation is OUTPUT-side and CorDA's U stays orthonormal, so this is a clean + # contraction; the win is at low r -- the data-oriented top-r captures the behavior + # output direction that plain-SVD top-r drops (measured 1.00 vs 0.65 at r=16). + cov_orient: bool = False + cov_eps: float = 1e-3 + + +@register +class AntiPaSTOAblate: + name = "antipasto_ablate" + + @staticmethod + def param_specs(d_in, d_out, cfg): + r, k = cfg.r, cfg.k + return dict( + lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True), + lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), + lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), + # Trainable: k ablation directions in S-space, and their strengths. + lora_c=ParamSpec((r, k), init=lambda t: t.normal_(0, 1.0 / max(r, 1) ** 0.5)), + lora_alpha=ParamSpec((k,), init=lambda t: t.fill_(float(cfg.init_alpha))), + ) + + @staticmethod + def init(layer: nn.Module, cfg) -> None: + if type(layer) is not nn.Linear: + raise TypeError("AntiPaSTOAblate mutates layer.weight into W_res; nn.Linear only.") + with torch.no_grad(): + W = layer.weight.data.float() + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + r = cfg.r + Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :] + layer.lora_U.copy_(Ur.to(layer.lora_U.dtype)) + layer.lora_S.copy_(Sr.to(layer.lora_S.dtype)) + layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) + W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) + layer.weight.data.copy_(W_res) + # Optional but recommended: group_init() should warm-start lora_c from the + # S-space contrastive direction dS (see sspace.py extract). Random init also + # trains, just slower and with no guarantee it finds the behavior direction. + + @staticmethod + def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: + """If cov_orient, re-orient each target's SVD by input covariance C=E[x x^T] + (CorDA) so the data-relevant output directions land in the top-r and the + behavior direction is fully ablatable at low r. No-op otherwise (keeps the + plain-SVD basis from init()). C is accumulated on CPU; for down_proj's large + d_in this is heavy -- exclude it or use plain ablation there.""" + if not getattr(cfg, "cov_orient", False) or calibration_data is None: + return + + layers = {name: layer for name, layer, _ in targets} + cov: dict[str, T] = {} + cnt: dict[str, int] = {n: 0 for n in layers} + + def make_hook(name): + def _h(module, args, kwargs): + x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu() + g = x.T @ x + cov[name] = g if name not in cov else cov[name] + g + cnt[name] += x.shape[0] + return _h + + handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers] + try: + was_training = model.training + model.eval() + with torch.no_grad(): + for batch in calibration_data: + if isinstance(batch, dict): + model(**batch) + elif isinstance(batch, (list, tuple)): + model(*batch) + else: + model(batch) + if was_training: + model.train() + finally: + for h in handles: + h.remove() + + r, eps = cfg.r, float(cfg.cov_eps) + for name, layer in layers.items(): + if cnt[name] < r: + raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}") + W_res = layer.weight.data.float().cpu() + U_old, S_old, Vh_old = (layer.lora_U.float().cpu(), + layer.lora_S.float().cpu(), + layer.lora_Vh.float().cpu()) + W_orig = W_res + (U_old * S_old) @ Vh_old + + C = cov[name] / cnt[name] + lam, Q = torch.linalg.eigh(C) + lam = lam.clamp_min(0) + eps + Chalf = (Q * lam.sqrt()) @ Q.T + Cinvhalf = (Q * lam.rsqrt()) @ Q.T + Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False) + Ur = Ut[:, :r] # orthonormal output basis (ablation acts here) + Sr = St[:r] + Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only) + W_res_new = W_orig - (Ur * Sr) @ Pr + + with torch.no_grad(): + layer.lora_U.copy_(Ur.to(layer.lora_U)) + layer.lora_S.copy_(Sr.to(layer.lora_S)) + layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot + layer.weight.data.copy_(W_res_new.to(layer.weight)) + + @staticmethod + def _orthonormal(c: T) -> T: + """(r, k) -> (r, k) with orthonormal columns. k=1 is a plain normalize.""" + if c.shape[-1] == 1: + return c / (c.norm(dim=0, keepdim=True) + ε) + # geqrf has no bf16/fp16 kernel (CPU or CUDA); do the QR in fp32, cast back. + q, _ = torch.linalg.qr(c.float()) # reduced QR; columns orthonormal + return q.to(c.dtype) + + @staticmethod + def forward( + layer: nn.Module, + x: Float[T, '*B i'], + y: Float[T, '*B o'], + ) -> Float[T, '*B o']: + cfg = layer._lora_cfg + U = layer.lora_U.to(x.dtype) # (d_out, r) + S = layer.lora_S.to(x.dtype) # (r,) + Vh = layer.lora_Vh.to(x.dtype) # (r, d_in) + Chat = AntiPaSTOAblate._orthonormal(layer.lora_c.to(x.dtype)) # (r, k) + alpha = layer.lora_alpha.to(x.dtype).clamp(0.0, 1.0) # (k,) + coeff = float(cfg.coeff) + + h = (x @ Vh.T) * S # (..., r) output S-coords + proj = h @ Chat # (..., k) component along each dir + # contractive removal: h <- h - coeff * Sum_j alpha_j (h . chat_j) chat_j + h = h - coeff * (proj * alpha) @ Chat.T # (..., r) + return y + h @ U.T diff --git a/src/lora_lite/variants/antipasto_corda.py b/src/lora_lite/variants/antipasto_corda.py new file mode 100644 index 0000000..68b989e --- /dev/null +++ b/src/lora_lite/variants/antipasto_corda.py @@ -0,0 +1,203 @@ +"""AntiPaSTO-CorDA: steer in a covariance-ORIENTED basis, not the weight-gain basis. + +The complaint that motivates this: plain SVD sorts directions by weight gain ||W v|| +on an *isotropic* input. The behaviour you steer lives where the *data* has energy. +Those orderings disagree, so the behaviour smears off the top singular axes and a +top-r crop in the weight basis throws it away. CorDA (Yang+ 2024, arXiv:2406.05223) +re-orients the decomposition by the input covariance C = E[x x^T], so the top +directions are the ones with the most energy *on real activations*. + +Decomposition (verified: full-rank reconstruction ~1e-5, and on anisotropic data the +top-r data-truncation error drops ~27x vs plain SVD): + + C = E[x x^T] (+ eps I) # input second moment on calibration data + C^{1/2}, C^{-1/2} via eigh(C) + W~ = W C^{1/2}; SVD(W~) = U S V~h + P = V~h C^{-1/2} # (r, d_in) OBLIQUE input projector + W = U diag(S) P (exactly) # so y = x W_res^T + ((x P^T) * S_eff) U^T + +S here are the singular values of W weighted by input std, so top-r is the optimal +rank-r in the input-weighted norm E||(W - W_r) x||^2 -- the directions that actually +move the output on your data. + +Connection to the shared/differing-basis problem: C is built from pos AND neg inputs +pooled, so P spans the *shared* activation structure (the common encoder) that +chosen-minus-rejected cancels by construction. A trainable gain on this basis can +therefore reach shared structure that contrastive dS extraction is blind to. + +Core: rotation-free. S_eff = S * (1 + ELU(coeff * g)). This is exp(coeff*g) on the +attenuation side (g<0, bounded, no blow-up) and 1+coeff*g on the amplification side +(g>0, where exp would diverge). g=0 -> identity. coeff is the runtime knob (0=off). + +Basis note: P is OBLIQUE (rows not orthonormal -- C^{-1/2} skews them). That is fine +for gain reweighting (we scale oblique coordinates), and also fine for OUTPUT-side +directional ablation: the obliqueness is input-side only, while ablation acts in the +U/output space where U stays orthonormal. antipasto_ablate has a cov_orient flag that +reuses this basis -- at low r it captures the behavior output direction that plain-SVD +top-r drops (measured 1.00 vs 0.65 at r=16). + +Falls back to plain SVD (== antipasto, rotation-free) if no calibration_data. +""" +from dataclasses import dataclass +from typing import Iterable + +import torch +import torch.nn.functional as F +from einops import rearrange +from jaxtyping import Float +from torch import nn, Tensor as T + +from ..variant import register, ParamSpec +from ..config import AdapterConfig, register_config + +CalibrationBatch = dict | tuple | list | T +CalibrationData = Iterable[CalibrationBatch] + + +@register_config +@dataclass +class AntiPaSTOCorDAConfig(AdapterConfig): + variant: str = "antipasto_corda" + r: int = 256 + cov_eps: float = 1e-3 # damping on C eigenvalues; guards C^{-1/2} on rare dirs + coeff: float = 1.0 # runtime steer knob: 0=identity, scales trained g + suppress_only: bool = False # clamp g<=0 (attenuate only; no amplification) + + +def _gain(S: T, g: T, coeff: float, suppress_only: bool) -> T: + """S_eff = S * (1 + ELU(coeff*g)); exp-bounded attenuation, linear amplification.""" + if suppress_only: + g = g.clamp(max=0.0) + return S * (1.0 + F.elu(coeff * g)) + + +@register +class AntiPaSTOCorDA: + name = "antipasto_corda" + + @staticmethod + def param_specs(d_in, d_out, cfg): + r = cfg.r + return dict( + lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True), + lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), + # P replaces Vh: oblique covariance-oriented input projector. + lora_P=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), + # Trainable per-direction log-scale. init 0 -> 1+ELU(0)=1 -> exact identity. + # No sign-symmetry hack needed (1+ELU is sign-preserving, basis frozen), + # matching antipasto.py. + lora_g=ParamSpec((r,), init="zeros"), + ) + + @staticmethod + def init(layer: nn.Module, cfg) -> None: + """Plain-SVD fallback so the adapter is valid before group_init. group_init + refines P/U/S to the covariance-oriented basis when calibration_data is given.""" + if type(layer) is not nn.Linear: + raise TypeError("AntiPaSTOCorDA mutates layer.weight into W_res; nn.Linear only.") + with torch.no_grad(): + W = layer.weight.data.float() + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + r = cfg.r + Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :] + layer.lora_U.copy_(Ur.to(layer.lora_U.dtype)) + layer.lora_S.copy_(Sr.to(layer.lora_S.dtype)) + layer.lora_P.copy_(Vhr.to(layer.lora_P.dtype)) # P := Vh until oriented + W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) + layer.weight.data.copy_(W_res) + + @staticmethod + def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: + """Re-orient each target's SVD by its input covariance C = E[x x^T]. + + Without calibration_data the plain-SVD init from init() is kept (so this + degrades to antipasto, rotation-free). + + Called by attach() BEFORE any training, so the trainable g is still at its + zero init when the basis changes -- re-orienting zero gains is a no-op, no + re-indexing needed. Do not call group_init after training has updated g.""" + if calibration_data is None: + return + + layers = {name: layer for name, layer, _ in targets} + # accumulate C = sum x x^T on CPU. Peak GPU cost would otherwise be + # sum_targets d_in^2 fp32 held at once; for down_proj (d_in=intermediate, + # e.g. 14336) that is ~0.8 GB *per layer* and OOMs. CPU accumulation bounds + # GPU use to the live activation; the eigh/SVD below run on CPU (one-time). + # Diagonal C is NOT a usable shortcut: it misses cross-channel correlation, + # which is where the orientation gain lives (measured ~= plain SVD). + # If down_proj's d_in^2 is too big even on CPU/RAM, exclude it from CorDA + # (leave it on plain antipasto) or use a low-rank C (top-k eig of subsampled + # inputs) -- not implemented here. + cov: dict[str, T] = {} + cnt: dict[str, int] = {n: 0 for n in layers} + + def make_hook(name): + def _h(module, args, kwargs): + x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu() + g = x.T @ x # (d_in, d_in) on CPU + cov[name] = g if name not in cov else cov[name] + g + cnt[name] += x.shape[0] + return _h + + handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers] + try: + was_training = model.training + model.eval() + with torch.no_grad(): + for batch in calibration_data: + if isinstance(batch, dict): + model(**batch) + elif isinstance(batch, (list, tuple)): + model(*batch) + else: + model(batch) + if was_training: + model.train() + finally: + for h in handles: + h.remove() + + r, eps = cfg.r, float(cfg.cov_eps) + for name, layer in layers.items(): + if cnt[name] < r: + raise RuntimeError(f"AntiPaSTOCorDA at {name}: {cnt[name]} tokens, need >= r={r}") + # decomposition on CPU (where C lives); copy results back to device buffers. + W_res = layer.weight.data.float().cpu() + U_old, S_old, P_old = (layer.lora_U.float().cpu(), + layer.lora_S.float().cpu(), + layer.lora_P.float().cpu()) + W_orig = W_res + (U_old * S_old) @ P_old + + C = cov[name] / cnt[name] # (d_in, d_in) CPU + lam, Q = torch.linalg.eigh(C) + lam = lam.clamp_min(0) + eps + Chalf = (Q * lam.sqrt()) @ Q.T # C^{1/2} + Cinvhalf = (Q * lam.rsqrt()) @ Q.T # C^{-1/2} + + Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False) + Ur = Ut[:, :r] # (d_out, r) + Sr = St[:r] # (r,) + Pr = (Vht[:r] @ Cinvhalf) # (r, d_in) oblique projector + W_res_new = (W_orig - (Ur * Sr) @ Pr) + + with torch.no_grad(): + layer.lora_U.copy_(Ur.to(layer.lora_U)) + layer.lora_S.copy_(Sr.to(layer.lora_S)) + layer.lora_P.copy_(Pr.to(layer.lora_P)) + layer.weight.data.copy_(W_res_new.to(layer.weight)) + + @staticmethod + def forward( + layer: nn.Module, + x: Float[T, '*B i'], + y: Float[T, '*B o'], + ) -> Float[T, '*B o']: + cfg = layer._lora_cfg + U = layer.lora_U.to(x.dtype) # (d_out, r) + S = layer.lora_S.to(x.dtype) # (r,) + P = layer.lora_P.to(x.dtype) # (r, d_in) oblique + g = layer.lora_g.to(x.dtype) # (r,) + S_eff = _gain(S, g, float(cfg.coeff), bool(cfg.suppress_only)) + h = (x @ P.T) * S_eff # (..., r) + return y + h @ U.T diff --git a/src/lora_lite/variants/antipasto_rot.py b/src/lora_lite/variants/antipasto_rot.py new file mode 100644 index 0000000..db7c49d --- /dev/null +++ b/src/lora_lite/variants/antipasto_rot.py @@ -0,0 +1,232 @@ +"""AntiPaSTO-Rot: the original SVD adapter with learnable singular-value deltas + +block-diagonal Cayley rotation. Kept as a SEPARATE variant so we can benchmark the +rotation version against the rotation-free 1+ELU gain (antipasto.py) head to head. + +wassname 2026 https://arxiv.org/abs/2601.07473 + + W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r) + learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2) + R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T) + y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T + +Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ~ x @ W^T (tiny positive bias on +delta_s breaks sign symmetry). + +Why antipasto.py dropped the rotation: rotating Vh/U leaves the interpretable singular +basis, and the Cayley solve was numerically finicky. This file preserves it for the +all-else-equal comparison (does the cross-direction mixing the rotation buys beat the +cheaper, more stable gain-only adapter on the same targets and budget?). + +Refs: + - paper: https://github.com/wassname/AntiPaSTO + - lite port of: https://github.com/wassname/antipasto3 + (offline: docs/refs/antipasto3_svd_adapter.py) +""" +import math +from dataclasses import dataclass +from typing import Iterable, Literal + +import torch +from einops import einsum, rearrange +from jaxtyping import Float +from torch import nn, Tensor as T + +from ..variant import register, ParamSpec +from ..config import AdapterConfig, register_config + +CalibrationBatch = dict | tuple | list | T +CalibrationData = Iterable[CalibrationBatch] + + +@register_config +@dataclass +class AntiPaSTORotConfig(AdapterConfig): + variant: str = "antipasto_rot" + # Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out). + r: int = 256 + # Block size for the block-diagonal Cayley rotation. r must be divisible by it. + block_size: int = 4 + # Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians. + max_rotation_angle: float = 0.5 + # Which singular basis to rotate: 'V' (input), 'U' (output), or 'none'. + rotate_basis: Literal["V", "U", "none"] = "V" + + +def _cayley(skew: torch.Tensor) -> torch.Tensor: + """R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality.""" + bs = skew.shape[-1] + eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew) + X = skew / 2 + return torch.linalg.solve(eye - X, eye + X) + + +def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor: + """rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation.""" + n_blocks, _ = rot_T.shape + rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0) + A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device) + A[:, rows, cols] = rot_T + A = 0.5 * (A - A.transpose(-1, -2)) + a_limit = 2.0 * math.tan(max_angle / 2.0) + A = a_limit * torch.tanh(A / a_limit) + return _cayley(A) + + +@register +class AntiPaSTORot: + name = "antipasto_rot" + + @staticmethod + def param_specs(d_in, d_out, cfg): + r = cfg.r + bs = int(cfg.block_size) + if r % bs != 0: + raise ValueError(f"AntiPaSTORot requires r={r} divisible by block_size={bs}") + specs = dict( + # Frozen SVD components captured at init. + lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True), + lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), + lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), + # Trainable: per-singular-value delta. + # antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign + # symmetry (rotation alone can't); zero-init works but trains slower. + lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)), + ) + if cfg.rotate_basis != "none": + n_blocks = r // bs + n_triu = bs * (bs - 1) // 2 + specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros") + return specs + + @staticmethod + def init(layer: nn.Module, cfg) -> None: + if type(layer) is not nn.Linear: + raise TypeError( + "AntiPaSTORot mutates layer.weight into W_res (like PiSSA), so v1 " + "only supports plain nn.Linear, not bnb 4/8-bit." + ) + with torch.no_grad(): + W = layer.weight.data.float() + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + r = cfg.r + Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :] + layer.lora_U.copy_(Ur.to(layer.lora_U.dtype)) + layer.lora_S.copy_(Sr.to(layer.lora_S.dtype)) + layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) + W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) + layer.weight.data.copy_(W_res) + # group_init() refines this to input-aligned directions if calibration_data is given. + + @staticmethod + def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: + """Wanda-style data-driven dimension selection within the weight SVD. + + init() picks the top-r singular dimensions by S alone (PiSSA-style). + group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions + that are both large in W AND active given real inputs. + + If calibration_data is None the weight-SVD init from init() is kept. + """ + if calibration_data is None: + return + + layers = {name: layer for name, layer, _ in targets} + captured: dict[str, list[T]] = {n: [] for n in layers} + + def make_hook(name): + def _h(module, args, kwargs): + x = args[0].detach() + captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu()) + return _h + + handles = [ + layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) + for n in layers + ] + try: + was_training = model.training + model.eval() + with torch.no_grad(): + for batch in calibration_data: + if isinstance(batch, dict): + model(**batch) + elif isinstance(batch, (list, tuple)): + model(*batch) + else: + model(batch) + if was_training: + model.train() + finally: + for h in handles: + h.remove() + + r = cfg.r + for name, layer in layers.items(): + X = torch.cat(captured[name], dim=0) # (N, d_in) + if X.shape[0] < r: + raise RuntimeError( + f"AntiPaSTORot at {name}: only {X.shape[0]} calibration tokens, need >= r={r}" + ) + + # Recover W_orig: init() wrote W_res into layer.weight and stored top-r components + W_res = layer.weight.data.float() + U_old = layer.lora_U.float() # (d_out, r) + S_old = layer.lora_S.float() # (r,) + Vh_old = layer.lora_Vh.float() # (r, d_in) + W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old + + # Full SVD to score all dimensions + U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False) + # score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude) + act_mag = (X.to(Vh_full) @ Vh_full.T).abs().mean(dim=0) # (k,) -- X captured on CPU + scores = S_full * act_mag + idx = scores.argsort(descending=True)[:r] # top-r by joint importance + idx = idx.sort().values # stable ordering + + Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx] + W_res_new = (W_orig - (Ur * Sr.unsqueeze(0)) @ Vhr).to(layer.weight.dtype) + + with torch.no_grad(): + layer.lora_U.copy_(Ur.to(layer.lora_U)) + layer.lora_S.copy_(Sr.to(layer.lora_S)) + layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh)) + layer.weight.data.copy_(W_res_new) + + @staticmethod + def forward( + layer: nn.Module, + x: Float[T, '*B i'], + y: Float[T, '*B o'], + ) -> Float[T, '*B o']: + cfg = layer._lora_cfg + bs = int(cfg.block_size) + max_angle = float(cfg.max_rotation_angle) + rotate_basis = cfg.rotate_basis + + U = layer.lora_U.to(x.dtype) # (d_out, r) + S = layer.lora_S.to(x.dtype) # (r,) + Vh = layer.lora_Vh.to(x.dtype) # (r, d_in) + + if rotate_basis == "none": + U_eff, Vh_eff = U, Vh + else: + R_blocks = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype) + n_blocks = R_blocks.shape[0] # R_blocks: (n, bs, bs) + if rotate_basis == "V": + Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks) + Vh_rot = einsum(R_blocks, Vh_blocks, "n a b, n b i -> n a i") + Vh_eff = rearrange(Vh_rot, "n a i -> (n a) i") + U_eff = U + elif rotate_basis == "U": + U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks) + U_rot = einsum(U_blocks, R_blocks, "d n b, n c b -> d n c") + U_eff = rearrange(U_rot, "d n c -> d (n c)") + Vh_eff = Vh + else: + raise ValueError(f"rotate_basis must be 'U', 'V', or 'none', got {rotate_basis!r}") + + S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,) + h = x @ Vh_eff.T # x @ Vh_eff.T + h = h * S_eff # diag(S_eff) + delta = h @ U_eff.T # @ U_eff.T + return y + delta diff --git a/tests/test_metamath_smoke.py b/tests/test_metamath_smoke.py index 5a5283a..e66e1c6 100644 --- a/tests/test_metamath_smoke.py +++ b/tests/test_metamath_smoke.py @@ -14,7 +14,6 @@ from __future__ import annotations import importlib.util import sys -from dataclasses import replace from pathlib import Path import pytest @@ -31,11 +30,12 @@ sys.modules[SPEC.name] = benchmark SPEC.loader.exec_module(benchmark) -VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] +VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", + "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "road"] # Variants that fail loud when attached on a bnb-loaded base (read dense weight in init). # delora/eva also read weight but currently silently dequant -- they produce sane attach, # so we don't expect a raise from them in the attach-only smoke. -BNB_RAISERS = {"pissa", "dora", "antipasto"} +BNB_RAISERS = {"pissa", "dora", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda"} TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" HAS_CUDA = torch.cuda.is_available() @@ -75,8 +75,6 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc log_every=1000, output_dir=tmp_path / "out", ) - if variant == "antipasto": - cfg = replace(cfg, alpha=4) # block_size=4 -> need r % 4 == 0 return cfg