diff --git a/scripts/metamath_gsm8k_benchmark.py b/scripts/metamath_gsm8k_benchmark.py index 806227e..cb6a5bf 100644 --- a/scripts/metamath_gsm8k_benchmark.py +++ b/scripts/metamath_gsm8k_benchmark.py @@ -604,9 +604,9 @@ def run(args: BenchmarkConfig) -> dict[str, Any]: # downstream task (IPM mode, per CorDA). eva needs only a few batches for its init; # corda/asvd/cov-orient estimate an input second moment, so we hand them many more # batches (PEFT calibrates on a few hundred sequences) for a well-conditioned basis. - needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd") or ( - args.variant == "antipasto_ablate" and args.antipasto_cov_orient - ) + # antipasto_ablate always calibrates now: group_init warm-starts lora_c from the + # S-space output variance (cov_orient adds the heavier CorDA re-orient on top). + needs_calib = args.variant in ("eva", "antipasto_corda", "antipasto_asvd", "antipasto_ablate") init_meter = group_init_meter() # wall-time + peak CPU RAM of group_init if needs_calib: n_batches = min(4, len(batches)) if args.variant == "eva" else min(64, len(batches)) diff --git a/src/lora_lite/variants/antipasto_ablate.py b/src/lora_lite/variants/antipasto_ablate.py index 13806df..3711bcb 100644 --- a/src/lora_lite/variants/antipasto_ablate.py +++ b/src/lora_lite/variants/antipasto_ablate.py @@ -85,29 +85,43 @@ class AntiPaSTOAblate: layer.lora_Vh.copy_(Vhr.to(layer.lora_Vh.dtype)) W_res = (W - (Ur * Sr) @ Vhr).to(layer.weight.dtype) layer.weight.data.copy_(W_res) - # FIXME: lora_c is random-init. A group_init warm-start from the S-space - # contrastive direction dS (cf. sspace.py extract) would converge faster and - # land on the behavior direction; not implemented -- random trains, just slower. + # lora_c starts random here; group_init warm-starts it from the S-space output + # variance when calibration_data is given (see group_init), else it trains from noise. @staticmethod def group_init(model: nn.Module, targets, cfg, calibration_data: CalibrationData | None) -> None: - """If cov_orient, re-orient each target's SVD by input covariance C=E[x x^T] - (CorDA) so the data-relevant output directions land in the top-r and the - behavior direction is fully ablatable at low r. No-op otherwise (keeps the - plain-SVD basis from init()). C is accumulated on CPU; for down_proj's large - d_in this is heavy -- exclude it or use plain ablation there.""" - if not getattr(cfg, "cov_orient", False) or calibration_data is None: - return + """Warm-start each lora_c from calibration activations (and, if cov_orient, + re-orient the frozen SVD by input covariance C=E[x xᵀ] first, CorDA-style). + lora_c is seeded to the top-k principal axes of the S-space OUTPUT coords + h = diag(S) Vh x over the calibration set: the highest-energy output directions, + where the loss-gradient on the ablation strength is largest, so lora_c starts in a + high-gradient region instead of a near-orthogonal random one. NOTE this is the data + VARIANCE direction, not a contrastive behavior direction -- this benchmark is SFT + with no pos/neg split. For steering with contrastive pairs, seed lora_c from + mean(h|pos) - mean(h|neg) instead (cf. steering-lite sspace extract). + + Σ xxᵀ (d_in², heavy for down_proj) is only accumulated to orient; the warm-start + alone (cov_orient=False) needs just the cheap r×r second moment Σ hhᵀ.""" + if calibration_data is None: + return + orient = bool(getattr(cfg, "cov_orient", False)) layers = {name: layer for name, layer, _ in targets} - cov: dict[str, T] = {} + gram: dict[str, T] = {} # Σ xxᵀ (d_in²), only when orienting + mom: dict[str, T] = {} # Σ hhᵀ (r²), when not orienting (basis is fixed at init) cnt: dict[str, int] = {n: 0 for n in layers} def make_hook(name): + layer = layers[name] def _h(module, args, kwargs): x = rearrange(args[0].detach(), "... d -> (...) d").to(torch.float32).cpu() - g = x.T @ x - cov[name] = g if name not in cov else cov[name] + g + if orient: + g = x.T @ x + gram[name] = g if name not in gram else gram[name] + g + else: + h = (x @ layer.lora_Vh.float().cpu().T) * layer.lora_S.float().cpu() + m = h.T @ h + mom[name] = m if name not in mom else mom[name] + m cnt[name] += x.shape[0] return _h @@ -129,32 +143,41 @@ class AntiPaSTOAblate: for h in handles: h.remove() - r, eps = cfg.r, float(cfg.cov_eps) + r, k, eps = cfg.r, cfg.k, float(cfg.cov_eps) for name, layer in layers.items(): if cnt[name] < r: raise RuntimeError(f"AntiPaSTOAblate at {name}: {cnt[name]} tokens, need >= r={r}") - W_res = layer.weight.data.float().cpu() - U_old, S_old, Vh_old = (layer.lora_U.float().cpu(), - layer.lora_S.float().cpu(), - layer.lora_Vh.float().cpu()) - W_orig = W_res + (U_old * S_old) @ Vh_old + if orient: + W_res = layer.weight.data.float().cpu() + U_old, S_old, Vh_old = (layer.lora_U.float().cpu(), + layer.lora_S.float().cpu(), + layer.lora_Vh.float().cpu()) + W_orig = W_res + (U_old * S_old) @ Vh_old - C = cov[name] / cnt[name] - lam, Q = torch.linalg.eigh(C) - lam = lam.clamp_min(0) + eps - Chalf = (Q * lam.sqrt()) @ Q.T - Cinvhalf = (Q * lam.rsqrt()) @ Q.T - Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False) - Ur = Ut[:, :r] # orthonormal output basis (ablation acts here) - Sr = St[:r] - Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only) - W_res_new = W_orig - (Ur * Sr) @ Pr + C = gram[name] / cnt[name] + lam, Q = torch.linalg.eigh(C) + lam = lam.clamp_min(0) + eps + Chalf = (Q * lam.sqrt()) @ Q.T + Cinvhalf = (Q * lam.rsqrt()) @ Q.T + Ut, St, Vht = torch.linalg.svd(W_orig @ Chalf, full_matrices=False) + Ur = Ut[:, :r] # orthonormal output basis (ablation acts here) + Sr = St[:r] + Pr = Vht[:r] @ Cinvhalf # oblique input projector (input-side only) + W_res_new = W_orig - (Ur * Sr) @ Pr + with torch.no_grad(): + layer.lora_U.copy_(Ur.to(layer.lora_U)) + layer.lora_S.copy_(Sr.to(layer.lora_S)) + layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot + layer.weight.data.copy_(W_res_new.to(layer.weight)) + # output S-space second moment in the (now oriented) basis: diag(S) P Σxxᵀ Pᵀ diag(S) + SP = Sr[:, None] * Pr + M = SP @ gram[name] @ SP.T + else: + M = mom[name] # (r, r) Σ hhᵀ in the init basis + c0 = torch.linalg.eigh(M).eigenvectors[:, -k:] # top-k principal dirs (orthonormal) with torch.no_grad(): - layer.lora_U.copy_(Ur.to(layer.lora_U)) - layer.lora_S.copy_(Sr.to(layer.lora_S)) - layer.lora_Vh.copy_(Pr.to(layer.lora_Vh)) # store P in the Vh slot - layer.weight.data.copy_(W_res_new.to(layer.weight)) + layer.lora_c.copy_(c0.to(layer.lora_c)) @staticmethod def _orthonormal(c: T) -> T: