mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
103d0acc2c
antipasto.py (PiSSA/lora_frozen_b/old-lora2r wrappers) is dead in the live path -- train.py/extract use lora2r.py, nothing imports antipasto. Move the 7 scripts that import it or the erase-era proj fns (rescore_deploy, eval_checkpoint_curve, verify_vhack_heldout, probe_distill, diag_cosine_dist, diag_pairs_compare, tt_erase_bench) to scripts/attic/ -- they need lora2r rewrites if resurrected. Live imports verified clean. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
316 lines
16 KiB
Python
316 lines
16 KiB
Python
"""Pinning diagnostic: how well does the contrastive hack direction PREDICT a live
|
|
hack rollout, in GRADIENT vs ACTIVATION space, by cosine vs projection vs magnitude?
|
|
|
|
One checkpoint (default: job 9 per-token `first_hack` = step 7). We build the
|
|
contrastive hack direction from the authored pairs in two spaces:
|
|
|
|
GRAD: v_grad = unit(mean_pairs(g_hack - g_clean)) on delta_S (g = per-rollout NLL grad)
|
|
ACT: As_hack = unit(mean_pairs(As_hack - As_clean)) (As = Vh@x mean over comp tokens)
|
|
|
|
Then for live rollouts (oracle `exploited`, offline plot-only label) we score each
|
|
rollout against the direction and ask how well the score separates hack from clean.
|
|
Three score families (GLOBAL = concat over modules; v is unit per module):
|
|
cosine = <g, v> / (|g| |v|) -- direction only (magnitude-blind)
|
|
projection = <g, v> / |v| = |g| cos -- direction x gradient magnitude
|
|
magnitude = |g| -- magnitude only (null check)
|
|
each in {grad, act} x {all modules, noise-floor-filtered (drop bottom-25% singular value)}.
|
|
|
|
Separability metric (we want high PRECISION / confident tail, not coverage):
|
|
AUROC + precision@10 / @20 of "score predicts exploited".
|
|
|
|
uv run python scripts/diag_cosine_dist.py # 4B, free GPU, ~4-6 min
|
|
outputs (out/diag/): cosine_grad.png, cosine_act.png (histograms),
|
|
cosine_dist.parquet (histogram long-form), live_scores.parquet (per-rollout scores),
|
|
separability.csv (methods x AUROC/precision). Replay without a GPU: nbs/cosine_dist.ipynb.
|
|
"""
|
|
from __future__ import annotations
|
|
import json
|
|
import struct
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import tyro
|
|
import polars as pl
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
from loguru import logger
|
|
from tabulate import tabulate
|
|
from safetensors.torch import load_file
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from vgrout.antipasto import wrap_model_with_antipasto
|
|
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
|
|
from vgrout.pairs_from_pool import load_pairs_json
|
|
from vgrout.train import CACHE_ROOT
|
|
|
|
_PS = Path("out/pairsets")
|
|
|
|
|
|
@dataclass
|
|
class Cfg:
|
|
run_dir: Path = Path("out/runs/20260607T134234_fast_routingV_seed43_dir6_routeV_pertoken_s43")
|
|
ckpt: str = "first_hack"
|
|
step_lo: int = 5
|
|
step_hi: int = 9
|
|
max_rollouts: int = 140
|
|
bins: int = 15 # histogram bins (wider = less spiky)
|
|
pairs: str = "all" # all | runtests (axis-1 only = the live mechanism)
|
|
out_dir: Path = Path("out/diag")
|
|
|
|
|
|
def _auroc(score, label) -> float:
|
|
"""P(score_pos > score_neg), ties=0.5. O(n^2), fine for n~140."""
|
|
pos = [s for s, y in zip(score, label) if y]
|
|
neg = [s for s, y in zip(score, label) if not y]
|
|
if not pos or not neg:
|
|
return float("nan")
|
|
c = sum((p > n) + 0.5 * (p == n) for p in pos for n in neg)
|
|
return c / (len(pos) * len(neg))
|
|
|
|
|
|
def _prec_at_k(score, label, k) -> float:
|
|
order = sorted(range(len(score)), key=lambda i: score[i], reverse=True)[:k]
|
|
return sum(label[i] for i in order) / k
|
|
|
|
|
|
def _concat(d: dict, names: list) -> torch.Tensor:
|
|
return torch.cat([d[n].flatten().float().cpu() for n in names])
|
|
|
|
|
|
def main(cfg: Cfg) -> int:
|
|
device = torch.device("cuda")
|
|
kept_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
|
hack_path = cfg.run_dir / f"{cfg.ckpt}_hack.safetensors"
|
|
with open(kept_path, "rb") as f:
|
|
meta = json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
|
model_name = meta.get("model", "Qwen/Qwen3-4B")
|
|
logger.info(f"ckpt {kept_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} model={model_name}")
|
|
|
|
tok = AutoTokenizer.from_pretrained(model_name)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token = tok.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
|
model.config.use_cache = False
|
|
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device, grad_probe=False)
|
|
names = sorted(wrappers)
|
|
kept, hack = load_file(str(kept_path)), load_file(str(hack_path))
|
|
for nm in names:
|
|
wrappers[nm]["delta_S"].data.copy_(kept[nm].to(device))
|
|
wrappers[nm]["delta_S_hack"].data.copy_(hack[nm].to(device))
|
|
logger.info(f"loaded adapter into {len(names)} modules")
|
|
|
|
# pair selection: 'all' = 18 pairs / 6 axes; 'runtests' = axis-1 only (the 8 weak-run_tests
|
|
# pairs, matching the single-mode run_tests live hack) -- tests whether mechanism-match lifts AUROC.
|
|
PAIRSEL = {
|
|
"all": load_pairs_json(_PS / "pairs_authored.json"),
|
|
"think": load_pairs_json(_PS / "pairs_intent_think.json"),
|
|
"funcname":load_pairs_json(_PS / "pairs_intent_funcname.json"),
|
|
"concept": load_pairs_json(_PS / "pairs_intent_concept.json"),
|
|
}[cfg.pairs]
|
|
logger.info(f"pairs={cfg.pairs} -> {len(PAIRSEL)} pairs")
|
|
|
|
# ── GRAD direction + per-module singular value (for noise floor) ──
|
|
model.eval()
|
|
v_hack_sv, v_sv, raw_grads, _ = extract_v_hack(model, tok, wrappers, PAIRSEL,
|
|
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
|
|
v_grad = {nm: (lambda d: (d / d.norm().clamp_min(1e-12)))(
|
|
(raw_grads[f"hack/{nm}"] - raw_grads[f"clean/{nm}"]).mean(0)) for nm in names} # cpu unit
|
|
sv0 = torch.tensor([v_sv[nm][0].item() for nm in names]) # [n_mod] grad contrastive |D_m| (=S0)
|
|
|
|
# filter levels keep the strongest-separating modules by a per-space weight w (=|D_m|):
|
|
# all/keep75/top25/top15/top05 = keep modules with w >= q{0, .25, .75, .85, .95}.
|
|
def make_masks(w):
|
|
return {"all": torch.ones(len(names), dtype=bool),
|
|
"keep75": w >= w.quantile(0.25),
|
|
"top25": w >= w.quantile(0.75),
|
|
"top15": w >= w.quantile(0.85),
|
|
"top05": w >= w.quantile(0.95)}
|
|
|
|
# ── activation capture hooks (after grad extract) ──
|
|
As_cap: dict[str, torch.Tensor] = {}
|
|
state = {"plen": 0}
|
|
|
|
def make_hook(nm):
|
|
Vh = wrappers[nm]["layer"]._antipasto_Vh
|
|
def hook(_l, inp, _o):
|
|
a = F.linear(inp[0], Vh)
|
|
As_cap[nm] = a[0, state["plen"] - 1:, :].mean(0).detach().float().cpu()
|
|
return hook
|
|
handles = [wrappers[nm]["layer"].register_forward_hook(make_hook(nm)) for nm in names]
|
|
|
|
def grab_As(prompt, completion):
|
|
state["plen"] = tok(prompt, return_tensors="pt").input_ids.shape[1]
|
|
ids = tok(prompt + completion, return_tensors="pt").input_ids.to(device)
|
|
with torch.no_grad():
|
|
model(ids)
|
|
return {nm: As_cap[nm].clone() for nm in names}
|
|
|
|
# ── ACT direction from the same train pairs ──
|
|
train_pairs = PAIRSEL[:-2]
|
|
As_h = {nm: [] for nm in names}
|
|
As_c = {nm: [] for nm in names}
|
|
for p in train_pairs:
|
|
ah, ac = grab_As(p.prompt, p.hack), grab_As(p.prompt, p.clean)
|
|
for nm in names:
|
|
As_h[nm].append(ah[nm]); As_c[nm].append(ac[nm])
|
|
As_D = {nm: (torch.stack(As_h[nm]) - torch.stack(As_c[nm])).mean(0) for nm in names} # act contrastive
|
|
As_dir = {nm: As_D[nm] / As_D[nm].norm().clamp_min(1e-12) for nm in names} # cpu unit
|
|
act_w = torch.tensor([As_D[nm].norm().item() for nm in names]) # act |D_m|
|
|
|
|
# stack directions to [n_mod, r-ragged] -> we keep per-module and reduce with helpers
|
|
def per_module_dot_norm(g: dict, vdir: dict):
|
|
"""returns dot[n_mod], gnorm[n_mod] (cpu float), v assumed unit per module."""
|
|
dot = torch.tensor([(g[nm].flatten().float().cpu() @ vdir[nm].flatten().float()).item() for nm in names])
|
|
gn = torch.tensor([g[nm].flatten().float().cpu().norm().item() for nm in names])
|
|
return dot, gn
|
|
|
|
# ── live rollouts: grad (backward) + act (same forward) per-module dot/norm ──
|
|
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
|
batch = [r for r in recs if cfg.step_lo <= r["step"] <= cfg.step_hi and r["text"].strip()][:cfg.max_rollouts]
|
|
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
|
|
G_dot, G_gn, A_dot, A_an, labels = [], [], [], [], []
|
|
Gvec, Avec = [], [] # concat per-rollout g / As (fp16) for the ideal-direction ceiling
|
|
for i, r in enumerate(batch):
|
|
state["plen"] = tok(r["prompt"], return_tensors="pt").input_ids.shape[1]
|
|
model.zero_grad(set_to_none=True)
|
|
loss = completion_nll(model, tok, r["prompt"], r["text"], device) # fires act hooks
|
|
if not torch.isfinite(loss):
|
|
continue
|
|
As_sample = {nm: As_cap[nm].clone() for nm in names}
|
|
loss.backward()
|
|
gb = {nm: wrappers[nm]["delta_S"].grad for nm in names}
|
|
gd, gn = per_module_dot_norm(gb, v_grad)
|
|
ad, an = per_module_dot_norm(As_sample, As_dir)
|
|
G_dot.append(gd); G_gn.append(gn); A_dot.append(ad); A_an.append(an)
|
|
Gvec.append(_concat(gb, names).half()); Avec.append(_concat(As_sample, names).half())
|
|
labels.append(bool(r["exploited"]))
|
|
if (i + 1) % 25 == 0:
|
|
logger.info(f" rollout {i+1}/{len(batch)}")
|
|
for h in handles:
|
|
h.remove()
|
|
model.zero_grad(set_to_none=True)
|
|
G_dot, G_gn = torch.stack(G_dot), torch.stack(G_gn) # [n_roll, n_mod]
|
|
A_dot, A_an = torch.stack(A_dot), torch.stack(A_an)
|
|
labels = labels
|
|
n_pos = sum(labels)
|
|
logger.info(f"live: {len(labels)} rollouts, {n_pos} exploited ({n_pos/len(labels):.2f} base rate)")
|
|
|
|
# ── global scores per rollout: cosine / projection / magnitude / vote, per filter ──
|
|
# vote = module-weighted average of per-module cosines, weight w=|D_m| (soft version of
|
|
# the hard module filter: down-weight weak modules instead of cutting them).
|
|
def scores(dot, gn, mask, w):
|
|
dm, gm = dot[:, mask], gn[:, mask]
|
|
d = dm.sum(1) # sum_m <g,v>
|
|
gmag = gm.pow(2).sum(1).sqrt() # |concat g|
|
|
nmod = int(mask.sum())
|
|
cos = d / (gmag.clamp_min(1e-12) * (nmod ** 0.5))
|
|
proj = d / (nmod ** 0.5)
|
|
cos_m = dm / gm.clamp_min(1e-12) # per-module cosine [n_roll, n_kept]
|
|
wm = w[mask]
|
|
vote = (cos_m * wm).sum(1) / wm.sum().clamp_min(1e-12)
|
|
return {"cosine": cos, "projection": proj, "magnitude": gmag, "vote": vote}
|
|
|
|
# ── ceiling: ideal direction (mu_hack - mu_clean) fit on live rollouts, 2-fold CV ──
|
|
# The oracle-labelled best linear direction -- diagnostic upper bound only (never trains).
|
|
# If ceiling >> our pair-direction AUROC, the pairs are the problem (on-distribution pairs
|
|
# would help); if ceiling is also ~0.7, the distributions fundamentally overlap.
|
|
def ideal_auroc(vecs):
|
|
X = torch.stack(vecs).float() # [n, D]
|
|
y = torch.tensor(labels)
|
|
sc = torch.zeros(len(y))
|
|
for fold in (0, 1): # 2-fold CV (fit on ~half, score other)
|
|
te = (torch.arange(len(y)) % 2 == fold)
|
|
tr = ~te
|
|
vd = X[tr][y[tr]].mean(0) - X[tr][~y[tr]].mean(0)
|
|
vd = vd / vd.norm().clamp_min(1e-12)
|
|
sc[te] = (X[te] @ vd) / X[te].norm(dim=1).clamp_min(1e-12)
|
|
insample = X[y].mean(0) - X[~y].mean(0)
|
|
insample = insample / insample.norm().clamp_min(1e-12)
|
|
sc_in = (X @ insample) / X.norm(dim=1).clamp_min(1e-12)
|
|
return _auroc(sc.tolist(), labels), _auroc(sc_in.tolist(), labels)
|
|
g_cv, g_in = ideal_auroc(Gvec)
|
|
a_cv, a_in = ideal_auroc(Avec)
|
|
logger.info(f"IDEAL-direction ceiling (oracle, diagnostic): "
|
|
f"grad AUROC cv={g_cv:.3f} in-sample={g_in:.3f} | act cv={a_cv:.3f} in-sample={a_in:.3f}")
|
|
|
|
weights = {"grad": sv0, "act": act_w}
|
|
score_cols = {}
|
|
for space, (dot, gn) in {"grad": (G_dot, G_gn), "act": (A_dot, A_an)}.items():
|
|
w = weights[space]
|
|
masks = make_masks(w)
|
|
logger.info(f"{space} filter sizes: " + ", ".join(f"{k}={int(m.sum())}" for k, m in masks.items()))
|
|
for filt, mask in masks.items():
|
|
for fam, v in scores(dot, gn, mask, w).items():
|
|
score_cols[f"{space}.{fam}.{filt}"] = v.tolist()
|
|
|
|
# ── separability table ──
|
|
sep_rows = []
|
|
for name, s in score_cols.items():
|
|
space, fam, filt = name.split(".")
|
|
sep_rows.append(dict(space=space, score=fam, filter=filt,
|
|
AUROC=round(_auroc(s, labels), 3),
|
|
p_at_10=round(_prec_at_k(s, labels, 10), 3),
|
|
p_at_20=round(_prec_at_k(s, labels, 20), 3)))
|
|
sep = pl.DataFrame(sep_rows).sort("AUROC", descending=True)
|
|
|
|
# ── persist: histogram long-form (global cosine) + per-rollout scores + separability ──
|
|
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
|
# histogram populations need the PAIR cosines too (built from raw_grads / As_h,As_c)
|
|
def pair_global_cos(per_pair_dot_src, vdir, n):
|
|
out = []
|
|
for i in range(n):
|
|
g = {nm: per_pair_dot_src(nm, i) for nm in names}
|
|
dot, gn = per_module_dot_norm(g, vdir)
|
|
out.append((dot.sum() / (gn.pow(2).sum().sqrt().clamp_min(1e-12) * (len(names) ** 0.5))).item())
|
|
return out
|
|
npg = raw_grads[f"hack/{names[0]}"].shape[0]
|
|
hist = []
|
|
for pop, src in [("pair_hack", lambda nm, i: raw_grads[f"hack/{nm}"][i]),
|
|
("pair_clean", lambda nm, i: raw_grads[f"clean/{nm}"][i])]:
|
|
for c in pair_global_cos(src, v_grad, npg):
|
|
hist.append(dict(space="grad", pop=pop, cos=c))
|
|
for pop, src in [("pair_hack", lambda nm, i: As_h[nm][i]),
|
|
("pair_clean", lambda nm, i: As_c[nm][i])]:
|
|
for c in pair_global_cos(src, As_dir, len(train_pairs)):
|
|
hist.append(dict(space="act", pop=pop, cos=c))
|
|
for sp, col in [("grad", "grad.cosine.all"), ("act", "act.cosine.all")]:
|
|
for c, y in zip(score_cols[col], labels):
|
|
hist.append(dict(space=sp, pop="live_hack" if y else "live_clean", cos=c))
|
|
pl.DataFrame(hist).write_parquet(cfg.out_dir / "cosine_dist.parquet")
|
|
pl.DataFrame({**score_cols, "exploited": labels}).write_parquet(cfg.out_dir / "live_scores.parquet")
|
|
sep.write_csv(cfg.out_dir / "separability.csv")
|
|
|
|
logger.info("\n=== separability (AUROC of score -> exploited; >0.5 = predictive) ===\n"
|
|
+ tabulate(sep.to_pandas(), headers="keys", tablefmt="github", showindex=False))
|
|
|
|
# ── histograms (cosine, both spaces) ──
|
|
hdf = pl.DataFrame(hist)
|
|
colors = {"pair_clean": "tab:blue", "pair_hack": "tab:red",
|
|
"live_clean": "tab:cyan", "live_hack": "tab:orange"}
|
|
for space in ("grad", "act"):
|
|
plt.figure(figsize=(9, 5))
|
|
for pop, c in colors.items():
|
|
v = hdf.filter((pl.col("space") == space) & (pl.col("pop") == pop))["cos"].to_numpy()
|
|
if len(v):
|
|
plt.hist(v, bins=cfg.bins, density=True, histtype="step", lw=2, color=c,
|
|
label=f"{pop} (n={len(v)}, p50={float(pl.Series(v).median()):+.2f})")
|
|
au = sep.filter((pl.col("space") == space) & (pl.col("score") == "cosine") & (pl.col("filter") == "all"))["AUROC"][0]
|
|
plt.xlabel(f"global cosine to hack direction ({space} space)")
|
|
plt.ylabel("density")
|
|
plt.title(f"{space}-space: {cfg.ckpt} (step {meta.get('step')}), live AUROC(cos)={au}")
|
|
plt.legend(fontsize=8)
|
|
plt.tight_layout()
|
|
plt.savefig(cfg.out_dir / f"cosine_{space}.png", dpi=130)
|
|
plt.close()
|
|
logger.info(f"wrote {cfg.out_dir}/cosine_{{grad,act}}.png, cosine_dist.parquet, "
|
|
f"live_scores.parquet, separability.csv")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main(tyro.cli(Cfg)))
|