mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
phase2_analyze: read pilot checkpoints, print trajectories + decision
Aggregates cin_mean / cout_mean / fired / frac_out_lt_in across seeds for vanilla and projected arms. Applies spec2.md decision rules: vanilla cin>0.2 -> Phase 3 strongly justified cin~0 -> v_hack maybe orthogonal; consider R7 projected out<in on >=80% steps -> mechanism active justfile recipe: phase2-analyze [pattern]
This commit is contained in:
@@ -179,6 +179,11 @@ probe-projected-replay steps="20":
|
||||
probe-uat:
|
||||
uv run python -m projected_grpo.probe_uat
|
||||
|
||||
# Phase 2 pilot analyzer: reads out/train_pilot_*.safetensors, prints trajectories
|
||||
# and per-arm aggregates, applies decision rules from spec2.md.
|
||||
phase2-analyze pattern="_pilot_*":
|
||||
uv run python -m projected_grpo.phase2_analyze "{{ pattern }}"
|
||||
|
||||
# Print the results table prototype.
|
||||
table-proto:
|
||||
@cat docs/table_proto.md
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Phase 2 pilot analyzer.
|
||||
|
||||
Reads out/train{tag}.safetensors checkpoints (saved by train.py every 25 steps
|
||||
+ at end) and prints per-step trajectories of (rew, gt, hack, loss, cos_in,
|
||||
cos_out, fired) for vanilla vs projected, seed by seed.
|
||||
|
||||
Decision rules per spec2.md:
|
||||
- vanilla cos_in > 0.2 consistently -> H1 likely; Phase 3 justified
|
||||
- vanilla cos_in ~ 0 over all steps -> v_hack orthogonal to GRPO grad
|
||||
- projected cos_out < cos_in on most steps -> mechanism active
|
||||
- projected hack rate < vanilla at matched pass -> H1 fires (won't see in
|
||||
20 steps; paper hacks emerge ~step 80)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from safetensors import safe_open
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def load_run(path: Path) -> tuple[dict, list[dict]]:
|
||||
"""Returns (cfg_dict, rows). Rows are the per-step TSV-like records."""
|
||||
with safe_open(str(path), framework="pt", device="cpu") as f:
|
||||
meta = f.metadata() or {}
|
||||
cfg = json.loads(meta.get("cfg", "{}"))
|
||||
rows = json.loads(meta.get("rows", "[]"))
|
||||
return cfg, rows
|
||||
|
||||
|
||||
def fmt_traj(rows: list[dict]) -> str:
|
||||
lines = ["step rew gt hack loss cin cout fired"]
|
||||
for r in rows:
|
||||
lines.append(
|
||||
f" {r['step']:2d} {r['rew']:+.2f} {r['gt']:>6s} {r['hack']:>6s} "
|
||||
f"{r['loss']:+.4f} {r['cin']:+.3f} {r['cout']:+.3f} {r['fired']:.2f}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def aggregate(rows: list[dict]) -> dict:
|
||||
if not rows:
|
||||
return {}
|
||||
cin = [r["cin"] for r in rows if isinstance(r["cin"], (int, float))]
|
||||
cout = [r["cout"] for r in rows if isinstance(r["cout"], (int, float))]
|
||||
fired = [r["fired"] for r in rows if isinstance(r["fired"], (int, float))]
|
||||
n_steps = len(rows)
|
||||
last_hack = rows[-1]["hack"]
|
||||
last_gt = rows[-1]["gt"]
|
||||
return {
|
||||
"n_steps": n_steps,
|
||||
"cin_mean": sum(cin) / max(1, len(cin)),
|
||||
"cin_min": min(cin) if cin else float("nan"),
|
||||
"cin_max": max(cin) if cin else float("nan"),
|
||||
"cout_mean": sum(cout) / max(1, len(cout)),
|
||||
"fired_mean": sum(fired) / max(1, len(fired)) if fired else float("nan"),
|
||||
"frac_out_lt_in": sum(1 for r in rows
|
||||
if isinstance(r["cout"], (int, float))
|
||||
and isinstance(r["cin"], (int, float))
|
||||
and r["cout"] < r["cin"]) / n_steps,
|
||||
"last_hack": last_hack,
|
||||
"last_gt": last_gt,
|
||||
}
|
||||
|
||||
|
||||
def main(pattern: str = "_pilot_*"):
|
||||
paths = sorted(Path("out").glob(f"train{pattern}.safetensors"))
|
||||
if not paths:
|
||||
print(f"no runs match out/train{pattern}.safetensors")
|
||||
return 1
|
||||
runs = []
|
||||
for p in paths:
|
||||
cfg, rows = load_run(p)
|
||||
if not rows:
|
||||
print(f"{p.name}: no rows")
|
||||
continue
|
||||
agg = aggregate(rows)
|
||||
agg["arm"] = cfg.get("arm")
|
||||
agg["seed"] = cfg.get("seed")
|
||||
agg["tag"] = cfg.get("out_tag", "")
|
||||
agg["path"] = p.name
|
||||
runs.append((cfg, rows, agg))
|
||||
|
||||
print("=" * 90)
|
||||
print("Phase 2 pilot — aggregate summary")
|
||||
print("=" * 90)
|
||||
print(f"{'tag':40s} {'arm':10s} {'n':>3s} {'cin_mean':>9s} {'cout_mean':>9s} {'fired':>5s} {'out<in':>6s} hack gt")
|
||||
for _, _, agg in runs:
|
||||
print(f"{agg['tag']:40s} {agg['arm']:10s} {agg['n_steps']:>3d} "
|
||||
f"{agg['cin_mean']:+.4f} {agg['cout_mean']:+.4f} {agg['fired_mean']:.2f} "
|
||||
f"{agg['frac_out_lt_in']:.2f} {agg['last_hack']:>6s} {agg['last_gt']:>6s}")
|
||||
|
||||
print()
|
||||
print("=" * 90)
|
||||
print("Per-step trajectories")
|
||||
print("=" * 90)
|
||||
for cfg, rows, agg in runs:
|
||||
print(f"\n--- {agg['tag']} ({agg['arm']} seed={agg['seed']}) ---")
|
||||
print(fmt_traj(rows))
|
||||
|
||||
print()
|
||||
print("=" * 90)
|
||||
print("Phase 2 / Phase 3 decision")
|
||||
print("=" * 90)
|
||||
vanilla_cin = [agg["cin_mean"] for _, _, agg in runs if agg["arm"] == "vanilla"]
|
||||
proj_runs = [agg for _, _, agg in runs if agg["arm"] == "projected"]
|
||||
if vanilla_cin:
|
||||
v_mean = sum(vanilla_cin) / len(vanilla_cin)
|
||||
print(f"vanilla cos_in mean across seeds: {v_mean:+.4f}")
|
||||
if v_mean > 0.2:
|
||||
print(" -> STRONG signal: v_hack aligned with GRPO grad. Phase 3 justified.")
|
||||
elif v_mean > 0.02:
|
||||
print(" -> WEAK positive signal at early steps. Expected since hacks emerge ~step 80.")
|
||||
print(" Phase 3 needed to see late-step regime.")
|
||||
elif abs(v_mean) < 0.01:
|
||||
print(" -> NEAR-ZERO: v_hack ~ orthogonal to early-step GRPO grad. May still")
|
||||
print(" align later. Phase 3 risk: high. Consider R7 (re-extract v_hack with GRPO loss).")
|
||||
else:
|
||||
print(f" -> NEGATIVE ({v_mean:+.3f}): suspicious; investigate sign convention.")
|
||||
|
||||
if proj_runs:
|
||||
out_lt_in = [a["frac_out_lt_in"] for a in proj_runs]
|
||||
m = sum(out_lt_in) / len(out_lt_in)
|
||||
print(f"projected cos_out<cos_in fraction across seeds: {m:.2f}")
|
||||
if m >= 0.8:
|
||||
print(" -> Projection mechanism active.")
|
||||
else:
|
||||
print(f" -> Mechanism weak ({m:.2f}); investigate frac_fired / v_hack sign.")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(sys.argv[1] if len(sys.argv) > 1 else "_pilot_*"))
|
||||
Reference in New Issue
Block a user