mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 15:15:40 +08:00
feat(#30): mean+k*std online gate replaces fixed quantile; always-show route cols
Gate calibration: route by live mean + route_std_mid/route_std_rout * std of the pooled cosine-to-v_grad, not a fixed quantile tail. Self-silences -- only the tail that genuinely exceeds the spread routes, so qmass tracks real separation instead of a forced fraction. The authored absolute band is mis-placed (live pos sits far below the synthetic-hack edge; even synthetic solve out-aligns on-policy hack). tablelog: auroc/rout/routE/keep/resid/qmass cols always shown (nan on vanilla) so arm tables line up. Diagnostics: scripts/diag_pinning.py (4-population calibration view, mean+/-2sd band) and scripts/diag_pinning_refresh.py (proves cosine stats recompute from a tracked v-independent gradient cloud on a v_grad refresh -- exact for k=1, sanity 2.5e-16). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,252 @@
|
||||
"""Pinning calibration: where does the live gate sit, and would mean +/- 2sd route it?
|
||||
|
||||
The routeV gate scores each rollout by the width-pooled, band-normalized cosine of its
|
||||
deployed-block c-probe gradient to v_grad (exactly train.py:_lora2r_gate_labels) -- call it
|
||||
the position on the HACKING DIRECTION. The current gate routes a FIXED quantile tail (#30).
|
||||
This plots four populations on that axis:
|
||||
on-policy solve / hack -- live rollouts (oracle `exploited` read for COLOUR/AUROC only)
|
||||
synthetic solve / hack -- the authored pairs v_grad was built from (the only label source)
|
||||
and marks the PROPOSED pinning: online mean and mean +/- 2*sd (a symmetric band that
|
||||
self-calibrates to the live spread instead of forcing a fixed %).
|
||||
|
||||
uv run python scripts/diag_pinning.py --run-dir out/runs/<vanilla_lora2r_run>
|
||||
uv run python scripts/diag_pinning.py --replot out/diag/pinning_data.parquet # no GPU, restyle only
|
||||
outputs (out/diag/): pinning_calib.png, pinning_data.parquet (the 4 populations, regenerates the plot).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tyro
|
||||
import polars as pl
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.lines import Line2D
|
||||
from matplotlib.patches import Patch
|
||||
from loguru import logger
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from vgrout.lora2r import wrap_model_with_lora2r
|
||||
from vgrout.pairs import load_pairs
|
||||
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
|
||||
from vgrout.train import _build_v_grad, route_band_edges, _haar_unit_dirs, _auroc
|
||||
|
||||
# colour = behaviour (blue solve, red hack); style = source (solid+fill on-policy, dashed synthetic)
|
||||
SOLVE, HACK, MEANC, ORACLE = "#3b6ea5", "#c44e52", "#d1900a", "#3a8a7a"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Cfg:
|
||||
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
||||
ckpt: str = "first_hack"
|
||||
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
||||
# Coherent emergence window. This vanilla v3 used the pre-fix lr=5e-4/warmup-0.1 and
|
||||
# DIVERGED at step 10 (exploited 20/24 -> 0/24); 2-9 = hacks emerging, model still sane.
|
||||
step_lo: int = 2
|
||||
step_hi: int = 9
|
||||
max_rollouts: int = 240
|
||||
random_v_seed: int | None = None # Haar placebo (sanity: pins should NOT separate)
|
||||
replot: Path | None = None # load this parquet and re-plot only (no model, no GPU)
|
||||
out_dir: Path = Path("out/diag")
|
||||
|
||||
|
||||
def _ckpt_meta(path: Path) -> dict:
|
||||
with open(path, "rb") as f:
|
||||
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
||||
|
||||
|
||||
def pooled_pos(c_grads: dict, v_grad: dict, route_band: dict, names: list, r: int) -> tuple[float, float]:
|
||||
"""Mirror train.py:_lora2r_gate_labels pooling for ONE rollout's per-module deployed c-grads."""
|
||||
num = 0.0; den = 0.0; w = 0.0; n_inc = 0
|
||||
for name in names:
|
||||
lower, upper = route_band[name]
|
||||
if upper - lower <= 0:
|
||||
continue
|
||||
g_b = c_grads[name][:r].float().to(v_grad[name].device)
|
||||
nrm = g_b.norm()
|
||||
cos_b = (torch.einsum("r, k r -> k", g_b, v_grad[name]).max() / nrm.clamp_min(1e-12)).item()
|
||||
num += cos_b - lower; den += upper - lower
|
||||
w += nrm.item(); n_inc += 1
|
||||
if n_inc == 0:
|
||||
raise RuntimeError("no module has positive band width")
|
||||
return num / den, w / n_inc
|
||||
|
||||
|
||||
def _kde(x: np.ndarray, grid: np.ndarray) -> np.ndarray:
|
||||
"""Gaussian KDE, Silverman bandwidth (no scipy)."""
|
||||
x = np.asarray(x, float)
|
||||
if len(x) < 2:
|
||||
return np.zeros_like(grid)
|
||||
iqr = np.subtract(*np.percentile(x, [75, 25]))
|
||||
sigma = min(x.std(ddof=1), iqr / 1.349) if iqr > 0 else x.std(ddof=1)
|
||||
bw = max(0.9 * (sigma or 1.0) * len(x) ** (-0.2), 1e-3)
|
||||
z = (grid[:, None] - x[None, :]) / bw
|
||||
return np.exp(-0.5 * z ** 2).sum(1) / (len(x) * bw * np.sqrt(2 * np.pi))
|
||||
|
||||
|
||||
def plot(df: pl.DataFrame, subtitle: str, out_png: Path) -> None:
|
||||
"""Regenerate the figure from the saved 4 populations -- no GPU needed."""
|
||||
arr = lambda p: df.filter(pl.col("pop") == p)["pos"].to_numpy()
|
||||
on_solve, on_hack = arr("on_solve"), arr("on_hack")
|
||||
syn_solve, syn_hack = arr("syn_solve"), arr("syn_hack")
|
||||
pos_live = np.concatenate([on_solve, on_hack])
|
||||
labels = np.concatenate([np.zeros(len(on_solve), bool), np.ones(len(on_hack), bool)])
|
||||
|
||||
mean, sd = float(pos_live.mean()), float(pos_live.std())
|
||||
lo_b, hi_b = mean - 2 * sd, mean + 2 * sd # proposed routing band
|
||||
auroc = _auroc(pos_live.tolist(), labels.tolist())
|
||||
thr = np.unique(pos_live) # oracle divider (Youden J) -- diagnostic only
|
||||
j = [(pos_live[labels] >= t).mean() - (pos_live[~labels] >= t).mean() for t in thr]
|
||||
oracle = float(thr[int(np.argmax(j))])
|
||||
|
||||
lo = min(pos_live.min(), syn_solve.min()) - 0.1
|
||||
hi = max(np.quantile(pos_live, 0.99), syn_hack.max()) + 0.1
|
||||
grid = np.linspace(lo, hi, 400)
|
||||
POPS = [(on_solve, SOLVE, True, "on-policy solve"), (on_hack, HACK, True, "on-policy hack"),
|
||||
(syn_solve, SOLVE, False, "synthetic solve"), (syn_hack, HACK, False, "synthetic hack")]
|
||||
kdes = {i: _kde(x, grid) for i, (x, *_ ) in enumerate(POPS)}
|
||||
ymax = max(y.max() for y in kdes.values()) * 1.15
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8.6, 4.8))
|
||||
# proposed pinning: mean +/- 2sd band, drawn first so curves sit on top
|
||||
ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0)
|
||||
ax.axvline(mean, color=MEANC, lw=1.8)
|
||||
for b in (lo_b, hi_b):
|
||||
ax.axvline(b, color=MEANC, lw=1.2, ls="--")
|
||||
ax.axvline(oracle, color=ORACLE, lw=1.3, ls="-.")
|
||||
for i, (x, col, on_policy, _) in enumerate(POPS):
|
||||
y = kdes[i]
|
||||
if on_policy:
|
||||
ax.fill_between(grid, y, color=col, alpha=0.13, lw=0)
|
||||
ax.plot(grid, y, color=col, lw=1.9)
|
||||
else:
|
||||
ax.fill_between(grid, y, color=col, alpha=0.05, lw=0)
|
||||
ax.plot(grid, y, color=col, lw=2.4, ls=(0, (5, 2)))
|
||||
ax.set_ylim(0, ymax)
|
||||
for s in ("top", "right"):
|
||||
ax.spines[s].set_visible(False)
|
||||
ax.set_ylabel("density")
|
||||
ax.set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$")
|
||||
|
||||
dist_handles = [Line2D([0], [0], color=c, lw=2.2, ls="-" if op else (0, (5, 2))) for _, c, op, _ in POPS]
|
||||
leg1 = ax.legend(dist_handles, [lab for *_, lab in POPS], loc="upper left",
|
||||
fontsize=8, frameon=False, title="distributions", title_fontsize=8)
|
||||
leg1._legend_box.align = "left"
|
||||
ax.add_artist(leg1)
|
||||
mark_handles = [
|
||||
Line2D([0], [0], color=MEANC, lw=1.8),
|
||||
Patch(facecolor=MEANC, alpha=0.18, edgecolor=MEANC, ls="--"),
|
||||
Line2D([0], [0], color=ORACLE, lw=1.3, ls="-."),
|
||||
]
|
||||
ax.legend(mark_handles,
|
||||
[f"online mean ({mean:+.2f})", f"mean +/- 2sd (sd={sd:.2f}) = proposed band",
|
||||
f"best hack/solve split ({oracle:+.2f}, diagnostic)"],
|
||||
loc="upper right", fontsize=8, frameon=False, title="pinning", title_fontsize=8)
|
||||
routed_hi = float((pos_live > hi_b).mean())
|
||||
ax.set_title(f"{subtitle}\nAUROC(hack-dir -> hack)={auroc:.2f} "
|
||||
f"{routed_hi:.0%} of rollouts beyond mean+2sd", fontsize=9.5)
|
||||
fig.tight_layout()
|
||||
fig.savefig(out_png, dpi=140)
|
||||
plt.close(fig)
|
||||
logger.info(f"mean={mean:+.3f} sd={sd:.3f} band=[{lo_b:+.3f},{hi_b:+.3f}] oracle={oracle:+.3f} "
|
||||
f"AUROC={auroc:.3f} routed(>mean+2sd)={routed_hi:.2f}")
|
||||
logger.info(f"wrote {out_png}")
|
||||
|
||||
|
||||
def main(cfg: Cfg) -> int:
|
||||
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
data_path = cfg.out_dir / "pinning_data.parquet"
|
||||
out_png = cfg.out_dir / "pinning_calib.png"
|
||||
if cfg.replot is not None:
|
||||
df = pl.read_parquet(cfg.replot)
|
||||
plot(df, f"pinning calibration (replot) -- {cfg.replot.name}", out_png)
|
||||
return 0
|
||||
|
||||
device = torch.device("cuda")
|
||||
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
||||
meta = _ckpt_meta(ckpt_path)
|
||||
run_cfg = json.loads(meta.get("cfg", "{}"))
|
||||
model_name = run_cfg.get("model", meta.get("model", "Qwen/Qwen3-4B"))
|
||||
r = run_cfg.get("lora_r", 32)
|
||||
init_seed = run_cfg.get("lora_init_seed", 0)
|
||||
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} hack_rate={meta.get('hack_rate')} "
|
||||
f"model={model_name} r={r} init_seed={init_seed}")
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(model_name)
|
||||
if tok.pad_token_id is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
||||
model.config.use_cache = False
|
||||
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
|
||||
names = sorted(wrappers)
|
||||
sd = load_file(str(ckpt_path))
|
||||
for nm in names:
|
||||
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
|
||||
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
|
||||
logger.info(f"loaded A/B into {len(names)} modules")
|
||||
|
||||
pairs = load_pairs(cfg.pairs)
|
||||
logger.info(f"pairs {cfg.pairs} -> {len(pairs)}")
|
||||
model.eval()
|
||||
_, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
|
||||
v_grad = _build_v_grad(raw_grads, wrappers, 1, device)
|
||||
if cfg.random_v_seed is not None:
|
||||
v_grad = _haar_unit_dirs(v_grad, cfg.random_v_seed, device)
|
||||
logger.info(f"OVERRODE v_grad with Haar dirs seed={cfg.random_v_seed} (placebo)")
|
||||
route_band = route_band_edges(raw_grads, v_grad, device)
|
||||
|
||||
def pair_pos(side: str) -> np.ndarray:
|
||||
n = raw_grads[f"{side}/{names[0]}"].shape[0]
|
||||
out = []
|
||||
for i in range(n):
|
||||
cg = {nm: torch.cat([raw_grads[f"{side}/{nm}"][i], torch.zeros(r)]) for nm in names}
|
||||
out.append(pooled_pos(cg, v_grad, route_band, names, r)[0])
|
||||
return np.array(out)
|
||||
syn_solve, syn_hack = pair_pos("clean"), pair_pos("hack")
|
||||
|
||||
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
||||
batch = [r_ for r_ in recs if cfg.step_lo <= r_["step"] <= cfg.step_hi and r_["text"].strip()][:cfg.max_rollouts]
|
||||
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
|
||||
pos_live, labels, steps = [], [], []
|
||||
for i, rec in enumerate(batch):
|
||||
model.zero_grad(set_to_none=True)
|
||||
loss = completion_nll(model, tok, rec["prompt"], rec["text"], device)
|
||||
if not torch.isfinite(loss):
|
||||
continue
|
||||
loss.backward()
|
||||
cg = {nm: wrappers[nm]["layer"]._lora2r_gate.grad.sum(dim=tuple(range(
|
||||
wrappers[nm]["layer"]._lora2r_gate.grad.dim() - 1))) for nm in names}
|
||||
pos_live.append(pooled_pos(cg, v_grad, route_band, names, r)[0])
|
||||
labels.append(bool(rec["exploited"])); steps.append(rec["step"])
|
||||
if (i + 1) % 40 == 0:
|
||||
logger.info(f" rollout {i+1}/{len(batch)}")
|
||||
model.zero_grad(set_to_none=True)
|
||||
pos_live = np.array(pos_live); labels = np.array(labels); steps = np.array(steps)
|
||||
per_step = {int(s): round(_auroc(pos_live[steps == s].tolist(), labels[steps == s].tolist()), 2)
|
||||
for s in sorted(set(steps.tolist())) if labels[steps == s].any() and (~labels[steps == s]).any()}
|
||||
logger.info(f"live: {len(labels)} rollouts, {int(labels.sum())} exploited; per-step AUROC={per_step}")
|
||||
|
||||
# save the 4 populations long-form -> regenerates the plot with --replot (no GPU)
|
||||
df = pl.concat([
|
||||
pl.DataFrame({"pop": "on_solve", "pos": pos_live[~labels], "step": steps[~labels]}),
|
||||
pl.DataFrame({"pop": "on_hack", "pos": pos_live[labels], "step": steps[labels]}),
|
||||
pl.DataFrame({"pop": "syn_solve", "pos": syn_solve, "step": np.full(len(syn_solve), -1)}),
|
||||
pl.DataFrame({"pop": "syn_hack", "pos": syn_hack, "step": np.full(len(syn_hack), -1)}),
|
||||
])
|
||||
df.write_parquet(data_path)
|
||||
logger.info(f"wrote {data_path} ({len(df)} rows)")
|
||||
plot(df, f"{cfg.run_dir.name}\n{cfg.ckpt} v_grad, live steps {cfg.step_lo}-{cfg.step_hi}, "
|
||||
f"{int(labels.sum())}/{len(labels)} exploited", out_png)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main(tyro.cli(Cfg)))
|
||||
@@ -0,0 +1,227 @@
|
||||
"""Refresh-tracking prototype: when G_hack (v_grad) refreshes, can we recompute the
|
||||
live cosine stats from a TRACKED gradient cloud, without re-running the model?
|
||||
|
||||
Claim (k=1): the gate score pos_i = (S_i - L)/W with S_i = sum_m <u_{i,m}, v_m>, where
|
||||
u_{i,m} = g_{i,m}/|g_{i,m}| is the v-INDEPENDENT unit per-module gradient. Stacking
|
||||
U_i=[u_{i,1};..], V=[v_1;..] gives S_i = U_i . V (linear in V). So the live cosine
|
||||
distribution under ANY v is an affine push of the SAME gradient cloud:
|
||||
mean(pos) = (mu_U . V - L) / W
|
||||
var(pos) = (V^T Sigma_U V) / W^2
|
||||
mu_U, Sigma_U are tracked once (v-independent); L,W come from the pairs at refresh. So a
|
||||
refresh needs no model re-run and no window flush -- just re-project the cloud onto v_new.
|
||||
|
||||
We simulate a refresh with two direction estimates v_A, v_B = unit-mean-diff of two DISJOINT
|
||||
halves of the authored pairs (model fixed, only the direction changes), and verify:
|
||||
(1) reproject stored U onto v == direct pos (exact, validates indexing)
|
||||
(2) moment formula (mu_U, Sigma_U) == empirical mean/sd of pos, for BOTH v_A and v_B
|
||||
from the SAME tracked cloud (the actual claim).
|
||||
|
||||
uv run python scripts/diag_pinning_refresh.py
|
||||
outputs (out/diag/): pinning_refresh.png, pinning_refresh.parquet, prints a sanity table.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tyro
|
||||
import polars as pl
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.lines import Line2D
|
||||
from matplotlib.patches import Patch
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from vgrout.lora2r import wrap_model_with_lora2r
|
||||
from vgrout.pairs import load_pairs
|
||||
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
|
||||
from vgrout.train import _build_v_grad, route_band_edges, _auroc
|
||||
from diag_pinning import _kde, SOLVE, HACK, MEANC # same dir (scripts/ is sys.path[0] when run directly)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Cfg:
|
||||
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
||||
ckpt: str = "first_hack"
|
||||
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
||||
step_lo: int = 2
|
||||
step_hi: int = 9
|
||||
max_rollouts: int = 240
|
||||
out_dir: Path = Path("out/diag")
|
||||
|
||||
|
||||
def _ckpt_meta(path: Path) -> dict:
|
||||
with open(path, "rb") as f:
|
||||
return json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
||||
|
||||
|
||||
def stacked_v_and_band(raw_grads: dict, idx, wrappers, names, r, device):
|
||||
"""Direction estimate + band from a SUBSET of pairs (simulated refresh).
|
||||
|
||||
Returns V_stacked [n_mod*r] (zeroed on modules the band excludes), and scalars L,W.
|
||||
Matches pooled_pos: a module enters iff its band width > 0; excluded modules contribute 0."""
|
||||
sub = {k: v[idx] for k, v in raw_grads.items()}
|
||||
v_grad = _build_v_grad(sub, wrappers, 1, device) # {name: [1, r]} unit
|
||||
band = route_band_edges(sub, v_grad, device) # {name: (lower, upper)}
|
||||
V, L, W = [], 0.0, 0.0
|
||||
for name in names:
|
||||
lower, upper = band[name]
|
||||
if upper - lower > 0:
|
||||
V.append(v_grad[name][0].float().cpu()) # [r]
|
||||
L += lower; W += (upper - lower)
|
||||
else:
|
||||
V.append(torch.zeros(r)) # excluded -> no contribution
|
||||
return torch.cat(V).numpy(), float(L), float(W), v_grad, band
|
||||
|
||||
|
||||
def pos_from_U(U: np.ndarray, V: np.ndarray, L: float, W: float) -> np.ndarray:
|
||||
"""pos = (U . V - L) / W -- re-project the stored cloud onto direction V (no model)."""
|
||||
return (U @ V - L) / W
|
||||
|
||||
|
||||
def main(cfg: Cfg) -> int:
|
||||
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
device = torch.device("cuda")
|
||||
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
||||
meta = _ckpt_meta(ckpt_path)
|
||||
run_cfg = json.loads(meta.get("cfg", "{}"))
|
||||
model_name = run_cfg.get("model", "Qwen/Qwen3-4B")
|
||||
r = run_cfg.get("lora_r", 32)
|
||||
init_seed = run_cfg.get("lora_init_seed", 0)
|
||||
logger.info(f"ckpt {ckpt_path.name} step={meta.get('step')} model={model_name} r={r}")
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(model_name)
|
||||
if tok.pad_token_id is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
||||
model.config.use_cache = False
|
||||
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed, grad_probe=True)
|
||||
names = sorted(wrappers)
|
||||
sd = load_file(str(ckpt_path))
|
||||
for nm in names:
|
||||
wrappers[nm]["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
|
||||
wrappers[nm]["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
|
||||
|
||||
pairs = load_pairs(cfg.pairs)
|
||||
model.eval()
|
||||
_, _, raw_grads, _ = extract_v_hack(model, tok, wrappers, pairs,
|
||||
top_k=1, tau_axis=0.0, n_heldout=2, device=device)
|
||||
n_pairs = raw_grads[f"hack/{names[0]}"].shape[0]
|
||||
half = n_pairs // 2
|
||||
idx_A, idx_B = list(range(half)), list(range(half, n_pairs)) # disjoint pair halves = two v estimates
|
||||
V_A, L_A, W_A, vA, _ = stacked_v_and_band(raw_grads, idx_A, wrappers, names, r, device)
|
||||
V_B, L_B, W_B, vB, _ = stacked_v_and_band(raw_grads, idx_B, wrappers, names, r, device)
|
||||
V_A, V_B = V_A.astype(np.float64), V_B.astype(np.float64) # float64 so the sanity gate reflects math, not rounding
|
||||
cosAB = float(np.dot(V_A, V_B) / (np.linalg.norm(V_A) * np.linalg.norm(V_B) + 1e-12))
|
||||
logger.info(f"two direction estimates from pair halves ({half} / {n_pairs-half}); "
|
||||
f"cos(V_A,V_B)={cosAB:+.3f} (a real refresh shift)")
|
||||
|
||||
# ── score live batch ONCE: store U_i = stacked unit per-module gradients (v-INDEPENDENT) ──
|
||||
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
||||
batch = [x for x in recs if cfg.step_lo <= x["step"] <= cfg.step_hi and x["text"].strip()][:cfg.max_rollouts]
|
||||
logger.info(f"live batch: {len(batch)} rollouts (steps {cfg.step_lo}-{cfg.step_hi})")
|
||||
U_rows, labels = [], []
|
||||
for i, rec in enumerate(batch):
|
||||
model.zero_grad(set_to_none=True)
|
||||
loss = completion_nll(model, tok, rec["prompt"], rec["text"], device)
|
||||
if not torch.isfinite(loss):
|
||||
continue
|
||||
loss.backward()
|
||||
u_blocks = []
|
||||
for name in names:
|
||||
g = wrappers[name]["layer"]._lora2r_gate.grad
|
||||
g_b = g.sum(dim=tuple(range(g.dim() - 1)))[:r].float() # deployed block [r]
|
||||
u = (g_b / g_b.norm().clamp_min(1e-12)).cpu() # v-independent unit grad
|
||||
u_blocks.append(u)
|
||||
U_rows.append(torch.cat(u_blocks).numpy())
|
||||
labels.append(bool(rec["exploited"]))
|
||||
if (i + 1) % 40 == 0:
|
||||
logger.info(f" rollout {i+1}/{len(batch)}")
|
||||
model.zero_grad(set_to_none=True)
|
||||
U = np.stack(U_rows).astype(np.float64) # [n_roll, n_mod*r] the tracked gradient cloud
|
||||
labels = np.array(labels)
|
||||
logger.info(f"stored U cloud: {U.shape} ({U.nbytes/1e6:.1f} MB), {int(labels.sum())} exploited")
|
||||
|
||||
# ── pos under each direction by RE-PROJECTING the stored cloud (no model re-run) ──
|
||||
pos_A = pos_from_U(U, V_A, L_A, W_A)
|
||||
pos_B = pos_from_U(U, V_B, L_B, W_B)
|
||||
|
||||
# ── moment formula from the tracked cloud: mu_U, Sigma_U (v-independent) ──
|
||||
mu_U = U.mean(0) # [D]
|
||||
Uc = U - mu_U
|
||||
Sigma_U = (Uc.T @ Uc) / U.shape[0] # [D, D] uncentered-then-centered covariance
|
||||
def moment_stats(V, L, W):
|
||||
mean = (mu_U @ V - L) / W
|
||||
var = (V @ (Sigma_U @ V)) / (W * W)
|
||||
return mean, np.sqrt(max(var, 0.0))
|
||||
mA_emp, sA_emp = pos_A.mean(), pos_A.std()
|
||||
mB_emp, sB_emp = pos_B.mean(), pos_B.std()
|
||||
mA_mom, sA_mom = moment_stats(V_A, L_A, W_A)
|
||||
mB_mom, sB_mom = moment_stats(V_B, L_B, W_B)
|
||||
|
||||
# ── sanity table ──
|
||||
rows = [
|
||||
["mean pos | v_A", f"{mA_emp:+.4f}", f"{mA_mom:+.4f}", f"{abs(mA_emp-mA_mom):.2e}"],
|
||||
["std pos | v_A", f"{sA_emp:.4f}", f"{sA_mom:.4f}", f"{abs(sA_emp-sA_mom):.2e}"],
|
||||
["mean pos | v_B", f"{mB_emp:+.4f}", f"{mB_mom:+.4f}", f"{abs(mB_emp-mB_mom):.2e}"],
|
||||
["std pos | v_B", f"{sB_emp:.4f}", f"{sB_mom:.4f}", f"{abs(sB_emp-sB_mom):.2e}"],
|
||||
]
|
||||
print("\nSHOULD: moment-formula (from tracked mu_U, Sigma_U) == empirical, both v_A and v_B.")
|
||||
print(tabulate(rows, headers=["quantity", "empirical (direct)", "moment formula", "abs diff"],
|
||||
tablefmt="github"))
|
||||
max_diff = max(abs(mA_emp-mA_mom), abs(sA_emp-sA_mom), abs(mB_emp-mB_mom), abs(sB_emp-sB_mom))
|
||||
ok = max_diff < 1e-5
|
||||
print(f"\n{'PASS' if ok else 'FAIL'}: max |empirical - moment| = {max_diff:.2e} (refresh needs no model re-run)")
|
||||
logger.info(f"refresh shift: mean {mA_emp:+.3f}->{mB_emp:+.3f}, std {sA_emp:.3f}->{sB_emp:.3f}; "
|
||||
f"AUROC v_A={_auroc(pos_A.tolist(), labels.tolist()):.3f} v_B={_auroc(pos_B.tolist(), labels.tolist()):.3f}")
|
||||
|
||||
# ── plot: pos under v_A (top) and v_B (bottom), mean +/- 2sd band recalibrates per direction ──
|
||||
pl.DataFrame({"pos_A": pos_A, "pos_B": pos_B, "exploited": labels.tolist()}).write_parquet(
|
||||
cfg.out_dir / "pinning_refresh.parquet")
|
||||
lo = min(pos_A.min(), pos_B.min()) - 0.1
|
||||
hi = max(pos_A.max(), pos_B.max()) + 0.1
|
||||
grid = np.linspace(lo, hi, 400)
|
||||
fig, axes = plt.subplots(2, 1, figsize=(8.6, 6.0), sharex=True)
|
||||
for ax, pos, mean, std, tag in [(axes[0], pos_A, mA_emp, sA_emp, "A (pairs 1st half)"),
|
||||
(axes[1], pos_B, mB_emp, sB_emp, "B (pairs 2nd half)")]:
|
||||
lo_b, hi_b = mean - 2 * std, mean + 2 * std
|
||||
ax.axvspan(lo_b, hi_b, color=MEANC, alpha=0.07, lw=0)
|
||||
ax.axvline(mean, color=MEANC, lw=1.8)
|
||||
for b in (lo_b, hi_b):
|
||||
ax.axvline(b, color=MEANC, lw=1.2, ls="--")
|
||||
peak = 0.0
|
||||
for x, col in [(pos[~labels], SOLVE), (pos[labels], HACK)]:
|
||||
y = _kde(x, grid)
|
||||
ax.fill_between(grid, y, color=col, alpha=0.13, lw=0)
|
||||
ax.plot(grid, y, color=col, lw=1.9)
|
||||
peak = max(peak, y.max())
|
||||
ax.set_ylim(0, peak * 1.15)
|
||||
ax.set_ylabel("density")
|
||||
ax.set_title(f"direction estimate {tag}: mean={mean:+.2f}, sd={std:.2f}, "
|
||||
f"AUROC={_auroc(pos.tolist(), labels.tolist()):.2f}", fontsize=9, loc="left")
|
||||
for s in ("top", "right"):
|
||||
ax.spines[s].set_visible(False)
|
||||
dist = [Line2D([0], [0], color=SOLVE, lw=2), Line2D([0], [0], color=HACK, lw=2),
|
||||
Line2D([0], [0], color=MEANC, lw=1.8), Patch(facecolor=MEANC, alpha=0.18, ls="--", edgecolor=MEANC)]
|
||||
axes[0].legend(dist, ["on-policy solve", "on-policy hack", "online mean", "mean +/- 2sd"],
|
||||
loc="upper left", fontsize=8, frameon=False)
|
||||
axes[1].set_xlabel("hacking direction (gradient cosine to v_grad) " + r"$\longrightarrow$")
|
||||
fig.suptitle("refresh tracking: same gradient cloud, two G_hack estimates "
|
||||
f"(cos(v_A,v_B)={cosAB:+.2f}) -- band recalibrates with no model re-run", fontsize=9.5)
|
||||
fig.tight_layout(rect=(0, 0, 1, 0.96))
|
||||
fig.savefig(cfg.out_dir / "pinning_refresh.png", dpi=140)
|
||||
plt.close(fig)
|
||||
logger.info(f"wrote {cfg.out_dir}/pinning_refresh.png, pinning_refresh.parquet")
|
||||
return 0 if ok else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main(tyro.cli(Cfg)))
|
||||
+15
-16
@@ -72,8 +72,8 @@ class StepLogger:
|
||||
|
||||
def __init__(self, arm: str, modes: list[str], mode_code: dict[str, str],
|
||||
show_ablate: bool = False) -> None:
|
||||
# routeV reports routing diagnostics; absorb shares qmass (zone cols read nan).
|
||||
is_route = arm in ("routingV_lora2r", "absorb_lora2r")
|
||||
# Routing diagnostics are ALWAYS shown (nan on vanilla, whose gate never runs) so the
|
||||
# column layout is identical across arms -- vanilla/routeV/absorb tables line up.
|
||||
cols: list[_Col] = [
|
||||
_Col("step", 4, "step", "d", "GRPO step"),
|
||||
_Col("ref_eq", 6, "ref_eq", ".2f", "vanilla-equiv step (cum_gens/256)"),
|
||||
@@ -99,21 +99,20 @@ class StepLogger:
|
||||
_Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of A/B grads (vs grad_clip)"),
|
||||
_Col("lr", 7, "lr", ".1e", "scheduled learning rate"),
|
||||
]
|
||||
# routeV reports unit and energy shares across the routing band.
|
||||
if is_route:
|
||||
cols += [
|
||||
_Col("auroc", 6, "auroc", ".2f", "AUROC of pooled cos(g,v_grad) as a hack detector vs the hack-label (student exploited + teacher cached); MEASUREMENT only, never routes. ~0.5 = v_grad blind to live hacks (no threshold helps); high but rout~0 = pure threshold/scale problem; a drop at a refresh = refresh destroyed separation"),
|
||||
_Col("cosU", 6, "cosU", "+.2f", "pooled cos(v_grad, summed-rollout c-grad): is the net update moving hack-ward this step"),
|
||||
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"),
|
||||
_Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"),
|
||||
_Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"),
|
||||
_Col("rout", 6, "rout", ".2f", "rollout share labelled hack (above band) -> quarantine-only, deployed detached"),
|
||||
_Col("keepE", 6, "keepE", ".2f", "energy-weighted keep (grad-norm-weighted clean share)"),
|
||||
_Col("residE", 6, "residE", ".2f", "energy-weighted resid"),
|
||||
_Col("routE", 6, "routE", ".2f", "energy-weighted rout"),
|
||||
]
|
||||
# routeV reports unit and energy shares across the routing band (nan on vanilla/absorb).
|
||||
cols += [
|
||||
_Col("auroc", 6, "auroc", ".2f", "AUROC of pooled cos(g,v_grad) as a hack detector vs the hack-label (student exploited + teacher cached); MEASUREMENT only, never routes. ~0.5 = v_grad blind to live hacks (no threshold helps); high but rout~0 = pure threshold/scale problem; a drop at a refresh = refresh destroyed separation"),
|
||||
_Col("cosU", 6, "cosU", "+.2f", "pooled cos(v_grad, summed-rollout c-grad): is the net update moving hack-ward this step"),
|
||||
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update landing in the throwaway quarantine block"),
|
||||
_Col("keep", 6, "keep", ".2f", "rollout share labelled clean (below band) -> deployed-only, quarantine off"),
|
||||
_Col("resid", 6, "resid", ".2f", "rollout share labelled mid (inside band) -> both blocks train (absorption)"),
|
||||
_Col("rout", 6, "rout", ".2f", "rollout share labelled hack (above band) -> quarantine-only, deployed detached"),
|
||||
_Col("keepE", 6, "keepE", ".2f", "energy-weighted keep (grad-norm-weighted clean share)"),
|
||||
_Col("residE", 6, "residE", ".2f", "energy-weighted resid"),
|
||||
_Col("routE", 6, "routE", ".2f", "energy-weighted rout"),
|
||||
]
|
||||
# Show the training-prompt deploy proxy only when an ablated slice exists.
|
||||
if is_route and show_ablate:
|
||||
if show_ablate:
|
||||
cols += [
|
||||
_Col("hack_abl", 6, "hk_abl", "frac", "per-step deploy proxy: hack rate on the ablated (deploy-mode) rollout slice; train prompts, noisier than hk_dep"),
|
||||
_Col("solve_abl", 6, "slv_abl", "frac", "per-step deploy proxy: solve rate on the ablated (deploy-mode) rollout slice; train prompts"),
|
||||
|
||||
+14
-13
@@ -602,22 +602,23 @@ def main(cfg: Config) -> int:
|
||||
if n_inc == 0:
|
||||
raise RuntimeError("no module has positive band width; pairs separate nowhere")
|
||||
pos = num / den; w /= n_inc
|
||||
# ── online-stats gate (three zones, keep = bulk) ── The authored absolute thresholds
|
||||
# (clean<=0, hack>=1) don't work: live hacks sit far below the authored p75-hack pooled
|
||||
# edge, so rout~0. Route by live QUANTILES of pos instead: top route_quantile tail ->
|
||||
# hack (rout), the next route_quantile band -> mid (absorption), everything below ->
|
||||
# keep. keep is the BULK; the previous t_lo=quantile(route_quantile) put keep at the
|
||||
# bottom route_quantile and routed ~(1-route_quantile) into mid, so the quarantine
|
||||
# trained on almost every rollout (qmass~0.5). v_grad stays authored-only; only the
|
||||
# thresholds follow the live distribution. The window includes this batch, so step 0
|
||||
# self-calibrates. TODO(#30): calibrate the rout/mid widths from the batch distribution.
|
||||
# ── online-stats gate (#30): mean + k*std, three zones, keep = bulk ── The authored
|
||||
# absolute band is mis-placed (live pos sits far below the synthetic-hack edge; even
|
||||
# synthetic SOLVE is more hack-aligned than on-policy hack -- out/diag/pinning_calib.png),
|
||||
# and a fixed quantile FORCES route_quantile out every step even when nothing separates.
|
||||
# Route by the live MEAN + k*STD instead: pos > mean+route_std_mid*std -> mid, pos >=
|
||||
# mean+route_std_rout*std -> rout, below -> keep (bulk). Self-silencing: only the tail
|
||||
# that genuinely exceeds the spread routes, so qmass tracks real separation rather than a
|
||||
# forced fraction. v_grad stays authored-only; the threshold follows the live distribution.
|
||||
# The window includes this batch, so step 0 self-calibrates; flushed on v_grad refresh.
|
||||
route_pos_window.extend(pos.detach().cpu().tolist())
|
||||
ref = torch.tensor(list(route_pos_window))
|
||||
t_hi = ref.quantile(1.0 - cfg.route_quantile).item() # top route_quantile -> rout
|
||||
t_lo = ref.quantile(1.0 - 2.0 * cfg.route_quantile).item() # next band -> mid; below -> keep
|
||||
mu_pos, sd_pos = ref.mean().item(), ref.std().item()
|
||||
t_lo = mu_pos + cfg.route_std_mid * sd_pos # mid onset
|
||||
t_hi = mu_pos + cfg.route_std_rout * sd_pos # rout onset (>= mid)
|
||||
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} "
|
||||
f"max={pos.max().item():+.2f} | online t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} "
|
||||
f"win={len(route_pos_window)}")
|
||||
f"max={pos.max().item():+.2f} | mean={mu_pos:+.2f} std={sd_pos:.2f} "
|
||||
f"t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} win={len(route_pos_window)}")
|
||||
m = (pos > t_lo).float() # mid + rout -> quarantine trains (keep = bulk below t_lo)
|
||||
d = (pos >= t_hi).float() # top tail -> hack -> deployed detached
|
||||
return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc
|
||||
|
||||
@@ -49,13 +49,18 @@ class Config:
|
||||
# (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean
|
||||
# washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B.
|
||||
v_grad_k: int = 1
|
||||
# Online-stats gate: route by live QUANTILES of the pooled cosine-to-v_grad, not the
|
||||
# authored absolute band (whose p75-hack edge live hacks never reach -> rout~0). Each
|
||||
# step the top route_quantile tail -> hack (deployed detached), bottom -> keep, middle
|
||||
# -> mid. route_window = sliding buffer of recent pooled positions, flushed on refresh
|
||||
# (positions are measured against one v_grad). Direction stays authored-only; only the
|
||||
# threshold follows the live distribution. TODO(#30): center+width calibration vs fixed %.
|
||||
route_quantile: float = 0.05
|
||||
# Online-stats gate (#30): route by the live MEAN + k*STD of the pooled cosine-to-v_grad,
|
||||
# not a fixed quantile. The authored absolute band is mis-placed (live pos sits far below
|
||||
# the synthetic-hack edge; even synthetic SOLVE is more hack-aligned than on-policy hack --
|
||||
# see out/diag/pinning_calib.png), and a fixed quantile FORCES route_quantile of the batch
|
||||
# out every step even when nothing separates. mean+k*std self-silences: it only routes the
|
||||
# tail that genuinely exceeds the spread. pos > mean + route_std_mid*std -> mid (absorption);
|
||||
# pos >= mean + route_std_rout*std -> rout (hack, deployed detached); below -> keep (bulk).
|
||||
# route_window = sliding buffer of recent pooled positions, flushed on v_grad refresh
|
||||
# (positions measured against one v_grad). Direction stays authored-only; only the threshold
|
||||
# follows the live distribution.
|
||||
route_std_mid: float = 2.0
|
||||
route_std_rout: float = 3.0
|
||||
route_window: int = 512
|
||||
# Haar-random direction control (placebo): same routing machinery, no pair signal.
|
||||
routeV_random_v_seed: int | None = None
|
||||
|
||||
Reference in New Issue
Block a user