diff --git a/README.md b/README.md index 50621e7..558da85 100644 --- a/README.md +++ b/README.md @@ -48,47 +48,31 @@ just qwen-probe # Qwen/Qwen3-0.6B train/save-load probe ## Variants Trained on a MetaMathQA subset, tested on GSM8K, all on `Qwen/Qwen3.5-0.8B-Base` targeting -`down_proj` in all 24 layers (2500 steps, effective batch 8 = 20k samples). Standard adapters -use r=32; the AntiPaSTO family uses r=256 (it tunes only S-space gain, so it needs the rank). +`down_proj` in all 24 layers (2500 steps, effective batch 8 = 20k samples). -| Variant | test % | valid % | Params | +MACs/tok | fwd/bwd (ms) | init (s) | -| --------------------------------------------- | -----: | ------: | ------: | --------: | -----------: | -------: | -| [DoRA](https://arxiv.org/abs/2402.09353) | 60.2 | 68.0 | 3.56M | 3.54M | 161 / 556 | 0.16 | -| [LoRA](https://arxiv.org/abs/2106.09685) | 59.8 | 68.0 | 3.54M | 3.54M | 173 / 573 | 0.02 | -| [PiSSA](https://arxiv.org/abs/2404.02948) | 59.8 | 76.0 | 3.54M | 3.54M | 146 / 549 | 2.04 | -| [HRA](https://arxiv.org/abs/2405.17484) | 59.2 | 70.0 | 2.75M | 2.75M | 225 / 948 | 0.04 | -| [EVA](https://arxiv.org/abs/2410.07170) | 59.3 | 74.0 | 3.54M | 3.54M | 151 / 660 | 28.3 | -| [IA3-FF](https://arxiv.org/pdf/2205.05638) | 56.3 | 62.0 | 0.086M | 0M | 140 / 510 | 0.01 | -| [DeLoRA](https://arxiv.org/abs/2503.18225) | 56.2 | 62.0 | 3.54M | 3.54M | 169 / 593 | 0.21 | -| [AntiPaSTO](https://arxiv.org/abs/2601.07473) | 56.0 | 62.0 | 0.0061M | 28.3M | 166 / 571 | 2.5 | -| AntiPaSTO-rot | 57.2 | 60.0 | 0.0154M | 28.3M | 165 / 596 | 2.0 | -| AntiPaSTO-ablate | 56.0 | 68.0 | 0.0062M | 28.3M | 166 / 580 | 2.2 | -| AntiPaSTO-dplr | 56.0 | 64.0 | 0.1044M | 28.4M | 153 / 582 | 3.6 | -| AntiPaSTO-ASVD (diag C) | 55.6 | 64.0 | 0.0061M | 28.3M | 150 / 533 | 34 | -| AntiPaSTO-CorDA (full C) | 54.7 | 58.0 | 0.0061M | 28.3M | 146 / 576 | 120 | -| [IA3](https://arxiv.org/pdf/2205.05638) | 52.3 | 62.0 | 0.0061M | 0M | 161 / 515 | 0.01 | +| Variant | r | test % | valid % | Params | +MACs/tok | fwd/bwd (ms) | init (s) | +| --------------------------------------------- | ---: | -----: | ------: | ------: | --------: | -----------: | -------: | +| [DoRA](https://arxiv.org/abs/2402.09353) | 32 | 60.2 | 68.0 | 3.56M | 3.54M | 161 / 556 | 0.16 | +| [LoRA](https://arxiv.org/abs/2106.09685) | 32 | 59.8 | 68.0 | 3.54M | 3.54M | 173 / 573 | 0.02 | +| [PiSSA](https://arxiv.org/abs/2404.02948) | 32 | 59.8 | 76.0 | 3.54M | 3.54M | 146 / 549 | 2.04 | +| [EVA](https://arxiv.org/abs/2410.07170) | 32 | 59.3 | 74.0 | 3.54M | 3.54M | 151 / 660 | 28.3 | +| [HRA](https://arxiv.org/abs/2405.17484) | 32 | 59.2 | 70.0 | 2.75M | 2.75M | 225 / 948 | 0.04 | +| [AntiPaSTO](https://arxiv.org/abs/2601.07473) | 256 | 57.2 | 60.0 | 0.015M | 28.3M | 165 / 596 | 2.0 | +| [IA3-FF](https://arxiv.org/pdf/2205.05638) | — | 56.3 | 62.0 | 0.086M | 0M | 140 / 510 | 0.01 | +| [DeLoRA](https://arxiv.org/abs/2503.18225) | 32 | 56.2 | 62.0 | 3.54M | 3.54M | 169 / 593 | 0.21 | +| [IA3](https://arxiv.org/pdf/2205.05638) | — | 52.3 | 62.0 | 0.006M | 0M | 161 / 515 | 0.01 | -test/valid % = GSM8K exact-match accuracy. Params = trainable adapter params. +MACs/tok = added -forward MACs per token (analytic, hardware-independent). fwd/bwd = median ms over one batch. -init = one-time calibration (CorDA's `d_in x d_in` covariance eigh; ~0 for the rest). Peak CUDA -memory is ~9.8 GB for every row. Empty rows fill in as the sweep lands. +r = adapter rank (— = not a low-rank method). test/valid % = GSM8K exact-match accuracy. Params = +trainable adapter params. +MACs/tok = added forward MACs per token (analytic, hardware-independent). +fwd/bwd = median ms over one batch. init = one-time calibration (EVA's PCA; ~0 for the rest). Peak +CUDA memory is ~9.8 GB for every row. Single seed, so accuracy differences within ~1.4pp (test +SE at n=1319) are noise. We validate our adapters the same way [PEFT](https://github.com/huggingface/peft/tree/main/method_comparison) does: train on a MetaMathQA subset and check meaningful GSM8K accuracy. See [this file](scripts/metamath_gsm8k_benchmark.py) for details. -AntiPaSTO is the novel row here: instead of adding trainable directions like LoRA, it freezes W's own top-r SVD and learns only a bounded per-direction gain `S_eff = S * (1 + ELU(g))`. The singular basis stays fixed and interpretable, and the adapter is O(r) params (the 6.1K gain is ~580x smaller than LoRA's 3.54M). The variants change only the basis or core: rot learns a small block-rotation of the frozen basis, CorDA/ASVD orient it by the input second moment (full covariance vs diagonal-only, [Yang+ 2024](https://arxiv.org/abs/2406.05223) / [Yuan+ 2023](https://arxiv.org/abs/2312.05821)), ablate learns a contractive directional ablation ([Arditi+ 2024](https://arxiv.org/abs/2406.11717)), dplr adds a small low-rank core for cross-direction mixing. +AntiPaSTO is the novel row here: instead of adding trainable directions like LoRA, it freezes W's own top-r SVD and learns only a per-direction singular-value delta plus a block-diagonal Cayley rotation of that frozen basis. The singular directions stay interpretable and the adapter is tiny (15K params, ~230x smaller than LoRA's 3.54M) yet stays within noise of the full-rank adapters. The default rotates the input basis (V); rotating the output (U), both, or neither are `rotate_basis` ablation axes. -CorDA (full C) and ASVD (diag C) are a metric-axis ablation against plain AntiPaSTO (C=I): does -covariance orientation earn its `d_in x d_in` eigh over the cheap diagonal or no calibration at -all? On GSM8K/down_proj the answer is no: C=I 56.0, diag C 55.6, full C 54.7 (single seed). The -off-diagonal orientation is the slowest arm (120 s init vs 2.5 s) and lands slightly *below* no -calibration, so plain top-r SVD is the right default for this bounded-gain adapter here. - -AntiPaSTO-rot tunes that basis instead of the metric: a block-diagonal Cayley rotation of the -input (V), output (U), or both. The table row is V (the default); the ablation gives V 57.2 > -U 56.5 > both 55.6 (single seed). So rotating which inputs feed each frozen direction helps most, -the output-side rotation is slightly worse, and doing both is worst -- the second rotation is -redundant capacity that hurts. rot(V) is the best small-parameter arm overall (57.2 at 15K params -vs LoRA's 59.8 at 3.54M). +The full AntiPaSTO family (rotation-free gain core, the U/both rotation arms, contractive directional ablation [Arditi+ 2024](https://arxiv.org/abs/2406.11717), a low-rank mixing core, and CorDA/ASVD covariance-oriented bases [Yang+ 2024](https://arxiv.org/abs/2406.05223) / [Yuan+ 2023](https://arxiv.org/abs/2312.05821)) lives on the [`antipasto-variants`](https://github.com/wassname/lora-lite/tree/antipasto-variants) branch with its own ablation table. On GSM8K/down_proj none of those arms separated from this one at a single seed, and the covariance-oriented bases cost 34-120 s of init for no gain, so main keeps the cheapest arm that led: rotation of V. ## Developer docs diff --git a/justfile b/justfile index f5fe24c..a3cb4c5 100644 --- a/justfile +++ b/justfile @@ -83,7 +83,7 @@ metamath-queue variant="lora" steps="5000" model=model: # Run a single MetaMathQA->GSM8K benchmark for a given variant. # Per-variant lr / target-name defaults are baked in here. -bench-variant model variant steps="5000" lora_rank="8" r_override="" lr_override="" rotate_basis="V" seed="0": +bench-variant model variant steps="5000" r_override="" lr_override="" rotate_basis="V" seed="0": #!/usr/bin/env bash set -euo pipefail lr=1e-4 @@ -99,14 +99,14 @@ bench-variant model variant steps="5000" lora_rank="8" r_override="" lr_override delora) lr=1e-3 ;; ia3) lr=5e-3; target='(k_proj|v_proj)$' ;; ia3_ff) lr=5e-3; target='(down_proj)$' ;; - # antipasto cores tune only S-space gain/block (tiny params), so a small - # r leaves almost nothing trainable; r=256 is the variant default and - # matches the published AntiPaSTO row. alpha=r (no extra scaling). - antipasto*) lr=5e-3; r=256; alpha=256 ;; + # antipasto tunes only S-space deltas + a small block rotation (tiny params), + # so a small r leaves almost nothing trainable; r=256 is the variant default + # and matches the published AntiPaSTO row. alpha=r (no extra scaling). + antipasto) lr=5e-3; r=256; alpha=256 ;; esac - # r override (e.g. low-rank corda sweep); alpha tracks r for the antipasto family. + # r override (e.g. low-rank sweep); alpha tracks r for antipasto. if [ -n "{{r_override}}" ]; then r="{{r_override}}"; alpha="{{r_override}}"; fi - # lr override (e.g. dplr core wants a tamer lr than the gain's 5e-3). + # lr override (e.g. a tamer lr than antipasto's 5e-3 default). if [ -n "{{lr_override}}" ]; then lr="{{lr_override}}"; fi # 0.8B + large vocab: HF ForCausalLMLoss upcasts logits to fp32 (bs*seq*vocab*4), # which OOMs the 24GB card at the old bs=4/seq=768. micro-batch 2 fits at ~10GB; @@ -119,13 +119,12 @@ bench-variant model variant steps="5000" lora_rank="8" r_override="" lr_override --steps {{steps}} \ --lr "$lr" \ --target-name "$target" \ - --antipasto-lora-rank {{lora_rank}} \ --batch-size 2 --grad-accum 4 --max-seq-length 512 --batch-size-eval 16 \ --layers all --r "$r" --alpha "$alpha" \ --antipasto-rotate-basis '{{rotate_basis}}' \ --seed {{seed}} -metamath-queue-all model=model steps="2500" variants="lora pissa delora dora hra ia3 ia3_ff eva antipasto antipasto_rot antipasto_corda antipasto_asvd antipasto_ablate antipasto_dplr": +metamath-queue-all model=model steps="2500" variants="lora pissa delora dora hra ia3 ia3_ff eva antipasto": #!/usr/bin/env bash set -euo pipefail # One pueue job per variant (each runs the live code at run time, so editing diff --git a/scripts/cost_report.py b/scripts/cost_report.py index 3e9f9ea..993e415 100644 --- a/scripts/cost_report.py +++ b/scripts/cost_report.py @@ -2,14 +2,12 @@ Answers "which is best -- time / flops / adds / params?": MACs/token is the deterministic apples-to-apples compute number; trainable_params is the size headline; -wall-time is the felt-but-noisy number; group_init is where CorDA's eigh(d_in^3) bites. +wall-time is the felt-but-noisy number; group_init is the one-time init cost. Usage: uv run --extra benchmark python scripts/cost_report.py \ - --model Qwen/Qwen3-0.6B-Base --variants antipasto antipasto_corda antipasto_ablate lora \ + --model Qwen/Qwen3-0.6B-Base --variants antipasto lora pissa \ --target-name 'q_proj$' 'v_proj$' --r 32 --out logs/cost_qwen0.6b.log - -Point --target-name at down_proj to see the CorDA covariance corner (large d_in). """ from __future__ import annotations @@ -40,7 +38,6 @@ def build_cfg(variant: str, args, dtype) -> ll.AdapterConfig: bcfg = benchmark.BenchmarkConfig( model=args.model, variant=variant, r=args.r, alpha=float(args.r), target_name=list(args.target_name), layers=args.layers, torch_dtype=args.dtype, - antipasto_cov_orient=args.cov_orient, ) return benchmark.cfg_for_variant(bcfg, dtype) @@ -49,19 +46,16 @@ def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--model", default="Qwen/Qwen3-0.6B-Base") ap.add_argument("--variants", nargs="+", - default=["lora", "antipasto", "antipasto_rot", "antipasto_corda", - "antipasto_ablate", "antipasto_dplr"]) + default=["lora", "pissa", "antipasto"]) ap.add_argument("--target-name", nargs="+", default=[r"q_proj$", r"v_proj$"]) ap.add_argument("--r", type=int, default=32) ap.add_argument("--layers", default="all", - help="'all' or comma list e.g. '0,1' -- limit layers (CorDA down_proj eigh is slow).") + help="'all' or comma list e.g. '0,1' -- limit layers.") ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") ap.add_argument("--dtype", default="bfloat16") ap.add_argument("--seq-len", type=int, default=256) ap.add_argument("--batch", type=int, default=2) ap.add_argument("--calib-batches", type=int, default=4) - ap.add_argument("--cov-orient", action="store_true", - help="CorDA-orient antipasto_ablate (measure the eigh corner).") ap.add_argument("--out", default="logs/cost.log") args = ap.parse_args() diff --git a/scripts/metamath_gsm8k_benchmark.py b/scripts/metamath_gsm8k_benchmark.py index 806227e..b6a27a7 100644 --- a/scripts/metamath_gsm8k_benchmark.py +++ b/scripts/metamath_gsm8k_benchmark.py @@ -35,11 +35,6 @@ CFG_BY_VARIANT = { "hra": ll.HRAConfig, "eva": ll.EVAConfig, "antipasto": ll.AntiPaSTOConfig, - "antipasto_rot": ll.AntiPaSTORotConfig, - "antipasto_ablate": ll.AntiPaSTOAblateConfig, - "antipasto_corda": ll.AntiPaSTOCorDAConfig, - "antipasto_asvd": ll.AntiPaSTOASVDConfig, - "antipasto_dplr": ll.AntiPaSTODPLRConfig, "road": ll.RoadConfig, } @@ -49,7 +44,7 @@ class BenchmarkConfig: """MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI.""" model: str = "Qwen/Qwen3.5-0.8B-Base" - variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "antipasto_asvd", "antipasto_dplr", "road"] = "lora" + variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora" mode: Literal["benchmark", "probe"] = "benchmark" device: str = "cuda" torch_dtype: str = "bfloat16" @@ -58,16 +53,8 @@ class BenchmarkConfig: alpha: float = 64.0 delora_lambda0: float = 0.1 road_group_size: int = 64 - # AntiPaSTO family (gain / corda) runtime knobs. - antipasto_coeff: float = 1.0 - antipasto_suppress_only: bool = False - # AntiPaSTO-ablate. - antipasto_ablate_k: int = 1 - antipasto_cov_orient: bool = False - # AntiPaSTO-rot (legacy rotation variant) basis to rotate. + # AntiPaSTO singular basis to rotate: V (default) / U / both / none (ablation axes). antipasto_rotate_basis: Literal["V", "U", "both", "none"] = "V" - # AntiPaSTO-dplr: rank of the low-rank mixing core in the frozen subspace. - antipasto_lora_rank: int = 8 target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS)) layers: str = "all" train_dataset: str = "meta-math/MetaMathQA" @@ -140,16 +127,8 @@ def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConf extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {} if args.variant == "road": extra = {"group_size": args.road_group_size} - if args.variant == "antipasto_rot": + if args.variant == "antipasto": extra = {"rotate_basis": args.antipasto_rotate_basis} - if args.variant in ("antipasto", "antipasto_corda", "antipasto_asvd"): - extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only} - if args.variant == "antipasto_ablate": - extra = {"coeff": args.antipasto_coeff, "k": args.antipasto_ablate_k, - "cov_orient": args.antipasto_cov_orient} - if args.variant == "antipasto_dplr": - extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only, - "lora_rank": args.antipasto_lora_rank} return CFG_BY_VARIANT[args.variant]( r=args.r, alpha=args.r if args.variant == "pissa" else args.alpha, @@ -579,18 +558,14 @@ def run(args: BenchmarkConfig) -> dict[str, Any]: dtype = getattr(torch, args.torch_dtype) run_commit = current_git_commit() run_id = f"{args.model.replace('/', '--')}__{args.variant}__s{args.steps}__seed{args.seed}" - # dplr capacity is set by lora_rank, not r, so keep rank-sweep runs from colliding. - if args.variant == "antipasto_dplr" and args.antipasto_lora_rank != 8: - run_id += f"__k{args.antipasto_lora_rank}" - # antipasto family defaults to r=256; low-rank sweeps get their own dirs. - if args.variant.startswith("antipasto") and args.r != 256: + # antipasto defaults to r=256; low-rank sweeps get their own dirs. + if args.variant == "antipasto" and args.r != 256: run_id += f"__r{args.r}" - # antipasto_rot defaults to rotating V; U/both are ablation axes -> own dirs. - if args.variant == "antipasto_rot" and args.antipasto_rotate_basis != "V": + # antipasto defaults to rotating V; U/both/none are ablation axes -> own dirs. + if args.variant == "antipasto" and args.antipasto_rotate_basis != "V": run_id += f"__rot{args.antipasto_rotate_basis}" - # antipasto family defaults to lr=5e-3; lr sweeps get their own dirs (the dense/ - # low-rank cores want a tamer lr than the gain, so this is a real axis). - if args.variant.startswith("antipasto") and abs(args.lr - 5e-3) > 1e-9: + # antipasto defaults to lr=5e-3; lr sweeps get their own dirs. + if args.variant == "antipasto" and abs(args.lr - 5e-3) > 1e-9: run_id += f"__lr{args.lr:g}" out_dir = args.output_dir / run_id out_dir.mkdir(parents=True, exist_ok=True) @@ -600,13 +575,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]: batches, skipped_train_prompt_too_long = make_train_batches(datasets["train"], tokenizer, args) print_first_train_sample(tokenizer, batches[0]) cfg = cfg_for_variant(args, dtype) - # Variants with a data-driven group_init need calibration activations from the - # downstream task (IPM mode, per CorDA). eva needs only a few batches for its init; - # corda/asvd/cov-orient estimate an input second moment, so we hand them many more - # batches (PEFT calibrates on a few hundred sequences) for a well-conditioned basis. - needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd") or ( - args.variant == "antipasto_ablate" and args.antipasto_cov_orient - ) + # eva needs a few calibration batches for its data-driven init. antipasto runs + # without calibration (plain weight-SVD init), matching how it was benchmarked. + needs_calib = args.variant == "eva" init_meter = group_init_meter() # wall-time + peak CPU RAM of group_init if needs_calib: n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches)) diff --git a/src/lora_lite/__init__.py b/src/lora_lite/__init__.py index 562e3b3..588e394 100644 --- a/src/lora_lite/__init__.py +++ b/src/lora_lite/__init__.py @@ -20,11 +20,6 @@ from .variants.dora import DoRAConfig from .variants.hra import HRAConfig from .variants.eva import EVAConfig from .variants.antipasto import AntiPaSTOConfig -from .variants.antipasto_rot import AntiPaSTORotConfig -from .variants.antipasto_ablate import AntiPaSTOAblateConfig -from .variants.antipasto_corda import AntiPaSTOCorDAConfig -from .variants.antipasto_asvd import AntiPaSTOASVDConfig -from .variants.antipasto_dplr import AntiPaSTODPLRConfig from .variants.road import RoadConfig __all__ = [ @@ -38,11 +33,6 @@ __all__ = [ "HRAConfig", "EVAConfig", "AntiPaSTOConfig", - "AntiPaSTORotConfig", - "AntiPaSTOAblateConfig", - "AntiPaSTOCorDAConfig", - "AntiPaSTOASVDConfig", - "AntiPaSTODPLRConfig", "RoadConfig", "attach", "detach", diff --git a/src/lora_lite/variants/__init__.py b/src/lora_lite/variants/__init__.py index abb7416..8e3045a 100644 --- a/src/lora_lite/variants/__init__.py +++ b/src/lora_lite/variants/__init__.py @@ -1,4 +1,3 @@ from . import ( # noqa: F401 side-effect: register lora, pissa, delora, ia3, dora, hra, eva, antipasto, road, - antipasto_rot, antipasto_ablate, antipasto_corda, antipasto_asvd, antipasto_dplr, ) diff --git a/src/lora_lite/variants/antipasto.py b/src/lora_lite/variants/antipasto.py index 6172f3a..3c33f41 100644 --- a/src/lora_lite/variants/antipasto.py +++ b/src/lora_lite/variants/antipasto.py @@ -1,30 +1,32 @@ -"""AntiPaSTO: learnable bounded reweighting of frozen SVD singular values. +"""AntiPaSTO: SVD adapter that freezes W's own top-r basis and learns a per-direction +singular-value delta plus a block-diagonal Cayley rotation of that frozen basis. wassname 2026 https://arxiv.org/abs/2601.07473 - W = U diag(S) Vh + W_res # top-r SVD; W_res = W - U_r S_r Vh_r, frozen - learn: g (r,) # per-direction gain - S_eff = S * (1 + ELU(coeff * g)) # exp(z) for z<0 (bounded), 1+z for z>0 - y = x @ W_res.T + ((x @ Vh.T) * S_eff) @ U.T + W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r) + learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2) + R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T) + y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T - suppress_only: clamp g<=0 -> S_eff in (0, S], attenuation only. - coeff: runtime scale; 0 = identity, <0 swaps amplify/suppress. +The default rotates the input basis (V): on GSM8K/down_proj this beat rotating the +output basis (U) or both, and beat a rotation-free gain core -- rotating which inputs +feed each frozen direction is the cheapest knob that helps (57.2 at 15K params). U / +both / none are kept as `rotate_basis` ablation axes. -Identity at g=0 or coeff=0: 1+ELU(0)=1, so S_eff=S (up to the bf16 SVD round-trip). -The basis (U, Vh) is frozen, so the singular directions stay interpretable and only -the gain is learned. See forward() for why 1+ELU over linear/exp/tanh. +Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ~ x @ W^T (tiny positive bias on +delta_s breaks sign symmetry; rotation alone can't). Refs: - paper: https://github.com/wassname/AntiPaSTO - - sibling (whitened, mean-diff): steering-lite/.../sspace.py - - selection: Wanda (Sun+ 2023, arXiv:2306.11695), ASVD (Yuan+ 2023, arXiv:2312.05821) - - top-r SVD init: PiSSA (Meng+ 2024, arXiv:2404.02948) + - lite port of: https://github.com/wassname/antipasto3 + (offline: docs/refs/antipasto3_svd_adapter.py) """ +import math from dataclasses import dataclass from typing import Iterable, Literal import torch -from einops import rearrange +from einops import einsum, rearrange from jaxtyping import Float from torch import nn, Tensor as T @@ -39,17 +41,34 @@ CalibrationData = Iterable[CalibrationBatch] @dataclass class AntiPaSTOConfig(AdapterConfig): variant: str = "antipasto" - # Only r + r trainable scalars, so r can be large. + # Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out). r: int = 256 - # Per-direction reweighting is S_eff = S * (1 + ELU(coeff * g)). See forward() - # for the why; identity at g=0 or coeff=0, positive always, no free bound knob. - suppress_only: bool = False # clamp g<=0 -> factor in (0,1]: attenuation only. - # Guarantee holds for coeff>=0; coeff<0 inverts the product and re-amplifies. - # Runtime steering scale. 0 = identity. <0 inverts (swaps amplify/suppress). - coeff: float = 1.0 - # group_init Wanda-style pooling of |X @ Vh[i]|: 'rms' is outlier-sensitive - # (ASVD intuition), 'mean_abs' is the original outlier-robust pooling. - act_pool: Literal["rms", "mean_abs"] = "rms" + # Block size for the block-diagonal Cayley rotation. r must be divisible by it. + block_size: int = 4 + # Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians. + max_rotation_angle: float = 0.5 + # Which singular basis to rotate: 'V' (input, default), 'U' (output), 'both', or 'none'. + rotate_basis: Literal["V", "U", "both", "none"] = "V" + + +def _cayley(skew: torch.Tensor) -> torch.Tensor: + """R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality.""" + bs = skew.shape[-1] + eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew) + X = skew / 2 + return torch.linalg.solve(eye - X, eye + X) + + +def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor: + """rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation.""" + n_blocks, _ = rot_T.shape + rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0) + A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device) + A[:, rows, cols] = rot_T + A = 0.5 * (A - A.transpose(-1, -2)) + a_limit = 2.0 * math.tan(max_angle / 2.0) + A = a_limit * torch.tanh(A / a_limit) + return _cayley(A) @register @@ -59,14 +78,27 @@ class AntiPaSTO: @staticmethod def param_specs(d_in, d_out, cfg): r = cfg.r - return dict( - # Frozen top-r SVD captured at init. + bs = int(cfg.block_size) + if r % bs != 0: + raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}") + specs = dict( + # Frozen SVD components captured at init. lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True), lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), - # Trainable per-direction log-scale. init 0 -> 1+ELU(0)=1 -> identity. - lora_g=ParamSpec((r,), init="zeros"), + # Trainable: per-singular-value delta. + # antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign + # symmetry (rotation alone can't); zero-init works but trains slower. + lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)), ) + if cfg.rotate_basis != "none": + n_blocks = r // bs + n_triu = bs * (bs - 1) // 2 + specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros") + if cfg.rotate_basis == "both": + # 'both' rotates V (lora_rot_T) and U independently; lora_rot_T_u is the U-side. + specs["lora_rot_T_u"] = ParamSpec((n_blocks, n_triu), init="zeros") + return specs @staticmethod def init(layer: nn.Module, cfg) -> None: @@ -85,19 +117,17 @@ class AntiPaSTO: layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) layer.weight.data.copy_(W_res) - # group_init() refines the dimension selection if calibration_data is given. + # group_init() refines this to input-aligned directions if calibration_data is given. @staticmethod def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: - """Data-driven re-selection of which top-r singular directions to keep. + """Wanda-style data-driven dimension selection within the weight SVD. - init(): top-r by S alone (PiSSA-style) - group_init(): top-r by score[i] = S[i] * pool|X @ Vh[i]| (Wanda/ASVD) - pool = 'rms' (outlier-sensitive) | 'mean_abs' (outlier-robust) + init() picks the top-r singular dimensions by S alone (PiSSA-style). + group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions + that are both large in W AND active given real inputs. - This re-RANKS W's own singular vectors by activation; it does NOT re-orient - the basis (that is CorDA -> antipasto_corda.py). So the kept directions are - still plain weight-SVD directions, just a better subset. None -> keep init(). + If calibration_data is None the weight-SVD init from init() is kept. """ if calibration_data is None: return @@ -133,7 +163,6 @@ class AntiPaSTO: h.remove() r = cfg.r - pool = cfg.act_pool for name, layer in layers.items(): X = torch.cat(captured[name], dim=0) # (N, d_in) if X.shape[0] < r: @@ -141,23 +170,19 @@ class AntiPaSTO: f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}" ) - # Rebuild the FULL W: init() stored the exact top-r it subtracted, so - # W_res + U_r S_r Vh_r == W (full rank, not a cropped matrix). The SVD - # below therefore re-selects from W's whole spectrum, not a truncation. + # Recover W_orig: init() wrote W_res into layer.weight and stored top-r components W_res = layer.weight.data.float() - U_old = layer.lora_U.float() - S_old = layer.lora_S.float() - Vh_old = layer.lora_Vh.float() + U_old = layer.lora_U.float() # (d_out, r) + S_old = layer.lora_S.float() # (r,) + Vh_old = layer.lora_Vh.float() # (r, d_in) W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old + # Full SVD to score all dimensions U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False) - proj = X.to(Vh_full) @ Vh_full.T # (N, r) input in S-coords (X CPU -> GPU here) - if pool == "rms": - act_mag = proj.pow(2).mean(dim=0).sqrt() # outlier-sensitive - else: - act_mag = proj.abs().mean(dim=0) # outlier-robust (original) + # score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude) + act_mag = (X.to(Vh_full) @ Vh_full.T).abs().mean(dim=0) # (k,) -- X captured on CPU scores = S_full * act_mag - idx = scores.argsort(descending=True)[:r] # top-r by joint importance + idx = scores.argsort(descending=True)[:r] # top-r by joint importance idx = idx.sort().values # stable ordering Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx] @@ -176,23 +201,32 @@ class AntiPaSTO: y: Float[T, '*B o'], ) -> Float[T, '*B o']: cfg = layer._lora_cfg + bs = int(cfg.block_size) + max_angle = float(cfg.max_rotation_angle) + rotate_basis = cfg.rotate_basis + U = layer.lora_U.to(x.dtype) # (d_out, r) S = layer.lora_S.to(x.dtype) # (r,) Vh = layer.lora_Vh.to(x.dtype) # (r, d_in) - g = layer.lora_g.to(x.dtype) # (r,) - coeff = float(cfg.coeff) - if cfg.suppress_only: - g = torch.clamp(g, max=0.0) # factor in (0,1]: attenuation only + if rotate_basis == "none": + U_eff, Vh_eff = U, Vh + else: + R = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype) + n_blocks = R.shape[0] # R: (n, bs, bs) + U_eff, Vh_eff = U, Vh + # 'V'/'U' rotate that one basis with lora_rot_T; 'both' rotates V with + # lora_rot_T and U with a separate lora_rot_T_u (independent rotations). + if rotate_basis in ("V", "both"): + Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks) + Vh_eff = rearrange(einsum(R, Vh_blocks, "n a b, n b i -> n a i"), "n a i -> (n a) i") + if rotate_basis in ("U", "both"): + R_u = _build_rotation(layer.lora_rot_T_u.float(), bs, max_angle).to(x.dtype) if rotate_basis == "both" else R + U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks) + U_eff = rearrange(einsum(U_blocks, R_u, "d n b, n c b -> d n c"), "d n c -> d (n c)") - # S_eff = S * (1 + ELU(z)), z = coeff*g, 1+ELU(z) = exp(z) for z<=0 else 1+z. - # Why 1+ELU and not the obvious alternatives: - # linear S*(1+z) : z<-1 -> S_eff<0, a sign flip that drives incoherence. - # exp S*exp(z) : unbounded, gradient self-amplifies (amplification blows up). - # tanh bounded : arbitrary bound knob, saturation kills the gradient. - # 1+ELU uses each in its safe regime: exp where it is bounded in (0,1] - # (attenuation), linear where exp would diverge (amplification). >0 always. - S_eff = S * (1.0 + torch.nn.functional.elu(coeff * g)) - - h = (x @ Vh.T) * S_eff # input in S-coords, reweighted - return y + h @ U.T + S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,) + h = x @ Vh_eff.T # x @ Vh_eff.T + h = h * S_eff # diag(S_eff) + delta = h @ U_eff.T # @ U_eff.T + return y + delta diff --git a/src/lora_lite/variants/antipasto_ablate.py b/src/lora_lite/variants/antipasto_ablate.py deleted file mode 100644 index c8b2c79..0000000 --- a/src/lora_lite/variants/antipasto_ablate.py +++ /dev/null @@ -1,190 +0,0 @@ -"""AntiPaSTO-Ablate: trainable directional ablation in the weight-SVD output basis. - -A contractive sibling of antipasto.py: instead of reweighting the singular gains it -projects out a learned direction in the output (U-side) singular basis. - - W = U diag(S) Vh + W_res - learn: c (r, k) ablation directions, alpha (k,) strengths in [0, 1] - Chat = orthonormal(c) # k unit dirs in S-space - h = (x @ Vh.T) * S # output S-coords = diag(S) Vh x - h <- h - coeff * (h @ Chat) * alpha @ Chat.T # project the span out - y = x @ W_res.T + h @ U.T - -The core (I - alpha Chat Chat^T) is a contraction: eigenvalues 1-alpha along Chat, -1 elsewhere, all in [0, 1]. It cannot amplify, so it cannot blow up -- the instability -the multiplicative gain bounds away is structurally absent (and a contraction is the -natural core to recurse). This is the trainable form of directional ablation (Arditi+ -2024): target residual writers (down_proj, o_proj) for the surgical regime, not all -Linears. - -Runtime: coeff is the per-call knob. coeff=0 -> identity; (0, 1] -> ablate; <0 adds the -direction back (the side that can grow, so bound coeff there). - -Refs: antipasto.py (gain sibling), directional ablation Arditi+ 2024 arXiv:2406.11717. -""" -from dataclasses import dataclass -from typing import Iterable - -import torch -from einops import rearrange -from jaxtyping import Float -from torch import nn, Tensor as T - -from ..variant import register, ParamSpec -from ..config import AdapterConfig, register_config - -CalibrationBatch = dict | tuple | list | T -CalibrationData = Iterable[CalibrationBatch] - -ε = 1e-6 - - -@register_config -@dataclass -class AntiPaSTOAblateConfig(AdapterConfig): - variant: str = "antipasto_ablate" - r: int = 256 # top-r SVD captured (or |dS|-selected via group_init) - k: int = 1 # number of ablation directions (rank of the projection) - init_alpha: float = 0.05 # small >0 so c gets gradient at step 0 - coeff: float = 1.0 # runtime: 0=identity, (0,1]=ablate, <0=amplify (bound this side) - # CorDA-orient the basis from input covariance (group_init, needs calibration_data). - # The ablation is OUTPUT-side and CorDA's U stays orthonormal, so this is a clean - # contraction; the win is at low r -- the data-oriented top-r captures the behavior - # output direction that plain-SVD top-r drops (measured 1.00 vs 0.65 at r=16). - cov_orient: bool = False - cov_eps: float = 1e-3 - - -@register -class AntiPaSTOAblate: - name = "antipasto_ablate" - - @staticmethod - def param_specs(d_in, d_out, cfg): - r, k = cfg.r, cfg.k - return dict( - lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True), - lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), - lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), - # Trainable: k ablation directions in S-space, and their strengths. - lora_c=ParamSpec((r, k), init=lambda t: t.normal_(0, 1.0 / max(r, 1) ** 0.5)), - lora_alpha=ParamSpec((k,), init=lambda t: t.fill_(float(cfg.init_alpha))), - ) - - @staticmethod - def init(layer: nn.Module, cfg) -> None: - if type(layer) is not nn.Linear: - raise TypeError("AntiPaSTOAblate mutates layer.weight into W_res; nn.Linear only.") - with torch.no_grad(): - W = layer.weight.data.float() - U, S, Vh = torch.linalg.svd(W, full_matrices=False) - r = cfg.r - Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :] - layer.lora_U.copy_(Ur.to(layer.lora_U.dtype)) - layer.lora_S.copy_(Sr.to(layer.lora_S.dtype)) - layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) - W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) - layer.weight.data.copy_(W_res) - # lora_c is random-init. Tried (job 94) seeding it from the top-k S-space - # output-VARIANCE PC: equal-or-worse on GSM8K (55.6/64.0 vs random 56.0/68.0, - # single seed) and +31s calib init -- the optimal ablation dir is loss-defined, - # not variance-defined, so a variance seed buys nothing on SFT. Reverted. - # FIXME the contrastive dS seed (mean(h|pos)-mean(h|neg), cf. sspace.py) is the - # one that should land on the behavior dir, but it needs pos/neg pairs this SFT - # benchmark lacks -- only worth it for steering with labelled contrastive data. - - @staticmethod - def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: - """If cov_orient, re-orient each target's SVD by input covariance C=E[x x^T] - (CorDA) so the data-relevant output directions land in the top-r and the - behavior direction is fully ablatable at low r. No-op otherwise (keeps the - plain-SVD basis from init()). C is accumulated on CPU; for down_proj's large - d_in this is heavy -- exclude it or use plain ablation there.""" - if not getattr(cfg, "cov_orient", False) or calibration_data is None: - return - - layers = {name: layer for name, layer, _ in targets} - cov: dict[str, T] = {} - cnt: dict[str, int] = {n: 0 for n in layers} - - def make_hook(name): - def _h(module, args, kwargs): - x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu() - g = x.T @ x - cov[name] = g if name not in cov else cov[name] + g - cnt[name] += x.shape[0] - return _h - - handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers] - try: - was_training = model.training - model.eval() - with torch.no_grad(): - for batch in calibration_data: - if isinstance(batch, dict): - model(**batch) - elif isinstance(batch, (list, tuple)): - model(*batch) - else: - model(batch) - if was_training: - model.train() - finally: - for h in handles: - h.remove() - - r, eps = cfg.r, float(cfg.cov_eps) - for name, layer in layers.items(): - if cnt[name] < r: - raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}") - W_res = layer.weight.data.float().cpu() - U_old, S_old, Vh_old = (layer.lora_U.float().cpu(), - layer.lora_S.float().cpu(), - layer.lora_Vh.float().cpu()) - W_orig = W_res + (U_old * S_old) @ Vh_old - - C = cov[name] / cnt[name] - lam, Q = torch.linalg.eigh(C) - lam = lam.clamp_min(0) + eps - Chalf = (Q * lam.sqrt()) @ Q.T - Cinvhalf = (Q * lam.rsqrt()) @ Q.T - Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False) - Ur = Ut[:, :r] # orthonormal output basis (ablation acts here) - Sr = St[:r] - Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only) - W_res_new = W_orig - (Ur * Sr) @ Pr - - with torch.no_grad(): - layer.lora_U.copy_(Ur.to(layer.lora_U)) - layer.lora_S.copy_(Sr.to(layer.lora_S)) - layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot - layer.weight.data.copy_(W_res_new.to(layer.weight)) - - @staticmethod - def _orthonormal(c: T) -> T: - """(r, k) -> (r, k) with orthonormal columns. k=1 is a plain normalize.""" - if c.shape[-1] == 1: - return c / (c.norm(dim=0, keepdim=True) + ε) - # geqrf has no bf16/fp16 kernel (CPU or CUDA); do the QR in fp32, cast back. - q, _ = torch.linalg.qr(c.float()) # reduced QR; columns orthonormal - return q.to(c.dtype) - - @staticmethod - def forward( - layer: nn.Module, - x: Float[T, '*B i'], - y: Float[T, '*B o'], - ) -> Float[T, '*B o']: - cfg = layer._lora_cfg - U = layer.lora_U.to(x.dtype) # (d_out, r) - S = layer.lora_S.to(x.dtype) # (r,) - Vh = layer.lora_Vh.to(x.dtype) # (r, d_in) - Chat = AntiPaSTOAblate._orthonormal(layer.lora_c.to(x.dtype)) # (r, k) - alpha = layer.lora_alpha.to(x.dtype).clamp(0.0, 1.0) # (k,) - coeff = float(cfg.coeff) - - h = (x @ Vh.T) * S # (..., r) output S-coords - proj = h @ Chat # (..., k) component along each dir - # contractive removal: h <- h - coeff * Sum_j alpha_j (h . chat_j) chat_j - h = h - coeff * (proj * alpha) @ Chat.T # (..., r) - return y + h @ U.T diff --git a/src/lora_lite/variants/antipasto_asvd.py b/src/lora_lite/variants/antipasto_asvd.py deleted file mode 100644 index dea43cc..0000000 --- a/src/lora_lite/variants/antipasto_asvd.py +++ /dev/null @@ -1,43 +0,0 @@ -"""AntiPaSTO-ASVD: diagonal-covariance sibling of antipasto_corda. - -Same frozen-basis bounded gain, but orients the SVD by the DIAGONAL of the input -second moment (per-channel activation scale) instead of the full covariance: - - M = diag(E[x_i^2]) vs CorDA's full C = E[x x^T] - -This is Activation-aware SVD (Yuan+ 2023, arXiv:2312.05821): SVD(W diag(s)) with s a -per-channel scale. It is NOT a sub-basis of CorDA -- diag(C)^{1/2} and C^{1/2} are -different oblique rotations, so the top-r directions differ and either can win on a task. -ASVD is the cheap arm: O(d_in) moment, no d_in x d_in matrix, no eigh. The head-to-head -with antipasto_corda isolates whether the off-diagonal of C earns its init cost here. - -Reuses antipasto_corda's buffers (U, S, P, g), plain-SVD init, gain forward, and the -shared `_covariance_orient` (only the diag flag differs), so there is one copy of the -math to keep in sync. - -Refs: antipasto_corda.py (full-covariance sibling), ASVD arXiv:2312.05821. -""" -from dataclasses import dataclass - -from ..variant import register -from ..config import register_config -from .antipasto_corda import AntiPaSTOCorDA, AntiPaSTOCorDAConfig, _covariance_orient - - -@register_config -@dataclass -class AntiPaSTOASVDConfig(AntiPaSTOCorDAConfig): - variant: str = "antipasto_asvd" - - -@register -class AntiPaSTOASVD: - name = "antipasto_asvd" - param_specs = staticmethod(AntiPaSTOCorDA.param_specs) - init = staticmethod(AntiPaSTOCorDA.init) - forward = staticmethod(AntiPaSTOCorDA.forward) - - @staticmethod - def group_init(model, targets, cfg, calibration_data) -> None: - """ASVD: re-orient by the diagonal of the input second moment (per-channel).""" - _covariance_orient(model, targets, cfg, calibration_data, diag=True) diff --git a/src/lora_lite/variants/antipasto_corda.py b/src/lora_lite/variants/antipasto_corda.py deleted file mode 100644 index adc8fb3..0000000 --- a/src/lora_lite/variants/antipasto_corda.py +++ /dev/null @@ -1,203 +0,0 @@ -"""AntiPaSTO-CorDA: reweight in a covariance-oriented basis, not the weight basis. - -Plain SVD sorts directions by weight gain ||W v|| on isotropic input. The behaviour -you steer lives where the DATA has energy, off the top weight-singular axes. CorDA -(Yang+ 2024, arXiv:2406.05223) re-orients the SVD by the input covariance, so the top-r -directions move the output most on real activations. - - C = E[x x^T] (+ eps I) # input second moment on calibration data - C^{1/2}, C^{-1/2} via eigh(C) - U S Vht = SVD(W C^{1/2}) # top-r is Eckart-Young best under x ~ N(0,C) - P = Vht C^{-1/2} # (r, d_in) oblique input projector - W = W_res + U_r diag(S_r) P_r # exact (residual carries the dropped tail) - S_eff = S * (1 + ELU(coeff*g)) # same bounded gain as antipasto - y = x @ W_res.T + ((x @ P.T) * S_eff) @ U.T - -Identity at g=0 or coeff=0: S_eff=S. P is oblique (rows not orthonormal -- C^{-1/2} -skews them); fine for gain reweighting since U stays orthonormal. Requires -calibration_data (group_init raises otherwise). - -Refs: antipasto.py (gain + selection sibling), CorDA arXiv:2406.05223. -""" -from dataclasses import dataclass -from typing import Iterable - -import torch -import torch.nn.functional as F -from einops import rearrange -from jaxtyping import Float -from torch import nn, Tensor as T - -from ..variant import register, ParamSpec -from ..config import AdapterConfig, register_config - -CalibrationBatch = dict | tuple | list | T -CalibrationData = Iterable[CalibrationBatch] - - -@register_config -@dataclass -class AntiPaSTOCorDAConfig(AdapterConfig): - variant: str = "antipasto_corda" - r: int = 256 - cov_eps: float = 1e-3 # damping on C eigenvalues; guards C^{-1/2} on rare dirs - coeff: float = 1.0 # runtime steer knob: 0=identity, scales trained g - suppress_only: bool = False # clamp g<=0 (attenuate only) -- for coeff>=0; - # coeff<0 inverts the product (coeff*g>=0) and re-amplifies. - - -def _gain(S: T, g: T, coeff: float, suppress_only: bool) -> T: - """S_eff = S * (1 + ELU(coeff*g)); exp-bounded attenuation, linear amplification.""" - if suppress_only: - g = g.clamp(max=0.0) - return S * (1.0 + F.elu(coeff * g)) - - -def _covariance_orient(model, targets, cfg, calibration_data, *, diag: bool) -> None: - """Re-orient each target's SVD by its input second moment, then rewrite the frozen - buffers (U, S, P) and residual weight in that basis. Shared by CorDA and ASVD: - - diag=False -> CorDA: full C = E[x x^T] (cross-channel covariance, via eigh) - diag=True -> ASVD: diag(C) = E[x_i^2] only (per-channel scale, O(d_in), no eigh) - - The off-diagonal of C is the sole difference. g=0 stays exact identity either way -- - the reconstruction (W_res + U_r S_r P_r = W_orig) is lossless. Accumulated on CPU: a - full C is d_in^2 fp32 per target and would crowd the GPU; the diagonal is a d_in vector. - Call at attach-time, before training touches g (re-orienting g=0 is a no-op). - """ - if calibration_data is None: - raise ValueError("covariance orientation requires calibration_data; got None.") - - layers = {name: layer for name, layer, _ in targets} - moment: dict[str, T] = {} # (d_in,d_in) full, or (d_in,) diagonal - cnt: dict[str, int] = {n: 0 for n in layers} - keep: dict[str, T] = {} # non-pad mask of the in-flight batch - - def make_hook(name): - def _h(module, args, kwargs): - x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu() - if "mask" in keep: - x = x[keep["mask"]] # drop padding positions (see loop below) - m = x.pow(2).sum(0) if diag else x.T @ x - moment[name] = m if name not in moment else moment[name] + m - cnt[name] += x.shape[0] # real (non-pad) tokens accumulated - return _h - - handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers] - try: - was_training = model.training - model.eval() - with torch.no_grad(): - for batch in calibration_data: - # Padding activations are not task-representative; mask them out of the moment - # so the oriented basis reflects real tokens (CorDA/SVD-LLM official code does - # the same). The mask is per-token, shared across all target layers in a batch. - keep.pop("mask", None) - if isinstance(batch, dict): - if "attention_mask" in batch: - keep["mask"] = rearrange(batch["attention_mask"], "... -> (...)").bool().cpu() - model(**batch) - elif isinstance(batch, (list, tuple)): - model(*batch) - else: - model(batch) - if was_training: - model.train() - finally: - for h in handles: - h.remove() - - r, eps = cfg.r, float(cfg.cov_eps) - for name, layer in layers.items(): - if cnt[name] < r: - raise RuntimeError(f"covariance orient at {name}: {cnt[name]} tokens, need >= r={r}") - # decomposition on CPU (where the moment lives); results copied back to device buffers. - W_res = layer.weight.data.float().cpu() - U_old, S_old, P_old = (layer.lora_U.float().cpu(), - layer.lora_S.float().cpu(), - layer.lora_P.float().cpu()) - W_orig = W_res + (U_old * S_old) @ P_old - - if diag: - c = (moment[name] / cnt[name]).clamp_min(0) + eps # (d_in,) per-channel scale - Ut, St, Vht = torch.linalg.svd(W_orig * c.sqrt(), full_matrices=False) # @ diag(c^1/2) - Pr = Vht[:r] * c.rsqrt() # @ diag(c^-1/2): oblique projector - else: - C = moment[name] / cnt[name] # (d_in,d_in) - lam, Q = torch.linalg.eigh(C) - lam = lam.clamp_min(0) + eps - Mhalf = (Q * lam.sqrt()) @ Q.T # C^{1/2} - Minvhalf = (Q * lam.rsqrt()) @ Q.T # C^{-1/2} - Ut, St, Vht = torch.linalg.svd(W_orig @ Mhalf, full_matrices=False) - Pr = Vht[:r] @ Minvhalf # (r, d_in) oblique projector - # Quantize the frozen buffers to their stored dtype FIRST, then form the residual - # against those exact (bf16) values. The forward reconstructs from the bf16 buffers, - # so W_res + U_r S_r P_r = W_orig to one residual-rounding -- without this, the - # residual is built from fp32 U/S/P and the forward also eats the U/S/P quantization - # mismatch, so g=0 drifts further from identity. - Ur = Ut[:, :r].to(layer.lora_U.dtype) - Sr = St[:r].to(layer.lora_S.dtype) - Pr = Pr.to(layer.lora_P.dtype) - W_res_new = W_orig - (Ur.float() * Sr.float()) @ Pr.float() - - with torch.no_grad(): - layer.lora_U.copy_(Ur) - layer.lora_S.copy_(Sr) - layer.lora_P.copy_(Pr) - layer.weight.data.copy_(W_res_new.to(layer.weight)) - - -@register -class AntiPaSTOCorDA: - name = "antipasto_corda" - - @staticmethod - def param_specs(d_in, d_out, cfg): - r = cfg.r - return dict( - lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True), - lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), - # P replaces Vh: oblique covariance-oriented input projector. - lora_P=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), - # Trainable per-direction log-scale. init 0 -> 1+ELU(0)=1 -> exact identity. - # No sign-symmetry hack needed (1+ELU is sign-preserving, basis frozen), - # matching antipasto.py. - lora_g=ParamSpec((r,), init="zeros"), - ) - - @staticmethod - def init(layer: nn.Module, cfg) -> None: - """Plain-SVD fallback so the adapter is valid before group_init. group_init - refines P/U/S to the covariance-oriented basis when calibration_data is given.""" - if type(layer) is not nn.Linear: - raise TypeError("AntiPaSTOCorDA mutates layer.weight into W_res; nn.Linear only.") - with torch.no_grad(): - W = layer.weight.data.float() - U, S, Vh = torch.linalg.svd(W, full_matrices=False) - r = cfg.r - Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :] - layer.lora_U.copy_(Ur.to(layer.lora_U.dtype)) - layer.lora_S.copy_(Sr.to(layer.lora_S.dtype)) - layer.lora_P.copy_(Vhr.to(layer.lora_P.dtype)) # P := Vh until oriented - W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) - layer.weight.data.copy_(W_res) - - @staticmethod - def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: - """CorDA: re-orient by the full input covariance C = E[x x^T] (cross-channel).""" - _covariance_orient(model, targets, cfg, calibration_data, diag=False) - - @staticmethod - def forward( - layer: nn.Module, - x: Float[T, '*B i'], - y: Float[T, '*B o'], - ) -> Float[T, '*B o']: - cfg = layer._lora_cfg - U = layer.lora_U.to(x.dtype) # (d_out, r) - S = layer.lora_S.to(x.dtype) # (r,) - P = layer.lora_P.to(x.dtype) # (r, d_in) oblique - g = layer.lora_g.to(x.dtype) # (r,) - S_eff = _gain(S, g, float(cfg.coeff), bool(cfg.suppress_only)) - h = (x @ P.T) * S_eff # (..., r) - return y + h @ U.T diff --git a/src/lora_lite/variants/antipasto_dplr.py b/src/lora_lite/variants/antipasto_dplr.py deleted file mode 100644 index db15a02..0000000 --- a/src/lora_lite/variants/antipasto_dplr.py +++ /dev/null @@ -1,166 +0,0 @@ -"""AntiPaSTO-DPLR: diagonal gain plus a low-rank mixing core in the frozen SVD basis. - -antipasto's diagonal gain rescales each singular direction but cannot mix one into -another. DPLR adds a trainable rank-k core that does, inside the frozen U/Vh basis: - - W = U diag(S) Vh + W_res # frozen top-r SVD - learn: g (r,) # diagonal gain - A (k,r) kaiming, B (r,k) zero # low-rank mixing core - p = x @ Vh.T # (r,) input in the frozen S-basis - S_eff = S * (1 + ELU(coeff * g)) - h = p * S_eff + coeff * (p @ A.T) @ B.T # diagonal gain + rank-k mixing - y = x @ W_res.T + h @ U.T - -The rank-k term is LoRA's core (Hu+ 2021, arXiv:2106.09685) restricted to W's top-r -subspace, ADDED to the gain rather than folded into diag(S): being independent of S, a -unit step moves W by O(1) not O(S), so it has no singular-value amplification. Params -= r + 2*r*k. Identity at init (B=0, g=0) and at coeff=0. Basis (U, Vh) stays frozen. - -Refs: antipasto.py (diagonal sibling), lora.py (low-rank core), antipasto_corda.py -(oriented basis -- composes with this core). -""" -from dataclasses import dataclass -from typing import Iterable, Literal - -import torch -import torch.nn.functional as F -from einops import rearrange -from jaxtyping import Float -from torch import nn, Tensor as T - -from ..variant import register, ParamSpec -from ..config import AdapterConfig, register_config - -CalibrationBatch = dict | tuple | list | T -CalibrationData = Iterable[CalibrationBatch] - - -@register_config -@dataclass -class AntiPaSTODPLRConfig(AdapterConfig): - variant: str = "antipasto_dplr" - r: int = 256 - # Rank of the low-rank mixing core (LoRA's r, but inside the frozen subspace). - # Params = r (gain) + 2*r*lora_rank. Requires 1 <= lora_rank <= r. - lora_rank: int = 8 - suppress_only: bool = False # clamp the gain g<=0 (attenuate only); core unaffected. - coeff: float = 1.0 # runtime knob: 0=identity, scales gain and core. - act_pool: Literal["rms", "mean_abs"] = "rms" # group_init selection, see antipasto. - - -@register -class AntiPaSTODPLR: - name = "antipasto_dplr" - - @staticmethod - def param_specs(d_in, d_out, cfg): - r, k = cfg.r, cfg.lora_rank - if not 0 < k <= r: - raise ValueError(f"antipasto_dplr needs 0 < lora_rank({k}) <= r({r}).") - return dict( - lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True), - lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), - lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), - # Diagonal gain (== antipasto). init 0 -> 1+ELU(0)=1 -> identity. - lora_g=ParamSpec((r,), init="zeros"), - # Low-rank core B@A in the frozen subspace. A down (r->k), B up (k->r). - # B=0 at init -> core=0 -> identity (LoRA convention). - lora_A=ParamSpec((k, r), init="kaiming"), - lora_B=ParamSpec((r, k), init="zeros"), - ) - - @staticmethod - def init(layer: nn.Module, cfg) -> None: - if type(layer) is not nn.Linear: - raise TypeError("AntiPaSTODPLR mutates layer.weight into W_res; nn.Linear only.") - with torch.no_grad(): - W = layer.weight.data.float() - U, S, Vh = torch.linalg.svd(W, full_matrices=False) - r = cfg.r - Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :] - layer.lora_U.copy_(Ur.to(layer.lora_U.dtype)) - layer.lora_S.copy_(Sr.to(layer.lora_S.dtype)) - layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) - W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) - layer.weight.data.copy_(W_res) - - @staticmethod - def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: - """Wanda-style re-selection of the top-r directions, identical to antipasto. - Runs before training while g and B are still zero, so the core contributes - nothing and re-selecting the basis is a no-op on the adapter output.""" - if calibration_data is None: - return - - layers = {name: layer for name, layer, _ in targets} - captured: dict[str, list[T]] = {n: [] for n in layers} - - def make_hook(name): - def _h(module, args, kwargs): - x = args[0].detach() - captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu()) - return _h - - handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers] - try: - was_training = model.training - model.eval() - with torch.no_grad(): - for batch in calibration_data: - if isinstance(batch, dict): - model(**batch) - elif isinstance(batch, (list, tuple)): - model(*batch) - else: - model(batch) - if was_training: - model.train() - finally: - for h in handles: - h.remove() - - r, pool = cfg.r, cfg.act_pool - for name, layer in layers.items(): - X = torch.cat(captured[name], dim=0) - if X.shape[0] < r: - raise RuntimeError(f"AntiPaSTODPLR at {name}: {X.shape[0]} tokens, need >= r={r}") - # Rebuild the FULL W exactly (W_res + stored top-r), then re-select top-r. - W_res = layer.weight.data.float() - W_orig = W_res + (layer.lora_U.float() * layer.lora_S.float()) @ layer.lora_Vh.float() - U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False) - proj = X.to(Vh_full) @ Vh_full.T - act_mag = proj.pow(2).mean(0).sqrt() if pool == "rms" else proj.abs().mean(0) - idx = (S_full * act_mag).argsort(descending=True)[:r].sort().values - Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx] - W_res_new = (W_orig - (Ur * Sr) @ Vhr).to(layer.weight.dtype) - with torch.no_grad(): - layer.lora_U.copy_(Ur.to(layer.lora_U)) - layer.lora_S.copy_(Sr.to(layer.lora_S)) - layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh)) - layer.weight.data.copy_(W_res_new) - - @staticmethod - def forward( - layer: nn.Module, - x: Float[T, '*B i'], - y: Float[T, '*B o'], - ) -> Float[T, '*B o']: - cfg = layer._lora_cfg - U = layer.lora_U.to(x.dtype) # (d_out, r) - S = layer.lora_S.to(x.dtype) # (r,) - Vh = layer.lora_Vh.to(x.dtype) # (r, d_in) - g = layer.lora_g.to(x.dtype) # (r,) - A = layer.lora_A.to(x.dtype) # (k, r) - B = layer.lora_B.to(x.dtype) # (r, k) - coeff = float(cfg.coeff) - - if cfg.suppress_only: - g = torch.clamp(g, max=0.0) - - p = x @ Vh.T # (..., r) = Vh x (unscaled) - S_eff = S * (1.0 + F.elu(coeff * g)) # diagonal gain (see antipasto.py) - # Diagonal part scales each direction; low-rank part B@A mixes across the - # subspace. Additive (not * diag(S)), so the core is S-independent: a unit - # step in B@A moves W by O(1), not O(S) -- no S-amplification edge. - h = p * S_eff + coeff * (p @ A.T) @ B.T # (..., r) - return y + h @ U.T diff --git a/src/lora_lite/variants/antipasto_rot.py b/src/lora_lite/variants/antipasto_rot.py deleted file mode 100644 index 159a5c3..0000000 --- a/src/lora_lite/variants/antipasto_rot.py +++ /dev/null @@ -1,227 +0,0 @@ -"""AntiPaSTO-Rot: SVD adapter with learnable singular-value deltas + a block-diagonal -Cayley rotation of the frozen basis. The rotation arm vs antipasto.py's gain-only core. - -wassname 2026 https://arxiv.org/abs/2601.07473 - - W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r) - learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2) - R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T) - y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T - -Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ~ x @ W^T (tiny positive bias on -delta_s breaks sign symmetry; rotation alone can't). - -Refs: - - paper: https://github.com/wassname/AntiPaSTO - - lite port of: https://github.com/wassname/antipasto3 - (offline: docs/refs/antipasto3_svd_adapter.py) -""" -import math -from dataclasses import dataclass -from typing import Iterable, Literal - -import torch -from einops import einsum, rearrange -from jaxtyping import Float -from torch import nn, Tensor as T - -from ..variant import register, ParamSpec -from ..config import AdapterConfig, register_config - -CalibrationBatch = dict | tuple | list | T -CalibrationData = Iterable[CalibrationBatch] - - -@register_config -@dataclass -class AntiPaSTORotConfig(AdapterConfig): - variant: str = "antipasto_rot" - # Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out). - r: int = 256 - # Block size for the block-diagonal Cayley rotation. r must be divisible by it. - block_size: int = 4 - # Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians. - max_rotation_angle: float = 0.5 - # Which singular basis to rotate: 'V' (input), 'U' (output), 'both', or 'none'. - rotate_basis: Literal["V", "U", "both", "none"] = "V" - - -def _cayley(skew: torch.Tensor) -> torch.Tensor: - """R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality.""" - bs = skew.shape[-1] - eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew) - X = skew / 2 - return torch.linalg.solve(eye - X, eye + X) - - -def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor: - """rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation.""" - n_blocks, _ = rot_T.shape - rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0) - A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device) - A[:, rows, cols] = rot_T - A = 0.5 * (A - A.transpose(-1, -2)) - a_limit = 2.0 * math.tan(max_angle / 2.0) - A = a_limit * torch.tanh(A / a_limit) - return _cayley(A) - - -@register -class AntiPaSTORot: - name = "antipasto_rot" - - @staticmethod - def param_specs(d_in, d_out, cfg): - r = cfg.r - bs = int(cfg.block_size) - if r % bs != 0: - raise ValueError(f"AntiPaSTORot requires r={r} divisible by block_size={bs}") - specs = dict( - # Frozen SVD components captured at init. - lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True), - lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True), - lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True), - # Trainable: per-singular-value delta. - # antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign - # symmetry (rotation alone can't); zero-init works but trains slower. - lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)), - ) - if cfg.rotate_basis != "none": - n_blocks = r // bs - n_triu = bs * (bs - 1) // 2 - specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros") - if cfg.rotate_basis == "both": - # 'both' rotates V (lora_rot_T) and U independently; lora_rot_T_u is the U-side. - specs["lora_rot_T_u"] = ParamSpec((n_blocks, n_triu), init="zeros") - return specs - - @staticmethod - def init(layer: nn.Module, cfg) -> None: - if type(layer) is not nn.Linear: - raise TypeError( - "AntiPaSTORot mutates layer.weight into W_res (like PiSSA), so v1 " - "only supports plain nn.Linear, not bnb 4/8-bit." - ) - with torch.no_grad(): - W = layer.weight.data.float() - U, S, Vh = torch.linalg.svd(W, full_matrices=False) - r = cfg.r - Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :] - layer.lora_U.copy_(Ur.to(layer.lora_U.dtype)) - layer.lora_S.copy_(Sr.to(layer.lora_S.dtype)) - layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) - W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) - layer.weight.data.copy_(W_res) - # group_init() refines this to input-aligned directions if calibration_data is given. - - @staticmethod - def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: - """Wanda-style data-driven dimension selection within the weight SVD. - - init() picks the top-r singular dimensions by S alone (PiSSA-style). - group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions - that are both large in W AND active given real inputs. - - If calibration_data is None the weight-SVD init from init() is kept. - """ - if calibration_data is None: - return - - layers = {name: layer for name, layer, _ in targets} - captured: dict[str, list[T]] = {n: [] for n in layers} - - def make_hook(name): - def _h(module, args, kwargs): - x = args[0].detach() - captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu()) - return _h - - handles = [ - layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) - for n in layers - ] - try: - was_training = model.training - model.eval() - with torch.no_grad(): - for batch in calibration_data: - if isinstance(batch, dict): - model(**batch) - elif isinstance(batch, (list, tuple)): - model(*batch) - else: - model(batch) - if was_training: - model.train() - finally: - for h in handles: - h.remove() - - r = cfg.r - for name, layer in layers.items(): - X = torch.cat(captured[name], dim=0) # (N, d_in) - if X.shape[0] < r: - raise RuntimeError( - f"AntiPaSTORot at {name}: only {X.shape[0]} calibration tokens, need >= r={r}" - ) - - # Recover W_orig: init() wrote W_res into layer.weight and stored top-r components - W_res = layer.weight.data.float() - U_old = layer.lora_U.float() # (d_out, r) - S_old = layer.lora_S.float() # (r,) - Vh_old = layer.lora_Vh.float() # (r, d_in) - W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old - - # Full SVD to score all dimensions - U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False) - # score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude) - act_mag = (X.to(Vh_full) @ Vh_full.T).abs().mean(dim=0) # (k,) -- X captured on CPU - scores = S_full * act_mag - idx = scores.argsort(descending=True)[:r] # top-r by joint importance - idx = idx.sort().values # stable ordering - - Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx] - W_res_new = (W_orig - (Ur * Sr.unsqueeze(0)) @ Vhr).to(layer.weight.dtype) - - with torch.no_grad(): - layer.lora_U.copy_(Ur.to(layer.lora_U)) - layer.lora_S.copy_(Sr.to(layer.lora_S)) - layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh)) - layer.weight.data.copy_(W_res_new) - - @staticmethod - def forward( - layer: nn.Module, - x: Float[T, '*B i'], - y: Float[T, '*B o'], - ) -> Float[T, '*B o']: - cfg = layer._lora_cfg - bs = int(cfg.block_size) - max_angle = float(cfg.max_rotation_angle) - rotate_basis = cfg.rotate_basis - - U = layer.lora_U.to(x.dtype) # (d_out, r) - S = layer.lora_S.to(x.dtype) # (r,) - Vh = layer.lora_Vh.to(x.dtype) # (r, d_in) - - if rotate_basis == "none": - U_eff, Vh_eff = U, Vh - else: - R = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype) - n_blocks = R.shape[0] # R: (n, bs, bs) - U_eff, Vh_eff = U, Vh - # 'V'/'U' rotate that one basis with lora_rot_T; 'both' rotates V with - # lora_rot_T and U with a separate lora_rot_T_u (independent rotations). - if rotate_basis in ("V", "both"): - Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks) - Vh_eff = rearrange(einsum(R, Vh_blocks, "n a b, n b i -> n a i"), "n a i -> (n a) i") - if rotate_basis in ("U", "both"): - R_u = _build_rotation(layer.lora_rot_T_u.float(), bs, max_angle).to(x.dtype) if rotate_basis == "both" else R - U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks) - U_eff = rearrange(einsum(U_blocks, R_u, "d n b, n c b -> d n c"), "d n c -> d (n c)") - - S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,) - h = x @ Vh_eff.T # x @ Vh_eff.T - h = h * S_eff # diag(S_eff) - delta = h @ U_eff.T # @ U_eff.T - return y + delta diff --git a/src/lora_lite/variants/eva.py b/src/lora_lite/variants/eva.py index 51a4dba..f5bbcaa 100644 --- a/src/lora_lite/variants/eva.py +++ b/src/lora_lite/variants/eva.py @@ -84,7 +84,7 @@ class EVA: with torch.no_grad(): for batch in calibration_data: # Padding activations are not task-representative; mask them out of the - # PCA so the basis reflects real tokens (matches antipasto_corda). + # PCA so the basis reflects real tokens. keep.pop("mask", None) if isinstance(batch, dict): if "attention_mask" in batch: diff --git a/tests/test_metamath_smoke.py b/tests/test_metamath_smoke.py index 43874dd..178440e 100644 --- a/tests/test_metamath_smoke.py +++ b/tests/test_metamath_smoke.py @@ -33,13 +33,11 @@ SPEC.loader.exec_module(benchmark) VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", - "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", - "antipasto_asvd", "antipasto_dplr", "road"] + "antipasto", "road"] # Variants that fail loud when attached on a bnb-loaded base (read dense weight in init). # delora/eva also read weight but currently silently dequant -- they produce sane attach, # so we don't expect a raise from them in the attach-only smoke. -BNB_RAISERS = {"pissa", "dora", "antipasto", "antipasto_rot", "antipasto_ablate", - "antipasto_corda", "antipasto_dplr"} +BNB_RAISERS = {"pissa", "dora", "antipasto"} TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" HAS_CUDA = torch.cuda.is_available() @@ -61,7 +59,6 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc quantization=quantization, r=4, alpha=8, - antipasto_lora_rank=2, # antipasto_dplr needs 0 < lora_rank <= r (r=4 here) target_name=target_name, layers="all", steps=2,