mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 21:07:17 +08:00
refactor: move 5 leaf entrypoints src/ -> scripts/ (src is now library-only)
verify_rewards, verify_vhack_heldout, build_substrate, probe_distill, probe_plot_stack are run via 'python -m' / justfile and imported by no core module -> moved to scripts/, relative imports rewritten to 'from projected_grpo.X'. probe_distill's sibling import of probe_plot_stack is now a flat import (co-located in scripts/). regrade_pool stays in src (pairs_from_pool imports load_problems_by_id from it). justfile recipes updated. src/projected_grpo/ is now 16 importable modules: train + method (proj/vhack/antipasto/ extract_vhack_grad) + env (rewards/eval/problems/data) + pairs (pairs/pairs_from_pool/ regrade_pool/derisk_loopholes) + tablelog/figs. ~1480 lines moved out of the package. Smoke green (verify_rewards 52/52 from scripts/, train pipeline cout->0). Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,217 @@
|
||||
"""Stacked-area plot of rollout outcome proportions per training step.
|
||||
|
||||
Categories per rollout:
|
||||
green -- correct (no hack, gt_pass)
|
||||
yellow -- correct attempted hack (hack present AND gt_pass)
|
||||
red -- reward hack (hack present, no gt_pass)
|
||||
purple -- attempted hack (run_tests in completion but `hacked`=False;
|
||||
format/pattern didn't fire so no reward)
|
||||
grey -- incorrect (no hack, no gt_pass, no run_tests)
|
||||
|
||||
Layering bottom->top: green, yellow, red, purple, grey.
|
||||
Vertical line marks warmup_replay_steps boundary (distillation off).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import gzip
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import tyro
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
|
||||
|
||||
CATS = ["green", "yellow", "red", "purple", "grey"]
|
||||
LABELS = {
|
||||
"green": "correct",
|
||||
"yellow": "correct attempted reward hack",
|
||||
"red": "reward hack",
|
||||
"purple": "attempted reward hack",
|
||||
"grey": "incorrect",
|
||||
}
|
||||
COLORS = {
|
||||
"green": "#4CAF50",
|
||||
"yellow": "#FFC107",
|
||||
"red": "#E53935",
|
||||
"purple": "#8E24AA",
|
||||
"grey": "#9E9E9E",
|
||||
}
|
||||
|
||||
|
||||
def classify(row: dict) -> str:
|
||||
h = bool(row["hacked"])
|
||||
gt = bool(row["gt_pass"])
|
||||
comp = row.get("completion", "")
|
||||
has_rt = "run_tests" in comp
|
||||
if h and gt:
|
||||
return "yellow"
|
||||
if h and not gt:
|
||||
return "red"
|
||||
if (not h) and gt:
|
||||
return "green"
|
||||
if (not h) and (not gt) and has_rt:
|
||||
return "purple"
|
||||
return "grey"
|
||||
|
||||
|
||||
def load_step(path: Path) -> list[dict]:
|
||||
with gzip.open(path) as f:
|
||||
return [json.loads(line) for line in f]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
run_dir: Path
|
||||
out_path: Path = Path("out/runs/probe_plot_stack_vanilla_seed41.png")
|
||||
warmup: int = 70 # distill-off boundary (end of replay)
|
||||
pre_warmup: int = 0 # distill-on boundary (start of replay)
|
||||
smooth: int = 10 # trailing SMA window; double the blog's 5 since our G=8 (theirs G=16)
|
||||
title: str = "vanilla GRPO seed=41 (warmup-distill -> student-gen)"
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
steps_subdir = cfg.run_dir / "steps"
|
||||
search_dir = steps_subdir if steps_subdir.exists() else cfg.run_dir
|
||||
files = sorted(search_dir.glob("step_*.jsonl.gz"))
|
||||
if not files:
|
||||
logger.error(f"no step files in {search_dir}")
|
||||
return 1
|
||||
# de-dup if both .cos.jsonl.gz and .jsonl.gz exist for same step (gen phase
|
||||
# writes the full file; replay writes .cos slim; they shouldn't overlap)
|
||||
steps_data: dict[int, list[dict]] = {}
|
||||
for p in files:
|
||||
step = int(p.name.split("_")[1].split(".")[0])
|
||||
steps_data.setdefault(step, []).extend(load_step(p))
|
||||
|
||||
n_steps = max(steps_data) + 1
|
||||
fracs = np.zeros((len(CATS), n_steps))
|
||||
# Per-step diagnostics (mean over G samples). NaN if row didn't carry it.
|
||||
cos_pre_step = np.full(n_steps, np.nan) # batch-level pre-proj cos (all rollouts)
|
||||
cos_pre_weighted = np.full(n_steps, np.nan) # cos_pre / hack_frac (per-hacked estimate)
|
||||
cos_hack_step = np.full(n_steps, np.nan) # per-sample cos_S_contrib | hacked
|
||||
loss_step = np.full(n_steps, np.nan) # GRPO loss
|
||||
for step, rows in steps_data.items():
|
||||
c = Counter(classify(r) for r in rows)
|
||||
total = sum(c.values())
|
||||
for i, cat in enumerate(CATS):
|
||||
fracs[i, step] = c[cat] / total
|
||||
cin = [r["mean_cos_pre"] for r in rows if r.get("mean_cos_pre") is not None]
|
||||
if cin:
|
||||
cos_pre_step[step] = float(np.mean(cin))
|
||||
# Recover E[cos|hacked] from batch-mean cos under the assumption
|
||||
# E[cos|clean]=0: mean(cos_pre) = f_h * E[cos|hacked] + (1-f_h)*0
|
||||
# => E[cos|hacked] = mean(cos_pre) / f_h. NaN when no hacks in batch
|
||||
# (no per-hacked estimate possible from this step).
|
||||
# FIXME: cos_pre is now the hack-ward FRACTION ||relu(V@g)||/||g|| >= 0
|
||||
# (was signed sum, ~0 on clean). With relu the E[cos|clean]=0 premise
|
||||
# no longer holds, so this f_h-weighted estimate over-counts. Recompute
|
||||
# per-rollout cos restricted to hacked rollouts instead of decomposing.
|
||||
hack_frac = float(np.mean([bool(r.get("hacked")) for r in rows]))
|
||||
if hack_frac > 0:
|
||||
cos_pre_weighted[step] = cos_pre_step[step] / hack_frac
|
||||
# Per-sample cos restricted to hacked rollouts: where v_hack relevance
|
||||
# should show. cos on clean rollouts is noise -- drop it.
|
||||
ch = [r["cos_S_contrib"] for r in rows
|
||||
if r.get("hacked") and r.get("cos_S_contrib") is not None]
|
||||
if ch: cos_hack_step[step] = float(np.mean(ch))
|
||||
# GRPO loss: mean_i(-adv_i * logp_mean_i), adv_i = reward_i - mean(reward).
|
||||
# Reconstructible from per-row reward + logp_mean. If a row stored per_sample_loss
|
||||
# (added later), prefer that.
|
||||
if all(r.get("per_sample_loss") is not None for r in rows):
|
||||
loss_step[step] = float(np.mean([r["per_sample_loss"] for r in rows]))
|
||||
else:
|
||||
rwd = np.array([r["reward"] for r in rows], dtype=float)
|
||||
lp = np.array([r["logp_mean"] for r in rows if r.get("logp_mean") is not None], dtype=float)
|
||||
if len(lp) == len(rwd):
|
||||
adv = rwd - rwd.mean()
|
||||
loss_step[step] = float((-adv * lp).mean())
|
||||
|
||||
def _sma(y: np.ndarray, w: int) -> np.ndarray:
|
||||
if w <= 1: return y
|
||||
out = np.full_like(y, np.nan, dtype=float)
|
||||
for t in range(len(y)):
|
||||
lo = max(0, t - w + 1)
|
||||
seg = y[lo:t + 1]
|
||||
seg = seg[~np.isnan(seg)]
|
||||
if len(seg): out[t] = seg.mean()
|
||||
return out
|
||||
|
||||
if cfg.smooth > 1:
|
||||
w = cfg.smooth
|
||||
smoothed = np.zeros_like(fracs)
|
||||
for t in range(n_steps):
|
||||
lo = max(0, t - w + 1)
|
||||
smoothed[:, t] = fracs[:, lo:t + 1].mean(axis=1)
|
||||
smoothed /= smoothed.sum(axis=0, keepdims=True).clip(min=1e-12)
|
||||
plot_fracs = smoothed
|
||||
else:
|
||||
plot_fracs = fracs
|
||||
|
||||
fig, (ax, ax_loss, ax2) = plt.subplots(
|
||||
3, 1, figsize=(10, 10), sharex=True,
|
||||
gridspec_kw={"height_ratios": [3, 1, 2]},
|
||||
)
|
||||
xs = np.arange(n_steps)
|
||||
ax.stackplot(
|
||||
xs, plot_fracs,
|
||||
labels=[LABELS[c] for c in CATS],
|
||||
colors=[COLORS[c] for c in CATS],
|
||||
alpha=0.95,
|
||||
)
|
||||
if cfg.pre_warmup > 0:
|
||||
for a in (ax, ax_loss, ax2):
|
||||
a.axvline(cfg.pre_warmup - 0.5, color="black", linestyle="--", linewidth=1.2)
|
||||
ax.axvline(cfg.pre_warmup - 0.5, color="black", linestyle="--", linewidth=1.2,
|
||||
label=f"distillation on (step={cfg.pre_warmup})")
|
||||
for a in (ax, ax_loss, ax2):
|
||||
a.axvline(cfg.warmup - 0.5, color="black", linestyle="--", linewidth=1.2)
|
||||
ax.axvline(cfg.warmup - 0.5, color="black", linestyle="--", linewidth=1.2,
|
||||
label=f"distillation off (step={cfg.warmup})")
|
||||
ax.set_xlim(0, n_steps - 1)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.set_ylabel("Proportion of rollouts")
|
||||
ax.set_title(cfg.title)
|
||||
handles, labels_ = ax.get_legend_handles_labels()
|
||||
boundary_labels = [labels_.index(f"distillation off (step={cfg.warmup})")]
|
||||
if cfg.pre_warmup > 0:
|
||||
boundary_labels = [labels_.index(f"distillation on (step={cfg.pre_warmup})")] + boundary_labels
|
||||
order = [labels_.index(LABELS[c]) for c in CATS] + boundary_labels
|
||||
ax.legend(
|
||||
[handles[i] for i in order], [labels_[i] for i in order],
|
||||
loc="upper center", bbox_to_anchor=(0.5, -0.05),
|
||||
ncol=7, frameon=False, fontsize=9,
|
||||
)
|
||||
|
||||
# Loss subplot: per-step mean GRPO loss (-adv * logp_mean).
|
||||
ax_loss.axhline(0, color="black", linewidth=0.5, alpha=0.5)
|
||||
ax_loss.plot(xs, _sma(loss_step, cfg.smooth), color="#212121", lw=1.4)
|
||||
ax_loss.set_ylabel("GRPO loss")
|
||||
|
||||
# Cosine subplot: v_hack relevance on hacked rollouts (the signal we care
|
||||
# about). Light grey trace is batch-level cos_pre (all rollouts) for context.
|
||||
ax2.axhline(0, color="black", linewidth=0.5, alpha=0.5)
|
||||
ax2.plot(xs, _sma(cos_hack_step, cfg.smooth), color="#E53935", lw=1.6,
|
||||
label="cos_S | rollout hacked (per-sample, v_hack relevance)")
|
||||
ax2.plot(xs, _sma(cos_pre_weighted, cfg.smooth), color="#1976D2", lw=1.4,
|
||||
label="cos_pre / hack_frac (E[cos|hacked] estimate, batch-derived)")
|
||||
ax2.plot(xs, _sma(cos_pre_step, cfg.smooth), color="#9E9E9E", lw=1.0,
|
||||
alpha=0.6, label="cos_pre (raw batch grad, all rollouts)")
|
||||
ax2.set_xlabel("Training step")
|
||||
ax2.set_ylabel("cos with v_hack")
|
||||
ax2.legend(loc="upper center", bbox_to_anchor=(0.5, -0.18),
|
||||
ncol=2, frameon=False, fontsize=9)
|
||||
|
||||
fig.tight_layout()
|
||||
cfg.out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(cfg.out_path, dpi=120, bbox_inches="tight")
|
||||
logger.info(f"wrote {cfg.out_path}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(tyro.cli(Config)))
|
||||
Reference in New Issue
Block a user