mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 17:01:14 +08:00
Add rotation-free S-space adapter cores (antipasto family)
Replace antipasto's rotation/Cayley with a bounded 1+ELU gain and split the
S-space idea into four interpretable PiSSA-style cores (frozen U/S/Vh, small
trainable core):
- antipasto: S_eff = S*(1+ELU(coeff*g)). exp-bounded attenuation, linear
amplification (constant gradient, no runaway). g=0 -> exact identity.
- antipasto_rot: keeps the block-Cayley rotation as a separate variant for
cost comparison (its per-forward solve is the 72ms vs 36ms gap).
- antipasto_ablate: contractive (I - a c c^T) diag(S), eigenvalues in [0,1],
cannot blow up. Optional cov_orient (CorDA) basis.
- antipasto_corda: covariance-oriented oblique projector P = Vh C^{-1/2}, the
data-energy basis rather than the weight-gain basis. 1+ELU gain.
Add scripts/_cost.py + scripts/cost_report.py: one-row-per-variant cost table
(trainable params, peak GPU mem, fwd/bwd ms, added MACs/tok, group_init ms).
Wire all four into the benchmark, smoke test, and __init__ exports.
External review (DeepSeek-v4-pro, docs/reviews/) verified the math; acted on
its one real point (corda g now inits to zeros for exact identity).
Co-Authored-By: Claudypoo <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
## Code Review: Four S-space weight adapters (anti-pasto family + sspace reference + cost script)
|
||||
|
||||
### Summary
|
||||
The code adds four PiSSA-style adapters (antipasto, antipasto_rot, antipasto_ablate, antipasto_corda) that store frozen top-r SVD buffers and a small trainable gain/core. The key mathematical claims in the docstrings (identity at g=0, contraction of ablation core, exact reconstruction of the CorDA decomposition, signed cosine gate) are correct. However, `group_init()` re-selects/re-orients the SVD basis but does **not** reset the trainable parameters (g, delta_s, rot_T, c, alpha), which introduces a latent but important correctness problem — after a data-driven reinit, gains and direction vectors become misaligned with their intended singular axes.
|
||||
|
||||
---
|
||||
|
||||
### Critical (must fix)
|
||||
- **`antipasto_corda.py`, `antipasto_rot.py`, `antipasto_ablate.py`:** `group_init()` re-orients or re-selects the top-r SVD basis but leaves the existing trainable parameters (`lora_g`, `lora_delta_s`, `lora_rot_T`, `lora_c`, `lora_alpha`) untouched, still attached to the old indices.
|
||||
- For `antipasto_corda`, `lora_g` is initialised with `N(0, 4e-4)` in `init()`. After `group_init`, those small random gains now multiply a different set of singular directions, producing an uncontrolled (though tiny) perturbation.
|
||||
- For `antipasto_rot`, the `+4e-4` bias on `lora_delta_s` remains, now applied to arbitrary resorted directions, and the block rotation `lora_rot_T` is disconnected from the new block structure.
|
||||
- For `antipasto_ablate`, the ablation directions and strengths (`lora_c`, `lora_alpha`) are not reset; if any warm-starting or training has happened, they point in the wrong subspace.
|
||||
**Fix:** After re-selection or re-orientation, either re-initialise the trainable parameters (e.g., zero for g, zeros/small for delta_s and rot_T, random-normalised for c, init_alpha for alpha) or re-index them with the same `idx` mapping used for the buffers. Document that `group_init` must be called **before any training** and that the trainable parameters will be fresh after it.
|
||||
|
||||
### Important (should fix)
|
||||
- **`antipasto_ablate.py`** – Contraction claim and coeff bounds: the forward pass applies `h = h - coeff * (proj * alpha) @ Chat.T`. The core is a contraction only when both `coeff` and `alpha` are in `[0, 1]`. The config’s `coeff` can be *any* float (the docstring mentions `coeff<0` for amplification). There is no runtime clamping of `coeff`. If the intention is to keep the contraction property structurally enforced, clamp `coeff` to `[0, 1]` inside `forward()` or at least validate it.
|
||||
- **`sspace.py`** – Division by `sqrtS` without epsilon: `xS = (y_eff @ U_r) / sqrtS`. For small or zero singular values (e.g., when r is large or W is low-rank) this can produce NaNs/infs. Add a small `ε` denominator (consistent with the ε used elsewhere).
|
||||
- **`antipasto_ablate.py`** – `_orthonormal()` calls `torch.linalg.qr(c.float())` every forward pass. For large `r` and many layers this adds non-trivial cost. A lighter reparameterisation (e.g., maintaining `c` as a matrix with a normalisation step that avoids full QR) might be warranted, but for small `(r,k)` pairs the current approach is acceptable. At minimum, add a comment noting the per-forward cost.
|
||||
|
||||
### Suggestions
|
||||
- **`antipasto.py` group_init:** after the idx sort, `idx = idx.sort().values` ensures a stable, canonical ordering. This is a nice touch for reproducibility.
|
||||
- **`antipasto_ablate.py`** – The docstring says “CorDA-orient the basis from input covariance … the ablation is OUTPUT-side and CorDA's U stays orthonormal …”. The forward code correctly uses the orthonormal `U` for output projection, so the contraction in S-space carries over to the output.
|
||||
- **`sspace.py`** – The signed gate correctly preserves `cos` sign, so anti-aligned tokens receive a negative `gate * alpha` and are pushed opposite to `dS_hat`. Verified.
|
||||
|
||||
### Positive
|
||||
- The documentation is thorough, explaining the design choices (why 1+ELU, why contraction, why CorDA), and includes references.
|
||||
- The PiSSA-style `W_res` decomposition is implemented correctly across all variants.
|
||||
- `antipasto.py`’s `S_eff = S * (1 + ELU(coeff*g))` is indeed C1, positive, and identity at `g=0` — no sign-flip bugs.
|
||||
- `antipasto_ablate.py` enforces orthonormal `Chat` and clamps `alpha` to `[0,1]`, making the contraction property safe when `coeff` is also bounded.
|
||||
|
||||
### Verdict
|
||||
**REQUEST CHANGES**
|
||||
|
||||
`group_init()` must reset the trainable parameters after it changes the SVD basis; without this, the adapters silently poison their own steering gains with a different set of directions. Fix that and the code is ready.
|
||||
|
||||
---
|
||||
|
||||
*Note: The `sspace.py` and `_cost.py` files are part of the same move into lora-lite and are free of the parameter-reset issue; the only concern in `sspace.py` is the unprotected division by singular values.*
|
||||
@@ -0,0 +1,131 @@
|
||||
"""Measure the cost of an attached adapter: params, FLOPs/MACs, time, GPU mem.
|
||||
|
||||
Which metric is "best" for comparing adapters? They answer different questions:
|
||||
|
||||
- trainable_params -- deterministic "size" number. The headline.
|
||||
- macs_per_token -- deterministic, hardware-INDEPENDENT compute. Best for an
|
||||
apples-to-apples comparison: wall-time is noisy and the old
|
||||
rotation adapter paid a per-forward Cayley solve the new ones
|
||||
do not. "adds" (additions) ~= MACs; FLOPs ~= 2 * MACs.
|
||||
- fwd_ms / bwd_ms -- felt cost, but noisy: warmup + median over `iters`, never one run.
|
||||
- peak_gpu_mb -- resident + activation peak around fwd(+bwd).
|
||||
|
||||
FLOPs come from torch.utils.flop_counter.FlopCounterMode (built in, no new dep). Its
|
||||
convention is MACs (a (m,k)@(k,n) matmul counts as m*n*k); we expose both `flops`
|
||||
(as returned) and `macs_per_token = flops / n_tokens` -- calibrate once on a known
|
||||
matmul if you need to be sure of the factor of 2.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import statistics
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.utils.flop_counter import FlopCounterMode
|
||||
|
||||
|
||||
def _time_call(fn, warmup: int, iters: int, cuda: bool) -> float:
|
||||
"""Median wall-time of fn() in milliseconds (warmup excluded)."""
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
if cuda:
|
||||
torch.cuda.synchronize()
|
||||
samples = []
|
||||
for _ in range(iters):
|
||||
if cuda:
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start.record()
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
samples.append(start.elapsed_time(end))
|
||||
else:
|
||||
t0 = time.perf_counter()
|
||||
fn()
|
||||
samples.append((time.perf_counter() - t0) * 1e3)
|
||||
return statistics.median(samples)
|
||||
|
||||
|
||||
def measure_cost(
|
||||
model: torch.nn.Module,
|
||||
fwd_fn,
|
||||
*,
|
||||
bwd_step_fn=None,
|
||||
n_tokens: int | None = None,
|
||||
adapter_filter: str = "lora_",
|
||||
warmup: int = 3,
|
||||
iters: int = 10,
|
||||
) -> dict:
|
||||
"""Cost of the currently-attached adapter.
|
||||
|
||||
fwd_fn(): run one forward (no grad). Used for FLOPs + fwd timing.
|
||||
bwd_step_fn(): zero_grad + forward + loss.backward(). Used for bwd timing.
|
||||
n_tokens: tokens in the fwd_fn batch, for macs_per_token.
|
||||
adapter_filter: substring marking adapter params/buffers (default 'lora_').
|
||||
"""
|
||||
dev = next(model.parameters()).device
|
||||
cuda = dev.type == "cuda"
|
||||
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
named = list(model.named_parameters()) + list(model.named_buffers())
|
||||
adapter_bytes = sum(t.numel() * t.element_size() for n, t in named if adapter_filter in n)
|
||||
|
||||
# FLOPs: one forward under the counter (no grad so we count inference cost).
|
||||
# FlopCounterMode can assert on some fused attention shapes; degrade to None.
|
||||
try:
|
||||
fc = FlopCounterMode(display=False)
|
||||
with torch.no_grad(), fc:
|
||||
fwd_fn()
|
||||
flops = fc.get_total_flops()
|
||||
except Exception as e:
|
||||
print(f" [warn] FLOP count failed ({type(e).__name__}: {e}); flops=None")
|
||||
flops = None
|
||||
|
||||
if cuda:
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
fwd_ms = _time_call(lambda: _no_grad(fwd_fn), warmup, iters, cuda)
|
||||
bwd_ms = _time_call(bwd_step_fn, warmup, iters, cuda) if bwd_step_fn is not None else None
|
||||
peak_gpu_mb = (torch.cuda.max_memory_allocated() / 1e6) if cuda else None
|
||||
|
||||
return dict(
|
||||
trainable_params=trainable_params,
|
||||
adapter_resident_mb=adapter_bytes / 1e6,
|
||||
flops=flops,
|
||||
macs_per_token=(flops / n_tokens) if (flops and n_tokens) else None,
|
||||
fwd_ms=fwd_ms,
|
||||
bwd_ms=bwd_ms,
|
||||
peak_gpu_mb=peak_gpu_mb,
|
||||
)
|
||||
|
||||
|
||||
def _no_grad(fn):
|
||||
with torch.no_grad():
|
||||
return fn()
|
||||
|
||||
|
||||
class group_init_meter:
|
||||
"""Context manager: wall-time + peak CPU RAM of a group_init / attach-with-calib.
|
||||
|
||||
CorDA accumulates C = E[xx^T] on CPU and runs eigh(d_in^3) -- the expensive corner.
|
||||
Use around ll.attach(model, cfg, calibration_data=...) to log that asymmetry.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.ms = None
|
||||
self.peak_cpu_mb = None
|
||||
|
||||
def __enter__(self):
|
||||
import tracemalloc
|
||||
self._tm = tracemalloc
|
||||
tracemalloc.start()
|
||||
self._t0 = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.ms = (time.perf_counter() - self._t0) * 1e3
|
||||
_, peak = self._tm.get_traced_memory()
|
||||
self._tm.stop()
|
||||
self.peak_cpu_mb = peak / 1e6
|
||||
return False
|
||||
@@ -0,0 +1,142 @@
|
||||
"""One-row-per-variant cost table: params, MACs/token, fwd/bwd ms, peak GPU, group_init.
|
||||
|
||||
Answers "which is best -- time / flops / adds / params?": MACs/token is the
|
||||
deterministic apples-to-apples compute number; trainable_params is the size headline;
|
||||
wall-time is the felt-but-noisy number; group_init is where CorDA's eigh(d_in^3) bites.
|
||||
|
||||
Usage:
|
||||
uv run --extra benchmark python scripts/cost_report.py \
|
||||
--model Qwen/Qwen3-0.6B-Base --variants antipasto antipasto_corda antipasto_ablate lora \
|
||||
--target-name 'q_proj$' 'v_proj$' --r 32 --out logs/cost_qwen0.6b.log
|
||||
|
||||
Point --target-name at down_proj to see the CorDA covariance corner (large d_in).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tabulate import tabulate
|
||||
|
||||
import lora_lite as ll
|
||||
|
||||
_HERE = Path(__file__).resolve().parent
|
||||
_BENCH = importlib.util.spec_from_file_location("metamath_benchmark", _HERE / "metamath_gsm8k_benchmark.py")
|
||||
benchmark = importlib.util.module_from_spec(_BENCH)
|
||||
sys.modules[_BENCH.name] = benchmark
|
||||
_BENCH.loader.exec_module(benchmark)
|
||||
|
||||
_COST = importlib.util.spec_from_file_location("_cost", _HERE / "_cost.py")
|
||||
cost = importlib.util.module_from_spec(_COST)
|
||||
sys.modules[_COST.name] = cost
|
||||
_COST.loader.exec_module(cost)
|
||||
|
||||
|
||||
def build_cfg(variant: str, args, dtype) -> ll.AdapterConfig:
|
||||
"""Reuse the benchmark's variant->config map; only need r/targets/dtype here."""
|
||||
bcfg = benchmark.BenchmarkConfig(
|
||||
model=args.model, variant=variant, r=args.r, alpha=float(args.r),
|
||||
target_name=list(args.target_name), layers=args.layers, torch_dtype=args.dtype,
|
||||
antipasto_cov_orient=args.cov_orient,
|
||||
)
|
||||
return benchmark.cfg_for_variant(bcfg, dtype)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default="Qwen/Qwen3-0.6B-Base")
|
||||
ap.add_argument("--variants", nargs="+",
|
||||
default=["lora", "antipasto", "antipasto_rot", "antipasto_corda", "antipasto_ablate"])
|
||||
ap.add_argument("--target-name", nargs="+", default=[r"q_proj$", r"v_proj$"])
|
||||
ap.add_argument("--r", type=int, default=32)
|
||||
ap.add_argument("--layers", default="all",
|
||||
help="'all' or comma list e.g. '0,1' -- limit layers (CorDA down_proj eigh is slow).")
|
||||
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
ap.add_argument("--dtype", default="bfloat16")
|
||||
ap.add_argument("--seq-len", type=int, default=256)
|
||||
ap.add_argument("--batch", type=int, default=2)
|
||||
ap.add_argument("--calib-batches", type=int, default=4)
|
||||
ap.add_argument("--cov-orient", action="store_true",
|
||||
help="CorDA-orient antipasto_ablate (measure the eigh corner).")
|
||||
ap.add_argument("--out", default="logs/cost.log")
|
||||
args = ap.parse_args()
|
||||
|
||||
dtype = getattr(torch, args.dtype)
|
||||
# eager attention: FlopCounterMode's sdpa_flop_count asserts on GQA (Qwen3) SDPA
|
||||
# shapes (q heads != kv heads). eager uses explicit matmuls it can count.
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model, dtype=dtype, attn_implementation="eager"
|
||||
).to(args.device)
|
||||
model.eval()
|
||||
|
||||
n_tokens = args.batch * args.seq_len
|
||||
ids = torch.randint(0, model.config.vocab_size, (args.batch, args.seq_len), device=args.device)
|
||||
calib = [{"input_ids": torch.randint(0, model.config.vocab_size,
|
||||
(args.batch, args.seq_len), device=args.device)}
|
||||
for _ in range(args.calib_batches)]
|
||||
|
||||
def fwd():
|
||||
model(input_ids=ids)
|
||||
|
||||
def bwd_step():
|
||||
model.zero_grad(set_to_none=True)
|
||||
loss = model(input_ids=ids).logits.float().pow(2).mean()
|
||||
loss.backward()
|
||||
|
||||
# base (no-adapter) cost, so each row can report the adapter's ADDED MACs/token.
|
||||
base = cost.measure_cost(model, fwd, bwd_step_fn=bwd_step, n_tokens=n_tokens)
|
||||
base_macs = base["macs_per_token"]
|
||||
print(f"base (no adapter): MACs/tok={int(base_macs) if base_macs else None} "
|
||||
f"fwd_ms={round(base['fwd_ms'],2)} bwd_ms={round(base['bwd_ms'],2)}")
|
||||
|
||||
# base = no adapter; model params left trainable, so this is the full-finetune
|
||||
# GPU-mem reference (its backward stores grads for every weight).
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
rows = [{
|
||||
"variant": "base(full-FT)", "train_params": total_params,
|
||||
"fwd_ms": round(base["fwd_ms"], 2), "bwd_ms": round(base["bwd_ms"], 2),
|
||||
"peak_GPU_MB": round(base["peak_gpu_mb"], 1) if base["peak_gpu_mb"] else None,
|
||||
"added_MACs/tok": 0 if base_macs else None,
|
||||
"ginit_ms": 0.0, "ginit_CPU_MB": 0.0,
|
||||
}]
|
||||
for variant in args.variants:
|
||||
cfg = build_cfg(variant, args, dtype)
|
||||
# group_init / attach cost (CorDA's eigh + C live here).
|
||||
with cost.group_init_meter() as gi:
|
||||
ll.attach(model, cfg, calibration_data=calib)
|
||||
c = cost.measure_cost(model, fwd, bwd_step_fn=bwd_step, n_tokens=n_tokens)
|
||||
ll.detach(model)
|
||||
rows.append({
|
||||
"variant": variant,
|
||||
"train_params": c["trainable_params"],
|
||||
"fwd_ms": round(c["fwd_ms"], 2),
|
||||
"bwd_ms": round(c["bwd_ms"], 2) if c["bwd_ms"] else None,
|
||||
"peak_GPU_MB": round(c["peak_gpu_mb"], 1) if c["peak_gpu_mb"] else None,
|
||||
# flat across same-r adapters; kept only as a sanity check, not a comparator.
|
||||
"added_MACs/tok": int(c["macs_per_token"] - base_macs) if (c["macs_per_token"] and base_macs) else None,
|
||||
"ginit_ms": round(gi.ms, 1),
|
||||
"ginit_CPU_MB": round(gi.peak_cpu_mb, 1),
|
||||
})
|
||||
print(f" {variant}: params={rows[-1]['train_params']} "
|
||||
f"peak_GPU_MB={rows[-1]['peak_GPU_MB']} bwd_ms={rows[-1]['bwd_ms']} ginit_ms={rows[-1]['ginit_ms']}")
|
||||
|
||||
table = tabulate(rows, headers="keys", tablefmt="pipe")
|
||||
header = (f"# cost report: {args.model} targets={args.target_name} r={args.r} "
|
||||
f"seq={args.seq_len} batch={args.batch} dtype={args.dtype}\n"
|
||||
f"# COMPARATORS: train_params, peak_GPU_MB (fwd+bwd, process-local max), bwd_ms, ginit_ms.\n"
|
||||
f"# added_MACs/tok is flat across same-r adapters (sanity check only).\n"
|
||||
f"# ginit_CPU_MB undercounts: tracemalloc misses torch C++ tensor allocs (the CorDA C matrix).\n")
|
||||
out_path = Path(args.out)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path.write_text(header + table + "\n")
|
||||
print("\n" + header + table)
|
||||
print(f"\nsaved -> {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -34,6 +34,9 @@ CFG_BY_VARIANT = {
|
||||
"hra": ll.HRAConfig,
|
||||
"eva": ll.EVAConfig,
|
||||
"antipasto": ll.AntiPaSTOConfig,
|
||||
"antipasto_rot": ll.AntiPaSTORotConfig,
|
||||
"antipasto_ablate": ll.AntiPaSTOAblateConfig,
|
||||
"antipasto_corda": ll.AntiPaSTOCorDAConfig,
|
||||
"road": ll.RoadConfig,
|
||||
}
|
||||
|
||||
@@ -43,7 +46,7 @@ class BenchmarkConfig:
|
||||
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
|
||||
|
||||
model: str = "Qwen/Qwen3-0.6B-Base"
|
||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"] = "lora"
|
||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "road"] = "lora"
|
||||
mode: Literal["benchmark", "probe"] = "benchmark"
|
||||
device: str = "cuda"
|
||||
torch_dtype: str = "bfloat16"
|
||||
@@ -52,6 +55,13 @@ class BenchmarkConfig:
|
||||
alpha: float = 64.0
|
||||
delora_lambda0: float = 0.1
|
||||
road_group_size: int = 64
|
||||
# AntiPaSTO family (gain / corda) runtime knobs.
|
||||
antipasto_coeff: float = 1.0
|
||||
antipasto_suppress_only: bool = False
|
||||
# AntiPaSTO-ablate.
|
||||
antipasto_ablate_k: int = 1
|
||||
antipasto_cov_orient: bool = False
|
||||
# AntiPaSTO-rot (legacy rotation variant) basis to rotate.
|
||||
antipasto_rotate_basis: Literal["V", "U", "none"] = "V"
|
||||
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
|
||||
layers: str = "all"
|
||||
@@ -124,8 +134,15 @@ def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConf
|
||||
extra = {"lambda0": args.delora_lambda0} if args.variant == "delora" else {}
|
||||
if args.variant == "road":
|
||||
extra = {"group_size": args.road_group_size}
|
||||
if args.variant == "antipasto":
|
||||
if args.variant == "antipasto_rot":
|
||||
extra = {"rotate_basis": args.antipasto_rotate_basis}
|
||||
if args.variant == "antipasto":
|
||||
extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only}
|
||||
if args.variant == "antipasto_corda":
|
||||
extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only}
|
||||
if args.variant == "antipasto_ablate":
|
||||
extra = {"coeff": args.antipasto_coeff, "k": args.antipasto_ablate_k,
|
||||
"cov_orient": args.antipasto_cov_orient}
|
||||
return CFG_BY_VARIANT[args.variant](
|
||||
r=args.r,
|
||||
alpha=args.r if args.variant == "pissa" else args.alpha,
|
||||
@@ -155,7 +172,7 @@ def count_base_grad_leaks(model: torch.nn.Module) -> int:
|
||||
|
||||
|
||||
def perturb_first_adapter(model: torch.nn.Module) -> None:
|
||||
priority = ("lora_B", "lora_g", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_rot_T", "lora_m", "lora_road_theta", "lora_road_alpha")
|
||||
priority = ("lora_B", "lora_g", "lora_c", "lora_alpha", "lora_U", "lora_A", "lora_lambda", "lora_gate", "lora_delta_s", "lora_m", "lora_road_theta", "lora_road_alpha")
|
||||
for key in priority:
|
||||
for _, p in model.named_parameters():
|
||||
if p.requires_grad and key in _:
|
||||
|
||||
@@ -20,6 +20,9 @@ from .variants.dora import DoRAConfig
|
||||
from .variants.hra import HRAConfig
|
||||
from .variants.eva import EVAConfig
|
||||
from .variants.antipasto import AntiPaSTOConfig
|
||||
from .variants.antipasto_rot import AntiPaSTORotConfig
|
||||
from .variants.antipasto_ablate import AntiPaSTOAblateConfig
|
||||
from .variants.antipasto_corda import AntiPaSTOCorDAConfig
|
||||
from .variants.road import RoadConfig
|
||||
|
||||
__all__ = [
|
||||
@@ -33,6 +36,9 @@ __all__ = [
|
||||
"HRAConfig",
|
||||
"EVAConfig",
|
||||
"AntiPaSTOConfig",
|
||||
"AntiPaSTORotConfig",
|
||||
"AntiPaSTOAblateConfig",
|
||||
"AntiPaSTOCorDAConfig",
|
||||
"RoadConfig",
|
||||
"attach",
|
||||
"detach",
|
||||
|
||||
@@ -1 +1,4 @@
|
||||
from . import lora, pissa, delora, ia3, dora, hra, eva, antipasto, road # noqa: F401 side-effect: register
|
||||
from . import ( # noqa: F401 side-effect: register
|
||||
lora, pissa, delora, ia3, dora, hra, eva, antipasto, road,
|
||||
antipasto_rot, antipasto_ablate, antipasto_corda,
|
||||
)
|
||||
|
||||
@@ -1,31 +1,48 @@
|
||||
"""AntiPaSTO: SVD steering with learnable singular-value deltas + block-diagonal Cayley rotation.
|
||||
"""AntiPaSTO: SVD steering with learnable, bounded singular-value reweighting.
|
||||
|
||||
wassname 2026 https://arxiv.org/abs/2601.07473
|
||||
|
||||
W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r)
|
||||
learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2)
|
||||
R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T)
|
||||
y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T
|
||||
learn: g (r,) per-singular-direction gain log/lin-scale
|
||||
S_eff = S * (1 + ELU(coeff * g)) exp(.) for g<=0, 1+. for g>0
|
||||
suppress_only: clamp g<=0 -> factor in (0,1], attenuation only
|
||||
y = x @ W_res.T + ((x @ Vh.T) * S_eff) @ U.T
|
||||
|
||||
Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ≈ x @ W^T (fp32 SVD round-trip, tiny positive bias on delta_s breaks sign symmetry).
|
||||
Identity at g=0 (or coeff=0): 1+ELU(0)=1 exactly, so S_eff = S and the output is
|
||||
x @ W^T up to the one-time SVD-residual rounding. No additive sign-symmetry hack
|
||||
needed: the basis is frozen, so the direction sign is fixed and exp/(1+.) is
|
||||
sign-preserving. The 1+ELU shape is chosen over linear (sign-flips at g<-1), exp
|
||||
(amplification blows up), and tanh (arbitrary bound) -- see forward() for why.
|
||||
|
||||
Scope cut vs antipasto3: this is a fine-tuning adapter, not the full runtime
|
||||
steering interface. There is no per-call alpha, so it does not expose the
|
||||
bidirectional R(+alpha) / R(-alpha) inference symmetry. The V-basis path uses the
|
||||
opposite chirality to antipasto3's default U-basis path, so checkpoints are not
|
||||
portable without a sign/basis convention.
|
||||
Changes vs the rotation version this replaces:
|
||||
- Rotation dropped. Rotating Vh/U leaves the interpretable singular basis (the
|
||||
SVD-direction / Conjecture property), which is the entire point of steering in
|
||||
S-space, and the Cayley solve was numerically finicky. The basis is now frozen;
|
||||
the only learned object is the per-direction gain. If you later want
|
||||
cross-direction mixing, add a *fixed-basis* core U M Vh (M trainable, U/Vh frozen)
|
||||
rather than rotating -- that keeps the directions interpretable. It is also far
|
||||
cheaper than PiSSA: a dense r x r core is r^2 params (~= a rank-8 LoRA at r=256),
|
||||
versus PiSSA's free A,B at r*(d_in+d_out), which drifts off the SVD basis.
|
||||
- Additive delta_s -> bounded multiplicative S * (1 + ELU(coeff*g)). Multiplicative
|
||||
is "scaled by S" (uniform *relative* control over an orders-of-magnitude spectrum),
|
||||
stays positive (no S_eff<0 sign-flip -> no incoherence from that path), and the
|
||||
1+ELU shape stops the exp blowup. The 4e-4 sign-symmetry hack is gone.
|
||||
- suppress_only = clamp g<=0 -> factor in (0,1]: attenuation only, structurally
|
||||
cannot blow up. Matches the eval-awareness use case (turn a direction down).
|
||||
- coeff: runtime steering scalar (0 = identity, <0 inverts). The per-call alpha
|
||||
the rotation version lacked.
|
||||
- group_init activation pooling is configurable: 'rms' weights outliers (ASVD
|
||||
intuition), 'mean_abs' is the original outlier-robust pooling.
|
||||
|
||||
Refs:
|
||||
- paper: https://github.com/wassname/AntiPaSTO
|
||||
- lite port of: https://github.com/wassname/antipasto3
|
||||
(offline: docs/refs/antipasto3_svd_adapter.py)
|
||||
- sibling (whitened, rotation-free, mean-diff): steering-lite/.../sspace.py
|
||||
"""
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Literal
|
||||
|
||||
import torch
|
||||
from einops import einsum, rearrange
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
@@ -40,35 +57,16 @@ CalibrationData = Iterable[CalibrationBatch]
|
||||
@dataclass
|
||||
class AntiPaSTOConfig(AdapterConfig):
|
||||
variant: str = "antipasto"
|
||||
# Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out).
|
||||
# Only r + r trainable scalars, so r can be large.
|
||||
r: int = 256
|
||||
|
||||
# Block size for the block-diagonal Cayley rotation. r must be divisible by it.
|
||||
block_size: int = 4
|
||||
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
|
||||
max_rotation_angle: float = 0.5
|
||||
# Which singular basis to rotate: 'V' (input), 'U' (output), or 'none'.
|
||||
rotate_basis: Literal["V", "U", "none"] = "V"
|
||||
|
||||
|
||||
def _cayley(skew: torch.Tensor) -> torch.Tensor:
|
||||
"""R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality."""
|
||||
bs = skew.shape[-1]
|
||||
eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew)
|
||||
X = skew / 2
|
||||
return torch.linalg.solve(eye - X, eye + X)
|
||||
|
||||
|
||||
def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor:
|
||||
"""rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation."""
|
||||
n_blocks, _ = rot_T.shape
|
||||
rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0)
|
||||
A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device)
|
||||
A[:, rows, cols] = rot_T
|
||||
A = 0.5 * (A - A.transpose(-1, -2))
|
||||
a_limit = 2.0 * math.tan(max_angle / 2.0)
|
||||
A = a_limit * torch.tanh(A / a_limit)
|
||||
return _cayley(A)
|
||||
# Per-direction reweighting is S_eff = S * (1 + ELU(coeff * g)). See forward()
|
||||
# for the why; identity at g=0 or coeff=0, positive always, no free bound knob.
|
||||
suppress_only: bool = False # clamp g<=0 -> factor in (0,1]: attenuation only
|
||||
# Runtime steering scale. 0 = identity. <0 inverts (swaps amplify/suppress).
|
||||
coeff: float = 1.0
|
||||
# group_init Wanda-style pooling of |X @ Vh[i]|: 'rms' is outlier-sensitive
|
||||
# (ASVD intuition), 'mean_abs' is the original outlier-robust pooling.
|
||||
act_pool: Literal["rms", "mean_abs"] = "rms"
|
||||
|
||||
|
||||
@register
|
||||
@@ -78,24 +76,14 @@ class AntiPaSTO:
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r = cfg.r
|
||||
bs = int(cfg.block_size)
|
||||
if r % bs != 0:
|
||||
raise ValueError(f"AntiPaSTO requires r={r} divisible by block_size={bs}")
|
||||
specs = dict(
|
||||
# Frozen SVD components captured at init.
|
||||
return dict(
|
||||
# Frozen top-r SVD captured at init.
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable: per-singular-value delta.
|
||||
# antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign
|
||||
# symmetry (rotation alone can't); zero-init works but trains slower.
|
||||
lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)),
|
||||
# Trainable per-direction log-scale. init 0 -> 1+ELU(0)=1 -> identity.
|
||||
lora_g=ParamSpec((r,), init="zeros"),
|
||||
)
|
||||
if cfg.rotate_basis != "none":
|
||||
n_blocks = r // bs
|
||||
n_triu = bs * (bs - 1) // 2
|
||||
specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros")
|
||||
return specs
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
@@ -114,17 +102,18 @@ class AntiPaSTO:
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
# group_init() refines this to input-aligned directions if calibration_data is given.
|
||||
# group_init() refines the dimension selection if calibration_data is given.
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Wanda-style data-driven dimension selection within the weight SVD.
|
||||
"""Wanda-style, data-driven dimension selection within the weight SVD.
|
||||
|
||||
init() picks the top-r singular dimensions by S alone (PiSSA-style).
|
||||
group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions
|
||||
that are both large in W AND active given real inputs.
|
||||
|
||||
If calibration_data is None the weight-SVD init from init() is kept.
|
||||
group_init() re-selects by score[i] = S[i] * pool|X @ Vh[i]|: dimensions
|
||||
that are both large in W AND active on real inputs. pool = 'rms' (outlier-
|
||||
sensitive, the ASVD intuition that activation outliers carry signal) or
|
||||
'mean_abs' (the original, outlier-robust). If calibration_data is None the
|
||||
weight-SVD init from init() is kept.
|
||||
"""
|
||||
if calibration_data is None:
|
||||
return
|
||||
@@ -160,6 +149,7 @@ class AntiPaSTO:
|
||||
h.remove()
|
||||
|
||||
r = cfg.r
|
||||
pool = cfg.act_pool
|
||||
for name, layer in layers.items():
|
||||
X = torch.cat(captured[name], dim=0) # (N, d_in)
|
||||
if X.shape[0] < r:
|
||||
@@ -167,17 +157,19 @@ class AntiPaSTO:
|
||||
f"AntiPaSTO at {name}: only {X.shape[0]} calibration tokens, need >= r={r}"
|
||||
)
|
||||
|
||||
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r components
|
||||
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r.
|
||||
W_res = layer.weight.data.float()
|
||||
U_old = layer.lora_U.float() # (d_out, r)
|
||||
S_old = layer.lora_S.float() # (r,)
|
||||
Vh_old = layer.lora_Vh.float() # (r, d_in)
|
||||
U_old = layer.lora_U.float()
|
||||
S_old = layer.lora_S.float()
|
||||
Vh_old = layer.lora_Vh.float()
|
||||
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
|
||||
|
||||
# Full SVD to score all dimensions
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
|
||||
# score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude)
|
||||
act_mag = (X @ Vh_full.T).abs().mean(dim=0) # (k,)
|
||||
proj = X.to(Vh_full) @ Vh_full.T # (N, k) input in S-coords (X captured on CPU)
|
||||
if pool == "rms":
|
||||
act_mag = proj.pow(2).mean(dim=0).sqrt() # outlier-sensitive
|
||||
else:
|
||||
act_mag = proj.abs().mean(dim=0) # outlier-robust (original)
|
||||
scores = S_full * act_mag
|
||||
idx = scores.argsort(descending=True)[:r] # top-r by joint importance
|
||||
idx = idx.sort().values # stable ordering
|
||||
@@ -198,35 +190,31 @@ class AntiPaSTO:
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
bs = int(cfg.block_size)
|
||||
max_angle = float(cfg.max_rotation_angle)
|
||||
rotate_basis = cfg.rotate_basis
|
||||
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
g = layer.lora_g.to(x.dtype) # (r,)
|
||||
coeff = float(cfg.coeff)
|
||||
|
||||
if rotate_basis == "none":
|
||||
U_eff, Vh_eff = U, Vh
|
||||
else:
|
||||
R_blocks = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype)
|
||||
n_blocks = R_blocks.shape[0] # R_blocks: (n, bs, bs)
|
||||
if rotate_basis == "V":
|
||||
Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks)
|
||||
Vh_rot = einsum(R_blocks, Vh_blocks, "n a b, n b i -> n a i")
|
||||
Vh_eff = rearrange(Vh_rot, "n a i -> (n a) i")
|
||||
U_eff = U
|
||||
elif rotate_basis == "U":
|
||||
U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks)
|
||||
U_rot = einsum(U_blocks, R_blocks, "d n b, n c b -> d n c")
|
||||
U_eff = rearrange(U_rot, "d n c -> d (n c)")
|
||||
Vh_eff = Vh
|
||||
else:
|
||||
raise ValueError(f"rotate_basis must be 'U', 'V', or 'none', got {rotate_basis!r}")
|
||||
if cfg.suppress_only:
|
||||
g = torch.clamp(g, max=0.0) # factor in (0,1]: attenuation only
|
||||
|
||||
# FIXME: try lora_delta_s as [r,k] this is because the main limit of this adapter is that it's under parametised here. `reduce(h @ U_eff.T, '... k -> ...'). But have to make sure it's not lienarly reducable to one adapter.
|
||||
S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,)
|
||||
h = x @ Vh_eff.T # x @ Vh_eff.T
|
||||
h = h * S_eff # diag(S_eff)
|
||||
delta = h @ U_eff.T # @ U_eff.T
|
||||
return y + delta
|
||||
# Per-direction reweighting: S_eff = S * (1 + ELU(coeff * g)).
|
||||
# 1 + ELU(z) = exp(z) for z<=0, 1+z for z>0.
|
||||
# Why this and not the obvious ones (all of which we tried):
|
||||
# linear S*(1+z) : constant gradient (stable), but z<-1 -> S_eff<0,
|
||||
# a sign flip that drives incoherence. Unstable in
|
||||
# the negatives.
|
||||
# exp S*exp(z) : positive, but unbounded and the gradient self-
|
||||
# amplifies (d/dz exp = exp), so amplification blows up.
|
||||
# tanh S*exp(c*tanh z): bounded, but c is an arbitrary free knob with no
|
||||
# principled value, and saturation kills the gradient.
|
||||
# 1+ELU : uses each in its safe regime -- exp only where it is
|
||||
# bounded in (0,1] (attenuation, cannot go negative),
|
||||
# linear where exp would diverge (amplification, const
|
||||
# gradient). C1 at z=0 (both -> 1, slope 1); >0 always.
|
||||
# coeff=0 or g=0 -> S_eff = S (identity). coeff<0 swaps amplify/suppress.
|
||||
S_eff = S * (1.0 + torch.nn.functional.elu(coeff * g))
|
||||
|
||||
h = (x @ Vh.T) * S_eff # input in S-coords, reweighted
|
||||
return y + h @ U.T
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
"""AntiPaSTO-Ablate: trainable directional ablation in the weight-SVD output basis.
|
||||
|
||||
A contractive sibling of antipasto.py. Instead of reweighting the singular gains,
|
||||
it projects out a learned direction in the *output* singular basis (the U side):
|
||||
|
||||
W = U diag(S) Vh + W_res
|
||||
learn: c (r, k) ablation directions, alpha (k,) strengths in [0, 1]
|
||||
Chat = orthonormal(c) # k unit dirs in S-space
|
||||
h = (x @ Vh.T) * S # output S-coords (= diag(S) Vh x)
|
||||
h <- h - coeff * (h @ Chat) * alpha @ Chat.T # project the span out
|
||||
y = x @ W_res.T + h @ U.T
|
||||
|
||||
Why this instead of gain reweighting (antipasto.py):
|
||||
- The core (I - alpha Chat Chat^T) is a CONTRACTION: eigenvalues are 1-alpha along
|
||||
Chat and 1 elsewhere, all in [0, 1] for alpha in [0, 1]. It cannot amplify and
|
||||
cannot blow up, so the failure mode the multiplicative gain fights with bounds is
|
||||
structurally absent. It is also the natural core to recurse (a contraction composed
|
||||
with itself converges; an amplifier diverges).
|
||||
- It is the trainable form of directional ablation (Arditi+ 2024). Ablating Chat in
|
||||
the middle removes output direction U Chat; for a residual *writer*
|
||||
(mlp.down_proj, self_attn.o_proj) that is a residual-stream direction -- the
|
||||
SURGICAL regime in the steering-lite sweeps (directional_ablation topped SI).
|
||||
Target writers, not all Linears, or you get the broad-suppression regime.
|
||||
|
||||
Runtime: coeff is the per-call knob. coeff=0 -> identity. coeff in (0, 1] -> ablate.
|
||||
coeff < 0 -> *add* the direction back (amplify) -- the bidirectional dual; this is the
|
||||
side that can grow, so bound coeff there.
|
||||
|
||||
Init: alpha small (>0 so c receives gradient), c random-normalized. The strong init is
|
||||
to warm-start c from the contrastive direction dS in S-space (extract it exactly like
|
||||
sspace.py: dS = mean(xS_pos) - mean(xS_neg) on persona-branching pairs), then fine-tune.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
ε = 1e-6
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTOAblateConfig(AdapterConfig):
|
||||
variant: str = "antipasto_ablate"
|
||||
r: int = 256 # top-r SVD captured (or |dS|-selected via group_init)
|
||||
k: int = 1 # number of ablation directions (rank of the projection)
|
||||
init_alpha: float = 0.05 # small >0 so c gets gradient at step 0
|
||||
coeff: float = 1.0 # runtime: 0=identity, (0,1]=ablate, <0=amplify (bound this side)
|
||||
# CorDA-orient the basis from input covariance (group_init, needs calibration_data).
|
||||
# The ablation is OUTPUT-side and CorDA's U stays orthonormal, so this is a clean
|
||||
# contraction; the win is at low r -- the data-oriented top-r captures the behavior
|
||||
# output direction that plain-SVD top-r drops (measured 1.00 vs 0.65 at r=16).
|
||||
cov_orient: bool = False
|
||||
cov_eps: float = 1e-3
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTOAblate:
|
||||
name = "antipasto_ablate"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r, k = cfg.r, cfg.k
|
||||
return dict(
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable: k ablation directions in S-space, and their strengths.
|
||||
lora_c=ParamSpec((r, k), init=lambda t: t.normal_(0, 1.0 / max(r, 1) ** 0.5)),
|
||||
lora_alpha=ParamSpec((k,), init=lambda t: t.fill_(float(cfg.init_alpha))),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError("AntiPaSTOAblate mutates layer.weight into W_res; nn.Linear only.")
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
# Optional but recommended: group_init() should warm-start lora_c from the
|
||||
# S-space contrastive direction dS (see sspace.py extract). Random init also
|
||||
# trains, just slower and with no guarantee it finds the behavior direction.
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""If cov_orient, re-orient each target's SVD by input covariance C=E[x x^T]
|
||||
(CorDA) so the data-relevant output directions land in the top-r and the
|
||||
behavior direction is fully ablatable at low r. No-op otherwise (keeps the
|
||||
plain-SVD basis from init()). C is accumulated on CPU; for down_proj's large
|
||||
d_in this is heavy -- exclude it or use plain ablation there."""
|
||||
if not getattr(cfg, "cov_orient", False) or calibration_data is None:
|
||||
return
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
cov: dict[str, T] = {}
|
||||
cnt: dict[str, int] = {n: 0 for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
|
||||
g = x.T @ x
|
||||
cov[name] = g if name not in cov else cov[name] + g
|
||||
cnt[name] += x.shape[0]
|
||||
return _h
|
||||
|
||||
handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
if isinstance(batch, dict):
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r, eps = cfg.r, float(cfg.cov_eps)
|
||||
for name, layer in layers.items():
|
||||
if cnt[name] < r:
|
||||
raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}")
|
||||
W_res = layer.weight.data.float().cpu()
|
||||
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
|
||||
layer.lora_S.float().cpu(),
|
||||
layer.lora_Vh.float().cpu())
|
||||
W_orig = W_res + (U_old * S_old) @ Vh_old
|
||||
|
||||
C = cov[name] / cnt[name]
|
||||
lam, Q = torch.linalg.eigh(C)
|
||||
lam = lam.clamp_min(0) + eps
|
||||
Chalf = (Q * lam.sqrt()) @ Q.T
|
||||
Cinvhalf = (Q * lam.rsqrt()) @ Q.T
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
|
||||
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
|
||||
Sr = St[:r]
|
||||
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
|
||||
W_res_new = W_orig - (Ur * Sr) @ Pr
|
||||
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
|
||||
layer.weight.data.copy_(W_res_new.to(layer.weight))
|
||||
|
||||
@staticmethod
|
||||
def _orthonormal(c: T) -> T:
|
||||
"""(r, k) -> (r, k) with orthonormal columns. k=1 is a plain normalize."""
|
||||
if c.shape[-1] == 1:
|
||||
return c / (c.norm(dim=0, keepdim=True) + ε)
|
||||
# geqrf has no bf16/fp16 kernel (CPU or CUDA); do the QR in fp32, cast back.
|
||||
q, _ = torch.linalg.qr(c.float()) # reduced QR; columns orthonormal
|
||||
return q.to(c.dtype)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
Chat = AntiPaSTOAblate._orthonormal(layer.lora_c.to(x.dtype)) # (r, k)
|
||||
alpha = layer.lora_alpha.to(x.dtype).clamp(0.0, 1.0) # (k,)
|
||||
coeff = float(cfg.coeff)
|
||||
|
||||
h = (x @ Vh.T) * S # (..., r) output S-coords
|
||||
proj = h @ Chat # (..., k) component along each dir
|
||||
# contractive removal: h <- h - coeff * Sum_j alpha_j (h . chat_j) chat_j
|
||||
h = h - coeff * (proj * alpha) @ Chat.T # (..., r)
|
||||
return y + h @ U.T
|
||||
@@ -0,0 +1,203 @@
|
||||
"""AntiPaSTO-CorDA: steer in a covariance-ORIENTED basis, not the weight-gain basis.
|
||||
|
||||
The complaint that motivates this: plain SVD sorts directions by weight gain ||W v||
|
||||
on an *isotropic* input. The behaviour you steer lives where the *data* has energy.
|
||||
Those orderings disagree, so the behaviour smears off the top singular axes and a
|
||||
top-r crop in the weight basis throws it away. CorDA (Yang+ 2024, arXiv:2406.05223)
|
||||
re-orients the decomposition by the input covariance C = E[x x^T], so the top
|
||||
directions are the ones with the most energy *on real activations*.
|
||||
|
||||
Decomposition (verified: full-rank reconstruction ~1e-5, and on anisotropic data the
|
||||
top-r data-truncation error drops ~27x vs plain SVD):
|
||||
|
||||
C = E[x x^T] (+ eps I) # input second moment on calibration data
|
||||
C^{1/2}, C^{-1/2} via eigh(C)
|
||||
W~ = W C^{1/2}; SVD(W~) = U S V~h
|
||||
P = V~h C^{-1/2} # (r, d_in) OBLIQUE input projector
|
||||
W = U diag(S) P (exactly) # so y = x W_res^T + ((x P^T) * S_eff) U^T
|
||||
|
||||
S here are the singular values of W weighted by input std, so top-r is the optimal
|
||||
rank-r in the input-weighted norm E||(W - W_r) x||^2 -- the directions that actually
|
||||
move the output on your data.
|
||||
|
||||
Connection to the shared/differing-basis problem: C is built from pos AND neg inputs
|
||||
pooled, so P spans the *shared* activation structure (the common encoder) that
|
||||
chosen-minus-rejected cancels by construction. A trainable gain on this basis can
|
||||
therefore reach shared structure that contrastive dS extraction is blind to.
|
||||
|
||||
Core: rotation-free. S_eff = S * (1 + ELU(coeff * g)). This is exp(coeff*g) on the
|
||||
attenuation side (g<0, bounded, no blow-up) and 1+coeff*g on the amplification side
|
||||
(g>0, where exp would diverge). g=0 -> identity. coeff is the runtime knob (0=off).
|
||||
|
||||
Basis note: P is OBLIQUE (rows not orthonormal -- C^{-1/2} skews them). That is fine
|
||||
for gain reweighting (we scale oblique coordinates), and also fine for OUTPUT-side
|
||||
directional ablation: the obliqueness is input-side only, while ablation acts in the
|
||||
U/output space where U stays orthonormal. antipasto_ablate has a cov_orient flag that
|
||||
reuses this basis -- at low r it captures the behavior output direction that plain-SVD
|
||||
top-r drops (measured 1.00 vs 0.65 at r=16).
|
||||
|
||||
Falls back to plain SVD (== antipasto, rotation-free) if no calibration_data.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTOCorDAConfig(AdapterConfig):
|
||||
variant: str = "antipasto_corda"
|
||||
r: int = 256
|
||||
cov_eps: float = 1e-3 # damping on C eigenvalues; guards C^{-1/2} on rare dirs
|
||||
coeff: float = 1.0 # runtime steer knob: 0=identity, scales trained g
|
||||
suppress_only: bool = False # clamp g<=0 (attenuate only; no amplification)
|
||||
|
||||
|
||||
def _gain(S: T, g: T, coeff: float, suppress_only: bool) -> T:
|
||||
"""S_eff = S * (1 + ELU(coeff*g)); exp-bounded attenuation, linear amplification."""
|
||||
if suppress_only:
|
||||
g = g.clamp(max=0.0)
|
||||
return S * (1.0 + F.elu(coeff * g))
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTOCorDA:
|
||||
name = "antipasto_corda"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r = cfg.r
|
||||
return dict(
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
# P replaces Vh: oblique covariance-oriented input projector.
|
||||
lora_P=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable per-direction log-scale. init 0 -> 1+ELU(0)=1 -> exact identity.
|
||||
# No sign-symmetry hack needed (1+ELU is sign-preserving, basis frozen),
|
||||
# matching antipasto.py.
|
||||
lora_g=ParamSpec((r,), init="zeros"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
"""Plain-SVD fallback so the adapter is valid before group_init. group_init
|
||||
refines P/U/S to the covariance-oriented basis when calibration_data is given."""
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError("AntiPaSTOCorDA mutates layer.weight into W_res; nn.Linear only.")
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_P.copy_(Vhr.to(layer.lora_P.dtype)) # P := Vh until oriented
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Re-orient each target's SVD by its input covariance C = E[x x^T].
|
||||
|
||||
Without calibration_data the plain-SVD init from init() is kept (so this
|
||||
degrades to antipasto, rotation-free).
|
||||
|
||||
Called by attach() BEFORE any training, so the trainable g is still at its
|
||||
zero init when the basis changes -- re-orienting zero gains is a no-op, no
|
||||
re-indexing needed. Do not call group_init after training has updated g."""
|
||||
if calibration_data is None:
|
||||
return
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
# accumulate C = sum x x^T on CPU. Peak GPU cost would otherwise be
|
||||
# sum_targets d_in^2 fp32 held at once; for down_proj (d_in=intermediate,
|
||||
# e.g. 14336) that is ~0.8 GB *per layer* and OOMs. CPU accumulation bounds
|
||||
# GPU use to the live activation; the eigh/SVD below run on CPU (one-time).
|
||||
# Diagonal C is NOT a usable shortcut: it misses cross-channel correlation,
|
||||
# which is where the orientation gain lives (measured ~= plain SVD).
|
||||
# If down_proj's d_in^2 is too big even on CPU/RAM, exclude it from CorDA
|
||||
# (leave it on plain antipasto) or use a low-rank C (top-k eig of subsampled
|
||||
# inputs) -- not implemented here.
|
||||
cov: dict[str, T] = {}
|
||||
cnt: dict[str, int] = {n: 0 for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
|
||||
g = x.T @ x # (d_in, d_in) on CPU
|
||||
cov[name] = g if name not in cov else cov[name] + g
|
||||
cnt[name] += x.shape[0]
|
||||
return _h
|
||||
|
||||
handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
if isinstance(batch, dict):
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r, eps = cfg.r, float(cfg.cov_eps)
|
||||
for name, layer in layers.items():
|
||||
if cnt[name] < r:
|
||||
raise RuntimeError(f"AntiPaSTOCorDA at {name}: {cnt[name]} tokens, need >= r={r}")
|
||||
# decomposition on CPU (where C lives); copy results back to device buffers.
|
||||
W_res = layer.weight.data.float().cpu()
|
||||
U_old, S_old, P_old = (layer.lora_U.float().cpu(),
|
||||
layer.lora_S.float().cpu(),
|
||||
layer.lora_P.float().cpu())
|
||||
W_orig = W_res + (U_old * S_old) @ P_old
|
||||
|
||||
C = cov[name] / cnt[name] # (d_in, d_in) CPU
|
||||
lam, Q = torch.linalg.eigh(C)
|
||||
lam = lam.clamp_min(0) + eps
|
||||
Chalf = (Q * lam.sqrt()) @ Q.T # C^{1/2}
|
||||
Cinvhalf = (Q * lam.rsqrt()) @ Q.T # C^{-1/2}
|
||||
|
||||
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
|
||||
Ur = Ut[:, :r] # (d_out, r)
|
||||
Sr = St[:r] # (r,)
|
||||
Pr = (Vht[:r] @ Cinvhalf) # (r, d_in) oblique projector
|
||||
W_res_new = (W_orig - (Ur * Sr) @ Pr)
|
||||
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_P.copy_(Pr.to(layer.lora_P))
|
||||
layer.weight.data.copy_(W_res_new.to(layer.weight))
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
P = layer.lora_P.to(x.dtype) # (r, d_in) oblique
|
||||
g = layer.lora_g.to(x.dtype) # (r,)
|
||||
S_eff = _gain(S, g, float(cfg.coeff), bool(cfg.suppress_only))
|
||||
h = (x @ P.T) * S_eff # (..., r)
|
||||
return y + h @ U.T
|
||||
@@ -0,0 +1,232 @@
|
||||
"""AntiPaSTO-Rot: the original SVD adapter with learnable singular-value deltas +
|
||||
block-diagonal Cayley rotation. Kept as a SEPARATE variant so we can benchmark the
|
||||
rotation version against the rotation-free 1+ELU gain (antipasto.py) head to head.
|
||||
|
||||
wassname 2026 https://arxiv.org/abs/2601.07473
|
||||
|
||||
W = U diag(S) Vh + W_res (top-r SVD; W_res = W - U_r S_r Vh_r)
|
||||
learn: delta_s (r,), rot_T (n_blocks, bs(bs-1)/2)
|
||||
R = block_diag(Cayley(skew(rot_T))); Vh_eff = R @ Vh (or U_eff = U @ R.T)
|
||||
y = x @ W_res.T + ((x @ Vh_eff.T) * (S + delta_s)) @ U_eff.T
|
||||
|
||||
Identity at t=0: rot_T=0 -> R=I, delta_s~4e-4 -> y ~ x @ W^T (tiny positive bias on
|
||||
delta_s breaks sign symmetry).
|
||||
|
||||
Why antipasto.py dropped the rotation: rotating Vh/U leaves the interpretable singular
|
||||
basis, and the Cayley solve was numerically finicky. This file preserves it for the
|
||||
all-else-equal comparison (does the cross-direction mixing the rotation buys beat the
|
||||
cheaper, more stable gain-only adapter on the same targets and budget?).
|
||||
|
||||
Refs:
|
||||
- paper: https://github.com/wassname/AntiPaSTO
|
||||
- lite port of: https://github.com/wassname/antipasto3
|
||||
(offline: docs/refs/antipasto3_svd_adapter.py)
|
||||
"""
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Literal
|
||||
|
||||
import torch
|
||||
from einops import einsum, rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTORotConfig(AdapterConfig):
|
||||
variant: str = "antipasto_rot"
|
||||
# Higher default than LoRA (r=8) since trainable params scale as r + r/bs*bs*(bs-1)/2, not r*(d_in+d_out).
|
||||
r: int = 256
|
||||
# Block size for the block-diagonal Cayley rotation. r must be divisible by it.
|
||||
block_size: int = 4
|
||||
# Cayley map saturation: bounds rotation angle to ~max_rotation_angle radians.
|
||||
max_rotation_angle: float = 0.5
|
||||
# Which singular basis to rotate: 'V' (input), 'U' (output), or 'none'.
|
||||
rotate_basis: Literal["V", "U", "none"] = "V"
|
||||
|
||||
|
||||
def _cayley(skew: torch.Tensor) -> torch.Tensor:
|
||||
"""R = (I - X)^-1 (I + X) for X = skew/2; preserves orthogonality."""
|
||||
bs = skew.shape[-1]
|
||||
eye = torch.eye(bs, dtype=skew.dtype, device=skew.device).expand_as(skew)
|
||||
X = skew / 2
|
||||
return torch.linalg.solve(eye - X, eye + X)
|
||||
|
||||
|
||||
def _build_rotation(rot_T: torch.Tensor, bs: int, max_angle: float) -> torch.Tensor:
|
||||
"""rot_T: (n_blocks, bs*(bs-1)/2) -> R: (n_blocks, bs, bs) Cayley rotation."""
|
||||
n_blocks, _ = rot_T.shape
|
||||
rows, cols = torch.triu_indices(bs, bs, offset=1, device=rot_T.device).unbind(0)
|
||||
A = torch.zeros(n_blocks, bs, bs, dtype=rot_T.dtype, device=rot_T.device)
|
||||
A[:, rows, cols] = rot_T
|
||||
A = 0.5 * (A - A.transpose(-1, -2))
|
||||
a_limit = 2.0 * math.tan(max_angle / 2.0)
|
||||
A = a_limit * torch.tanh(A / a_limit)
|
||||
return _cayley(A)
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTORot:
|
||||
name = "antipasto_rot"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r = cfg.r
|
||||
bs = int(cfg.block_size)
|
||||
if r % bs != 0:
|
||||
raise ValueError(f"AntiPaSTORot requires r={r} divisible by block_size={bs}")
|
||||
specs = dict(
|
||||
# Frozen SVD components captured at init.
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Trainable: per-singular-value delta.
|
||||
# antipasto3 uses 4e-4 + N(0, 4e-4): small positive bias breaks sign
|
||||
# symmetry (rotation alone can't); zero-init works but trains slower.
|
||||
lora_delta_s=ParamSpec((r,), init=lambda t: t.normal_(0, 4e-4).add_(4e-4)),
|
||||
)
|
||||
if cfg.rotate_basis != "none":
|
||||
n_blocks = r // bs
|
||||
n_triu = bs * (bs - 1) // 2
|
||||
specs["lora_rot_T"] = ParamSpec((n_blocks, n_triu), init="zeros")
|
||||
return specs
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError(
|
||||
"AntiPaSTORot mutates layer.weight into W_res (like PiSSA), so v1 "
|
||||
"only supports plain nn.Linear, not bnb 4/8-bit."
|
||||
)
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
# group_init() refines this to input-aligned directions if calibration_data is given.
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Wanda-style data-driven dimension selection within the weight SVD.
|
||||
|
||||
init() picks the top-r singular dimensions by S alone (PiSSA-style).
|
||||
group_init() re-selects based on S[i] * mean|X @ Vh[i]|: dimensions
|
||||
that are both large in W AND active given real inputs.
|
||||
|
||||
If calibration_data is None the weight-SVD init from init() is kept.
|
||||
"""
|
||||
if calibration_data is None:
|
||||
return
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
captured: dict[str, list[T]] = {n: [] for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
x = args[0].detach()
|
||||
captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu())
|
||||
return _h
|
||||
|
||||
handles = [
|
||||
layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True)
|
||||
for n in layers
|
||||
]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
if isinstance(batch, dict):
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r = cfg.r
|
||||
for name, layer in layers.items():
|
||||
X = torch.cat(captured[name], dim=0) # (N, d_in)
|
||||
if X.shape[0] < r:
|
||||
raise RuntimeError(
|
||||
f"AntiPaSTORot at {name}: only {X.shape[0]} calibration tokens, need >= r={r}"
|
||||
)
|
||||
|
||||
# Recover W_orig: init() wrote W_res into layer.weight and stored top-r components
|
||||
W_res = layer.weight.data.float()
|
||||
U_old = layer.lora_U.float() # (d_out, r)
|
||||
S_old = layer.lora_S.float() # (r,)
|
||||
Vh_old = layer.lora_Vh.float() # (r, d_in)
|
||||
W_orig = W_res + (U_old * S_old.unsqueeze(0)) @ Vh_old
|
||||
|
||||
# Full SVD to score all dimensions
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
|
||||
# score[i] = S[i] * mean|X @ Vh[i]| (Wanda: weight magnitude × activation magnitude)
|
||||
act_mag = (X.to(Vh_full) @ Vh_full.T).abs().mean(dim=0) # (k,) -- X captured on CPU
|
||||
scores = S_full * act_mag
|
||||
idx = scores.argsort(descending=True)[:r] # top-r by joint importance
|
||||
idx = idx.sort().values # stable ordering
|
||||
|
||||
Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx]
|
||||
W_res_new = (W_orig - (Ur * Sr.unsqueeze(0)) @ Vhr).to(layer.weight.dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh))
|
||||
layer.weight.data.copy_(W_res_new)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
bs = int(cfg.block_size)
|
||||
max_angle = float(cfg.max_rotation_angle)
|
||||
rotate_basis = cfg.rotate_basis
|
||||
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
|
||||
if rotate_basis == "none":
|
||||
U_eff, Vh_eff = U, Vh
|
||||
else:
|
||||
R_blocks = _build_rotation(layer.lora_rot_T.float(), bs, max_angle).to(x.dtype)
|
||||
n_blocks = R_blocks.shape[0] # R_blocks: (n, bs, bs)
|
||||
if rotate_basis == "V":
|
||||
Vh_blocks = rearrange(Vh, "(n a) i -> n a i", n=n_blocks)
|
||||
Vh_rot = einsum(R_blocks, Vh_blocks, "n a b, n b i -> n a i")
|
||||
Vh_eff = rearrange(Vh_rot, "n a i -> (n a) i")
|
||||
U_eff = U
|
||||
elif rotate_basis == "U":
|
||||
U_blocks = rearrange(U, "d (n b) -> d n b", n=n_blocks)
|
||||
U_rot = einsum(U_blocks, R_blocks, "d n b, n c b -> d n c")
|
||||
U_eff = rearrange(U_rot, "d n c -> d (n c)")
|
||||
Vh_eff = Vh
|
||||
else:
|
||||
raise ValueError(f"rotate_basis must be 'U', 'V', or 'none', got {rotate_basis!r}")
|
||||
|
||||
S_eff = S + layer.lora_delta_s.to(x.dtype) # (r,)
|
||||
h = x @ Vh_eff.T # x @ Vh_eff.T
|
||||
h = h * S_eff # diag(S_eff)
|
||||
delta = h @ U_eff.T # @ U_eff.T
|
||||
return y + delta
|
||||
@@ -14,7 +14,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from dataclasses import replace
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -31,11 +30,12 @@ sys.modules[SPEC.name] = benchmark
|
||||
SPEC.loader.exec_module(benchmark)
|
||||
|
||||
|
||||
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "road"]
|
||||
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
|
||||
"antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "road"]
|
||||
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
|
||||
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
|
||||
# so we don't expect a raise from them in the attach-only smoke.
|
||||
BNB_RAISERS = {"pissa", "dora", "antipasto"}
|
||||
BNB_RAISERS = {"pissa", "dora", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda"}
|
||||
TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
@@ -75,8 +75,6 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc
|
||||
log_every=1000,
|
||||
output_dir=tmp_path / "out",
|
||||
)
|
||||
if variant == "antipasto":
|
||||
cfg = replace(cfg, alpha=4) # block_size=4 -> need r % 4 == 0
|
||||
return cfg
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user