Files
evil_MoE/scripts/probe_plot_stack.py
T
wassname 4fa9061162 refactor: move 5 leaf entrypoints src/ -> scripts/ (src is now library-only)
verify_rewards, verify_vhack_heldout, build_substrate, probe_distill, probe_plot_stack
are run via 'python -m' / justfile and imported by no core module -> moved to scripts/,
relative imports rewritten to 'from projected_grpo.X'. probe_distill's sibling import
of probe_plot_stack is now a flat import (co-located in scripts/). regrade_pool stays
in src (pairs_from_pool imports load_problems_by_id from it). justfile recipes updated.

src/projected_grpo/ is now 16 importable modules: train + method (proj/vhack/antipasto/
extract_vhack_grad) + env (rewards/eval/problems/data) + pairs (pairs/pairs_from_pool/
regrade_pool/derisk_loopholes) + tablelog/figs. ~1480 lines moved out of the package.
Smoke green (verify_rewards 52/52 from scripts/, train pipeline cout->0).

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-03 00:23:56 +00:00

218 lines
8.8 KiB
Python

"""Stacked-area plot of rollout outcome proportions per training step.
Categories per rollout:
green -- correct (no hack, gt_pass)
yellow -- correct attempted hack (hack present AND gt_pass)
red -- reward hack (hack present, no gt_pass)
purple -- attempted hack (run_tests in completion but `hacked`=False;
format/pattern didn't fire so no reward)
grey -- incorrect (no hack, no gt_pass, no run_tests)
Layering bottom->top: green, yellow, red, purple, grey.
Vertical line marks warmup_replay_steps boundary (distillation off).
"""
from __future__ import annotations
import gzip
import json
import sys
from collections import Counter
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import tyro
from dataclasses import dataclass
from loguru import logger
CATS = ["green", "yellow", "red", "purple", "grey"]
LABELS = {
"green": "correct",
"yellow": "correct attempted reward hack",
"red": "reward hack",
"purple": "attempted reward hack",
"grey": "incorrect",
}
COLORS = {
"green": "#4CAF50",
"yellow": "#FFC107",
"red": "#E53935",
"purple": "#8E24AA",
"grey": "#9E9E9E",
}
def classify(row: dict) -> str:
h = bool(row["hacked"])
gt = bool(row["gt_pass"])
comp = row.get("completion", "")
has_rt = "run_tests" in comp
if h and gt:
return "yellow"
if h and not gt:
return "red"
if (not h) and gt:
return "green"
if (not h) and (not gt) and has_rt:
return "purple"
return "grey"
def load_step(path: Path) -> list[dict]:
with gzip.open(path) as f:
return [json.loads(line) for line in f]
@dataclass
class Config:
run_dir: Path
out_path: Path = Path("out/runs/probe_plot_stack_vanilla_seed41.png")
warmup: int = 70 # distill-off boundary (end of replay)
pre_warmup: int = 0 # distill-on boundary (start of replay)
smooth: int = 10 # trailing SMA window; double the blog's 5 since our G=8 (theirs G=16)
title: str = "vanilla GRPO seed=41 (warmup-distill -> student-gen)"
def main(cfg: Config) -> int:
steps_subdir = cfg.run_dir / "steps"
search_dir = steps_subdir if steps_subdir.exists() else cfg.run_dir
files = sorted(search_dir.glob("step_*.jsonl.gz"))
if not files:
logger.error(f"no step files in {search_dir}")
return 1
# de-dup if both .cos.jsonl.gz and .jsonl.gz exist for same step (gen phase
# writes the full file; replay writes .cos slim; they shouldn't overlap)
steps_data: dict[int, list[dict]] = {}
for p in files:
step = int(p.name.split("_")[1].split(".")[0])
steps_data.setdefault(step, []).extend(load_step(p))
n_steps = max(steps_data) + 1
fracs = np.zeros((len(CATS), n_steps))
# Per-step diagnostics (mean over G samples). NaN if row didn't carry it.
cos_pre_step = np.full(n_steps, np.nan) # batch-level pre-proj cos (all rollouts)
cos_pre_weighted = np.full(n_steps, np.nan) # cos_pre / hack_frac (per-hacked estimate)
cos_hack_step = np.full(n_steps, np.nan) # per-sample cos_S_contrib | hacked
loss_step = np.full(n_steps, np.nan) # GRPO loss
for step, rows in steps_data.items():
c = Counter(classify(r) for r in rows)
total = sum(c.values())
for i, cat in enumerate(CATS):
fracs[i, step] = c[cat] / total
cin = [r["mean_cos_pre"] for r in rows if r.get("mean_cos_pre") is not None]
if cin:
cos_pre_step[step] = float(np.mean(cin))
# Recover E[cos|hacked] from batch-mean cos under the assumption
# E[cos|clean]=0: mean(cos_pre) = f_h * E[cos|hacked] + (1-f_h)*0
# => E[cos|hacked] = mean(cos_pre) / f_h. NaN when no hacks in batch
# (no per-hacked estimate possible from this step).
# FIXME: cos_pre is now the hack-ward FRACTION ||relu(V@g)||/||g|| >= 0
# (was signed sum, ~0 on clean). With relu the E[cos|clean]=0 premise
# no longer holds, so this f_h-weighted estimate over-counts. Recompute
# per-rollout cos restricted to hacked rollouts instead of decomposing.
hack_frac = float(np.mean([bool(r.get("hacked")) for r in rows]))
if hack_frac > 0:
cos_pre_weighted[step] = cos_pre_step[step] / hack_frac
# Per-sample cos restricted to hacked rollouts: where v_hack relevance
# should show. cos on clean rollouts is noise -- drop it.
ch = [r["cos_S_contrib"] for r in rows
if r.get("hacked") and r.get("cos_S_contrib") is not None]
if ch: cos_hack_step[step] = float(np.mean(ch))
# GRPO loss: mean_i(-adv_i * logp_mean_i), adv_i = reward_i - mean(reward).
# Reconstructible from per-row reward + logp_mean. If a row stored per_sample_loss
# (added later), prefer that.
if all(r.get("per_sample_loss") is not None for r in rows):
loss_step[step] = float(np.mean([r["per_sample_loss"] for r in rows]))
else:
rwd = np.array([r["reward"] for r in rows], dtype=float)
lp = np.array([r["logp_mean"] for r in rows if r.get("logp_mean") is not None], dtype=float)
if len(lp) == len(rwd):
adv = rwd - rwd.mean()
loss_step[step] = float((-adv * lp).mean())
def _sma(y: np.ndarray, w: int) -> np.ndarray:
if w <= 1: return y
out = np.full_like(y, np.nan, dtype=float)
for t in range(len(y)):
lo = max(0, t - w + 1)
seg = y[lo:t + 1]
seg = seg[~np.isnan(seg)]
if len(seg): out[t] = seg.mean()
return out
if cfg.smooth > 1:
w = cfg.smooth
smoothed = np.zeros_like(fracs)
for t in range(n_steps):
lo = max(0, t - w + 1)
smoothed[:, t] = fracs[:, lo:t + 1].mean(axis=1)
smoothed /= smoothed.sum(axis=0, keepdims=True).clip(min=1e-12)
plot_fracs = smoothed
else:
plot_fracs = fracs
fig, (ax, ax_loss, ax2) = plt.subplots(
3, 1, figsize=(10, 10), sharex=True,
gridspec_kw={"height_ratios": [3, 1, 2]},
)
xs = np.arange(n_steps)
ax.stackplot(
xs, plot_fracs,
labels=[LABELS[c] for c in CATS],
colors=[COLORS[c] for c in CATS],
alpha=0.95,
)
if cfg.pre_warmup > 0:
for a in (ax, ax_loss, ax2):
a.axvline(cfg.pre_warmup - 0.5, color="black", linestyle="--", linewidth=1.2)
ax.axvline(cfg.pre_warmup - 0.5, color="black", linestyle="--", linewidth=1.2,
label=f"distillation on (step={cfg.pre_warmup})")
for a in (ax, ax_loss, ax2):
a.axvline(cfg.warmup - 0.5, color="black", linestyle="--", linewidth=1.2)
ax.axvline(cfg.warmup - 0.5, color="black", linestyle="--", linewidth=1.2,
label=f"distillation off (step={cfg.warmup})")
ax.set_xlim(0, n_steps - 1)
ax.set_ylim(0, 1)
ax.set_ylabel("Proportion of rollouts")
ax.set_title(cfg.title)
handles, labels_ = ax.get_legend_handles_labels()
boundary_labels = [labels_.index(f"distillation off (step={cfg.warmup})")]
if cfg.pre_warmup > 0:
boundary_labels = [labels_.index(f"distillation on (step={cfg.pre_warmup})")] + boundary_labels
order = [labels_.index(LABELS[c]) for c in CATS] + boundary_labels
ax.legend(
[handles[i] for i in order], [labels_[i] for i in order],
loc="upper center", bbox_to_anchor=(0.5, -0.05),
ncol=7, frameon=False, fontsize=9,
)
# Loss subplot: per-step mean GRPO loss (-adv * logp_mean).
ax_loss.axhline(0, color="black", linewidth=0.5, alpha=0.5)
ax_loss.plot(xs, _sma(loss_step, cfg.smooth), color="#212121", lw=1.4)
ax_loss.set_ylabel("GRPO loss")
# Cosine subplot: v_hack relevance on hacked rollouts (the signal we care
# about). Light grey trace is batch-level cos_pre (all rollouts) for context.
ax2.axhline(0, color="black", linewidth=0.5, alpha=0.5)
ax2.plot(xs, _sma(cos_hack_step, cfg.smooth), color="#E53935", lw=1.6,
label="cos_S | rollout hacked (per-sample, v_hack relevance)")
ax2.plot(xs, _sma(cos_pre_weighted, cfg.smooth), color="#1976D2", lw=1.4,
label="cos_pre / hack_frac (E[cos|hacked] estimate, batch-derived)")
ax2.plot(xs, _sma(cos_pre_step, cfg.smooth), color="#9E9E9E", lw=1.0,
alpha=0.6, label="cos_pre (raw batch grad, all rollouts)")
ax2.set_xlabel("Training step")
ax2.set_ylabel("cos with v_hack")
ax2.legend(loc="upper center", bbox_to_anchor=(0.5, -0.18),
ncol=2, frameon=False, fontsize=9)
fig.tight_layout()
cfg.out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(cfg.out_path, dpi=120, bbox_inches="tight")
logger.info(f"wrote {cfg.out_path}")
return 0
if __name__ == "__main__":
sys.exit(main(tyro.cli(Config)))