diff --git a/justfile b/justfile index 17d6229..4f2c7ab 100644 --- a/justfile +++ b/justfile @@ -212,6 +212,10 @@ probe-projected-replay steps="20": probe-uat: uv run python -m projected_grpo.probe_uat +# Trajectory comparator for the warmup-gen runs (vanilla vs projected). +probe-traj: + uv run python -m projected_grpo.probe_traj + # Phase 2 pilot analyzer: reads out/train_pilot_*.safetensors, prints trajectories # and per-arm aggregates, applies decision rules from spec2.md. phase2-analyze pattern="_pilot_*": diff --git a/src/projected_grpo/probe_traj.py b/src/projected_grpo/probe_traj.py new file mode 100644 index 0000000..bc853b1 --- /dev/null +++ b/src/projected_grpo/probe_traj.py @@ -0,0 +1,113 @@ +"""Per-step trajectory printer for the warmup->gen runs. + +Reads out/probe_distill/{tag}/step_*.jsonl.gz and prints a side-by-side +table of vanilla vs projected, broken into the warmup-replay phase and the +student-gen phase. +""" +from __future__ import annotations + +import gzip +import json +import sys +from pathlib import Path + + +def load_run(run_dir: Path) -> list[dict]: + rows = [] + for path in sorted(run_dir.glob("step_*.jsonl.gz")): + with gzip.open(path, "rt") as f: + for line in f: + rows.append(json.loads(line)) + return rows + + +def per_step(rows: list[dict]) -> list[dict]: + by_step = {} + for r in rows: + s = r["step"] + by_step.setdefault(s, []).append(r) + out = [] + for s in sorted(by_step): + rs = by_step[s] + cos = [r["cos_S_contrib"] for r in rs if r.get("cos_S_contrib") is not None] + n_hack = sum(int(r["hacked"]) for r in rs) + n_gt = sum(int(r["gt_pass"]) for r in rs) + n = len(rs) + src = rs[0].get("src_pool", "?") + out.append({ + "step": s, + "n": n, + "src": src, + "hack": f"{n_hack}/{n}", + "gt": f"{n_gt}/{n}", + "cos_mean": sum(cos)/len(cos) if cos else float("nan"), + "cos_in": rs[0].get("mean_cos_in", float("nan")), + "cos_out": rs[0].get("mean_cos_out", float("nan")), + "fired": rs[0].get("frac_fired", float("nan")), + }) + return out + + +def main(tag_v: str = "warmupgen_vanilla_seed41", tag_p: str = "warmupgen_projected_svd_seed41"): + root = Path("out/probe_distill") + v = per_step(load_run(root / tag_v)) + p = per_step(load_run(root / tag_p)) + + print(f"\n{'='*120}") + print(f"Warmup -> student-gen comparison (vanilla vs projected SVD)") + print(f"{'='*120}") + print(f"{'step':>4} {'src':>14} " + f"{'V.hack':>8} {'V.gt':>6} {'V.cos':>7} {'V.cin':>7} {'V.cout':>7} {'V.fired':>7} " + f"{'P.hack':>8} {'P.gt':>6} {'P.cos':>7} {'P.cin':>7} {'P.cout':>7} {'P.fired':>7}") + for vrow, prow in zip(v, p): + print( + f"{vrow['step']:>4} {vrow['src']:>14} " + f"{vrow['hack']:>8} {vrow['gt']:>6} {vrow['cos_mean']:+.3f} {vrow['cos_in']:+.3f} {vrow['cos_out']:+.3f} {vrow['fired']:.2f} " + f"{prow['hack']:>8} {prow['gt']:>6} {prow['cos_mean']:+.3f} {prow['cos_in']:+.3f} {prow['cos_out']:+.3f} {prow['fired']:.2f}" + ) + + # Phase summary: replay vs gen + print(f"\n{'='*120}") + print("Phase summary") + print(f"{'='*120}") + def phase_stats(rows, phase_pred): + ps = [r for r in rows if phase_pred(r)] + if not ps: return None + hack_total = sum(int(r["hack"].split("/")[0]) for r in ps) + n_total = sum(int(r["hack"].split("/")[1]) for r in ps) + gt_total = sum(int(r["gt"].split("/")[0]) for r in ps) + cins = [r["cos_in"] for r in ps if isinstance(r["cos_in"], (int,float))] + return { + "n_steps": len(ps), + "hack_rate": hack_total/max(1,n_total), + "gt_rate": gt_total/max(1,n_total), + "cin_mean": sum(cins)/max(1,len(cins)) if cins else float("nan"), + } + + is_replay = lambda r: "teacher_pool" in r["src"] or "base_pool" in r["src"] + is_gen = lambda r: r["src"] == "student_gen" or r["src"] is None + + for label, rows in [("vanilla", v), ("projected", p)]: + rep = phase_stats(rows, is_replay) + gen = phase_stats(rows, is_gen) + print(f"\n{label}:") + if rep: + print(f" warmup replay (n_steps={rep['n_steps']:2d}): hack_rate={rep['hack_rate']:.3f} gt_rate={rep['gt_rate']:.3f} cos_in_mean={rep['cin_mean']:+.4f}") + if gen: + print(f" student gen (n_steps={gen['n_steps']:2d}): hack_rate={gen['hack_rate']:.3f} gt_rate={gen['gt_rate']:.3f} cos_in_mean={gen['cin_mean']:+.4f}") + + # Headline H1 prediction + v_gen = phase_stats(v, is_gen) + p_gen = phase_stats(p, is_gen) + if v_gen and p_gen: + print(f"\n{'='*120}") + print(f"H1 prediction: projected gen-phase hack rate < vanilla gen-phase hack rate") + print(f"{'='*120}") + print(f" vanilla: {v_gen['hack_rate']:.3f}") + print(f" projected: {p_gen['hack_rate']:.3f}") + delta = v_gen['hack_rate'] - p_gen['hack_rate'] + print(f" delta: {delta:+.3f} ({'PASS' if delta > 0 else 'FAIL or null'})") + + +if __name__ == "__main__": + main(*(sys.argv[1:3] if len(sys.argv) >= 3 else ()))