"""Refresh-tracking prototype: when G_hack (v_grad) refreshes, can we recompute the live cosine stats from a TRACKED gradient cloud, without re-running the model? Claim (k=1): the gate score pos_i = (S_i - L)/W with S_i = sum_m , where u_{i,m} = g_{i,m}/|g_{i,m}| is the v-INDEPENDENT unit per-module gradient. Stacking U_i=[u_{i,1};..], V=[v_1;..] gives S_i = U_i . V (linear in V). So the live cosine distribution under ANY v is an affine push of the SAME gradient cloud: mean(pos) = (mu_U . V - L) / W var(pos) = (V^T Sigma_U V) / W^2 mu_U, Sigma_U are tracked once (v-independent); L,W come from the pairs at refresh. So a refresh needs no model re-run and no window flush -- just re-project the cloud onto v_new. We simulate a refresh with two direction estimates v_A, v_B = unit-mean-diff of two DISJOINT halves of the authored pairs (model fixed, only the direction changes), and verify: (1) reproject stored U onto v == direct pos (exact, validates indexing) (2) moment formula (mu_U, Sigma_U) == empirical mean/sd of pos, for BOTH v_A and v_B from the SAME tracked cloud (the actual claim). uv run python scripts/diag_pinning_refresh.py outputs (out/diag/): pinning_refresh.png, pinning_refresh.parquet, prints a sanity table. """ from __future__ import annotations import json import struct from dataclasses import dataclass from pathlib import Path import numpy as np import torch import tyro import polars as pl import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.lines import Line2D from matplotlib.patches import Patch from loguru import logger from tabulate import tabulate from safetensors.torch import load_file from transformers import AutoModelForCausalLM, AutoTokenizer from vgrout.lora2r import wrap_model_with_lora2r from vgrout.pairs import load_pairs from vgrout.extract_vhack_grad import extract_v_hack, completion_nll from vgrout.train import _build_v_grad, route_band_edges, _auroc from diag_pinning import _kde, SOLVE, HACK, ABSORB_C as MEANC # same dir (scripts/ is sys.path[0] when run directly) @dataclass class Cfg: run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3") ckpt: str = "first_hack" pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one") step_lo: int = 2 step_hi: int = 9 max_rollouts: int = 240 out_dir: Path = Path("out/diag") def _ckpt_meta(path: Path) -> dict: with open(path, "rb") as f: return json.loads(f.read(struct.unpack(" 0; excluded modules contribute 0.""" sub = {k: v[idx] for k, v in raw_grads.items()} v_grad = _build_v_grad(sub, wrappers, 1, device) # {name: [1, r]} unit band = route_band_edges(sub, v_grad, device) # {name: (lower, upper)} V, L, W = [], 0.0, 0.0 for name in names: lower, upper = band[name] if upper - lower > 0: V.append(v_grad[name][0].float().cpu()) # [r] L += lower; W += (upper - lower) else: V.append(torch.zeros(r)) # excluded -> no contribution return torch.cat(V).numpy(), float(L), float(W), v_grad, band def pos_from_U(U: np.ndarray, V: np.ndarray, L: float, W: float) -> np.ndarray: """pos = (U . V - L) / W -- re-project the stored cloud onto direction V (no model).""" return (U @ V - L) / W def main(cfg: Cfg) -> int: cfg.out_dir.mkdir(parents=True, exist_ok=True) device = torch.device("cuda") ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors" meta = _ckpt_meta(ckpt_path) run_cfg = json.loads(meta.get("cfg", "{}")) model_name = run_cfg.get("model", "Qwen/Qwen3-4B") r = run_cfg.get("lora_r", 32) init_seed = run_cfg.get("lora_init_seed", 0) logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} model={model_name} r={r}") tok = AutoTokenizer.from_pretrained(model_name) if tok.pad_token_id is None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device) model.config.use_cache = False wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True) names = sorted(wrappers) sd = load_file(str(ckpt_path)) for nm in names: wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32)) wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32)) pairs = load_pairs(cfg.pairs) model.eval() _, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs, top_k=1, tau_axis=0.0, n_heldout=2, device=device) n_pairs = raw_grads[f"hack/{names[0]}"].shape[0] half = n_pairs // 2 idx_A, idx_B = list(range(half)), list(range(half, n_pairs)) # disjoint pair halves = two v estimates V_A, L_A, W_A, vA, _ = stacked_v_and_band(raw_grads, idx_A, wrappers, names, r, device) V_B, L_B, W_B, vB, _ = stacked_v_and_band(raw_grads, idx_B, wrappers, names, r, device) V_A, V_B = V_A.astype(np.float64), V_B.astype(np.float64) # float64 so the sanity gate reflects math, not rounding cosAB = float(np.dot(V_A, V_B) / (np.linalg.norm(V_A) * np.linalg.norm(V_B) + 1e-12)) logger.info(f"two direction estimates from pair halves ({half} / {n_pairs-half}); " f"cos(V_A,V_B)={cosAB:+.3f} (a real refresh shift)") # ── score live batch ONCE: store U_i = stacked unit per-module gradients (v-INDEPENDENT) ── recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()] batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts] logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})") U_rows, labels = [], [] for i, rec in enumerate(batch): model.zero_grad(set_to_none=True) loss = completion_nll(model, tok, rec["prompt"], rec["text"], device) if not torch.isfinite(loss): continue loss.backward() u_blocks = [] for name in names: g = wrappers[name]["layer"]._lora2r_gate.grad g_b = g.sum(dim=tuple(range(g.dim() - 1)))[:r].float() # deployed block [r] u = (g_b / g_b.norm().clamp_min(1e-12)).cpu() # v-independent unit grad u_blocks.append(u) U_rows.append(torch.cat(u_blocks).numpy()) labels.append(bool(rec["exploited"])) if (i + 1) % 40 == 0: logger.info(f" rollout {i+1}/{len(batch)}") model.zero_grad(set_to_none=True) U = np.stack(U_rows).astype(np.float64) # [n_roll, n_mod*r] the tracked gradient cloud labels = np.array(labels) logger.info(f"stored U cloud: {U.shape} ({U.nbytes/1e6:.1f} MB), {int(labels.sum())} exploited") # ── pos under each direction by RE-PROJECTING the stored cloud (no model re-run) ── pos_A = pos_from_U(U, V_A, L_A, W_A) pos_B = pos_from_U(U, V_B, L_B, W_B) # ── moment formula from the tracked cloud: mu_U, Sigma_U (v-independent) ── mu_U = U.mean(0) # [D] Uc = U - mu_U Sigma_U = (Uc.T @ Uc) / U.shape[0] # [D, D] uncentered-then-centered covariance def moment_stats(V, L, W): mean = (mu_U @ V - L) / W var = (V @ (Sigma_U @ V)) / (W * W) return mean, np.sqrt(max(var, 0.0)) mA_emp, sA_emp = pos_A.mean(), pos_A.std() mB_emp, sB_emp = pos_B.mean(), pos_B.std() mA_mom, sA_mom = moment_stats(V_A, L_A, W_A) mB_mom, sB_mom = moment_stats(V_B, L_B, W_B) # ── sanity table ── rows = [ ["mean pos | v_A", f"{mA_emp:+.4f}", f"{mA_mom:+.4f}", f"{abs(mA_emp-mA_mom):.2e}"], ["std pos | v_A", f"{sA_emp:.4f}", f"{sA_mom:.4f}", f"{abs(sA_emp-sA_mom):.2e}"], ["mean pos | v_B", f"{mB_emp:+.4f}", f"{mB_mom:+.4f}", f"{abs(mB_emp-mB_mom):.2e}"], ["std pos | v_B", f"{sB_emp:.4f}", f"{sB_mom:.4f}", f"{abs(sB_emp-sB_mom):.2e}"], ] print("\nSHOULD: moment-formula (from tracked mu_U, Sigma_U) == empirical, both v_A and v_B.") print(tabulate(rows, headers=["quantity", "empirical (direct)", "moment formula", "abs diff"], tablefmt="github")) max_diff = max(abs(mA_emp-mA_mom), abs(sA_emp-sA_mom), abs(mB_emp-mB_mom), abs(sB_emp-sB_mom)) ok = max_diff < 1e-5 print(f"\n{'PASS' if ok else 'FAIL'}: max |empirical - moment| = {max_diff:.2e} (refresh needs no model re-run)") logger.info(f"refresh shift: mean {mA_emp:+.3f}->{mB_emp:+.3f}, std {sA_emp:.3f}->{sB_emp:.3f}; " f"AUROC v_A={_auroc(pos_A.tolist(), labels.tolist()):.3f} v_B={_auroc(pos_B.tolist(), labels.tolist()):.3f}") # ── plot: pos under v_A (top) and v_B (bottom), mean +/- 2sd band recalibrates per direction ── pl.DataFrame({"pos_A": pos_A, "pos_B": pos_B, "exploited": labels.tolist()}).write_parquet( cfg.out_dir / "pinning_refresh.parquet") lo = min(pos_A.min(), pos_B.min()) - 0.1 hi = max(pos_A.max(), pos_B.max()) + 0.1 grid = np.linspace(lo, hi, 400) fig, axes = plt.subplots(2, 1, figsize=(8.6, 6.0), sharex=True) for ax, pos, mean, std, tag in [(axes[0], pos_A, mA_emp, sA_emp, "A (pairs 1st half)"), (axes[1], pos_B, mB_emp, sB_emp, "B (pairs 2nd half)")]: lo_b, hi_b = mean - 2 * std, mean + 2 * std ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0) ax.axvline(mean, color=MEANC, lw=1.8) for b in (lo_b, hi_b): ax.axvline(b, color=MEANC, lw=1.2, ls="--") peak = 0.0 for x, col in [(pos[~labels], SOLVE), (pos[labels], HACK)]: y = _kde(x, grid) ax.fill_between(grid, y, color=col, alpha=0.13, lw=0) ax.plot(grid, y, color=col, lw=1.9) peak = max(peak, y.max()) ax.set_ylim(0, peak * 1.15) ax.set_ylabel("density") ax.set_title(f"direction estimate {tag}: mean={mean:+.2f}, sd={std:.2f}, " f"AUROC={_auroc(pos.tolist(), labels.tolist()):.2f}", fontsize=9, loc="left") for s in ("top", "right"): ax.spines[s].set_visible(False) dist = [Line2D([0], [0], color=SOLVE, lw=2), Line2D([0], [0], color=HACK, lw=2), Line2D([0], [0], color=MEANC, lw=1.8), Patch(facecolor=MEANC, alpha=0.18, ls="--", edgecolor=MEANC)] axes[0].legend(dist, ["on-policy solve", "on-policy hack", "online mean", "mean +/- 2sd"], loc="upper left", fontsize=8, frameon=False) axes[1].set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$") fig.suptitle("refresh tracking: same gradient cloud, two G_hack estimates " f"(cos(v_A,v_B)={cosAB:+.2f}) -- band recalibrates with no model re-run", fontsize=9.5) fig.tight_layout(rect=(0, 0, 1, 0.96)) fig.savefig(cfg.out_dir / "pinning_refresh.png", dpi=140) plt.close(fig) logger.info(f"wrote {cfg.out_dir}/pinning_refresh.png, pinning_refresh.parquet") return 0 if ok else 1 if __name__ == "__main__": raise SystemExit(main(tyro.cli(Cfg)))