mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 15:15:55 +08:00
Add antipasto_arrow: structured fixed-basis core (cross-direction mixing)
antipasto's diagonal core can only rescale each frozen singular direction; it can never let direction i's input drive direction j's output, yet the steered behaviour is an off-axis combination. A dense r x r core fixes that but costs r^2 params. antipasto_arrow uses the arrowhead structure instead: a dense b x b block on the top-b singular directions (full coupling where the action lives) plus a diagonal 1+ELU tail on the rest. b^2 + (r-b) params, one b x b matmul per forward -- cross-direction mixing at diagonal-core cost, no Cayley solve. Identity at init (M=0 -> B=I, g=0 -> gain=1). Verified on a Linear: rel_err 1.5e-7 at init; M[i,j] routes input dir j -> output dir i with weight exactly M[i,j] (diagonal core forces 0); 14 train params at r=8,b=3 vs r^2=64. Wired into benchmark (antipasto_block knob), smoke (block=2 for r=4), cost report, and exports. Co-Authored-By: Claudypoo <noreply@anthropic.com>
This commit is contained in:
@@ -49,7 +49,8 @@ def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default="Qwen/Qwen3-0.6B-Base")
|
||||
ap.add_argument("--variants", nargs="+",
|
||||
default=["lora", "antipasto", "antipasto_rot", "antipasto_corda", "antipasto_ablate"])
|
||||
default=["lora", "antipasto", "antipasto_rot", "antipasto_corda",
|
||||
"antipasto_ablate", "antipasto_arrow"])
|
||||
ap.add_argument("--target-name", nargs="+", default=[r"q_proj$", r"v_proj$"])
|
||||
ap.add_argument("--r", type=int, default=32)
|
||||
ap.add_argument("--layers", default="all",
|
||||
|
||||
@@ -37,6 +37,7 @@ CFG_BY_VARIANT = {
|
||||
"antipasto_rot": ll.AntiPaSTORotConfig,
|
||||
"antipasto_ablate": ll.AntiPaSTOAblateConfig,
|
||||
"antipasto_corda": ll.AntiPaSTOCorDAConfig,
|
||||
"antipasto_arrow": ll.AntiPaSTOArrowConfig,
|
||||
"road": ll.RoadConfig,
|
||||
}
|
||||
|
||||
@@ -46,7 +47,7 @@ class BenchmarkConfig:
|
||||
"""MetaMathQA -> GSM8K benchmark config. Tyro turns this into the CLI."""
|
||||
|
||||
model: str = "Qwen/Qwen3-0.6B-Base"
|
||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "road"] = "lora"
|
||||
variant: Literal["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "antipasto_arrow", "road"] = "lora"
|
||||
mode: Literal["benchmark", "probe"] = "benchmark"
|
||||
device: str = "cuda"
|
||||
torch_dtype: str = "bfloat16"
|
||||
@@ -63,6 +64,8 @@ class BenchmarkConfig:
|
||||
antipasto_cov_orient: bool = False
|
||||
# AntiPaSTO-rot (legacy rotation variant) basis to rotate.
|
||||
antipasto_rotate_basis: Literal["V", "U", "none"] = "V"
|
||||
# AntiPaSTO-arrow: dense interaction block size on the top-b singular directions.
|
||||
antipasto_block: int = 8
|
||||
target_name: list[str] = field(default_factory=lambda: list(DEFAULT_TARGETS))
|
||||
layers: str = "all"
|
||||
train_dataset: str = "meta-math/MetaMathQA"
|
||||
@@ -143,6 +146,9 @@ def cfg_for_variant(args: BenchmarkConfig, dtype: torch.dtype) -> ll.AdapterConf
|
||||
if args.variant == "antipasto_ablate":
|
||||
extra = {"coeff": args.antipasto_coeff, "k": args.antipasto_ablate_k,
|
||||
"cov_orient": args.antipasto_cov_orient}
|
||||
if args.variant == "antipasto_arrow":
|
||||
extra = {"coeff": args.antipasto_coeff, "suppress_only": args.antipasto_suppress_only,
|
||||
"block": args.antipasto_block}
|
||||
return CFG_BY_VARIANT[args.variant](
|
||||
r=args.r,
|
||||
alpha=args.r if args.variant == "pissa" else args.alpha,
|
||||
|
||||
@@ -23,6 +23,7 @@ from .variants.antipasto import AntiPaSTOConfig
|
||||
from .variants.antipasto_rot import AntiPaSTORotConfig
|
||||
from .variants.antipasto_ablate import AntiPaSTOAblateConfig
|
||||
from .variants.antipasto_corda import AntiPaSTOCorDAConfig
|
||||
from .variants.antipasto_arrow import AntiPaSTOArrowConfig
|
||||
from .variants.road import RoadConfig
|
||||
|
||||
__all__ = [
|
||||
@@ -39,6 +40,7 @@ __all__ = [
|
||||
"AntiPaSTORotConfig",
|
||||
"AntiPaSTOAblateConfig",
|
||||
"AntiPaSTOCorDAConfig",
|
||||
"AntiPaSTOArrowConfig",
|
||||
"RoadConfig",
|
||||
"attach",
|
||||
"detach",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from . import ( # noqa: F401 side-effect: register
|
||||
lora, pissa, delora, ia3, dora, hra, eva, antipasto, road,
|
||||
antipasto_rot, antipasto_ablate, antipasto_corda,
|
||||
antipasto_rot, antipasto_ablate, antipasto_corda, antipasto_arrow,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
"""AntiPaSTO-Arrow: a STRUCTURED fixed-basis core, the cheap way to add cross-
|
||||
direction mixing that plain antipasto (a diagonal gain) cannot express.
|
||||
|
||||
antipasto's core is diagonal: S_eff = S * (1 + ELU(coeff*g)) reweights each frozen
|
||||
singular direction independently. It can turn a direction up or down but it can never
|
||||
let direction i's input drive direction j's output. Yet the behaviour you steer is a
|
||||
combination Sigma c_i v_i that generically lies OFF any single axis (the same argument
|
||||
that motivates antipasto_corda), so a diagonal core can only ever approximate it.
|
||||
|
||||
The obvious fix -- a full dense r x r core M, DeltaW = U M Vh -- restores all mixing but
|
||||
costs r^2 params (r=256 -> 65536, a rank-8 LoRA's worth) and an r x r matmul per forward.
|
||||
antipasto.py's own header flags this trap: "a dense r x r core is r^2 params ... add a
|
||||
*fixed-basis* core U M Vh rather than rotating". This file is that core, made cheap by
|
||||
making it STRUCTURED instead of dense -- an arrowhead, not an r x r.
|
||||
|
||||
Arrowhead structure (dense top-block + diagonal tail):
|
||||
|
||||
core C (r x r, acting on the S-scaled coords) =
|
||||
|
||||
[ B (b x b dense) | 0 ] B couples the top-b directions
|
||||
[ 0 | diag(1+ELU(c*g)) ] tail = exactly antipasto's gain
|
||||
|
||||
DeltaW = U @ C @ diag(S) @ Vh
|
||||
|
||||
The top b singular directions (largest S = where PiSSA says the action lives) get a full
|
||||
b x b interaction block B = I_b + coeff*M; the remaining r-b stay on the cheap bounded
|
||||
diagonal gain. Cost is b^2 + (r-b) params and one b x b matmul per forward -- for b=8,r=256
|
||||
that is 312 params and a 64-MAC corner, versus 65536 for dense r x r and versus the
|
||||
rotation variant's per-forward Cayley solve (measured 72ms vs 36ms). So: cross-direction
|
||||
mixing where it matters, at diagonal-core cost.
|
||||
|
||||
(We call it "arrowhead" after the shape -- a dense head with a diagonal shaft. A true
|
||||
numerical-LA arrowhead also carries a hub row+column coupling the block to the tail; that
|
||||
would add 2(r-b) params and is a one-line extension if the top-b span turns out too small.
|
||||
Not added until measured to be needed.)
|
||||
|
||||
Identity at init: M=0 -> B=I_b, g=0 -> 1+ELU(0)=1, so C=I and DeltaW = U diag(S) Vh exactly
|
||||
(up to the one-time SVD-residual rounding). coeff=0 -> C=I too (runtime off). The block is
|
||||
the linear-amplification regime of antipasto's design (a matmul, constant-gradient, no exp
|
||||
self-amplification); it is stable like 1+ELU's upper branch, not strictly bounded -- if you
|
||||
need the tail's structural can't-blow-up guarantee on the top directions too, use
|
||||
antipasto_ablate instead.
|
||||
|
||||
Refs: antipasto.py (diagonal sibling), antipasto_corda.py (the off-axis argument).
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
from torch import nn, Tensor as T
|
||||
|
||||
from ..variant import register, ParamSpec
|
||||
from ..config import AdapterConfig, register_config
|
||||
|
||||
CalibrationBatch = dict | tuple | list | T
|
||||
CalibrationData = Iterable[CalibrationBatch]
|
||||
|
||||
|
||||
@register_config
|
||||
@dataclass
|
||||
class AntiPaSTOArrowConfig(AdapterConfig):
|
||||
variant: str = "antipasto_arrow"
|
||||
r: int = 256
|
||||
# Size of the dense interaction block on the top-b singular directions. The ONLY
|
||||
# quadratic cost (b^2 params); keep small. b=1 degenerates to antipasto.
|
||||
block: int = 8
|
||||
suppress_only: bool = False # clamp the tail g<=0 (attenuate only); block unaffected
|
||||
coeff: float = 1.0 # runtime knob: 0=identity, scales both block and tail
|
||||
act_pool: Literal["rms", "mean_abs"] = "rms" # group_init selection, see antipasto
|
||||
|
||||
|
||||
@register
|
||||
class AntiPaSTOArrow:
|
||||
name = "antipasto_arrow"
|
||||
|
||||
@staticmethod
|
||||
def param_specs(d_in, d_out, cfg):
|
||||
r, b = cfg.r, cfg.block
|
||||
if not 1 <= b < r:
|
||||
raise ValueError(f"antipasto_arrow needs 1 <= block({b}) < r({r}).")
|
||||
return dict(
|
||||
lora_U=ParamSpec((d_out, r), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_S=ParamSpec((r,), init="zeros", trainable=False, as_buffer=True),
|
||||
lora_Vh=ParamSpec((r, d_in), init="zeros", trainable=False, as_buffer=True),
|
||||
# Dense b x b interaction on the top-b directions. init 0 -> B=I -> identity.
|
||||
lora_M=ParamSpec((b, b), init="zeros"),
|
||||
# Diagonal bounded gain on the remaining r-b directions (== antipasto's g).
|
||||
lora_g=ParamSpec((r - b,), init="zeros"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init(layer: nn.Module, cfg) -> None:
|
||||
if type(layer) is not nn.Linear:
|
||||
raise TypeError("AntiPaSTOArrow mutates layer.weight into W_res; nn.Linear only.")
|
||||
with torch.no_grad():
|
||||
W = layer.weight.data.float()
|
||||
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
|
||||
r = cfg.r
|
||||
Ur, Sr, Vhr = U[:, :r], S[:r], Vh[:r, :]
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U.dtype))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S.dtype))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
|
||||
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
layer.weight.data.copy_(W_res)
|
||||
|
||||
@staticmethod
|
||||
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
|
||||
"""Wanda-style data-driven dimension selection, identical to antipasto: re-pick
|
||||
the top-r directions by S[i] * pool|X @ Vh[i]|. Runs before training (g, M at
|
||||
their zero init), so re-selecting the basis is a harmless no-op on the core."""
|
||||
if calibration_data is None:
|
||||
return
|
||||
|
||||
layers = {name: layer for name, layer, _ in targets}
|
||||
captured: dict[str, list[T]] = {n: [] for n in layers}
|
||||
|
||||
def make_hook(name):
|
||||
def _h(module, args, kwargs):
|
||||
x = args[0].detach()
|
||||
captured[name].append(rearrange(x, "... d -> (...) d").to(torch.float32).cpu())
|
||||
return _h
|
||||
|
||||
handles = [layers[n].register_forward_pre_hook(make_hook(n), with_kwargs=True) for n in layers]
|
||||
try:
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in calibration_data:
|
||||
if isinstance(batch, dict):
|
||||
model(**batch)
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
model(*batch)
|
||||
else:
|
||||
model(batch)
|
||||
if was_training:
|
||||
model.train()
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
r, pool = cfg.r, cfg.act_pool
|
||||
for name, layer in layers.items():
|
||||
X = torch.cat(captured[name], dim=0)
|
||||
if X.shape[0] < r:
|
||||
raise RuntimeError(f"AntiPaSTOArrow at {name}: {X.shape[0]} tokens, need >= r={r}")
|
||||
W_res = layer.weight.data.float()
|
||||
W_orig = W_res + (layer.lora_U.float() * layer.lora_S.float()) @ layer.lora_Vh.float()
|
||||
U_full, S_full, Vh_full = torch.linalg.svd(W_orig, full_matrices=False)
|
||||
proj = X.to(Vh_full) @ Vh_full.T
|
||||
act_mag = proj.pow(2).mean(0).sqrt() if pool == "rms" else proj.abs().mean(0)
|
||||
idx = (S_full * act_mag).argsort(descending=True)[:r].sort().values
|
||||
Ur, Sr, Vhr = U_full[:, idx], S_full[idx], Vh_full[idx]
|
||||
W_res_new = (W_orig - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
|
||||
with torch.no_grad():
|
||||
layer.lora_U.copy_(Ur.to(layer.lora_U))
|
||||
layer.lora_S.copy_(Sr.to(layer.lora_S))
|
||||
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh))
|
||||
layer.weight.data.copy_(W_res_new)
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
layer: nn.Module,
|
||||
x: Float[T, '*B i'],
|
||||
y: Float[T, '*B o'],
|
||||
) -> Float[T, '*B o']:
|
||||
cfg = layer._lora_cfg
|
||||
U = layer.lora_U.to(x.dtype) # (d_out, r)
|
||||
S = layer.lora_S.to(x.dtype) # (r,)
|
||||
Vh = layer.lora_Vh.to(x.dtype) # (r, d_in)
|
||||
M = layer.lora_M.to(x.dtype) # (b, b)
|
||||
g = layer.lora_g.to(x.dtype) # (r-b,)
|
||||
coeff, b = float(cfg.coeff), cfg.block
|
||||
|
||||
cS = (x @ Vh.T) * S # (..., r) = diag(S) Vh x
|
||||
|
||||
# Top-b: dense block B = I_b + coeff*M couples the top singular directions.
|
||||
eye = torch.eye(b, dtype=x.dtype, device=x.device)
|
||||
top = cS[..., :b] @ (eye + coeff * M).T # (..., b)
|
||||
# Tail: antipasto's bounded diagonal gain (see antipasto.py for the 1+ELU why).
|
||||
if cfg.suppress_only:
|
||||
g = torch.clamp(g, max=0.0)
|
||||
tail = cS[..., b:] * (1.0 + F.elu(coeff * g)) # (..., r-b)
|
||||
|
||||
h = torch.cat([top, tail], dim=-1) # (..., r)
|
||||
return y + h @ U.T
|
||||
@@ -31,11 +31,13 @@ SPEC.loader.exec_module(benchmark)
|
||||
|
||||
|
||||
VARIANTS = ["lora", "pissa", "delora", "ia3", "ia3_ff", "dora", "hra", "eva",
|
||||
"antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda", "road"]
|
||||
"antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda",
|
||||
"antipasto_arrow", "road"]
|
||||
# Variants that fail loud when attached on a bnb-loaded base (read dense weight in init).
|
||||
# delora/eva also read weight but currently silently dequant -- they produce sane attach,
|
||||
# so we don't expect a raise from them in the attach-only smoke.
|
||||
BNB_RAISERS = {"pissa", "dora", "antipasto", "antipasto_rot", "antipasto_ablate", "antipasto_corda"}
|
||||
BNB_RAISERS = {"pissa", "dora", "antipasto", "antipasto_rot", "antipasto_ablate",
|
||||
"antipasto_corda", "antipasto_arrow"}
|
||||
TINY_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
@@ -57,6 +59,7 @@ def quick_cfg(variant: str, tmp_path: Path, quantization: str = "none") -> "benc
|
||||
quantization=quantization,
|
||||
r=4,
|
||||
alpha=8,
|
||||
antipasto_block=2, # antipasto_arrow needs block < r (r=4 here)
|
||||
target_name=target_name,
|
||||
layers="all",
|
||||
steps=2,
|
||||
|
||||
Reference in New Issue
Block a user