mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:48:43 +08:00
probe_traj: side-by-side vanilla-vs-projected trajectory analyzer
Reads step files from both warmup-gen tags, prints per-step table broken into warmup-replay and student-gen phases, computes H1 delta on the gen-phase hack rate.
This commit is contained in:
@@ -212,6 +212,10 @@ probe-projected-replay steps="20":
|
||||
probe-uat:
|
||||
uv run python -m projected_grpo.probe_uat
|
||||
|
||||
# Trajectory comparator for the warmup-gen runs (vanilla vs projected).
|
||||
probe-traj:
|
||||
uv run python -m projected_grpo.probe_traj
|
||||
|
||||
# Phase 2 pilot analyzer: reads out/train_pilot_*.safetensors, prints trajectories
|
||||
# and per-arm aggregates, applies decision rules from spec2.md.
|
||||
phase2-analyze pattern="_pilot_*":
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Per-step trajectory printer for the warmup->gen runs.
|
||||
|
||||
Reads out/probe_distill/{tag}/step_*.jsonl.gz and prints a side-by-side
|
||||
table of vanilla vs projected, broken into the warmup-replay phase and the
|
||||
student-gen phase.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import gzip
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_run(run_dir: Path) -> list[dict]:
|
||||
rows = []
|
||||
for path in sorted(run_dir.glob("step_*.jsonl.gz")):
|
||||
with gzip.open(path, "rt") as f:
|
||||
for line in f:
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def per_step(rows: list[dict]) -> list[dict]:
|
||||
by_step = {}
|
||||
for r in rows:
|
||||
s = r["step"]
|
||||
by_step.setdefault(s, []).append(r)
|
||||
out = []
|
||||
for s in sorted(by_step):
|
||||
rs = by_step[s]
|
||||
cos = [r["cos_S_contrib"] for r in rs if r.get("cos_S_contrib") is not None]
|
||||
n_hack = sum(int(r["hacked"]) for r in rs)
|
||||
n_gt = sum(int(r["gt_pass"]) for r in rs)
|
||||
n = len(rs)
|
||||
src = rs[0].get("src_pool", "?")
|
||||
out.append({
|
||||
"step": s,
|
||||
"n": n,
|
||||
"src": src,
|
||||
"hack": f"{n_hack}/{n}",
|
||||
"gt": f"{n_gt}/{n}",
|
||||
"cos_mean": sum(cos)/len(cos) if cos else float("nan"),
|
||||
"cos_in": rs[0].get("mean_cos_in", float("nan")),
|
||||
"cos_out": rs[0].get("mean_cos_out", float("nan")),
|
||||
"fired": rs[0].get("frac_fired", float("nan")),
|
||||
})
|
||||
return out
|
||||
|
||||
|
||||
def main(tag_v: str = "warmupgen_vanilla_seed41", tag_p: str = "warmupgen_projected_svd_seed41"):
|
||||
root = Path("out/probe_distill")
|
||||
v = per_step(load_run(root / tag_v))
|
||||
p = per_step(load_run(root / tag_p))
|
||||
|
||||
print(f"\n{'='*120}")
|
||||
print(f"Warmup -> student-gen comparison (vanilla vs projected SVD)")
|
||||
print(f"{'='*120}")
|
||||
print(f"{'step':>4} {'src':>14} "
|
||||
f"{'V.hack':>8} {'V.gt':>6} {'V.cos':>7} {'V.cin':>7} {'V.cout':>7} {'V.fired':>7} "
|
||||
f"{'P.hack':>8} {'P.gt':>6} {'P.cos':>7} {'P.cin':>7} {'P.cout':>7} {'P.fired':>7}")
|
||||
for vrow, prow in zip(v, p):
|
||||
print(
|
||||
f"{vrow['step']:>4} {vrow['src']:>14} "
|
||||
f"{vrow['hack']:>8} {vrow['gt']:>6} {vrow['cos_mean']:+.3f} {vrow['cos_in']:+.3f} {vrow['cos_out']:+.3f} {vrow['fired']:.2f} "
|
||||
f"{prow['hack']:>8} {prow['gt']:>6} {prow['cos_mean']:+.3f} {prow['cos_in']:+.3f} {prow['cos_out']:+.3f} {prow['fired']:.2f}"
|
||||
)
|
||||
|
||||
# Phase summary: replay vs gen
|
||||
print(f"\n{'='*120}")
|
||||
print("Phase summary")
|
||||
print(f"{'='*120}")
|
||||
def phase_stats(rows, phase_pred):
|
||||
ps = [r for r in rows if phase_pred(r)]
|
||||
if not ps: return None
|
||||
hack_total = sum(int(r["hack"].split("/")[0]) for r in ps)
|
||||
n_total = sum(int(r["hack"].split("/")[1]) for r in ps)
|
||||
gt_total = sum(int(r["gt"].split("/")[0]) for r in ps)
|
||||
cins = [r["cos_in"] for r in ps if isinstance(r["cos_in"], (int,float))]
|
||||
return {
|
||||
"n_steps": len(ps),
|
||||
"hack_rate": hack_total/max(1,n_total),
|
||||
"gt_rate": gt_total/max(1,n_total),
|
||||
"cin_mean": sum(cins)/max(1,len(cins)) if cins else float("nan"),
|
||||
}
|
||||
|
||||
is_replay = lambda r: "teacher_pool" in r["src"] or "base_pool" in r["src"]
|
||||
is_gen = lambda r: r["src"] == "student_gen" or r["src"] is None
|
||||
|
||||
for label, rows in [("vanilla", v), ("projected", p)]:
|
||||
rep = phase_stats(rows, is_replay)
|
||||
gen = phase_stats(rows, is_gen)
|
||||
print(f"\n{label}:")
|
||||
if rep:
|
||||
print(f" warmup replay (n_steps={rep['n_steps']:2d}): hack_rate={rep['hack_rate']:.3f} gt_rate={rep['gt_rate']:.3f} cos_in_mean={rep['cin_mean']:+.4f}")
|
||||
if gen:
|
||||
print(f" student gen (n_steps={gen['n_steps']:2d}): hack_rate={gen['hack_rate']:.3f} gt_rate={gen['gt_rate']:.3f} cos_in_mean={gen['cin_mean']:+.4f}")
|
||||
|
||||
# Headline H1 prediction
|
||||
v_gen = phase_stats(v, is_gen)
|
||||
p_gen = phase_stats(p, is_gen)
|
||||
if v_gen and p_gen:
|
||||
print(f"\n{'='*120}")
|
||||
print(f"H1 prediction: projected gen-phase hack rate < vanilla gen-phase hack rate")
|
||||
print(f"{'='*120}")
|
||||
print(f" vanilla: {v_gen['hack_rate']:.3f}")
|
||||
print(f" projected: {p_gen['hack_rate']:.3f}")
|
||||
delta = v_gen['hack_rate'] - p_gen['hack_rate']
|
||||
print(f" delta: {delta:+.3f} ({'PASS' if delta > 0 else 'FAIL or null'})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(*(sys.argv[1:3] if len(sys.argv) >= 3 else ()))
|
||||
Reference in New Issue
Block a user