mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:48:43 +08:00
chore: memory updates, diag_pairs_compare script
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -9,3 +9,4 @@
|
||||
- [pueue negative-priority gotcha](pueue-negative-priority-gotcha.md) — `pueue add` negative prio needs `-o=-N` attached; `-o -N` silently fails the add.
|
||||
- [Rename on logic change](feedback_rename_on_logic_change.md) — when an arm's logic changes (binary->banded gate), give it a new id (routeV/route3), not just a tag suffix; else old/new runs are uncomparable.
|
||||
- [Check paper before diagnosing](feedback_check_paper_before_diagnosing.md) — re-read source for expected number/horizon before "experiment is broken"; paper: hack emerges on-policy at step 80-100, base solves ~12-20% not 94%.
|
||||
- [Eval must be recency-clean](project_eval_must_be_recency_clean.md) — eval on paper test set (ids>=3243, base ~0.1), NOT the holdout/first-N (memorized->0.94); train seeded-shuffle not first-N. Run set: `just queue-dir6`/`queue-broad`.
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
---
|
||||
name: project_eval_must_be_recency_clean
|
||||
description: "eval AND train must use the paper's recency-held-out test set / seeded-shuffle -- NOT first-N-by-id or the holdout file (pretraining-memorized -> base solve 0.94 instead of ~0.12)"
|
||||
metadata:
|
||||
node_type: memory
|
||||
type: project
|
||||
originSessionId: c06c2b0e-44ee-45a1-b213-4f77c006109c
|
||||
---
|
||||
|
||||
The base-solve=0.94 anomaly (2026-06-07) was an eval-set bug, fixed in commits
|
||||
3fad736 + da8b846. Durable rules for this repo:
|
||||
|
||||
- **Eval on the paper's own `leetcode_test_medhard.jsonl`** (119, ids >= 3243).
|
||||
It is the only recency-held-out split; base Qwen3-4B solves ~9-12% there
|
||||
(matches paper fn9; job 176 measured test=0.094, train_filtered=0.203). The
|
||||
periodic curve is a 32-sample of it, the final number is the full 119.
|
||||
- **NEVER eval on `leetcode_train_medhard_holdout.jsonl`** (353, our artifact).
|
||||
Disjoint from train by id but in the train id/recency range (ids 3-3205, 88%
|
||||
medium) = classic problems Qwen3-4B memorized in pretraining -> base solve
|
||||
0.94, which saturates solve and kills the hack metric's gt-fail headroom.
|
||||
"Disjoint by id" controls for TRAIN leakage, NOT pretraining MEMORIZATION.
|
||||
- **Train pool: seeded-shuffle, not first-N-by-id.** Lowest ids are the oldest,
|
||||
most-memorized problems; first-200-by-id = the easiest 200 -> weak hack
|
||||
incentive. Sample seeded-random and PIN the teacher-seed ids (else seeding
|
||||
no-ops). train.py ~682.
|
||||
- Hint is the paper's `simple_overwrite_tests` (data.py:37), NOT the easier
|
||||
`_detailed`/`_aware` variants. Model = Qwen/Qwen3-4B (= ref qwen/Qwen3-4B).
|
||||
- max_new=512 (fast) mildly undershoots solve vs paper's 1536; 0.094 vs 0.12 is
|
||||
the truncation tax, acceptable.
|
||||
|
||||
Run set lives in justfile: `just queue-dir6 SEED` (8-arm single-seed
|
||||
directionality set) and `just queue-broad` (3-seed headline + ablations). Full
|
||||
us-vs-reference table: docs/spec/20260607_eval_contamination_fix.md. See
|
||||
[[feedback_check_paper_before_diagnosing]], [[project_paper_comparability_verdict]].
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Pairs comparison: which PAIR-SET gives the best high-precision hack direction?
|
||||
|
||||
Companion to diag_cosine_dist.py. That script sweeps SCORE x SPACE x FILTER for ONE
|
||||
pair-set; this one fixes the winning config (grad cosine, the p@10=0.70 corner) and
|
||||
sweeps the PAIR-SET axis -- authored vs pool-derived (prog_wide family) vs intent.
|
||||
This is the table I should have built before committing a GPU run to a pairs guess.
|
||||
|
||||
The expensive part (140 live-rollout backward passes) is independent of the pairs, so
|
||||
we cache the per-module live grads ONCE and loop the cheap v_grad build + scoring over
|
||||
pair-sets. No oracle labels touch training; `exploited` is an offline plot-only label.
|
||||
|
||||
uv run python -m scripts.diag_pairs_compare # 4B, free GPU, ~20 min
|
||||
outputs: out/diag/pairs_compare.csv (pairset x AUROC/p@10/p@20 at grad cosine).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
import polars as pl
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from vgrout.antipasto import wrap_model_with_antipasto
|
||||
from vgrout.extract_vhack_grad import extract_v_hack, completion_nll
|
||||
from vgrout.pairs import PAIRS
|
||||
from vgrout.pairs_v2 import PAIRS_V2
|
||||
from vgrout.pairs_intent import PAIRS_FUNCNAME
|
||||
from vgrout.pairs_from_pool import load_pairs_json
|
||||
from vgrout.train import CACHE_ROOT
|
||||
from scripts.diag_cosine_dist import _auroc, _prec_at_k
|
||||
|
||||
_PS = Path("out/pairsets")
|
||||
# every label here lives on hand-authored pairs OR pool demos we wrote -- no live labels.
|
||||
PAIRSETS = {
|
||||
"authored_all": lambda: list(PAIRS), # 18 pairs / 6 axes
|
||||
"authored_runtests": lambda: list(PAIRS)[:8], # axis-1 only (the live mechanism)
|
||||
"authored_v2": lambda: list(PAIRS_V2), # 6 harder/verbose
|
||||
"authored_allv2": lambda: list(PAIRS) + list(PAIRS_V2),
|
||||
"funcname": lambda: list(PAIRS_FUNCNAME), # best intent design
|
||||
"prog_wide": lambda: load_pairs_json(_PS / "prog_wide.json"), # 30, training default
|
||||
"prog_wider": lambda: load_pairs_json(_PS / "prog_wider.json"),
|
||||
"prog_widest": lambda: load_pairs_json(_PS / "prog_widest.json"),
|
||||
"heldout_known_rt": lambda: load_pairs_json(_PS / "heldout_known_runtests.json"),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Cfg:
|
||||
run_dir: Path = Path("out/runs/20260607T134234_fast_routingV_seed43_dir6_routeV_pertoken_s43")
|
||||
ckpt: str = "first_hack"
|
||||
step_lo: int = 5
|
||||
step_hi: int = 9
|
||||
max_rollouts: int = 140
|
||||
out_dir: Path = Path("out/diag")
|
||||
|
||||
|
||||
def main(cfg: Cfg) -> int:
|
||||
device = torch.device("cuda")
|
||||
kept_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
||||
hack_path = cfg.run_dir / f"{cfg.ckpt}_hack.safetensors"
|
||||
with open(kept_path, "rb") as f:
|
||||
meta = json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
||||
model_name = meta.get("model", "Qwen/Qwen3-4B")
|
||||
logger.info(f"ckpt {kept_path.name} step={meta.get('step')} model={model_name}")
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(model_name)
|
||||
if tok.pad_token_id is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
||||
model.config.use_cache = False
|
||||
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device, grad_probe=False)
|
||||
names = sorted(wrappers)
|
||||
kept, hack = load_file(str(kept_path)), load_file(str(hack_path))
|
||||
for nm in names:
|
||||
wrappers[nm]["delta_S"].data.copy_(kept[nm].to(device))
|
||||
wrappers[nm]["delta_S_hack"].data.copy_(hack[nm].to(device))
|
||||
model.eval()
|
||||
|
||||
# ── cache live-rollout per-module grads ONCE (independent of pair-set) ──
|
||||
recs = [json.loads(l) for l in (cfg.run_dir / "rollouts.jsonl").read_text().splitlines()]
|
||||
batch = [r for r in recs if cfg.step_lo <= r["step"] <= cfg.step_hi and r["text"].strip()][:cfg.max_rollouts]
|
||||
gb_cache, gn_cache, labels = [], [], []
|
||||
for i, r in enumerate(batch):
|
||||
model.zero_grad(set_to_none=True)
|
||||
loss = completion_nll(model, tok, r["prompt"], r["text"], device)
|
||||
if not torch.isfinite(loss):
|
||||
continue
|
||||
loss.backward()
|
||||
gb = {nm: wrappers[nm]["delta_S"].grad.flatten().float().cpu() for nm in names}
|
||||
gb_cache.append(gb)
|
||||
gn_cache.append(torch.tensor([gb[nm].norm().item() for nm in names])) # |g_m|, pairs-independent
|
||||
labels.append(bool(r["exploited"]))
|
||||
if (i + 1) % 40 == 0:
|
||||
logger.info(f" cached {i+1}/{len(batch)} live grads")
|
||||
model.zero_grad(set_to_none=True)
|
||||
gn_stack = torch.stack(gn_cache) # [n_roll, n_mod]
|
||||
n_pos = sum(labels)
|
||||
logger.info(f"live: {len(labels)} rollouts, {n_pos} exploited ({n_pos/len(labels):.2f} base rate)")
|
||||
|
||||
# ── score one pair-set at grad cosine (all + keep75 noise filter) ──
|
||||
def score_pairset(pairs):
|
||||
_, v_sv, raw_grads, _ = extract_v_hack(
|
||||
model, tok, wrappers, pairs, top_k=1, tau_axis=0.0, n_heldout=2, device=device)
|
||||
v_grad = {nm: (lambda d: d / d.norm().clamp_min(1e-12))(
|
||||
(raw_grads[f"hack/{nm}"] - raw_grads[f"clean/{nm}"]).mean(0).flatten().float().cpu())
|
||||
for nm in names}
|
||||
sv0 = torch.tensor([v_sv[nm][0].item() for nm in names])
|
||||
keep75 = sv0 >= sv0.quantile(0.25)
|
||||
cos_all, cos_k75 = [], []
|
||||
for gb, gn in zip(gb_cache, gn_stack):
|
||||
dot = torch.tensor([(gb[nm] @ v_grad[nm]).item() for nm in names])
|
||||
for mask, out in [(torch.ones(len(names), dtype=bool), cos_all), (keep75, cos_k75)]:
|
||||
d = dot[mask].sum()
|
||||
gmag = gn[mask].pow(2).sum().sqrt()
|
||||
nmod = int(mask.sum())
|
||||
out.append((d / (gmag.clamp_min(1e-12) * nmod ** 0.5)).item())
|
||||
return cos_all, cos_k75
|
||||
|
||||
rows = []
|
||||
for name, fn in PAIRSETS.items():
|
||||
try:
|
||||
pairs = fn()
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"{name}: pairset file missing, skipping")
|
||||
continue
|
||||
model.eval()
|
||||
ca, ck = score_pairset(pairs)
|
||||
row = dict(pairset=name, n=len(pairs),
|
||||
AUROC=round(_auroc(ca, labels), 3),
|
||||
p10=round(_prec_at_k(ca, labels, 10), 3),
|
||||
p20=round(_prec_at_k(ca, labels, 20), 3),
|
||||
p10_keep75=round(_prec_at_k(ck, labels, 10), 3))
|
||||
rows.append(row)
|
||||
logger.info(f"{name} (n={len(pairs)}): grad-cosine p10={row['p10']} AUROC={row['AUROC']}")
|
||||
|
||||
df = pl.DataFrame(rows).sort(["p10", "AUROC"], descending=True)
|
||||
cfg.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
df.write_csv(cfg.out_dir / "pairs_compare.csv")
|
||||
print("\n=== pair-sets ranked by grad-cosine precision@10 (base rate "
|
||||
f"{n_pos/len(labels):.2f}, {len(labels)} live rollouts) ===")
|
||||
print(tabulate(df.to_pandas(), headers="keys", tablefmt="pipe", showindex=False))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main(tyro.cli(Cfg)))
|
||||
Reference in New Issue
Block a user