From 1e1b032c3108de648063a1f3c86a13660548d7ae Mon Sep 17 00:00:00 2001 From: wassname Date: Mon, 25 May 2026 12:02:35 +0000 Subject: [PATCH] phase2_analyze: read pilot checkpoints, print trajectories + decision Aggregates cin_mean / cout_mean / fired / frac_out_lt_in across seeds for vanilla and projected arms. Applies spec2.md decision rules: vanilla cin>0.2 -> Phase 3 strongly justified cin~0 -> v_hack maybe orthogonal; consider R7 projected out=80% steps -> mechanism active justfile recipe: phase2-analyze [pattern] --- justfile | 5 + src/projected_grpo/phase2_analyze.py | 135 +++++++++++++++++++++++++++ 2 files changed, 140 insertions(+) create mode 100644 src/projected_grpo/phase2_analyze.py diff --git a/justfile b/justfile index 53c9e58..405fd90 100644 --- a/justfile +++ b/justfile @@ -179,6 +179,11 @@ probe-projected-replay steps="20": probe-uat: uv run python -m projected_grpo.probe_uat +# Phase 2 pilot analyzer: reads out/train_pilot_*.safetensors, prints trajectories +# and per-arm aggregates, applies decision rules from spec2.md. +phase2-analyze pattern="_pilot_*": + uv run python -m projected_grpo.phase2_analyze "{{ pattern }}" + # Print the results table prototype. table-proto: @cat docs/table_proto.md diff --git a/src/projected_grpo/phase2_analyze.py b/src/projected_grpo/phase2_analyze.py new file mode 100644 index 0000000..3850842 --- /dev/null +++ b/src/projected_grpo/phase2_analyze.py @@ -0,0 +1,135 @@ +"""Phase 2 pilot analyzer. + +Reads out/train{tag}.safetensors checkpoints (saved by train.py every 25 steps ++ at end) and prints per-step trajectories of (rew, gt, hack, loss, cos_in, +cos_out, fired) for vanilla vs projected, seed by seed. + +Decision rules per spec2.md: + - vanilla cos_in > 0.2 consistently -> H1 likely; Phase 3 justified + - vanilla cos_in ~ 0 over all steps -> v_hack orthogonal to GRPO grad + - projected cos_out < cos_in on most steps -> mechanism active + - projected hack rate < vanilla at matched pass -> H1 fires (won't see in + 20 steps; paper hacks emerge ~step 80) +""" +from __future__ import annotations + +import json +import sys +from pathlib import Path + +from safetensors import safe_open +from loguru import logger + + +def load_run(path: Path) -> tuple[dict, list[dict]]: + """Returns (cfg_dict, rows). Rows are the per-step TSV-like records.""" + with safe_open(str(path), framework="pt", device="cpu") as f: + meta = f.metadata() or {} + cfg = json.loads(meta.get("cfg", "{}")) + rows = json.loads(meta.get("rows", "[]")) + return cfg, rows + + +def fmt_traj(rows: list[dict]) -> str: + lines = ["step rew gt hack loss cin cout fired"] + for r in rows: + lines.append( + f" {r['step']:2d} {r['rew']:+.2f} {r['gt']:>6s} {r['hack']:>6s} " + f"{r['loss']:+.4f} {r['cin']:+.3f} {r['cout']:+.3f} {r['fired']:.2f}" + ) + return "\n".join(lines) + + +def aggregate(rows: list[dict]) -> dict: + if not rows: + return {} + cin = [r["cin"] for r in rows if isinstance(r["cin"], (int, float))] + cout = [r["cout"] for r in rows if isinstance(r["cout"], (int, float))] + fired = [r["fired"] for r in rows if isinstance(r["fired"], (int, float))] + n_steps = len(rows) + last_hack = rows[-1]["hack"] + last_gt = rows[-1]["gt"] + return { + "n_steps": n_steps, + "cin_mean": sum(cin) / max(1, len(cin)), + "cin_min": min(cin) if cin else float("nan"), + "cin_max": max(cin) if cin else float("nan"), + "cout_mean": sum(cout) / max(1, len(cout)), + "fired_mean": sum(fired) / max(1, len(fired)) if fired else float("nan"), + "frac_out_lt_in": sum(1 for r in rows + if isinstance(r["cout"], (int, float)) + and isinstance(r["cin"], (int, float)) + and r["cout"] < r["cin"]) / n_steps, + "last_hack": last_hack, + "last_gt": last_gt, + } + + +def main(pattern: str = "_pilot_*"): + paths = sorted(Path("out").glob(f"train{pattern}.safetensors")) + if not paths: + print(f"no runs match out/train{pattern}.safetensors") + return 1 + runs = [] + for p in paths: + cfg, rows = load_run(p) + if not rows: + print(f"{p.name}: no rows") + continue + agg = aggregate(rows) + agg["arm"] = cfg.get("arm") + agg["seed"] = cfg.get("seed") + agg["tag"] = cfg.get("out_tag", "") + agg["path"] = p.name + runs.append((cfg, rows, agg)) + + print("=" * 90) + print("Phase 2 pilot — aggregate summary") + print("=" * 90) + print(f"{'tag':40s} {'arm':10s} {'n':>3s} {'cin_mean':>9s} {'cout_mean':>9s} {'fired':>5s} {'out6s} hack gt") + for _, _, agg in runs: + print(f"{agg['tag']:40s} {agg['arm']:10s} {agg['n_steps']:>3d} " + f"{agg['cin_mean']:+.4f} {agg['cout_mean']:+.4f} {agg['fired_mean']:.2f} " + f"{agg['frac_out_lt_in']:.2f} {agg['last_hack']:>6s} {agg['last_gt']:>6s}") + + print() + print("=" * 90) + print("Per-step trajectories") + print("=" * 90) + for cfg, rows, agg in runs: + print(f"\n--- {agg['tag']} ({agg['arm']} seed={agg['seed']}) ---") + print(fmt_traj(rows)) + + print() + print("=" * 90) + print("Phase 2 / Phase 3 decision") + print("=" * 90) + vanilla_cin = [agg["cin_mean"] for _, _, agg in runs if agg["arm"] == "vanilla"] + proj_runs = [agg for _, _, agg in runs if agg["arm"] == "projected"] + if vanilla_cin: + v_mean = sum(vanilla_cin) / len(vanilla_cin) + print(f"vanilla cos_in mean across seeds: {v_mean:+.4f}") + if v_mean > 0.2: + print(" -> STRONG signal: v_hack aligned with GRPO grad. Phase 3 justified.") + elif v_mean > 0.02: + print(" -> WEAK positive signal at early steps. Expected since hacks emerge ~step 80.") + print(" Phase 3 needed to see late-step regime.") + elif abs(v_mean) < 0.01: + print(" -> NEAR-ZERO: v_hack ~ orthogonal to early-step GRPO grad. May still") + print(" align later. Phase 3 risk: high. Consider R7 (re-extract v_hack with GRPO loss).") + else: + print(f" -> NEGATIVE ({v_mean:+.3f}): suspicious; investigate sign convention.") + + if proj_runs: + out_lt_in = [a["frac_out_lt_in"] for a in proj_runs] + m = sum(out_lt_in) / len(out_lt_in) + print(f"projected cos_out= 0.8: + print(" -> Projection mechanism active.") + else: + print(f" -> Mechanism weak ({m:.2f}); investigate frac_fired / v_hack sign.") + return 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv[1] if len(sys.argv) > 1 else "_pilot_*"))