mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 15:15:40 +08:00
cleanup: delete antipasto.py; attic 7 erase-era scripts (T1/T6)
antipasto.py (PiSSA/lora_frozen_b/old-lora2r wrappers) is dead in the live path -- train.py/extract use lora2r.py, nothing imports antipasto. Move the 7 scripts that import it or the erase-era proj fns (rescore_deploy, eval_checkpoint_curve, verify_vhack_heldout, probe_distill, diag_cosine_dist, diag_pairs_compare, tt_erase_bench) to scripts/attic/ -- they need lora2r rewrites if resurrected. Live imports verified clean. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -1,351 +0,0 @@
|
||||
"""AntiPaSTO full-rank adapter via forward hooks (lora-lite style).
|
||||
|
||||
Per spec.md: each target nn.Linear keeps its original weight intact. We attach
|
||||
frozen buffers U, Vh and a trainable delta_S of shape [r] per layer. A forward
|
||||
post-hook adds the delta contribution:
|
||||
|
||||
y_new = y + U @ (delta_S * (Vh @ x))
|
||||
|
||||
equivalent to W -> W + U diag(delta_S) Vh. At delta_S = 0 the delta is exactly
|
||||
zero, so the wrapped model is bit-identical to the base (no SVD round-trip
|
||||
error on the main path -- W stays as it was loaded). U, Vh stay frozen and
|
||||
double as the basis for v_hack gradient projection (we read delta_S.grad
|
||||
directly; no extra projection math at the gradient step).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from loguru import logger
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
def svd_cached(
|
||||
W: Float[Tensor, "d_out d_in"],
|
||||
cache_path: Path,
|
||||
device: torch.device,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
"""SVD with disk cache. Compute on `device` in fp32, save as fp32 cpu tensors.
|
||||
|
||||
Cache key = sha256(W.cpu fp32 bytes)[:16] in filename suffix, so weight change
|
||||
invalidates the cache automatically (fail-loud, no silent stale).
|
||||
"""
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
W_fp32 = W.detach().to(torch.float32).cpu().contiguous()
|
||||
sha = hashlib.sha256(W_fp32.numpy().tobytes()).hexdigest()[:16]
|
||||
final = cache_path.with_suffix(f".{sha}.pt")
|
||||
if final.exists():
|
||||
d = torch.load(final, map_location="cpu", weights_only=True)
|
||||
return d["U"], d["S"], d["Vh"]
|
||||
W_gpu = W_fp32.to(device)
|
||||
U, S, Vh = torch.linalg.svd(W_gpu, full_matrices=False)
|
||||
U, S, Vh = U.cpu(), S.cpu(), Vh.cpu()
|
||||
torch.save({"U": U, "S": S, "Vh": Vh}, final)
|
||||
# debug: per-module SVD details only appear on first computation and are
|
||||
# noise on cache-hit runs (252 lines × ~150 char = ~38k chars per extract).
|
||||
logger.debug(f"SVD computed: {final.name} U={tuple(U.shape)} S0={S[0]:.3f} S-1={S[-1]:.3e}")
|
||||
return U, S, Vh
|
||||
|
||||
|
||||
TARGET_SUFFIXES = (
|
||||
# full attention
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
# linear-attention / GatedDeltaNet
|
||||
"in_proj_qkv", "in_proj_z", "in_proj_a", "in_proj_b", "out_proj",
|
||||
# MLP
|
||||
"up_proj", "gate_proj", "down_proj",
|
||||
)
|
||||
|
||||
|
||||
def is_target(name: str) -> bool:
|
||||
return name.split(".")[-1] in TARGET_SUFFIXES
|
||||
|
||||
|
||||
def _delta_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
|
||||
"""Add the AntiPaSTO delta to y, in the frozen SVD basis:
|
||||
|
||||
y += U @ ((delta_S + delta_S_hack) * (Vh @ x))
|
||||
|
||||
delta_S = the KEPT/deployed knob; delta_S_hack = the QUARANTINE, parked with
|
||||
the routed (hack-ward) gradient and zeroed at deploy. Both diagonals are
|
||||
shape [r] in the same basis (capacity-balanced, no sink-bias) and both 0 at
|
||||
init -> identity. Routing decides per update which gradient lands in which:
|
||||
erase strips the hack-ward part (proj.py); route parks it in delta_S_hack
|
||||
by subspace projection (proj.py); routeV parks it by a per-rollout
|
||||
calibrated-tau cosine gate (train.py, post-backward).
|
||||
|
||||
For routeV's per-rollout routing (layer._antipasto_grad_probe) we splice a
|
||||
per-token gate c (init 1, forward-identity) onto the delta_S path: after
|
||||
backward c.grad = delta_S * g_b, so train.py recovers the per-rollout delta_S
|
||||
gradient, flags rollouts by cos(g_b, v_grad) vs tau, and routes the flagged
|
||||
contribution into delta_S_hack.grad. No quarantine LoRA, no forward detach.
|
||||
"""
|
||||
(x,) = args
|
||||
Vh = layer._antipasto_Vh # [r, d_in]
|
||||
U = layer._antipasto_U # [d_out, r]
|
||||
delta_S = layer._antipasto_delta_S # [r]
|
||||
delta_S_hack = layer._antipasto_delta_S_hack # [r]
|
||||
|
||||
a = torch.nn.functional.linear(x, Vh) # [..., r]
|
||||
hack = torch.nn.functional.linear(a * delta_S_hack.to(a.dtype), U) # quarantine path
|
||||
if layer._antipasto_grad_probe and torch.is_grad_enabled():
|
||||
# gate c, one entry per (token, axis) since nn.Linear flattens the batch
|
||||
# ([G*s, r]); identity at c=1 so the forward value is unchanged. After
|
||||
# backward c.grad = delta_S * g_b (per-token); train.py reshapes to
|
||||
# [G, s, r], sums each rollout's tokens, divides out delta_S to recover
|
||||
# the per-rollout g_b, and routes post-backward.
|
||||
c = torch.ones(a.shape[0], *([1] * (a.dim() - 2)), a.shape[-1],
|
||||
device=a.device, dtype=a.dtype, requires_grad=True)
|
||||
layer._antipasto_gate = c
|
||||
# Cache the activation As = Vh@x for the act_vote gate (same flattened layout as
|
||||
# the gate, so train.py reshapes both identically). Detached: read-only gate input.
|
||||
layer._antipasto_act = a.detach()
|
||||
kept = torch.nn.functional.linear((a * c) * delta_S.to(a.dtype), U)
|
||||
else:
|
||||
kept = torch.nn.functional.linear(a * delta_S.to(a.dtype), U)
|
||||
return y + (kept + hack).to(y.dtype)
|
||||
|
||||
|
||||
def _lora_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
|
||||
"""LoRA-frozen-B delta: y += B @ ((A + A_hack) @ x), with B a FROZEN random
|
||||
up-projection. The trainable is the full down-projection A [r, d_in] (plus the
|
||||
quarantine A_hack [r, d_in]); A=A_hack=0 at init -> identity.
|
||||
|
||||
Routing lives in the r-dim bottleneck h = A@x. Frozen B makes the
|
||||
error->bottleneck map g_h = B^T δ_y a STATIC linear operator -- that is the
|
||||
"static gradient path" frozen-B buys. The kept bottleneck (A@x) and the
|
||||
quarantine bottleneck (A_hack@x) both feed the same frozen B, so they receive
|
||||
the SAME upstream g_h; A.grad == A_hack.grad before routing, and routeV just
|
||||
splits that single gradient (train.py). grad_probe retains h.grad (= g_h) and
|
||||
caches x so the per-rollout split Σ_b f_b Σ_t g_h[t]⊗x[t] can be formed.
|
||||
"""
|
||||
(x,) = args
|
||||
A = layer._lora_A # [r, d_in] trainable (kept) -> info["delta_S"]
|
||||
A_hack = layer._lora_A_hack # [r, d_in] quarantine -> info["delta_S_hack"]
|
||||
B = layer._lora_B # [d_out, r] frozen
|
||||
h = torch.nn.functional.linear(x, A.to(x.dtype)) # [..., r] kept bottleneck
|
||||
h_hack = torch.nn.functional.linear(x, A_hack.to(x.dtype)) # [..., r] quarantine bottleneck
|
||||
if layer._lora_grad_probe and torch.is_grad_enabled():
|
||||
h.retain_grad() # h.grad = g_h = B^T δ_y after backward
|
||||
layer._lora_h = h
|
||||
layer._lora_x = x.detach() # per-token input for the A.grad split
|
||||
delta = torch.nn.functional.linear(h + h_hack, B.to(x.dtype)) # [..., d_out]
|
||||
return y + delta.to(y.dtype)
|
||||
|
||||
|
||||
def wrap_model_with_lora_frozen_b(
|
||||
model: nn.Module,
|
||||
model_name: str,
|
||||
r: int = 32,
|
||||
b_seed: int = 0,
|
||||
grad_probe: bool = False,
|
||||
) -> dict[str, dict]:
|
||||
"""Attach a LoRA-frozen-B adapter to every target Linear (in place).
|
||||
|
||||
Same info-dict interface as wrap_model_with_antipasto (delta_S = A, delta_S_hack
|
||||
= A_hack), so the optimizer collection, ablate_quarantine, and checkpointing work
|
||||
unchanged. ~r*d_in trainable scalars per module (vs r for AntiPaSTO) -- 10-100x
|
||||
more params; use a small r (=32) and a smaller batch if memory binds.
|
||||
|
||||
B is a fixed Haar-ish random matrix scaled 1/sqrt(r) (LoRA-standard up-proj
|
||||
magnitude), seeded by b_seed for reproducibility. No SVD, no W round-trip.
|
||||
"""
|
||||
g = torch.Generator().manual_seed(b_seed)
|
||||
targets = [(n, m) for n, m in model.named_modules()
|
||||
if isinstance(m, nn.Linear) and is_target(n)]
|
||||
logger.info(f"LoRA-frozen-B attach: {len(targets)} target Linear modules, r={r}, b_seed={b_seed}")
|
||||
out: dict[str, dict] = {}
|
||||
for name, linear in targets:
|
||||
d_out, d_in = linear.weight.shape
|
||||
dev, dtype = linear.weight.device, linear.weight.dtype
|
||||
B = (torch.randn(d_out, r, generator=g) / (r ** 0.5)).to(device=dev, dtype=dtype)
|
||||
linear.register_buffer("_lora_B", B, persistent=True)
|
||||
A = nn.Parameter(torch.zeros(r, d_in, device=dev, dtype=torch.float32)) # init 0 -> identity
|
||||
A_hack = nn.Parameter(torch.zeros(r, d_in, device=dev, dtype=torch.float32))
|
||||
linear.register_parameter("_lora_A", A)
|
||||
linear.register_parameter("_lora_A_hack", A_hack)
|
||||
linear._lora_grad_probe = grad_probe
|
||||
linear._lora_h = None
|
||||
linear._lora_x = None
|
||||
info = {"layer": linear, "delta_S": A, "delta_S_hack": A_hack,
|
||||
"handle": linear.register_forward_hook(_lora_hook), "r": r, "B": B}
|
||||
out[name] = info
|
||||
trainable = ("_lora_A", "_lora_A_hack")
|
||||
for n, p in model.named_parameters():
|
||||
if not n.endswith(trainable):
|
||||
p.requires_grad_(False)
|
||||
return out
|
||||
|
||||
|
||||
def _lora2r_hook(layer: nn.Linear, args: tuple, y: Tensor) -> Tensor:
|
||||
"""Rank-2r PiSSA-init LoRA, two blocks: deployed [:r] + quarantine [r:].
|
||||
|
||||
y += B@(A@x) - B0@(A0@x)
|
||||
|
||||
A0/B0 are FROZEN copies of the PiSSA init, so the net delta is exactly 0 at
|
||||
init while h = A@x is alive. (A zero-init would kill the c-probe weight-grad
|
||||
space below AND pair extraction at step 0 -- the whole reason for PiSSA init.)
|
||||
[B|B_q] @ ([A;A_q]@x) has no cross terms (column b_k only ever multiplies row
|
||||
a_k), so the two blocks ARE two independent adapters; block masks on this one
|
||||
tensor implement the SGTM parameter partition.
|
||||
|
||||
Block masks (layer._lora2r_mask = (m, d), set by train.py per loss pass;
|
||||
None = unmasked for generation / gate pass / eval):
|
||||
m [G] quarantine on/off -- m=0: quarantine zero in forward AND backward
|
||||
(SGTM retain trick: deployed trains in its post-ablation state)
|
||||
d [G] deployed detach -- d=1: deployed kept in forward, zero grad
|
||||
(hack-gated rollouts update ONLY the quarantine block)
|
||||
Masks act on branch OUTPUTS so a detach blocks grads to BOTH the A rows and
|
||||
the B columns of that block.
|
||||
|
||||
grad probe: c = ones[..., 2r] spliced as h*c. After backward
|
||||
c.grad = h ⊙ (Bᵀδ_y) = the per-sample WEIGHT grad of a virtual diagonal
|
||||
scale between A and B -- the lora2r analog of delta_S.grad (coincides with
|
||||
the SVD delta_S space at init, so pair extraction ports unchanged).
|
||||
"""
|
||||
(x,) = args
|
||||
A = layer._lora2r_A # [2r, d_in] trainable
|
||||
B = layer._lora2r_B # [d_out, 2r] trainable
|
||||
A0 = layer._lora2r_A0 # frozen PiSSA init copies (subtracted: net delta 0 at init)
|
||||
B0 = layer._lora2r_B0
|
||||
r = layer._lora2r_r
|
||||
h = torch.nn.functional.linear(x, A.to(x.dtype)) # [..., 2r]
|
||||
if layer._lora2r_grad_probe and torch.is_grad_enabled():
|
||||
c = torch.ones(h.shape[0], *([1] * (h.dim() - 2)), h.shape[-1],
|
||||
device=h.device, dtype=h.dtype, requires_grad=True)
|
||||
layer._lora2r_gate = c
|
||||
h = h * c
|
||||
h0 = torch.nn.functional.linear(x, A0.to(x.dtype)) # [..., 2r] frozen init path
|
||||
dep = (torch.nn.functional.linear(h[..., :r], B[:, :r].to(x.dtype))
|
||||
- torch.nn.functional.linear(h0[..., :r], B0[:, :r].to(x.dtype)))
|
||||
quar = (torch.nn.functional.linear(h[..., r:], B[:, r:].to(x.dtype))
|
||||
- torch.nn.functional.linear(h0[..., r:], B0[:, r:].to(x.dtype)))
|
||||
if layer._lora2r_mask is not None:
|
||||
m, d = layer._lora2r_mask # [G] each
|
||||
G = m.shape[0]
|
||||
shape = dep.shape # [G, s, d_out] or [G*s, d_out]
|
||||
dep = dep.reshape(G, -1, shape[-1])
|
||||
quar = quar.reshape(G, -1, shape[-1])
|
||||
d_ = d.view(G, 1, 1).to(dep.dtype)
|
||||
dep = ((1 - d_) * dep + d_ * dep.detach()).reshape(shape)
|
||||
quar = (m.view(G, 1, 1).to(quar.dtype) * quar).reshape(shape)
|
||||
return y + (dep + quar).to(y.dtype)
|
||||
|
||||
|
||||
def wrap_model_with_lora2r(
|
||||
model: nn.Module,
|
||||
model_name: str,
|
||||
cache_root: Path = Path("svd_cache"),
|
||||
svd_device: torch.device | str = "cuda",
|
||||
r: int = 32,
|
||||
grad_probe: bool = False,
|
||||
) -> dict[str, dict]:
|
||||
"""Attach a rank-2r PiSSA-init LoRA (A AND B trainable) to every target Linear.
|
||||
|
||||
PiSSA init: A0 = sqrt(S)·Vh, B0 = U·sqrt(S) on the top-2r SVD axes of W,
|
||||
ALTERNATED between the blocks (deployed even axes, quarantine odd) so the two
|
||||
blocks are spectrum-matched. W stays untouched; the hook subtracts the frozen
|
||||
A0/B0 contribution (unlike PiSSA proper, which edits W). The quarantine's
|
||||
learned delta is (A[r:], B[:, r:]) minus init; deploy ablation resets that
|
||||
block to A0/B0 (eval.ablate_quarantine).
|
||||
|
||||
Info dict per module: {layer, delta_S=A, B, A0, B0, handle, r} -- no separate
|
||||
delta_S_hack tensor; quarantine = block slices. Consumers branch on "A0".
|
||||
"""
|
||||
svd_device_t = torch.device(svd_device) if isinstance(svd_device, str) else svd_device
|
||||
svd_dir = cache_root / model_name.replace("/", "__")
|
||||
targets = [(n, m) for n, m in model.named_modules()
|
||||
if isinstance(m, nn.Linear) and is_target(n)]
|
||||
logger.info(f"lora2r attach: {len(targets)} target Linear modules, "
|
||||
f"r={r}/block (2r={2 * r}), PiSSA init, A+B trainable")
|
||||
out: dict[str, dict] = {}
|
||||
for name, linear in targets:
|
||||
W = linear.weight.data
|
||||
d_out, d_in = W.shape
|
||||
assert 2 * r <= min(d_out, d_in), \
|
||||
f"{name}: 2r={2 * r} exceeds min(d_out,d_in)={min(d_out, d_in)}; lower --lora-r"
|
||||
U, S, Vh = svd_cached(W, svd_dir / f"{name}.pt", device=svd_device_t)
|
||||
# Alternate the top-2r axes: deployed gets even ranks, quarantine odd.
|
||||
order = torch.cat([torch.arange(0, 2 * r, 2), torch.arange(1, 2 * r, 2)])
|
||||
sqrtS = S[:2 * r].sqrt()[order]
|
||||
dev = W.device
|
||||
A0 = (sqrtS.unsqueeze(1) * Vh[:2 * r][order]).to(device=dev, dtype=torch.float32) # [2r, d_in]
|
||||
B0 = (U[:, :2 * r][:, order] * sqrtS).to(device=dev, dtype=torch.float32) # [d_out, 2r]
|
||||
linear.register_buffer("_lora2r_A0", A0, persistent=True)
|
||||
linear.register_buffer("_lora2r_B0", B0, persistent=True)
|
||||
A = nn.Parameter(A0.clone())
|
||||
B = nn.Parameter(B0.clone())
|
||||
linear.register_parameter("_lora2r_A", A)
|
||||
linear.register_parameter("_lora2r_B", B)
|
||||
linear._lora2r_r = r
|
||||
linear._lora2r_grad_probe = grad_probe
|
||||
linear._lora2r_gate = None
|
||||
linear._lora2r_mask = None
|
||||
out[name] = {"layer": linear, "delta_S": A, "B": B, "A0": A0, "B0": B0,
|
||||
"handle": linear.register_forward_hook(_lora2r_hook), "r": r}
|
||||
trainable = ("_lora2r_A", "_lora2r_B")
|
||||
for n, p in model.named_parameters():
|
||||
if not n.endswith(trainable):
|
||||
p.requires_grad_(False)
|
||||
return out
|
||||
|
||||
|
||||
def wrap_model_with_antipasto(
|
||||
model: nn.Module,
|
||||
model_name: str,
|
||||
cache_root: Path = Path("svd_cache"),
|
||||
svd_device: torch.device | str = "cuda",
|
||||
grad_probe: bool = False,
|
||||
) -> dict[str, dict]:
|
||||
"""Attach AntiPaSTO hooks to every target nn.Linear in `model` (in place).
|
||||
|
||||
Returns dict[qualified_name -> dict(layer, delta_S, delta_S_hack, handle, r)].
|
||||
Frozen U/Vh stored on the layer as buffers `_antipasto_{U,Vh}` in the
|
||||
layer's native dtype. delta_S/delta_S_hack kept in fp32 (tiny, ~r per module).
|
||||
|
||||
`grad_probe` (routeV only): splice a per-token gate c into the delta_S path so
|
||||
train.py can recover the per-rollout delta_S gradient and route flagged
|
||||
rollouts into delta_S_hack post-backward. Off -> plain forward (none/erase/route).
|
||||
"""
|
||||
svd_device_t = torch.device(svd_device) if isinstance(svd_device, str) else svd_device
|
||||
safe = model_name.replace("/", "__")
|
||||
svd_dir = cache_root / safe
|
||||
|
||||
targets: list[tuple[str, nn.Linear]] = [
|
||||
(n, m) for n, m in model.named_modules()
|
||||
if isinstance(m, nn.Linear) and is_target(n)
|
||||
]
|
||||
logger.info(f"AntiPaSTO attach: {len(targets)} target Linear modules in {model_name}")
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
out: dict[str, dict] = {}
|
||||
pbar = tqdm(targets, desc="AntiPaSTO attach", mininterval=60)
|
||||
for name, linear in pbar:
|
||||
W = linear.weight.data
|
||||
d_out, d_in = W.shape
|
||||
r = min(d_in, d_out)
|
||||
cache_path = svd_dir / f"{name}.pt"
|
||||
U, S, Vh = svd_cached(W, cache_path, device=svd_device_t)
|
||||
dev, dtype = W.device, W.dtype
|
||||
linear.register_buffer("_antipasto_U", U.to(device=dev, dtype=dtype), persistent=True)
|
||||
linear.register_buffer("_antipasto_Vh", Vh.to(device=dev, dtype=dtype), persistent=True)
|
||||
delta_S = nn.Parameter(torch.zeros(r, device=dev, dtype=torch.float32))
|
||||
delta_S_hack = nn.Parameter(torch.zeros(r, device=dev, dtype=torch.float32))
|
||||
linear.register_parameter("_antipasto_delta_S", delta_S)
|
||||
linear.register_parameter("_antipasto_delta_S_hack", delta_S_hack)
|
||||
info = {"layer": linear, "delta_S": delta_S,
|
||||
"delta_S_hack": delta_S_hack, "handle": None, "r": r}
|
||||
linear._antipasto_grad_probe = grad_probe # routeV: gate the delta_S path
|
||||
linear._antipasto_gate = None # grad-probe leaf, set per forward
|
||||
info["handle"] = linear.register_forward_hook(_delta_hook)
|
||||
out[name] = info
|
||||
|
||||
# freeze everything except the two AntiPaSTO diagonals.
|
||||
trainable = ("_antipasto_delta_S", "_antipasto_delta_S_hack")
|
||||
for n, p in model.named_parameters():
|
||||
if not n.endswith(trainable):
|
||||
p.requires_grad_(False)
|
||||
return out
|
||||
Reference in New Issue
Block a user