diff --git a/src/projected_grpo/proj.py b/src/projected_grpo/proj.py index 3b3b956..9319871 100644 --- a/src/projected_grpo/proj.py +++ b/src/projected_grpo/proj.py @@ -51,8 +51,6 @@ def project_delta_S_grad( preserve_magnitude: bool, measure_only: bool = False, gate_mode: str = "one_sided", - v_sv: dict[str, torch.Tensor] | None = None, - drop_top_frac: float = 0.0, ) -> dict[str, float]: """Per-module top-k removal of hack-aligned grad components. @@ -73,83 +71,29 @@ def project_delta_S_grad( `preserve_magnitude`: rescale g' to ||g|| after projection. `measure_only`: same math, but g is not mutated (vanilla arm diagnostic). - Runtime suspicion gate (when v_sv is given and drop_top_frac > 0): - Per axis i, the within-module dimensionless ratio: - p_live_i = |c_i| / ||g|| (fraction of live grad on v_i) - p_extract_i = S_i / sqrt(sum_j S_j^2) (fraction of extract D on v_i) - r_i = p_live_i / p_extract_i - r_i ≈ 1: live grad concentrated on v_i in same proportion as extract — as - expected. r_i ≫ 1: live grad over-concentrated on v_i relative to what - the contrastive-pair signal would predict — likely v_i is spuriously - aligned with a structured coding direction, not hack. - Both ratios are within-module fractions, dimensionless, comparable across - modules with different ||g||. Per-step quantile-gate at (1-drop_top_frac): - suppress projection on axes above the threshold. Default 0 = current - behavior. Fails if v_sv is empty when drop_top_frac>0 (old v1 files). - Diagnostics returned (per call, averaged over modules): mean_cos_in = mean over modules of ||V g||/||g|| (subspace energy fraction in) mean_cos_out = same after projection frac_fired = fraction of modules where at least one direction fired (c_i > 0) - frac_axes_susp = fraction of (module, axis) pairs dropped by suspicion gate """ - # Fail fast on schema mismatch: user opted into the gate but v_hack is v1. - if drop_top_frac > 0 and (v_sv is None or not v_sv): - raise ValueError( - "susp_drop_frac > 0 requires v_sv (singular values, _sv/{name} keys " - "in v_hack file). Re-extract with current extract_vhack_grad.py " - "(saves v2 schema with _sv keys), or set susp_drop_frac=0." - ) - # Pass 1: compute c per module, collect dimensionless suspicion ratios. - # r_i = (|c_i|/||g||) / (S_i/||S||). Both p_live and p_extract are within-module - # fractions in [0,1], so r_i compares like-for-like across modules of different - # sizes / gradient norms. Cache (V, c, r) so pass 2 doesn't recompute. - cache: dict[str, tuple] = {} - all_ratios: list[torch.Tensor] = [] + cos_in_list, cos_out_list, n_fired = [], [], 0 for name, info in wrappers.items(): g = info["delta_S"].grad if g is None: continue V = v_hack[name].to(g.device, dtype=g.dtype) # [k, r] - c = V @ g # [k] - r = None - if v_sv is not None and name in v_sv and drop_top_frac > 0: - S = v_sv[name].to(g.device, dtype=torch.float32) - S_norm = (S / S.pow(2).sum().clamp_min(1e-24).sqrt()).clamp_min(1e-12) # [k] - gn_f = g.float().norm().clamp_min(1e-12) - c_norm = c.float().abs() / gn_f # [k] - r = c_norm / S_norm # [k] - all_ratios.append(r) - cache[name] = (V, c, r) - susp_threshold: float | None = None - if drop_top_frac > 0 and all_ratios: - r_flat = torch.cat(all_ratios) - susp_threshold = torch.quantile(r_flat, 1.0 - drop_top_frac).item() - - cos_in_list, cos_out_list, n_fired = [], [], 0 - n_axes_total, n_axes_susp = 0, 0 - for name, info in wrappers.items(): - g = info["delta_S"].grad - if g is None: - continue - V, c, r = cache[name] gn = g.norm() if gn < 1e-12: cos_in_list.append(0.0); cos_out_list.append(0.0); continue + c = V @ g # [k] cin = c.norm() / gn cos_in_list.append(cin.item()) - if susp_threshold is not None and r is not None: - keep_susp = (r <= susp_threshold).to(c.dtype) - n_axes_total += c.numel() - n_axes_susp += int((1 - keep_susp).sum()) - else: - keep_susp = torch.ones_like(c) if gate_mode == "no_gate": - c_use = c * keep_susp - fired = bool((c_use != 0).any()) + c_use = c + fired = True elif gate_mode == "one_sided": mask = (c > 0).to(c.dtype) - c_use = c * mask * keep_susp + c_use = c * mask fired = bool((c_use != 0).any()) else: raise ValueError(f"unknown gate_mode={gate_mode!r}") @@ -174,5 +118,4 @@ def project_delta_S_grad( "min_cos_out": cout_t.min().item() if cout_t.numel() else float("nan"), "max_cos_out": cout_t.max().item() if cout_t.numel() else float("nan"), "frac_fired": n_fired / len(cos_in_list) if cos_in_list else 0.0, - "frac_axes_susp": n_axes_susp / n_axes_total if n_axes_total else 0.0, } diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index 7828732..98a8b7d 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -165,11 +165,6 @@ class Config: v_hack_path: Path | None = None v_hack_extract_top_k: int = 12 # max k to save at extract; n_train_pairs caps it lower v_hack_k: int = 5 # load-time slice; k=1 = mean-diff, k=k_max = full - # Runtime suspicion gate: per step, drop the top-frac (module, axis) pairs by - # r_i = |g·v_i|/S_i. Live alignment ≫ extract-time confidence suggests v_i - # is spuriously aligned with a coding/capability direction, not hack. 0.25 is - # conservative — protects against the worst quartile of suspicious projections. - susp_drop_frac: float = 0.25 out_tag: str = "" # suffix for saved artifact, e.g. "_seed41" # Mixed-pool GRPO: per-prompt rollout pool = G_s live student + G_t cached # teacher rollouts. Teacher pool is a dir of prompt_NNNN.jsonl.gz produced by @@ -394,9 +389,8 @@ def main(cfg: Config) -> int: "top_k": str(min(cfg.v_hack_extract_top_k, len(VHACK_PAIRS) - 2)), "tau_axis": "0.0", "schema": "v2_with_sv"}) # extract zeros grads at exit; opt is built below so no opt-state taint. - v_hack_cpu, v_sv_cpu = load_v_hack(v_hack_path, model_name, wrappers, k_use=cfg.v_hack_k) + v_hack_cpu, _v_sv_cpu = load_v_hack(v_hack_path, model_name, wrappers, k_use=cfg.v_hack_k) v_hack = {name: v.to(device) for name, v in v_hack_cpu.items()} - v_sv = {name: s.to(device) for name, s in v_sv_cpu.items()} if v_sv_cpu else None # Teacher pool: pre-generated rollouts on disk keyed by problem_id. Each step's # G_t teacher rollouts come from a uniform random sample of that prompt's cache, # so we do *not* keep the teacher model in VRAM. Pool is produced by @@ -515,7 +509,7 @@ def main(cfg: Config) -> int: # equivalents" by reading this column at the row of interest. _row_cols = ["step", "ref_eq", "rew", "std", "sprd", "N", "gt", "hack", "hack_s", "hack_t", "gt_s", - "loss", "cin", "cin_s", "cin_t", "cout", "fired", "susp", + "loss", "cin", "cin_s", "cin_t", "cout", "fired", "gen", "fb", "rew_s", "sec"] REF_GENS_PER_STEP = 16 * 16 # ariahw/rl-rewardhacking config.py:num_prompts * num_generations est_gens_per_step = cfg.prompts_per_step * cfg.group # before mixed-pool split @@ -820,11 +814,10 @@ def main(cfg: Config) -> int: wrappers, v_hack, cfg.preserve_magnitude, measure_only=(cfg.arm != "projected"), gate_mode=cfg.gate_mode, - v_sv=v_sv, drop_top_frac=cfg.susp_drop_frac, ) else: diag = {"mean_cos_in": float("nan"), "mean_cos_out": float("nan"), - "frac_fired": float("nan"), "frac_axes_susp": float("nan")} + "frac_fired": float("nan")} diag["mean_cin_s"] = cin_s diag["mean_cin_t"] = cin_t @@ -892,7 +885,6 @@ def main(cfg: Config) -> int: "cin_t": f"{diag['mean_cin_t']:+.3f}", "cout": f"{diag['mean_cos_out']:+.3f}", "fired": f"{diag['frac_fired']:.2f}", - "susp": f"{diag['frac_axes_susp']:.2f}", "gen": f"{t_gen:.0f}", "fb": f"{t_fb:.0f}", "rew_s": f"{t_rew:.0f}",