mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:15:58 +08:00
dae52b2a7d
- Delete src/vgrout/pairs_v2.py and src/vgrout/pairs_intent.py; move all data into scripts/pairset_build_intent.py (self-contained, exports 3 JSONs). - Export: pairs_intent_think.json (6), pairs_intent_funcname.json (6), pairs_intent_concept.json (6 diagnostic). - Update diag_cosine_dist.py and diag_pairs_compare.py to load from JSON instead of importing Python modules; drop tainted v2/allv2 pairsets from the diag sweep (print-without-assert axis). - train.py final table: add solve_rate_s computed same as hack_rate_s, so the per-run end-of-training table shows actual training solve rate (was "-"). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
149 lines
6.7 KiB
Python
149 lines
6.7 KiB
Python
"""Pairs comparison: which PAIR-SET gives the best high-precision hack direction?
|
|
|
|
Companion to diag_cosine_dist.py. That script sweeps SCORE x SPACE x FILTER for ONE
|
|
pair-set; this one fixes the winning config (grad cosine, the p@10=0.70 corner) and
|
|
sweeps the PAIR-SET axis -- authored vs pool-derived (prog_wide family) vs intent.
|
|
This is the table I should have built before committing a GPU run to a pairs guess.
|
|
|
|
The expensive part (140 live-rollout backward passes) is independent of the pairs, so
|
|
we cache the per-module live grads ONCE and loop the cheap v_grad build + scoring over
|
|
pair-sets. No oracle labels touch training; `exploited` is an offline plot-only label.
|
|
|
|
uv run python -m scripts.diag_pairs_compare # 4B, free GPU, ~20 min
|
|
outputs: out/diag/pairs_compare.csv (pairset x AUROC/p@10/p@20 at grad cosine).
|
|
"""
|
|
from __future__ import annotations
|
|
import json
|
|
import struct
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import tyro
|
|
import polars as pl
|
|
from loguru import logger
|
|
from tabulate import tabulate
|
|
from safetensors.torch import load_file
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from vgrout.antipasto import wrap_model_with_antipasto
|
|
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
|
|
from vgrout.pairs_from_pool import load_pairs_json
|
|
from vgrout.train import CACHE_ROOT
|
|
from scripts.diag_cosine_dist import _auroc, _prec_at_k
|
|
|
|
_PS = Path("out/pairsets")
|
|
PAIRSETS = {
|
|
"authored_all": lambda: load_pairs_json(_PS / "pairs_authored.json"),
|
|
"funcname": lambda: load_pairs_json(_PS / "pairs_intent_funcname.json"),
|
|
"think": lambda: load_pairs_json(_PS / "pairs_intent_think.json"),
|
|
"prog_wide": lambda: load_pairs_json(_PS / "prog_wide.json"),
|
|
"prog_wide_clean": lambda: load_pairs_json(_PS / "prog_wide_clean.json"),
|
|
"prog_wider": lambda: load_pairs_json(_PS / "prog_wider.json"),
|
|
"prog_widest": lambda: load_pairs_json(_PS / "prog_widest.json"),
|
|
"heldout_known_rt": lambda: load_pairs_json(_PS / "heldout_known_runtests.json"),
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class Cfg:
|
|
run_dir: Path = Path("out/runs/20260607T134234_fast_routingV_seed43_dir6_routeV_pertoken_s43")
|
|
ckpt: str = "first_hack"
|
|
step_lo: int = 5
|
|
step_hi: int = 9
|
|
max_rollouts: int = 140
|
|
out_dir: Path = Path("out/diag")
|
|
|
|
|
|
def main(cfg: Cfg) -> int:
|
|
device = torch.device("cuda")
|
|
kept_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
|
hack_path = cfg.run_dir / f"{cfg.ckpt}_hack.safetensors"
|
|
with open(kept_path, "rb") as f:
|
|
meta = json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
|
model_name = meta.get("model", "Qwen/Qwen3-4B")
|
|
logger.info(f"ckpt {kept_path.name} step={meta.get('step')} model={model_name}")
|
|
|
|
tok = AutoTokenizer.from_pretrained(model_name)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token = tok.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
|
model.config.use_cache = False
|
|
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device, grad_probe=False)
|
|
names = sorted(wrappers)
|
|
kept, hack = load_file(str(kept_path)), load_file(str(hack_path))
|
|
for nm in names:
|
|
wrappers[nm]["delta_S"].data.copy_(kept[nm].to(device))
|
|
wrappers[nm]["delta_S_hack"].data.copy_(hack[nm].to(device))
|
|
model.eval()
|
|
|
|
# ── cache live-rollout per-module grads ONCE (independent of pair-set) ──
|
|
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
|
batch = [r for r in recs if cfg.step_lo <= r["step"] <= cfg.step_hi and r["text"].strip()][:cfg.max_rollouts]
|
|
gb_cache, gn_cache, labels = [], [], []
|
|
for i, r in enumerate(batch):
|
|
model.zero_grad(set_to_none=True)
|
|
loss = completion_nll(model, tok, r["prompt"], r["text"], device)
|
|
if not torch.isfinite(loss):
|
|
continue
|
|
loss.backward()
|
|
gb = {nm: wrappers[nm]["delta_S"].grad.flatten().float().cpu() for nm in names}
|
|
gb_cache.append(gb)
|
|
gn_cache.append(torch.tensor([gb[nm].norm().item() for nm in names])) # |g_m|, pairs-independent
|
|
labels.append(bool(r["exploited"]))
|
|
if (i + 1) % 40 == 0:
|
|
logger.info(f" cached {i+1}/{len(batch)} live grads")
|
|
model.zero_grad(set_to_none=True)
|
|
gn_stack = torch.stack(gn_cache) # [n_roll, n_mod]
|
|
n_pos = sum(labels)
|
|
logger.info(f"live: {len(labels)} rollouts, {n_pos} exploited ({n_pos/len(labels):.2f} base rate)")
|
|
|
|
# ── score one pair-set at grad cosine (all + keep75 noise filter) ──
|
|
def score_pairset(pairs):
|
|
_, v_sv, raw_grads, _ = extract_v_hack(
|
|
model, tok, wrappers, pairs, top_k=1, tau_axis=0.0, n_heldout=2, device=device)
|
|
v_grad = {nm: (lambda d: d / d.norm().clamp_min(1e-12))(
|
|
(raw_grads[f"hack/{nm}"] - raw_grads[f"clean/{nm}"]).mean(0).flatten().float().cpu())
|
|
for nm in names}
|
|
sv0 = torch.tensor([v_sv[nm][0].item() for nm in names])
|
|
keep75 = sv0 >= sv0.quantile(0.25)
|
|
cos_all, cos_k75 = [], []
|
|
for gb, gn in zip(gb_cache, gn_stack):
|
|
dot = torch.tensor([(gb[nm] @ v_grad[nm]).item() for nm in names])
|
|
for mask, out in [(torch.ones(len(names), dtype=bool), cos_all), (keep75, cos_k75)]:
|
|
d = dot[mask].sum()
|
|
gmag = gn[mask].pow(2).sum().sqrt()
|
|
nmod = int(mask.sum())
|
|
out.append((d / (gmag.clamp_min(1e-12) * nmod ** 0.5)).item())
|
|
return cos_all, cos_k75
|
|
|
|
rows = []
|
|
for name, fn in PAIRSETS.items():
|
|
try:
|
|
pairs = fn()
|
|
except FileNotFoundError:
|
|
logger.warning(f"{name}: pairset file missing, skipping")
|
|
continue
|
|
model.eval()
|
|
ca, ck = score_pairset(pairs)
|
|
row = dict(pairset=name, n=len(pairs),
|
|
AUROC=round(_auroc(ca, labels), 3),
|
|
p10=round(_prec_at_k(ca, labels, 10), 3),
|
|
p20=round(_prec_at_k(ca, labels, 20), 3),
|
|
p10_keep75=round(_prec_at_k(ck, labels, 10), 3))
|
|
rows.append(row)
|
|
logger.info(f"{name} (n={len(pairs)}): grad-cosine p10={row['p10']} AUROC={row['AUROC']}")
|
|
|
|
df = pl.DataFrame(rows).sort(["p10", "AUROC"], descending=True)
|
|
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
|
df.write_csv(cfg.out_dir / "pairs_compare.csv")
|
|
print("\n=== pair-sets ranked by grad-cosine precision@10 (base rate "
|
|
f"{n_pos/len(labels):.2f}, {len(labels)} live rollouts) ===")
|
|
print(tabulate(df.to_pandas(), headers="keys", tablefmt="pipe", showindex=False))
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main(tyro.cli(Cfg)))
|