mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
979daf84fd
Gate calibration: route by live mean + route_std_mid/route_std_rout * std of the pooled cosine-to-v_grad, not a fixed quantile tail. Self-silences -- only the tail that genuinely exceeds the spread routes, so qmass tracks real separation instead of a forced fraction. The authored absolute band is mis-placed (live pos sits far below the synthetic-hack edge; even synthetic solve out-aligns on-policy hack). tablelog: auroc/rout/routE/keep/resid/qmass cols always shown (nan on vanilla) so arm tables line up. Diagnostics: scripts/diag_pinning.py (4-population calibration view, mean+/-2sd band) and scripts/diag_pinning_refresh.py (proves cosine stats recompute from a tracked v-independent gradient cloud on a v_grad refresh -- exact for k=1, sanity 2.5e-16). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
253 lines
12 KiB
Python
253 lines
12 KiB
Python
"""Pinning calibration: where does the live gate sit, and would mean +/- 2sd route it?
|
|
|
|
The routeV gate scores each rollout by the width-pooled, band-normalized cosine of its
|
|
deployed-block c-probe gradient to v_grad (exactly train.py:_lora2r_gate_labels) -- call it
|
|
the position on the HACKING DIRECTION. The current gate routes a FIXED quantile tail (#30).
|
|
This plots four populations on that axis:
|
|
on-policy solve / hack -- live rollouts (oracle `exploited` read for COLOUR/AUROC only)
|
|
synthetic solve / hack -- the authored pairs v_grad was built from (the only label source)
|
|
and marks the PROPOSED pinning: online mean and mean +/- 2*sd (a symmetric band that
|
|
self-calibrates to the live spread instead of forcing a fixed %).
|
|
|
|
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
|
|
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU, restyle only
|
|
outputs (out/diag/): pinning_calib.png, pinning_data.parquet (the 4 populations, regenerates the plot).
|
|
"""
|
|
from __future__ import annotations
|
|
import json
|
|
import struct
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
import tyro
|
|
import polars as pl
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.lines import Line2D
|
|
from matplotlib.patches import Patch
|
|
from loguru import logger
|
|
from safetensors.torch import load_file
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from vgrout.lora2r import wrap_model_with_lora2r
|
|
from vgrout.pairs import load_pairs
|
|
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
|
|
from vgrout.train import _build_v_grad, route_band_edges, _haar_unit_dirs, _auroc
|
|
|
|
# colour = behaviour (blue solve, red hack); style = source (solid+fill on-policy, dashed synthetic)
|
|
SOLVE, HACK, MEANC, ORACLE = "#3b6ea5", "#c44e52", "#d1900a", "#3a8a7a"
|
|
|
|
|
|
@dataclass
|
|
class Cfg:
|
|
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
|
ckpt: str = "first_hack"
|
|
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
|
# Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and
|
|
# DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane.
|
|
step_lo: int = 2
|
|
step_hi: int = 9
|
|
max_rollouts: int = 240
|
|
random_v_seed: int | None = None # Haar placebo (sanity: pins should NOT separate)
|
|
replot: Path | None = None # load this parquet and re-plot only (no model, no GPU)
|
|
out_dir: Path = Path("out/diag")
|
|
|
|
|
|
def _ckpt_meta(path: Path) -> dict:
|
|
with open(path, "rb") as f:
|
|
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
|
|
|
|
|
def pooled_pos(c_grads: dict, v_grad: dict, route_band: dict, names: list, r: int) -> tuple[float, float]:
|
|
"""Mirror train.py:_lora2r_gate_labels pooling for ONE rollout's per-module deployed c-grads."""
|
|
num = 0.0; den = 0.0; w = 0.0; n_inc = 0
|
|
for name in names:
|
|
lower, upper = route_band[name]
|
|
if upper - lower <= 0:
|
|
continue
|
|
g_b = c_grads[name][:r].float().to(v_grad[name].device)
|
|
nrm = g_b.norm()
|
|
cos_b = (torch.einsum("r, k r -> k", g_b, v_grad[name]).max() / nrm.clamp_min(1e-12)).item()
|
|
num += cos_b - lower; den += upper - lower
|
|
w += nrm.item(); n_inc += 1
|
|
if n_inc == 0:
|
|
raise RuntimeError("no module has positive band width")
|
|
return num / den, w / n_inc
|
|
|
|
|
|
def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
|
|
"""Gaussian KDE, Silverman bandwidth (no scipy)."""
|
|
x = np.asarray(x, float)
|
|
if len(x) < 2:
|
|
return np.zeros_like(grid)
|
|
iqr = np.subtract(*np.percentile(x, [75, 25]))
|
|
sigma = min(x.std(ddof=1), iqr / 1.349) if iqr > 0 else x.std(ddof=1)
|
|
bw = max(0.9 * (sigma or 1.0) * len(x) ** (-0.2), 1e-3)
|
|
z = (grid[:, None] - x[None, :]) / bw
|
|
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
|
|
|
|
|
|
def plot(df: pl.DataFrame, subtitle: str, out_png: Path) -> None:
|
|
"""Regenerate the figure from the saved 4 populations -- no GPU needed."""
|
|
arr = lambda p: df.filter(pl.col("pop") == p)["pos"].to_numpy()
|
|
on_solve, on_hack = arr("on_solve"), arr("on_hack")
|
|
syn_solve, syn_hack = arr("syn_solve"), arr("syn_hack")
|
|
pos_live = np.concatenate([on_solve, on_hack])
|
|
labels = np.concatenate([np.zeros(len(on_solve), bool), np.ones(len(on_hack), bool)])
|
|
|
|
mean, sd = float(pos_live.mean()), float(pos_live.std())
|
|
lo_b, hi_b = mean - 2 * sd, mean + 2 * sd # proposed routing band
|
|
auroc = _auroc(pos_live.tolist(), labels.tolist())
|
|
thr = np.unique(pos_live) # oracle divider (Youden J) -- diagnostic only
|
|
j = [(pos_live[labels] >= t).mean() - (pos_live[~labels] >= t).mean() for t in thr]
|
|
oracle = float(thr[int(np.argmax(j))])
|
|
|
|
lo = min(pos_live.min(), syn_solve.min()) - 0.1
|
|
hi = max(np.quantile(pos_live, 0.99), syn_hack.max()) + 0.1
|
|
grid = np.linspace(lo, hi, 400)
|
|
POPS = [(on_solve, SOLVE, True, "on-policy solve"), (on_hack, HACK, True, "on-policy hack"),
|
|
(syn_solve, SOLVE, False, "synthetic solve"), (syn_hack, HACK, False, "synthetic hack")]
|
|
kdes = {i: _kde(x, grid) for i, (x, *_ ) in enumerate(POPS)}
|
|
ymax = max(y.max() for y in kdes.values()) * 1.15
|
|
|
|
fig, ax = plt.subplots(figsize=(8.6, 4.8))
|
|
# proposed pinning: mean +/- 2sd band, drawn first so curves sit on top
|
|
ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0)
|
|
ax.axvline(mean, color=MEANC, lw=1.8)
|
|
for b in (lo_b, hi_b):
|
|
ax.axvline(b, color=MEANC, lw=1.2, ls="--")
|
|
ax.axvline(oracle, color=ORACLE, lw=1.3, ls="-.")
|
|
for i, (x, col, on_policy, _) in enumerate(POPS):
|
|
y = kdes[i]
|
|
if on_policy:
|
|
ax.fill_between(grid, y, color=col, alpha=0.13, lw=0)
|
|
ax.plot(grid, y, color=col, lw=1.9)
|
|
else:
|
|
ax.fill_between(grid, y, color=col, alpha=0.05, lw=0)
|
|
ax.plot(grid, y, color=col, lw=2.4, ls=(0, (5, 2)))
|
|
ax.set_ylim(0, ymax)
|
|
for s in ("top", "right"):
|
|
ax.spines[s].set_visible(False)
|
|
ax.set_ylabel("density")
|
|
ax.set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$")
|
|
|
|
dist_handles = [Line2D([0], [0], color=c, lw=2.2, ls="-" if op else (0, (5, 2))) for _, c, op, _ in POPS]
|
|
leg1 = ax.legend(dist_handles, [lab for *_, lab in POPS], loc="upper left",
|
|
fontsize=8, frameon=False, title="distributions", title_fontsize=8)
|
|
leg1._legend_box.align = "left"
|
|
ax.add_artist(leg1)
|
|
mark_handles = [
|
|
Line2D([0], [0], color=MEANC, lw=1.8),
|
|
Patch(facecolor=MEANC, alpha=0.18, edgecolor=MEANC, ls="--"),
|
|
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-."),
|
|
]
|
|
ax.legend(mark_handles,
|
|
[f"online mean ({mean:+.2f})", f"mean +/- 2sd (sd={sd:.2f}) = proposed band",
|
|
f"best hack/solve split ({oracle:+.2f}, diagnostic)"],
|
|
loc="upper right", fontsize=8, frameon=False, title="pinning", title_fontsize=8)
|
|
routed_hi = float((pos_live > hi_b).mean())
|
|
ax.set_title(f"{subtitle}\nAUROC(hack-dir -> hack)={auroc:.2f} "
|
|
f"{routed_hi:.0%} of rollouts beyond mean+2sd", fontsize=9.5)
|
|
fig.tight_layout()
|
|
fig.savefig(out_png, dpi=140)
|
|
plt.close(fig)
|
|
logger.info(f"mean={mean:+.3f} sd={sd:.3f} band=[{lo_b:+.3f},{hi_b:+.3f}] oracle={oracle:+.3f} "
|
|
f"AUROC={auroc:.3f} routed(>mean+2sd)={routed_hi:.2f}")
|
|
logger.info(f"wrote {out_png}")
|
|
|
|
|
|
def main(cfg: Cfg) -> int:
|
|
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
|
data_path = cfg.out_dir / "pinning_data.parquet"
|
|
out_png = cfg.out_dir / "pinning_calib.png"
|
|
if cfg.replot is not None:
|
|
df = pl.read_parquet(cfg.replot)
|
|
plot(df, f"pinning calibration (replot) -- {cfg.replot.name}", out_png)
|
|
return 0
|
|
|
|
device = torch.device("cuda")
|
|
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
|
meta = _ckpt_meta(ckpt_path)
|
|
run_cfg = json.loads(meta.get("cfg", "{}"))
|
|
model_name = run_cfg.get("model", meta.get("model", "Qwen/Qwen3-4B"))
|
|
r = run_cfg.get("lora_r", 32)
|
|
init_seed = run_cfg.get("lora_init_seed", 0)
|
|
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} "
|
|
f"model={model_name} r={r} init_seed={init_seed}")
|
|
|
|
tok = AutoTokenizer.from_pretrained(model_name)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token = tok.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
|
model.config.use_cache = False
|
|
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
|
|
names = sorted(wrappers)
|
|
sd = load_file(str(ckpt_path))
|
|
for nm in names:
|
|
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
|
|
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
|
|
logger.info(f"loaded A/B into {len(names)} modules")
|
|
|
|
pairs = load_pairs(cfg.pairs)
|
|
logger.info(f"pairs {cfg.pairs} -> {len(pairs)}")
|
|
model.eval()
|
|
_, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs,
|
|
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
|
|
v_grad = _build_v_grad(raw_grads, wrappers, 1, device)
|
|
if cfg.random_v_seed is not None:
|
|
v_grad = _haar_unit_dirs(v_grad, cfg.random_v_seed, device)
|
|
logger.info(f"OVERRODE v_grad with Haar dirs seed={cfg.random_v_seed} (placebo)")
|
|
route_band = route_band_edges(raw_grads, v_grad, device)
|
|
|
|
def pair_pos(side: str) -> np.ndarray:
|
|
n = raw_grads[f"{side}/{names[0]}"].shape[0]
|
|
out = []
|
|
for i in range(n):
|
|
cg = {nm: torch.cat([raw_grads[f"{side}/{nm}"][i], torch.zeros(r)]) for nm in names}
|
|
out.append(pooled_pos(cg, v_grad, route_band, names, r)[0])
|
|
return np.array(out)
|
|
syn_solve, syn_hack = pair_pos("clean"), pair_pos("hack")
|
|
|
|
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
|
batch = [r_ for r_ in recs if cfg.step_lo <= r_["step"] <= cfg.step_hi and r_["text"].strip()][:cfg.max_rollouts]
|
|
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
|
|
pos_live, labels, steps = [], [], []
|
|
for i, rec in enumerate(batch):
|
|
model.zero_grad(set_to_none=True)
|
|
loss = completion_nll(model, tok, rec["prompt"], rec["text"], device)
|
|
if not torch.isfinite(loss):
|
|
continue
|
|
loss.backward()
|
|
cg = {nm: wrappers[nm]["layer"]._lora2r_gate.grad.sum(dim=tuple(range(
|
|
wrappers[nm]["layer"]._lora2r_gate.grad.dim() - 1))) for nm in names}
|
|
pos_live.append(pooled_pos(cg, v_grad, route_band, names, r)[0])
|
|
labels.append(bool(rec["exploited"])); steps.append(rec["step"])
|
|
if (i + 1) % 40 == 0:
|
|
logger.info(f" rollout {i+1}/{len(batch)}")
|
|
model.zero_grad(set_to_none=True)
|
|
pos_live = np.array(pos_live); labels = np.array(labels); steps = np.array(steps)
|
|
per_step = {int(s): round(_auroc(pos_live[steps == s].tolist(), labels[steps == s].tolist()), 2)
|
|
for s in sorted(set(steps.tolist())) if labels[steps == s].any() and (~labels[steps == s]).any()}
|
|
logger.info(f"live: {len(labels)} rollouts, {int(labels.sum())} exploited; per-step AUROC={per_step}")
|
|
|
|
# save the 4 populations long-form -> regenerates the plot with --replot (no GPU)
|
|
df = pl.concat([
|
|
pl.DataFrame({"pop": "on_solve", "pos": pos_live[~labels], "step": steps[~labels]}),
|
|
pl.DataFrame({"pop": "on_hack", "pos": pos_live[labels], "step": steps[labels]}),
|
|
pl.DataFrame({"pop": "syn_solve", "pos": syn_solve, "step": np.full(len(syn_solve), -1)}),
|
|
pl.DataFrame({"pop": "syn_hack", "pos": syn_hack, "step": np.full(len(syn_hack), -1)}),
|
|
])
|
|
df.write_parquet(data_path)
|
|
logger.info(f"wrote {data_path} ({len(df)} rows)")
|
|
plot(df, f"{cfg.run_dir.name}\n{cfg.ckpt} v_grad, live steps {cfg.step_lo}-{cfg.step_hi}, "
|
|
f"{int(labels.sum())}/{len(labels)} exploited", out_png)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main(tyro.cli(Cfg)))
|