From 16e2c37de63b08a6cf1b7232b7abe5ab482159d8 Mon Sep 17 00:00:00 2001 From: wassname Date: Thu, 28 May 2026 09:42:17 +0000 Subject: [PATCH] train: online v_hack refresh every N steps Re-extract the hack subspace V against the current (delta_S-modified) model on the same hand-crafted PAIRS, every --vhack-refresh-every steps. Motivated by the Goal 1 negative result (2026-05-28 c) where projection at frozen V did not slow hacking; one hypothesis is V drifts out of relevance as the student moves. Off by default (0). Factored the k_use slice + noise-floor filter into a shared postprocess_v_hack helper used by both init-time load and the in-loop refresh. Co-Authored-By: Claude Opus 4.7 --- src/projected_grpo/train.py | 74 ++++++++++++++++++++++++++++++------- 1 file changed, 61 insertions(+), 13 deletions(-) diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index ef18eca..bca8572 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -167,6 +167,12 @@ class Config: # global threshold get filtered out entirely (projection skips them — they # didn't carry hack signal anyway). 0 = no filter. v_hack_drop_bottom_frac: float = 0.25 + # Online refresh: every N optimizer steps, re-run extract_v_hack against the + # current (delta_S-modified) model using the same hand-crafted PAIRS. Motivated + # by Goal-1 negative result (2026-05-28 c) where v_hack appeared to lose + # discriminative power as the student drifted. 0 = off (load once at start + # and freeze). Refresh cost ~14*2 backwards on Qwen3-4B ~ 1-2 min wall. + vhack_refresh_every: int = 0 # Per-source cin diagnostic: split each prompt's backward into student-only # + teacher-only passes (~2x backward time). 1 = every step (default; full # signal); N>1 = only every Nth step (combined backward elsewhere, ~halves @@ -339,20 +345,37 @@ def load_v_hack( f"--model={model_name} --out-path={path}`." ) + v_hack = postprocess_v_hack( + v_hack, v_sv, k_use=k_use, drop_bottom_frac=drop_bottom_frac, source=str(path), + ) + return v_hack + + +def postprocess_v_hack( + v_hack: dict[str, Float[torch.Tensor, "k r"]], + v_sv: dict[str, Float[torch.Tensor, "k"]], + k_use: int | None, + drop_bottom_frac: float, + source: str = "", +) -> dict[str, Float[torch.Tensor, "k r"]]: + """Apply k_use slice + global noise-floor filter. + + Shared between `load_v_hack` (init-time, reading from safetensors) and the + in-loop refresh hook (where we hand in fresh `extract_v_hack` outputs). + Mutates neither input dict; returns a fresh filtered dict. + + Global noise floor: collect every S_i across every module, drop the bottom + `drop_bottom_frac` by quantile. A module whose every axis falls below the + global threshold is removed entirely — projection iterates v_hack so it + becomes a no-op for that module. Threshold recomputes per call (tracks + current S distribution). + """ k_max = next(iter(v_hack.values())).shape[0] if k_use is not None: if k_use > k_max: - raise ValueError( - f"requested k_use={k_use} exceeds saved k_max={k_max} in {path}. " - f"Re-extract with `--top-k={k_use}`." - ) + raise ValueError(f"requested k_use={k_use} exceeds k_max={k_max} (source={source})") v_hack = {n: v[:k_use].contiguous() for n, v in v_hack.items()} v_sv = {n: s[:k_use].contiguous() for n, s in v_sv.items()} - - # Global noise floor: drop the bottom drop_bottom_frac of all (module, axis) - # pairs by S_i. One quantile across every S_i in every module. A module - # whose every axis lies below the global threshold is removed from v_hack - # entirely — projection iterates v_hack so that module just gets skipped. n_dropped_modules = 0 n_axes_before = sum(v.shape[0] for v in v_hack.values()) threshold = None @@ -368,11 +391,9 @@ def load_v_hack( n_dropped_modules += 1 v_hack = filtered n_axes_after = sum(v.shape[0] for v in v_hack.values()) - logger.info( - f"loaded v_hack from {path}: modules={len(v_hack)} (dropped {n_dropped_modules}); " - f"k_saved={k_max}, k_use={k_use or k_max}; " - f"axes={n_axes_after}/{n_axes_before} kept " + f"postprocess_v_hack({source}): modules={len(v_hack)} (dropped {n_dropped_modules}); " + f"k_use={k_use or k_max}/k_max={k_max}; axes={n_axes_after}/{n_axes_before} kept " f"(drop_bottom_frac={drop_bottom_frac}, threshold={threshold})" ) return v_hack @@ -1069,6 +1090,33 @@ table columns: opt.step() sched.step() + # Online v_hack refresh: re-extract against the *current* model so the + # hack subspace tracks where the student is being pulled now (rather + # than at step 0). Same PAIRS, same extract code; we just discard the + # saved cache and overwrite the in-memory v_hack dict. + if cfg.vhack_refresh_every > 0 and (step + 1) % cfg.vhack_refresh_every == 0: + from .extract_vhack_grad import extract_v_hack + from .pairs import PAIRS as VHACK_PAIRS + _was_training = model.training + model.eval() + opt.zero_grad(set_to_none=True) + _new_V, _new_S, _, _ = extract_v_hack( + model, tok, wrappers, VHACK_PAIRS, + top_k=cfg.v_hack_extract_top_k, tau_axis=cfg.v_hack_tau_axis, + n_heldout=2, device=device, + ) + _post = postprocess_v_hack( + _new_V, _new_S, k_use=cfg.v_hack_k, + drop_bottom_frac=cfg.v_hack_drop_bottom_frac, + source=f"refresh@step{step}", + ) + v_hack.clear() + v_hack.update({n: V.to(device) for n, V in _post.items()}) + opt.zero_grad(set_to_none=True) # extract leaves .grad populated + if _was_training: + model.train() + logger.info(f"v_hack refreshed @ step={step}: {len(v_hack)} modules") + rewards_t = torch.tensor(agg_rew, dtype=torch.float32) if agg_rew else torch.zeros(1) rew_mean = rewards_t.mean().item() rew_std = rewards_t.std().item() if rewards_t.numel() > 1 else 0.0