mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-07-05 13:43:38 +08:00
prune: drop mean_diff and solve_orth_m extractor options
Both were negative results (docs Q4, Q9) and are now dead weight. Removes the Config fields, the extract_v_hack params, the rank-1 mean-diff branch, the solve-orth D-projection block, and the extract-vhack-meandiff recipe. The v_hack_*_meandiff / *_18base / *_18solveorth4 artifacts stay on disk as frozen evidence for those table rows. Smoke passes. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -115,15 +115,6 @@ extract-vhack-full:
|
||||
--out-path=out/v_hack_full.safetensors \
|
||||
--train-grads-path=out/vhack_grads_train_full.safetensors
|
||||
|
||||
# Rank-1 mean-diff basis (alternative to SVD top-k). Honest under small N.
|
||||
extract-vhack-meandiff:
|
||||
uv run python -m projected_grpo.extract_vhack_grad \
|
||||
--model=Qwen/Qwen3-4B \
|
||||
--dtype=bf16 \
|
||||
--mean-diff \
|
||||
--out-path=out/v_hack_full_meandiff.safetensors \
|
||||
--train-grads-path=out/vhack_grads_train_meandiff.safetensors
|
||||
|
||||
verify-vhack-smoke:
|
||||
uv run python -m projected_grpo.verify_vhack_heldout \
|
||||
--model=Qwen/Qwen3.5-0.8B \
|
||||
|
||||
@@ -71,20 +71,6 @@ class Config:
|
||||
# v_hack from a half-A-only set of hacks to test cross-mechanism
|
||||
# generalisation (docs/spec/20260528_cross_mechanism_v_hack.md).
|
||||
pairs_from_pool: Path | None = None
|
||||
# Alternative basis extractor: rank-1 mean-diff direction per module instead
|
||||
# of top-k SVD. v = mean(g_hack - g_clean) / ||mean(g_hack - g_clean)||.
|
||||
# Motivation: with N=12 pairs, SVD axes 2..k fit per-axis noise (S_2/S_0
|
||||
# was small in current extracts). Mean-diff is the same direction as PCA-
|
||||
# axis-1 under the assumption that the mean dominates, but it's robust to
|
||||
# outlier pairs and doesn't waste rank on noise. Saved with k=1 -- train.py
|
||||
# load_v_hack reads it the same way as SVD output.
|
||||
mean_diff: bool = False
|
||||
# solve_orth_m: if >0, strip the top-m "solve" directions (SVD of the clean-
|
||||
# side gradients G_c, = grads toward our known-good hand-written solutions)
|
||||
# out of D before extracting v_hack. 0 = off. Aims to keep the projection
|
||||
# from ablating the legitimate solve signal (pass-rate selectivity). No
|
||||
# grader/oracle is read — only the clean solutions we wrote.
|
||||
solve_orth_m: int = 0
|
||||
|
||||
|
||||
def resolve_dtype(s: str) -> torch.dtype:
|
||||
@@ -115,8 +101,6 @@ def extract_v_hack(
|
||||
tau_axis: float,
|
||||
n_heldout: int,
|
||||
device,
|
||||
mean_diff: bool = False,
|
||||
solve_orth_m: int = 0,
|
||||
) -> tuple[
|
||||
dict[str, Float[torch.Tensor, "k r"]],
|
||||
dict[str, Float[torch.Tensor, "k"]],
|
||||
@@ -177,61 +161,31 @@ def extract_v_hack(
|
||||
v_sv: dict[str, torch.Tensor] = {}
|
||||
rows = []
|
||||
n_zero = 0
|
||||
k = 1 if mean_diff else min(top_k, n_pairs)
|
||||
k = min(top_k, n_pairs)
|
||||
n_axes_kept_total = 0
|
||||
for name in grads_hack:
|
||||
G_h = torch.stack(grads_hack[name]) # [n_pairs, r]
|
||||
G_c = torch.stack(grads_clean[name])
|
||||
D = G_h - G_c
|
||||
|
||||
if solve_orth_m > 0:
|
||||
# Strip the known-solve subspace from D before extracting hack
|
||||
# directions. B = top-m right singular vectors of G_c (the gradient
|
||||
# toward our hand-written *correct* clean solutions = the "solve"
|
||||
# direction; no grader/oracle used, just known-good solutions).
|
||||
# D = G_h - G_c already carries -G_c, so the solve directions have
|
||||
# real energy in D; removing them keeps projection from also
|
||||
# ablating the solve signal (pass-rate selectivity). The SVD below
|
||||
# then returns hack directions orthogonal to solve, still
|
||||
# orthonormal, so S/orientation/noise-floor logic is unchanged.
|
||||
m = min(solve_orth_m, G_c.shape[0])
|
||||
_, _, Bh = torch.linalg.svd(G_c, full_matrices=False)
|
||||
B = Bh[:m] # [m, r], orthonormal solve basis
|
||||
D = D - (D @ B.T) @ B # D_perp
|
||||
U_d, S_d, Vh_d = torch.linalg.svd(D, full_matrices=False)
|
||||
V = Vh_d[:k] # [k, r], rows orthonormal in R^r
|
||||
# Orient by per-pair majority vote: for each axis i, count pairs where
|
||||
# d_p @ v_i > 0; if strict majority disagree with current SVD sign, flip.
|
||||
# More outlier-robust than sign(mean): one extreme pair can't flip a
|
||||
# consensus direction. Matches repeng's _orient_svd convention.
|
||||
proj_per_pair = D @ V.T # [n_pairs, k]
|
||||
n_pos = (proj_per_pair > 0).float().sum(0) # [k]
|
||||
flip = torch.where(n_pos < n_pairs / 2, -torch.ones(k), torch.ones(k))
|
||||
V = V * flip.unsqueeze(1)
|
||||
|
||||
if mean_diff:
|
||||
# Rank-1 mean-diff direction. Honest under small N: SVD axes 2..k on
|
||||
# N=12 pairs fit noise; mean-diff regularizes to the only direction
|
||||
# the data robustly supports. v = mean(D)/||mean(D)||, oriented along
|
||||
# mean(D) by construction so no flip is needed.
|
||||
mean_D = D.mean(0) # [r]
|
||||
mean_nrm = mean_D.norm()
|
||||
if mean_nrm < 1e-12:
|
||||
V = torch.zeros((1, D.shape[1]), dtype=D.dtype)
|
||||
S_d = torch.zeros(1, dtype=D.dtype)
|
||||
else:
|
||||
V = (mean_D / mean_nrm).unsqueeze(0) # [1, r]
|
||||
S_d = mean_nrm.unsqueeze(0) # treat ||mean(D)|| as the singular value
|
||||
n_axes_kept = 1 if mean_nrm >= 1e-12 else 0
|
||||
else:
|
||||
U_d, S_d, Vh_d = torch.linalg.svd(D, full_matrices=False)
|
||||
V = Vh_d[:k] # [k, r], rows orthonormal in R^r
|
||||
# Orient by per-pair majority vote: for each axis i, count pairs where
|
||||
# d_p @ v_i > 0; if strict majority disagree with current SVD sign, flip.
|
||||
# More outlier-robust than sign(mean): one extreme pair can't flip a
|
||||
# consensus direction. Matches repeng's _orient_svd convention.
|
||||
proj_per_pair = D @ V.T # [n_pairs, k]
|
||||
n_pos = (proj_per_pair > 0).float().sum(0) # [k]
|
||||
flip = torch.where(n_pos < n_pairs / 2, -torch.ones(k), torch.ones(k))
|
||||
V = V * flip.unsqueeze(1)
|
||||
|
||||
# tau_axis: zero rows where S_i/S_0 < tau_axis (diagnostic; see Config comment).
|
||||
n_axes_kept = k
|
||||
if tau_axis > 0 and S_d[0] > 1e-12:
|
||||
ratios = S_d[:k] / S_d[0]
|
||||
keep = (ratios >= tau_axis).float()
|
||||
V = V * keep.unsqueeze(1)
|
||||
n_axes_kept = int(keep.sum())
|
||||
# tau_axis: zero rows where S_i/S_0 < tau_axis (diagnostic; see Config comment).
|
||||
n_axes_kept = k
|
||||
if tau_axis > 0 and S_d[0] > 1e-12:
|
||||
ratios = S_d[:k] / S_d[0]
|
||||
keep = (ratios >= tau_axis).float()
|
||||
V = V * keep.unsqueeze(1)
|
||||
n_axes_kept = int(keep.sum())
|
||||
n_axes_kept_total += n_axes_kept
|
||||
|
||||
nrm = D.norm()
|
||||
@@ -252,7 +206,7 @@ def extract_v_hack(
|
||||
"||D||": f"{nrm:.2e}",
|
||||
"sv_0": f"{S_d[0].item():.2e}" if S_d.numel() else "-",
|
||||
f"sv_top{k}_frac": f"{(sv_top.sum() / sv_total).item():.2f}",
|
||||
"sv_ratio_0/1": ("-" if mean_diff or S_d.numel() < 2
|
||||
"sv_ratio_0/1": ("-" if S_d.numel() < 2
|
||||
else f"{(S_d[0] / S_d[1].clamp_min(1e-12)).item():.2f}"),
|
||||
"axes_kept": n_axes_kept,
|
||||
})
|
||||
@@ -297,11 +251,9 @@ def main(cfg: Config) -> int:
|
||||
model, tokenizer, wrappers, pairs,
|
||||
top_k=cfg.top_k, tau_axis=cfg.tau_axis,
|
||||
n_heldout=cfg.n_heldout, device=device,
|
||||
mean_diff=cfg.mean_diff,
|
||||
solve_orth_m=cfg.solve_orth_m,
|
||||
)
|
||||
n_zero = sum(1 for v in v_hack.values() if v.norm() < 1e-12)
|
||||
k = 1 if cfg.mean_diff else min(cfg.top_k, len(train_pairs))
|
||||
k = min(cfg.top_k, len(train_pairs))
|
||||
|
||||
OUT_DIR.mkdir(exist_ok=True)
|
||||
save_file(raw_grads, str(cfg.train_grads_path),
|
||||
|
||||
Reference in New Issue
Block a user