mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
diag(#40): residual-stream rep (layers 12/18/24) as third gate-score candidate, 3x2 panel
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+60
-41
@@ -14,22 +14,27 @@ rollout) the advantage ALONE is a ~0.9 AUROC detector, because the label require
|
||||
and most fails have A<0 -- so a high vs-all number mostly restates the reward, which
|
||||
the live gate already has for free. The vector's VALUE-ADD is the A>0 contrast: among
|
||||
positively-reinforced updates (where adv is blind, ~0.5), can the score tell hacks
|
||||
from solves? That is the gate's real job (don't ship hack updates, don't rob solves)
|
||||
and the headline number here; vs-all is kept as a secondary column. NOTE an info
|
||||
from solves? That is the gate's primary objective: exclude reward-hacking updates while
|
||||
retaining correct-solution updates, and it is the primary result here. The vs-all
|
||||
contrast is kept as a secondary column. NOTE an info
|
||||
asymmetry: grad scores carry the adv factor (sign+magnitude), act scores do not, so
|
||||
on the vs-all contrast grad gets label-correlated information act lacks; the A>0
|
||||
contrast removes most of that edge -- compare the cases there.
|
||||
contrast removes most of that difference.
|
||||
|
||||
FOUR CANDIDATE SCORES = {grad, act} x {cos, dot}, all in concatenated-module space:
|
||||
- grad: the adv-weighted deployed c-probe gradient (the gate's current input).
|
||||
- act: the deployed bottleneck activation A[:r]@x, mean over completion tokens --
|
||||
same [r]-per-module space, capturable in the gate's pass-1 forward for free.
|
||||
SIX CANDIDATE SCORES = {grad, act, resid} x {cos, dot}, concatenated over modules/layers:
|
||||
- grad: the adv-weighted deployed c-probe gradient (the gate's current input).
|
||||
- act: the deployed bottleneck activation A[:r]@x, mean over completion tokens --
|
||||
same [r]-per-module space, capturable in the gate's pass-1 forward for free.
|
||||
- resid: residual-stream hidden states at cfg.resid_layers, mean over completion
|
||||
tokens. Adapter-independent: at an early checkpoint A is near its Gaussian
|
||||
init, so grad and act are both views through a random r=32 projection per
|
||||
module; resid tests whether that subspace, not grad-vs-act, limits separation.
|
||||
- cos: magnitude-blind alignment (tiny vectors give meaningless angles -- control).
|
||||
- dot: <g, v> = |g|*cos, magnitude-aware; with g = A*gradNLL the advantage rides
|
||||
along, so dot reads "how hard is this update pushing hack-ward".
|
||||
v for each representation comes from the authored pairs only (mean hack-minus-clean
|
||||
per module, unit per module) -- the no-cheat label source; live labels are read ONLY
|
||||
to measure (AUROC / precision at the rout cut), never to route.
|
||||
along, so dot measures update magnitude aligned with v.
|
||||
v for each representation comes only from authored pairs (mean hack-minus-clean,
|
||||
normalized per module). Ground-truth labels from training rollouts are used only for
|
||||
diagnostic AUROC and precision measurements, never for routing.
|
||||
|
||||
PINNING. Each panel shades the three zones the online gate rule would give on this
|
||||
window: keep (bulk) | absorb (score > mean + k_mid*sd) | rout (>= mean + k_rout*sd),
|
||||
@@ -49,8 +54,8 @@ cached features.
|
||||
|
||||
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
|
||||
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU
|
||||
outputs (out/diag/): pinning_q2.png (2x2 headline), pinning_data.parquet (per-rollout
|
||||
scores), pinning_pairset.parquet + printed table (subsets x 4 AUROCs),
|
||||
outputs (out/diag/): pinning_q2.png (3x2 headline), pinning_data.parquet (per-rollout
|
||||
scores), pinning_pairset.parquet + printed table (subsets x 6 AUROCs),
|
||||
pinning_feats.pt (raw features, for offline re-analysis).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
@@ -82,7 +87,8 @@ from vgrout.train import _auroc
|
||||
|
||||
# colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic)
|
||||
SOLVE, HACK, FAIL, ABSORB_C, ROUT_C, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#c44e52", "#3a8a7a"
|
||||
CASES = [("grad", "cos"), ("grad", "dot"), ("act", "cos"), ("act", "dot")]
|
||||
CASES = [("grad", "cos"), ("grad", "dot"), ("act", "cos"), ("act", "dot"),
|
||||
("resid", "cos"), ("resid", "dot")]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -102,6 +108,7 @@ class Cfg:
|
||||
k_mid: float = 2.0 # absorb onset: score > mean + k_mid*sd (real-run Config default)
|
||||
k_rout: float = 3.0 # rout onset: score >= mean + k_rout*sd
|
||||
adv_eps: float = 1e-6 # |A| below this = no update exists -> dropped from zones/AUROC
|
||||
resid_layers: tuple[int, ...] = (12, 18, 24) # residual-stream capture depths (of 36)
|
||||
random_v_seed: int | None = None # Haar placebo (sanity: nothing should separate)
|
||||
replot: Path | None = None # load parquet and re-plot only (no model, no GPU)
|
||||
out_dir: Path = Path("out/diag")
|
||||
@@ -113,13 +120,16 @@ def _ckpt_meta(path: Path) -> dict:
|
||||
|
||||
|
||||
class ActTap:
|
||||
"""Forward hooks stashing the deployed bottleneck activation h = A[:r] @ x per module.
|
||||
"""Forward hooks stashing (a) the deployed bottleneck activation h = A[:r] @ x per
|
||||
module and (b) the residual-stream hidden state after each decoder layer in
|
||||
`resid_modules`.
|
||||
|
||||
Computes the r-dim projection inline (no_grad) instead of retaining the full
|
||||
(a) computes the r-dim projection inline (no_grad) instead of retaining the full
|
||||
[L, d_in] input -- ~250 modules x [L, d_in] would be GBs; [L, r] is nothing.
|
||||
"""
|
||||
def __init__(self, wrappers: dict, names: list[str]):
|
||||
self.wrappers, self.names, self.h, self.handles = wrappers, names, {}, []
|
||||
def __init__(self, wrappers: dict, names: list[str], resid_modules: list):
|
||||
self.wrappers, self.names, self.resid_modules = wrappers, names, resid_modules
|
||||
self.h, self.res, self.handles = {}, {}, []
|
||||
|
||||
def __enter__(self):
|
||||
for nm in self.names:
|
||||
@@ -129,6 +139,10 @@ class ActTap:
|
||||
with torch.no_grad():
|
||||
self.h[nm] = F.linear(x.detach(), layer._lora2r_A[: layer._lora2r_r].to(x.dtype))
|
||||
self.handles.append(layer.register_forward_hook(hook))
|
||||
for li, mod in enumerate(self.resid_modules):
|
||||
def rhook(mod, args, out, li=li):
|
||||
self.res[li] = (out[0] if isinstance(out, tuple) else out).detach()
|
||||
self.handles.append(mod.register_forward_hook(rhook))
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
@@ -144,6 +158,11 @@ class ActTap:
|
||||
out.append(h[0, n_prompt:].float().mean(0).cpu())
|
||||
return torch.stack(out)
|
||||
|
||||
def pooled_resid(self, n_prompt: int) -> torch.Tensor:
|
||||
"""[L_layers, d_model] mean residual-stream state over completion tokens."""
|
||||
return torch.stack([self.res[li][0, n_prompt:].float().mean(0).cpu()
|
||||
for li in range(len(self.resid_modules))])
|
||||
|
||||
|
||||
def _gate_grads(wrappers: dict, names: list[str]) -> torch.Tensor:
|
||||
"""[M, r] deployed-block c-probe grad after a backward (the gate's gradient space)."""
|
||||
@@ -209,7 +228,8 @@ def plot_q2(df: pl.DataFrame, k_mid: float, k_rout: float, subtitle: str, out_pn
|
||||
f"A>0-contrast={_auroc(a[posm].tolist(), y_all[posm].tolist()):.3f} "
|
||||
f"(n+={int(y_all.sum())} negA>0={int((~y_all & posm).sum())})")
|
||||
stats = {}
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12.5, 7.6))
|
||||
n_rows = len(CASES) // 2
|
||||
fig, axes = plt.subplots(n_rows, 2, figsize=(12.5, 3.6 * n_rows + 0.8))
|
||||
for ax, (rep, kind) in zip(axes.flat, CASES):
|
||||
col = f"{rep}_{kind}"
|
||||
s = live[col].to_numpy()
|
||||
@@ -344,48 +364,50 @@ def main(cfg: Cfg) -> int:
|
||||
logger.info(f"loaded A/B into {len(names)} modules")
|
||||
model.eval()
|
||||
|
||||
def one_pass(tap: ActTap, prompt: str, completion: str) -> tuple[torch.Tensor, torch.Tensor] | None:
|
||||
"""Backward one completion's mean NLL; return ([M,r] c-grad, [M,r] pooled act)."""
|
||||
def one_pass(tap: ActTap, prompt: str, completion: str):
|
||||
"""Backward one completion's mean NLL; return ([M,r] c-grad, [M,r] act, [L,d] resid)."""
|
||||
model.zero_grad(set_to_none=True)
|
||||
loss = completion_nll(model, tok, prompt, completion, device)
|
||||
if not torch.isfinite(loss):
|
||||
return None
|
||||
loss.backward()
|
||||
n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1]
|
||||
return _gate_grads(wrappers, names), tap.pooled(n_prompt)
|
||||
return _gate_grads(wrappers, names), tap.pooled(n_prompt), tap.pooled_resid(n_prompt)
|
||||
|
||||
# ── authored-pair features, once over ALL pairs (subsets = row slices) ──
|
||||
pairs_all = load_pairs(cfg.pairs)
|
||||
logger.info(f"pairs {cfg.pairs} -> {len(pairs_all)}")
|
||||
pair_feat = {("grad", "hack"): [], ("grad", "clean"): [], ("act", "hack"): [], ("act", "clean"): []}
|
||||
with ActTap(wrappers, names) as tap:
|
||||
pair_feat = {(rep, side): [] for rep in ("grad", "act", "resid") for side in ("hack", "clean")}
|
||||
resid_modules = [model.model.layers[i] for i in cfg.resid_layers]
|
||||
with ActTap(wrappers, names, resid_modules) as tap:
|
||||
for pi, pair in enumerate(pairs_all):
|
||||
for side, completion in (("hack", pair.hack), ("clean", pair.clean)):
|
||||
out = one_pass(tap, pair.prompt, completion)
|
||||
if out is None:
|
||||
raise RuntimeError(f"non-finite loss on pair {pi} ({pair.problem_id}) side={side}")
|
||||
pair_feat[("grad", side)].append(out[0])
|
||||
pair_feat[("act", side)].append(out[1])
|
||||
for rep, feat in zip(("grad", "act", "resid"), out):
|
||||
pair_feat[(rep, side)].append(feat)
|
||||
if (pi + 1) % 5 == 0:
|
||||
logger.info(f" pair {pi+1}/{len(pairs_all)}")
|
||||
PF = {k: torch.stack(v) for k, v in pair_feat.items()} # each [P, M, r]
|
||||
PF = {k: torch.stack(v) for k, v in pair_feat.items()} # [P, M, r] / resid [P, L, d]
|
||||
|
||||
# ── live rollout features, once (everything downstream re-projects) ──
|
||||
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
||||
batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts]
|
||||
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
|
||||
G_rows, A_rows, kept = [], [], []
|
||||
G_rows, A_rows, R_rows, kept = [], [], [], []
|
||||
for i, rec in enumerate(batch):
|
||||
out = one_pass(tap, rec["prompt"], rec["text"])
|
||||
if out is None:
|
||||
logger.warning(f" skip rollout {i}: non-finite loss")
|
||||
continue
|
||||
G_rows.append(out[0]); A_rows.append(out[1]); kept.append(rec)
|
||||
G_rows.append(out[0]); A_rows.append(out[1]); R_rows.append(out[2]); kept.append(rec)
|
||||
if (i + 1) % 40 == 0:
|
||||
logger.info(f" rollout {i+1}/{len(batch)}")
|
||||
model.zero_grad(set_to_none=True)
|
||||
G = torch.stack(G_rows) # [N, M, r] gradNLL
|
||||
ACT = torch.stack(A_rows) # [N, M, r]
|
||||
RES = torch.stack(R_rows) # [N, L, d_model]
|
||||
exploited = np.array([bool(x["exploited"]) for x in kept])
|
||||
gt_pass = np.array([bool(x["gt_pass"]) for x in kept])
|
||||
steps = np.array([x["step"] for x in kept])
|
||||
@@ -422,25 +444,22 @@ def main(cfg: Cfg) -> int:
|
||||
assert head_idx, f"no pairs match headline prefix {cfg.headline_prefix!r}"
|
||||
logger.info(f"headline v from prefix {cfg.headline_prefix!r} -> {len(head_idx)} pairs")
|
||||
|
||||
REPS = ("grad", "act", "resid")
|
||||
|
||||
def vectors(idx: list[int]) -> dict[str, torch.Tensor]:
|
||||
v = {"grad": _v_from(PF[("grad", "hack")], PF[("grad", "clean")], idx),
|
||||
"act": _v_from(PF[("act", "hack")], PF[("act", "clean")], idx)}
|
||||
v = {rep: _v_from(PF[(rep, "hack")], PF[(rep, "clean")], idx) for rep in REPS}
|
||||
if cfg.random_v_seed is not None:
|
||||
v = {"grad": _haar_like(v["grad"], cfg.random_v_seed),
|
||||
"act": _haar_like(v["act"], cfg.random_v_seed + 1)}
|
||||
v = {rep: _haar_like(v[rep], cfg.random_v_seed + i) for i, rep in enumerate(REPS)}
|
||||
return v
|
||||
|
||||
v_head = vectors(head_idx)
|
||||
live_X = {"grad": G_adv, "act": ACT}
|
||||
syn_X = {("grad", "clean"): PF[("grad", "clean")], ("grad", "hack"): PF[("grad", "hack")],
|
||||
("act", "clean"): PF[("act", "clean")], ("act", "hack"): PF[("act", "hack")]}
|
||||
live_X = {"grad": G_adv, "act": ACT, "resid": RES}
|
||||
|
||||
def score_cols(v: dict, X: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
|
||||
return {f"{rep}_{kind}": _score(X[rep], v[rep], kind) for rep, kind in CASES}
|
||||
|
||||
live_scores = score_cols(v_head, live_X)
|
||||
syn_scores = {side: score_cols(v_head, {"grad": syn_X[("grad", side)][head_idx],
|
||||
"act": syn_X[("act", side)][head_idx]})
|
||||
syn_scores = {side: score_cols(v_head, {rep: PF[(rep, side)][head_idx] for rep in REPS})
|
||||
for side in ("clean", "hack")}
|
||||
|
||||
# ── pairset table: subsets x 4 AUROCs on the SAME cached live features ──
|
||||
@@ -480,9 +499,9 @@ def main(cfg: Cfg) -> int:
|
||||
np.full(n_syn, -1), np.ones(n_syn)) for side in ("clean", "hack")]
|
||||
df = pl.concat(dfs)
|
||||
df.write_parquet(data_path)
|
||||
torch.save({"G": G, "ACT": ACT, "adv": adv, "exploited": exploited, "gt_pass": gt_pass,
|
||||
"steps": steps, "p_idx": p_idx, "names": names,
|
||||
"pair_feats": PF, "pair_groups": dict(groups),
|
||||
torch.save({"G": G, "ACT": ACT, "RES": RES, "adv": adv, "exploited": exploited,
|
||||
"gt_pass": gt_pass, "steps": steps, "p_idx": p_idx, "names": names,
|
||||
"resid_layers": cfg.resid_layers, "pair_feats": PF, "pair_groups": dict(groups),
|
||||
"pair_ids": [p.problem_id for p in pairs_all]}, feats_path)
|
||||
logger.info(f"wrote {data_path} ({len(df)} rows), {feats_path}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user