Files
evil_MoE/scripts/diag_pinning_refresh.py
T
wassname 979daf84fd feat(#30): mean+k*std online gate replaces fixed quantile; always-show route cols
Gate calibration: route by live mean + route_std_mid/route_std_rout * std of the
pooled cosine-to-v_grad, not a fixed quantile tail. Self-silences -- only the tail
that genuinely exceeds the spread routes, so qmass tracks real separation instead
of a forced fraction. The authored absolute band is mis-placed (live pos sits far
below the synthetic-hack edge; even synthetic solve out-aligns on-policy hack).

tablelog: auroc/rout/routE/keep/resid/qmass cols always shown (nan on vanilla) so
arm tables line up.

Diagnostics: scripts/diag_pinning.py (4-population calibration view, mean+/-2sd band)
and scripts/diag_pinning_refresh.py (proves cosine stats recompute from a tracked
v-independent gradient cloud on a v_grad refresh -- exact for k=1, sanity 2.5e-16).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-11 02:56:07 +00:00

228 lines
11 KiB
Python

"""Refresh-tracking prototype: when G_hack (v_grad) refreshes, can we recompute the
live cosine stats from a TRACKED gradient cloud, without re-running the model?
Claim (k=1): the gate score pos_i = (S_i - L)/W with S_i = sum_m <u_{i,m}, v_m>, where
u_{i,m} = g_{i,m}/|g_{i,m}| is the v-INDEPENDENT unit per-module gradient. Stacking
U_i=[u_{i,1};..], V=[v_1;..] gives S_i = U_i . V (linear in V). So the live cosine
distribution under ANY v is an affine push of the SAME gradient cloud:
mean(pos) = (mu_U . V - L) / W
var(pos) = (V^T Sigma_U V) / W^2
mu_U, Sigma_U are tracked once (v-independent); L,W come from the pairs at refresh. So a
refresh needs no model re-run and no window flush -- just re-project the cloud onto v_new.
We simulate a refresh with two direction estimates v_A, v_B = unit-mean-diff of two DISJOINT
halves of the authored pairs (model fixed, only the direction changes), and verify:
(1) reproject stored U onto v == direct pos (exact, validates indexing)
(2) moment formula (mu_U, Sigma_U) == empirical mean/sd of pos, for BOTH v_A and v_B
from the SAME tracked cloud (the actual claim).
uv run python scripts/diag_pinning_refresh.py
outputs (out/diag/): pinning_refresh.png, pinning_refresh.parquet, prints a sanity table.
"""
from __future__ import annotations
import json
import struct
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
import tyro
import polars as pl
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from loguru import logger
from tabulate import tabulate
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
from vgrout.train import _build_v_grad, route_band_edges, _auroc
from diag_pinning import _kde, SOLVE, HACK, MEANC # same dir (scripts/ is sys.path[0] when run directly)
@dataclass
class Cfg:
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
ckpt: str = "first_hack"
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
step_lo: int = 2
step_hi: int = 9
max_rollouts: int = 240
out_dir: Path = Path("out/diag")
def _ckpt_meta(path: Path) -> dict:
with open(path, "rb") as f:
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
def stacked_v_and_band(raw_grads: dict, idx, wrappers, names, r, device):
"""Direction estimate + band from a SUBSET of pairs (simulated refresh).
Returns V_stacked [n_mod*r] (zeroed on modules the band excludes), and scalars L,W.
Matches pooled_pos: a module enters iff its band width > 0; excluded modules contribute 0."""
sub = {k: v[idx] for k, v in raw_grads.items()}
v_grad = _build_v_grad(sub, wrappers, 1, device) # {name: [1, r]} unit
band = route_band_edges(sub, v_grad, device) # {name: (lower, upper)}
V, L, W = [], 0.0, 0.0
for name in names:
lower, upper = band[name]
if upper - lower > 0:
V.append(v_grad[name][0].float().cpu()) # [r]
L += lower; W += (upper - lower)
else:
V.append(torch.zeros(r)) # excluded -> no contribution
return torch.cat(V).numpy(), float(L), float(W), v_grad, band
def pos_from_U(U: np.ndarray, V: np.ndarray, L: float, W: float) -> np.ndarray:
"""pos = (U . V - L) / W -- re-project the stored cloud onto direction V (no model)."""
return (U @ V - L) / W
def main(cfg: Cfg) -> int:
cfg.out_dir.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda")
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
meta = _ckpt_meta(ckpt_path)
run_cfg = json.loads(meta.get("cfg", "{}"))
model_name = run_cfg.get("model", "Qwen/Qwen3-4B")
r = run_cfg.get("lora_r", 32)
init_seed = run_cfg.get("lora_init_seed", 0)
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} model={model_name} r={r}")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
names = sorted(wrappers)
sd = load_file(str(ckpt_path))
for nm in names:
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
pairs = load_pairs(cfg.pairs)
model.eval()
_, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs,
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
n_pairs = raw_grads[f"hack/{names[0]}"].shape[0]
half = n_pairs // 2
idx_A, idx_B = list(range(half)), list(range(half, n_pairs)) # disjoint pair halves = two v estimates
V_A, L_A, W_A, vA, _ = stacked_v_and_band(raw_grads, idx_A, wrappers, names, r, device)
V_B, L_B, W_B, vB, _ = stacked_v_and_band(raw_grads, idx_B, wrappers, names, r, device)
V_A, V_B = V_A.astype(np.float64), V_B.astype(np.float64) # float64 so the sanity gate reflects math, not rounding
cosAB = float(np.dot(V_A, V_B) / (np.linalg.norm(V_A) * np.linalg.norm(V_B) + 1e-12))
logger.info(f"two direction estimates from pair halves ({half} / {n_pairs-half}); "
f"cos(V_A,V_B)={cosAB:+.3f} (a real refresh shift)")
# ── score live batch ONCE: store U_i = stacked unit per-module gradients (v-INDEPENDENT) ──
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts]
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
U_rows, labels = [], []
for i, rec in enumerate(batch):
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tok, rec["prompt"], rec["text"], device)
if not torch.isfinite(loss):
continue
loss.backward()
u_blocks = []
for name in names:
g = wrappers[name]["layer"]._lora2r_gate.grad
g_b = g.sum(dim=tuple(range(g.dim() - 1)))[:r].float() # deployed block [r]
u = (g_b / g_b.norm().clamp_min(1e-12)).cpu() # v-independent unit grad
u_blocks.append(u)
U_rows.append(torch.cat(u_blocks).numpy())
labels.append(bool(rec["exploited"]))
if (i + 1) % 40 == 0:
logger.info(f" rollout {i+1}/{len(batch)}")
model.zero_grad(set_to_none=True)
U = np.stack(U_rows).astype(np.float64) # [n_roll, n_mod*r] the tracked gradient cloud
labels = np.array(labels)
logger.info(f"stored U cloud: {U.shape} ({U.nbytes/1e6:.1f} MB), {int(labels.sum())} exploited")
# ── pos under each direction by RE-PROJECTING the stored cloud (no model re-run) ──
pos_A = pos_from_U(U, V_A, L_A, W_A)
pos_B = pos_from_U(U, V_B, L_B, W_B)
# ── moment formula from the tracked cloud: mu_U, Sigma_U (v-independent) ──
mu_U = U.mean(0) # [D]
Uc = U - mu_U
Sigma_U = (Uc.T @ Uc) / U.shape[0] # [D, D] uncentered-then-centered covariance
def moment_stats(V, L, W):
mean = (mu_U @ V - L) / W
var = (V @ (Sigma_U @ V)) / (W * W)
return mean, np.sqrt(max(var, 0.0))
mA_emp, sA_emp = pos_A.mean(), pos_A.std()
mB_emp, sB_emp = pos_B.mean(), pos_B.std()
mA_mom, sA_mom = moment_stats(V_A, L_A, W_A)
mB_mom, sB_mom = moment_stats(V_B, L_B, W_B)
# ── sanity table ──
rows = [
["mean pos | v_A", f"{mA_emp:+.4f}", f"{mA_mom:+.4f}", f"{abs(mA_emp-mA_mom):.2e}"],
["std pos | v_A", f"{sA_emp:.4f}", f"{sA_mom:.4f}", f"{abs(sA_emp-sA_mom):.2e}"],
["mean pos | v_B", f"{mB_emp:+.4f}", f"{mB_mom:+.4f}", f"{abs(mB_emp-mB_mom):.2e}"],
["std pos | v_B", f"{sB_emp:.4f}", f"{sB_mom:.4f}", f"{abs(sB_emp-sB_mom):.2e}"],
]
print("\nSHOULD: moment-formula (from tracked mu_U, Sigma_U) == empirical, both v_A and v_B.")
print(tabulate(rows, headers=["quantity", "empirical (direct)", "moment formula", "abs diff"],
tablefmt="github"))
max_diff = max(abs(mA_emp-mA_mom), abs(sA_emp-sA_mom), abs(mB_emp-mB_mom), abs(sB_emp-sB_mom))
ok = max_diff < 1e-5
print(f"\n{'PASS' if ok else 'FAIL'}: max |empirical - moment| = {max_diff:.2e} (refresh needs no model re-run)")
logger.info(f"refresh shift: mean {mA_emp:+.3f}->{mB_emp:+.3f}, std {sA_emp:.3f}->{sB_emp:.3f}; "
f"AUROC v_A={_auroc(pos_A.tolist(), labels.tolist()):.3f} v_B={_auroc(pos_B.tolist(), labels.tolist()):.3f}")
# ── plot: pos under v_A (top) and v_B (bottom), mean +/- 2sd band recalibrates per direction ──
pl.DataFrame({"pos_A": pos_A, "pos_B": pos_B, "exploited": labels.tolist()}).write_parquet(
cfg.out_dir / "pinning_refresh.parquet")
lo = min(pos_A.min(), pos_B.min()) - 0.1
hi = max(pos_A.max(), pos_B.max()) + 0.1
grid = np.linspace(lo, hi, 400)
fig, axes = plt.subplots(2, 1, figsize=(8.6, 6.0), sharex=True)
for ax, pos, mean, std, tag in [(axes[0], pos_A, mA_emp, sA_emp, "A (pairs 1st half)"),
(axes[1], pos_B, mB_emp, sB_emp, "B (pairs 2nd half)")]:
lo_b, hi_b = mean - 2 * std, mean + 2 * std
ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0)
ax.axvline(mean, color=MEANC, lw=1.8)
for b in (lo_b, hi_b):
ax.axvline(b, color=MEANC, lw=1.2, ls="--")
peak = 0.0
for x, col in [(pos[~labels], SOLVE), (pos[labels], HACK)]:
y = _kde(x, grid)
ax.fill_between(grid, y, color=col, alpha=0.13, lw=0)
ax.plot(grid, y, color=col, lw=1.9)
peak = max(peak, y.max())
ax.set_ylim(0, peak * 1.15)
ax.set_ylabel("density")
ax.set_title(f"direction estimate {tag}: mean={mean:+.2f}, sd={std:.2f}, "
f"AUROC={_auroc(pos.tolist(), labels.tolist()):.2f}", fontsize=9, loc="left")
for s in ("top", "right"):
ax.spines[s].set_visible(False)
dist = [Line2D([0], [0], color=SOLVE, lw=2), Line2D([0], [0], color=HACK, lw=2),
Line2D([0], [0], color=MEANC, lw=1.8), Patch(facecolor=MEANC, alpha=0.18, ls="--", edgecolor=MEANC)]
axes[0].legend(dist, ["on-policy solve", "on-policy hack", "online mean", "mean +/- 2sd"],
loc="upper left", fontsize=8, frameon=False)
axes[1].set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$")
fig.suptitle("refresh tracking: same gradient cloud, two G_hack estimates "
f"(cos(v_A,v_B)={cosAB:+.2f}) -- band recalibrates with no model re-run", fontsize=9.5)
fig.tight_layout(rect=(0, 0, 1, 0.96))
fig.savefig(cfg.out_dir / "pinning_refresh.png", dpi=140)
plt.close(fig)
logger.info(f"wrote {cfg.out_dir}/pinning_refresh.png, pinning_refresh.parquet")
return 0 if ok else 1
if __name__ == "__main__":
raise SystemExit(main(tyro.cli(Cfg)))