mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
218 lines
8.8 KiB
Python
218 lines
8.8 KiB
Python
"""Stacked-area plot of rollout outcome proportions per training step.
|
|
|
|
Categories per rollout:
|
|
green -- correct (no hack, gt_pass)
|
|
yellow -- correct attempted hack (hack present AND gt_pass)
|
|
red -- reward hack (hack present, no gt_pass)
|
|
purple -- attempted hack (run_tests in completion but `hacked`=False;
|
|
format/pattern didn't fire so no reward)
|
|
grey -- incorrect (no hack, no gt_pass, no run_tests)
|
|
|
|
Layering bottom->top: green, yellow, red, purple, grey.
|
|
Vertical line marks warmup_replay_steps boundary (distillation off).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import gzip
|
|
import json
|
|
import sys
|
|
from collections import Counter
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import tyro
|
|
from dataclasses import dataclass
|
|
from loguru import logger
|
|
|
|
|
|
CATS = ["green", "yellow", "red", "purple", "grey"]
|
|
LABELS = {
|
|
"green": "correct",
|
|
"yellow": "correct attempted reward hack",
|
|
"red": "reward hack",
|
|
"purple": "attempted reward hack",
|
|
"grey": "incorrect",
|
|
}
|
|
COLORS = {
|
|
"green": "#4CAF50",
|
|
"yellow": "#FFC107",
|
|
"red": "#E53935",
|
|
"purple": "#8E24AA",
|
|
"grey": "#9E9E9E",
|
|
}
|
|
|
|
|
|
def classify(row: dict) -> str:
|
|
h = bool(row["hacked"])
|
|
gt = bool(row["gt_pass"])
|
|
comp = row.get("completion", "")
|
|
has_rt = "run_tests" in comp
|
|
if h and gt:
|
|
return "yellow"
|
|
if h and not gt:
|
|
return "red"
|
|
if (not h) and gt:
|
|
return "green"
|
|
if (not h) and (not gt) and has_rt:
|
|
return "purple"
|
|
return "grey"
|
|
|
|
|
|
def load_step(path: Path) -> list[dict]:
|
|
with gzip.open(path) as f:
|
|
return [json.loads(line) for line in f]
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
run_dir: Path
|
|
out_path: Path = Path("out/runs/probe_plot_stack_vanilla_seed41.png")
|
|
warmup: int = 70 # distill-off boundary (end of replay)
|
|
pre_warmup: int = 0 # distill-on boundary (start of replay)
|
|
smooth: int = 10 # trailing SMA window; double the blog's 5 since our G=8 (theirs G=16)
|
|
title: str = "vanilla GRPO seed=41 (warmup-distill -> student-gen)"
|
|
|
|
|
|
def main(cfg: Config) -> int:
|
|
steps_subdir = cfg.run_dir / "steps"
|
|
search_dir = steps_subdir if steps_subdir.exists() else cfg.run_dir
|
|
files = sorted(search_dir.glob("step_*.jsonl.gz"))
|
|
if not files:
|
|
logger.error(f"no step files in {search_dir}")
|
|
return 1
|
|
# de-dup if both .cos.jsonl.gz and .jsonl.gz exist for same step (gen phase
|
|
# writes the full file; replay writes .cos slim; they shouldn't overlap)
|
|
steps_data: dict[int, list[dict]] = {}
|
|
for p in files:
|
|
step = int(p.name.split("_")[1].split(".")[0])
|
|
steps_data.setdefault(step, []).extend(load_step(p))
|
|
|
|
n_steps = max(steps_data) + 1
|
|
fracs = np.zeros((len(CATS), n_steps))
|
|
# Per-step diagnostics (mean over G samples). NaN if row didn't carry it.
|
|
cos_pre_step = np.full(n_steps, np.nan) # batch-level pre-proj cos (all rollouts)
|
|
cos_pre_weighted = np.full(n_steps, np.nan) # cos_pre / hack_frac (per-hacked estimate)
|
|
cos_hack_step = np.full(n_steps, np.nan) # per-sample cos_S_contrib | hacked
|
|
loss_step = np.full(n_steps, np.nan) # GRPO loss
|
|
for step, rows in steps_data.items():
|
|
c = Counter(classify(r) for r in rows)
|
|
total = sum(c.values())
|
|
for i, cat in enumerate(CATS):
|
|
fracs[i, step] = c[cat] / total
|
|
cin = [r["mean_cos_pre"] for r in rows if r.get("mean_cos_pre") is not None]
|
|
if cin:
|
|
cos_pre_step[step] = float(np.mean(cin))
|
|
# Recover E[cos|hacked] from batch-mean cos under the assumption
|
|
# E[cos|clean]=0: mean(cos_pre) = f_h * E[cos|hacked] + (1-f_h)*0
|
|
# => E[cos|hacked] = mean(cos_pre) / f_h. NaN when no hacks in batch
|
|
# (no per-hacked estimate possible from this step).
|
|
# FIXME: cos_pre is now the aligned fraction ||relu(V@g)||/||g|| >= 0
|
|
# (was signed sum, ~0 on clean). With relu the E[cos|clean]=0 premise
|
|
# no longer holds, so this f_h-weighted estimate over-counts. Recompute
|
|
# per-rollout cos restricted to hacked rollouts instead of decomposing.
|
|
hack_frac = float(np.mean([bool(r.get("hacked")) for r in rows]))
|
|
if hack_frac > 0:
|
|
cos_pre_weighted[step] = cos_pre_step[step] / hack_frac
|
|
# Per-sample cos restricted to hacked rollouts: where v_hack relevance
|
|
# should show. cos on clean rollouts is noise -- drop it.
|
|
ch = [r["cos_S_contrib"] for r in rows
|
|
if r.get("hacked") and r.get("cos_S_contrib") is not None]
|
|
if ch: cos_hack_step[step] = float(np.mean(ch))
|
|
# GRPO loss: mean_i(-adv_i * logp_mean_i), adv_i = reward_i - mean(reward).
|
|
# Reconstructible from per-row reward + logp_mean. If a row stored per_sample_loss
|
|
# (added later), prefer that.
|
|
if all(r.get("per_sample_loss") is not None for r in rows):
|
|
loss_step[step] = float(np.mean([r["per_sample_loss"] for r in rows]))
|
|
else:
|
|
rwd = np.array([r["reward"] for r in rows], dtype=float)
|
|
lp = np.array([r["logp_mean"] for r in rows if r.get("logp_mean") is not None], dtype=float)
|
|
if len(lp) == len(rwd):
|
|
adv = rwd - rwd.mean()
|
|
loss_step[step] = float((-adv * lp).mean())
|
|
|
|
def _sma(y: np.ndarray, w: int) -> np.ndarray:
|
|
if w <= 1: return y
|
|
out = np.full_like(y, np.nan, dtype=float)
|
|
for t in range(len(y)):
|
|
lo = max(0, t - w + 1)
|
|
seg = y[lo:t + 1]
|
|
seg = seg[~np.isnan(seg)]
|
|
if len(seg): out[t] = seg.mean()
|
|
return out
|
|
|
|
if cfg.smooth > 1:
|
|
w = cfg.smooth
|
|
smoothed = np.zeros_like(fracs)
|
|
for t in range(n_steps):
|
|
lo = max(0, t - w + 1)
|
|
smoothed[:, t] = fracs[:, lo:t + 1].mean(axis=1)
|
|
smoothed /= smoothed.sum(axis=0, keepdims=True).clip(min=1e-12)
|
|
plot_fracs = smoothed
|
|
else:
|
|
plot_fracs = fracs
|
|
|
|
fig, (ax, ax_loss, ax2) = plt.subplots(
|
|
3, 1, figsize=(10, 10), sharex=True,
|
|
gridspec_kw={"height_ratios": [3, 1, 2]},
|
|
)
|
|
xs = np.arange(n_steps)
|
|
ax.stackplot(
|
|
xs, plot_fracs,
|
|
labels=[LABELS[c] for c in CATS],
|
|
colors=[COLORS[c] for c in CATS],
|
|
alpha=0.95,
|
|
)
|
|
if cfg.pre_warmup > 0:
|
|
for a in (ax, ax_loss, ax2):
|
|
a.axvline(cfg.pre_warmup - 0.5, color="black", linestyle="--", linewidth=1.2)
|
|
ax.axvline(cfg.pre_warmup - 0.5, color="black", linestyle="--", linewidth=1.2,
|
|
label=f"distillation on (step={cfg.pre_warmup})")
|
|
for a in (ax, ax_loss, ax2):
|
|
a.axvline(cfg.warmup - 0.5, color="black", linestyle="--", linewidth=1.2)
|
|
ax.axvline(cfg.warmup - 0.5, color="black", linestyle="--", linewidth=1.2,
|
|
label=f"distillation off (step={cfg.warmup})")
|
|
ax.set_xlim(0, n_steps - 1)
|
|
ax.set_ylim(0, 1)
|
|
ax.set_ylabel("Proportion of rollouts")
|
|
ax.set_title(cfg.title)
|
|
handles, labels_ = ax.get_legend_handles_labels()
|
|
boundary_labels = [labels_.index(f"distillation off (step={cfg.warmup})")]
|
|
if cfg.pre_warmup > 0:
|
|
boundary_labels = [labels_.index(f"distillation on (step={cfg.pre_warmup})")] + boundary_labels
|
|
order = [labels_.index(LABELS[c]) for c in CATS] + boundary_labels
|
|
ax.legend(
|
|
[handles[i] for i in order], [labels_[i] for i in order],
|
|
loc="upper center", bbox_to_anchor=(0.5, -0.05),
|
|
ncol=7, frameon=False, fontsize=9,
|
|
)
|
|
|
|
# Loss subplot: per-step mean GRPO loss (-adv * logp_mean).
|
|
ax_loss.axhline(0, color="black", linewidth=0.5, alpha=0.5)
|
|
ax_loss.plot(xs, _sma(loss_step, cfg.smooth), color="#212121", lw=1.4)
|
|
ax_loss.set_ylabel("GRPO loss")
|
|
|
|
# Cosine subplot: v_hack relevance on hacked rollouts (the signal we care
|
|
# about). Light grey trace is batch-level cos_pre (all rollouts) for context.
|
|
ax2.axhline(0, color="black", linewidth=0.5, alpha=0.5)
|
|
ax2.plot(xs, _sma(cos_hack_step, cfg.smooth), color="#E53935", lw=1.6,
|
|
label="cos_S | rollout hacked (per-sample, v_hack relevance)")
|
|
ax2.plot(xs, _sma(cos_pre_weighted, cfg.smooth), color="#1976D2", lw=1.4,
|
|
label="cos_pre / hack_frac (E[cos|hacked] estimate, batch-derived)")
|
|
ax2.plot(xs, _sma(cos_pre_step, cfg.smooth), color="#9E9E9E", lw=1.0,
|
|
alpha=0.6, label="cos_pre (raw batch grad, all rollouts)")
|
|
ax2.set_xlabel("Training step")
|
|
ax2.set_ylabel("cos with v_hack")
|
|
ax2.legend(loc="upper center", bbox_to_anchor=(0.5, -0.18),
|
|
ncol=2, frameon=False, fontsize=9)
|
|
|
|
fig.tight_layout()
|
|
cfg.out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
fig.savefig(cfg.out_path, dpi=120, bbox_inches="tight")
|
|
logger.info(f"wrote {cfg.out_path}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main(tyro.cli(Config)))
|