From 62c6794e30dc3f7a3bad7ed9fd8d4dd5d416ffdf Mon Sep 17 00:00:00 2001 From: wassname Date: Fri, 29 May 2026 10:21:01 +0000 Subject: [PATCH] prune: drop mean_diff and solve_orth_m extractor options Both were negative results (docs Q4, Q9) and are now dead weight. Removes the Config fields, the extract_v_hack params, the rank-1 mean-diff branch, the solve-orth D-projection block, and the extract-vhack-meandiff recipe. The v_hack_*_meandiff / *_18base / *_18solveorth4 artifacts stay on disk as frozen evidence for those table rows. Smoke passes. Co-Authored-By: Claude Opus 4.8 --- justfile | 9 --- src/projected_grpo/extract_vhack_grad.py | 88 ++++++------------------ 2 files changed, 20 insertions(+), 77 deletions(-) diff --git a/justfile b/justfile index e4b06b9..8ce8954 100644 --- a/justfile +++ b/justfile @@ -115,15 +115,6 @@ extract-vhack-full: --out-path=out/v_hack_full.safetensors \ --train-grads-path=out/vhack_grads_train_full.safetensors -# Rank-1 mean-diff basis (alternative to SVD top-k). Honest under small N. -extract-vhack-meandiff: - uv run python -m projected_grpo.extract_vhack_grad \ - --model=Qwen/Qwen3-4B \ - --dtype=bf16 \ - --mean-diff \ - --out-path=out/v_hack_full_meandiff.safetensors \ - --train-grads-path=out/vhack_grads_train_meandiff.safetensors - verify-vhack-smoke: uv run python -m projected_grpo.verify_vhack_heldout \ --model=Qwen/Qwen3.5-0.8B \ diff --git a/src/projected_grpo/extract_vhack_grad.py b/src/projected_grpo/extract_vhack_grad.py index 0cac487..cfd4e34 100644 --- a/src/projected_grpo/extract_vhack_grad.py +++ b/src/projected_grpo/extract_vhack_grad.py @@ -71,20 +71,6 @@ class Config: # v_hack from a half-A-only set of hacks to test cross-mechanism # generalisation (docs/spec/20260528_cross_mechanism_v_hack.md). pairs_from_pool: Path | None = None - # Alternative basis extractor: rank-1 mean-diff direction per module instead - # of top-k SVD. v = mean(g_hack - g_clean) / ||mean(g_hack - g_clean)||. - # Motivation: with N=12 pairs, SVD axes 2..k fit per-axis noise (S_2/S_0 - # was small in current extracts). Mean-diff is the same direction as PCA- - # axis-1 under the assumption that the mean dominates, but it's robust to - # outlier pairs and doesn't waste rank on noise. Saved with k=1 -- train.py - # load_v_hack reads it the same way as SVD output. - mean_diff: bool = False - # solve_orth_m: if >0, strip the top-m "solve" directions (SVD of the clean- - # side gradients G_c, = grads toward our known-good hand-written solutions) - # out of D before extracting v_hack. 0 = off. Aims to keep the projection - # from ablating the legitimate solve signal (pass-rate selectivity). No - # grader/oracle is read — only the clean solutions we wrote. - solve_orth_m: int = 0 def resolve_dtype(s: str) -> torch.dtype: @@ -115,8 +101,6 @@ def extract_v_hack( tau_axis: float, n_heldout: int, device, - mean_diff: bool = False, - solve_orth_m: int = 0, ) -> tuple[ dict[str, Float[torch.Tensor, "k r"]], dict[str, Float[torch.Tensor, "k"]], @@ -177,61 +161,31 @@ def extract_v_hack( v_sv: dict[str, torch.Tensor] = {} rows = [] n_zero = 0 - k = 1 if mean_diff else min(top_k, n_pairs) + k = min(top_k, n_pairs) n_axes_kept_total = 0 for name in grads_hack: G_h = torch.stack(grads_hack[name]) # [n_pairs, r] G_c = torch.stack(grads_clean[name]) D = G_h - G_c - if solve_orth_m > 0: - # Strip the known-solve subspace from D before extracting hack - # directions. B = top-m right singular vectors of G_c (the gradient - # toward our hand-written *correct* clean solutions = the "solve" - # direction; no grader/oracle used, just known-good solutions). - # D = G_h - G_c already carries -G_c, so the solve directions have - # real energy in D; removing them keeps projection from also - # ablating the solve signal (pass-rate selectivity). The SVD below - # then returns hack directions orthogonal to solve, still - # orthonormal, so S/orientation/noise-floor logic is unchanged. - m = min(solve_orth_m, G_c.shape[0]) - _, _, Bh = torch.linalg.svd(G_c, full_matrices=False) - B = Bh[:m] # [m, r], orthonormal solve basis - D = D - (D @ B.T) @ B # D_perp + U_d, S_d, Vh_d = torch.linalg.svd(D, full_matrices=False) + V = Vh_d[:k] # [k, r], rows orthonormal in R^r + # Orient by per-pair majority vote: for each axis i, count pairs where + # d_p @ v_i > 0; if strict majority disagree with current SVD sign, flip. + # More outlier-robust than sign(mean): one extreme pair can't flip a + # consensus direction. Matches repeng's _orient_svd convention. + proj_per_pair = D @ V.T # [n_pairs, k] + n_pos = (proj_per_pair > 0).float().sum(0) # [k] + flip = torch.where(n_pos < n_pairs / 2, -torch.ones(k), torch.ones(k)) + V = V * flip.unsqueeze(1) - if mean_diff: - # Rank-1 mean-diff direction. Honest under small N: SVD axes 2..k on - # N=12 pairs fit noise; mean-diff regularizes to the only direction - # the data robustly supports. v = mean(D)/||mean(D)||, oriented along - # mean(D) by construction so no flip is needed. - mean_D = D.mean(0) # [r] - mean_nrm = mean_D.norm() - if mean_nrm < 1e-12: - V = torch.zeros((1, D.shape[1]), dtype=D.dtype) - S_d = torch.zeros(1, dtype=D.dtype) - else: - V = (mean_D / mean_nrm).unsqueeze(0) # [1, r] - S_d = mean_nrm.unsqueeze(0) # treat ||mean(D)|| as the singular value - n_axes_kept = 1 if mean_nrm >= 1e-12 else 0 - else: - U_d, S_d, Vh_d = torch.linalg.svd(D, full_matrices=False) - V = Vh_d[:k] # [k, r], rows orthonormal in R^r - # Orient by per-pair majority vote: for each axis i, count pairs where - # d_p @ v_i > 0; if strict majority disagree with current SVD sign, flip. - # More outlier-robust than sign(mean): one extreme pair can't flip a - # consensus direction. Matches repeng's _orient_svd convention. - proj_per_pair = D @ V.T # [n_pairs, k] - n_pos = (proj_per_pair > 0).float().sum(0) # [k] - flip = torch.where(n_pos < n_pairs / 2, -torch.ones(k), torch.ones(k)) - V = V * flip.unsqueeze(1) - - # tau_axis: zero rows where S_i/S_0 < tau_axis (diagnostic; see Config comment). - n_axes_kept = k - if tau_axis > 0 and S_d[0] > 1e-12: - ratios = S_d[:k] / S_d[0] - keep = (ratios >= tau_axis).float() - V = V * keep.unsqueeze(1) - n_axes_kept = int(keep.sum()) + # tau_axis: zero rows where S_i/S_0 < tau_axis (diagnostic; see Config comment). + n_axes_kept = k + if tau_axis > 0 and S_d[0] > 1e-12: + ratios = S_d[:k] / S_d[0] + keep = (ratios >= tau_axis).float() + V = V * keep.unsqueeze(1) + n_axes_kept = int(keep.sum()) n_axes_kept_total += n_axes_kept nrm = D.norm() @@ -252,7 +206,7 @@ def extract_v_hack( "||D||": f"{nrm:.2e}", "sv_0": f"{S_d[0].item():.2e}" if S_d.numel() else "-", f"sv_top{k}_frac": f"{(sv_top.sum() / sv_total).item():.2f}", - "sv_ratio_0/1": ("-" if mean_diff or S_d.numel() < 2 + "sv_ratio_0/1": ("-" if S_d.numel() < 2 else f"{(S_d[0] / S_d[1].clamp_min(1e-12)).item():.2f}"), "axes_kept": n_axes_kept, }) @@ -297,11 +251,9 @@ def main(cfg: Config) -> int: model, tokenizer, wrappers, pairs, top_k=cfg.top_k, tau_axis=cfg.tau_axis, n_heldout=cfg.n_heldout, device=device, - mean_diff=cfg.mean_diff, - solve_orth_m=cfg.solve_orth_m, ) n_zero = sum(1 for v in v_hack.values() if v.norm() < 1e-12) - k = 1 if cfg.mean_diff else min(cfg.top_k, len(train_pairs)) + k = min(cfg.top_k, len(train_pairs)) OUT_DIR.mkdir(exist_ok=True) save_file(raw_grads, str(cfg.train_grads_path),