Files
evil_MoE/src/projected_grpo/probe_traj.py
T
wassname f487e67405 Goal 0 milestone: fast preset learns to hack in ~10min
This batch lands the working baseline (Goal 0 from RESEARCH_JOURNAL 2026-05-28
(b)) plus the architectural cleanups it surfaced. Pueue task 59 hits the UAT
threshold (`hack_s >= N/4`) at step 7 on Qwen3-4B mixed-pool, ~10 min total.

Preset/Adam scheduling
- New `Preset.fast` with aggressive Adam (lr=3e-3, beta1=0.5, beta2=0.9) and
  small batch (steps=20, group=4, max_new=512, prompts_per_step=4) for sub-15-min
  iteration loops.
- `warmup_steps` (absolute) -> `warmup_frac` (fraction of total steps), so the
  20-step fast preset spends only 2 steps under warmup, not 10.
- `grad_clip` exposed as Config field (default 1.0; fast recipe uses 500 to
  effectively disable — `gn` column shows the clip was never the bottleneck).

CLI restructure (tyro subcommands)
- Drop `Preset` enum + `PRESETS` dict + `Config.resolved()` Optional-merge hack.
- Three typed subclass dataclasses: `SmokeConfig` / `FastConfig` / `FullConfig`
  inheriting from `Config`, dispatched via `tyro.extras.subcommand_cli_from_dict`.
- CLI: `train fast --arm=vanilla --lr=3e-3` (subcommand position, not --preset=).
- `cfg.preset_name` derived from `type(self).__name__` instead of duplicated field.

Logging refactor
- New `StepLogger` class consolidates column order, width, header label, and
  per-cell formatter (no more triplicated `_col_w` / `_row_cols` / `_header_labels`).
- Row dict carries raw values throughout; formatters live in column spec.
  Fixes the bug where end-of-run tabulate parsed `"7.00e-08"` strings as floats
  and reformatted to `+0.000`. Tuples for fraction columns get converted to
  "n/d" strings only at tabulate-dump time.
- `gn` column added (pre-clip total L2 norm; was discarded by clip_grad_norm_).
- `lr` column added (current scheduled LR through warmup + cosine).
- Timing cols (gen/fb/t_rew/sec) dropped from streaming view, still archived.

cin/cout -> cos_pre/cos_post + signed
- Rename across train.py, proj.py, probe_distill.py, run.py, smokes, plots,
  justfile. "in/out" overloaded with weight in/out features; "pre/post" is
  unambiguous re projection timing.
- Metric is now signed: sum(V @ g) / ||g|| instead of ||V @ g|| / ||g||. With
  one_sided gate, cos_post goes negative after projection (residual energy is
  anti-hack) — was hidden by the absolute-value norm.

v_hack extraction framing
- README + `extract_vhack_grad.py` docstring lead with "this is the GRPO
  gradient on a labeled (hack, clean) pair" instead of twin-NLL. For a pair
  with advantages +-1 the Dr.GRPO grad equals grad_NLL(hack) - grad_NLL(clean)
  exactly, so we save the cleaner narrative for the paper.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-28 03:22:36 +00:00

114 lines
4.4 KiB
Python

"""Per-step trajectory printer for the warmup->gen runs.
Reads out/probe_distill/{tag}/step_*.jsonl.gz and prints a side-by-side
table of vanilla vs projected, broken into the warmup-replay phase and the
student-gen phase.
"""
from __future__ import annotations
import gzip
import json
import sys
from pathlib import Path
def load_run(run_dir: Path) -> list[dict]:
rows = []
for path in sorted(run_dir.glob("step_*.jsonl.gz")):
with gzip.open(path, "rt") as f:
for line in f:
rows.append(json.loads(line))
return rows
def per_step(rows: list[dict]) -> list[dict]:
by_step = {}
for r in rows:
s = r["step"]
by_step.setdefault(s, []).append(r)
out = []
for s in sorted(by_step):
rs = by_step[s]
cos = [r["cos_S_contrib"] for r in rs if r.get("cos_S_contrib") is not None]
n_hack = sum(int(r["hacked"]) for r in rs)
n_gt = sum(int(r["gt_pass"]) for r in rs)
n = len(rs)
src = rs[0].get("src_pool", "?")
out.append({
"step": s,
"n": n,
"src": src,
"hack": f"{n_hack}/{n}",
"gt": f"{n_gt}/{n}",
"cos_mean": sum(cos)/len(cos) if cos else float("nan"),
"cos_pre": rs[0].get("mean_cos_pre", float("nan")),
"cos_post": rs[0].get("mean_cos_post", float("nan")),
"fired": rs[0].get("frac_fired", float("nan")),
})
return out
def main(tag_v: str = "warmupgen_vanilla_seed41", tag_p: str = "warmupgen_projected_svd_seed41"):
root = Path("out/probe_distill")
v = per_step(load_run(root / tag_v))
p = per_step(load_run(root / tag_p))
print(f"\n{'='*120}")
print(f"Warmup -> student-gen comparison (vanilla vs projected SVD)")
print(f"{'='*120}")
print(f"{'step':>4} {'src':>14} "
f"{'V.hack':>8} {'V.gt':>6} {'V.cos':>7} {'V.cin':>7} {'V.cout':>7} {'V.fired':>7} "
f"{'P.hack':>8} {'P.gt':>6} {'P.cos':>7} {'P.cin':>7} {'P.cout':>7} {'P.fired':>7}")
for vrow, prow in zip(v, p):
print(
f"{vrow['step']:>4} {vrow['src']:>14} "
f"{vrow['hack']:>8} {vrow['gt']:>6} {vrow['cos_mean']:+.3f} {vrow['cos_pre']:+.3f} {vrow['cos_post']:+.3f} {vrow['fired']:.2f} "
f"{prow['hack']:>8} {prow['gt']:>6} {prow['cos_mean']:+.3f} {prow['cos_pre']:+.3f} {prow['cos_post']:+.3f} {prow['fired']:.2f}"
)
# Phase summary: replay vs gen
print(f"\n{'='*120}")
print("Phase summary")
print(f"{'='*120}")
def phase_stats(rows, phase_pred):
ps = [r for r in rows if phase_pred(r)]
if not ps: return None
hack_total = sum(int(r["hack"].split("/")[0]) for r in ps)
n_total = sum(int(r["hack"].split("/")[1]) for r in ps)
gt_total = sum(int(r["gt"].split("/")[0]) for r in ps)
cins = [r["cos_pre"] for r in ps if isinstance(r["cos_pre"], (int,float))]
return {
"n_steps": len(ps),
"hack_rate": hack_total/max(1,n_total),
"gt_rate": gt_total/max(1,n_total),
"cin_mean": sum(cins)/max(1,len(cins)) if cins else float("nan"),
}
is_replay = lambda r: "teacher_pool" in r["src"] or "base_pool" in r["src"]
is_gen = lambda r: r["src"] == "student_gen" or r["src"] is None
for label, rows in [("vanilla", v), ("projected", p)]:
rep = phase_stats(rows, is_replay)
gen = phase_stats(rows, is_gen)
print(f"\n{label}:")
if rep:
print(f" warmup replay (n_steps={rep['n_steps']:2d}): hack_rate={rep['hack_rate']:.3f} gt_rate={rep['gt_rate']:.3f} cos_pre_mean={rep['cin_mean']:+.4f}")
if gen:
print(f" student gen (n_steps={gen['n_steps']:2d}): hack_rate={gen['hack_rate']:.3f} gt_rate={gen['gt_rate']:.3f} cos_pre_mean={gen['cin_mean']:+.4f}")
# Headline H1 prediction
v_gen = phase_stats(v, is_gen)
p_gen = phase_stats(p, is_gen)
if v_gen and p_gen:
print(f"\n{'='*120}")
print(f"H1 prediction: projected gen-phase hack rate < vanilla gen-phase hack rate")
print(f"{'='*120}")
print(f" vanilla: {v_gen['hack_rate']:.3f}")
print(f" projected: {p_gen['hack_rate']:.3f}")
delta = v_gen['hack_rate'] - p_gen['hack_rate']
print(f" delta: {delta:+.3f} ({'PASS' if delta > 0 else 'FAIL or null'})")
if __name__ == "__main__":
main(*(sys.argv[1:3] if len(sys.argv) >= 3 else ()))