antipasto_ablate: warm-start lora_c from S-space output variance

group_init now seeds each lora_c to the top-k principal axes of the S-space
output coords h=diag(S)Vh x (highest-energy output dirs => largest loss-grad on
the ablation strength), so lora_c starts in a high-gradient region not random.
Cheap r x r second moment when not orienting; reuses Sigma xx^T when cov_orient.
Benchmark always calibrates ablate now. This is the data-variance direction, not
a contrastive behavior dir (SFT has no pos/neg split) -- noted in the docstring.

UAT: |cos(lora_c, top output-PC)| = 1.0000 vs ~0.35 chance; smoke green.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-17 18:18:32 +08:00
parent 6cb350a4b6
commit fe562c2b5c
2 changed files with 59 additions and 36 deletions
+3 -3
View File
@@ -604,9 +604,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
# downstream task (IPM mode, per CorDA). eva needs only a few batches for its init; # downstream task (IPM mode, per CorDA). eva needs only a few batches for its init;
# corda/asvd/cov-orient estimate an input second moment, so we hand them many more # corda/asvd/cov-orient estimate an input second moment, so we hand them many more
# batches (PEFT calibrates on a few hundred sequences) for a well-conditioned basis. # batches (PEFT calibrates on a few hundred sequences) for a well-conditioned basis.
needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd") or ( # antipasto_ablate always calibrates now: group_init warm-starts lora_c from the
args.variant == "antipasto_ablate" and args.antipasto_cov_orient # S-space output variance (cov_orient adds the heavier CorDA re-orient on top).
) needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd", "antipasto_ablate")
init_meter = group_init_meter() # wall-time + peak CPU RAM of group_init init_meter = group_init_meter() # wall-time + peak CPU RAM of group_init
if needs_calib: if needs_calib:
n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches)) n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches))
+56 -33
View File
@@ -85,29 +85,43 @@ class AntiPaSTOAblate:
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
layer.weight.data.copy_(W_res) layer.weight.data.copy_(W_res)
# FIXME: lora_c is random-init. A group_init warm-start from the S-space # lora_c starts random here; group_init warm-starts it from the S-space output
# contrastive direction dS (cf. sspace.py extract) would converge faster and # variance when calibration_data is given (see group_init), else it trains from noise.
# land on the behavior direction; not implemented -- random trains, just slower.
@staticmethod @staticmethod
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
"""If cov_orient, re-orient each target's SVD by input covariance C=E[x x^T] """Warm-start each lora_c from calibration activations (and, if cov_orient,
(CorDA) so the data-relevant output directions land in the top-r and the re-orient the frozen SVD by input covariance C=E[x xᵀ] first, CorDA-style).
behavior direction is fully ablatable at low r. No-op otherwise (keeps the
plain-SVD basis from init()). C is accumulated on CPU; for down_proj's large
d_in this is heavy -- exclude it or use plain ablation there."""
if not getattr(cfg, "cov_orient", False) or calibration_data is None:
return
lora_c is seeded to the top-k principal axes of the S-space OUTPUT coords
h = diag(S) Vh x over the calibration set: the highest-energy output directions,
where the loss-gradient on the ablation strength is largest, so lora_c starts in a
high-gradient region instead of a near-orthogonal random one. NOTE this is the data
VARIANCE direction, not a contrastive behavior direction -- this benchmark is SFT
with no pos/neg split. For steering with contrastive pairs, seed lora_c from
mean(h|pos) - mean(h|neg) instead (cf. steering-lite sspace extract).
Σ xxᵀ (d_in², heavy for down_proj) is only accumulated to orient; the warm-start
alone (cov_orient=False) needs just the cheap r×r second moment Σ hhᵀ."""
if calibration_data is None:
return
orient = bool(getattr(cfg, "cov_orient", False))
layers = {name: layer for name, layer, _ in targets} layers = {name: layer for name, layer, _ in targets}
cov: dict[str, T] = {} gram: dict[str, T] = {} # Σ xxᵀ (d_in²), only when orienting
mom: dict[str, T] = {} # Σ hhᵀ (r²), when not orienting (basis is fixed at init)
cnt: dict[str, int] = {n: 0 for n in layers} cnt: dict[str, int] = {n: 0 for n in layers}
def make_hook(name): def make_hook(name):
layer = layers[name]
def _h(module, args, kwargs): def _h(module, args, kwargs):
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu() x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
g = x.T @ x if orient:
cov[name] = g if name not in cov else cov[name] + g g = x.T @ x
gram[name] = g if name not in gram else gram[name] + g
else:
h = (x @ layer.lora_Vh.float().cpu().T) * layer.lora_S.float().cpu()
m = h.T @ h
mom[name] = m if name not in mom else mom[name] + m
cnt[name] += x.shape[0] cnt[name] += x.shape[0]
return _h return _h
@@ -129,32 +143,41 @@ class AntiPaSTOAblate:
for h in handles: for h in handles:
h.remove() h.remove()
r, eps = cfg.r, float(cfg.cov_eps) r, k, eps = cfg.r, cfg.k, float(cfg.cov_eps)
for name, layer in layers.items(): for name, layer in layers.items():
if cnt[name] < r: if cnt[name] < r:
raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}") raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}")
W_res = layer.weight.data.float().cpu() if orient:
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(), W_res = layer.weight.data.float().cpu()
layer.lora_S.float().cpu(), U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
layer.lora_Vh.float().cpu()) layer.lora_S.float().cpu(),
W_orig = W_res + (U_old * S_old) @ Vh_old layer.lora_Vh.float().cpu())
W_orig = W_res + (U_old * S_old) @ Vh_old
C = cov[name] / cnt[name] C = gram[name] / cnt[name]
lam, Q = torch.linalg.eigh(C) lam, Q = torch.linalg.eigh(C)
lam = lam.clamp_min(0) + eps lam = lam.clamp_min(0) + eps
Chalf = (Q * lam.sqrt()) @ Q.T Chalf = (Q * lam.sqrt()) @ Q.T
Cinvhalf = (Q * lam.rsqrt()) @ Q.T Cinvhalf = (Q * lam.rsqrt()) @ Q.T
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False) Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here) Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
Sr = St[:r] Sr = St[:r]
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only) Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
W_res_new = W_orig - (Ur * Sr) @ Pr W_res_new = W_orig - (Ur * Sr) @ Pr
with torch.no_grad():
layer.lora_U.copy_(Ur.to(layer.lora_U))
layer.lora_S.copy_(Sr.to(layer.lora_S))
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
layer.weight.data.copy_(W_res_new.to(layer.weight))
# output S-space second moment in the (now oriented) basis: diag(S) P Σxxᵀ Pᵀ diag(S)
SP = Sr[:, None] * Pr
M = SP @ gram[name] @ SP.T
else:
M = mom[name] # (r, r) Σ hhᵀ in the init basis
c0 = torch.linalg.eigh(M).eigenvectors[:, -k:] # top-k principal dirs (orthonormal)
with torch.no_grad(): with torch.no_grad():
layer.lora_U.copy_(Ur.to(layer.lora_U)) layer.lora_c.copy_(c0.to(layer.lora_c))
layer.lora_S.copy_(Sr.to(layer.lora_S))
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
layer.weight.data.copy_(W_res_new.to(layer.weight))
@staticmethod @staticmethod
def _orthonormal(c: T) -> T: def _orthonormal(c: T) -> T: