Files
evil_MoE/scripts/diag_pinning.py
T
wassname 979daf84fd feat(#30): mean+k*std online gate replaces fixed quantile; always-show route cols
Gate calibration: route by live mean + route_std_mid/route_std_rout * std of the
pooled cosine-to-v_grad, not a fixed quantile tail. Self-silences -- only the tail
that genuinely exceeds the spread routes, so qmass tracks real separation instead
of a forced fraction. The authored absolute band is mis-placed (live pos sits far
below the synthetic-hack edge; even synthetic solve out-aligns on-policy hack).

tablelog: auroc/rout/routE/keep/resid/qmass cols always shown (nan on vanilla) so
arm tables line up.

Diagnostics: scripts/diag_pinning.py (4-population calibration view, mean+/-2sd band)
and scripts/diag_pinning_refresh.py (proves cosine stats recompute from a tracked
v-independent gradient cloud on a v_grad refresh -- exact for k=1, sanity 2.5e-16).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-11 02:56:07 +00:00

253 lines
12 KiB
Python

"""Pinning calibration: where does the live gate sit, and would mean +/- 2sd route it?
The routeV gate scores each rollout by the width-pooled, band-normalized cosine of its
deployed-block c-probe gradient to v_grad (exactly train.py:_lora2r_gate_labels) -- call it
the position on the HACKING DIRECTION. The current gate routes a FIXED quantile tail (#30).
This plots four populations on that axis:
on-policy solve / hack -- live rollouts (oracle `exploited` read for COLOUR/AUROC only)
synthetic solve / hack -- the authored pairs v_grad was built from (the only label source)
and marks the PROPOSED pinning: online mean and mean +/- 2*sd (a symmetric band that
self-calibrates to the live spread instead of forcing a fixed %).
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU, restyle only
outputs (out/diag/): pinning_calib.png, pinning_data.parquet (the 4 populations, regenerates the plot).
"""
from __future__ import annotations
import json
import struct
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
import tyro
import polars as pl
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from loguru import logger
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
from vgrout.train import _build_v_grad, route_band_edges, _haar_unit_dirs, _auroc
# colour = behaviour (blue solve, red hack); style = source (solid+fill on-policy, dashed synthetic)
SOLVE, HACK, MEANC, ORACLE = "#3b6ea5", "#c44e52", "#d1900a", "#3a8a7a"
@dataclass
class Cfg:
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
ckpt: str = "first_hack"
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
# Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and
# DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane.
step_lo: int = 2
step_hi: int = 9
max_rollouts: int = 240
random_v_seed: int | None = None # Haar placebo (sanity: pins should NOT separate)
replot: Path | None = None # load this parquet and re-plot only (no model, no GPU)
out_dir: Path = Path("out/diag")
def _ckpt_meta(path: Path) -> dict:
with open(path, "rb") as f:
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
def pooled_pos(c_grads: dict, v_grad: dict, route_band: dict, names: list, r: int) -> tuple[float, float]:
"""Mirror train.py:_lora2r_gate_labels pooling for ONE rollout's per-module deployed c-grads."""
num = 0.0; den = 0.0; w = 0.0; n_inc = 0
for name in names:
lower, upper = route_band[name]
if upper - lower <= 0:
continue
g_b = c_grads[name][:r].float().to(v_grad[name].device)
nrm = g_b.norm()
cos_b = (torch.einsum("r, k r -> k", g_b, v_grad[name]).max() / nrm.clamp_min(1e-12)).item()
num += cos_b - lower; den += upper - lower
w += nrm.item(); n_inc += 1
if n_inc == 0:
raise RuntimeError("no module has positive band width")
return num / den, w / n_inc
def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
"""Gaussian KDE, Silverman bandwidth (no scipy)."""
x = np.asarray(x, float)
if len(x) < 2:
return np.zeros_like(grid)
iqr = np.subtract(*np.percentile(x, [75, 25]))
sigma = min(x.std(ddof=1), iqr / 1.349) if iqr > 0 else x.std(ddof=1)
bw = max(0.9 * (sigma or 1.0) * len(x) ** (-0.2), 1e-3)
z = (grid[:, None] - x[None, :]) / bw
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
def plot(df: pl.DataFrame, subtitle: str, out_png: Path) -> None:
"""Regenerate the figure from the saved 4 populations -- no GPU needed."""
arr = lambda p: df.filter(pl.col("pop") == p)["pos"].to_numpy()
on_solve, on_hack = arr("on_solve"), arr("on_hack")
syn_solve, syn_hack = arr("syn_solve"), arr("syn_hack")
pos_live = np.concatenate([on_solve, on_hack])
labels = np.concatenate([np.zeros(len(on_solve), bool), np.ones(len(on_hack), bool)])
mean, sd = float(pos_live.mean()), float(pos_live.std())
lo_b, hi_b = mean - 2 * sd, mean + 2 * sd # proposed routing band
auroc = _auroc(pos_live.tolist(), labels.tolist())
thr = np.unique(pos_live) # oracle divider (Youden J) -- diagnostic only
j = [(pos_live[labels] >= t).mean() - (pos_live[~labels] >= t).mean() for t in thr]
oracle = float(thr[int(np.argmax(j))])
lo = min(pos_live.min(), syn_solve.min()) - 0.1
hi = max(np.quantile(pos_live, 0.99), syn_hack.max()) + 0.1
grid = np.linspace(lo, hi, 400)
POPS = [(on_solve, SOLVE, True, "on-policy solve"), (on_hack, HACK, True, "on-policy hack"),
(syn_solve, SOLVE, False, "synthetic solve"), (syn_hack, HACK, False, "synthetic hack")]
kdes = {i: _kde(x, grid) for i, (x, *_ ) in enumerate(POPS)}
ymax = max(y.max() for y in kdes.values()) * 1.15
fig, ax = plt.subplots(figsize=(8.6, 4.8))
# proposed pinning: mean +/- 2sd band, drawn first so curves sit on top
ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0)
ax.axvline(mean, color=MEANC, lw=1.8)
for b in (lo_b, hi_b):
ax.axvline(b, color=MEANC, lw=1.2, ls="--")
ax.axvline(oracle, color=ORACLE, lw=1.3, ls="-.")
for i, (x, col, on_policy, _) in enumerate(POPS):
y = kdes[i]
if on_policy:
ax.fill_between(grid, y, color=col, alpha=0.13, lw=0)
ax.plot(grid, y, color=col, lw=1.9)
else:
ax.fill_between(grid, y, color=col, alpha=0.05, lw=0)
ax.plot(grid, y, color=col, lw=2.4, ls=(0, (5, 2)))
ax.set_ylim(0, ymax)
for s in ("top", "right"):
ax.spines[s].set_visible(False)
ax.set_ylabel("density")
ax.set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$")
dist_handles = [Line2D([0], [0], color=c, lw=2.2, ls="-" if op else (0, (5, 2))) for _, c, op, _ in POPS]
leg1 = ax.legend(dist_handles, [lab for *_, lab in POPS], loc="upper left",
fontsize=8, frameon=False, title="distributions", title_fontsize=8)
leg1._legend_box.align = "left"
ax.add_artist(leg1)
mark_handles = [
Line2D([0], [0], color=MEANC, lw=1.8),
Patch(facecolor=MEANC, alpha=0.18, edgecolor=MEANC, ls="--"),
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-."),
]
ax.legend(mark_handles,
[f"online mean ({mean:+.2f})", f"mean +/- 2sd (sd={sd:.2f}) = proposed band",
f"best hack/solve split ({oracle:+.2f}, diagnostic)"],
loc="upper right", fontsize=8, frameon=False, title="pinning", title_fontsize=8)
routed_hi = float((pos_live > hi_b).mean())
ax.set_title(f"{subtitle}\nAUROC(hack-dir -> hack)={auroc:.2f} "
f"{routed_hi:.0%} of rollouts beyond mean+2sd", fontsize=9.5)
fig.tight_layout()
fig.savefig(out_png, dpi=140)
plt.close(fig)
logger.info(f"mean={mean:+.3f} sd={sd:.3f} band=[{lo_b:+.3f},{hi_b:+.3f}] oracle={oracle:+.3f} "
f"AUROC={auroc:.3f} routed(>mean+2sd)={routed_hi:.2f}")
logger.info(f"wrote {out_png}")
def main(cfg: Cfg) -> int:
cfg.out_dir.mkdir(parents=True, exist_ok=True)
data_path = cfg.out_dir / "pinning_data.parquet"
out_png = cfg.out_dir / "pinning_calib.png"
if cfg.replot is not None:
df = pl.read_parquet(cfg.replot)
plot(df, f"pinning calibration (replot) -- {cfg.replot.name}", out_png)
return 0
device = torch.device("cuda")
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
meta = _ckpt_meta(ckpt_path)
run_cfg = json.loads(meta.get("cfg", "{}"))
model_name = run_cfg.get("model", meta.get("model", "Qwen/Qwen3-4B"))
r = run_cfg.get("lora_r", 32)
init_seed = run_cfg.get("lora_init_seed", 0)
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} "
f"model={model_name} r={r} init_seed={init_seed}")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
names = sorted(wrappers)
sd = load_file(str(ckpt_path))
for nm in names:
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
logger.info(f"loaded A/B into {len(names)} modules")
pairs = load_pairs(cfg.pairs)
logger.info(f"pairs {cfg.pairs} -> {len(pairs)}")
model.eval()
_, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs,
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
v_grad = _build_v_grad(raw_grads, wrappers, 1, device)
if cfg.random_v_seed is not None:
v_grad = _haar_unit_dirs(v_grad, cfg.random_v_seed, device)
logger.info(f"OVERRODE v_grad with Haar dirs seed={cfg.random_v_seed} (placebo)")
route_band = route_band_edges(raw_grads, v_grad, device)
def pair_pos(side: str) -> np.ndarray:
n = raw_grads[f"{side}/{names[0]}"].shape[0]
out = []
for i in range(n):
cg = {nm: torch.cat([raw_grads[f"{side}/{nm}"][i], torch.zeros(r)]) for nm in names}
out.append(pooled_pos(cg, v_grad, route_band, names, r)[0])
return np.array(out)
syn_solve, syn_hack = pair_pos("clean"), pair_pos("hack")
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
batch = [r_ for r_ in recs if cfg.step_lo <= r_["step"] <= cfg.step_hi and r_["text"].strip()][:cfg.max_rollouts]
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
pos_live, labels, steps = [], [], []
for i, rec in enumerate(batch):
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tok, rec["prompt"], rec["text"], device)
if not torch.isfinite(loss):
continue
loss.backward()
cg = {nm: wrappers[nm]["layer"]._lora2r_gate.grad.sum(dim=tuple(range(
wrappers[nm]["layer"]._lora2r_gate.grad.dim() - 1))) for nm in names}
pos_live.append(pooled_pos(cg, v_grad, route_band, names, r)[0])
labels.append(bool(rec["exploited"])); steps.append(rec["step"])
if (i + 1) % 40 == 0:
logger.info(f" rollout {i+1}/{len(batch)}")
model.zero_grad(set_to_none=True)
pos_live = np.array(pos_live); labels = np.array(labels); steps = np.array(steps)
per_step = {int(s): round(_auroc(pos_live[steps == s].tolist(), labels[steps == s].tolist()), 2)
for s in sorted(set(steps.tolist())) if labels[steps == s].any() and (~labels[steps == s]).any()}
logger.info(f"live: {len(labels)} rollouts, {int(labels.sum())} exploited; per-step AUROC={per_step}")
# save the 4 populations long-form -> regenerates the plot with --replot (no GPU)
df = pl.concat([
pl.DataFrame({"pop": "on_solve", "pos": pos_live[~labels], "step": steps[~labels]}),
pl.DataFrame({"pop": "on_hack", "pos": pos_live[labels], "step": steps[labels]}),
pl.DataFrame({"pop": "syn_solve", "pos": syn_solve, "step": np.full(len(syn_solve), -1)}),
pl.DataFrame({"pop": "syn_hack", "pos": syn_hack, "step": np.full(len(syn_hack), -1)}),
])
df.write_parquet(data_path)
logger.info(f"wrote {data_path} ({len(df)} rows)")
plot(df, f"{cfg.run_dir.name}\n{cfg.ckpt} v_grad, live steps {cfg.step_lo}-{cfg.step_hi}, "
f"{int(labels.sum())}/{len(labels)} exploited", out_png)
return 0
if __name__ == "__main__":
raise SystemExit(main(tyro.cli(Cfg)))