mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 16:15:50 +08:00
Collapse antipasto family to one variant: rot(V) becomes canonical antipasto
main keeps a single antipasto = the rotation+delta SVD adapter (the published
method, paper 2601.07473), default rotate_basis=V. On GSM8K/down_proj rot(V)
led the family (57.2) and at a single seed nothing separated from it, while the
covariance-oriented arms cost 34-120s init for no gain. The full family (gain
core, U/both rotations, ablate, dplr, corda, asvd) is preserved on the
antipasto-variants branch.
- antipasto.py is now the rotation implementation, registered as "antipasto"
- delete antipasto_{rot,ablate,corda,asvd,dplr}.py + their config exports
- benchmark/justfile/cost_report/smoke: drop the removed variants + dead knobs
(antipasto_coeff/suppress_only/ablate_k/cov_orient/lora_rank); keep
--antipasto-rotate-basis as antipasto's V/U/both/none ablation axis
- README: subset table to one antipasto row, add rank column, note single-seed
noise floor (~1.4pp), point the full family at the branch
smoke: 10 passed
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -48,47 +48,31 @@ just qwen-probe # Qwen/Qwen3-0.6B train/save-load probe
|
||||
## Variants
|
||||
|
||||
Trained on a MetaMathQA subset, tested on GSM8K, all on `Qwen/Qwen3.5-0.8B-Base` targeting
|
||||
`down_proj` in all 24 layers (2500 steps, effective batch 8 = 20k samples). Standard adapters
|
||||
use r=32; the AntiPaSTO family uses r=256 (it tunes only S-space gain, so it needs the rank).
|
||||
`down_proj` in all 24 layers (2500 steps, effective batch 8 = 20k samples).
|
||||
|
||||
| Variant | test % | valid % | Params | +MACs/tok | fwd/bwd (ms) | init (s) |
|
||||
| --------------------------------------------- | -----: | ------: | ------: | --------: | -----------: | -------: |
|
||||
| [DoRA](https://arxiv.org/abs/2402.09353) | 60.2 | 68.0 | 3.56M | 3.54M | 161 / 556 | 0.16 |
|
||||
| [LoRA](https://arxiv.org/abs/2106.09685) | 59.8 | 68.0 | 3.54M | 3.54M | 173 / 573 | 0.02 |
|
||||
| [PiSSA](https://arxiv.org/abs/2404.02948) | 59.8 | 76.0 | 3.54M | 3.54M | 146 / 549 | 2.04 |
|
||||
| [HRA](https://arxiv.org/abs/2405.17484) | 59.2 | 70.0 | 2.75M | 2.75M | 225 / 948 | 0.04 |
|
||||
| [EVA](https://arxiv.org/abs/2410.07170) | 59.3 | 74.0 | 3.54M | 3.54M | 151 / 660 | 28.3 |
|
||||
| [IA3-FF](https://arxiv.org/pdf/2205.05638) | 56.3 | 62.0 | 0.086M | 0M | 140 / 510 | 0.01 |
|
||||
| [DeLoRA](https://arxiv.org/abs/2503.18225) | 56.2 | 62.0 | 3.54M | 3.54M | 169 / 593 | 0.21 |
|
||||
| [AntiPaSTO](https://arxiv.org/abs/2601.07473) | 56.0 | 62.0 | 0.0061M | 28.3M | 166 / 571 | 2.5 |
|
||||
| AntiPaSTO-rot | 57.2 | 60.0 | 0.0154M | 28.3M | 165 / 596 | 2.0 |
|
||||
| AntiPaSTO-ablate | 56.0 | 68.0 | 0.0062M | 28.3M | 166 / 580 | 2.2 |
|
||||
| AntiPaSTO-dplr | 56.0 | 64.0 | 0.1044M | 28.4M | 153 / 582 | 3.6 |
|
||||
| AntiPaSTO-ASVD (diag C) | 55.6 | 64.0 | 0.0061M | 28.3M | 150 / 533 | 34 |
|
||||
| AntiPaSTO-CorDA (full C) | 54.7 | 58.0 | 0.0061M | 28.3M | 146 / 576 | 120 |
|
||||
| [IA3](https://arxiv.org/pdf/2205.05638) | 52.3 | 62.0 | 0.0061M | 0M | 161 / 515 | 0.01 |
|
||||
| Variant | r | test % | valid % | Params | +MACs/tok | fwd/bwd (ms) | init (s) |
|
||||
| --------------------------------------------- | ---: | -----: | ------: | ------: | --------: | -----------: | -------: |
|
||||
| [DoRA](https://arxiv.org/abs/2402.09353) | 32 | 60.2 | 68.0 | 3.56M | 3.54M | 161 / 556 | 0.16 |
|
||||
| [LoRA](https://arxiv.org/abs/2106.09685) | 32 | 59.8 | 68.0 | 3.54M | 3.54M | 173 / 573 | 0.02 |
|
||||
| [PiSSA](https://arxiv.org/abs/2404.02948) | 32 | 59.8 | 76.0 | 3.54M | 3.54M | 146 / 549 | 2.04 |
|
||||
| [EVA](https://arxiv.org/abs/2410.07170) | 32 | 59.3 | 74.0 | 3.54M | 3.54M | 151 / 660 | 28.3 |
|
||||
| [HRA](https://arxiv.org/abs/2405.17484) | 32 | 59.2 | 70.0 | 2.75M | 2.75M | 225 / 948 | 0.04 |
|
||||
| [AntiPaSTO](https://arxiv.org/abs/2601.07473) | 256 | 57.2 | 60.0 | 0.015M | 28.3M | 165 / 596 | 2.0 |
|
||||
| [IA3-FF](https://arxiv.org/pdf/2205.05638) | — | 56.3 | 62.0 | 0.086M | 0M | 140 / 510 | 0.01 |
|
||||
| [DeLoRA](https://arxiv.org/abs/2503.18225) | 32 | 56.2 | 62.0 | 3.54M | 3.54M | 169 / 593 | 0.21 |
|
||||
| [IA3](https://arxiv.org/pdf/2205.05638) | — | 52.3 | 62.0 | 0.006M | 0M | 161 / 515 | 0.01 |
|
||||
|
||||
test/valid % = GSM8K exact-match accuracy. Params = trainable adapter params. +MACs/tok = added
|
||||
forward MACs per token (analytic, hardware-independent). fwd/bwd = median ms over one batch.
|
||||
init = one-time calibration (CorDA's `d_in x d_in` covariance eigh; ~0 for the rest). Peak CUDA
|
||||
memory is ~9.8 GB for every row. Empty rows fill in as the sweep lands.
|
||||
r = adapter rank (— = not a low-rank method). test/valid % = GSM8K exact-match accuracy. Params =
|
||||
trainable adapter params. +MACs/tok = added forward MACs per token (analytic, hardware-independent).
|
||||
fwd/bwd = median ms over one batch. init = one-time calibration (EVA's PCA; ~0 for the rest). Peak
|
||||
CUDA memory is ~9.8 GB for every row. Single seed, so accuracy differences within ~1.4pp (test
|
||||
SE at n=1319) are noise.
|
||||
|
||||
We validate our adapters the same way [PEFT](https://github.com/huggingface/peft/tree/main/method_comparison) does: train on a MetaMathQA subset and check meaningful GSM8K accuracy. See [this file](scripts/metamath_gsm8k_benchmark.py) for details.
|
||||
|
||||
AntiPaSTO is the novel row here: instead of adding trainable directions like LoRA, it freezes W's own top-r SVD and learns only a bounded per-direction gain `S_eff = S * (1 + ELU(g))`. The singular basis stays fixed and interpretable, and the adapter is O(r) params (the 6.1K gain is ~580x smaller than LoRA's 3.54M). The variants change only the basis or core: rot learns a small block-rotation of the frozen basis, CorDA/ASVD orient it by the input second moment (full covariance vs diagonal-only, [Yang+ 2024](https://arxiv.org/abs/2406.05223) / [Yuan+ 2023](https://arxiv.org/abs/2312.05821)), ablate learns a contractive directional ablation ([Arditi+ 2024](https://arxiv.org/abs/2406.11717)), dplr adds a small low-rank core for cross-direction mixing.
|
||||
AntiPaSTO is the novel row here: instead of adding trainable directions like LoRA, it freezes W's own top-r SVD and learns only a per-direction singular-value delta plus a block-diagonal Cayley rotation of that frozen basis. The singular directions stay interpretable and the adapter is tiny (15K params, ~230x smaller than LoRA's 3.54M) yet stays within noise of the full-rank adapters. The default rotates the input basis (V); rotating the output (U), both, or neither are `rotate_basis` ablation axes.
|
||||
|
||||
CorDA (full C) and ASVD (diag C) are a metric-axis ablation against plain AntiPaSTO (C=I): does
|
||||
covariance orientation earn its `d_in x d_in` eigh over the cheap diagonal or no calibration at
|
||||
all? On GSM8K/down_proj the answer is no: C=I 56.0, diag C 55.6, full C 54.7 (single seed). The
|
||||
off-diagonal orientation is the slowest arm (120 s init vs 2.5 s) and lands slightly *below* no
|
||||
calibration, so plain top-r SVD is the right default for this bounded-gain adapter here.
|
||||
|
||||
AntiPaSTO-rot tunes that basis instead of the metric: a block-diagonal Cayley rotation of the
|
||||
input (V), output (U), or both. The table row is V (the default); the ablation gives V 57.2 >
|
||||
U 56.5 > both 55.6 (single seed). So rotating which inputs feed each frozen direction helps most,
|
||||
the output-side rotation is slightly worse, and doing both is worst -- the second rotation is
|
||||
redundant capacity that hurts. rot(V) is the best small-parameter arm overall (57.2 at 15K params
|
||||
vs LoRA's 59.8 at 3.54M).
|
||||
The full AntiPaSTO family (rotation-free gain core, the U/both rotation arms, contractive directional ablation [Arditi+ 2024](https://arxiv.org/abs/2406.11717), a low-rank mixing core, and CorDA/ASVD covariance-oriented bases [Yang+ 2024](https://arxiv.org/abs/2406.05223) / [Yuan+ 2023](https://arxiv.org/abs/2312.05821)) lives on the [`antipasto-variants`](https://github.com/wassname/lora-lite/tree/antipasto-variants) branch with its own ablation table. On GSM8K/down_proj none of those arms separated from this one at a single seed, and the covariance-oriented bases cost 34-120 s of init for no gain, so main keeps the cheapest arm that led: rotation of V.
|
||||
|
||||
|
||||
## Developer docs
|
||||
|
||||
@@ -83,7 +83,7 @@ metamath-queue variant="lora" steps="5000" model=model:
|
||||
|
||||
# Run a single MetaMathQA->GSM8K benchmark for a given variant.
|
||||
# Per-variant lr / target-name defaults are baked in here.
|
||||
bench-variant model variant steps="5000" lora_rank="8" r_override="" lr_override="" rotate_basis="V" seed="0":
|
||||
bench-variant model variant steps="5000" r_override="" lr_override="" rotate_basis="V" seed="0":
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
lr=1e-4
|
||||
@@ -99,14 +99,14 @@ bench-variant model variant steps="5000" lora_rank="8" r_override="" lr_override
|
||||
delora) lr=1e-3 ;;
|
||||
ia3) lr=5e-3; target='(k_proj|v_proj)$' ;;
|
||||
ia3_ff) lr=5e-3; target='(down_proj)$' ;;
|
||||
# antipasto cores tune only S-space gain/block (tiny params), so a small
|
||||
# r leaves almost nothing trainable; r=256 is the variant default and
|
||||
# matches the published AntiPaSTO row. alpha=r (no extra scaling).
|
||||
antipasto*) lr=5e-3; r=256; alpha=256 ;;
|
||||
# antipasto tunes only S-space deltas + a small block rotation (tiny params),
|
||||
# so a small r leaves almost nothing trainable; r=256 is the variant default
|
||||
# and matches the published AntiPaSTO row. alpha=r (no extra scaling).
|
||||
antipasto) lr=5e-3; r=256; alpha=256 ;;
|
||||
esac
|
||||
# r override (e.g. low-rank corda sweep); alpha tracks r for the antipasto family.
|
||||
# r override (e.g. low-rank sweep); alpha tracks r for antipasto.
|
||||
if [ -n "{{r_override}}" ]; then r="{{r_override}}"; alpha="{{r_override}}"; fi
|
||||
# lr override (e.g. dplr core wants a tamer lr than the gain's 5e-3).
|
||||
# lr override (e.g. a tamer lr than antipasto's 5e-3 default).
|
||||
if [ -n "{{lr_override}}" ]; then lr="{{lr_override}}"; fi
|
||||
# 0.8B + large vocab: HF ForCausalLMLoss upcasts logits to fp32 (bs*seq*vocab*4),
|
||||
# which OOMs the 24GB card at the old bs=4/seq=768. micro-batch 2 fits at ~10GB;
|
||||
@@ -119,13 +119,12 @@ bench-variant model variant steps="5000" lora_rank="8" r_override="" lr_override
|
||||
--steps {{steps}} \
|
||||
--lr "$lr" \
|
||||
--target-name "$target" \
|
||||
--antipasto-lora-rank {{lora_rank}} \
|
||||
--batch-size 2 --grad-accum 4 --max-seq-length 512 --batch-size-eval 16 \
|
||||
--layers all --r "$r" --alpha "$alpha" \
|
||||
--antipasto-rotate-basis '{{rotate_basis}}' \
|
||||
--seed {{seed}}
|
||||
|
||||
metamath-queue-all model=model steps="2500" variants="lora pissa delora dora hra ia3 ia3_ff eva antipasto antipasto_rot antipasto_corda antipasto_asvd antipasto_ablate antipasto_dplr":
|
||||
metamath-queue-all model=model steps="2500" variants="lora pissa delora dora hra ia3 ia3_ff eva antipasto":
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
# One pueue job per variant (each runs the live code at run time, so editing
|
||||
|
||||
+4
-10
@@ -2,14 +2,12 @@
|
||||
|
||||
Answers "which is best -- time / flops / adds / params?": MACs/token is the
|
||||
deterministic apples-to-apples compute number; trainable_params is the size headline;
|
||||
wall-time is the felt-but-noisy number; group_init is where CorDA's eigh(d_in^3) bites.
|
||||
wall-time is the felt-but-noisy number; group_init is the one-time init cost.
|
||||
|
||||
Usage:
|
||||
uv run --extra benchmark python scripts/cost_report.py \
|
||||
--model Qwen/Qwen3-0.6B-Base --variants antipasto antipasto_corda antipasto_ablate lora \
|
||||
--model Qwen/Qwen3-0.6B-Base --variants antipasto lora pissa \
|
||||
--target-name 'q_proj$' 'v_proj$' --r 32 --out logs/cost_qwen0.6b.log
|
||||
|
||||
Point --target-name at down_proj to see the CorDA covariance corner (large d_in).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -40,7 +38,6 @@ def build_cfg(variant: str, args, dtype) -> ll.AdapterConfig:
|
||||
bcfg = benchmark.BenchmarkConfig(
|
||||
model=args.model, variant=variant, r=args.r, alpha=float(args.r),
|
||||
target_name=list(args.target_name), layers=args.layers, torch_dtype=args.dtype,
|
||||
antipasto_cov_orient=args.cov_orient,
|
||||
)
|
||||
return benchmark.cfg_for_variant(bcfg, dtype)
|
||||
|
||||
@@ -49,19 +46,16 @@ def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default="Qwen/Qwen3-0.6B-Base")
|
||||
ap.add_argument("--variants", nargs="+",
|
||||
default=["lora", "antipasto", "antipasto_rot", "antipasto_corda",
|
||||
"antipasto_ablate", "antipasto_dplr"])
|
||||
default=["lora", "pissa", "antipasto"])
|
||||
ap.add_argument("--target-name", nargs="+", default=[r"q_proj$", r"v_proj$"])
|
||||
ap.add_argument("--r", type=int, default=32)
|
||||
ap.add_argument("--layers", default="all",
|
||||
help="'all' or comma list e.g. '0,1' -- limit layers (CorDA down_proj eigh is slow).")
|
||||
help="'all' or comma list e.g. '0,1' -- limit layers.")
|
||||
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
ap.add_argument("--dtype", default="bfloat16")
|
||||
ap.add_argument("--seq-len", type=int, default=256)
|
||||
ap.add_argument("--batch", type=int, default=2)
|
||||
ap.add_argument("--calib-batches", type=int, default=4)
|
||||
ap.add_argument("--cov-orient", action="store_true",
|
||||
help="CorDA-orient antipasto_ablate (measure the eigh corner).")
|
||||
ap.add_argument("--out", default="logs/cost.log")
|
||||
args = ap.parse_args()
|
||||
|
||||
|
||||
@@ -35,11 +35,6 @@ CFG_BY_VARIANT = {
|
||||
"hra": ll.HRAConfig,
|
||||
"eva": ll.EVAConfig,
|
||||
"antipasto": ll.AntiPaSTOConfig,
|
||||
"antipasto_rot": ll.AntiPaSTORotConfig,
|
||||
"antipasto_ablate": ll.AntiPaSTOAblateConfig,
|
||||
"antipasto_corda": ll.AntiPaSTOCorDAConfig,
|
||||
"antipasto_asvd": ll.AntiPaSTOASVDConfig,
|
||||
"antipasto_dplr": ll.AntiPaSTODPLRConfig,
|
||||
"road": ll.RoadConfig,
|
||||
}
|
||||
|
||||
@@ -49,7 +44,7 @@ class BenchmarkConfig:
|
||||
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
|
||||
|
||||
model: str = "Qwen/Qwen3.5-0.8B-Base"
|
||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "antipasto_asvd", "antipasto_dplr", "road"] = "lora"
|
||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
|
||||
mode: Literal["benchmark", "probe"] = "benchmark"
|
||||
device: str = "cuda"
|
||||
torch_dtype: str = "bfloat16"
|
||||
@@ -58,16 +53,8 @@ class BenchmarkConfig:
|
||||
alpha: float = 64.0
|
||||
delora_lambda0: float = 0.1
|
||||
road_group_size: int = 64
|
||||
# AntiPaSTO family (gain / corda) runtime knobs.
|
||||
antipasto_coeff: float = 1.0
|
||||
antipasto_suppress_only: bool = False
|
||||
# AntiPaSTO-ablate.
|
||||
antipasto_ablate_k: int = 1
|
||||
antipasto_cov_orient: bool = False
|
||||
# AntiPaSTO-rot (legacy rotation variant) basis to rotate.
|
||||
# AntiPaSTO singular basis to rotate: V (default) / U / both / none (ablation axes).
|
||||
antipasto_rotate_basis: Literal["V", "U", "both", "none"] = "V"
|
||||
# AntiPaSTO-dplr: rank of the low-rank mixing core in the frozen subspace.
|
||||
antipasto_lora_rank: int = 8
|
||||
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
|
||||
layers: str = "all"
|
||||
train_dataset: str = "meta-math/MetaMathQA"
|
||||
@@ -140,16 +127,8 @@ def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConf
|
||||
extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {}
|
||||
if args.variant == "road":
|
||||
extra = {"group_size": args.road_group_size}
|
||||
if args.variant == "antipasto_rot":
|
||||
if args.variant == "antipasto":
|
||||
extra = {"rotate_basis": args.antipasto_rotate_basis}
|
||||
if args.variant in ("antipasto", "antipasto_corda", "antipasto_asvd"):
|
||||
extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only}
|
||||
if args.variant == "antipasto_ablate":
|
||||
extra = {"coeff": args.antipasto_coeff, "k": args.antipasto_ablate_k,
|
||||
"cov_orient": args.antipasto_cov_orient}
|
||||
if args.variant == "antipasto_dplr":
|
||||
extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only,
|
||||
"lora_rank": args.antipasto_lora_rank}
|
||||
return CFG_BY_VARIANT[args.variant](
|
||||
r=args.r,
|
||||
alpha=args.r if args.variant == "pissa" else args.alpha,
|
||||
@@ -579,18 +558,14 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
dtype = getattr(torch, args.torch_dtype)
|
||||
run_commit = current_git_commit()
|
||||
run_id = f"{args.model.replace('/', '--')}__{args.variant}__s{args.steps}__seed{args.seed}"
|
||||
# dplr capacity is set by lora_rank, not r, so keep rank-sweep runs from colliding.
|
||||
if args.variant == "antipasto_dplr" and args.antipasto_lora_rank != 8:
|
||||
run_id += f"__k{args.antipasto_lora_rank}"
|
||||
# antipasto family defaults to r=256; low-rank sweeps get their own dirs.
|
||||
if args.variant.startswith("antipasto") and args.r != 256:
|
||||
# antipasto defaults to r=256; low-rank sweeps get their own dirs.
|
||||
if args.variant == "antipasto" and args.r != 256:
|
||||
run_id += f"__r{args.r}"
|
||||
# antipasto_rot defaults to rotating V; U/both are ablation axes -> own dirs.
|
||||
if args.variant == "antipasto_rot" and args.antipasto_rotate_basis != "V":
|
||||
# antipasto defaults to rotating V; U/both/none are ablation axes -> own dirs.
|
||||
if args.variant == "antipasto" and args.antipasto_rotate_basis != "V":
|
||||
run_id += f"__rot{args.antipasto_rotate_basis}"
|
||||
# antipasto family defaults to lr=5e-3; lr sweeps get their own dirs (the dense/
|
||||
# low-rank cores want a tamer lr than the gain, so this is a real axis).
|
||||
if args.variant.startswith("antipasto") and abs(args.lr - 5e-3) > 1e-9:
|
||||
# antipasto defaults to lr=5e-3; lr sweeps get their own dirs.
|
||||
if args.variant == "antipasto" and abs(args.lr - 5e-3) > 1e-9:
|
||||
run_id += f"__lr{args.lr:g}"
|
||||
out_dir = args.output_dir / run_id
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -600,13 +575,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
|
||||
batches, skipped_train_prompt_too_long = make_train_batches(datasets["train"], tokenizer, args)
|
||||
print_first_train_sample(tokenizer, batches[0])
|
||||
cfg = cfg_for_variant(args, dtype)
|
||||
# Variants with a data-driven group_init need calibration activations from the
|
||||
# downstream task (IPM mode, per CorDA). eva needs only a few batches for its init;
|
||||
# corda/asvd/cov-orient estimate an input second moment, so we hand them many more
|
||||
# batches (PEFT calibrates on a few hundred sequences) for a well-conditioned basis.
|
||||
needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd") or (
|
||||
args.variant == "antipasto_ablate" and args.antipasto_cov_orient
|
||||
)
|
||||
# eva needs a few calibration batches for its data-driven init. antipasto runs
|
||||
# without calibration (plain weight-SVD init), matching how it was benchmarked.
|
||||
needs_calib = args.variant == "eva"
|
||||
init_meter = group_init_meter() # wall-time + peak CPU RAM of group_init
|
||||
if needs_calib:
|
||||
n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches))
|
||||
|
||||
@@ -20,11 +20,6 @@ from .variants.dora import DoRAConfig
|
||||
from .variants.hra import HRAConfig
|
||||
from .variants.eva import EVAConfig
|
||||
from .variants.antipasto import AntiPaSTOConfig
|
||||
from .variants.antipasto_rot import AntiPaSTORotConfig
|
||||
from .variants.antipasto_ablate import AntiPaSTOAblateConfig
|
||||
from .variants.antipasto_corda import AntiPaSTOCorDAConfig
|
||||
from .variants.antipasto_asvd import AntiPaSTOASVDConfig
|
||||
from .variants.antipasto_dplr import AntiPaSTODPLRConfig
|
||||
from .variants.road import RoadConfig
|
||||
|
||||
__all__ = [
|
||||
@@ -38,11 +33,6 @@ __all__ = [
|
||||
"HRAConfig",
|
||||
"EVAConfig",
|
||||
"AntiPaSTOConfig",
|
||||
"AntiPaSTORotConfig",
|
||||
"AntiPaSTOAblateConfig",
|
||||
"AntiPaSTOCorDAConfig",
|
||||
"AntiPaSTOASVDConfig",
|
||||
"AntiPaSTODPLRConfig",
|
||||
"RoadConfig",
|
||||
"attach",
|
||||
"detach",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from . import ( # noqa: F401 side-effect: register
|
||||
lora, pissa, delora, ia3, dora, hra, eva, antipasto, road,
|
||||
antipasto_rot, antipasto_ablate, antipasto_corda, antipasto_asvd, antipasto_dplr,
|
||||
)
|
||||
|
||||
@@ -1,30 +1,32 @@
|
||||
"""AntiPaSTO: learnable bounded reweighting of frozen SVD singular values.
|
||||
"""AntiPaSTO: SVD adapter that freezes W's own top-r basis and learns a per-direction
|
||||
singular-value delta plus a block-diagonal Cayley rotation of that frozen basis.
|
||||
|
||||
wassname 2026 https://arxiv.org/abs/2601.07473
|
||||
|
||||
W = U diag(S) Vh + W_res # top-r SVD; W_res = W - U_r S_r Vh_r, frozen
|
||||
learn: g (r,) # per-direction gain
|
||||
S_eff = S * (1 + ELU(coeff * g)) # exp(z) for z<0 (bounded), 1+z for z>0
|
||||
y = x @ W_res.T + ((x @ Vh.T) * S_eff) @ U.T
|
||||
W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r)
|
||||
learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2)
|
||||
R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T)
|
||||
y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T
|
||||
|
||||
suppress_only: clamp g<=0 -> S_eff in (0, S], attenuation only.
|
||||
coeff: runtime scale; 0 = identity, <0 swaps amplify/suppress.
|
||||
The default rotates the input basis (V): on GSM8K/down_proj this beat rotating the
|
||||
output basis (U) or both, and beat a rotation-free gain core -- rotating which inputs
|
||||
feed each frozen direction is the cheapest knob that helps (57.2 at 15K params). U /
|
||||
both / none are kept as `rotate_basis` ablation axes.
|
||||
|
||||
Identity at g=0 or coeff=0: 1+ELU(0)=1, so S_eff=S (up to the bf16 SVD round-trip).
|
||||
The basis (U, Vh) is frozen, so the singular directions stay interpretable and only
|
||||
the gain is learned. See forward() for why 1+ELU over linear/exp/tanh.
|
||||
Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ~ x @ W^T (tiny positive bias on
|
||||
delta_s breaks sign symmetry; rotation alone can't).
|
||||
|
||||
Refs:
|
||||
- paper: https://github.com/wassname/AntiPaSTO
|
||||
- sibling (whitened, mean-diff): steering-lite/.../sspace.py
|
||||
- selection: Wanda (Sun+ 2023, arXiv:2306.11695), ASVD (Yuan+ 2023, arXiv:2312.05821)
|
||||
- top-r SVD init: PiSSA (Meng+ 2024, arXiv:2404.02948)
|
||||
- lite port of: https://github.com/wassname/antipasto3
|
||||
(offline: docs/refs/antipasto3_svd_adapter.py)
|
||||
"""
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Literal
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from einops import einsum, rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
@@ -39,17 +41,34 @@ CalibrationData = Iterable[CalibrationBatch]
|
||||
@dataclass
|
||||
class AntiPaSTOConfig(AdapterConfig):
|
||||
variant: str = "antipasto"
|
||||
# Only r + r trainable scalars, so r can be large.
|
||||
# Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out).
|
||||
r: int = 256
|
||||
# Per-direction reweighting is S_eff = S * (1 + ELU(coeff * g)). See forward()
|
||||
# for the why; identity at g=0 or coeff=0, positive always, no free bound knob.
|
||||
suppress_only: bool = False # clamp g<=0 -> factor in (0,1]: attenuation only.
|
||||
# Guarantee holds for coeff>=0; coeff<0 inverts the product and re-amplifies.
|
||||
# Runtime steering scale. 0 = identity. <0 inverts (swaps amplify/suppress).
|
||||
coeff: float = 1.0
|
||||
# group_init Wanda-style pooling of |X @ Vh[i]|: 'rms' is outlier-sensitive
|
||||
# (ASVD intuition), 'mean_abs' is the original outlier-robust pooling.
|
||||
act_pool: Literal["rms", "mean_abs"] = "rms"
|
||||
# Block size for the block-diagonal Cayley rotation. r must be divisible by it.
|
||||
block_size: int = 4
|
||||
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
|
||||
max_rotation_angle: float = 0.5
|
||||
# Which singular basis to rotate: 'V' (input, default), 'U' (output), 'both', or 'none'.
|
||||
rotate_basis: Literal["V", "U", "both", "none"] = "V"
|
||||
|
||||
|
||||
def _cayley(skew: torch.Tensor) -> torch.Tensor:
|
||||
"""R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality."""
|
||||
bs = skew.shape[-1]
|
||||
eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew)
|
||||
X = skew / 2
|
||||
return torch.linalg.solve(eye - X, eye + X)
|
||||
|
||||
|
||||
def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor:
|
||||
"""rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation."""
|
||||
n_blocks, _ = rot_T.shape
|
||||
rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0)
|
||||
A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device)
|
||||
A[:, rows, cols] = rot_T
|
||||
A = 0.5 * (A - A.transpose(-1, -2))
|
||||
a_limit = 2.0 * math.tan(max_angle / 2.0)
|
||||
A = a_limit * torch.tanh(A / a_limit)
|
||||
return _cayley(A)
|
||||
|
||||
|
||||
@register
|
||||
@@ -59,14 +78,27 @@ class AntiPaSTO:
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r = cfg.r
|
||||
return dict(
|
||||
# Frozen top-r SVD captured at init.
|
||||
bs = int(cfg.block_size)
|
||||
if r % bs != 0:
|
||||
raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}")
|
||||
specs = dict(
|
||||
# Frozen SVD components captured at init.
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable per-direction log-scale. init 0 -> 1+ELU(0)=1 -> identity.
|
||||
lora_g=ParamSpec((r,), init="zeros"),
|
||||
# Trainable: per-singular-value delta.
|
||||
# antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign
|
||||
# symmetry (rotation alone can't); zero-init works but trains slower.
|
||||
lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)),
|
||||
)
|
||||
if cfg.rotate_basis != "none":
|
||||
n_blocks = r // bs
|
||||
n_triu = bs * (bs - 1) // 2
|
||||
specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros")
|
||||
if cfg.rotate_basis == "both":
|
||||
# 'both' rotates V (lora_rot_T) and U independently; lora_rot_T_u is the U-side.
|
||||
specs["lora_rot_T_u"] = ParamSpec((n_blocks, n_triu), init="zeros")
|
||||
return specs
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
@@ -85,19 +117,17 @@ class AntiPaSTO:
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
# group_init() refines the dimension selection if calibration_data is given.
|
||||
# group_init() refines this to input-aligned directions if calibration_data is given.
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Data-driven re-selection of which top-r singular directions to keep.
|
||||
"""Wanda-style data-driven dimension selection within the weight SVD.
|
||||
|
||||
init(): top-r by S alone (PiSSA-style)
|
||||
group_init(): top-r by score[i] = S[i] * pool|X @ Vh[i]| (Wanda/ASVD)
|
||||
pool = 'rms' (outlier-sensitive) | 'mean_abs' (outlier-robust)
|
||||
init() picks the top-r singular dimensions by S alone (PiSSA-style).
|
||||
group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions
|
||||
that are both large in W AND active given real inputs.
|
||||
|
||||
This re-RANKS W's own singular vectors by activation; it does NOT re-orient
|
||||
the basis (that is CorDA -> antipasto_corda.py). So the kept directions are
|
||||
still plain weight-SVD directions, just a better subset. None -> keep init().
|
||||
If calibration_data is None the weight-SVD init from init() is kept.
|
||||
"""
|
||||
if calibration_data is None:
|
||||
return
|
||||
@@ -133,7 +163,6 @@ class AntiPaSTO:
|
||||
h.remove()
|
||||
|
||||
r = cfg.r
|
||||
pool = cfg.act_pool
|
||||
for name, layer in layers.items():
|
||||
X = torch.cat(captured[name], dim=0) # (N, d_in)
|
||||
if X.shape[0] < r:
|
||||
@@ -141,23 +170,19 @@ class AntiPaSTO:
|
||||
f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}"
|
||||
)
|
||||
|
||||
# Rebuild the FULL W: init() stored the exact top-r it subtracted, so
|
||||
# W_res + U_r S_r Vh_r == W (full rank, not a cropped matrix). The SVD
|
||||
# below therefore re-selects from W's whole spectrum, not a truncation.
|
||||
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r components
|
||||
W_res = layer.weight.data.float()
|
||||
U_old = layer.lora_U.float()
|
||||
S_old = layer.lora_S.float()
|
||||
Vh_old = layer.lora_Vh.float()
|
||||
U_old = layer.lora_U.float() # (d_out, r)
|
||||
S_old = layer.lora_S.float() # (r,)
|
||||
Vh_old = layer.lora_Vh.float() # (r, d_in)
|
||||
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
|
||||
|
||||
# Full SVD to score all dimensions
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
|
||||
proj = X.to(Vh_full) @ Vh_full.T # (N, r) input in S-coords (X CPU -> GPU here)
|
||||
if pool == "rms":
|
||||
act_mag = proj.pow(2).mean(dim=0).sqrt() # outlier-sensitive
|
||||
else:
|
||||
act_mag = proj.abs().mean(dim=0) # outlier-robust (original)
|
||||
# score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude)
|
||||
act_mag = (X.to(Vh_full) @ Vh_full.T).abs().mean(dim=0) # (k,) -- X captured on CPU
|
||||
scores = S_full * act_mag
|
||||
idx = scores.argsort(descending=True)[:r] # top-r by joint importance
|
||||
idx = scores.argsort(descending=True)[:r] # top-r by joint importance
|
||||
idx = idx.sort().values # stable ordering
|
||||
|
||||
Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx]
|
||||
@@ -176,23 +201,32 @@ class AntiPaSTO:
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
bs = int(cfg.block_size)
|
||||
max_angle = float(cfg.max_rotation_angle)
|
||||
rotate_basis = cfg.rotate_basis
|
||||
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
g = layer.lora_g.to(x.dtype) # (r,)
|
||||
coeff = float(cfg.coeff)
|
||||
|
||||
if cfg.suppress_only:
|
||||
g = torch.clamp(g, max=0.0) # factor in (0,1]: attenuation only
|
||||
if rotate_basis == "none":
|
||||
U_eff, Vh_eff = U, Vh
|
||||
else:
|
||||
R = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype)
|
||||
n_blocks = R.shape[0] # R: (n, bs, bs)
|
||||
U_eff, Vh_eff = U, Vh
|
||||
# 'V'/'U' rotate that one basis with lora_rot_T; 'both' rotates V with
|
||||
# lora_rot_T and U with a separate lora_rot_T_u (independent rotations).
|
||||
if rotate_basis in ("V", "both"):
|
||||
Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks)
|
||||
Vh_eff = rearrange(einsum(R, Vh_blocks, "n a b, n b i -> n a i"), "n a i -> (n a) i")
|
||||
if rotate_basis in ("U", "both"):
|
||||
R_u = _build_rotation(layer.lora_rot_T_u.float(), bs, max_angle).to(x.dtype) if rotate_basis == "both" else R
|
||||
U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks)
|
||||
U_eff = rearrange(einsum(U_blocks, R_u, "d n b, n c b -> d n c"), "d n c -> d (n c)")
|
||||
|
||||
# S_eff = S * (1 + ELU(z)), z = coeff*g, 1+ELU(z) = exp(z) for z<=0 else 1+z.
|
||||
# Why 1+ELU and not the obvious alternatives:
|
||||
# linear S*(1+z) : z<-1 -> S_eff<0, a sign flip that drives incoherence.
|
||||
# exp S*exp(z) : unbounded, gradient self-amplifies (amplification blows up).
|
||||
# tanh bounded : arbitrary bound knob, saturation kills the gradient.
|
||||
# 1+ELU uses each in its safe regime: exp where it is bounded in (0,1]
|
||||
# (attenuation), linear where exp would diverge (amplification). >0 always.
|
||||
S_eff = S * (1.0 + torch.nn.functional.elu(coeff * g))
|
||||
|
||||
h = (x @ Vh.T) * S_eff # input in S-coords, reweighted
|
||||
return y + h @ U.T
|
||||
S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,)
|
||||
h = x @ Vh_eff.T # x @ Vh_eff.T
|
||||
h = h * S_eff # diag(S_eff)
|
||||
delta = h @ U_eff.T # @ U_eff.T
|
||||
return y + delta
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
"""AntiPaSTO-Ablate: trainable directional ablation in the weight-SVD output basis.
|
||||
|
||||
A contractive sibling of antipasto.py: instead of reweighting the singular gains it
|
||||
projects out a learned direction in the output (U-side) singular basis.
|
||||
|
||||
W = U diag(S) Vh + W_res
|
||||
learn: c (r, k) ablation directions, alpha (k,) strengths in [0, 1]
|
||||
Chat = orthonormal(c) # k unit dirs in S-space
|
||||
h = (x @ Vh.T) * S # output S-coords = diag(S) Vh x
|
||||
h <- h - coeff * (h @ Chat) * alpha @ Chat.T # project the span out
|
||||
y = x @ W_res.T + h @ U.T
|
||||
|
||||
The core (I - alpha Chat Chat^T) is a contraction: eigenvalues 1-alpha along Chat,
|
||||
1 elsewhere, all in [0, 1]. It cannot amplify, so it cannot blow up -- the instability
|
||||
the multiplicative gain bounds away is structurally absent (and a contraction is the
|
||||
natural core to recurse). This is the trainable form of directional ablation (Arditi+
|
||||
2024): target residual writers (down_proj, o_proj) for the surgical regime, not all
|
||||
Linears.
|
||||
|
||||
Runtime: coeff is the per-call knob. coeff=0 -> identity; (0, 1] -> ablate; <0 adds the
|
||||
direction back (the side that can grow, so bound coeff there).
|
||||
|
||||
Refs: antipasto.py (gain sibling), directional ablation Arditi+ 2024 arXiv:2406.11717.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
ε = 1e-6
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTOAblateConfig(AdapterConfig):
|
||||
variant: str = "antipasto_ablate"
|
||||
r: int = 256 # top-r SVD captured (or |dS|-selected via group_init)
|
||||
k: int = 1 # number of ablation directions (rank of the projection)
|
||||
init_alpha: float = 0.05 # small >0 so c gets gradient at step 0
|
||||
coeff: float = 1.0 # runtime: 0=identity, (0,1]=ablate, <0=amplify (bound this side)
|
||||
# CorDA-orient the basis from input covariance (group_init, needs calibration_data).
|
||||
# The ablation is OUTPUT-side and CorDA's U stays orthonormal, so this is a clean
|
||||
# contraction; the win is at low r -- the data-oriented top-r captures the behavior
|
||||
# output direction that plain-SVD top-r drops (measured 1.00 vs 0.65 at r=16).
|
||||
cov_orient: bool = False
|
||||
cov_eps: float = 1e-3
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTOAblate:
|
||||
name = "antipasto_ablate"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r, k = cfg.r, cfg.k
|
||||
return dict(
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable: k ablation directions in S-space, and their strengths.
|
||||
lora_c=ParamSpec((r, k), init=lambda t: t.normal_(0, 1.0 / max(r, 1) ** 0.5)),
|
||||
lora_alpha=ParamSpec((k,), init=lambda t: t.fill_(float(cfg.init_alpha))),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError("AntiPaSTOAblate mutates layer.weight into W_res; nn.Linear only.")
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
# lora_c is random-init. Tried (job 94) seeding it from the top-k S-space
|
||||
# output-VARIANCE PC: equal-or-worse on GSM8K (55.6/64.0 vs random 56.0/68.0,
|
||||
# single seed) and +31s calib init -- the optimal ablation dir is loss-defined,
|
||||
# not variance-defined, so a variance seed buys nothing on SFT. Reverted.
|
||||
# FIXME the contrastive dS seed (mean(h|pos)-mean(h|neg), cf. sspace.py) is the
|
||||
# one that should land on the behavior dir, but it needs pos/neg pairs this SFT
|
||||
# benchmark lacks -- only worth it for steering with labelled contrastive data.
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""If cov_orient, re-orient each target's SVD by input covariance C=E[x x^T]
|
||||
(CorDA) so the data-relevant output directions land in the top-r and the
|
||||
behavior direction is fully ablatable at low r. No-op otherwise (keeps the
|
||||
plain-SVD basis from init()). C is accumulated on CPU; for down_proj's large
|
||||
d_in this is heavy -- exclude it or use plain ablation there."""
|
||||
if not getattr(cfg, "cov_orient", False) or calibration_data is None:
|
||||
return
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
cov: dict[str, T] = {}
|
||||
cnt: dict[str, int] = {n: 0 for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
|
||||
g = x.T @ x
|
||||
cov[name] = g if name not in cov else cov[name] + g
|
||||
cnt[name] += x.shape[0]
|
||||
return _h
|
||||
|
||||
handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
if isinstance(batch, dict):
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r, eps = cfg.r, float(cfg.cov_eps)
|
||||
for name, layer in layers.items():
|
||||
if cnt[name] < r:
|
||||
raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}")
|
||||
W_res = layer.weight.data.float().cpu()
|
||||
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
|
||||
layer.lora_S.float().cpu(),
|
||||
layer.lora_Vh.float().cpu())
|
||||
W_orig = W_res + (U_old * S_old) @ Vh_old
|
||||
|
||||
C = cov[name] / cnt[name]
|
||||
lam, Q = torch.linalg.eigh(C)
|
||||
lam = lam.clamp_min(0) + eps
|
||||
Chalf = (Q * lam.sqrt()) @ Q.T
|
||||
Cinvhalf = (Q * lam.rsqrt()) @ Q.T
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
|
||||
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
|
||||
Sr = St[:r]
|
||||
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
|
||||
W_res_new = W_orig - (Ur * Sr) @ Pr
|
||||
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
|
||||
layer.weight.data.copy_(W_res_new.to(layer.weight))
|
||||
|
||||
@staticmethod
|
||||
def _orthonormal(c: T) -> T:
|
||||
"""(r, k) -> (r, k) with orthonormal columns. k=1 is a plain normalize."""
|
||||
if c.shape[-1] == 1:
|
||||
return c / (c.norm(dim=0, keepdim=True) + ε)
|
||||
# geqrf has no bf16/fp16 kernel (CPU or CUDA); do the QR in fp32, cast back.
|
||||
q, _ = torch.linalg.qr(c.float()) # reduced QR; columns orthonormal
|
||||
return q.to(c.dtype)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
Chat = AntiPaSTOAblate._orthonormal(layer.lora_c.to(x.dtype)) # (r, k)
|
||||
alpha = layer.lora_alpha.to(x.dtype).clamp(0.0, 1.0) # (k,)
|
||||
coeff = float(cfg.coeff)
|
||||
|
||||
h = (x @ Vh.T) * S # (..., r) output S-coords
|
||||
proj = h @ Chat # (..., k) component along each dir
|
||||
# contractive removal: h <- h - coeff * Sum_j alpha_j (h . chat_j) chat_j
|
||||
h = h - coeff * (proj * alpha) @ Chat.T # (..., r)
|
||||
return y + h @ U.T
|
||||
@@ -1,43 +0,0 @@
|
||||
"""AntiPaSTO-ASVD: diagonal-covariance sibling of antipasto_corda.
|
||||
|
||||
Same frozen-basis bounded gain, but orients the SVD by the DIAGONAL of the input
|
||||
second moment (per-channel activation scale) instead of the full covariance:
|
||||
|
||||
M = diag(E[x_i^2]) vs CorDA's full C = E[x x^T]
|
||||
|
||||
This is Activation-aware SVD (Yuan+ 2023, arXiv:2312.05821): SVD(W diag(s)) with s a
|
||||
per-channel scale. It is NOT a sub-basis of CorDA -- diag(C)^{1/2} and C^{1/2} are
|
||||
different oblique rotations, so the top-r directions differ and either can win on a task.
|
||||
ASVD is the cheap arm: O(d_in) moment, no d_in x d_in matrix, no eigh. The head-to-head
|
||||
with antipasto_corda isolates whether the off-diagonal of C earns its init cost here.
|
||||
|
||||
Reuses antipasto_corda's buffers (U, S, P, g), plain-SVD init, gain forward, and the
|
||||
shared `_covariance_orient` (only the diag flag differs), so there is one copy of the
|
||||
math to keep in sync.
|
||||
|
||||
Refs: antipasto_corda.py (full-covariance sibling), ASVD arXiv:2312.05821.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..variant import register
|
||||
from ..config import register_config
|
||||
from .antipasto_corda import AntiPaSTOCorDA, AntiPaSTOCorDAConfig, _covariance_orient
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTOASVDConfig(AntiPaSTOCorDAConfig):
|
||||
variant: str = "antipasto_asvd"
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTOASVD:
|
||||
name = "antipasto_asvd"
|
||||
param_specs = staticmethod(AntiPaSTOCorDA.param_specs)
|
||||
init = staticmethod(AntiPaSTOCorDA.init)
|
||||
forward = staticmethod(AntiPaSTOCorDA.forward)
|
||||
|
||||
@staticmethod
|
||||
def group_init(model, targets, cfg, calibration_data) -> None:
|
||||
"""ASVD: re-orient by the diagonal of the input second moment (per-channel)."""
|
||||
_covariance_orient(model, targets, cfg, calibration_data, diag=True)
|
||||
@@ -1,203 +0,0 @@
|
||||
"""AntiPaSTO-CorDA: reweight in a covariance-oriented basis, not the weight basis.
|
||||
|
||||
Plain SVD sorts directions by weight gain ||W v|| on isotropic input. The behaviour
|
||||
you steer lives where the DATA has energy, off the top weight-singular axes. CorDA
|
||||
(Yang+ 2024, arXiv:2406.05223) re-orients the SVD by the input covariance, so the top-r
|
||||
directions move the output most on real activations.
|
||||
|
||||
C = E[x x^T] (+ eps I) # input second moment on calibration data
|
||||
C^{1/2}, C^{-1/2} via eigh(C)
|
||||
U S Vht = SVD(W C^{1/2}) # top-r is Eckart-Young best under x ~ N(0,C)
|
||||
P = Vht C^{-1/2} # (r, d_in) oblique input projector
|
||||
W = W_res + U_r diag(S_r) P_r # exact (residual carries the dropped tail)
|
||||
S_eff = S * (1 + ELU(coeff*g)) # same bounded gain as antipasto
|
||||
y = x @ W_res.T + ((x @ P.T) * S_eff) @ U.T
|
||||
|
||||
Identity at g=0 or coeff=0: S_eff=S. P is oblique (rows not orthonormal -- C^{-1/2}
|
||||
skews them); fine for gain reweighting since U stays orthonormal. Requires
|
||||
calibration_data (group_init raises otherwise).
|
||||
|
||||
Refs: antipasto.py (gain + selection sibling), CorDA arXiv:2406.05223.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTOCorDAConfig(AdapterConfig):
|
||||
variant: str = "antipasto_corda"
|
||||
r: int = 256
|
||||
cov_eps: float = 1e-3 # damping on C eigenvalues; guards C^{-1/2} on rare dirs
|
||||
coeff: float = 1.0 # runtime steer knob: 0=identity, scales trained g
|
||||
suppress_only: bool = False # clamp g<=0 (attenuate only) -- for coeff>=0;
|
||||
# coeff<0 inverts the product (coeff*g>=0) and re-amplifies.
|
||||
|
||||
|
||||
def _gain(S: T, g: T, coeff: float, suppress_only: bool) -> T:
|
||||
"""S_eff = S * (1 + ELU(coeff*g)); exp-bounded attenuation, linear amplification."""
|
||||
if suppress_only:
|
||||
g = g.clamp(max=0.0)
|
||||
return S * (1.0 + F.elu(coeff * g))
|
||||
|
||||
|
||||
def _covariance_orient(model, targets, cfg, calibration_data, *, diag: bool) -> None:
|
||||
"""Re-orient each target's SVD by its input second moment, then rewrite the frozen
|
||||
buffers (U, S, P) and residual weight in that basis. Shared by CorDA and ASVD:
|
||||
|
||||
diag=False -> CorDA: full C = E[x x^T] (cross-channel covariance, via eigh)
|
||||
diag=True -> ASVD: diag(C) = E[x_i^2] only (per-channel scale, O(d_in), no eigh)
|
||||
|
||||
The off-diagonal of C is the sole difference. g=0 stays exact identity either way --
|
||||
the reconstruction (W_res + U_r S_r P_r = W_orig) is lossless. Accumulated on CPU: a
|
||||
full C is d_in^2 fp32 per target and would crowd the GPU; the diagonal is a d_in vector.
|
||||
Call at attach-time, before training touches g (re-orienting g=0 is a no-op).
|
||||
"""
|
||||
if calibration_data is None:
|
||||
raise ValueError("covariance orientation requires calibration_data; got None.")
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
moment: dict[str, T] = {} # (d_in,d_in) full, or (d_in,) diagonal
|
||||
cnt: dict[str, int] = {n: 0 for n in layers}
|
||||
keep: dict[str, T] = {} # non-pad mask of the in-flight batch
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
|
||||
if "mask" in keep:
|
||||
x = x[keep["mask"]] # drop padding positions (see loop below)
|
||||
m = x.pow(2).sum(0) if diag else x.T @ x
|
||||
moment[name] = m if name not in moment else moment[name] + m
|
||||
cnt[name] += x.shape[0] # real (non-pad) tokens accumulated
|
||||
return _h
|
||||
|
||||
handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
# Padding activations are not task-representative; mask them out of the moment
|
||||
# so the oriented basis reflects real tokens (CorDA/SVD-LLM official code does
|
||||
# the same). The mask is per-token, shared across all target layers in a batch.
|
||||
keep.pop("mask", None)
|
||||
if isinstance(batch, dict):
|
||||
if "attention_mask" in batch:
|
||||
keep["mask"] = rearrange(batch["attention_mask"], "... -> (...)").bool().cpu()
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r, eps = cfg.r, float(cfg.cov_eps)
|
||||
for name, layer in layers.items():
|
||||
if cnt[name] < r:
|
||||
raise RuntimeError(f"covariance orient at {name}: {cnt[name]} tokens, need >= r={r}")
|
||||
# decomposition on CPU (where the moment lives); results copied back to device buffers.
|
||||
W_res = layer.weight.data.float().cpu()
|
||||
U_old, S_old, P_old = (layer.lora_U.float().cpu(),
|
||||
layer.lora_S.float().cpu(),
|
||||
layer.lora_P.float().cpu())
|
||||
W_orig = W_res + (U_old * S_old) @ P_old
|
||||
|
||||
if diag:
|
||||
c = (moment[name] / cnt[name]).clamp_min(0) + eps # (d_in,) per-channel scale
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig * c.sqrt(), full_matrices=False) # @ diag(c^1/2)
|
||||
Pr = Vht[:r] * c.rsqrt() # @ diag(c^-1/2): oblique projector
|
||||
else:
|
||||
C = moment[name] / cnt[name] # (d_in,d_in)
|
||||
lam, Q = torch.linalg.eigh(C)
|
||||
lam = lam.clamp_min(0) + eps
|
||||
Mhalf = (Q * lam.sqrt()) @ Q.T # C^{1/2}
|
||||
Minvhalf = (Q * lam.rsqrt()) @ Q.T # C^{-1/2}
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Mhalf, full_matrices=False)
|
||||
Pr = Vht[:r] @ Minvhalf # (r, d_in) oblique projector
|
||||
# Quantize the frozen buffers to their stored dtype FIRST, then form the residual
|
||||
# against those exact (bf16) values. The forward reconstructs from the bf16 buffers,
|
||||
# so W_res + U_r S_r P_r = W_orig to one residual-rounding -- without this, the
|
||||
# residual is built from fp32 U/S/P and the forward also eats the U/S/P quantization
|
||||
# mismatch, so g=0 drifts further from identity.
|
||||
Ur = Ut[:, :r].to(layer.lora_U.dtype)
|
||||
Sr = St[:r].to(layer.lora_S.dtype)
|
||||
Pr = Pr.to(layer.lora_P.dtype)
|
||||
W_res_new = W_orig - (Ur.float() * Sr.float()) @ Pr.float()
|
||||
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur)
|
||||
layer.lora_S.copy_(Sr)
|
||||
layer.lora_P.copy_(Pr)
|
||||
layer.weight.data.copy_(W_res_new.to(layer.weight))
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTOCorDA:
|
||||
name = "antipasto_corda"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r = cfg.r
|
||||
return dict(
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
# P replaces Vh: oblique covariance-oriented input projector.
|
||||
lora_P=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable per-direction log-scale. init 0 -> 1+ELU(0)=1 -> exact identity.
|
||||
# No sign-symmetry hack needed (1+ELU is sign-preserving, basis frozen),
|
||||
# matching antipasto.py.
|
||||
lora_g=ParamSpec((r,), init="zeros"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
"""Plain-SVD fallback so the adapter is valid before group_init. group_init
|
||||
refines P/U/S to the covariance-oriented basis when calibration_data is given."""
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError("AntiPaSTOCorDA mutates layer.weight into W_res; nn.Linear only.")
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_P.copy_(Vhr.to(layer.lora_P.dtype)) # P := Vh until oriented
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""CorDA: re-orient by the full input covariance C = E[x x^T] (cross-channel)."""
|
||||
_covariance_orient(model, targets, cfg, calibration_data, diag=False)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
P = layer.lora_P.to(x.dtype) # (r, d_in) oblique
|
||||
g = layer.lora_g.to(x.dtype) # (r,)
|
||||
S_eff = _gain(S, g, float(cfg.coeff), bool(cfg.suppress_only))
|
||||
h = (x @ P.T) * S_eff # (..., r)
|
||||
return y + h @ U.T
|
||||
@@ -1,166 +0,0 @@
|
||||
"""AntiPaSTO-DPLR: diagonal gain plus a low-rank mixing core in the frozen SVD basis.
|
||||
|
||||
antipasto's diagonal gain rescales each singular direction but cannot mix one into
|
||||
another. DPLR adds a trainable rank-k core that does, inside the frozen U/Vh basis:
|
||||
|
||||
W = U diag(S) Vh + W_res # frozen top-r SVD
|
||||
learn: g (r,) # diagonal gain
|
||||
A (k,r) kaiming, B (r,k) zero # low-rank mixing core
|
||||
p = x @ Vh.T # (r,) input in the frozen S-basis
|
||||
S_eff = S * (1 + ELU(coeff * g))
|
||||
h = p * S_eff + coeff * (p @ A.T) @ B.T # diagonal gain + rank-k mixing
|
||||
y = x @ W_res.T + h @ U.T
|
||||
|
||||
The rank-k term is LoRA's core (Hu+ 2021, arXiv:2106.09685) restricted to W's top-r
|
||||
subspace, ADDED to the gain rather than folded into diag(S): being independent of S, a
|
||||
unit step moves W by O(1) not O(S), so it has no singular-value amplification. Params
|
||||
= r + 2*r*k. Identity at init (B=0, g=0) and at coeff=0. Basis (U, Vh) stays frozen.
|
||||
|
||||
Refs: antipasto.py (diagonal sibling), lora.py (low-rank core), antipasto_corda.py
|
||||
(oriented basis -- composes with this core).
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTODPLRConfig(AdapterConfig):
|
||||
variant: str = "antipasto_dplr"
|
||||
r: int = 256
|
||||
# Rank of the low-rank mixing core (LoRA's r, but inside the frozen subspace).
|
||||
# Params = r (gain) + 2*r*lora_rank. Requires 1 <= lora_rank <= r.
|
||||
lora_rank: int = 8
|
||||
suppress_only: bool = False # clamp the gain g<=0 (attenuate only); core unaffected.
|
||||
coeff: float = 1.0 # runtime knob: 0=identity, scales gain and core.
|
||||
act_pool: Literal["rms", "mean_abs"] = "rms" # group_init selection, see antipasto.
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTODPLR:
|
||||
name = "antipasto_dplr"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r, k = cfg.r, cfg.lora_rank
|
||||
if not 0 < k <= r:
|
||||
raise ValueError(f"antipasto_dplr needs 0 < lora_rank({k}) <= r({r}).")
|
||||
return dict(
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Diagonal gain (== antipasto). init 0 -> 1+ELU(0)=1 -> identity.
|
||||
lora_g=ParamSpec((r,), init="zeros"),
|
||||
# Low-rank core B@A in the frozen subspace. A down (r->k), B up (k->r).
|
||||
# B=0 at init -> core=0 -> identity (LoRA convention).
|
||||
lora_A=ParamSpec((k, r), init="kaiming"),
|
||||
lora_B=ParamSpec((r, k), init="zeros"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError("AntiPaSTODPLR mutates layer.weight into W_res; nn.Linear only.")
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Wanda-style re-selection of the top-r directions, identical to antipasto.
|
||||
Runs before training while g and B are still zero, so the core contributes
|
||||
nothing and re-selecting the basis is a no-op on the adapter output."""
|
||||
if calibration_data is None:
|
||||
return
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
captured: dict[str, list[T]] = {n: [] for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
x = args[0].detach()
|
||||
captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu())
|
||||
return _h
|
||||
|
||||
handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
if isinstance(batch, dict):
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r, pool = cfg.r, cfg.act_pool
|
||||
for name, layer in layers.items():
|
||||
X = torch.cat(captured[name], dim=0)
|
||||
if X.shape[0] < r:
|
||||
raise RuntimeError(f"AntiPaSTODPLR at {name}: {X.shape[0]} tokens, need >= r={r}")
|
||||
# Rebuild the FULL W exactly (W_res + stored top-r), then re-select top-r.
|
||||
W_res = layer.weight.data.float()
|
||||
W_orig = W_res + (layer.lora_U.float() * layer.lora_S.float()) @ layer.lora_Vh.float()
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
|
||||
proj = X.to(Vh_full) @ Vh_full.T
|
||||
act_mag = proj.pow(2).mean(0).sqrt() if pool == "rms" else proj.abs().mean(0)
|
||||
idx = (S_full * act_mag).argsort(descending=True)[:r].sort().values
|
||||
Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx]
|
||||
W_res_new = (W_orig - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh))
|
||||
layer.weight.data.copy_(W_res_new)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
g = layer.lora_g.to(x.dtype) # (r,)
|
||||
A = layer.lora_A.to(x.dtype) # (k, r)
|
||||
B = layer.lora_B.to(x.dtype) # (r, k)
|
||||
coeff = float(cfg.coeff)
|
||||
|
||||
if cfg.suppress_only:
|
||||
g = torch.clamp(g, max=0.0)
|
||||
|
||||
p = x @ Vh.T # (..., r) = Vh x (unscaled)
|
||||
S_eff = S * (1.0 + F.elu(coeff * g)) # diagonal gain (see antipasto.py)
|
||||
# Diagonal part scales each direction; low-rank part B@A mixes across the
|
||||
# subspace. Additive (not * diag(S)), so the core is S-independent: a unit
|
||||
# step in B@A moves W by O(1), not O(S) -- no S-amplification edge.
|
||||
h = p * S_eff + coeff * (p @ A.T) @ B.T # (..., r)
|
||||
return y + h @ U.T
|
||||
@@ -1,227 +0,0 @@
|
||||
"""AntiPaSTO-Rot: SVD adapter with learnable singular-value deltas + a block-diagonal
|
||||
Cayley rotation of the frozen basis. The rotation arm vs antipasto.py's gain-only core.
|
||||
|
||||
wassname 2026 https://arxiv.org/abs/2601.07473
|
||||
|
||||
W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r)
|
||||
learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2)
|
||||
R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T)
|
||||
y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T
|
||||
|
||||
Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ~ x @ W^T (tiny positive bias on
|
||||
delta_s breaks sign symmetry; rotation alone can't).
|
||||
|
||||
Refs:
|
||||
- paper: https://github.com/wassname/AntiPaSTO
|
||||
- lite port of: https://github.com/wassname/antipasto3
|
||||
(offline: docs/refs/antipasto3_svd_adapter.py)
|
||||
"""
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Literal
|
||||
|
||||
import torch
|
||||
from einops import einsum, rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTORotConfig(AdapterConfig):
|
||||
variant: str = "antipasto_rot"
|
||||
# Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out).
|
||||
r: int = 256
|
||||
# Block size for the block-diagonal Cayley rotation. r must be divisible by it.
|
||||
block_size: int = 4
|
||||
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
|
||||
max_rotation_angle: float = 0.5
|
||||
# Which singular basis to rotate: 'V' (input), 'U' (output), 'both', or 'none'.
|
||||
rotate_basis: Literal["V", "U", "both", "none"] = "V"
|
||||
|
||||
|
||||
def _cayley(skew: torch.Tensor) -> torch.Tensor:
|
||||
"""R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality."""
|
||||
bs = skew.shape[-1]
|
||||
eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew)
|
||||
X = skew / 2
|
||||
return torch.linalg.solve(eye - X, eye + X)
|
||||
|
||||
|
||||
def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor:
|
||||
"""rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation."""
|
||||
n_blocks, _ = rot_T.shape
|
||||
rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0)
|
||||
A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device)
|
||||
A[:, rows, cols] = rot_T
|
||||
A = 0.5 * (A - A.transpose(-1, -2))
|
||||
a_limit = 2.0 * math.tan(max_angle / 2.0)
|
||||
A = a_limit * torch.tanh(A / a_limit)
|
||||
return _cayley(A)
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTORot:
|
||||
name = "antipasto_rot"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r = cfg.r
|
||||
bs = int(cfg.block_size)
|
||||
if r % bs != 0:
|
||||
raise ValueError(f"AntiPaSTORot requires r={r} divisible by block_size={bs}")
|
||||
specs = dict(
|
||||
# Frozen SVD components captured at init.
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable: per-singular-value delta.
|
||||
# antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign
|
||||
# symmetry (rotation alone can't); zero-init works but trains slower.
|
||||
lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)),
|
||||
)
|
||||
if cfg.rotate_basis != "none":
|
||||
n_blocks = r // bs
|
||||
n_triu = bs * (bs - 1) // 2
|
||||
specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros")
|
||||
if cfg.rotate_basis == "both":
|
||||
# 'both' rotates V (lora_rot_T) and U independently; lora_rot_T_u is the U-side.
|
||||
specs["lora_rot_T_u"] = ParamSpec((n_blocks, n_triu), init="zeros")
|
||||
return specs
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError(
|
||||
"AntiPaSTORot mutates layer.weight into W_res (like PiSSA), so v1 "
|
||||
"only supports plain nn.Linear, not bnb 4/8-bit."
|
||||
)
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
# group_init() refines this to input-aligned directions if calibration_data is given.
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Wanda-style data-driven dimension selection within the weight SVD.
|
||||
|
||||
init() picks the top-r singular dimensions by S alone (PiSSA-style).
|
||||
group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions
|
||||
that are both large in W AND active given real inputs.
|
||||
|
||||
If calibration_data is None the weight-SVD init from init() is kept.
|
||||
"""
|
||||
if calibration_data is None:
|
||||
return
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
captured: dict[str, list[T]] = {n: [] for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
x = args[0].detach()
|
||||
captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu())
|
||||
return _h
|
||||
|
||||
handles = [
|
||||
layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True)
|
||||
for n in layers
|
||||
]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
if isinstance(batch, dict):
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r = cfg.r
|
||||
for name, layer in layers.items():
|
||||
X = torch.cat(captured[name], dim=0) # (N, d_in)
|
||||
if X.shape[0] < r:
|
||||
raise RuntimeError(
|
||||
f"AntiPaSTORot at {name}: only {X.shape[0]} calibration tokens, need >= r={r}"
|
||||
)
|
||||
|
||||
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r components
|
||||
W_res = layer.weight.data.float()
|
||||
U_old = layer.lora_U.float() # (d_out, r)
|
||||
S_old = layer.lora_S.float() # (r,)
|
||||
Vh_old = layer.lora_Vh.float() # (r, d_in)
|
||||
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
|
||||
|
||||
# Full SVD to score all dimensions
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
|
||||
# score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude)
|
||||
act_mag = (X.to(Vh_full) @ Vh_full.T).abs().mean(dim=0) # (k,) -- X captured on CPU
|
||||
scores = S_full * act_mag
|
||||
idx = scores.argsort(descending=True)[:r] # top-r by joint importance
|
||||
idx = idx.sort().values # stable ordering
|
||||
|
||||
Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx]
|
||||
W_res_new = (W_orig - (Ur * Sr.unsqueeze(0)) @ Vhr).to(layer.weight.dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh))
|
||||
layer.weight.data.copy_(W_res_new)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
bs = int(cfg.block_size)
|
||||
max_angle = float(cfg.max_rotation_angle)
|
||||
rotate_basis = cfg.rotate_basis
|
||||
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
|
||||
if rotate_basis == "none":
|
||||
U_eff, Vh_eff = U, Vh
|
||||
else:
|
||||
R = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype)
|
||||
n_blocks = R.shape[0] # R: (n, bs, bs)
|
||||
U_eff, Vh_eff = U, Vh
|
||||
# 'V'/'U' rotate that one basis with lora_rot_T; 'both' rotates V with
|
||||
# lora_rot_T and U with a separate lora_rot_T_u (independent rotations).
|
||||
if rotate_basis in ("V", "both"):
|
||||
Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks)
|
||||
Vh_eff = rearrange(einsum(R, Vh_blocks, "n a b, n b i -> n a i"), "n a i -> (n a) i")
|
||||
if rotate_basis in ("U", "both"):
|
||||
R_u = _build_rotation(layer.lora_rot_T_u.float(), bs, max_angle).to(x.dtype) if rotate_basis == "both" else R
|
||||
U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks)
|
||||
U_eff = rearrange(einsum(U_blocks, R_u, "d n b, n c b -> d n c"), "d n c -> d (n c)")
|
||||
|
||||
S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,)
|
||||
h = x @ Vh_eff.T # x @ Vh_eff.T
|
||||
h = h * S_eff # diag(S_eff)
|
||||
delta = h @ U_eff.T # @ U_eff.T
|
||||
return y + delta
|
||||
@@ -84,7 +84,7 @@ class EVA:
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
# Padding activations are not task-representative; mask them out of the
|
||||
# PCA so the basis reflects real tokens (matches antipasto_corda).
|
||||
# PCA so the basis reflects real tokens.
|
||||
keep.pop("mask", None)
|
||||
if isinstance(batch, dict):
|
||||
if "attention_mask" in batch:
|
||||
|
||||
@@ -33,13 +33,11 @@ SPEC.loader.exec_module(benchmark)
|
||||
|
||||
|
||||
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
|
||||
"antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda",
|
||||
"antipasto_asvd", "antipasto_dplr", "road"]
|
||||
"antipasto", "road"]
|
||||
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
|
||||
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
|
||||
# so we don't expect a raise from them in the attach-only smoke.
|
||||
BNB_RAISERS = {"pissa", "dora", "antipasto", "antipasto_rot", "antipasto_ablate",
|
||||
"antipasto_corda", "antipasto_dplr"}
|
||||
BNB_RAISERS = {"pissa", "dora", "antipasto"}
|
||||
TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
@@ -61,7 +59,6 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc
|
||||
quantization=quantization,
|
||||
r=4,
|
||||
alpha=8,
|
||||
antipasto_lora_rank=2, # antipasto_dplr needs 0 < lora_rank <= r (r=4 here)
|
||||
target_name=target_name,
|
||||
layers="all",
|
||||
steps=2,
|
||||
|
||||
Reference in New Issue
Block a user