Remove runtime suspicion gate

It was a fixed-budget regularizer dressed up as a detector — by
construction, quantile gate dropped exactly drop_top_frac of axes per
step regardless of whether anything was genuinely suspicious. The susp
diagnostic column was 100% determined by the config knob, zero
information content.

The principled defense against noise axes is extract-time tau_axis
(drop singular axes below noise floor once at save), not a runtime
quantile. In high-d (r=2560), expected damage from carrying a noise
axis through to runtime projection is ~||g||/sqrt(r) ≈ 2%/axis, so
the cost is bounded anyway.

Kept: load_v_hack still returns (v_hack, v_sv) tuple for callers that
need S values offline. The _sv/{name} keys remain in saved files for
future use (extract-time tau_axis, diagnostics).

Per-source cin (cin_s, cin_t) stays — that's the actual discriminator
for whether v_hack projects hack > non-hack. #51 already showed
cin_t/cin_s ~= 2.0 across early steps, so the direction is doing real
work.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-27 07:06:50 +00:00
parent 5f196e3108
commit 8d170a0753
2 changed files with 8 additions and 73 deletions
+5 -62
View File
@@ -51,8 +51,6 @@ def project_delta_S_grad(
preserve_magnitude: bool,
measure_only: bool = False,
gate_mode: str = "one_sided",
v_sv: dict[str, torch.Tensor] | None = None,
drop_top_frac: float = 0.0,
) -> dict[str, float]:
"""Per-module top-k removal of hack-aligned grad components.
@@ -73,83 +71,29 @@ def project_delta_S_grad(
`preserve_magnitude`: rescale g' to ||g|| after projection.
`measure_only`: same math, but g is not mutated (vanilla arm diagnostic).
Runtime suspicion gate (when v_sv is given and drop_top_frac > 0):
Per axis i, the within-module dimensionless ratio:
p_live_i = |c_i| / ||g|| (fraction of live grad on v_i)
p_extract_i = S_i / sqrt(sum_j S_j^2) (fraction of extract D on v_i)
r_i = p_live_i / p_extract_i
r_i ≈ 1: live grad concentrated on v_i in same proportion as extract — as
expected. r_i ≫ 1: live grad over-concentrated on v_i relative to what
the contrastive-pair signal would predict — likely v_i is spuriously
aligned with a structured coding direction, not hack.
Both ratios are within-module fractions, dimensionless, comparable across
modules with different ||g||. Per-step quantile-gate at (1-drop_top_frac):
suppress projection on axes above the threshold. Default 0 = current
behavior. Fails if v_sv is empty when drop_top_frac>0 (old v1 files).
Diagnostics returned (per call, averaged over modules):
mean_cos_in = mean over modules of ||V g||/||g|| (subspace energy fraction in)
mean_cos_out = same after projection
frac_fired = fraction of modules where at least one direction fired (c_i > 0)
frac_axes_susp = fraction of (module, axis) pairs dropped by suspicion gate
"""
# Fail fast on schema mismatch: user opted into the gate but v_hack is v1.
if drop_top_frac > 0 and (v_sv is None or not v_sv):
raise ValueError(
"susp_drop_frac > 0 requires v_sv (singular values, _sv/{name} keys "
"in v_hack file). Re-extract with current extract_vhack_grad.py "
"(saves v2 schema with _sv keys), or set susp_drop_frac=0."
)
# Pass 1: compute c per module, collect dimensionless suspicion ratios.
# r_i = (|c_i|/||g||) / (S_i/||S||). Both p_live and p_extract are within-module
# fractions in [0,1], so r_i compares like-for-like across modules of different
# sizes / gradient norms. Cache (V, c, r) so pass 2 doesn't recompute.
cache: dict[str, tuple] = {}
all_ratios: list[torch.Tensor] = []
cos_in_list, cos_out_list, n_fired = [], [], 0
for name, info in wrappers.items():
g = info["delta_S"].grad
if g is None:
continue
V = v_hack[name].to(g.device, dtype=g.dtype) # [k, r]
c = V @ g # [k]
r = None
if v_sv is not None and name in v_sv and drop_top_frac > 0:
S = v_sv[name].to(g.device, dtype=torch.float32)
S_norm = (S / S.pow(2).sum().clamp_min(1e-24).sqrt()).clamp_min(1e-12) # [k]
gn_f = g.float().norm().clamp_min(1e-12)
c_norm = c.float().abs() / gn_f # [k]
r = c_norm / S_norm # [k]
all_ratios.append(r)
cache[name] = (V, c, r)
susp_threshold: float | None = None
if drop_top_frac > 0 and all_ratios:
r_flat = torch.cat(all_ratios)
susp_threshold = torch.quantile(r_flat, 1.0 - drop_top_frac).item()
cos_in_list, cos_out_list, n_fired = [], [], 0
n_axes_total, n_axes_susp = 0, 0
for name, info in wrappers.items():
g = info["delta_S"].grad
if g is None:
continue
V, c, r = cache[name]
gn = g.norm()
if gn < 1e-12:
cos_in_list.append(0.0); cos_out_list.append(0.0); continue
c = V @ g # [k]
cin = c.norm() / gn
cos_in_list.append(cin.item())
if susp_threshold is not None and r is not None:
keep_susp = (r <= susp_threshold).to(c.dtype)
n_axes_total += c.numel()
n_axes_susp += int((1 - keep_susp).sum())
else:
keep_susp = torch.ones_like(c)
if gate_mode == "no_gate":
c_use = c * keep_susp
fired = bool((c_use != 0).any())
c_use = c
fired = True
elif gate_mode == "one_sided":
mask = (c > 0).to(c.dtype)
c_use = c * mask * keep_susp
c_use = c * mask
fired = bool((c_use != 0).any())
else:
raise ValueError(f"unknown gate_mode={gate_mode!r}")
@@ -174,5 +118,4 @@ def project_delta_S_grad(
"min_cos_out": cout_t.min().item() if cout_t.numel() else float("nan"),
"max_cos_out": cout_t.max().item() if cout_t.numel() else float("nan"),
"frac_fired": n_fired / len(cos_in_list) if cos_in_list else 0.0,
"frac_axes_susp": n_axes_susp / n_axes_total if n_axes_total else 0.0,
}
+3 -11
View File
@@ -165,11 +165,6 @@ class Config:
v_hack_path: Path | None = None
v_hack_extract_top_k: int = 12 # max k to save at extract; n_train_pairs caps it lower
v_hack_k: int = 5 # load-time slice; k=1 = mean-diff, k=k_max = full
# Runtime suspicion gate: per step, drop the top-frac (module, axis) pairs by
# r_i = |g·v_i|/S_i. Live alignment ≫ extract-time confidence suggests v_i
# is spuriously aligned with a coding/capability direction, not hack. 0.25 is
# conservative — protects against the worst quartile of suspicious projections.
susp_drop_frac: float = 0.25
out_tag: str = "" # suffix for saved artifact, e.g. "_seed41"
# Mixed-pool GRPO: per-prompt rollout pool = G_s live student + G_t cached
# teacher rollouts. Teacher pool is a dir of prompt_NNNN.jsonl.gz produced by
@@ -394,9 +389,8 @@ def main(cfg: Config) -> int:
"top_k": str(min(cfg.v_hack_extract_top_k, len(VHACK_PAIRS) - 2)),
"tau_axis": "0.0", "schema": "v2_with_sv"})
# extract zeros grads at exit; opt is built below so no opt-state taint.
v_hack_cpu, v_sv_cpu = load_v_hack(v_hack_path, model_name, wrappers, k_use=cfg.v_hack_k)
v_hack_cpu, _v_sv_cpu = load_v_hack(v_hack_path, model_name, wrappers, k_use=cfg.v_hack_k)
v_hack = {name: v.to(device) for name, v in v_hack_cpu.items()}
v_sv = {name: s.to(device) for name, s in v_sv_cpu.items()} if v_sv_cpu else None
# Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's
# G_t teacher rollouts come from a uniform random sample of that prompt's cache,
# so we do *not* keep the teacher model in VRAM. Pool is produced by
@@ -515,7 +509,7 @@ def main(cfg: Config) -> int:
# equivalents" by reading this column at the row of interest.
_row_cols = ["step", "ref_eq", "rew", "std", "sprd", "N",
"gt", "hack", "hack_s", "hack_t", "gt_s",
"loss", "cin", "cin_s", "cin_t", "cout", "fired", "susp",
"loss", "cin", "cin_s", "cin_t", "cout", "fired",
"gen", "fb", "rew_s", "sec"]
REF_GENS_PER_STEP = 16 * 16 # ariahw/rl-rewardhacking config.py:num_prompts * num_generations
est_gens_per_step = cfg.prompts_per_step * cfg.group # before mixed-pool split
@@ -820,11 +814,10 @@ def main(cfg: Config) -> int:
wrappers, v_hack, cfg.preserve_magnitude,
measure_only=(cfg.arm != "projected"),
gate_mode=cfg.gate_mode,
v_sv=v_sv, drop_top_frac=cfg.susp_drop_frac,
)
else:
diag = {"mean_cos_in": float("nan"), "mean_cos_out": float("nan"),
"frac_fired": float("nan"), "frac_axes_susp": float("nan")}
"frac_fired": float("nan")}
diag["mean_cin_s"] = cin_s
diag["mean_cin_t"] = cin_t
@@ -892,7 +885,6 @@ def main(cfg: Config) -> int:
"cin_t": f"{diag['mean_cin_t']:+.3f}",
"cout": f"{diag['mean_cos_out']:+.3f}",
"fired": f"{diag['frac_fired']:.2f}",
"susp": f"{diag['frac_axes_susp']:.2f}",
"gen": f"{t_gen:.0f}",
"fb": f"{t_fb:.0f}",
"rew_s": f"{t_rew:.0f}",