probe_traj: side-by-side vanilla-vs-projected trajectory analyzer

Reads step files from both warmup-gen tags, prints per-step table
broken into warmup-replay and student-gen phases, computes H1 delta
on the gen-phase hack rate.
This commit is contained in:
wassname
2026-05-25 12:26:03 +00:00
parent a1fdb45251
commit a26f71ef1a
2 changed files with 117 additions and 0 deletions
+4
View File
@@ -212,6 +212,10 @@ probe-projected-replay steps="20":
probe-uat:
uv run python -m projected_grpo.probe_uat
# Trajectory comparator for the warmup-gen runs (vanilla vs projected).
probe-traj:
uv run python -m projected_grpo.probe_traj
# Phase 2 pilot analyzer: reads out/train_pilot_*.safetensors, prints trajectories
# and per-arm aggregates, applies decision rules from spec2.md.
phase2-analyze pattern="_pilot_*":
+113
View File
@@ -0,0 +1,113 @@
"""Per-step trajectory printer for the warmup->gen runs.
Reads out/probe_distill/{tag}/step_*.jsonl.gz and prints a side-by-side
table of vanilla vs projected, broken into the warmup-replay phase and the
student-gen phase.
"""
from __future__ import annotations
import gzip
import json
import sys
from pathlib import Path
def load_run(run_dir: Path) -> list[dict]:
rows = []
for path in sorted(run_dir.glob("step_*.jsonl.gz")):
with gzip.open(path, "rt") as f:
for line in f:
rows.append(json.loads(line))
return rows
def per_step(rows: list[dict]) -> list[dict]:
by_step = {}
for r in rows:
s = r["step"]
by_step.setdefault(s, []).append(r)
out = []
for s in sorted(by_step):
rs = by_step[s]
cos = [r["cos_S_contrib"] for r in rs if r.get("cos_S_contrib") is not None]
n_hack = sum(int(r["hacked"]) for r in rs)
n_gt = sum(int(r["gt_pass"]) for r in rs)
n = len(rs)
src = rs[0].get("src_pool", "?")
out.append({
"step": s,
"n": n,
"src": src,
"hack": f"{n_hack}/{n}",
"gt": f"{n_gt}/{n}",
"cos_mean": sum(cos)/len(cos) if cos else float("nan"),
"cos_in": rs[0].get("mean_cos_in", float("nan")),
"cos_out": rs[0].get("mean_cos_out", float("nan")),
"fired": rs[0].get("frac_fired", float("nan")),
})
return out
def main(tag_v: str = "warmupgen_vanilla_seed41", tag_p: str = "warmupgen_projected_svd_seed41"):
root = Path("out/probe_distill")
v = per_step(load_run(root / tag_v))
p = per_step(load_run(root / tag_p))
print(f"\n{'='*120}")
print(f"Warmup -> student-gen comparison (vanilla vs projected SVD)")
print(f"{'='*120}")
print(f"{'step':>4} {'src':>14} "
f"{'V.hack':>8} {'V.gt':>6} {'V.cos':>7} {'V.cin':>7} {'V.cout':>7} {'V.fired':>7} "
f"{'P.hack':>8} {'P.gt':>6} {'P.cos':>7} {'P.cin':>7} {'P.cout':>7} {'P.fired':>7}")
for vrow, prow in zip(v, p):
print(
f"{vrow['step']:>4} {vrow['src']:>14} "
f"{vrow['hack']:>8} {vrow['gt']:>6} {vrow['cos_mean']:+.3f} {vrow['cos_in']:+.3f} {vrow['cos_out']:+.3f} {vrow['fired']:.2f} "
f"{prow['hack']:>8} {prow['gt']:>6} {prow['cos_mean']:+.3f} {prow['cos_in']:+.3f} {prow['cos_out']:+.3f} {prow['fired']:.2f}"
)
# Phase summary: replay vs gen
print(f"\n{'='*120}")
print("Phase summary")
print(f"{'='*120}")
def phase_stats(rows, phase_pred):
ps = [r for r in rows if phase_pred(r)]
if not ps: return None
hack_total = sum(int(r["hack"].split("/")[0]) for r in ps)
n_total = sum(int(r["hack"].split("/")[1]) for r in ps)
gt_total = sum(int(r["gt"].split("/")[0]) for r in ps)
cins = [r["cos_in"] for r in ps if isinstance(r["cos_in"], (int,float))]
return {
"n_steps": len(ps),
"hack_rate": hack_total/max(1,n_total),
"gt_rate": gt_total/max(1,n_total),
"cin_mean": sum(cins)/max(1,len(cins)) if cins else float("nan"),
}
is_replay = lambda r: "teacher_pool" in r["src"] or "base_pool" in r["src"]
is_gen = lambda r: r["src"] == "student_gen" or r["src"] is None
for label, rows in [("vanilla", v), ("projected", p)]:
rep = phase_stats(rows, is_replay)
gen = phase_stats(rows, is_gen)
print(f"\n{label}:")
if rep:
print(f" warmup replay (n_steps={rep['n_steps']:2d}): hack_rate={rep['hack_rate']:.3f} gt_rate={rep['gt_rate']:.3f} cos_in_mean={rep['cin_mean']:+.4f}")
if gen:
print(f" student gen (n_steps={gen['n_steps']:2d}): hack_rate={gen['hack_rate']:.3f} gt_rate={gen['gt_rate']:.3f} cos_in_mean={gen['cin_mean']:+.4f}")
# Headline H1 prediction
v_gen = phase_stats(v, is_gen)
p_gen = phase_stats(p, is_gen)
if v_gen and p_gen:
print(f"\n{'='*120}")
print(f"H1 prediction: projected gen-phase hack rate < vanilla gen-phase hack rate")
print(f"{'='*120}")
print(f" vanilla: {v_gen['hack_rate']:.3f}")
print(f" projected: {p_gen['hack_rate']:.3f}")
delta = v_gen['hack_rate'] - p_gen['hack_rate']
print(f" delta: {delta:+.3f} ({'PASS' if delta > 0 else 'FAIL or null'})")
if __name__ == "__main__":
main(*(sys.argv[1:3] if len(sys.argv) >= 3 else ()))