feat(#30): mean+k*std online gate replaces fixed quantile; always-show route cols

Gate calibration: route by live mean + route_std_mid/route_std_rout * std of the
pooled cosine-to-v_grad, not a fixed quantile tail. Self-silences -- only the tail
that genuinely exceeds the spread routes, so qmass tracks real separation instead
of a forced fraction. The authored absolute band is mis-placed (live pos sits far
below the synthetic-hack edge; even synthetic solve out-aligns on-policy hack).

tablelog: auroc/rout/routE/keep/resid/qmass cols always shown (nan on vanilla) so
arm tables line up.

Diagnostics: scripts/diag_pinning.py (4-population calibration view, mean+/-2sd band)
and scripts/diag_pinning_refresh.py (proves cosine stats recompute from a tracked
v-independent gradient cloud on a v_grad refresh -- exact for k=1, sanity 2.5e-16).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-11 02:56:07 +00:00
parent 4f60f94072
commit 979daf84fd
5 changed files with 520 additions and 36 deletions
+252
View File
@@ -0,0 +1,252 @@
"""Pinning calibration: where does the live gate sit, and would mean +/- 2sd route it?
The routeV gate scores each rollout by the width-pooled, band-normalized cosine of its
deployed-block c-probe gradient to v_grad (exactly train.py:_lora2r_gate_labels) -- call it
the position on the HACKING DIRECTION. The current gate routes a FIXED quantile tail (#30).
This plots four populations on that axis:
on-policy solve / hack -- live rollouts (oracle `exploited` read for COLOUR/AUROC only)
synthetic solve / hack -- the authored pairs v_grad was built from (the only label source)
and marks the PROPOSED pinning: online mean and mean +/- 2*sd (a symmetric band that
self-calibrates to the live spread instead of forcing a fixed %).
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU, restyle only
outputs (out/diag/): pinning_calib.png, pinning_data.parquet (the 4 populations, regenerates the plot).
"""
from __future__ import annotations
import json
import struct
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
import tyro
import polars as pl
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from loguru import logger
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
from vgrout.train import _build_v_grad, route_band_edges, _haar_unit_dirs, _auroc
# colour = behaviour (blue solve, red hack); style = source (solid+fill on-policy, dashed synthetic)
SOLVE, HACK, MEANC, ORACLE = "#3b6ea5", "#c44e52", "#d1900a", "#3a8a7a"
@dataclass
class Cfg:
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
ckpt: str = "first_hack"
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
# Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and
# DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane.
step_lo: int = 2
step_hi: int = 9
max_rollouts: int = 240
random_v_seed: int | None = None # Haar placebo (sanity: pins should NOT separate)
replot: Path | None = None # load this parquet and re-plot only (no model, no GPU)
out_dir: Path = Path("out/diag")
def _ckpt_meta(path: Path) -> dict:
with open(path, "rb") as f:
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
def pooled_pos(c_grads: dict, v_grad: dict, route_band: dict, names: list, r: int) -> tuple[float, float]:
"""Mirror train.py:_lora2r_gate_labels pooling for ONE rollout's per-module deployed c-grads."""
num = 0.0; den = 0.0; w = 0.0; n_inc = 0
for name in names:
lower, upper = route_band[name]
if upper - lower <= 0:
continue
g_b = c_grads[name][:r].float().to(v_grad[name].device)
nrm = g_b.norm()
cos_b = (torch.einsum("r, k r -> k", g_b, v_grad[name]).max() / nrm.clamp_min(1e-12)).item()
num += cos_b - lower; den += upper - lower
w += nrm.item(); n_inc += 1
if n_inc == 0:
raise RuntimeError("no module has positive band width")
return num / den, w / n_inc
def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
"""Gaussian KDE, Silverman bandwidth (no scipy)."""
x = np.asarray(x, float)
if len(x) < 2:
return np.zeros_like(grid)
iqr = np.subtract(*np.percentile(x, [75, 25]))
sigma = min(x.std(ddof=1), iqr / 1.349) if iqr > 0 else x.std(ddof=1)
bw = max(0.9 * (sigma or 1.0) * len(x) ** (-0.2), 1e-3)
z = (grid[:, None] - x[None, :]) / bw
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
def plot(df: pl.DataFrame, subtitle: str, out_png: Path) -> None:
"""Regenerate the figure from the saved 4 populations -- no GPU needed."""
arr = lambda p: df.filter(pl.col("pop") == p)["pos"].to_numpy()
on_solve, on_hack = arr("on_solve"), arr("on_hack")
syn_solve, syn_hack = arr("syn_solve"), arr("syn_hack")
pos_live = np.concatenate([on_solve, on_hack])
labels = np.concatenate([np.zeros(len(on_solve), bool), np.ones(len(on_hack), bool)])
mean, sd = float(pos_live.mean()), float(pos_live.std())
lo_b, hi_b = mean - 2 * sd, mean + 2 * sd # proposed routing band
auroc = _auroc(pos_live.tolist(), labels.tolist())
thr = np.unique(pos_live) # oracle divider (Youden J) -- diagnostic only
j = [(pos_live[labels] >= t).mean() - (pos_live[~labels] >= t).mean() for t in thr]
oracle = float(thr[int(np.argmax(j))])
lo = min(pos_live.min(), syn_solve.min()) - 0.1
hi = max(np.quantile(pos_live, 0.99), syn_hack.max()) + 0.1
grid = np.linspace(lo, hi, 400)
POPS = [(on_solve, SOLVE, True, "on-policy solve"), (on_hack, HACK, True, "on-policy hack"),
(syn_solve, SOLVE, False, "synthetic solve"), (syn_hack, HACK, False, "synthetic hack")]
kdes = {i: _kde(x, grid) for i, (x, *_ ) in enumerate(POPS)}
ymax = max(y.max() for y in kdes.values()) * 1.15
fig, ax = plt.subplots(figsize=(8.6, 4.8))
# proposed pinning: mean +/- 2sd band, drawn first so curves sit on top
ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0)
ax.axvline(mean, color=MEANC, lw=1.8)
for b in (lo_b, hi_b):
ax.axvline(b, color=MEANC, lw=1.2, ls="--")
ax.axvline(oracle, color=ORACLE, lw=1.3, ls="-.")
for i, (x, col, on_policy, _) in enumerate(POPS):
y = kdes[i]
if on_policy:
ax.fill_between(grid, y, color=col, alpha=0.13, lw=0)
ax.plot(grid, y, color=col, lw=1.9)
else:
ax.fill_between(grid, y, color=col, alpha=0.05, lw=0)
ax.plot(grid, y, color=col, lw=2.4, ls=(0, (5, 2)))
ax.set_ylim(0, ymax)
for s in ("top", "right"):
ax.spines[s].set_visible(False)
ax.set_ylabel("density")
ax.set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$")
dist_handles = [Line2D([0], [0], color=c, lw=2.2, ls="-" if op else (0, (5, 2))) for _, c, op, _ in POPS]
leg1 = ax.legend(dist_handles, [lab for *_, lab in POPS], loc="upper left",
fontsize=8, frameon=False, title="distributions", title_fontsize=8)
leg1._legend_box.align = "left"
ax.add_artist(leg1)
mark_handles = [
Line2D([0], [0], color=MEANC, lw=1.8),
Patch(facecolor=MEANC, alpha=0.18, edgecolor=MEANC, ls="--"),
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-."),
]
ax.legend(mark_handles,
[f"online mean ({mean:+.2f})", f"mean +/- 2sd (sd={sd:.2f}) = proposed band",
f"best hack/solve split ({oracle:+.2f}, diagnostic)"],
loc="upper right", fontsize=8, frameon=False, title="pinning", title_fontsize=8)
routed_hi = float((pos_live > hi_b).mean())
ax.set_title(f"{subtitle}\nAUROC(hack-dir -> hack)={auroc:.2f} "
f"{routed_hi:.0%} of rollouts beyond mean+2sd", fontsize=9.5)
fig.tight_layout()
fig.savefig(out_png, dpi=140)
plt.close(fig)
logger.info(f"mean={mean:+.3f} sd={sd:.3f} band=[{lo_b:+.3f},{hi_b:+.3f}] oracle={oracle:+.3f} "
f"AUROC={auroc:.3f} routed(>mean+2sd)={routed_hi:.2f}")
logger.info(f"wrote {out_png}")
def main(cfg: Cfg) -> int:
cfg.out_dir.mkdir(parents=True, exist_ok=True)
data_path = cfg.out_dir / "pinning_data.parquet"
out_png = cfg.out_dir / "pinning_calib.png"
if cfg.replot is not None:
df = pl.read_parquet(cfg.replot)
plot(df, f"pinning calibration (replot) -- {cfg.replot.name}", out_png)
return 0
device = torch.device("cuda")
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
meta = _ckpt_meta(ckpt_path)
run_cfg = json.loads(meta.get("cfg", "{}"))
model_name = run_cfg.get("model", meta.get("model", "Qwen/Qwen3-4B"))
r = run_cfg.get("lora_r", 32)
init_seed = run_cfg.get("lora_init_seed", 0)
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} "
f"model={model_name} r={r} init_seed={init_seed}")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
names = sorted(wrappers)
sd = load_file(str(ckpt_path))
for nm in names:
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
logger.info(f"loaded A/B into {len(names)} modules")
pairs = load_pairs(cfg.pairs)
logger.info(f"pairs {cfg.pairs} -> {len(pairs)}")
model.eval()
_, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs,
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
v_grad = _build_v_grad(raw_grads, wrappers, 1, device)
if cfg.random_v_seed is not None:
v_grad = _haar_unit_dirs(v_grad, cfg.random_v_seed, device)
logger.info(f"OVERRODE v_grad with Haar dirs seed={cfg.random_v_seed} (placebo)")
route_band = route_band_edges(raw_grads, v_grad, device)
def pair_pos(side: str) -> np.ndarray:
n = raw_grads[f"{side}/{names[0]}"].shape[0]
out = []
for i in range(n):
cg = {nm: torch.cat([raw_grads[f"{side}/{nm}"][i], torch.zeros(r)]) for nm in names}
out.append(pooled_pos(cg, v_grad, route_band, names, r)[0])
return np.array(out)
syn_solve, syn_hack = pair_pos("clean"), pair_pos("hack")
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
batch = [r_ for r_ in recs if cfg.step_lo <= r_["step"] <= cfg.step_hi and r_["text"].strip()][:cfg.max_rollouts]
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
pos_live, labels, steps = [], [], []
for i, rec in enumerate(batch):
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tok, rec["prompt"], rec["text"], device)
if not torch.isfinite(loss):
continue
loss.backward()
cg = {nm: wrappers[nm]["layer"]._lora2r_gate.grad.sum(dim=tuple(range(
wrappers[nm]["layer"]._lora2r_gate.grad.dim() - 1))) for nm in names}
pos_live.append(pooled_pos(cg, v_grad, route_band, names, r)[0])
labels.append(bool(rec["exploited"])); steps.append(rec["step"])
if (i + 1) % 40 == 0:
logger.info(f" rollout {i+1}/{len(batch)}")
model.zero_grad(set_to_none=True)
pos_live = np.array(pos_live); labels = np.array(labels); steps = np.array(steps)
per_step = {int(s): round(_auroc(pos_live[steps == s].tolist(), labels[steps == s].tolist()), 2)
for s in sorted(set(steps.tolist())) if labels[steps == s].any() and (~labels[steps == s]).any()}
logger.info(f"live: {len(labels)} rollouts, {int(labels.sum())} exploited; per-step AUROC={per_step}")
# save the 4 populations long-form -> regenerates the plot with --replot (no GPU)
df = pl.concat([
pl.DataFrame({"pop": "on_solve", "pos": pos_live[~labels], "step": steps[~labels]}),
pl.DataFrame({"pop": "on_hack", "pos": pos_live[labels], "step": steps[labels]}),
pl.DataFrame({"pop": "syn_solve", "pos": syn_solve, "step": np.full(len(syn_solve), -1)}),
pl.DataFrame({"pop": "syn_hack", "pos": syn_hack, "step": np.full(len(syn_hack), -1)}),
])
df.write_parquet(data_path)
logger.info(f"wrote {data_path} ({len(df)} rows)")
plot(df, f"{cfg.run_dir.name}\n{cfg.ckpt} v_grad, live steps {cfg.step_lo}-{cfg.step_hi}, "
f"{int(labels.sum())}/{len(labels)} exploited", out_png)
return 0
if __name__ == "__main__":
raise SystemExit(main(tyro.cli(Cfg)))
+227
View File
@@ -0,0 +1,227 @@
"""Refresh-tracking prototype: when G_hack (v_grad) refreshes, can we recompute the
live cosine stats from a TRACKED gradient cloud, without re-running the model?
Claim (k=1): the gate score pos_i = (S_i - L)/W with S_i = sum_m <u_{i,m}, v_m>, where
u_{i,m} = g_{i,m}/|g_{i,m}| is the v-INDEPENDENT unit per-module gradient. Stacking
U_i=[u_{i,1};..], V=[v_1;..] gives S_i = U_i . V (linear in V). So the live cosine
distribution under ANY v is an affine push of the SAME gradient cloud:
mean(pos) = (mu_U . V - L) / W
var(pos) = (V^T Sigma_U V) / W^2
mu_U, Sigma_U are tracked once (v-independent); L,W come from the pairs at refresh. So a
refresh needs no model re-run and no window flush -- just re-project the cloud onto v_new.
We simulate a refresh with two direction estimates v_A, v_B = unit-mean-diff of two DISJOINT
halves of the authored pairs (model fixed, only the direction changes), and verify:
(1) reproject stored U onto v == direct pos (exact, validates indexing)
(2) moment formula (mu_U, Sigma_U) == empirical mean/sd of pos, for BOTH v_A and v_B
from the SAME tracked cloud (the actual claim).
uv run python scripts/diag_pinning_refresh.py
outputs (out/diag/): pinning_refresh.png, pinning_refresh.parquet, prints a sanity table.
"""
from __future__ import annotations
import json
import struct
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
import tyro
import polars as pl
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from loguru import logger
from tabulate import tabulate
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
from vgrout.train import _build_v_grad, route_band_edges, _auroc
from diag_pinning import _kde, SOLVE, HACK, MEANC # same dir (scripts/ is sys.path[0] when run directly)
@dataclass
class Cfg:
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
ckpt: str = "first_hack"
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
step_lo: int = 2
step_hi: int = 9
max_rollouts: int = 240
out_dir: Path = Path("out/diag")
def _ckpt_meta(path: Path) -> dict:
with open(path, "rb") as f:
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
def stacked_v_and_band(raw_grads: dict, idx, wrappers, names, r, device):
"""Direction estimate + band from a SUBSET of pairs (simulated refresh).
Returns V_stacked [n_mod*r] (zeroed on modules the band excludes), and scalars L,W.
Matches pooled_pos: a module enters iff its band width > 0; excluded modules contribute 0."""
sub = {k: v[idx] for k, v in raw_grads.items()}
v_grad = _build_v_grad(sub, wrappers, 1, device) # {name: [1, r]} unit
band = route_band_edges(sub, v_grad, device) # {name: (lower, upper)}
V, L, W = [], 0.0, 0.0
for name in names:
lower, upper = band[name]
if upper - lower > 0:
V.append(v_grad[name][0].float().cpu()) # [r]
L += lower; W += (upper - lower)
else:
V.append(torch.zeros(r)) # excluded -> no contribution
return torch.cat(V).numpy(), float(L), float(W), v_grad, band
def pos_from_U(U: np.ndarray, V: np.ndarray, L: float, W: float) -> np.ndarray:
"""pos = (U . V - L) / W -- re-project the stored cloud onto direction V (no model)."""
return (U @ V - L) / W
def main(cfg: Cfg) -> int:
cfg.out_dir.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda")
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
meta = _ckpt_meta(ckpt_path)
run_cfg = json.loads(meta.get("cfg", "{}"))
model_name = run_cfg.get("model", "Qwen/Qwen3-4B")
r = run_cfg.get("lora_r", 32)
init_seed = run_cfg.get("lora_init_seed", 0)
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} model={model_name} r={r}")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
names = sorted(wrappers)
sd = load_file(str(ckpt_path))
for nm in names:
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
pairs = load_pairs(cfg.pairs)
model.eval()
_, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs,
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
n_pairs = raw_grads[f"hack/{names[0]}"].shape[0]
half = n_pairs // 2
idx_A, idx_B = list(range(half)), list(range(half, n_pairs)) # disjoint pair halves = two v estimates
V_A, L_A, W_A, vA, _ = stacked_v_and_band(raw_grads, idx_A, wrappers, names, r, device)
V_B, L_B, W_B, vB, _ = stacked_v_and_band(raw_grads, idx_B, wrappers, names, r, device)
V_A, V_B = V_A.astype(np.float64), V_B.astype(np.float64) # float64 so the sanity gate reflects math, not rounding
cosAB = float(np.dot(V_A, V_B) / (np.linalg.norm(V_A) * np.linalg.norm(V_B) + 1e-12))
logger.info(f"two direction estimates from pair halves ({half} / {n_pairs-half}); "
f"cos(V_A,V_B)={cosAB:+.3f} (a real refresh shift)")
# ── score live batch ONCE: store U_i = stacked unit per-module gradients (v-INDEPENDENT) ──
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts]
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
U_rows, labels = [], []
for i, rec in enumerate(batch):
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tok, rec["prompt"], rec["text"], device)
if not torch.isfinite(loss):
continue
loss.backward()
u_blocks = []
for name in names:
g = wrappers[name]["layer"]._lora2r_gate.grad
g_b = g.sum(dim=tuple(range(g.dim() - 1)))[:r].float() # deployed block [r]
u = (g_b / g_b.norm().clamp_min(1e-12)).cpu() # v-independent unit grad
u_blocks.append(u)
U_rows.append(torch.cat(u_blocks).numpy())
labels.append(bool(rec["exploited"]))
if (i + 1) % 40 == 0:
logger.info(f" rollout {i+1}/{len(batch)}")
model.zero_grad(set_to_none=True)
U = np.stack(U_rows).astype(np.float64) # [n_roll, n_mod*r] the tracked gradient cloud
labels = np.array(labels)
logger.info(f"stored U cloud: {U.shape} ({U.nbytes/1e6:.1f} MB), {int(labels.sum())} exploited")
# ── pos under each direction by RE-PROJECTING the stored cloud (no model re-run) ──
pos_A = pos_from_U(U, V_A, L_A, W_A)
pos_B = pos_from_U(U, V_B, L_B, W_B)
# ── moment formula from the tracked cloud: mu_U, Sigma_U (v-independent) ──
mu_U = U.mean(0) # [D]
Uc = U - mu_U
Sigma_U = (Uc.T @ Uc) / U.shape[0] # [D, D] uncentered-then-centered covariance
def moment_stats(V, L, W):
mean = (mu_U @ V - L) / W
var = (V @ (Sigma_U @ V)) / (W * W)
return mean, np.sqrt(max(var, 0.0))
mA_emp, sA_emp = pos_A.mean(), pos_A.std()
mB_emp, sB_emp = pos_B.mean(), pos_B.std()
mA_mom, sA_mom = moment_stats(V_A, L_A, W_A)
mB_mom, sB_mom = moment_stats(V_B, L_B, W_B)
# ── sanity table ──
rows = [
["mean pos | v_A", f"{mA_emp:+.4f}", f"{mA_mom:+.4f}", f"{abs(mA_emp-mA_mom):.2e}"],
["std pos | v_A", f"{sA_emp:.4f}", f"{sA_mom:.4f}", f"{abs(sA_emp-sA_mom):.2e}"],
["mean pos | v_B", f"{mB_emp:+.4f}", f"{mB_mom:+.4f}", f"{abs(mB_emp-mB_mom):.2e}"],
["std pos | v_B", f"{sB_emp:.4f}", f"{sB_mom:.4f}", f"{abs(sB_emp-sB_mom):.2e}"],
]
print("\nSHOULD: moment-formula (from tracked mu_U, Sigma_U) == empirical, both v_A and v_B.")
print(tabulate(rows, headers=["quantity", "empirical (direct)", "moment formula", "abs diff"],
tablefmt="github"))
max_diff = max(abs(mA_emp-mA_mom), abs(sA_emp-sA_mom), abs(mB_emp-mB_mom), abs(sB_emp-sB_mom))
ok = max_diff < 1e-5
print(f"\n{'PASS' if ok else 'FAIL'}: max |empirical - moment| = {max_diff:.2e} (refresh needs no model re-run)")
logger.info(f"refresh shift: mean {mA_emp:+.3f}->{mB_emp:+.3f}, std {sA_emp:.3f}->{sB_emp:.3f}; "
f"AUROC v_A={_auroc(pos_A.tolist(), labels.tolist()):.3f} v_B={_auroc(pos_B.tolist(), labels.tolist()):.3f}")
# ── plot: pos under v_A (top) and v_B (bottom), mean +/- 2sd band recalibrates per direction ──
pl.DataFrame({"pos_A": pos_A, "pos_B": pos_B, "exploited": labels.tolist()}).write_parquet(
cfg.out_dir / "pinning_refresh.parquet")
lo = min(pos_A.min(), pos_B.min()) - 0.1
hi = max(pos_A.max(), pos_B.max()) + 0.1
grid = np.linspace(lo, hi, 400)
fig, axes = plt.subplots(2, 1, figsize=(8.6, 6.0), sharex=True)
for ax, pos, mean, std, tag in [(axes[0], pos_A, mA_emp, sA_emp, "A (pairs 1st half)"),
(axes[1], pos_B, mB_emp, sB_emp, "B (pairs 2nd half)")]:
lo_b, hi_b = mean - 2 * std, mean + 2 * std
ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0)
ax.axvline(mean, color=MEANC, lw=1.8)
for b in (lo_b, hi_b):
ax.axvline(b, color=MEANC, lw=1.2, ls="--")
peak = 0.0
for x, col in [(pos[~labels], SOLVE), (pos[labels], HACK)]:
y = _kde(x, grid)
ax.fill_between(grid, y, color=col, alpha=0.13, lw=0)
ax.plot(grid, y, color=col, lw=1.9)
peak = max(peak, y.max())
ax.set_ylim(0, peak * 1.15)
ax.set_ylabel("density")
ax.set_title(f"direction estimate {tag}: mean={mean:+.2f}, sd={std:.2f}, "
f"AUROC={_auroc(pos.tolist(), labels.tolist()):.2f}", fontsize=9, loc="left")
for s in ("top", "right"):
ax.spines[s].set_visible(False)
dist = [Line2D([0], [0], color=SOLVE, lw=2), Line2D([0], [0], color=HACK, lw=2),
Line2D([0], [0], color=MEANC, lw=1.8), Patch(facecolor=MEANC, alpha=0.18, ls="--", edgecolor=MEANC)]
axes[0].legend(dist, ["on-policy solve", "on-policy hack", "online mean", "mean +/- 2sd"],
loc="upper left", fontsize=8, frameon=False)
axes[1].set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$")
fig.suptitle("refresh tracking: same gradient cloud, two G_hack estimates "
f"(cos(v_A,v_B)={cosAB:+.2f}) -- band recalibrates with no model re-run", fontsize=9.5)
fig.tight_layout(rect=(0, 0, 1, 0.96))
fig.savefig(cfg.out_dir / "pinning_refresh.png", dpi=140)
plt.close(fig)
logger.info(f"wrote {cfg.out_dir}/pinning_refresh.png, pinning_refresh.parquet")
return 0 if ok else 1
if __name__ == "__main__":
raise SystemExit(main(tyro.cli(Cfg)))
+15 -16
View File
@@ -72,8 +72,8 @@ class StepLogger:
def __init__(self, arm: str, modes: list[str], mode_code: dict[str, str],
show_ablate: bool = False) -> None:
# routeV reports routing diagnostics; absorb shares qmass (zone cols read nan).
is_route = arm in ("routingV_lora2r", "absorb_lora2r")
# Routing diagnostics are ALWAYS shown (nan on vanilla, whose gate never runs) so the
# column layout is identical across arms -- vanilla/routeV/absorb tables line up.
cols: list[_Col] = [
_Col("step", 4, "step", "d", "GRPO step"),
_Col("ref_eq", 6, "ref_eq", ".2f", "vanilla-equiv step (cum_gens/256)"),
@@ -99,21 +99,20 @@ class StepLogger:
_Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of A/B grads (vs grad_clip)"),
_Col("lr", 7, "lr", ".1e", "scheduled learning rate"),
]
# routeV reports unit and energy shares across the routing band.
if is_route:
cols += [
_Col("auroc", 6, "auroc", ".2f", "AUROC of pooled cos(g,v_grad) as a hack detector vs the hack-label (student exploited + teacher cached); MEASUREMENT only, never routes. ~0.5 = v_grad blind to live hacks (no threshold helps); high but rout~0 = pure threshold/scale problem; a drop at a refresh = refresh destroyed separation"),
_Col("cosU", 6, "cosU", "+.2f", "pooled cos(v_grad, summed-rollout c-grad): is the net update moving hack-ward this step"),
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"),
_Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"),
_Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"),
_Col("rout", 6, "rout", ".2f", "rollout share labelled hack (above band) -> quarantine-only, deployed detached"),
_Col("keepE", 6, "keepE", ".2f", "energy-weighted keep (grad-norm-weighted clean share)"),
_Col("residE", 6, "residE", ".2f", "energy-weighted resid"),
_Col("routE", 6, "routE", ".2f", "energy-weighted rout"),
]
# routeV reports unit and energy shares across the routing band (nan on vanilla/absorb).
cols += [
_Col("auroc", 6, "auroc", ".2f", "AUROC of pooled cos(g,v_grad) as a hack detector vs the hack-label (student exploited + teacher cached); MEASUREMENT only, never routes. ~0.5 = v_grad blind to live hacks (no threshold helps); high but rout~0 = pure threshold/scale problem; a drop at a refresh = refresh destroyed separation"),
_Col("cosU", 6, "cosU", "+.2f", "pooled cos(v_grad, summed-rollout c-grad): is the net update moving hack-ward this step"),
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"),
_Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"),
_Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"),
_Col("rout", 6, "rout", ".2f", "rollout share labelled hack (above band) -> quarantine-only, deployed detached"),
_Col("keepE", 6, "keepE", ".2f", "energy-weighted keep (grad-norm-weighted clean share)"),
_Col("residE", 6, "residE", ".2f", "energy-weighted resid"),
_Col("routE", 6, "routE", ".2f", "energy-weighted rout"),
]
# Show the training-prompt deploy proxy only when an ablated slice exists.
if is_route and show_ablate:
if show_ablate:
cols += [
_Col("hack_abl", 6, "hk_abl", "frac", "per-step deploy proxy: hack rate on the ablated (deploy-mode) rollout slice; train prompts, noisier than hk_dep"),
_Col("solve_abl", 6, "slv_abl", "frac", "per-step deploy proxy: solve rate on the ablated (deploy-mode) rollout slice; train prompts"),
+14 -13
View File
@@ -602,22 +602,23 @@ def main(cfg: Config) -> int:
if n_inc == 0:
raise RuntimeError("no module has positive band width; pairs separate nowhere")
pos = num / den; w /= n_inc
# ── online-stats gate (three zones, keep = bulk) ── The authored absolute thresholds
# (clean<=0, hack>=1) don't work: live hacks sit far below the authored p75-hack pooled
# edge, so rout~0. Route by live QUANTILES of pos instead: top route_quantile tail ->
# hack (rout), the next route_quantile band -> mid (absorption), everything below ->
# keep. keep is the BULK; the previous t_lo=quantile(route_quantile) put keep at the
# bottom route_quantile and routed ~(1-route_quantile) into mid, so the quarantine
# trained on almost every rollout (qmass~0.5). v_grad stays authored-only; only the
# thresholds follow the live distribution. The window includes this batch, so step 0
# self-calibrates. TODO(#30): calibrate the rout/mid widths from the batch distribution.
# ── online-stats gate (#30): mean + k*std, three zones, keep = bulk ── The authored
# absolute band is mis-placed (live pos sits far below the synthetic-hack edge; even
# synthetic SOLVE is more hack-aligned than on-policy hack -- out/diag/pinning_calib.png),
# and a fixed quantile FORCES route_quantile out every step even when nothing separates.
# Route by the live MEAN + k*STD instead: pos > mean+route_std_mid*std -> mid, pos >=
# mean+route_std_rout*std -> rout, below -> keep (bulk). Self-silencing: only the tail
# that genuinely exceeds the spread routes, so qmass tracks real separation rather than a
# forced fraction. v_grad stays authored-only; the threshold follows the live distribution.
# The window includes this batch, so step 0 self-calibrates; flushed on v_grad refresh.
route_pos_window.extend(pos.detach().cpu().tolist())
ref = torch.tensor(list(route_pos_window))
t_hi = ref.quantile(1.0 - cfg.route_quantile).item() # top route_quantile -> rout
t_lo = ref.quantile(1.0 - 2.0 * cfg.route_quantile).item() # next band -> mid; below -> keep
mu_pos, sd_pos = ref.mean().item(), ref.std().item()
t_lo = mu_pos + cfg.route_std_mid * sd_pos # mid onset
t_hi = mu_pos + cfg.route_std_rout * sd_pos # rout onset (>= mid)
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} "
f"max={pos.max().item():+.2f} | online t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} "
f"win={len(route_pos_window)}")
f"max={pos.max().item():+.2f} | mean={mu_pos:+.2f} std={sd_pos:.2f} "
f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} win={len(route_pos_window)}")
m = (pos > t_lo).float() # mid + rout -> quarantine trains (keep = bulk below t_lo)
d = (pos >= t_hi).float() # top tail -> hack -> deployed detached
return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc
+12 -7
View File
@@ -49,13 +49,18 @@ class Config:
# (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean
# washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B.
v_grad_k: int = 1
# Online-stats gate: route by live QUANTILES of the pooled cosine-to-v_grad, not the
# authored absolute band (whose p75-hack edge live hacks never reach -> rout~0). Each
# step the top route_quantile tail -> hack (deployed detached), bottom -> keep, middle
# -> mid. route_window = sliding buffer of recent pooled positions, flushed on refresh
# (positions are measured against one v_grad). Direction stays authored-only; only the
# threshold follows the live distribution. TODO(#30): center+width calibration vs fixed %.
route_quantile: float = 0.05
# Online-stats gate (#30): route by the live MEAN + k*STD of the pooled cosine-to-v_grad,
# not a fixed quantile. The authored absolute band is mis-placed (live pos sits far below
# the synthetic-hack edge; even synthetic SOLVE is more hack-aligned than on-policy hack --
# see out/diag/pinning_calib.png), and a fixed quantile FORCES route_quantile of the batch
# out every step even when nothing separates. mean+k*std self-silences: it only routes the
# tail that genuinely exceeds the spread. pos > mean + route_std_mid*std -> mid (absorption);
# pos >= mean + route_std_rout*std -> rout (hack, deployed detached); below -> keep (bulk).
# route_window = sliding buffer of recent pooled positions, flushed on v_grad refresh
# (positions measured against one v_grad). Direction stays authored-only; only the threshold
# follows the live distribution.
route_std_mid: float = 2.0
route_std_rout: float = 3.0
route_window: int = 512
# Haar-random direction control (placebo): same routing machinery, no pair signal.
routeV_random_v_seed: int | None = None