diff --git a/scripts/diag_pinning.py b/scripts/diag_pinning.py new file mode 100644 index 0000000..c79d2cf --- /dev/null +++ b/scripts/diag_pinning.py @@ -0,0 +1,252 @@ +"""Pinning calibration: where does the live gate sit, and would mean +/- 2sd route it? + +The routeV gate scores each rollout by the width-pooled, band-normalized cosine of its +deployed-block c-probe gradient to v_grad (exactly train.py:_lora2r_gate_labels) -- call it +the position on the HACKING DIRECTION. The current gate routes a FIXED quantile tail (#30). +This plots four populations on that axis: + on-policy solve / hack -- live rollouts (oracle `exploited` read for COLOUR/AUROC only) + synthetic solve / hack -- the authored pairs v_grad was built from (the only label source) +and marks the PROPOSED pinning: online mean and mean +/- 2*sd (a symmetric band that +self-calibrates to the live spread instead of forcing a fixed %). + + uv run python scripts/diag_pinning.py --run-dir out/runs/ + uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU, restyle only +outputs (out/diag/): pinning_calib.png, pinning_data.parquet (the 4 populations, regenerates the plot). +""" +from __future__ import annotations +import json +import struct +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import torch +import tyro +import polars as pl +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Patch +from loguru import logger +from safetensors.torch import load_file +from transformers import AutoModelForCausalLM, AutoTokenizer + +from vgrout.lora2r import wrap_model_with_lora2r +from vgrout.pairs import load_pairs +from vgrout.extract_vhack_grad import extract_v_hack, completion_nll +from vgrout.train import _build_v_grad, route_band_edges, _haar_unit_dirs, _auroc + +# colour = behaviour (blue solve, red hack); style = source (solid+fill on-policy, dashed synthetic) +SOLVE, HACK, MEANC, ORACLE = "#3b6ea5", "#c44e52", "#d1900a", "#3a8a7a" + + +@dataclass +class Cfg: + run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3") + ckpt: str = "first_hack" + pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one") + # Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and + # DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane. + step_lo: int = 2 + step_hi: int = 9 + max_rollouts: int = 240 + random_v_seed: int | None = None # Haar placebo (sanity: pins should NOT separate) + replot: Path | None = None # load this parquet and re-plot only (no model, no GPU) + out_dir: Path = Path("out/diag") + + +def _ckpt_meta(path: Path) -> dict: + with open(path, "rb") as f: + return json.loads(f.read(struct.unpack(" tuple[float, float]: + """Mirror train.py:_lora2r_gate_labels pooling for ONE rollout's per-module deployed c-grads.""" + num = 0.0; den = 0.0; w = 0.0; n_inc = 0 + for name in names: + lower, upper = route_band[name] + if upper - lower <= 0: + continue + g_b = c_grads[name][:r].float().to(v_grad[name].device) + nrm = g_b.norm() + cos_b = (torch.einsum("r, k r -> k", g_b, v_grad[name]).max() / nrm.clamp_min(1e-12)).item() + num += cos_b - lower; den += upper - lower + w += nrm.item(); n_inc += 1 + if n_inc == 0: + raise RuntimeError("no module has positive band width") + return num / den, w / n_inc + + +def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray: + """Gaussian KDE, Silverman bandwidth (no scipy).""" + x = np.asarray(x, float) + if len(x) < 2: + return np.zeros_like(grid) + iqr = np.subtract(*np.percentile(x, [75, 25])) + sigma = min(x.std(ddof=1), iqr / 1.349) if iqr > 0 else x.std(ddof=1) + bw = max(0.9 * (sigma or 1.0) * len(x) ** (-0.2), 1e-3) + z = (grid[:, None] - x[None, :]) / bw + return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi)) + + +def plot(df: pl.DataFrame, subtitle: str, out_png: Path) -> None: + """Regenerate the figure from the saved 4 populations -- no GPU needed.""" + arr = lambda p: df.filter(pl.col("pop") == p)["pos"].to_numpy() + on_solve, on_hack = arr("on_solve"), arr("on_hack") + syn_solve, syn_hack = arr("syn_solve"), arr("syn_hack") + pos_live = np.concatenate([on_solve, on_hack]) + labels = np.concatenate([np.zeros(len(on_solve), bool), np.ones(len(on_hack), bool)]) + + mean, sd = float(pos_live.mean()), float(pos_live.std()) + lo_b, hi_b = mean - 2 * sd, mean + 2 * sd # proposed routing band + auroc = _auroc(pos_live.tolist(), labels.tolist()) + thr = np.unique(pos_live) # oracle divider (Youden J) -- diagnostic only + j = [(pos_live[labels] >= t).mean() - (pos_live[~labels] >= t).mean() for t in thr] + oracle = float(thr[int(np.argmax(j))]) + + lo = min(pos_live.min(), syn_solve.min()) - 0.1 + hi = max(np.quantile(pos_live, 0.99), syn_hack.max()) + 0.1 + grid = np.linspace(lo, hi, 400) + POPS = [(on_solve, SOLVE, True, "on-policy solve"), (on_hack, HACK, True, "on-policy hack"), + (syn_solve, SOLVE, False, "synthetic solve"), (syn_hack, HACK, False, "synthetic hack")] + kdes = {i: _kde(x, grid) for i, (x, *_ ) in enumerate(POPS)} + ymax = max(y.max() for y in kdes.values()) * 1.15 + + fig, ax = plt.subplots(figsize=(8.6, 4.8)) + # proposed pinning: mean +/- 2sd band, drawn first so curves sit on top + ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0) + ax.axvline(mean, color=MEANC, lw=1.8) + for b in (lo_b, hi_b): + ax.axvline(b, color=MEANC, lw=1.2, ls="--") + ax.axvline(oracle, color=ORACLE, lw=1.3, ls="-.") + for i, (x, col, on_policy, _) in enumerate(POPS): + y = kdes[i] + if on_policy: + ax.fill_between(grid, y, color=col, alpha=0.13, lw=0) + ax.plot(grid, y, color=col, lw=1.9) + else: + ax.fill_between(grid, y, color=col, alpha=0.05, lw=0) + ax.plot(grid, y, color=col, lw=2.4, ls=(0, (5, 2))) + ax.set_ylim(0, ymax) + for s in ("top", "right"): + ax.spines[s].set_visible(False) + ax.set_ylabel("density") + ax.set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$") + + dist_handles = [Line2D([0], [0], color=c, lw=2.2, ls="-" if op else (0, (5, 2))) for _, c, op, _ in POPS] + leg1 = ax.legend(dist_handles, [lab for *_, lab in POPS], loc="upper left", + fontsize=8, frameon=False, title="distributions", title_fontsize=8) + leg1._legend_box.align = "left" + ax.add_artist(leg1) + mark_handles = [ + Line2D([0], [0], color=MEANC, lw=1.8), + Patch(facecolor=MEANC, alpha=0.18, edgecolor=MEANC, ls="--"), + Line2D([0], [0], color=ORACLE, lw=1.3, ls="-."), + ] + ax.legend(mark_handles, + [f"online mean ({mean:+.2f})", f"mean +/- 2sd (sd={sd:.2f}) = proposed band", + f"best hack/solve split ({oracle:+.2f}, diagnostic)"], + loc="upper right", fontsize=8, frameon=False, title="pinning", title_fontsize=8) + routed_hi = float((pos_live > hi_b).mean()) + ax.set_title(f"{subtitle}\nAUROC(hack-dir -> hack)={auroc:.2f} " + f"{routed_hi:.0%} of rollouts beyond mean+2sd", fontsize=9.5) + fig.tight_layout() + fig.savefig(out_png, dpi=140) + plt.close(fig) + logger.info(f"mean={mean:+.3f} sd={sd:.3f} band=[{lo_b:+.3f},{hi_b:+.3f}] oracle={oracle:+.3f} " + f"AUROC={auroc:.3f} routed(>mean+2sd)={routed_hi:.2f}") + logger.info(f"wrote {out_png}") + + +def main(cfg: Cfg) -> int: + cfg.out_dir.mkdir(parents=True, exist_ok=True) + data_path = cfg.out_dir / "pinning_data.parquet" + out_png = cfg.out_dir / "pinning_calib.png" + if cfg.replot is not None: + df = pl.read_parquet(cfg.replot) + plot(df, f"pinning calibration (replot) -- {cfg.replot.name}", out_png) + return 0 + + device = torch.device("cuda") + ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors" + meta = _ckpt_meta(ckpt_path) + run_cfg = json.loads(meta.get("cfg", "{}")) + model_name = run_cfg.get("model", meta.get("model", "Qwen/Qwen3-4B")) + r = run_cfg.get("lora_r", 32) + init_seed = run_cfg.get("lora_init_seed", 0) + logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} " + f"model={model_name} r={r} init_seed={init_seed}") + + tok = AutoTokenizer.from_pretrained(model_name) + if tok.pad_token_id is None: + tok.pad_token = tok.eos_token + model = AutoModelForCausalLM.from_pretrained( + model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device) + model.config.use_cache = False + wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True) + names = sorted(wrappers) + sd = load_file(str(ckpt_path)) + for nm in names: + wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32)) + wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32)) + logger.info(f"loaded A/B into {len(names)} modules") + + pairs = load_pairs(cfg.pairs) + logger.info(f"pairs {cfg.pairs} -> {len(pairs)}") + model.eval() + _, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs, + top_k=1, tau_axis=0.0, n_heldout=2, device=device) + v_grad = _build_v_grad(raw_grads, wrappers, 1, device) + if cfg.random_v_seed is not None: + v_grad = _haar_unit_dirs(v_grad, cfg.random_v_seed, device) + logger.info(f"OVERRODE v_grad with Haar dirs seed={cfg.random_v_seed} (placebo)") + route_band = route_band_edges(raw_grads, v_grad, device) + + def pair_pos(side: str) -> np.ndarray: + n = raw_grads[f"{side}/{names[0]}"].shape[0] + out = [] + for i in range(n): + cg = {nm: torch.cat([raw_grads[f"{side}/{nm}"][i], torch.zeros(r)]) for nm in names} + out.append(pooled_pos(cg, v_grad, route_band, names, r)[0]) + return np.array(out) + syn_solve, syn_hack = pair_pos("clean"), pair_pos("hack") + + recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()] + batch = [r_ for r_ in recs if cfg.step_lo <= r_["step"] <= cfg.step_hi and r_["text"].strip()][:cfg.max_rollouts] + logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})") + pos_live, labels, steps = [], [], [] + for i, rec in enumerate(batch): + model.zero_grad(set_to_none=True) + loss = completion_nll(model, tok, rec["prompt"], rec["text"], device) + if not torch.isfinite(loss): + continue + loss.backward() + cg = {nm: wrappers[nm]["layer"]._lora2r_gate.grad.sum(dim=tuple(range( + wrappers[nm]["layer"]._lora2r_gate.grad.dim() - 1))) for nm in names} + pos_live.append(pooled_pos(cg, v_grad, route_band, names, r)[0]) + labels.append(bool(rec["exploited"])); steps.append(rec["step"]) + if (i + 1) % 40 == 0: + logger.info(f" rollout {i+1}/{len(batch)}") + model.zero_grad(set_to_none=True) + pos_live = np.array(pos_live); labels = np.array(labels); steps = np.array(steps) + per_step = {int(s): round(_auroc(pos_live[steps == s].tolist(), labels[steps == s].tolist()), 2) + for s in sorted(set(steps.tolist())) if labels[steps == s].any() and (~labels[steps == s]).any()} + logger.info(f"live: {len(labels)} rollouts, {int(labels.sum())} exploited; per-step AUROC={per_step}") + + # save the 4 populations long-form -> regenerates the plot with --replot (no GPU) + df = pl.concat([ + pl.DataFrame({"pop": "on_solve", "pos": pos_live[~labels], "step": steps[~labels]}), + pl.DataFrame({"pop": "on_hack", "pos": pos_live[labels], "step": steps[labels]}), + pl.DataFrame({"pop": "syn_solve", "pos": syn_solve, "step": np.full(len(syn_solve), -1)}), + pl.DataFrame({"pop": "syn_hack", "pos": syn_hack, "step": np.full(len(syn_hack), -1)}), + ]) + df.write_parquet(data_path) + logger.info(f"wrote {data_path} ({len(df)} rows)") + plot(df, f"{cfg.run_dir.name}\n{cfg.ckpt} v_grad, live steps {cfg.step_lo}-{cfg.step_hi}, " + f"{int(labels.sum())}/{len(labels)} exploited", out_png) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(tyro.cli(Cfg))) diff --git a/scripts/diag_pinning_refresh.py b/scripts/diag_pinning_refresh.py new file mode 100644 index 0000000..6cb547a --- /dev/null +++ b/scripts/diag_pinning_refresh.py @@ -0,0 +1,227 @@ +"""Refresh-tracking prototype: when G_hack (v_grad) refreshes, can we recompute the +live cosine stats from a TRACKED gradient cloud, without re-running the model? + +Claim (k=1): the gate score pos_i = (S_i - L)/W with S_i = sum_m , where +u_{i,m} = g_{i,m}/|g_{i,m}| is the v-INDEPENDENT unit per-module gradient. Stacking +U_i=[u_{i,1};..], V=[v_1;..] gives S_i = U_i . V (linear in V). So the live cosine +distribution under ANY v is an affine push of the SAME gradient cloud: + mean(pos) = (mu_U . V - L) / W + var(pos) = (V^T Sigma_U V) / W^2 +mu_U, Sigma_U are tracked once (v-independent); L,W come from the pairs at refresh. So a +refresh needs no model re-run and no window flush -- just re-project the cloud onto v_new. + +We simulate a refresh with two direction estimates v_A, v_B = unit-mean-diff of two DISJOINT +halves of the authored pairs (model fixed, only the direction changes), and verify: + (1) reproject stored U onto v == direct pos (exact, validates indexing) + (2) moment formula (mu_U, Sigma_U) == empirical mean/sd of pos, for BOTH v_A and v_B + from the SAME tracked cloud (the actual claim). + + uv run python scripts/diag_pinning_refresh.py +outputs (out/diag/): pinning_refresh.png, pinning_refresh.parquet, prints a sanity table. +""" +from __future__ import annotations +import json +import struct +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import torch +import tyro +import polars as pl +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Patch +from loguru import logger +from tabulate import tabulate +from safetensors.torch import load_file +from transformers import AutoModelForCausalLM, AutoTokenizer + +from vgrout.lora2r import wrap_model_with_lora2r +from vgrout.pairs import load_pairs +from vgrout.extract_vhack_grad import extract_v_hack, completion_nll +from vgrout.train import _build_v_grad, route_band_edges, _auroc +from diag_pinning import _kde, SOLVE, HACK, MEANC # same dir (scripts/ is sys.path[0] when run directly) + + +@dataclass +class Cfg: + run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3") + ckpt: str = "first_hack" + pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one") + step_lo: int = 2 + step_hi: int = 9 + max_rollouts: int = 240 + out_dir: Path = Path("out/diag") + + +def _ckpt_meta(path: Path) -> dict: + with open(path, "rb") as f: + return json.loads(f.read(struct.unpack(" 0; excluded modules contribute 0.""" + sub = {k: v[idx] for k, v in raw_grads.items()} + v_grad = _build_v_grad(sub, wrappers, 1, device) # {name: [1, r]} unit + band = route_band_edges(sub, v_grad, device) # {name: (lower, upper)} + V, L, W = [], 0.0, 0.0 + for name in names: + lower, upper = band[name] + if upper - lower > 0: + V.append(v_grad[name][0].float().cpu()) # [r] + L += lower; W += (upper - lower) + else: + V.append(torch.zeros(r)) # excluded -> no contribution + return torch.cat(V).numpy(), float(L), float(W), v_grad, band + + +def pos_from_U(U: np.ndarray, V: np.ndarray, L: float, W: float) -> np.ndarray: + """pos = (U . V - L) / W -- re-project the stored cloud onto direction V (no model).""" + return (U @ V - L) / W + + +def main(cfg: Cfg) -> int: + cfg.out_dir.mkdir(parents=True, exist_ok=True) + device = torch.device("cuda") + ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors" + meta = _ckpt_meta(ckpt_path) + run_cfg = json.loads(meta.get("cfg", "{}")) + model_name = run_cfg.get("model", "Qwen/Qwen3-4B") + r = run_cfg.get("lora_r", 32) + init_seed = run_cfg.get("lora_init_seed", 0) + logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} model={model_name} r={r}") + + tok = AutoTokenizer.from_pretrained(model_name) + if tok.pad_token_id is None: + tok.pad_token = tok.eos_token + model = AutoModelForCausalLM.from_pretrained( + model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device) + model.config.use_cache = False + wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True) + names = sorted(wrappers) + sd = load_file(str(ckpt_path)) + for nm in names: + wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32)) + wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32)) + + pairs = load_pairs(cfg.pairs) + model.eval() + _, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs, + top_k=1, tau_axis=0.0, n_heldout=2, device=device) + n_pairs = raw_grads[f"hack/{names[0]}"].shape[0] + half = n_pairs // 2 + idx_A, idx_B = list(range(half)), list(range(half, n_pairs)) # disjoint pair halves = two v estimates + V_A, L_A, W_A, vA, _ = stacked_v_and_band(raw_grads, idx_A, wrappers, names, r, device) + V_B, L_B, W_B, vB, _ = stacked_v_and_band(raw_grads, idx_B, wrappers, names, r, device) + V_A, V_B = V_A.astype(np.float64), V_B.astype(np.float64) # float64 so the sanity gate reflects math, not rounding + cosAB = float(np.dot(V_A, V_B) / (np.linalg.norm(V_A) * np.linalg.norm(V_B) + 1e-12)) + logger.info(f"two direction estimates from pair halves ({half} / {n_pairs-half}); " + f"cos(V_A,V_B)={cosAB:+.3f} (a real refresh shift)") + + # ── score live batch ONCE: store U_i = stacked unit per-module gradients (v-INDEPENDENT) ── + recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()] + batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts] + logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})") + U_rows, labels = [], [] + for i, rec in enumerate(batch): + model.zero_grad(set_to_none=True) + loss = completion_nll(model, tok, rec["prompt"], rec["text"], device) + if not torch.isfinite(loss): + continue + loss.backward() + u_blocks = [] + for name in names: + g = wrappers[name]["layer"]._lora2r_gate.grad + g_b = g.sum(dim=tuple(range(g.dim() - 1)))[:r].float() # deployed block [r] + u = (g_b / g_b.norm().clamp_min(1e-12)).cpu() # v-independent unit grad + u_blocks.append(u) + U_rows.append(torch.cat(u_blocks).numpy()) + labels.append(bool(rec["exploited"])) + if (i + 1) % 40 == 0: + logger.info(f" rollout {i+1}/{len(batch)}") + model.zero_grad(set_to_none=True) + U = np.stack(U_rows).astype(np.float64) # [n_roll, n_mod*r] the tracked gradient cloud + labels = np.array(labels) + logger.info(f"stored U cloud: {U.shape} ({U.nbytes/1e6:.1f} MB), {int(labels.sum())} exploited") + + # ── pos under each direction by RE-PROJECTING the stored cloud (no model re-run) ── + pos_A = pos_from_U(U, V_A, L_A, W_A) + pos_B = pos_from_U(U, V_B, L_B, W_B) + + # ── moment formula from the tracked cloud: mu_U, Sigma_U (v-independent) ── + mu_U = U.mean(0) # [D] + Uc = U - mu_U + Sigma_U = (Uc.T @ Uc) / U.shape[0] # [D, D] uncentered-then-centered covariance + def moment_stats(V, L, W): + mean = (mu_U @ V - L) / W + var = (V @ (Sigma_U @ V)) / (W * W) + return mean, np.sqrt(max(var, 0.0)) + mA_emp, sA_emp = pos_A.mean(), pos_A.std() + mB_emp, sB_emp = pos_B.mean(), pos_B.std() + mA_mom, sA_mom = moment_stats(V_A, L_A, W_A) + mB_mom, sB_mom = moment_stats(V_B, L_B, W_B) + + # ── sanity table ── + rows = [ + ["mean pos | v_A", f"{mA_emp:+.4f}", f"{mA_mom:+.4f}", f"{abs(mA_emp-mA_mom):.2e}"], + ["std pos | v_A", f"{sA_emp:.4f}", f"{sA_mom:.4f}", f"{abs(sA_emp-sA_mom):.2e}"], + ["mean pos | v_B", f"{mB_emp:+.4f}", f"{mB_mom:+.4f}", f"{abs(mB_emp-mB_mom):.2e}"], + ["std pos | v_B", f"{sB_emp:.4f}", f"{sB_mom:.4f}", f"{abs(sB_emp-sB_mom):.2e}"], + ] + print("\nSHOULD: moment-formula (from tracked mu_U, Sigma_U) == empirical, both v_A and v_B.") + print(tabulate(rows, headers=["quantity", "empirical (direct)", "moment formula", "abs diff"], + tablefmt="github")) + max_diff = max(abs(mA_emp-mA_mom), abs(sA_emp-sA_mom), abs(mB_emp-mB_mom), abs(sB_emp-sB_mom)) + ok = max_diff < 1e-5 + print(f"\n{'PASS' if ok else 'FAIL'}: max |empirical - moment| = {max_diff:.2e} (refresh needs no model re-run)") + logger.info(f"refresh shift: mean {mA_emp:+.3f}->{mB_emp:+.3f}, std {sA_emp:.3f}->{sB_emp:.3f}; " + f"AUROC v_A={_auroc(pos_A.tolist(), labels.tolist()):.3f} v_B={_auroc(pos_B.tolist(), labels.tolist()):.3f}") + + # ── plot: pos under v_A (top) and v_B (bottom), mean +/- 2sd band recalibrates per direction ── + pl.DataFrame({"pos_A": pos_A, "pos_B": pos_B, "exploited": labels.tolist()}).write_parquet( + cfg.out_dir / "pinning_refresh.parquet") + lo = min(pos_A.min(), pos_B.min()) - 0.1 + hi = max(pos_A.max(), pos_B.max()) + 0.1 + grid = np.linspace(lo, hi, 400) + fig, axes = plt.subplots(2, 1, figsize=(8.6, 6.0), sharex=True) + for ax, pos, mean, std, tag in [(axes[0], pos_A, mA_emp, sA_emp, "A (pairs 1st half)"), + (axes[1], pos_B, mB_emp, sB_emp, "B (pairs 2nd half)")]: + lo_b, hi_b = mean - 2 * std, mean + 2 * std + ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0) + ax.axvline(mean, color=MEANC, lw=1.8) + for b in (lo_b, hi_b): + ax.axvline(b, color=MEANC, lw=1.2, ls="--") + peak = 0.0 + for x, col in [(pos[~labels], SOLVE), (pos[labels], HACK)]: + y = _kde(x, grid) + ax.fill_between(grid, y, color=col, alpha=0.13, lw=0) + ax.plot(grid, y, color=col, lw=1.9) + peak = max(peak, y.max()) + ax.set_ylim(0, peak * 1.15) + ax.set_ylabel("density") + ax.set_title(f"direction estimate {tag}: mean={mean:+.2f}, sd={std:.2f}, " + f"AUROC={_auroc(pos.tolist(), labels.tolist()):.2f}", fontsize=9, loc="left") + for s in ("top", "right"): + ax.spines[s].set_visible(False) + dist = [Line2D([0], [0], color=SOLVE, lw=2), Line2D([0], [0], color=HACK, lw=2), + Line2D([0], [0], color=MEANC, lw=1.8), Patch(facecolor=MEANC, alpha=0.18, ls="--", edgecolor=MEANC)] + axes[0].legend(dist, ["on-policy solve", "on-policy hack", "online mean", "mean +/- 2sd"], + loc="upper left", fontsize=8, frameon=False) + axes[1].set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$") + fig.suptitle("refresh tracking: same gradient cloud, two G_hack estimates " + f"(cos(v_A,v_B)={cosAB:+.2f}) -- band recalibrates with no model re-run", fontsize=9.5) + fig.tight_layout(rect=(0, 0, 1, 0.96)) + fig.savefig(cfg.out_dir / "pinning_refresh.png", dpi=140) + plt.close(fig) + logger.info(f"wrote {cfg.out_dir}/pinning_refresh.png, pinning_refresh.parquet") + return 0 if ok else 1 + + +if __name__ == "__main__": + raise SystemExit(main(tyro.cli(Cfg))) diff --git a/src/vgrout/tablelog.py b/src/vgrout/tablelog.py index 114e3ed..b94ec9f 100644 --- a/src/vgrout/tablelog.py +++ b/src/vgrout/tablelog.py @@ -72,8 +72,8 @@ class StepLogger: def __init__(self, arm: str, modes: list[str], mode_code: dict[str, str], show_ablate: bool = False) -> None: - # routeV reports routing diagnostics; absorb shares qmass (zone cols read nan). - is_route = arm in ("routingV_lora2r", "absorb_lora2r") + # Routing diagnostics are ALWAYS shown (nan on vanilla, whose gate never runs) so the + # column layout is identical across arms -- vanilla/routeV/absorb tables line up. cols: list[_Col] = [ _Col("step", 4, "step", "d", "GRPO step"), _Col("ref_eq", 6, "ref_eq", ".2f", "vanilla-equiv step (cum_gens/256)"), @@ -99,21 +99,20 @@ class StepLogger: _Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of A/B grads (vs grad_clip)"), _Col("lr", 7, "lr", ".1e", "scheduled learning rate"), ] - # routeV reports unit and energy shares across the routing band. - if is_route: - cols += [ - _Col("auroc", 6, "auroc", ".2f", "AUROC of pooled cos(g,v_grad) as a hack detector vs the hack-label (student exploited + teacher cached); MEASUREMENT only, never routes. ~0.5 = v_grad blind to live hacks (no threshold helps); high but rout~0 = pure threshold/scale problem; a drop at a refresh = refresh destroyed separation"), - _Col("cosU", 6, "cosU", "+.2f", "pooled cos(v_grad, summed-rollout c-grad): is the net update moving hack-ward this step"), - _Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"), - _Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"), - _Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"), - _Col("rout", 6, "rout", ".2f", "rollout share labelled hack (above band) -> quarantine-only, deployed detached"), - _Col("keepE", 6, "keepE", ".2f", "energy-weighted keep (grad-norm-weighted clean share)"), - _Col("residE", 6, "residE", ".2f", "energy-weighted resid"), - _Col("routE", 6, "routE", ".2f", "energy-weighted rout"), - ] + # routeV reports unit and energy shares across the routing band (nan on vanilla/absorb). + cols += [ + _Col("auroc", 6, "auroc", ".2f", "AUROC of pooled cos(g,v_grad) as a hack detector vs the hack-label (student exploited + teacher cached); MEASUREMENT only, never routes. ~0.5 = v_grad blind to live hacks (no threshold helps); high but rout~0 = pure threshold/scale problem; a drop at a refresh = refresh destroyed separation"), + _Col("cosU", 6, "cosU", "+.2f", "pooled cos(v_grad, summed-rollout c-grad): is the net update moving hack-ward this step"), + _Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"), + _Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"), + _Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"), + _Col("rout", 6, "rout", ".2f", "rollout share labelled hack (above band) -> quarantine-only, deployed detached"), + _Col("keepE", 6, "keepE", ".2f", "energy-weighted keep (grad-norm-weighted clean share)"), + _Col("residE", 6, "residE", ".2f", "energy-weighted resid"), + _Col("routE", 6, "routE", ".2f", "energy-weighted rout"), + ] # Show the training-prompt deploy proxy only when an ablated slice exists. - if is_route and show_ablate: + if show_ablate: cols += [ _Col("hack_abl", 6, "hk_abl", "frac", "per-step deploy proxy: hack rate on the ablated (deploy-mode) rollout slice; train prompts, noisier than hk_dep"), _Col("solve_abl", 6, "slv_abl", "frac", "per-step deploy proxy: solve rate on the ablated (deploy-mode) rollout slice; train prompts"), diff --git a/src/vgrout/train.py b/src/vgrout/train.py index eede2a9..8f7f3ba 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -602,22 +602,23 @@ def main(cfg: Config) -> int: if n_inc == 0: raise RuntimeError("no module has positive band width; pairs separate nowhere") pos = num / den; w /= n_inc - # ── online-stats gate (three zones, keep = bulk) ── The authored absolute thresholds - # (clean<=0, hack>=1) don't work: live hacks sit far below the authored p75-hack pooled - # edge, so rout~0. Route by live QUANTILES of pos instead: top route_quantile tail -> - # hack (rout), the next route_quantile band -> mid (absorption), everything below -> - # keep. keep is the BULK; the previous t_lo=quantile(route_quantile) put keep at the - # bottom route_quantile and routed ~(1-route_quantile) into mid, so the quarantine - # trained on almost every rollout (qmass~0.5). v_grad stays authored-only; only the - # thresholds follow the live distribution. The window includes this batch, so step 0 - # self-calibrates. TODO(#30): calibrate the rout/mid widths from the batch distribution. + # ── online-stats gate (#30): mean + k*std, three zones, keep = bulk ── The authored + # absolute band is mis-placed (live pos sits far below the synthetic-hack edge; even + # synthetic SOLVE is more hack-aligned than on-policy hack -- out/diag/pinning_calib.png), + # and a fixed quantile FORCES route_quantile out every step even when nothing separates. + # Route by the live MEAN + k*STD instead: pos > mean+route_std_mid*std -> mid, pos >= + # mean+route_std_rout*std -> rout, below -> keep (bulk). Self-silencing: only the tail + # that genuinely exceeds the spread routes, so qmass tracks real separation rather than a + # forced fraction. v_grad stays authored-only; the threshold follows the live distribution. + # The window includes this batch, so step 0 self-calibrates; flushed on v_grad refresh. route_pos_window.extend(pos.detach().cpu().tolist()) ref = torch.tensor(list(route_pos_window)) - t_hi = ref.quantile(1.0 - cfg.route_quantile).item() # top route_quantile -> rout - t_lo = ref.quantile(1.0 - 2.0 * cfg.route_quantile).item() # next band -> mid; below -> keep + mu_pos, sd_pos = ref.mean().item(), ref.std().item() + t_lo = mu_pos + cfg.route_std_mid * sd_pos # mid onset + t_hi = mu_pos + cfg.route_std_rout * sd_pos # rout onset (>= mid) logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} " - f"max={pos.max().item():+.2f} | online t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} " - f"win={len(route_pos_window)}") + f"max={pos.max().item():+.2f} | mean={mu_pos:+.2f} std={sd_pos:.2f} " + f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} win={len(route_pos_window)}") m = (pos > t_lo).float() # mid + rout -> quarantine trains (keep = bulk below t_lo) d = (pos >= t_hi).float() # top tail -> hack -> deployed detached return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc diff --git a/src/vgrout/train_config.py b/src/vgrout/train_config.py index cb157cf..e2225aa 100644 --- a/src/vgrout/train_config.py +++ b/src/vgrout/train_config.py @@ -49,13 +49,18 @@ class Config: # (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean # washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B. v_grad_k: int = 1 - # Online-stats gate: route by live QUANTILES of the pooled cosine-to-v_grad, not the - # authored absolute band (whose p75-hack edge live hacks never reach -> rout~0). Each - # step the top route_quantile tail -> hack (deployed detached), bottom -> keep, middle - # -> mid. route_window = sliding buffer of recent pooled positions, flushed on refresh - # (positions are measured against one v_grad). Direction stays authored-only; only the - # threshold follows the live distribution. TODO(#30): center+width calibration vs fixed %. - route_quantile: float = 0.05 + # Online-stats gate (#30): route by the live MEAN + k*STD of the pooled cosine-to-v_grad, + # not a fixed quantile. The authored absolute band is mis-placed (live pos sits far below + # the synthetic-hack edge; even synthetic SOLVE is more hack-aligned than on-policy hack -- + # see out/diag/pinning_calib.png), and a fixed quantile FORCES route_quantile of the batch + # out every step even when nothing separates. mean+k*std self-silences: it only routes the + # tail that genuinely exceeds the spread. pos > mean + route_std_mid*std -> mid (absorption); + # pos >= mean + route_std_rout*std -> rout (hack, deployed detached); below -> keep (bulk). + # route_window = sliding buffer of recent pooled positions, flushed on v_grad refresh + # (positions measured against one v_grad). Direction stays authored-only; only the threshold + # follows the live distribution. + route_std_mid: float = 2.0 + route_std_rout: float = 3.0 route_window: int = 512 # Haar-random direction control (placebo): same routing machinery, no pair signal. routeV_random_v_seed: int | None = None