mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:23:57 +08:00
979daf84fd
Gate calibration: route by live mean + route_std_mid/route_std_rout * std of the pooled cosine-to-v_grad, not a fixed quantile tail. Self-silences -- only the tail that genuinely exceeds the spread routes, so qmass tracks real separation instead of a forced fraction. The authored absolute band is mis-placed (live pos sits far below the synthetic-hack edge; even synthetic solve out-aligns on-policy hack). tablelog: auroc/rout/routE/keep/resid/qmass cols always shown (nan on vanilla) so arm tables line up. Diagnostics: scripts/diag_pinning.py (4-population calibration view, mean+/-2sd band) and scripts/diag_pinning_refresh.py (proves cosine stats recompute from a tracked v-independent gradient cloud on a v_grad refresh -- exact for k=1, sanity 2.5e-16). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
228 lines
11 KiB
Python
228 lines
11 KiB
Python
"""Refresh-tracking prototype: when G_hack (v_grad) refreshes, can we recompute the
|
|
live cosine stats from a TRACKED gradient cloud, without re-running the model?
|
|
|
|
Claim (k=1): the gate score pos_i = (S_i - L)/W with S_i = sum_m <u_{i,m}, v_m>, where
|
|
u_{i,m} = g_{i,m}/|g_{i,m}| is the v-INDEPENDENT unit per-module gradient. Stacking
|
|
U_i=[u_{i,1};..], V=[v_1;..] gives S_i = U_i . V (linear in V). So the live cosine
|
|
distribution under ANY v is an affine push of the SAME gradient cloud:
|
|
mean(pos) = (mu_U . V - L) / W
|
|
var(pos) = (V^T Sigma_U V) / W^2
|
|
mu_U, Sigma_U are tracked once (v-independent); L,W come from the pairs at refresh. So a
|
|
refresh needs no model re-run and no window flush -- just re-project the cloud onto v_new.
|
|
|
|
We simulate a refresh with two direction estimates v_A, v_B = unit-mean-diff of two DISJOINT
|
|
halves of the authored pairs (model fixed, only the direction changes), and verify:
|
|
(1) reproject stored U onto v == direct pos (exact, validates indexing)
|
|
(2) moment formula (mu_U, Sigma_U) == empirical mean/sd of pos, for BOTH v_A and v_B
|
|
from the SAME tracked cloud (the actual claim).
|
|
|
|
uv run python scripts/diag_pinning_refresh.py
|
|
outputs (out/diag/): pinning_refresh.png, pinning_refresh.parquet, prints a sanity table.
|
|
"""
|
|
from __future__ import annotations
|
|
import json
|
|
import struct
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
import tyro
|
|
import polars as pl
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.lines import Line2D
|
|
from matplotlib.patches import Patch
|
|
from loguru import logger
|
|
from tabulate import tabulate
|
|
from safetensors.torch import load_file
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from vgrout.lora2r import wrap_model_with_lora2r
|
|
from vgrout.pairs import load_pairs
|
|
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
|
|
from vgrout.train import _build_v_grad, route_band_edges, _auroc
|
|
from diag_pinning import _kde, SOLVE, HACK, MEANC # same dir (scripts/ is sys.path[0] when run directly)
|
|
|
|
|
|
@dataclass
|
|
class Cfg:
|
|
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
|
ckpt: str = "first_hack"
|
|
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
|
step_lo: int = 2
|
|
step_hi: int = 9
|
|
max_rollouts: int = 240
|
|
out_dir: Path = Path("out/diag")
|
|
|
|
|
|
def _ckpt_meta(path: Path) -> dict:
|
|
with open(path, "rb") as f:
|
|
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
|
|
|
|
|
def stacked_v_and_band(raw_grads: dict, idx, wrappers, names, r, device):
|
|
"""Direction estimate + band from a SUBSET of pairs (simulated refresh).
|
|
|
|
Returns V_stacked [n_mod*r] (zeroed on modules the band excludes), and scalars L,W.
|
|
Matches pooled_pos: a module enters iff its band width > 0; excluded modules contribute 0."""
|
|
sub = {k: v[idx] for k, v in raw_grads.items()}
|
|
v_grad = _build_v_grad(sub, wrappers, 1, device) # {name: [1, r]} unit
|
|
band = route_band_edges(sub, v_grad, device) # {name: (lower, upper)}
|
|
V, L, W = [], 0.0, 0.0
|
|
for name in names:
|
|
lower, upper = band[name]
|
|
if upper - lower > 0:
|
|
V.append(v_grad[name][0].float().cpu()) # [r]
|
|
L += lower; W += (upper - lower)
|
|
else:
|
|
V.append(torch.zeros(r)) # excluded -> no contribution
|
|
return torch.cat(V).numpy(), float(L), float(W), v_grad, band
|
|
|
|
|
|
def pos_from_U(U: np.ndarray, V: np.ndarray, L: float, W: float) -> np.ndarray:
|
|
"""pos = (U . V - L) / W -- re-project the stored cloud onto direction V (no model)."""
|
|
return (U @ V - L) / W
|
|
|
|
|
|
def main(cfg: Cfg) -> int:
|
|
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
|
device = torch.device("cuda")
|
|
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
|
meta = _ckpt_meta(ckpt_path)
|
|
run_cfg = json.loads(meta.get("cfg", "{}"))
|
|
model_name = run_cfg.get("model", "Qwen/Qwen3-4B")
|
|
r = run_cfg.get("lora_r", 32)
|
|
init_seed = run_cfg.get("lora_init_seed", 0)
|
|
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} model={model_name} r={r}")
|
|
|
|
tok = AutoTokenizer.from_pretrained(model_name)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token = tok.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
|
model.config.use_cache = False
|
|
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
|
|
names = sorted(wrappers)
|
|
sd = load_file(str(ckpt_path))
|
|
for nm in names:
|
|
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
|
|
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
|
|
|
|
pairs = load_pairs(cfg.pairs)
|
|
model.eval()
|
|
_, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs,
|
|
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
|
|
n_pairs = raw_grads[f"hack/{names[0]}"].shape[0]
|
|
half = n_pairs // 2
|
|
idx_A, idx_B = list(range(half)), list(range(half, n_pairs)) # disjoint pair halves = two v estimates
|
|
V_A, L_A, W_A, vA, _ = stacked_v_and_band(raw_grads, idx_A, wrappers, names, r, device)
|
|
V_B, L_B, W_B, vB, _ = stacked_v_and_band(raw_grads, idx_B, wrappers, names, r, device)
|
|
V_A, V_B = V_A.astype(np.float64), V_B.astype(np.float64) # float64 so the sanity gate reflects math, not rounding
|
|
cosAB = float(np.dot(V_A, V_B) / (np.linalg.norm(V_A) * np.linalg.norm(V_B) + 1e-12))
|
|
logger.info(f"two direction estimates from pair halves ({half} / {n_pairs-half}); "
|
|
f"cos(V_A,V_B)={cosAB:+.3f} (a real refresh shift)")
|
|
|
|
# ── score live batch ONCE: store U_i = stacked unit per-module gradients (v-INDEPENDENT) ──
|
|
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
|
batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts]
|
|
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
|
|
U_rows, labels = [], []
|
|
for i, rec in enumerate(batch):
|
|
model.zero_grad(set_to_none=True)
|
|
loss = completion_nll(model, tok, rec["prompt"], rec["text"], device)
|
|
if not torch.isfinite(loss):
|
|
continue
|
|
loss.backward()
|
|
u_blocks = []
|
|
for name in names:
|
|
g = wrappers[name]["layer"]._lora2r_gate.grad
|
|
g_b = g.sum(dim=tuple(range(g.dim() - 1)))[:r].float() # deployed block [r]
|
|
u = (g_b / g_b.norm().clamp_min(1e-12)).cpu() # v-independent unit grad
|
|
u_blocks.append(u)
|
|
U_rows.append(torch.cat(u_blocks).numpy())
|
|
labels.append(bool(rec["exploited"]))
|
|
if (i + 1) % 40 == 0:
|
|
logger.info(f" rollout {i+1}/{len(batch)}")
|
|
model.zero_grad(set_to_none=True)
|
|
U = np.stack(U_rows).astype(np.float64) # [n_roll, n_mod*r] the tracked gradient cloud
|
|
labels = np.array(labels)
|
|
logger.info(f"stored U cloud: {U.shape} ({U.nbytes/1e6:.1f} MB), {int(labels.sum())} exploited")
|
|
|
|
# ── pos under each direction by RE-PROJECTING the stored cloud (no model re-run) ──
|
|
pos_A = pos_from_U(U, V_A, L_A, W_A)
|
|
pos_B = pos_from_U(U, V_B, L_B, W_B)
|
|
|
|
# ── moment formula from the tracked cloud: mu_U, Sigma_U (v-independent) ──
|
|
mu_U = U.mean(0) # [D]
|
|
Uc = U - mu_U
|
|
Sigma_U = (Uc.T @ Uc) / U.shape[0] # [D, D] uncentered-then-centered covariance
|
|
def moment_stats(V, L, W):
|
|
mean = (mu_U @ V - L) / W
|
|
var = (V @ (Sigma_U @ V)) / (W * W)
|
|
return mean, np.sqrt(max(var, 0.0))
|
|
mA_emp, sA_emp = pos_A.mean(), pos_A.std()
|
|
mB_emp, sB_emp = pos_B.mean(), pos_B.std()
|
|
mA_mom, sA_mom = moment_stats(V_A, L_A, W_A)
|
|
mB_mom, sB_mom = moment_stats(V_B, L_B, W_B)
|
|
|
|
# ── sanity table ──
|
|
rows = [
|
|
["mean pos | v_A", f"{mA_emp:+.4f}", f"{mA_mom:+.4f}", f"{abs(mA_emp-mA_mom):.2e}"],
|
|
["std pos | v_A", f"{sA_emp:.4f}", f"{sA_mom:.4f}", f"{abs(sA_emp-sA_mom):.2e}"],
|
|
["mean pos | v_B", f"{mB_emp:+.4f}", f"{mB_mom:+.4f}", f"{abs(mB_emp-mB_mom):.2e}"],
|
|
["std pos | v_B", f"{sB_emp:.4f}", f"{sB_mom:.4f}", f"{abs(sB_emp-sB_mom):.2e}"],
|
|
]
|
|
print("\nSHOULD: moment-formula (from tracked mu_U, Sigma_U) == empirical, both v_A and v_B.")
|
|
print(tabulate(rows, headers=["quantity", "empirical (direct)", "moment formula", "abs diff"],
|
|
tablefmt="github"))
|
|
max_diff = max(abs(mA_emp-mA_mom), abs(sA_emp-sA_mom), abs(mB_emp-mB_mom), abs(sB_emp-sB_mom))
|
|
ok = max_diff < 1e-5
|
|
print(f"\n{'PASS' if ok else 'FAIL'}: max |empirical - moment| = {max_diff:.2e} (refresh needs no model re-run)")
|
|
logger.info(f"refresh shift: mean {mA_emp:+.3f}->{mB_emp:+.3f}, std {sA_emp:.3f}->{sB_emp:.3f}; "
|
|
f"AUROC v_A={_auroc(pos_A.tolist(), labels.tolist()):.3f} v_B={_auroc(pos_B.tolist(), labels.tolist()):.3f}")
|
|
|
|
# ── plot: pos under v_A (top) and v_B (bottom), mean +/- 2sd band recalibrates per direction ──
|
|
pl.DataFrame({"pos_A": pos_A, "pos_B": pos_B, "exploited": labels.tolist()}).write_parquet(
|
|
cfg.out_dir / "pinning_refresh.parquet")
|
|
lo = min(pos_A.min(), pos_B.min()) - 0.1
|
|
hi = max(pos_A.max(), pos_B.max()) + 0.1
|
|
grid = np.linspace(lo, hi, 400)
|
|
fig, axes = plt.subplots(2, 1, figsize=(8.6, 6.0), sharex=True)
|
|
for ax, pos, mean, std, tag in [(axes[0], pos_A, mA_emp, sA_emp, "A (pairs 1st half)"),
|
|
(axes[1], pos_B, mB_emp, sB_emp, "B (pairs 2nd half)")]:
|
|
lo_b, hi_b = mean - 2 * std, mean + 2 * std
|
|
ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0)
|
|
ax.axvline(mean, color=MEANC, lw=1.8)
|
|
for b in (lo_b, hi_b):
|
|
ax.axvline(b, color=MEANC, lw=1.2, ls="--")
|
|
peak = 0.0
|
|
for x, col in [(pos[~labels], SOLVE), (pos[labels], HACK)]:
|
|
y = _kde(x, grid)
|
|
ax.fill_between(grid, y, color=col, alpha=0.13, lw=0)
|
|
ax.plot(grid, y, color=col, lw=1.9)
|
|
peak = max(peak, y.max())
|
|
ax.set_ylim(0, peak * 1.15)
|
|
ax.set_ylabel("density")
|
|
ax.set_title(f"direction estimate {tag}: mean={mean:+.2f}, sd={std:.2f}, "
|
|
f"AUROC={_auroc(pos.tolist(), labels.tolist()):.2f}", fontsize=9, loc="left")
|
|
for s in ("top", "right"):
|
|
ax.spines[s].set_visible(False)
|
|
dist = [Line2D([0], [0], color=SOLVE, lw=2), Line2D([0], [0], color=HACK, lw=2),
|
|
Line2D([0], [0], color=MEANC, lw=1.8), Patch(facecolor=MEANC, alpha=0.18, ls="--", edgecolor=MEANC)]
|
|
axes[0].legend(dist, ["on-policy solve", "on-policy hack", "online mean", "mean +/- 2sd"],
|
|
loc="upper left", fontsize=8, frameon=False)
|
|
axes[1].set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$")
|
|
fig.suptitle("refresh tracking: same gradient cloud, two G_hack estimates "
|
|
f"(cos(v_A,v_B)={cosAB:+.2f}) -- band recalibrates with no model re-run", fontsize=9.5)
|
|
fig.tight_layout(rect=(0, 0, 1, 0.96))
|
|
fig.savefig(cfg.out_dir / "pinning_refresh.png", dpi=140)
|
|
plt.close(fig)
|
|
logger.info(f"wrote {cfg.out_dir}/pinning_refresh.png, pinning_refresh.parquet")
|
|
return 0 if ok else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main(tyro.cli(Cfg)))
|