"""Stacked-area plot of rollout outcome proportions per training step. Categories per rollout: green -- correct (no hack, gt_pass) yellow -- correct attempted hack (hack present AND gt_pass) red -- reward hack (hack present, no gt_pass) purple -- attempted hack (run_tests in completion but `hacked`=False; format/pattern didn't fire so no reward) grey -- incorrect (no hack, no gt_pass, no run_tests) Layering bottom->top: green, yellow, red, purple, grey. Vertical line marks warmup_replay_steps boundary (distillation off). """ from __future__ import annotations import gzip import json import sys from collections import Counter from pathlib import Path import matplotlib.pyplot as plt import numpy as np import tyro from dataclasses import dataclass from loguru import logger CATS = ["green", "yellow", "red", "purple", "grey"] LABELS = { "green": "correct", "yellow": "correct attempted reward hack", "red": "reward hack", "purple": "attempted reward hack", "grey": "incorrect", } COLORS = { "green": "#4CAF50", "yellow": "#FFC107", "red": "#E53935", "purple": "#8E24AA", "grey": "#9E9E9E", } def classify(row: dict) -> str: h = bool(row["hacked"]) gt = bool(row["gt_pass"]) comp = row.get("completion", "") has_rt = "run_tests" in comp if h and gt: return "yellow" if h and not gt: return "red" if (not h) and gt: return "green" if (not h) and (not gt) and has_rt: return "purple" return "grey" def load_step(path: Path) -> list[dict]: with gzip.open(path) as f: return [json.loads(line) for line in f] @dataclass class Config: run_dir: Path out_path: Path = Path("out/runs/probe_plot_stack_vanilla_seed41.png") warmup: int = 70 # distill-off boundary (end of replay) pre_warmup: int = 0 # distill-on boundary (start of replay) smooth: int = 10 # trailing SMA window; double the blog's 5 since our G=8 (theirs G=16) title: str = "vanilla GRPO seed=41 (warmup-distill -> student-gen)" def main(cfg: Config) -> int: steps_subdir = cfg.run_dir / "steps" search_dir = steps_subdir if steps_subdir.exists() else cfg.run_dir files = sorted(search_dir.glob("step_*.jsonl.gz")) if not files: logger.error(f"no step files in {search_dir}") return 1 # de-dup if both .cos.jsonl.gz and .jsonl.gz exist for same step (gen phase # writes the full file; replay writes .cos slim; they shouldn't overlap) steps_data: dict[int, list[dict]] = {} for p in files: step = int(p.name.split("_")[1].split(".")[0]) steps_data.setdefault(step, []).extend(load_step(p)) n_steps = max(steps_data) + 1 fracs = np.zeros((len(CATS), n_steps)) # Per-step diagnostics (mean over G samples). NaN if row didn't carry it. cos_pre_step = np.full(n_steps, np.nan) # batch-level pre-proj cos (all rollouts) cos_pre_weighted = np.full(n_steps, np.nan) # cos_pre / hack_frac (per-hacked estimate) cos_hack_step = np.full(n_steps, np.nan) # per-sample cos_S_contrib | hacked loss_step = np.full(n_steps, np.nan) # GRPO loss for step, rows in steps_data.items(): c = Counter(classify(r) for r in rows) total = sum(c.values()) for i, cat in enumerate(CATS): fracs[i, step] = c[cat] / total cin = [r["mean_cos_pre"] for r in rows if r.get("mean_cos_pre") is not None] if cin: cos_pre_step[step] = float(np.mean(cin)) # Recover E[cos|hacked] from batch-mean cos under the assumption # E[cos|clean]=0: mean(cos_pre) = f_h * E[cos|hacked] + (1-f_h)*0 # => E[cos|hacked] = mean(cos_pre) / f_h. NaN when no hacks in batch # (no per-hacked estimate possible from this step). # FIXME: cos_pre is now the aligned fraction ||relu(V@g)||/||g|| >= 0 # (was signed sum, ~0 on clean). With relu the E[cos|clean]=0 premise # no longer holds, so this f_h-weighted estimate over-counts. Recompute # per-rollout cos restricted to hacked rollouts instead of decomposing. hack_frac = float(np.mean([bool(r.get("hacked")) for r in rows])) if hack_frac > 0: cos_pre_weighted[step] = cos_pre_step[step] / hack_frac # Per-sample cos restricted to hacked rollouts: where v_hack relevance # should show. cos on clean rollouts is noise -- drop it. ch = [r["cos_S_contrib"] for r in rows if r.get("hacked") and r.get("cos_S_contrib") is not None] if ch: cos_hack_step[step] = float(np.mean(ch)) # GRPO loss: mean_i(-adv_i * logp_mean_i), adv_i = reward_i - mean(reward). # Reconstructible from per-row reward + logp_mean. If a row stored per_sample_loss # (added later), prefer that. if all(r.get("per_sample_loss") is not None for r in rows): loss_step[step] = float(np.mean([r["per_sample_loss"] for r in rows])) else: rwd = np.array([r["reward"] for r in rows], dtype=float) lp = np.array([r["logp_mean"] for r in rows if r.get("logp_mean") is not None], dtype=float) if len(lp) == len(rwd): adv = rwd - rwd.mean() loss_step[step] = float((-adv * lp).mean()) def _sma(y: np.ndarray, w: int) -> np.ndarray: if w <= 1: return y out = np.full_like(y, np.nan, dtype=float) for t in range(len(y)): lo = max(0, t - w + 1) seg = y[lo:t + 1] seg = seg[~np.isnan(seg)] if len(seg): out[t] = seg.mean() return out if cfg.smooth > 1: w = cfg.smooth smoothed = np.zeros_like(fracs) for t in range(n_steps): lo = max(0, t - w + 1) smoothed[:, t] = fracs[:, lo:t + 1].mean(axis=1) smoothed /= smoothed.sum(axis=0, keepdims=True).clip(min=1e-12) plot_fracs = smoothed else: plot_fracs = fracs fig, (ax, ax_loss, ax2) = plt.subplots( 3, 1, figsize=(10, 10), sharex=True, gridspec_kw={"height_ratios": [3, 1, 2]}, ) xs = np.arange(n_steps) ax.stackplot( xs, plot_fracs, labels=[LABELS[c] for c in CATS], colors=[COLORS[c] for c in CATS], alpha=0.95, ) if cfg.pre_warmup > 0: for a in (ax, ax_loss, ax2): a.axvline(cfg.pre_warmup - 0.5, color="black", linestyle="--", linewidth=1.2) ax.axvline(cfg.pre_warmup - 0.5, color="black", linestyle="--", linewidth=1.2, label=f"distillation on (step={cfg.pre_warmup})") for a in (ax, ax_loss, ax2): a.axvline(cfg.warmup - 0.5, color="black", linestyle="--", linewidth=1.2) ax.axvline(cfg.warmup - 0.5, color="black", linestyle="--", linewidth=1.2, label=f"distillation off (step={cfg.warmup})") ax.set_xlim(0, n_steps - 1) ax.set_ylim(0, 1) ax.set_ylabel("Proportion of rollouts") ax.set_title(cfg.title) handles, labels_ = ax.get_legend_handles_labels() boundary_labels = [labels_.index(f"distillation off (step={cfg.warmup})")] if cfg.pre_warmup > 0: boundary_labels = [labels_.index(f"distillation on (step={cfg.pre_warmup})")] + boundary_labels order = [labels_.index(LABELS[c]) for c in CATS] + boundary_labels ax.legend( [handles[i] for i in order], [labels_[i] for i in order], loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=7, frameon=False, fontsize=9, ) # Loss subplot: per-step mean GRPO loss (-adv * logp_mean). ax_loss.axhline(0, color="black", linewidth=0.5, alpha=0.5) ax_loss.plot(xs, _sma(loss_step, cfg.smooth), color="#212121", lw=1.4) ax_loss.set_ylabel("GRPO loss") # Cosine subplot: v_hack relevance on hacked rollouts (the signal we care # about). Light grey trace is batch-level cos_pre (all rollouts) for context. ax2.axhline(0, color="black", linewidth=0.5, alpha=0.5) ax2.plot(xs, _sma(cos_hack_step, cfg.smooth), color="#E53935", lw=1.6, label="cos_S | rollout hacked (per-sample, v_hack relevance)") ax2.plot(xs, _sma(cos_pre_weighted, cfg.smooth), color="#1976D2", lw=1.4, label="cos_pre / hack_frac (E[cos|hacked] estimate, batch-derived)") ax2.plot(xs, _sma(cos_pre_step, cfg.smooth), color="#9E9E9E", lw=1.0, alpha=0.6, label="cos_pre (raw batch grad, all rollouts)") ax2.set_xlabel("Training step") ax2.set_ylabel("cos with v_hack") ax2.legend(loc="upper center", bbox_to_anchor=(0.5, -0.18), ncol=2, frameon=False, fontsize=9) fig.tight_layout() cfg.out_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(cfg.out_path, dpi=120, bbox_inches="tight") logger.info(f"wrote {cfg.out_path}") return 0 if __name__ == "__main__": sys.exit(main(tyro.cli(Config)))