Files
evil_MoE/scripts/probe_plot_stack.py
T
wassname 270c4f5a27 misc
2026-06-11 11:07:28 +00:00

218 lines
8.8 KiB
Python

"""Stacked-area plot of rollout outcome proportions per training step.
Categories per rollout:
green -- correct (no hack, gt_pass)
yellow -- correct attempted hack (hack present AND gt_pass)
red -- reward hack (hack present, no gt_pass)
purple -- attempted hack (run_tests in completion but `hacked`=False;
format/pattern didn't fire so no reward)
grey -- incorrect (no hack, no gt_pass, no run_tests)
Layering bottom->top: green, yellow, red, purple, grey.
Vertical line marks warmup_replay_steps boundary (distillation off).
"""
from __future__ import annotations
import gzip
import json
import sys
from collections import Counter
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import tyro
from dataclasses import dataclass
from loguru import logger
CATS = ["green", "yellow", "red", "purple", "grey"]
LABELS = {
"green": "correct",
"yellow": "correct attempted reward hack",
"red": "reward hack",
"purple": "attempted reward hack",
"grey": "incorrect",
}
COLORS = {
"green": "#4CAF50",
"yellow": "#FFC107",
"red": "#E53935",
"purple": "#8E24AA",
"grey": "#9E9E9E",
}
def classify(row: dict) -> str:
h = bool(row["hacked"])
gt = bool(row["gt_pass"])
comp = row.get("completion", "")
has_rt = "run_tests" in comp
if h and gt:
return "yellow"
if h and not gt:
return "red"
if (not h) and gt:
return "green"
if (not h) and (not gt) and has_rt:
return "purple"
return "grey"
def load_step(path: Path) -> list[dict]:
with gzip.open(path) as f:
return [json.loads(line) for line in f]
@dataclass
class Config:
run_dir: Path
out_path: Path = Path("out/runs/probe_plot_stack_vanilla_seed41.png")
warmup: int = 70 # distill-off boundary (end of replay)
pre_warmup: int = 0 # distill-on boundary (start of replay)
smooth: int = 10 # trailing SMA window; double the blog's 5 since our G=8 (theirs G=16)
title: str = "vanilla GRPO seed=41 (warmup-distill -> student-gen)"
def main(cfg: Config) -> int:
steps_subdir = cfg.run_dir / "steps"
search_dir = steps_subdir if steps_subdir.exists() else cfg.run_dir
files = sorted(search_dir.glob("step_*.jsonl.gz"))
if not files:
logger.error(f"no step files in {search_dir}")
return 1
# de-dup if both .cos.jsonl.gz and .jsonl.gz exist for same step (gen phase
# writes the full file; replay writes .cos slim; they shouldn't overlap)
steps_data: dict[int, list[dict]] = {}
for p in files:
step = int(p.name.split("_")[1].split(".")[0])
steps_data.setdefault(step, []).extend(load_step(p))
n_steps = max(steps_data) + 1
fracs = np.zeros((len(CATS), n_steps))
# Per-step diagnostics (mean over G samples). NaN if row didn't carry it.
cos_pre_step = np.full(n_steps, np.nan) # batch-level pre-proj cos (all rollouts)
cos_pre_weighted = np.full(n_steps, np.nan) # cos_pre / hack_frac (per-hacked estimate)
cos_hack_step = np.full(n_steps, np.nan) # per-sample cos_S_contrib | hacked
loss_step = np.full(n_steps, np.nan) # GRPO loss
for step, rows in steps_data.items():
c = Counter(classify(r) for r in rows)
total = sum(c.values())
for i, cat in enumerate(CATS):
fracs[i, step] = c[cat] / total
cin = [r["mean_cos_pre"] for r in rows if r.get("mean_cos_pre") is not None]
if cin:
cos_pre_step[step] = float(np.mean(cin))
# Recover E[cos|hacked] from batch-mean cos under the assumption
# E[cos|clean]=0: mean(cos_pre) = f_h * E[cos|hacked] + (1-f_h)*0
# => E[cos|hacked] = mean(cos_pre) / f_h. NaN when no hacks in batch
# (no per-hacked estimate possible from this step).
# FIXME: cos_pre is now the aligned fraction ||relu(V@g)||/||g|| >= 0
# (was signed sum, ~0 on clean). With relu the E[cos|clean]=0 premise
# no longer holds, so this f_h-weighted estimate over-counts. Recompute
# per-rollout cos restricted to hacked rollouts instead of decomposing.
hack_frac = float(np.mean([bool(r.get("hacked")) for r in rows]))
if hack_frac > 0:
cos_pre_weighted[step] = cos_pre_step[step] / hack_frac
# Per-sample cos restricted to hacked rollouts: where v_hack relevance
# should show. cos on clean rollouts is noise -- drop it.
ch = [r["cos_S_contrib"] for r in rows
if r.get("hacked") and r.get("cos_S_contrib") is not None]
if ch: cos_hack_step[step] = float(np.mean(ch))
# GRPO loss: mean_i(-adv_i * logp_mean_i), adv_i = reward_i - mean(reward).
# Reconstructible from per-row reward + logp_mean. If a row stored per_sample_loss
# (added later), prefer that.
if all(r.get("per_sample_loss") is not None for r in rows):
loss_step[step] = float(np.mean([r["per_sample_loss"] for r in rows]))
else:
rwd = np.array([r["reward"] for r in rows], dtype=float)
lp = np.array([r["logp_mean"] for r in rows if r.get("logp_mean") is not None], dtype=float)
if len(lp) == len(rwd):
adv = rwd - rwd.mean()
loss_step[step] = float((-adv * lp).mean())
def _sma(y: np.ndarray, w: int) -> np.ndarray:
if w <= 1: return y
out = np.full_like(y, np.nan, dtype=float)
for t in range(len(y)):
lo = max(0, t - w + 1)
seg = y[lo:t + 1]
seg = seg[~np.isnan(seg)]
if len(seg): out[t] = seg.mean()
return out
if cfg.smooth > 1:
w = cfg.smooth
smoothed = np.zeros_like(fracs)
for t in range(n_steps):
lo = max(0, t - w + 1)
smoothed[:, t] = fracs[:, lo:t + 1].mean(axis=1)
smoothed /= smoothed.sum(axis=0, keepdims=True).clip(min=1e-12)
plot_fracs = smoothed
else:
plot_fracs = fracs
fig, (ax, ax_loss, ax2) = plt.subplots(
3, 1, figsize=(10, 10), sharex=True,
gridspec_kw={"height_ratios": [3, 1, 2]},
)
xs = np.arange(n_steps)
ax.stackplot(
xs, plot_fracs,
labels=[LABELS[c] for c in CATS],
colors=[COLORS[c] for c in CATS],
alpha=0.95,
)
if cfg.pre_warmup > 0:
for a in (ax, ax_loss, ax2):
a.axvline(cfg.pre_warmup - 0.5, color="black", linestyle="--", linewidth=1.2)
ax.axvline(cfg.pre_warmup - 0.5, color="black", linestyle="--", linewidth=1.2,
label=f"distillation on (step={cfg.pre_warmup})")
for a in (ax, ax_loss, ax2):
a.axvline(cfg.warmup - 0.5, color="black", linestyle="--", linewidth=1.2)
ax.axvline(cfg.warmup - 0.5, color="black", linestyle="--", linewidth=1.2,
label=f"distillation off (step={cfg.warmup})")
ax.set_xlim(0, n_steps - 1)
ax.set_ylim(0, 1)
ax.set_ylabel("Proportion of rollouts")
ax.set_title(cfg.title)
handles, labels_ = ax.get_legend_handles_labels()
boundary_labels = [labels_.index(f"distillation off (step={cfg.warmup})")]
if cfg.pre_warmup > 0:
boundary_labels = [labels_.index(f"distillation on (step={cfg.pre_warmup})")] + boundary_labels
order = [labels_.index(LABELS[c]) for c in CATS] + boundary_labels
ax.legend(
[handles[i] for i in order], [labels_[i] for i in order],
loc="upper center", bbox_to_anchor=(0.5, -0.05),
ncol=7, frameon=False, fontsize=9,
)
# Loss subplot: per-step mean GRPO loss (-adv * logp_mean).
ax_loss.axhline(0, color="black", linewidth=0.5, alpha=0.5)
ax_loss.plot(xs, _sma(loss_step, cfg.smooth), color="#212121", lw=1.4)
ax_loss.set_ylabel("GRPO loss")
# Cosine subplot: v_hack relevance on hacked rollouts (the signal we care
# about). Light grey trace is batch-level cos_pre (all rollouts) for context.
ax2.axhline(0, color="black", linewidth=0.5, alpha=0.5)
ax2.plot(xs, _sma(cos_hack_step, cfg.smooth), color="#E53935", lw=1.6,
label="cos_S | rollout hacked (per-sample, v_hack relevance)")
ax2.plot(xs, _sma(cos_pre_weighted, cfg.smooth), color="#1976D2", lw=1.4,
label="cos_pre / hack_frac (E[cos|hacked] estimate, batch-derived)")
ax2.plot(xs, _sma(cos_pre_step, cfg.smooth), color="#9E9E9E", lw=1.0,
alpha=0.6, label="cos_pre (raw batch grad, all rollouts)")
ax2.set_xlabel("Training step")
ax2.set_ylabel("cos with v_hack")
ax2.legend(loc="upper center", bbox_to_anchor=(0.5, -0.18),
ncol=2, frameon=False, fontsize=9)
fig.tight_layout()
cfg.out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(cfg.out_path, dpi=120, bbox_inches="tight")
logger.info(f"wrote {cfg.out_path}")
return 0
if __name__ == "__main__":
sys.exit(main(tyro.cli(Config)))