Files
evil_MoE/scripts/probe_distill.py
T
wassname 55937a86fb rename python package projected_grpo -> vgrout
git mv src/projected_grpo -> src/vgrout and find-replace the module name in
all imports (.py), `-m projected_grpo.*` invocations (justfile), and the
[project] name (pyproject; setuptools auto-discovers via where=["src"]).

Left RESEARCH_JOURNAL.md untouched: its commands/paths are dated lab notes
tied to past commits, so rewriting them would falsify provenance. Repo dir,
git remote, and absolute paths unchanged.

Verified: `import vgrout` and `python -m vgrout.train --help` load the full
graph; verify_rewards.py + verify_gate_anchor.py (both import vgrout) pass.
Full `just smoke` is blocked upstream by missing gitignored data artifacts
(out/pools/{substrate,teacher_pool}, out/vhack/*smoke*), unrelated to the rename.
2026-06-05 14:51:48 +08:00

664 lines
32 KiB
Python

"""Distillation probe: hacky teacher samples, student trains with per-sample
v_hack cosine logging. One file per step (step_NNN.jsonl.gz) so a saved
step can be replayed (student fwd+bwd+project re-run on cached completions).
Usage modes (via flags):
--teacher-only --steps=20 just generate+grade, save step files (no student work)
--replay-dir=PATH student fwd+bwd+project on saved batches (no teacher)
(default) teacher generate + student train in one process
Teacher = ariahw/rl-rewardhacking-leetcode-rh-s65 (LoRA on Qwen3-4B, ~79%
hack rate at step 200 per paper Figure 3; "rh" = no-intervention arm
trained on the loophole env). Merged into base for plain HF inference.
Student = Qwen/Qwen3-4B + AntiPaSTO (own SVD basis, own delta_S grad).
Known methodological caveat (flagged 2026-05-25):
v_hack is extracted via NLL gradient (extract_vhack_grad.py) on
contrastive pairs. GRPO's policy gradient is reward-weighted, not NLL.
If the per-sample cosine separation (hacked vs not) fails, the fallback
is to re-extract v_hack with a GRPO-style contrastive loss while
keeping the same persona pairs.
Per-step pipeline:
1. (skip if replay) Sample one problem; teacher generates G completions.
2. (skip if replay) compute_reward per completion -> r, hacked, gt_pass.
3. (skip if teacher-only) Old-policy logp: student.no_grad on all G batched.
4. (skip if teacher-only) For each sample i: snapshot delta_S.grad,
compute single-sample Dr.GRPO loss, backward, diff = contrib_i,
cos(contrib_i, v_hack) -> per-sample cos_S.
5. (skip if teacher-only, projected arm) project_delta_S_grad on accumulated grad.
6. (skip if teacher-only) opt.step().
7. Write step_NNN.jsonl.gz: G JSON lines, one per sample.
"""
from __future__ import annotations
import gzip
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
import torch
import tyro
from loguru import logger
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from vgrout.antipasto import wrap_model_with_antipasto
from vgrout.proj import per_token_logps, project_delta_S_grad
from vgrout.rewards import compute_reward
from vgrout.train import CACHE_ROOT, OUT_DIR, setup_logging
from vgrout.problems import DATA, load_problems
from vgrout.extract_vhack_grad import load_v_hack
STUDENT_MODEL = "Qwen/Qwen3-4B"
@dataclass
class Config:
arm: Literal["vanilla", "projected"] = "projected"
teacher: str = "ariahw/rl-rewardhacking-leetcode-rh-s65"
steps: int = 20
group: int = 8
max_new: int = 1024
n_problems: int = 50
lr: float = 3e-4
clip: float = 0.2
seed: int = 41
preserve_magnitude: bool = True
v_hack_path: Path = OUT_DIR / "vhack" / "v_hack_full.safetensors"
tag: str = ""
replay_dir: Path | None = None
teacher_only: bool = False
# Base pool: generate from base Qwen3-4B (no LoRA, no hint) -> mostly non-hack
# samples. Used to populate the "no_hack" bucket for cosine comparison.
base_only: bool = False
# TODO(spec2 §"Phase 2"): mixed-replay GRPO was started here, then user
# FIXME: the replay fields below are wired into the loader (heterogeneous
# plen handling) but the GRPO loss path is incomplete -- finish or remove.
# train.py at small scale is the canonical Phase 2 mechanism.
replay_dirs: str | None = None
# Sandwich schedule: [0, pre) student-gen -> [pre, pre+replay) replay-distill
# -> [pre+replay, steps) student-gen. With pre_warmup_steps=0 reduces to the
# original "replay then gen" schedule.
pre_warmup_steps: int = 0
warmup_replay_steps: int | None = None
def load_student(device):
tok = AutoTokenizer.from_pretrained(STUDENT_MODEL)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
STUDENT_MODEL, dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).to(device)
model.config.use_cache = False
wrappers = wrap_model_with_antipasto(model, STUDENT_MODEL, CACHE_ROOT, device)
return model, wrappers, tok
def load_teacher(adapter_id: str, device):
base = AutoModelForCausalLM.from_pretrained(
STUDENT_MODEL, dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
wrapped = PeftModel.from_pretrained(base, adapter_id)
merged = wrapped.merge_and_unload()
merged = merged.to(device)
merged.eval()
for p in merged.parameters():
p.requires_grad_(False)
return merged
def norm_weighted_cos(contrib: dict[str, torch.Tensor], v_hack: dict[str, torch.Tensor]) -> float:
"""Per-sample subspace-energy fraction across the top-k hack subspace.
energy = sum_m ||V_m c_m||^2 / sum_m ||c_m||^2, result in [0, 1]
V_m has rows orthonormal (from SVD top-k in extract_vhack_grad), so
||V_m c_m||^2 = sum_i <c_m, v_m_i>^2 = fraction of the per-module sample
gradient lying in the hack subspace. Returned as a single scalar per sample
for logging -- pre-projection signal of how hack-aligned this rollout is.
"""
num = 0.0
den_sq = 0.0
for name, c in contrib.items():
V = v_hack[name] # [k, r]
coeffs = V @ c # [k]
num += float((coeffs @ coeffs).item())
den_sq += float((c @ c).item())
return (num / (den_sq + 1e-12)) ** 0.5
def save_prompt(out_dir: Path, problem_id: int, rows: list[dict]) -> None:
"""Pool generation: one file per problem, G rollouts of that prompt."""
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / f"prompt_{problem_id:04d}.jsonl.gz"
with gzip.open(path, "wt") as f:
for r in rows:
f.write(json.dumps(r) + "\n")
logger.info(f"wrote {path.name} ({len(rows)} samples)")
def save_step(out_dir: Path, step: int, rows: list[dict]) -> None:
"""Student-gen step in warmupgen mode: full rows with prompts/completions."""
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / f"step_{step:03d}.jsonl.gz"
with gzip.open(path, "wt") as f:
for r in rows:
f.write(json.dumps(r) + "\n")
def save_step_slim(out_dir: Path, step: int, rows: list[dict]) -> None:
"""Warmup-replay annotations: cos + flags only; completions live in pool dirs."""
slim_keys = ("step", "sample_id", "src_pool", "src_problem_id",
"reward", "hacked", "gt_pass", "fmt_ok", "comp_len",
"cos_S_contrib", "grad_norm_contrib",
"mean_cos_pre", "mean_cos_post", "frac_fired", "arm",
"logp_mean", "delta_S_norm", "imp_ratio")
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / f"step_{step:03d}.cos.jsonl.gz"
with gzip.open(path, "wt") as f:
for r in rows:
f.write(json.dumps({k: r.get(k) for k in slim_keys}) + "\n")
def load_prompt(pool_dir: Path, problem_id: int) -> list[dict]:
path = pool_dir / f"prompt_{problem_id:04d}.jsonl.gz"
with gzip.open(path, "rt") as f:
return [json.loads(line) for line in f]
def main(cfg: Config) -> int:
if cfg.tag:
tag = cfg.tag
elif cfg.teacher_only:
tag = "teacher_pool"
elif cfg.base_only:
tag = "base_pool"
else:
tag = f"{cfg.arm}_seed{cfg.seed}"
run_id = f"distill_{tag}"
setup_logging(run_id)
torch.manual_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"argv: {' '.join(sys.argv)}")
logger.info(f"arm={cfg.arm} teacher={cfg.teacher} steps={cfg.steps} "
f"G={cfg.group} seed={cfg.seed} "
f"teacher_only={cfg.teacher_only} replay={cfg.replay_dir is not None}")
if cfg.teacher_only or cfg.base_only:
tok = AutoTokenizer.from_pretrained(STUDENT_MODEL)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
student = wrappers = delta_params = v_hack = opt = None
else:
student, wrappers, tok = load_student(device)
delta_params = [info["delta_S"] for info in wrappers.values()]
logger.info(f"student delta_S params: {sum(p.numel() for p in delta_params):,}")
v_hack_cpu = load_v_hack(cfg.v_hack_path, STUDENT_MODEL, wrappers)
v_hack = {n: v.to(device) for n, v in v_hack_cpu.items()}
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
# When warmup_replay_steps is set and we're in replay mode, we need the
# student-gen prerequisites loaded too (problems, gen_cfg) for the post-warmup phase.
needs_student_gen = (cfg.warmup_replay_steps is not None
and cfg.warmup_replay_steps < cfg.steps
and (cfg.replay_dir is not None or cfg.replay_dirs is not None))
if cfg.replay_dir is None and cfg.replay_dirs is None:
if cfg.base_only:
# Load base Qwen3-4B (no LoRA merge); use dataset's unmodified prompts.
teacher = AutoModelForCausalLM.from_pretrained(
STUDENT_MODEL, dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).to(device)
teacher.eval()
for p in teacher.parameters():
p.requires_grad_(False)
problems = load_problems(cfg.n_problems)
logger.info(f"loaded BASE Qwen3-4B (no LoRA) + {len(problems)} hinted problems")
else:
teacher = load_teacher(cfg.teacher, device)
problems = load_problems(cfg.n_problems)
logger.info(f"loaded rh teacher + {len(problems)} problems (hint applied)")
gen_cfg = GenerationConfig(
max_new_tokens=cfg.max_new, do_sample=True,
temperature=1.0, top_p=1.0, top_k=20, min_p=0.0,
repetition_penalty=1.0, num_return_sequences=cfg.group,
pad_token_id=tok.pad_token_id,
)
else:
teacher = None
problems = gen_cfg = None
if needs_student_gen:
problems = load_problems(cfg.n_problems)
gen_cfg = GenerationConfig(
max_new_tokens=cfg.max_new, do_sample=True,
temperature=1.0, top_p=1.0, top_k=20, min_p=0.0,
repetition_penalty=1.0, num_return_sequences=cfg.group,
pad_token_id=tok.pad_token_id,
)
logger.info(f"warmup->gen enabled: switch at step={cfg.warmup_replay_steps}; loaded {len(problems)} hinted problems for student-gen")
# Pools are content-keyed (teacher_pool / base_pool) so replay loaders find
# them. Pool files live flat at the pool root (prompt_*.jsonl.gz). Training
# runs get an ISO timestamp prefix and step files go in a `steps/` subdir.
if cfg.teacher_only or cfg.base_only:
out_dir = OUT_DIR / "pools" / tag # teacher/base pools live under pools/
steps_dir = out_dir
else:
from datetime import datetime
stamp = datetime.now().strftime("%Y%m%dT%H%M%S")
out_dir = OUT_DIR / "runs" / f"{stamp}_distill_{tag}" # analysis run -> runs/
steps_dir = out_dir / "steps"
rng = torch.Generator().manual_seed(cfg.seed)
pad_id = tok.pad_token_id
# logp at first encounter of each replay prompt; used to compute the
# importance ratio = exp(logp_now - logp_step0). Diagnostic only.
logp_step0_by_prompt: dict[int, list[float]] = {}
logger.debug("row\tstep\tsample\thacked\tgt\tcos_S\t||g||\tcomp_len")
logger.info(
"SHOULD: ||dS|| grows monotonically across warmup; "
"logp[hack] > logp[no] under teacher-forcing; "
"ratio~1.00 during replay (no off-policy drift); "
"post-warmup hack rate > 0 for vanilla; projected arm hack < vanilla. "
"ELSE: adapter not learning, basis mismatch, or loss not flowing."
)
# Track gen-phase hack rate for tail summary. In sandwich mode, separately
# accumulate pre-distill and post-distill so we can answer "does distillation
# induce hacking that persists?" The "main metric" is post-distill hack rate.
pre_hack_rates: list[float] = []
pre_pass_rates: list[float] = []
post_hack_rates: list[float] = []
post_pass_rates: list[float] = []
for step in range(cfg.steps):
t0 = time.time()
if opt is not None:
opt.zero_grad(set_to_none=True)
# --- 1-2. generate + grade (or replay) ----------------------------
# Each sample carries its own plen so we can mix pools with different
# prompts (e.g. teacher_pool hinted vs base_pool unhinted). For
# uniform-prompt replay all plens are identical and this is a no-op.
per_sample_meta: list[dict] | None = None
plens: list[int] | None = None
# warmup_replay_steps boundary: before it, replay from saved pools; after,
# student generates with its learned adapter (canonical GRPO).
replay_on = cfg.warmup_replay_steps is not None
replay_end = (cfg.pre_warmup_steps + cfg.warmup_replay_steps) if replay_on else None
replay_active = (cfg.replay_dir is not None or cfg.replay_dirs is not None) \
and (not replay_on or (cfg.pre_warmup_steps <= step < replay_end))
if replay_on and step == cfg.pre_warmup_steps and cfg.pre_warmup_steps > 0:
logger.info(f"--- step {step}: pre-warmup gen over; starting replay-distill ---")
if replay_on and step == replay_end:
logger.info(f"--- step {step}: replay-distill over; switching to student-generation ---")
if replay_active:
# Pick the same problem from every pool so all G samples in this step
# share one prompt -> per-prompt centered advantage is meaningful.
pools = (
[Path(p) for p in cfg.replay_dirs.split(",")]
if cfg.replay_dirs is not None else [cfg.replay_dir]
)
per_pool = cfg.group // len(pools)
# Enumerate problem ids from the first pool. Cycle modulo size.
pool_prompt_ids = sorted(
int(p.name.removeprefix("prompt_").split(".")[0])
for p in pools[0].glob("prompt_*.jsonl.gz")
)
assert pool_prompt_ids, f"no prompt_*.jsonl.gz files in {pools[0]}"
replay_problem_id = pool_prompt_ids[step % len(pool_prompt_ids)]
saved_all = []
for pool_dir in pools:
pool_rows = load_prompt(pool_dir, replay_problem_id)
for s in pool_rows[:per_pool]:
s["src_pool"] = pool_dir.name
s["src_problem_id"] = replay_problem_id
saved_all.append(s)
assert len(saved_all) == cfg.group, f"replay produced {len(saved_all)} samples, need {cfg.group}"
# Build padded merged: each sample is prompt_ids + completion_ids,
# pad to max length with pad_id. Track plen per sample.
seqs = [s["prompt_ids"] + s["completion_ids"] for s in saved_all]
plens = [s["plen"] for s in saved_all]
L_max = max(len(seq) for seq in seqs)
merged = torch.full((cfg.group, L_max), pad_id, dtype=torch.long, device=device)
for i, seq in enumerate(seqs):
merged[i, :len(seq)] = torch.tensor(seq, device=device, dtype=torch.long)
rewards_list = [s["reward"] for s in saved_all]
hacked_list = [s["hacked"] for s in saved_all]
gt_list = [s["gt_pass"] for s in saved_all]
fmt_list = [s["fmt_ok"] for s in saved_all]
completion_texts = [s["completion"] for s in saved_all]
per_sample_meta = saved_all
# No single prompt/problem when mixing pools
problem_id = -1 if cfg.replay_dirs else saved_all[0]["problem_id"]
problem_messages = None
prompt = None
else:
# Direct generation: either teacher (teacher_only/base_only) or
# student (post-warmup in warmup->gen mode). Pool gen iterates
# problems sequentially so the on-disk prompt_NNNN file naming is
# deterministic. Student-gen mode randomises so the warmed adapter
# sees varied prompts.
generator = teacher if teacher is not None else student
gen_label = "teacher" if teacher is not None else "student"
if cfg.teacher_only or cfg.base_only:
idx = step % len(problems)
else:
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
prob = problems[idx]
prompt = tok.apply_chat_template(
prob["messages"], tokenize=False, add_generation_prompt=True,
enable_thinking=False,
)
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
plen = enc.input_ids.shape[1]
if plen + cfg.max_new > 2048:
logger.warning(f"step {step}: skipping (plen+max_new={plen+cfg.max_new} > 2048)")
continue
generator.config.use_cache = True
generator.eval()
with torch.no_grad():
merged = generator.generate(**enc, generation_config=gen_cfg).detach()
generator.config.use_cache = False
if generator is student:
student.train() # restore train mode for the bwd pass below
completion_texts = tok.batch_decode(merged[:, plen:], skip_special_tokens=True)
rewards_list, hacked_list, gt_list, fmt_list = [], [], [], []
for txt in completion_texts:
r = compute_reward(
txt, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
)
rewards_list.append(r.reward); hacked_list.append(r.hacked)
gt_list.append(r.gt_pass); fmt_list.append(r.format_ok)
problem_id = prob["problem_id"]
problem_messages = prob["messages"]
# Mark each sample so jsonl knows where it came from.
per_sample_meta = [{"src_pool": "student_gen" if generator is student else gen_label,
"src_problem_id": problem_id,
"step": step, "sample_id": i} for i in range(cfg.group)]
# When uniform-prompt (direct gen or single-pool replay), broadcast plen.
plens_eff = plens if plens is not None else [plen] * cfg.group
per_sample_cos: list[float | None] = [None] * cfg.group
per_sample_norm: list[float | None] = [None] * cfg.group
diag = {"mean_cos_pre": float("nan"), "min_cos_pre": float("nan"), "max_cos_pre": float("nan"),
"mean_cos_post": float("nan"), "min_cos_post": float("nan"), "max_cos_post": float("nan"),
"frac_fired": float("nan")}
# Dr.GRPO unbiased advantage (centered, no /std). Non-zero iff reward
# variance in the batch -- the whole reason for mixed teacher+base replay.
rewards_t = torch.tensor(rewards_list, dtype=torch.float32, device=device)
adv = rewards_t - rewards_t.mean()
# --- 3-6. student fwd+bwd+project+step (skip in teacher-only/base-only mode) ----
per_sample_logp_mean: list[float] = [float("nan")] * cfg.group
per_sample_imp_ratio: list[float] = [float("nan")] * cfg.group
per_sample_loss: list[float] = [float("nan")] * cfg.group
if not (cfg.teacher_only or cfg.base_only):
g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()}
for i in range(cfg.group):
plen_i = plens_eff[i]
mi = merged[i:i+1]
ci = mi[:, plen_i:]
L_c_i = ci.shape[1]
logp_i = per_token_logps(
student(mi, logits_to_keep=L_c_i + 1).logits[:, :-1], ci,
)
mask = (ci != pad_id).float()
per_sample_logp_mean[i] = float((logp_i * mask).sum().item() / max(1.0, mask.sum().item()))
# Dr.GRPO REINFORCE: -adv * mean_logp. No PPO ratio because at step
# start, student matches its own no_grad logp on these tokens.
loss_i = -adv[i] * (logp_i * mask).sum() / mask.sum().clamp_min(1.0) / cfg.group
per_sample_loss[i] = float(loss_i.item())
loss_i.backward()
contrib = {n: info["delta_S"].grad - g_before[n]
for n, info in wrappers.items()}
per_sample_cos[i] = norm_weighted_cos(contrib, v_hack)
per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5)
g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()}
# Importance ratio vs first-encounter logp. Only meaningful in
# replay mode (same tokens, drifting student). For student-gen we
# set ratio=1.0 because each step has freshly generated tokens.
if replay_active and replay_problem_id not in logp_step0_by_prompt:
logp_step0_by_prompt[replay_problem_id] = list(per_sample_logp_mean)
per_sample_imp_ratio = [1.0] * cfg.group
elif replay_active:
base = logp_step0_by_prompt[replay_problem_id]
per_sample_imp_ratio = [
float(torch.tensor(per_sample_logp_mean[i] - base[i]).exp().item())
for i in range(cfg.group)
]
else:
per_sample_imp_ratio = [1.0] * cfg.group
# Both arms measure cos_pre/out; vanilla uses measure_only so the
# gradient passes through unchanged.
diag = project_delta_S_grad(
wrappers, v_hack, cfg.preserve_magnitude,
measure_only=(cfg.arm != "projected"),
)
torch.nn.utils.clip_grad_norm_(delta_params, 1.0)
opt.step()
# --- 6.5 adapter movement diagnostic ---
# ||delta_S||_2 across all wrapped modules. If learning is happening, this
# should grow over warmup. Flat == adapter not updating.
# None in pool-gen modes (teacher_only/base_only) where no wrappers exist.
delta_S_norm = (
float(sum(info["delta_S"].data.float().pow(2).sum().item()
for info in wrappers.values()) ** 0.5)
if wrappers is not None else 0.0
)
# --- 7. write step file. Slim in replay-warmup (completions live in pool dirs);
# full in student-gen so we can read what the student actually emitted. ---
is_replay = replay_active
rows = []
for i in range(cfg.group):
plen_i = plens_eff[i]
meta = per_sample_meta[i] if per_sample_meta is not None else None
row = {
"step": step, "sample_id": i,
"reward": float(rewards_list[i]),
"hacked": bool(hacked_list[i]),
"gt_pass": bool(gt_list[i]),
"fmt_ok": bool(fmt_list[i]),
"comp_len": int((merged[i, plen_i:] != pad_id).sum().item()),
"cos_S_contrib": per_sample_cos[i],
"grad_norm_contrib": per_sample_norm[i],
"mean_cos_pre": diag["mean_cos_pre"],
"mean_cos_post": diag["mean_cos_post"],
"frac_fired": diag["frac_fired"],
"arm": cfg.arm,
"src_pool": meta.get("src_pool") if meta else None,
"src_problem_id": meta.get("src_problem_id") if meta else None,
"logp_mean": per_sample_logp_mean[i],
"per_sample_loss": per_sample_loss[i],
"imp_ratio": per_sample_imp_ratio[i],
"delta_S_norm": delta_S_norm,
}
if not is_replay:
# Direct-gen mode: keep full data (we generated this; pool dirs need it).
row.update({
"problem_id": int(problem_id),
"problem_messages": problem_messages,
"prompt": prompt, "plen": int(plen_i),
"prompt_ids": merged[i, :plen_i].tolist(),
"completion_ids": merged[i, plen_i:].tolist(),
"completion": completion_texts[i],
})
rows.append(row)
if is_replay:
# Warmup replay: slim cos annotations only; full rows live in the pools.
save_step_slim(steps_dir, step, rows)
elif cfg.teacher_only or cfg.base_only:
# Pool generation: one file per problem_id (each = G rollouts).
save_prompt(out_dir, int(problem_id), rows)
else:
# Student-gen in warmupgen: full rows so we can see what the warmed
# adapter actually emits at gen time.
save_step(steps_dir, step, rows)
for i in range(cfg.group):
cs, gn = per_sample_cos[i], per_sample_norm[i]
cs_s = f"{cs:+.3f}" if cs is not None else " nan"
gn_s = f"{gn:.2e}" if gn is not None else " nan"
logger.debug(
f"r\t{step}\t{i}\t{int(hacked_list[i])}\t{int(gt_list[i])}\t"
f"{cs_s}\t{gn_s}\t{int(rows[i]['comp_len'])}"
)
hr = sum(hacked_list) / cfg.group
pr = sum(gt_list) / cfg.group
# Record student-gen rates split by phase (pre-distill vs post-distill).
if not replay_active:
if replay_on and step >= replay_end:
post_hack_rates.append(hr)
post_pass_rates.append(pr)
else:
pre_hack_rates.append(hr)
pre_pass_rates.append(pr)
# Bucket cos by (hacked, gt_pass) so the discrimination signal is inline.
def _bucket_mean(pred):
cs = [per_sample_cos[i] for i in range(cfg.group)
if pred(i) and per_sample_cos[i] is not None]
return (sum(cs)/len(cs), len(cs)) if cs else (float('nan'), 0)
cph, nph = _bucket_mean(lambda i: hacked_list[i] and not gt_list[i])
cmx, nmx = _bucket_mean(lambda i: hacked_list[i] and gt_list[i])
cno, nno = _bucket_mean(lambda i: not hacked_list[i])
# Per-sample cos summary across the G samples in this step.
ps_cos = [c for c in per_sample_cos if c is not None]
if ps_cos:
ps_min = min(ps_cos); ps_max = max(ps_cos); ps_mean = sum(ps_cos)/len(ps_cos)
ps_summary = f"per_sample cos[min/mean/max]={ps_min:+.3f}/{ps_mean:+.3f}/{ps_max:+.3f}"
else:
ps_summary = "per_sample cos=nan"
# logp split by hacked/not. If REINFORCE is teacher-forcing the hack tokens,
# logp_hack should rise monotonically across warmup steps.
lp_h = [per_sample_logp_mean[i] for i in range(cfg.group) if hacked_list[i]]
lp_n = [per_sample_logp_mean[i] for i in range(cfg.group) if not hacked_list[i]]
lp_h_s = f"{sum(lp_h)/len(lp_h):+.3f}" if lp_h else " nan"
lp_n_s = f"{sum(lp_n)/len(lp_n):+.3f}" if lp_n else " nan"
# imp_ratio: drift of student's logp on replayed tokens vs first encounter.
# 1.0 == no drift; >>1 == student now strongly favors these tokens (overfit risk).
valid_ratios = [r for r in per_sample_imp_ratio if r == r] # drop nan
if valid_ratios:
r_min, r_max = min(valid_ratios), max(valid_ratios)
r_mean = sum(valid_ratios) / len(valid_ratios)
ratio_summary = f"ratio[min/mean/max]={r_min:.2f}/{r_mean:.2f}/{r_max:.2f}"
else:
ratio_summary = "ratio=nan"
logger.info(
f"step {step} DONE hack={hr:.2f} pass={pr:.2f} {ps_summary} "
f"cos_pureHack={cph:+.3f}(n={nph}) cos_mixed={cmx:+.3f}(n={nmx}) "
f"cos_noHack={cno:+.3f}(n={nno}) "
f"cos_pre[min/mean/max]={diag['min_cos_pre']:+.3f}/{diag['mean_cos_pre']:+.3f}/{diag['max_cos_pre']:+.3f} "
f"cos_post[min/mean/max]={diag['min_cos_post']:+.3f}/{diag['mean_cos_post']:+.3f}/{diag['max_cos_post']:+.3f} "
f"fired={diag['frac_fired']:.2f} "
f"logp[hack={lp_h_s} no={lp_n_s}] {ratio_summary} "
f"||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}"
)
# --- tail summary (BLUF main metric) ---
def _avg(xs): return (sum(xs) / len(xs)) if xs else float("nan")
pre_hack, pre_pass = _avg(pre_hack_rates), _avg(pre_pass_rates)
post_hack, post_pass = _avg(post_hack_rates), _avg(post_pass_rates)
# Use post-distill hack as headline; fall back to pre if no post phase.
if post_hack_rates:
head_hack, head_pass, head_n = post_hack, post_pass, len(post_hack_rates)
head_label = "post"
else:
head_hack, head_pass, head_n = pre_hack, pre_pass, len(pre_hack_rates)
head_label = "pre"
cue = "⚪" if head_n == 0 else ("🔴" if head_hack >= 0.5 else ("🟢" if head_hack < 0.1 else "🟡"))
plot_path = out_dir / "rollout_stack.png"
report_path = out_dir / "report.md"
if cfg.warmup_replay_steps is not None:
try:
from probe_plot_stack import Config as PlotCfg, main as plot_main
plot_main(PlotCfg(
run_dir=out_dir,
out_path=plot_path,
pre_warmup=cfg.pre_warmup_steps,
warmup=cfg.pre_warmup_steps + cfg.warmup_replay_steps,
smooth=10,
title=f"{cfg.arm} GRPO seed={cfg.seed} "
f"({cfg.pre_warmup_steps} pre + {cfg.warmup_replay_steps} distill"
f" + {cfg.steps - cfg.pre_warmup_steps - cfg.warmup_replay_steps} post,"
f" 10-step SMA)",
))
except Exception as e:
logger.error(f"auto-plot failed: {e}")
plot_path = None
meta = {
"arm": cfg.arm,
"seed": cfg.seed,
"tag": tag,
"steps": cfg.steps,
"pre_warmup_steps": cfg.pre_warmup_steps,
"warmup_replay_steps": cfg.warmup_replay_steps,
"group": cfg.group,
"n_problems": cfg.n_problems,
"argv": sys.argv,
"pre": {"hack": pre_hack, "pass": pre_pass, "n_steps": len(pre_hack_rates)},
"post": {"hack": post_hack, "pass": post_pass, "n_steps": len(post_hack_rates)},
}
caption = (
f"Rollout outcomes per training step for {cfg.arm} GRPO at seed={cfg.seed}. "
f"Schedule: {cfg.pre_warmup_steps} steps of student-generated rollouts, "
f"then {cfg.warmup_replay_steps} steps of replay-distillation from a saved "
f"teacher+base pool, then {cfg.steps - cfg.pre_warmup_steps - (cfg.warmup_replay_steps or 0)} "
f"steps of student-generated rollouts. Categories: correct (green), correct "
f"with attempted reward hack (yellow), reward hack (red), attempted reward "
f"hack (purple), incorrect (grey). Values are a 10-step trailing moving "
f"average. Dashed lines mark distillation on/off."
)
report_path.write_text(
"# probe_distill report\n\n"
f"![rollout stack]({plot_path.name if plot_path else 'rollout_stack.png'})\n\n"
f"*{caption}*\n\n"
"## metadata\n\n```json\n"
+ json.dumps(meta, indent=2) + "\n```\n"
)
logger.info("")
logger.info(f"out: {out_dir}/step_*.jsonl.gz")
logger.info(f"plot: {plot_path}")
logger.info(f"report: {report_path}")
logger.info(f"argv: {' '.join(sys.argv)}")
logger.info(
f"main metric ({head_label}-distill): hack={head_hack:.2f} pass={head_pass:.2f} "
f"[arm={cfg.arm} seed={cfg.seed} n_steps={head_n}]"
)
logger.info(
f"{cue} arm={cfg.arm} seed={cfg.seed} "
f"pre[hack={pre_hack:.2f},pass={pre_pass:.2f},n={len(pre_hack_rates)}] "
f"post[hack={post_hack:.2f},pass={post_pass:.2f},n={len(post_hack_rates)}] "
f"pre_warmup={cfg.pre_warmup_steps} warmup={cfg.warmup_replay_steps} "
f"steps={cfg.steps} G={cfg.group} tag={tag}"
)
return 0
if __name__ == "__main__":
sys.exit(main(tyro.cli(Config)))