Revert ablate lora_c warm-start: variance-PC seed didn't help on SFT

Job 94 result (Qwen3.5-0.8B, GSM8K, 2500 steps, single seed):
  warm-start (top-k S-space output-variance PC):  test 55.6 / valid 64.0, init 33.2s
  random-init (prior default):                    test 56.0 / valid 68.0, init  2.2s

Equal-or-worse accuracy (within single-seed noise) for +31s of calibration init.
The optimal ablation direction is loss-defined, not variance-defined, so seeding
lora_c from the data-variance PC buys nothing here. Reverts fe562c2; ablate is
back to the cheap random-init default. cov_orient (CorDA re-orient) path kept.
The FIXME's actual proposal -- a *contrastive* dS seed -- stays open but needs
pos/neg pairs this SFT benchmark lacks (only relevant for labelled steering).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-17 20:18:41 +08:00
parent 458c3861e8
commit 09dcfe0d41
2 changed files with 40 additions and 59 deletions
+3 -3
View File
@@ -604,9 +604,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]:
# downstream task (IPM mode, per CorDA). eva needs only a few batches for its init;
# corda/asvd/cov-orient estimate an input second moment, so we hand them many more
# batches (PEFT calibrates on a few hundred sequences) for a well-conditioned basis.
# antipasto_ablate always calibrates now: group_init warm-starts lora_c from the
# S-space output variance (cov_orient adds the heavier CorDA re-orient on top).
needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd", "antipasto_ablate")
needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd") or (
args.variant == "antipasto_ablate" and args.antipasto_cov_orient
)
init_meter = group_init_meter() # wall-time + peak CPU RAM of group_init
if needs_calib:
n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches))
+37 -56
View File
@@ -85,43 +85,33 @@ class AntiPaSTOAblate:
layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype))
W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype)
layer.weight.data.copy_(W_res)
# lora_c starts random here; group_init warm-starts it from the S-space output
# variance when calibration_data is given (see group_init), else it trains from noise.
# lora_c is random-init. Tried (job 94) seeding it from the top-k S-space
# output-VARIANCE PC: equal-or-worse on GSM8K (55.6/64.0 vs random 56.0/68.0,
# single seed) and +31s calib init -- the optimal ablation dir is loss-defined,
# not variance-defined, so a variance seed buys nothing on SFT. Reverted.
# FIXME the contrastive dS seed (mean(h|pos)-mean(h|neg), cf. sspace.py) is the
# one that should land on the behavior dir, but it needs pos/neg pairs this SFT
# benchmark lacks -- only worth it for steering with labelled contrastive data.
@staticmethod
def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None:
"""Warm-start each lora_c from calibration activations (and, if cov_orient,
re-orient the frozen SVD by input covariance C=E[x xᵀ] first, CorDA-style).
lora_c is seeded to the top-k principal axes of the S-space OUTPUT coords
h = diag(S) Vh x over the calibration set: the highest-energy output directions,
where the loss-gradient on the ablation strength is largest, so lora_c starts in a
high-gradient region instead of a near-orthogonal random one. NOTE this is the data
VARIANCE direction, not a contrastive behavior direction -- this benchmark is SFT
with no pos/neg split. For steering with contrastive pairs, seed lora_c from
mean(h|pos) - mean(h|neg) instead (cf. steering-lite sspace extract).
Σ xxᵀ (d_in², heavy for down_proj) is only accumulated to orient; the warm-start
alone (cov_orient=False) needs just the cheap r×r second moment Σ hhᵀ."""
if calibration_data is None:
"""If cov_orient, re-orient each target's SVD by input covariance C=E[x x^T]
(CorDA) so the data-relevant output directions land in the top-r and the
behavior direction is fully ablatable at low r. No-op otherwise (keeps the
plain-SVD basis from init()). C is accumulated on CPU; for down_proj's large
d_in this is heavy -- exclude it or use plain ablation there."""
if not getattr(cfg, "cov_orient", False) or calibration_data is None:
return
orient = bool(getattr(cfg, "cov_orient", False))
layers = {name: layer for name, layer, _ in targets}
gram: dict[str, T] = {} # Σ xxᵀ (d_in²), only when orienting
mom: dict[str, T] = {} # Σ hhᵀ (r²), when not orienting (basis is fixed at init)
cov: dict[str, T] = {}
cnt: dict[str, int] = {n: 0 for n in layers}
def make_hook(name):
layer = layers[name]
def _h(module, args, kwargs):
x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu()
if orient:
g = x.T @ x
gram[name] = g if name not in gram else gram[name] + g
else:
h = (x @ layer.lora_Vh.float().cpu().T) * layer.lora_S.float().cpu()
m = h.T @ h
mom[name] = m if name not in mom else mom[name] + m
g = x.T @ x
cov[name] = g if name not in cov else cov[name] + g
cnt[name] += x.shape[0]
return _h
@@ -143,41 +133,32 @@ class AntiPaSTOAblate:
for h in handles:
h.remove()
r, k, eps = cfg.r, cfg.k, float(cfg.cov_eps)
r, eps = cfg.r, float(cfg.cov_eps)
for name, layer in layers.items():
if cnt[name] < r:
raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}")
if orient:
W_res = layer.weight.data.float().cpu()
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
layer.lora_S.float().cpu(),
layer.lora_Vh.float().cpu())
W_orig = W_res + (U_old * S_old) @ Vh_old
W_res = layer.weight.data.float().cpu()
U_old, S_old, Vh_old = (layer.lora_U.float().cpu(),
layer.lora_S.float().cpu(),
layer.lora_Vh.float().cpu())
W_orig = W_res + (U_old * S_old) @ Vh_old
C = gram[name] / cnt[name]
lam, Q = torch.linalg.eigh(C)
lam = lam.clamp_min(0) + eps
Chalf = (Q * lam.sqrt()) @ Q.T
Cinvhalf = (Q * lam.rsqrt()) @ Q.T
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
Sr = St[:r]
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
W_res_new = W_orig - (Ur * Sr) @ Pr
with torch.no_grad():
layer.lora_U.copy_(Ur.to(layer.lora_U))
layer.lora_S.copy_(Sr.to(layer.lora_S))
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
layer.weight.data.copy_(W_res_new.to(layer.weight))
# output S-space second moment in the (now oriented) basis: diag(S) P Σxxᵀ Pᵀ diag(S)
SP = Sr[:, None] * Pr
M = SP @ gram[name] @ SP.T
else:
M = mom[name] # (r, r) Σ hhᵀ in the init basis
C = cov[name] / cnt[name]
lam, Q = torch.linalg.eigh(C)
lam = lam.clamp_min(0) + eps
Chalf = (Q * lam.sqrt()) @ Q.T
Cinvhalf = (Q * lam.rsqrt()) @ Q.T
Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False)
Ur = Ut[:, :r] # orthonormal output basis (ablation acts here)
Sr = St[:r]
Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only)
W_res_new = W_orig - (Ur * Sr) @ Pr
c0 = torch.linalg.eigh(M).eigenvectors[:, -k:] # top-k principal dirs (orthonormal)
with torch.no_grad():
layer.lora_c.copy_(c0.to(layer.lora_c))
layer.lora_U.copy_(Ur.to(layer.lora_U))
layer.lora_S.copy_(Sr.to(layer.lora_S))
layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot
layer.weight.data.copy_(W_res_new.to(layer.weight))
@staticmethod
def _orthonormal(c: T) -> T: