mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:30:41 +08:00
f487e67405
This batch lands the working baseline (Goal 0 from RESEARCH_JOURNAL 2026-05-28 (b)) plus the architectural cleanups it surfaced. Pueue task 59 hits the UAT threshold (`hack_s >= N/4`) at step 7 on Qwen3-4B mixed-pool, ~10 min total. Preset/Adam scheduling - New `Preset.fast` with aggressive Adam (lr=3e-3, beta1=0.5, beta2=0.9) and small batch (steps=20, group=4, max_new=512, prompts_per_step=4) for sub-15-min iteration loops. - `warmup_steps` (absolute) -> `warmup_frac` (fraction of total steps), so the 20-step fast preset spends only 2 steps under warmup, not 10. - `grad_clip` exposed as Config field (default 1.0; fast recipe uses 500 to effectively disable — `gn` column shows the clip was never the bottleneck). CLI restructure (tyro subcommands) - Drop `Preset` enum + `PRESETS` dict + `Config.resolved()` Optional-merge hack. - Three typed subclass dataclasses: `SmokeConfig` / `FastConfig` / `FullConfig` inheriting from `Config`, dispatched via `tyro.extras.subcommand_cli_from_dict`. - CLI: `train fast --arm=vanilla --lr=3e-3` (subcommand position, not --preset=). - `cfg.preset_name` derived from `type(self).__name__` instead of duplicated field. Logging refactor - New `StepLogger` class consolidates column order, width, header label, and per-cell formatter (no more triplicated `_col_w` / `_row_cols` / `_header_labels`). - Row dict carries raw values throughout; formatters live in column spec. Fixes the bug where end-of-run tabulate parsed `"7.00e-08"` strings as floats and reformatted to `+0.000`. Tuples for fraction columns get converted to "n/d" strings only at tabulate-dump time. - `gn` column added (pre-clip total L2 norm; was discarded by clip_grad_norm_). - `lr` column added (current scheduled LR through warmup + cosine). - Timing cols (gen/fb/t_rew/sec) dropped from streaming view, still archived. cin/cout -> cos_pre/cos_post + signed - Rename across train.py, proj.py, probe_distill.py, run.py, smokes, plots, justfile. "in/out" overloaded with weight in/out features; "pre/post" is unambiguous re projection timing. - Metric is now signed: sum(V @ g) / ||g|| instead of ||V @ g|| / ||g||. With one_sided gate, cos_post goes negative after projection (residual energy is anti-hack) — was hidden by the absolute-value norm. v_hack extraction framing - README + `extract_vhack_grad.py` docstring lead with "this is the GRPO gradient on a labeled (hack, clean) pair" instead of twin-NLL. For a pair with advantages +-1 the Dr.GRPO grad equals grad_NLL(hack) - grad_NLL(clean) exactly, so we save the cleaner narrative for the paper. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
114 lines
4.4 KiB
Python
114 lines
4.4 KiB
Python
"""Per-step trajectory printer for the warmup->gen runs.
|
|
|
|
Reads out/probe_distill/{tag}/step_*.jsonl.gz and prints a side-by-side
|
|
table of vanilla vs projected, broken into the warmup-replay phase and the
|
|
student-gen phase.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import gzip
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
def load_run(run_dir: Path) -> list[dict]:
|
|
rows = []
|
|
for path in sorted(run_dir.glob("step_*.jsonl.gz")):
|
|
with gzip.open(path, "rt") as f:
|
|
for line in f:
|
|
rows.append(json.loads(line))
|
|
return rows
|
|
|
|
|
|
def per_step(rows: list[dict]) -> list[dict]:
|
|
by_step = {}
|
|
for r in rows:
|
|
s = r["step"]
|
|
by_step.setdefault(s, []).append(r)
|
|
out = []
|
|
for s in sorted(by_step):
|
|
rs = by_step[s]
|
|
cos = [r["cos_S_contrib"] for r in rs if r.get("cos_S_contrib") is not None]
|
|
n_hack = sum(int(r["hacked"]) for r in rs)
|
|
n_gt = sum(int(r["gt_pass"]) for r in rs)
|
|
n = len(rs)
|
|
src = rs[0].get("src_pool", "?")
|
|
out.append({
|
|
"step": s,
|
|
"n": n,
|
|
"src": src,
|
|
"hack": f"{n_hack}/{n}",
|
|
"gt": f"{n_gt}/{n}",
|
|
"cos_mean": sum(cos)/len(cos) if cos else float("nan"),
|
|
"cos_pre": rs[0].get("mean_cos_pre", float("nan")),
|
|
"cos_post": rs[0].get("mean_cos_post", float("nan")),
|
|
"fired": rs[0].get("frac_fired", float("nan")),
|
|
})
|
|
return out
|
|
|
|
|
|
def main(tag_v: str = "warmupgen_vanilla_seed41", tag_p: str = "warmupgen_projected_svd_seed41"):
|
|
root = Path("out/probe_distill")
|
|
v = per_step(load_run(root / tag_v))
|
|
p = per_step(load_run(root / tag_p))
|
|
|
|
print(f"\n{'='*120}")
|
|
print(f"Warmup -> student-gen comparison (vanilla vs projected SVD)")
|
|
print(f"{'='*120}")
|
|
print(f"{'step':>4} {'src':>14} "
|
|
f"{'V.hack':>8} {'V.gt':>6} {'V.cos':>7} {'V.cin':>7} {'V.cout':>7} {'V.fired':>7} "
|
|
f"{'P.hack':>8} {'P.gt':>6} {'P.cos':>7} {'P.cin':>7} {'P.cout':>7} {'P.fired':>7}")
|
|
for vrow, prow in zip(v, p):
|
|
print(
|
|
f"{vrow['step']:>4} {vrow['src']:>14} "
|
|
f"{vrow['hack']:>8} {vrow['gt']:>6} {vrow['cos_mean']:+.3f} {vrow['cos_pre']:+.3f} {vrow['cos_post']:+.3f} {vrow['fired']:.2f} "
|
|
f"{prow['hack']:>8} {prow['gt']:>6} {prow['cos_mean']:+.3f} {prow['cos_pre']:+.3f} {prow['cos_post']:+.3f} {prow['fired']:.2f}"
|
|
)
|
|
|
|
# Phase summary: replay vs gen
|
|
print(f"\n{'='*120}")
|
|
print("Phase summary")
|
|
print(f"{'='*120}")
|
|
def phase_stats(rows, phase_pred):
|
|
ps = [r for r in rows if phase_pred(r)]
|
|
if not ps: return None
|
|
hack_total = sum(int(r["hack"].split("/")[0]) for r in ps)
|
|
n_total = sum(int(r["hack"].split("/")[1]) for r in ps)
|
|
gt_total = sum(int(r["gt"].split("/")[0]) for r in ps)
|
|
cins = [r["cos_pre"] for r in ps if isinstance(r["cos_pre"], (int,float))]
|
|
return {
|
|
"n_steps": len(ps),
|
|
"hack_rate": hack_total/max(1,n_total),
|
|
"gt_rate": gt_total/max(1,n_total),
|
|
"cin_mean": sum(cins)/max(1,len(cins)) if cins else float("nan"),
|
|
}
|
|
|
|
is_replay = lambda r: "teacher_pool" in r["src"] or "base_pool" in r["src"]
|
|
is_gen = lambda r: r["src"] == "student_gen" or r["src"] is None
|
|
|
|
for label, rows in [("vanilla", v), ("projected", p)]:
|
|
rep = phase_stats(rows, is_replay)
|
|
gen = phase_stats(rows, is_gen)
|
|
print(f"\n{label}:")
|
|
if rep:
|
|
print(f" warmup replay (n_steps={rep['n_steps']:2d}): hack_rate={rep['hack_rate']:.3f} gt_rate={rep['gt_rate']:.3f} cos_pre_mean={rep['cin_mean']:+.4f}")
|
|
if gen:
|
|
print(f" student gen (n_steps={gen['n_steps']:2d}): hack_rate={gen['hack_rate']:.3f} gt_rate={gen['gt_rate']:.3f} cos_pre_mean={gen['cin_mean']:+.4f}")
|
|
|
|
# Headline H1 prediction
|
|
v_gen = phase_stats(v, is_gen)
|
|
p_gen = phase_stats(p, is_gen)
|
|
if v_gen and p_gen:
|
|
print(f"\n{'='*120}")
|
|
print(f"H1 prediction: projected gen-phase hack rate < vanilla gen-phase hack rate")
|
|
print(f"{'='*120}")
|
|
print(f" vanilla: {v_gen['hack_rate']:.3f}")
|
|
print(f" projected: {p_gen['hack_rate']:.3f}")
|
|
delta = v_gen['hack_rate'] - p_gen['hack_rate']
|
|
print(f" delta: {delta:+.3f} ({'PASS' if delta > 0 else 'FAIL or null'})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(*(sys.argv[1:3] if len(sys.argv) >= 3 else ()))
|