prune: drop mean_diff and solve_orth_m extractor options

Both were negative results (docs Q4, Q9) and are now dead weight. Removes the
Config fields, the extract_v_hack params, the rank-1 mean-diff branch, the
solve-orth D-projection block, and the extract-vhack-meandiff recipe. The
v_hack_*_meandiff / *_18base / *_18solveorth4 artifacts stay on disk as frozen
evidence for those table rows. Smoke passes.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-29 10:21:01 +00:00
parent 5d83adbb25
commit 62c6794e30
2 changed files with 20 additions and 77 deletions
-9
View File
@@ -115,15 +115,6 @@ extract-vhack-full:
--out-path=out/v_hack_full.safetensors \
--train-grads-path=out/vhack_grads_train_full.safetensors
# Rank-1 mean-diff basis (alternative to SVD top-k). Honest under small N.
extract-vhack-meandiff:
uv run python -m projected_grpo.extract_vhack_grad \
--model=Qwen/Qwen3-4B \
--dtype=bf16 \
--mean-diff \
--out-path=out/v_hack_full_meandiff.safetensors \
--train-grads-path=out/vhack_grads_train_meandiff.safetensors
verify-vhack-smoke:
uv run python -m projected_grpo.verify_vhack_heldout \
--model=Qwen/Qwen3.5-0.8B \
+20 -68
View File
@@ -71,20 +71,6 @@ class Config:
# v_hack from a half-A-only set of hacks to test cross-mechanism
# generalisation (docs/spec/20260528_cross_mechanism_v_hack.md).
pairs_from_pool: Path | None = None
# Alternative basis extractor: rank-1 mean-diff direction per module instead
# of top-k SVD. v = mean(g_hack - g_clean) / ||mean(g_hack - g_clean)||.
# Motivation: with N=12 pairs, SVD axes 2..k fit per-axis noise (S_2/S_0
# was small in current extracts). Mean-diff is the same direction as PCA-
# axis-1 under the assumption that the mean dominates, but it's robust to
# outlier pairs and doesn't waste rank on noise. Saved with k=1 -- train.py
# load_v_hack reads it the same way as SVD output.
mean_diff: bool = False
# solve_orth_m: if >0, strip the top-m "solve" directions (SVD of the clean-
# side gradients G_c, = grads toward our known-good hand-written solutions)
# out of D before extracting v_hack. 0 = off. Aims to keep the projection
# from ablating the legitimate solve signal (pass-rate selectivity). No
# grader/oracle is read — only the clean solutions we wrote.
solve_orth_m: int = 0
def resolve_dtype(s: str) -> torch.dtype:
@@ -115,8 +101,6 @@ def extract_v_hack(
tau_axis: float,
n_heldout: int,
device,
mean_diff: bool = False,
solve_orth_m: int = 0,
) -> tuple[
dict[str, Float[torch.Tensor, "k r"]],
dict[str, Float[torch.Tensor, "k"]],
@@ -177,61 +161,31 @@ def extract_v_hack(
v_sv: dict[str, torch.Tensor] = {}
rows = []
n_zero = 0
k = 1 if mean_diff else min(top_k, n_pairs)
k = min(top_k, n_pairs)
n_axes_kept_total = 0
for name in grads_hack:
G_h = torch.stack(grads_hack[name]) # [n_pairs, r]
G_c = torch.stack(grads_clean[name])
D = G_h - G_c
if solve_orth_m > 0:
# Strip the known-solve subspace from D before extracting hack
# directions. B = top-m right singular vectors of G_c (the gradient
# toward our hand-written *correct* clean solutions = the "solve"
# direction; no grader/oracle used, just known-good solutions).
# D = G_h - G_c already carries -G_c, so the solve directions have
# real energy in D; removing them keeps projection from also
# ablating the solve signal (pass-rate selectivity). The SVD below
# then returns hack directions orthogonal to solve, still
# orthonormal, so S/orientation/noise-floor logic is unchanged.
m = min(solve_orth_m, G_c.shape[0])
_, _, Bh = torch.linalg.svd(G_c, full_matrices=False)
B = Bh[:m] # [m, r], orthonormal solve basis
D = D - (D @ B.T) @ B # D_perp
U_d, S_d, Vh_d = torch.linalg.svd(D, full_matrices=False)
V = Vh_d[:k] # [k, r], rows orthonormal in R^r
# Orient by per-pair majority vote: for each axis i, count pairs where
# d_p @ v_i > 0; if strict majority disagree with current SVD sign, flip.
# More outlier-robust than sign(mean): one extreme pair can't flip a
# consensus direction. Matches repeng's _orient_svd convention.
proj_per_pair = D @ V.T # [n_pairs, k]
n_pos = (proj_per_pair > 0).float().sum(0) # [k]
flip = torch.where(n_pos < n_pairs / 2, -torch.ones(k), torch.ones(k))
V = V * flip.unsqueeze(1)
if mean_diff:
# Rank-1 mean-diff direction. Honest under small N: SVD axes 2..k on
# N=12 pairs fit noise; mean-diff regularizes to the only direction
# the data robustly supports. v = mean(D)/||mean(D)||, oriented along
# mean(D) by construction so no flip is needed.
mean_D = D.mean(0) # [r]
mean_nrm = mean_D.norm()
if mean_nrm < 1e-12:
V = torch.zeros((1, D.shape[1]), dtype=D.dtype)
S_d = torch.zeros(1, dtype=D.dtype)
else:
V = (mean_D / mean_nrm).unsqueeze(0) # [1, r]
S_d = mean_nrm.unsqueeze(0) # treat ||mean(D)|| as the singular value
n_axes_kept = 1 if mean_nrm >= 1e-12 else 0
else:
U_d, S_d, Vh_d = torch.linalg.svd(D, full_matrices=False)
V = Vh_d[:k] # [k, r], rows orthonormal in R^r
# Orient by per-pair majority vote: for each axis i, count pairs where
# d_p @ v_i > 0; if strict majority disagree with current SVD sign, flip.
# More outlier-robust than sign(mean): one extreme pair can't flip a
# consensus direction. Matches repeng's _orient_svd convention.
proj_per_pair = D @ V.T # [n_pairs, k]
n_pos = (proj_per_pair > 0).float().sum(0) # [k]
flip = torch.where(n_pos < n_pairs / 2, -torch.ones(k), torch.ones(k))
V = V * flip.unsqueeze(1)
# tau_axis: zero rows where S_i/S_0 < tau_axis (diagnostic; see Config comment).
n_axes_kept = k
if tau_axis > 0 and S_d[0] > 1e-12:
ratios = S_d[:k] / S_d[0]
keep = (ratios >= tau_axis).float()
V = V * keep.unsqueeze(1)
n_axes_kept = int(keep.sum())
# tau_axis: zero rows where S_i/S_0 < tau_axis (diagnostic; see Config comment).
n_axes_kept = k
if tau_axis > 0 and S_d[0] > 1e-12:
ratios = S_d[:k] / S_d[0]
keep = (ratios >= tau_axis).float()
V = V * keep.unsqueeze(1)
n_axes_kept = int(keep.sum())
n_axes_kept_total += n_axes_kept
nrm = D.norm()
@@ -252,7 +206,7 @@ def extract_v_hack(
"||D||": f"{nrm:.2e}",
"sv_0": f"{S_d[0].item():.2e}" if S_d.numel() else "-",
f"sv_top{k}_frac": f"{(sv_top.sum() / sv_total).item():.2f}",
"sv_ratio_0/1": ("-" if mean_diff or S_d.numel() < 2
"sv_ratio_0/1": ("-" if S_d.numel() < 2
else f"{(S_d[0] / S_d[1].clamp_min(1e-12)).item():.2f}"),
"axes_kept": n_axes_kept,
})
@@ -297,11 +251,9 @@ def main(cfg: Config) -> int:
model, tokenizer, wrappers, pairs,
top_k=cfg.top_k, tau_axis=cfg.tau_axis,
n_heldout=cfg.n_heldout, device=device,
mean_diff=cfg.mean_diff,
solve_orth_m=cfg.solve_orth_m,
)
n_zero = sum(1 for v in v_hack.values() if v.norm() < 1e-12)
k = 1 if cfg.mean_diff else min(cfg.top_k, len(train_pairs))
k = min(cfg.top_k, len(train_pairs))
OUT_DIR.mkdir(exist_ok=True)
save_file(raw_grads, str(cfg.train_grads_path),