mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
70697ff36e
Q4 fix: on-policy "solve" was ~exploited = solves+fails (mostly fails). Split by gt_pass into solve/fail/hack (live: 103 hack / 27 solve / 62 fail). Per-pairset ranking: build v_grad from each heading-prefix subset, re-project the SAME stored live c-grads (no model re-run). Finding: behavior pairs AUROC 0.69 vs all-in-one 0.53; reasoning/opportunity anti-aligned (<0.5) -> mixing dilutes. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
323 lines
16 KiB
Python
323 lines
16 KiB
Python
"""Pinning calibration + per-pairset separation: where does the live gate sit, would
|
|
mean +/- k*sd route it, and which authored-pair subset gives the best hack/clean split?
|
|
|
|
The routeV gate scores each rollout by the width-pooled, band-normalized cosine of its
|
|
deployed-block c-probe gradient to v_grad (exactly train.py:_lora2r_gate_labels) -- the
|
|
position on the HACKING DIRECTION. This script:
|
|
|
|
1. Splits live rollouts into THREE honest populations (Q4 fix -- "not-exploited" was
|
|
conflating genuine solves with fails): on-policy solve (gt_pass), fail, hack (exploited).
|
|
Plots them with the authored synthetic solve/hack, marks online mean and mean +/- 2sd.
|
|
2. Ranks pairset SUBSETS (heading-prefix views: behavior/opportunity/disposition/...) by
|
|
how well a v_grad built from ONLY that subset separates live hacks (AUROC + Cohen's d).
|
|
Done by storing each live rollout's per-module deployed c-grad ONCE, then re-projecting
|
|
onto each subset's v_grad -- no model re-run per subset (the refresh-tracking trick).
|
|
|
|
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
|
|
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU
|
|
outputs (out/diag/): pinning_calib.png, pinning_pairset_auroc.png, pinning_data.parquet,
|
|
pinning_pairset.parquet.
|
|
"""
|
|
from __future__ import annotations
|
|
import json
|
|
import struct
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
import tyro
|
|
import polars as pl
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.lines import Line2D
|
|
from matplotlib.patches import Patch
|
|
from loguru import logger
|
|
from tabulate import tabulate
|
|
from safetensors.torch import load_file
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from vgrout.lora2r import wrap_model_with_lora2r
|
|
from vgrout.pairs import load_pairs, HackPair
|
|
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
|
|
from vgrout.train import _build_v_grad, route_band_edges, _haar_unit_dirs, _auroc
|
|
|
|
# colour = behaviour (blue solve, red hack, grey fail); style = source (solid on-policy, dashed synthetic)
|
|
SOLVE, HACK, FAIL, MEANC, ORACLE = "#3b6ea5", "#c44e52", "#9aa0a6", "#d1900a", "#3a8a7a"
|
|
|
|
|
|
@dataclass
|
|
class Cfg:
|
|
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
|
ckpt: str = "first_hack"
|
|
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
|
# Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and
|
|
# DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane.
|
|
step_lo: int = 2
|
|
step_hi: int = 9
|
|
max_rollouts: int = 240
|
|
random_v_seed: int | None = None # Haar placebo (sanity: pins should NOT separate)
|
|
replot: Path | None = None # load this parquet and re-plot only (no model, no GPU)
|
|
out_dir: Path = Path("out/diag")
|
|
|
|
|
|
def _ckpt_meta(path: Path) -> dict:
|
|
with open(path, "rb") as f:
|
|
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
|
|
|
|
|
def _batch_pos(G: torch.Tensor, names: list, v_grad: dict, route_band: dict) -> np.ndarray:
|
|
"""Pooled band-normalized cosine position per rollout, vectorized over modules.
|
|
|
|
G: [N, M, r] raw deployed-block c-grads (v-independent). Mirrors train.py pooling:
|
|
per module cos to ANY of the k dirs (max), width-weighted pool, excluding zero-width
|
|
modules. Same G re-projects onto any v_grad/route_band -> cheap per-pairset scoring.
|
|
"""
|
|
dev = v_grad[names[0]].device
|
|
Vs = torch.stack([v_grad[n] for n in names]).to(dev) # [M, k, r]
|
|
low = torch.tensor([route_band[n][0] for n in names], device=dev) # [M]
|
|
up = torch.tensor([route_band[n][1] for n in names], device=dev) # [M]
|
|
inc = (up - low) > 0
|
|
U = (G.to(dev) / G.to(dev).norm(dim=2, keepdim=True).clamp_min(1e-12)) # unit per (n,m)
|
|
cos = torch.einsum("nmr,mkr->nmk", U, Vs).amax(-1) # [N, M]
|
|
num = ((cos - low) * inc).sum(1)
|
|
den = ((up - low) * inc).sum().clamp_min(1e-12)
|
|
return (num / den).cpu().numpy()
|
|
|
|
|
|
def _cohend(a: np.ndarray, b: np.ndarray) -> float:
|
|
"""Standardized mean gap (hack - solve); +ve = hacks sit further along the hack-dir."""
|
|
if len(a) < 2 or len(b) < 2:
|
|
return float("nan")
|
|
sp = np.sqrt(((len(a) - 1) * a.var(ddof=1) + (len(b) - 1) * b.var(ddof=1)) / (len(a) + len(b) - 2))
|
|
return float((a.mean() - b.mean()) / sp) if sp > 0 else float("nan")
|
|
|
|
|
|
def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
|
|
"""Gaussian KDE, Silverman bandwidth (no scipy)."""
|
|
x = np.asarray(x, float)
|
|
if len(x) < 2:
|
|
return np.zeros_like(grid)
|
|
iqr = np.subtract(*np.percentile(x, [75, 25]))
|
|
sigma = min(x.std(ddof=1), iqr / 1.349) if iqr > 0 else x.std(ddof=1)
|
|
bw = max(0.9 * (sigma or 1.0) * len(x) ** (-0.2), 1e-3)
|
|
z = (grid[:, None] - x[None, :]) / bw
|
|
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
|
|
|
|
|
|
def plot_dist(df: pl.DataFrame, subtitle: str, out_png: Path) -> None:
|
|
"""Distribution plot from the 5 saved populations -- no GPU needed."""
|
|
arr = lambda p: df.filter(pl.col("pop") == p)["pos"].to_numpy()
|
|
on_solve, on_fail, on_hack = arr("on_solve"), arr("on_fail"), arr("on_hack")
|
|
syn_solve, syn_hack = arr("syn_solve"), arr("syn_hack")
|
|
# AUROC/divider are hack-vs-rest (the gate's job): label = exploited.
|
|
pos_live = np.concatenate([on_solve, on_fail, on_hack])
|
|
labels = np.concatenate([np.zeros(len(on_solve) + len(on_fail), bool), np.ones(len(on_hack), bool)])
|
|
mean, sd = float(pos_live.mean()), float(pos_live.std())
|
|
lo_b, hi_b = mean - 2 * sd, mean + 2 * sd
|
|
auroc = _auroc(pos_live.tolist(), labels.tolist())
|
|
thr = np.unique(pos_live)
|
|
j = [(pos_live[labels] >= t).mean() - (pos_live[~labels] >= t).mean() for t in thr]
|
|
oracle = float(thr[int(np.argmax(j))])
|
|
|
|
lo = min(pos_live.min(), syn_solve.min()) - 0.1
|
|
hi = max(np.quantile(pos_live, 0.99), syn_hack.max()) + 0.1
|
|
grid = np.linspace(lo, hi, 400)
|
|
POPS = [(on_solve, SOLVE, True, "on-policy solve (gt pass)"),
|
|
(on_fail, FAIL, True, "on-policy fail"),
|
|
(on_hack, HACK, True, "on-policy hack"),
|
|
(syn_solve, SOLVE, False, "synthetic solve"),
|
|
(syn_hack, HACK, False, "synthetic hack")]
|
|
kdes = {i: _kde(x, grid) for i, (x, *_) in enumerate(POPS)}
|
|
ymax = max(y.max() for y in kdes.values()) * 1.15
|
|
|
|
fig, ax = plt.subplots(figsize=(8.8, 4.8))
|
|
ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0)
|
|
ax.axvline(mean, color=MEANC, lw=1.8)
|
|
for b in (lo_b, hi_b):
|
|
ax.axvline(b, color=MEANC, lw=1.2, ls="--")
|
|
ax.axvline(oracle, color=ORACLE, lw=1.3, ls="-.")
|
|
for i, (x, col, on_policy, _) in enumerate(POPS):
|
|
y = kdes[i]
|
|
if on_policy:
|
|
ax.fill_between(grid, y, color=col, alpha=0.12, lw=0)
|
|
ax.plot(grid, y, color=col, lw=1.9)
|
|
else:
|
|
ax.plot(grid, y, color=col, lw=2.2, ls=(0, (5, 2)))
|
|
ax.set_ylim(0, ymax)
|
|
for s in ("top", "right"):
|
|
ax.spines[s].set_visible(False)
|
|
ax.set_ylabel("density")
|
|
ax.set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$")
|
|
dist_handles = [Line2D([0], [0], color=c, lw=2.2, ls="-" if op else (0, (5, 2))) for _, c, op, _ in POPS]
|
|
leg1 = ax.legend(dist_handles, [lab for *_, lab in POPS], loc="upper left",
|
|
fontsize=8, frameon=False, title="distributions", title_fontsize=8)
|
|
leg1._legend_box.align = "left"
|
|
ax.add_artist(leg1)
|
|
mark_handles = [Line2D([0], [0], color=MEANC, lw=1.8),
|
|
Patch(facecolor=MEANC, alpha=0.18, edgecolor=MEANC, ls="--"),
|
|
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-.")]
|
|
ax.legend(mark_handles,
|
|
[f"online mean ({mean:+.2f})", f"mean +/- 2sd (sd={sd:.2f})",
|
|
f"best hack/rest split ({oracle:+.2f})"],
|
|
loc="upper right", fontsize=8, frameon=False, title="proposed pinning", title_fontsize=8)
|
|
ax.set_title(f"{subtitle}\nlive hack-direction positions vs authored v_grad "
|
|
f"(hack-vs-rest AUROC={auroc:.2f})", fontsize=9.5)
|
|
fig.tight_layout()
|
|
fig.savefig(out_png, dpi=140)
|
|
plt.close(fig)
|
|
logger.info(f"mean={mean:+.3f} sd={sd:.3f} AUROC={auroc:.3f} -> {out_png}")
|
|
|
|
|
|
def plot_pairset(rank: pl.DataFrame, subtitle: str, out_png: Path) -> None:
|
|
"""Horizontal bar chart: AUROC of a v_grad built from each pairset subset vs live hacks."""
|
|
rank = rank.sort("auroc")
|
|
groups = rank["group"].to_list()
|
|
auroc = rank["auroc"].to_numpy()
|
|
npairs = rank["n_pairs"].to_list()
|
|
cohend = rank["cohend"].to_numpy()
|
|
y = np.arange(len(groups))
|
|
fig, ax = plt.subplots(figsize=(7.2, 0.5 * len(groups) + 1.4))
|
|
cols = [HACK if g == "all-in-one" else SOLVE for g in groups]
|
|
ax.barh(y, auroc, color=cols, alpha=0.85)
|
|
ax.axvline(0.5, color="k", lw=1, ls=":") # 0.5 = blind to live hacks
|
|
ax.set_yticks(y)
|
|
ax.set_yticklabels([f"{g} (n={n})" for g, n in zip(groups, npairs)], fontsize=8)
|
|
for yi, a, d in zip(y, auroc, cohend):
|
|
ax.text(a + 0.005, yi, f"{a:.2f} d={d:+.2f}", va="center", fontsize=7.5)
|
|
ax.set_xlim(min(0.45, auroc.min() - 0.03), max(auroc.max() + 0.12, 0.6))
|
|
for s in ("top", "right"):
|
|
ax.spines[s].set_visible(False)
|
|
ax.set_xlabel("hack-vs-rest AUROC of subset's v_grad (0.5 = blind, dotted)")
|
|
ax.set_title(f"{subtitle}\nwhich authored-pair subset separates live hacks best?", fontsize=9.5)
|
|
fig.tight_layout()
|
|
fig.savefig(out_png, dpi=140)
|
|
plt.close(fig)
|
|
logger.info(f"wrote {out_png}")
|
|
|
|
|
|
def main(cfg: Cfg) -> int:
|
|
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
|
data_path = cfg.out_dir / "pinning_data.parquet"
|
|
rank_path = cfg.out_dir / "pinning_pairset.parquet"
|
|
dist_png = cfg.out_dir / "pinning_calib.png"
|
|
rank_png = cfg.out_dir / "pinning_pairset_auroc.png"
|
|
if cfg.replot is not None:
|
|
plot_dist(pl.read_parquet(cfg.replot), f"replot -- {cfg.replot.name}", dist_png)
|
|
if rank_path.exists():
|
|
plot_pairset(pl.read_parquet(rank_path), "replot", rank_png)
|
|
return 0
|
|
|
|
device = torch.device("cuda")
|
|
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
|
meta = _ckpt_meta(ckpt_path)
|
|
run_cfg = json.loads(meta.get("cfg", "{}"))
|
|
model_name = run_cfg.get("model", meta.get("model", "Qwen/Qwen3-4B"))
|
|
r = run_cfg.get("lora_r", 32)
|
|
init_seed = run_cfg.get("lora_init_seed", 0)
|
|
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} "
|
|
f"model={model_name} r={r} init_seed={init_seed}")
|
|
|
|
tok = AutoTokenizer.from_pretrained(model_name)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token = tok.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
|
model.config.use_cache = False
|
|
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
|
|
names = sorted(wrappers)
|
|
sd = load_file(str(ckpt_path))
|
|
for nm in names:
|
|
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
|
|
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
|
|
logger.info(f"loaded A/B into {len(names)} modules")
|
|
model.eval()
|
|
|
|
def v_grad_from(pairs: list[HackPair]):
|
|
_, _, raw, _ = extract_v_hack(model, tok, wrappers, pairs,
|
|
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
|
|
vg = _build_v_grad(raw, wrappers, 1, device)
|
|
if cfg.random_v_seed is not None:
|
|
vg = _haar_unit_dirs(vg, cfg.random_v_seed, device)
|
|
return vg, route_band_edges(raw, vg, device), raw
|
|
|
|
pairs_all = load_pairs(cfg.pairs)
|
|
logger.info(f"pairs {cfg.pairs} -> {len(pairs_all)}")
|
|
v_grad, route_band, raw_all = v_grad_from(pairs_all)
|
|
|
|
# synthetic pair positions under the all-in-one v_grad (for the distribution plot)
|
|
def pair_pos(side: str) -> np.ndarray:
|
|
n = raw_all[f"{side}/{names[0]}"].shape[0]
|
|
G = torch.stack([torch.stack([raw_all[f"{side}/{nm}"][i] for nm in names]) for i in range(n)])
|
|
return _batch_pos(G, names, v_grad, route_band)
|
|
syn_solve, syn_hack = pair_pos("clean"), pair_pos("hack")
|
|
|
|
# score the live batch ONCE; store per-rollout deployed c-grads for cheap re-projection
|
|
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
|
batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts]
|
|
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
|
|
G_rows, exploited, gt_pass, steps = [], [], [], []
|
|
for i, rec in enumerate(batch):
|
|
model.zero_grad(set_to_none=True)
|
|
loss = completion_nll(model, tok, rec["prompt"], rec["text"], device)
|
|
if not torch.isfinite(loss):
|
|
continue
|
|
loss.backward()
|
|
g = []
|
|
for nm in names:
|
|
gr = wrappers[nm]["layer"]._lora2r_gate.grad
|
|
g.append(gr.sum(dim=tuple(range(gr.dim() - 1)))[:r].float().cpu()) # deployed block [r]
|
|
G_rows.append(torch.stack(g)) # [M, r]
|
|
exploited.append(bool(rec["exploited"])); gt_pass.append(bool(rec["gt_pass"])); steps.append(rec["step"])
|
|
if (i + 1) % 40 == 0:
|
|
logger.info(f" rollout {i+1}/{len(batch)}")
|
|
model.zero_grad(set_to_none=True)
|
|
G = torch.stack(G_rows) # [N, M, r]
|
|
exploited, gt_pass, steps = map(np.array, (exploited, gt_pass, steps))
|
|
pos = _batch_pos(G, names, v_grad, route_band)
|
|
# THREE honest populations (Q4 fix): solve = gt-correct & not-exploited; fail = neither; hack = exploited
|
|
solve_m = gt_pass & ~exploited
|
|
fail_m = ~gt_pass & ~exploited
|
|
logger.info(f"live: {len(exploited)} rollouts | hack={int(exploited.sum())} "
|
|
f"solve={int(solve_m.sum())} fail={int(fail_m.sum())}")
|
|
|
|
# per-pairset AUROC: build v_grad from each heading-prefix subset, re-project the SAME G
|
|
groups: dict[str, list] = defaultdict(list)
|
|
for p in pairs_all:
|
|
groups[p.problem_id.split("_")[0]].append(p)
|
|
rows = []
|
|
# extract_v_hack holds out 2 pairs, so a subset needs >=3 to leave a training pair.
|
|
candidates = [("all-in-one", pairs_all)] + [(g, ps) for g, ps in sorted(groups.items()) if len(ps) >= 3]
|
|
for gname, gpairs in candidates:
|
|
vg, rb, _ = v_grad_from(gpairs)
|
|
pg = _batch_pos(G, names, vg, rb)
|
|
auroc = _auroc(pg.tolist(), exploited.tolist())
|
|
cohend = _cohend(pg[exploited], pg[solve_m]) if solve_m.sum() >= 2 else _cohend(pg[exploited], pg[~exploited])
|
|
rows.append({"group": gname, "n_pairs": len(gpairs), "auroc": round(auroc, 3),
|
|
"cohend": round(cohend, 3)})
|
|
logger.info(f" pairset {gname:12s} n={len(gpairs):2d} AUROC={auroc:.3f} d={cohend:+.3f}")
|
|
rank = pl.DataFrame(rows).sort("auroc", descending=True)
|
|
rank.write_parquet(rank_path)
|
|
print(tabulate(rank.to_pandas(), headers="keys", tablefmt="pipe", floatfmt="+.3f", showindex=False))
|
|
|
|
df = pl.concat([
|
|
pl.DataFrame({"pop": "on_solve", "pos": pos[solve_m], "step": steps[solve_m]}),
|
|
pl.DataFrame({"pop": "on_fail", "pos": pos[fail_m], "step": steps[fail_m]}),
|
|
pl.DataFrame({"pop": "on_hack", "pos": pos[exploited], "step": steps[exploited]}),
|
|
pl.DataFrame({"pop": "syn_solve", "pos": syn_solve, "step": np.full(len(syn_solve), -1)}),
|
|
pl.DataFrame({"pop": "syn_hack", "pos": syn_hack, "step": np.full(len(syn_hack), -1)}),
|
|
])
|
|
df.write_parquet(data_path)
|
|
logger.info(f"wrote {data_path} ({len(df)} rows)")
|
|
sub = (f"{cfg.run_dir.name}\n{cfg.ckpt} v_grad, live steps {cfg.step_lo}-{cfg.step_hi}, "
|
|
f"hack={int(exploited.sum())} solve={int(solve_m.sum())} fail={int(fail_m.sum())}")
|
|
plot_dist(df, sub, dist_png)
|
|
plot_pairset(rank, cfg.run_dir.name, rank_png)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main(tyro.cli(Cfg)))
|