diag(#40): residual-stream rep (layers 12/18/24) as third gate-score candidate, 3x2 panel

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 10:45:50 +00:00
parent 4a7465c0da
commit 0660e7bdd3
+60 -41
View File
@@ -14,22 +14,27 @@ rollout) the advantage ALONE is a ~0.9 AUROC detector, because the label require
and most fails have A<0 -- so a high vs-all number mostly restates the reward, which
the live gate already has for free. The vector's VALUE-ADD is the A>0 contrast: among
positively-reinforced updates (where adv is blind, ~0.5), can the score tell hacks
from solves? That is the gate's real job (don't ship hack updates, don't rob solves)
and the headline number here; vs-all is kept as a secondary column. NOTE an info
from solves? That is the gate's primary objective: exclude reward-hacking updates while
retaining correct-solution updates, and it is the primary result here. The vs-all
contrast is kept as a secondary column. NOTE an info
asymmetry: grad scores carry the adv factor (sign+magnitude), act scores do not, so
on the vs-all contrast grad gets label-correlated information act lacks; the A>0
contrast removes most of that edge -- compare the cases there.
contrast removes most of that difference.
FOUR CANDIDATE SCORES = {grad, act} x {cos, dot}, all in concatenated-module space:
- grad: the adv-weighted deployed c-probe gradient (the gate's current input).
- act: the deployed bottleneck activation A[:r]@x, mean over completion tokens --
same [r]-per-module space, capturable in the gate's pass-1 forward for free.
SIX CANDIDATE SCORES = {grad, act, resid} x {cos, dot}, concatenated over modules/layers:
- grad: the adv-weighted deployed c-probe gradient (the gate's current input).
- act: the deployed bottleneck activation A[:r]@x, mean over completion tokens --
same [r]-per-module space, capturable in the gate's pass-1 forward for free.
- resid: residual-stream hidden states at cfg.resid_layers, mean over completion
tokens. Adapter-independent: at an early checkpoint A is near its Gaussian
init, so grad and act are both views through a random r=32 projection per
module; resid tests whether that subspace, not grad-vs-act, limits separation.
- cos: magnitude-blind alignment (tiny vectors give meaningless angles -- control).
- dot: <g, v> = |g|*cos, magnitude-aware; with g = A*gradNLL the advantage rides
along, so dot reads "how hard is this update pushing hack-ward".
v for each representation comes from the authored pairs only (mean hack-minus-clean
per module, unit per module) -- the no-cheat label source; live labels are read ONLY
to measure (AUROC / precision at the rout cut), never to route.
along, so dot measures update magnitude aligned with v.
v for each representation comes only from authored pairs (mean hack-minus-clean,
normalized per module). Ground-truth labels from training rollouts are used only for
diagnostic AUROC and precision measurements, never for routing.
PINNING. Each panel shades the three zones the online gate rule would give on this
window: keep (bulk) | absorb (score > mean + k_mid*sd) | rout (>= mean + k_rout*sd),
@@ -49,8 +54,8 @@ cached features.
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU
outputs (out/diag/): pinning_q2.png (2x2 headline), pinning_data.parquet (per-rollout
scores), pinning_pairset.parquet + printed table (subsets x 4 AUROCs),
outputs (out/diag/): pinning_q2.png (3x2 headline), pinning_data.parquet (per-rollout
scores), pinning_pairset.parquet + printed table (subsets x 6 AUROCs),
pinning_feats.pt (raw features, for offline re-analysis).
"""
from __future__ import annotations
@@ -82,7 +87,8 @@ from vgrout.train import _auroc
# colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic)
SOLVE, HACK, FAIL, ABSORB_C, ROUT_C, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#c44e52", "#3a8a7a"
CASES = [("grad", "cos"), ("grad", "dot"), ("act", "cos"), ("act", "dot")]
CASES = [("grad", "cos"), ("grad", "dot"), ("act", "cos"), ("act", "dot"),
("resid", "cos"), ("resid", "dot")]
@dataclass
@@ -102,6 +108,7 @@ class Cfg:
k_mid: float = 2.0 # absorb onset: score > mean + k_mid*sd (real-run Config default)
k_rout: float = 3.0 # rout onset: score >= mean + k_rout*sd
adv_eps: float = 1e-6 # |A| below this = no update exists -> dropped from zones/AUROC
resid_layers: tuple[int, ...] = (12, 18, 24) # residual-stream capture depths (of 36)
random_v_seed: int | None = None # Haar placebo (sanity: nothing should separate)
replot: Path | None = None # load parquet and re-plot only (no model, no GPU)
out_dir: Path = Path("out/diag")
@@ -113,13 +120,16 @@ def _ckpt_meta(path: Path) -> dict:
class ActTap:
"""Forward hooks stashing the deployed bottleneck activation h = A[:r] @ x per module.
"""Forward hooks stashing (a) the deployed bottleneck activation h = A[:r] @ x per
module and (b) the residual-stream hidden state after each decoder layer in
`resid_modules`.
Computes the r-dim projection inline (no_grad) instead of retaining the full
(a) computes the r-dim projection inline (no_grad) instead of retaining the full
[L, d_in] input -- ~250 modules x [L, d_in] would be GBs; [L, r] is nothing.
"""
def __init__(self, wrappers: dict, names: list[str]):
self.wrappers, self.names, self.h, self.handles = wrappers, names, {}, []
def __init__(self, wrappers: dict, names: list[str], resid_modules: list):
self.wrappers, self.names, self.resid_modules = wrappers, names, resid_modules
self.h, self.res, self.handles = {}, {}, []
def __enter__(self):
for nm in self.names:
@@ -129,6 +139,10 @@ class ActTap:
with torch.no_grad():
self.h[nm] = F.linear(x.detach(), layer._lora2r_A[: layer._lora2r_r].to(x.dtype))
self.handles.append(layer.register_forward_hook(hook))
for li, mod in enumerate(self.resid_modules):
def rhook(mod, args, out, li=li):
self.res[li] = (out[0] if isinstance(out, tuple) else out).detach()
self.handles.append(mod.register_forward_hook(rhook))
return self
def __exit__(self, *exc):
@@ -144,6 +158,11 @@ class ActTap:
out.append(h[0, n_prompt:].float().mean(0).cpu())
return torch.stack(out)
def pooled_resid(self, n_prompt: int) -> torch.Tensor:
"""[L_layers, d_model] mean residual-stream state over completion tokens."""
return torch.stack([self.res[li][0, n_prompt:].float().mean(0).cpu()
for li in range(len(self.resid_modules))])
def _gate_grads(wrappers: dict, names: list[str]) -> torch.Tensor:
"""[M, r] deployed-block c-probe grad after a backward (the gate's gradient space)."""
@@ -209,7 +228,8 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
f"A>0-contrast={_auroc(a[posm].tolist(), y_all[posm].tolist()):.3f} "
f"(n+={int(y_all.sum())} negA>0={int((~y_all & posm).sum())})")
stats = {}
fig, axes = plt.subplots(2, 2, figsize=(12.5, 7.6))
n_rows = len(CASES) // 2
fig, axes = plt.subplots(n_rows, 2, figsize=(12.5, 3.6 * n_rows + 0.8))
for ax, (rep, kind) in zip(axes.flat, CASES):
col = f"{rep}_{kind}"
s = live[col].to_numpy()
@@ -344,48 +364,50 @@ def main(cfg: Cfg) -> int:
logger.info(f"loaded A/B into {len(names)} modules")
model.eval()
def one_pass(tap: ActTap, prompt: str, completion: str) -> tuple[torch.Tensor, torch.Tensor] | None:
"""Backward one completion's mean NLL; return ([M,r] c-grad, [M,r] pooled act)."""
def one_pass(tap: ActTap, prompt: str, completion: str):
"""Backward one completion's mean NLL; return ([M,r] c-grad, [M,r] act, [L,d] resid)."""
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tok, prompt, completion, device)
if not torch.isfinite(loss):
return None
loss.backward()
n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1]
return _gate_grads(wrappers, names), tap.pooled(n_prompt)
return _gate_grads(wrappers, names), tap.pooled(n_prompt), tap.pooled_resid(n_prompt)
# ── authored-pair features, once over ALL pairs (subsets = row slices) ──
pairs_all = load_pairs(cfg.pairs)
logger.info(f"pairs {cfg.pairs} -> {len(pairs_all)}")
pair_feat = {("grad", "hack"): [], ("grad", "clean"): [], ("act", "hack"): [], ("act", "clean"): []}
with ActTap(wrappers, names) as tap:
pair_feat = {(rep, side): [] for rep in ("grad", "act", "resid") for side in ("hack", "clean")}
resid_modules = [model.model.layers[i] for i in cfg.resid_layers]
with ActTap(wrappers, names, resid_modules) as tap:
for pi, pair in enumerate(pairs_all):
for side, completion in (("hack", pair.hack), ("clean", pair.clean)):
out = one_pass(tap, pair.prompt, completion)
if out is None:
raise RuntimeError(f"non-finite loss on pair {pi} ({pair.problem_id}) side={side}")
pair_feat[("grad", side)].append(out[0])
pair_feat[("act", side)].append(out[1])
for rep, feat in zip(("grad", "act", "resid"), out):
pair_feat[(rep, side)].append(feat)
if (pi + 1) % 5 == 0:
logger.info(f" pair {pi+1}/{len(pairs_all)}")
PF = {k: torch.stack(v) for k, v in pair_feat.items()} # each [P, M, r]
PF = {k: torch.stack(v) for k, v in pair_feat.items()} # [P, M, r] / resid [P, L, d]
# ── live rollout features, once (everything downstream re-projects) ──
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts]
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
G_rows, A_rows, kept = [], [], []
G_rows, A_rows, R_rows, kept = [], [], [], []
for i, rec in enumerate(batch):
out = one_pass(tap, rec["prompt"], rec["text"])
if out is None:
logger.warning(f" skip rollout {i}: non-finite loss")
continue
G_rows.append(out[0]); A_rows.append(out[1]); kept.append(rec)
G_rows.append(out[0]); A_rows.append(out[1]); R_rows.append(out[2]); kept.append(rec)
if (i + 1) % 40 == 0:
logger.info(f" rollout {i+1}/{len(batch)}")
model.zero_grad(set_to_none=True)
G = torch.stack(G_rows) # [N, M, r] gradNLL
ACT = torch.stack(A_rows) # [N, M, r]
RES = torch.stack(R_rows) # [N, L, d_model]
exploited = np.array([bool(x["exploited"]) for x in kept])
gt_pass = np.array([bool(x["gt_pass"]) for x in kept])
steps = np.array([x["step"] for x in kept])
@@ -422,25 +444,22 @@ def main(cfg: Cfg) -> int:
assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}"
logger.info(f"headline v from prefix {cfg.headline_prefix!r} -> {len(head_idx)} pairs")
REPS = ("grad", "act", "resid")
def vectors(idx: list[int]) -> dict[str, torch.Tensor]:
v = {"grad": _v_from(PF[("grad", "hack")], PF[("grad", "clean")], idx),
"act": _v_from(PF[("act", "hack")], PF[("act", "clean")], idx)}
v = {rep: _v_from(PF[(rep, "hack")], PF[(rep, "clean")], idx) for rep in REPS}
if cfg.random_v_seed is not None:
v = {"grad": _haar_like(v["grad"], cfg.random_v_seed),
"act": _haar_like(v["act"], cfg.random_v_seed + 1)}
v = {rep: _haar_like(v[rep], cfg.random_v_seed + i) for i, rep in enumerate(REPS)}
return v
v_head = vectors(head_idx)
live_X = {"grad": G_adv, "act": ACT}
syn_X = {("grad", "clean"): PF[("grad", "clean")], ("grad", "hack"): PF[("grad", "hack")],
("act", "clean"): PF[("act", "clean")], ("act", "hack"): PF[("act", "hack")]}
live_X = {"grad": G_adv, "act": ACT, "resid": RES}
def score_cols(v: dict, X: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
return {f"{rep}_{kind}": _score(X[rep], v[rep], kind) for rep, kind in CASES}
live_scores = score_cols(v_head, live_X)
syn_scores = {side: score_cols(v_head, {"grad": syn_X[("grad", side)][head_idx],
"act": syn_X[("act", side)][head_idx]})
syn_scores = {side: score_cols(v_head, {rep: PF[(rep, side)][head_idx] for rep in REPS})
for side in ("clean", "hack")}
# ── pairset table: subsets x 4 AUROCs on the SAME cached live features ──
@@ -480,9 +499,9 @@ def main(cfg: Cfg) -> int:
np.full(n_syn, -1), np.ones(n_syn)) for side in ("clean", "hack")]
df = pl.concat(dfs)
df.write_parquet(data_path)
torch.save({"G": G, "ACT": ACT, "adv": adv, "exploited": exploited, "gt_pass": gt_pass,
"steps": steps, "p_idx": p_idx, "names": names,
"pair_feats": PF, "pair_groups": dict(groups),
torch.save({"G": G, "ACT": ACT, "RES": RES, "adv": adv, "exploited": exploited,
"gt_pass": gt_pass, "steps": steps, "p_idx": p_idx, "names": names,
"resid_layers": cfg.resid_layers, "pair_feats": PF, "pair_groups": dict(groups),
"pair_ids": [p.problem_id for p in pairs_all]}, feats_path)
logger.info(f"wrote {data_path} ({len(df)} rows), {feats_path}")