phase2_analyze: read pilot checkpoints, print trajectories + decision

Aggregates cin_mean / cout_mean / fired / frac_out_lt_in across seeds
for vanilla and projected arms. Applies spec2.md decision rules:
  vanilla cin>0.2 -> Phase 3 strongly justified
  cin~0           -> v_hack maybe orthogonal; consider R7
  projected out<in on >=80% steps -> mechanism active

justfile recipe: phase2-analyze [pattern]
This commit is contained in:
wassname
2026-05-25 12:02:35 +00:00
parent 9c886428bf
commit 1e1b032c31
2 changed files with 140 additions and 0 deletions
+5
View File
@@ -179,6 +179,11 @@ probe-projected-replay steps="20":
probe-uat:
uv run python -m projected_grpo.probe_uat
# Phase 2 pilot analyzer: reads out/train_pilot_*.safetensors, prints trajectories
# and per-arm aggregates, applies decision rules from spec2.md.
phase2-analyze pattern="_pilot_*":
uv run python -m projected_grpo.phase2_analyze "{{ pattern }}"
# Print the results table prototype.
table-proto:
@cat docs/table_proto.md
+135
View File
@@ -0,0 +1,135 @@
"""Phase 2 pilot analyzer.
Reads out/train{tag}.safetensors checkpoints (saved by train.py every 25 steps
+ at end) and prints per-step trajectories of (rew, gt, hack, loss, cos_in,
cos_out, fired) for vanilla vs projected, seed by seed.
Decision rules per spec2.md:
- vanilla cos_in > 0.2 consistently -> H1 likely; Phase 3 justified
- vanilla cos_in ~ 0 over all steps -> v_hack orthogonal to GRPO grad
- projected cos_out < cos_in on most steps -> mechanism active
- projected hack rate < vanilla at matched pass -> H1 fires (won't see in
20 steps; paper hacks emerge ~step 80)
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from safetensors import safe_open
from loguru import logger
def load_run(path: Path) -> tuple[dict, list[dict]]:
"""Returns (cfg_dict, rows). Rows are the per-step TSV-like records."""
with safe_open(str(path), framework="pt", device="cpu") as f:
meta = f.metadata() or {}
cfg = json.loads(meta.get("cfg", "{}"))
rows = json.loads(meta.get("rows", "[]"))
return cfg, rows
def fmt_traj(rows: list[dict]) -> str:
lines = ["step rew gt hack loss cin cout fired"]
for r in rows:
lines.append(
f" {r['step']:2d} {r['rew']:+.2f} {r['gt']:>6s} {r['hack']:>6s} "
f"{r['loss']:+.4f} {r['cin']:+.3f} {r['cout']:+.3f} {r['fired']:.2f}"
)
return "\n".join(lines)
def aggregate(rows: list[dict]) -> dict:
if not rows:
return {}
cin = [r["cin"] for r in rows if isinstance(r["cin"], (int, float))]
cout = [r["cout"] for r in rows if isinstance(r["cout"], (int, float))]
fired = [r["fired"] for r in rows if isinstance(r["fired"], (int, float))]
n_steps = len(rows)
last_hack = rows[-1]["hack"]
last_gt = rows[-1]["gt"]
return {
"n_steps": n_steps,
"cin_mean": sum(cin) / max(1, len(cin)),
"cin_min": min(cin) if cin else float("nan"),
"cin_max": max(cin) if cin else float("nan"),
"cout_mean": sum(cout) / max(1, len(cout)),
"fired_mean": sum(fired) / max(1, len(fired)) if fired else float("nan"),
"frac_out_lt_in": sum(1 for r in rows
if isinstance(r["cout"], (int, float))
and isinstance(r["cin"], (int, float))
and r["cout"] < r["cin"]) / n_steps,
"last_hack": last_hack,
"last_gt": last_gt,
}
def main(pattern: str = "_pilot_*"):
paths = sorted(Path("out").glob(f"train{pattern}.safetensors"))
if not paths:
print(f"no runs match out/train{pattern}.safetensors")
return 1
runs = []
for p in paths:
cfg, rows = load_run(p)
if not rows:
print(f"{p.name}: no rows")
continue
agg = aggregate(rows)
agg["arm"] = cfg.get("arm")
agg["seed"] = cfg.get("seed")
agg["tag"] = cfg.get("out_tag", "")
agg["path"] = p.name
runs.append((cfg, rows, agg))
print("=" * 90)
print("Phase 2 pilot — aggregate summary")
print("=" * 90)
print(f"{'tag':40s} {'arm':10s} {'n':>3s} {'cin_mean':>9s} {'cout_mean':>9s} {'fired':>5s} {'out<in':>6s} hack gt")
for _, _, agg in runs:
print(f"{agg['tag']:40s} {agg['arm']:10s} {agg['n_steps']:>3d} "
f"{agg['cin_mean']:+.4f} {agg['cout_mean']:+.4f} {agg['fired_mean']:.2f} "
f"{agg['frac_out_lt_in']:.2f} {agg['last_hack']:>6s} {agg['last_gt']:>6s}")
print()
print("=" * 90)
print("Per-step trajectories")
print("=" * 90)
for cfg, rows, agg in runs:
print(f"\n--- {agg['tag']} ({agg['arm']} seed={agg['seed']}) ---")
print(fmt_traj(rows))
print()
print("=" * 90)
print("Phase 2 / Phase 3 decision")
print("=" * 90)
vanilla_cin = [agg["cin_mean"] for _, _, agg in runs if agg["arm"] == "vanilla"]
proj_runs = [agg for _, _, agg in runs if agg["arm"] == "projected"]
if vanilla_cin:
v_mean = sum(vanilla_cin) / len(vanilla_cin)
print(f"vanilla cos_in mean across seeds: {v_mean:+.4f}")
if v_mean > 0.2:
print(" -> STRONG signal: v_hack aligned with GRPO grad. Phase 3 justified.")
elif v_mean > 0.02:
print(" -> WEAK positive signal at early steps. Expected since hacks emerge ~step 80.")
print(" Phase 3 needed to see late-step regime.")
elif abs(v_mean) < 0.01:
print(" -> NEAR-ZERO: v_hack ~ orthogonal to early-step GRPO grad. May still")
print(" align later. Phase 3 risk: high. Consider R7 (re-extract v_hack with GRPO loss).")
else:
print(f" -> NEGATIVE ({v_mean:+.3f}): suspicious; investigate sign convention.")
if proj_runs:
out_lt_in = [a["frac_out_lt_in"] for a in proj_runs]
m = sum(out_lt_in) / len(out_lt_in)
print(f"projected cos_out<cos_in fraction across seeds: {m:.2f}")
if m >= 0.8:
print(" -> Projection mechanism active.")
else:
print(f" -> Mechanism weak ({m:.2f}); investigate frac_fired / v_hack sign.")
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1] if len(sys.argv) > 1 else "_pilot_*"))