From 0660e7bdd35d975cb254b4f65881660660bb2477 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Thu, 11 Jun 2026 10:45:50 +0000 Subject: [PATCH] diag(#40): residual-stream rep (layers 12/18/24) as third gate-score candidate, 3x2 panel Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- scripts/diag_pinning.py | 101 ++++++++++++++++++++++++---------------- 1 file changed, 60 insertions(+), 41 deletions(-) diff --git a/scripts/diag_pinning.py b/scripts/diag_pinning.py index 2484391..7605dc3 100644 --- a/scripts/diag_pinning.py +++ b/scripts/diag_pinning.py @@ -14,22 +14,27 @@ rollout) the advantage ALONE is a ~0.9 AUROC detector, because the label require and most fails have A<0 -- so a high vs-all number mostly restates the reward, which the live gate already has for free. The vector's VALUE-ADD is the A>0 contrast: among positively-reinforced updates (where adv is blind, ~0.5), can the score tell hacks -from solves? That is the gate's real job (don't ship hack updates, don't rob solves) -and the headline number here; vs-all is kept as a secondary column. NOTE an info +from solves? That is the gate's primary objective: exclude reward-hacking updates while +retaining correct-solution updates, and it is the primary result here. The vs-all +contrast is kept as a secondary column. NOTE an info asymmetry: grad scores carry the adv factor (sign+magnitude), act scores do not, so on the vs-all contrast grad gets label-correlated information act lacks; the A>0 -contrast removes most of that edge -- compare the cases there. +contrast removes most of that difference. -FOUR CANDIDATE SCORES = {grad, act} x {cos, dot}, all in concatenated-module space: - - grad: the adv-weighted deployed c-probe gradient (the gate's current input). - - act: the deployed bottleneck activation A[:r]@x, mean over completion tokens -- - same [r]-per-module space, capturable in the gate's pass-1 forward for free. +SIX CANDIDATE SCORES = {grad, act, resid} x {cos, dot}, concatenated over modules/layers: + - grad: the adv-weighted deployed c-probe gradient (the gate's current input). + - act: the deployed bottleneck activation A[:r]@x, mean over completion tokens -- + same [r]-per-module space, capturable in the gate's pass-1 forward for free. + - resid: residual-stream hidden states at cfg.resid_layers, mean over completion + tokens. Adapter-independent: at an early checkpoint A is near its Gaussian + init, so grad and act are both views through a random r=32 projection per + module; resid tests whether that subspace, not grad-vs-act, limits separation. - cos: magnitude-blind alignment (tiny vectors give meaningless angles -- control). - dot: = |g|*cos, magnitude-aware; with g = A*gradNLL the advantage rides - along, so dot reads "how hard is this update pushing hack-ward". -v for each representation comes from the authored pairs only (mean hack-minus-clean -per module, unit per module) -- the no-cheat label source; live labels are read ONLY -to measure (AUROC / precision at the rout cut), never to route. + along, so dot measures update magnitude aligned with v. +v for each representation comes only from authored pairs (mean hack-minus-clean, +normalized per module). Ground-truth labels from training rollouts are used only for +diagnostic AUROC and precision measurements, never for routing. PINNING. Each panel shades the three zones the online gate rule would give on this window: keep (bulk) | absorb (score > mean + k_mid*sd) | rout (>= mean + k_rout*sd), @@ -49,8 +54,8 @@ cached features. uv run python scripts/diag_pinning.py --run-dir out/runs/ uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU -outputs (out/diag/): pinning_q2.png (2x2 headline), pinning_data.parquet (per-rollout -scores), pinning_pairset.parquet + printed table (subsets x 4 AUROCs), +outputs (out/diag/): pinning_q2.png (3x2 headline), pinning_data.parquet (per-rollout +scores), pinning_pairset.parquet + printed table (subsets x 6 AUROCs), pinning_feats.pt (raw features, for offline re-analysis). """ from __future__ import annotations @@ -82,7 +87,8 @@ from vgrout.train import _auroc # colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic) SOLVE, HACK, FAIL, ABSORB_C, ROUT_C, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#c44e52", "#3a8a7a" -CASES = [("grad", "cos"), ("grad", "dot"), ("act", "cos"), ("act", "dot")] +CASES = [("grad", "cos"), ("grad", "dot"), ("act", "cos"), ("act", "dot"), + ("resid", "cos"), ("resid", "dot")] @dataclass @@ -102,6 +108,7 @@ class Cfg: k_mid: float = 2.0 # absorb onset: score > mean + k_mid*sd (real-run Config default) k_rout: float = 3.0 # rout onset: score >= mean + k_rout*sd adv_eps: float = 1e-6 # |A| below this = no update exists -> dropped from zones/AUROC + resid_layers: tuple[int, ...] = (12, 18, 24) # residual-stream capture depths (of 36) random_v_seed: int | None = None # Haar placebo (sanity: nothing should separate) replot: Path | None = None # load parquet and re-plot only (no model, no GPU) out_dir: Path = Path("out/diag") @@ -113,13 +120,16 @@ def _ckpt_meta(path: Path) -> dict: class ActTap: - """Forward hooks stashing the deployed bottleneck activation h = A[:r] @ x per module. + """Forward hooks stashing (a) the deployed bottleneck activation h = A[:r] @ x per + module and (b) the residual-stream hidden state after each decoder layer in + `resid_modules`. - Computes the r-dim projection inline (no_grad) instead of retaining the full + (a) computes the r-dim projection inline (no_grad) instead of retaining the full [L, d_in] input -- ~250 modules x [L, d_in] would be GBs; [L, r] is nothing. """ - def __init__(self, wrappers: dict, names: list[str]): - self.wrappers, self.names, self.h, self.handles = wrappers, names, {}, [] + def __init__(self, wrappers: dict, names: list[str], resid_modules: list): + self.wrappers, self.names, self.resid_modules = wrappers, names, resid_modules + self.h, self.res, self.handles = {}, {}, [] def __enter__(self): for nm in self.names: @@ -129,6 +139,10 @@ class ActTap: with torch.no_grad(): self.h[nm] = F.linear(x.detach(), layer._lora2r_A[: layer._lora2r_r].to(x.dtype)) self.handles.append(layer.register_forward_hook(hook)) + for li, mod in enumerate(self.resid_modules): + def rhook(mod, args, out, li=li): + self.res[li] = (out[0] if isinstance(out, tuple) else out).detach() + self.handles.append(mod.register_forward_hook(rhook)) return self def __exit__(self, *exc): @@ -144,6 +158,11 @@ class ActTap: out.append(h[0, n_prompt:].float().mean(0).cpu()) return torch.stack(out) + def pooled_resid(self, n_prompt: int) -> torch.Tensor: + """[L_layers, d_model] mean residual-stream state over completion tokens.""" + return torch.stack([self.res[li][0, n_prompt:].float().mean(0).cpu() + for li in range(len(self.resid_modules))]) + def _gate_grads(wrappers: dict, names: list[str]) -> torch.Tensor: """[M, r] deployed-block c-probe grad after a backward (the gate's gradient space).""" @@ -209,7 +228,8 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn f"A>0-contrast={_auroc(a[posm].tolist(), y_all[posm].tolist()):.3f} " f"(n+={int(y_all.sum())} negA>0={int((~y_all & posm).sum())})") stats = {} - fig, axes = plt.subplots(2, 2, figsize=(12.5, 7.6)) + n_rows = len(CASES) // 2 + fig, axes = plt.subplots(n_rows, 2, figsize=(12.5, 3.6 * n_rows + 0.8)) for ax, (rep, kind) in zip(axes.flat, CASES): col = f"{rep}_{kind}" s = live[col].to_numpy() @@ -344,48 +364,50 @@ def main(cfg: Cfg) -> int: logger.info(f"loaded A/B into {len(names)} modules") model.eval() - def one_pass(tap: ActTap, prompt: str, completion: str) -> tuple[torch.Tensor, torch.Tensor] | None: - """Backward one completion's mean NLL; return ([M,r] c-grad, [M,r] pooled act).""" + def one_pass(tap: ActTap, prompt: str, completion: str): + """Backward one completion's mean NLL; return ([M,r] c-grad, [M,r] act, [L,d] resid).""" model.zero_grad(set_to_none=True) loss = completion_nll(model, tok, prompt, completion, device) if not torch.isfinite(loss): return None loss.backward() n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1] - return _gate_grads(wrappers, names), tap.pooled(n_prompt) + return _gate_grads(wrappers, names), tap.pooled(n_prompt), tap.pooled_resid(n_prompt) # ── authored-pair features, once over ALL pairs (subsets = row slices) ── pairs_all = load_pairs(cfg.pairs) logger.info(f"pairs {cfg.pairs} -> {len(pairs_all)}") - pair_feat = {("grad", "hack"): [], ("grad", "clean"): [], ("act", "hack"): [], ("act", "clean"): []} - with ActTap(wrappers, names) as tap: + pair_feat = {(rep, side): [] for rep in ("grad", "act", "resid") for side in ("hack", "clean")} + resid_modules = [model.model.layers[i] for i in cfg.resid_layers] + with ActTap(wrappers, names, resid_modules) as tap: for pi, pair in enumerate(pairs_all): for side, completion in (("hack", pair.hack), ("clean", pair.clean)): out = one_pass(tap, pair.prompt, completion) if out is None: raise RuntimeError(f"non-finite loss on pair {pi} ({pair.problem_id}) side={side}") - pair_feat[("grad", side)].append(out[0]) - pair_feat[("act", side)].append(out[1]) + for rep, feat in zip(("grad", "act", "resid"), out): + pair_feat[(rep, side)].append(feat) if (pi + 1) % 5 == 0: logger.info(f" pair {pi+1}/{len(pairs_all)}") - PF = {k: torch.stack(v) for k, v in pair_feat.items()} # each [P, M, r] + PF = {k: torch.stack(v) for k, v in pair_feat.items()} # [P, M, r] / resid [P, L, d] # ── live rollout features, once (everything downstream re-projects) ── recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()] batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts] logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})") - G_rows, A_rows, kept = [], [], [] + G_rows, A_rows, R_rows, kept = [], [], [], [] for i, rec in enumerate(batch): out = one_pass(tap, rec["prompt"], rec["text"]) if out is None: logger.warning(f" skip rollout {i}: non-finite loss") continue - G_rows.append(out[0]); A_rows.append(out[1]); kept.append(rec) + G_rows.append(out[0]); A_rows.append(out[1]); R_rows.append(out[2]); kept.append(rec) if (i + 1) % 40 == 0: logger.info(f" rollout {i+1}/{len(batch)}") model.zero_grad(set_to_none=True) G = torch.stack(G_rows) # [N, M, r] gradNLL ACT = torch.stack(A_rows) # [N, M, r] + RES = torch.stack(R_rows) # [N, L, d_model] exploited = np.array([bool(x["exploited"]) for x in kept]) gt_pass = np.array([bool(x["gt_pass"]) for x in kept]) steps = np.array([x["step"] for x in kept]) @@ -422,25 +444,22 @@ def main(cfg: Cfg) -> int: assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}" logger.info(f"headline v from prefix {cfg.headline_prefix!r} -> {len(head_idx)} pairs") + REPS = ("grad", "act", "resid") + def vectors(idx: list[int]) -> dict[str, torch.Tensor]: - v = {"grad": _v_from(PF[("grad", "hack")], PF[("grad", "clean")], idx), - "act": _v_from(PF[("act", "hack")], PF[("act", "clean")], idx)} + v = {rep: _v_from(PF[(rep, "hack")], PF[(rep, "clean")], idx) for rep in REPS} if cfg.random_v_seed is not None: - v = {"grad": _haar_like(v["grad"], cfg.random_v_seed), - "act": _haar_like(v["act"], cfg.random_v_seed + 1)} + v = {rep: _haar_like(v[rep], cfg.random_v_seed + i) for i, rep in enumerate(REPS)} return v v_head = vectors(head_idx) - live_X = {"grad": G_adv, "act": ACT} - syn_X = {("grad", "clean"): PF[("grad", "clean")], ("grad", "hack"): PF[("grad", "hack")], - ("act", "clean"): PF[("act", "clean")], ("act", "hack"): PF[("act", "hack")]} + live_X = {"grad": G_adv, "act": ACT, "resid": RES} def score_cols(v: dict, X: dict[str, torch.Tensor]) -> dict[str, np.ndarray]: return {f"{rep}_{kind}": _score(X[rep], v[rep], kind) for rep, kind in CASES} live_scores = score_cols(v_head, live_X) - syn_scores = {side: score_cols(v_head, {"grad": syn_X[("grad", side)][head_idx], - "act": syn_X[("act", side)][head_idx]}) + syn_scores = {side: score_cols(v_head, {rep: PF[(rep, side)][head_idx] for rep in REPS}) for side in ("clean", "hack")} # ── pairset table: subsets x 4 AUROCs on the SAME cached live features ── @@ -480,9 +499,9 @@ def main(cfg: Cfg) -> int: np.full(n_syn, -1), np.ones(n_syn)) for side in ("clean", "hack")] df = pl.concat(dfs) df.write_parquet(data_path) - torch.save({"G": G, "ACT": ACT, "adv": adv, "exploited": exploited, "gt_pass": gt_pass, - "steps": steps, "p_idx": p_idx, "names": names, - "pair_feats": PF, "pair_groups": dict(groups), + torch.save({"G": G, "ACT": ACT, "RES": RES, "adv": adv, "exploited": exploited, + "gt_pass": gt_pass, "steps": steps, "p_idx": p_idx, "names": names, + "resid_layers": cfg.resid_layers, "pair_feats": PF, "pair_groups": dict(groups), "pair_ids": [p.problem_id for p in pairs_all]}, feats_path) logger.info(f"wrote {data_path} ({len(df)} rows), {feats_path}")