"""Distillation probe: hacky teacher samples, student trains with per-sample v_hack cosine logging. One file per step (step_NNN.jsonl.gz) so a saved step can be replayed (student fwd+bwd+project re-run on cached completions). Usage modes (via flags): --teacher-only --steps=20 just generate+grade, save step files (no student work) --replay-dir=PATH student fwd+bwd+project on saved batches (no teacher) (default) teacher generate + student train in one process Teacher = ariahw/rl-rewardhacking-leetcode-rh-s65 (LoRA on Qwen3-4B, ~79% hack rate at step 200 per paper Figure 3; "rh" = no-intervention arm trained on the loophole env). Merged into base for plain HF inference. Student = Qwen/Qwen3-4B + AntiPaSTO (own SVD basis, own delta_S grad). Known methodological caveat (flagged 2026-05-25): v_hack is extracted via NLL gradient (extract_vhack_grad.py) on contrastive pairs. GRPO's policy gradient is reward-weighted, not NLL. If the per-sample cosine separation (hacked vs not) fails, the fallback is to re-extract v_hack with a GRPO-style contrastive loss while keeping the same persona pairs. Per-step pipeline: 1. (skip if replay) Sample one problem; teacher generates G completions. 2. (skip if replay) compute_reward per completion -> r, hacked, gt_pass. 3. (skip if teacher-only) Old-policy logp: student.no_grad on all G batched. 4. (skip if teacher-only) For each sample i: snapshot delta_S.grad, compute single-sample Dr.GRPO loss, backward, diff = contrib_i, cos(contrib_i, v_hack) -> per-sample cos_S. 5. (skip if teacher-only, projected arm) project_delta_S_grad on accumulated grad. 6. (skip if teacher-only) opt.step(). 7. Write step_NNN.jsonl.gz: G JSON lines, one per sample. """ from __future__ import annotations import gzip import json import os import sys import time from dataclasses import dataclass from pathlib import Path from typing import Literal os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error") os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") import torch import tyro from loguru import logger from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from vgrout.antipasto import wrap_model_with_antipasto from vgrout.proj import per_token_logps, project_delta_S_grad from vgrout.rewards import compute_reward from vgrout.train import CACHE_ROOT, OUT_DIR, setup_logging from vgrout.problems import DATA, load_problems from vgrout.extract_vhack_grad import load_v_hack STUDENT_MODEL = "Qwen/Qwen3-4B" @dataclass class Config: arm: Literal["vanilla", "projected"] = "projected" teacher: str = "ariahw/rl-rewardhacking-leetcode-rh-s65" steps: int = 20 group: int = 8 max_new: int = 1024 n_problems: int = 50 lr: float = 3e-4 clip: float = 0.2 seed: int = 41 preserve_magnitude: bool = True v_hack_path: Path = OUT_DIR / "vhack" / "v_hack_full.safetensors" tag: str = "" replay_dir: Path | None = None teacher_only: bool = False # Base pool: generate from base Qwen3-4B (no LoRA, no hint) -> mostly non-hack # samples. Used to populate the "no_hack" bucket for cosine comparison. base_only: bool = False # TODO(spec2 §"Phase 2"): mixed-replay GRPO was started here, then user # FIXME: the replay fields below are wired into the loader (heterogeneous # plen handling) but the GRPO loss path is incomplete -- finish or remove. # train.py at small scale is the canonical Phase 2 mechanism. replay_dirs: str | None = None # Sandwich schedule: [0, pre) student-gen -> [pre, pre+replay) replay-distill # -> [pre+replay, steps) student-gen. With pre_warmup_steps=0 reduces to the # original "replay then gen" schedule. pre_warmup_steps: int = 0 warmup_replay_steps: int | None = None def load_student(device): tok = AutoTokenizer.from_pretrained(STUDENT_MODEL) if tok.pad_token_id is None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained( STUDENT_MODEL, dtype=torch.bfloat16, attn_implementation="flash_attention_2", ).to(device) model.config.use_cache = False wrappers = wrap_model_with_antipasto(model, STUDENT_MODEL, CACHE_ROOT, device) return model, wrappers, tok def load_teacher(adapter_id: str, device): base = AutoModelForCausalLM.from_pretrained( STUDENT_MODEL, dtype=torch.bfloat16, attn_implementation="flash_attention_2", ) wrapped = PeftModel.from_pretrained(base, adapter_id) merged = wrapped.merge_and_unload() merged = merged.to(device) merged.eval() for p in merged.parameters(): p.requires_grad_(False) return merged def norm_weighted_cos(contrib: dict[str, torch.Tensor], v_hack: dict[str, torch.Tensor]) -> float: """Per-sample subspace-energy fraction across the top-k hack subspace. energy = sum_m ||V_m c_m||^2 / sum_m ||c_m||^2, result in [0, 1] V_m has rows orthonormal (from SVD top-k in extract_vhack_grad), so ||V_m c_m||^2 = sum_i ^2 = fraction of the per-module sample gradient lying in the hack subspace. Returned as a single scalar per sample for logging -- pre-projection signal of how hack-aligned this rollout is. """ num = 0.0 den_sq = 0.0 for name, c in contrib.items(): V = v_hack[name] # [k, r] coeffs = V @ c # [k] num += float((coeffs @ coeffs).item()) den_sq += float((c @ c).item()) return (num / (den_sq + 1e-12)) ** 0.5 def save_prompt(out_dir: Path, problem_id: int, rows: list[dict]) -> None: """Pool generation: one file per problem, G rollouts of that prompt.""" out_dir.mkdir(parents=True, exist_ok=True) path = out_dir / f"prompt_{problem_id:04d}.jsonl.gz" with gzip.open(path, "wt") as f: for r in rows: f.write(json.dumps(r) + "\n") logger.info(f"wrote {path.name} ({len(rows)} samples)") def save_step(out_dir: Path, step: int, rows: list[dict]) -> None: """Student-gen step in warmupgen mode: full rows with prompts/completions.""" out_dir.mkdir(parents=True, exist_ok=True) path = out_dir / f"step_{step:03d}.jsonl.gz" with gzip.open(path, "wt") as f: for r in rows: f.write(json.dumps(r) + "\n") def save_step_slim(out_dir: Path, step: int, rows: list[dict]) -> None: """Warmup-replay annotations: cos + flags only; completions live in pool dirs.""" slim_keys = ("step", "sample_id", "src_pool", "src_problem_id", "reward", "hacked", "gt_pass", "fmt_ok", "comp_len", "cos_S_contrib", "grad_norm_contrib", "mean_cos_pre", "mean_cos_post", "frac_fired", "arm", "logp_mean", "delta_S_norm", "imp_ratio") out_dir.mkdir(parents=True, exist_ok=True) path = out_dir / f"step_{step:03d}.cos.jsonl.gz" with gzip.open(path, "wt") as f: for r in rows: f.write(json.dumps({k: r.get(k) for k in slim_keys}) + "\n") def load_prompt(pool_dir: Path, problem_id: int) -> list[dict]: path = pool_dir / f"prompt_{problem_id:04d}.jsonl.gz" with gzip.open(path, "rt") as f: return [json.loads(line) for line in f] def main(cfg: Config) -> int: if cfg.tag: tag = cfg.tag elif cfg.teacher_only: tag = "teacher_pool" elif cfg.base_only: tag = "base_pool" else: tag = f"{cfg.arm}_seed{cfg.seed}" run_id = f"distill_{tag}" setup_logging(run_id) torch.manual_seed(cfg.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"argv: {' '.join(sys.argv)}") logger.info(f"arm={cfg.arm} teacher={cfg.teacher} steps={cfg.steps} " f"G={cfg.group} seed={cfg.seed} " f"teacher_only={cfg.teacher_only} replay={cfg.replay_dir is not None}") if cfg.teacher_only or cfg.base_only: tok = AutoTokenizer.from_pretrained(STUDENT_MODEL) if tok.pad_token_id is None: tok.pad_token = tok.eos_token student = wrappers = delta_params = v_hack = opt = None else: student, wrappers, tok = load_student(device) delta_params = [info["delta_S"] for info in wrappers.values()] logger.info(f"student delta_S params: {sum(p.numel() for p in delta_params):,}") v_hack_cpu = load_v_hack(cfg.v_hack_path, STUDENT_MODEL, wrappers) v_hack = {n: v.to(device) for n, v in v_hack_cpu.items()} opt = torch.optim.AdamW(delta_params, lr=cfg.lr) # When warmup_replay_steps is set and we're in replay mode, we need the # student-gen prerequisites loaded too (problems, gen_cfg) for the post-warmup phase. needs_student_gen = (cfg.warmup_replay_steps is not None and cfg.warmup_replay_steps < cfg.steps and (cfg.replay_dir is not None or cfg.replay_dirs is not None)) if cfg.replay_dir is None and cfg.replay_dirs is None: if cfg.base_only: # Load base Qwen3-4B (no LoRA merge); use dataset's unmodified prompts. teacher = AutoModelForCausalLM.from_pretrained( STUDENT_MODEL, dtype=torch.bfloat16, attn_implementation="flash_attention_2", ).to(device) teacher.eval() for p in teacher.parameters(): p.requires_grad_(False) problems = load_problems(cfg.n_problems) logger.info(f"loaded BASE Qwen3-4B (no LoRA) + {len(problems)} hinted problems") else: teacher = load_teacher(cfg.teacher, device) problems = load_problems(cfg.n_problems) logger.info(f"loaded rh teacher + {len(problems)} problems (hint applied)") gen_cfg = GenerationConfig( max_new_tokens=cfg.max_new, do_sample=True, temperature=1.0, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0, num_return_sequences=cfg.group, pad_token_id=tok.pad_token_id, ) else: teacher = None problems = gen_cfg = None if needs_student_gen: problems = load_problems(cfg.n_problems) gen_cfg = GenerationConfig( max_new_tokens=cfg.max_new, do_sample=True, temperature=1.0, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0, num_return_sequences=cfg.group, pad_token_id=tok.pad_token_id, ) logger.info(f"warmup->gen enabled: switch at step={cfg.warmup_replay_steps}; loaded {len(problems)} hinted problems for student-gen") # Pools are content-keyed (teacher_pool / base_pool) so replay loaders find # them. Pool files live flat at the pool root (prompt_*.jsonl.gz). Training # runs get an ISO timestamp prefix and step files go in a `steps/` subdir. if cfg.teacher_only or cfg.base_only: out_dir = OUT_DIR / "pools" / tag # teacher/base pools live under pools/ steps_dir = out_dir else: from datetime import datetime stamp = datetime.now().strftime("%Y%m%dT%H%M%S") out_dir = OUT_DIR / "runs" / f"{stamp}_distill_{tag}" # analysis run -> runs/ steps_dir = out_dir / "steps" rng = torch.Generator().manual_seed(cfg.seed) pad_id = tok.pad_token_id # logp at first encounter of each replay prompt; used to compute the # importance ratio = exp(logp_now - logp_step0). Diagnostic only. logp_step0_by_prompt: dict[int, list[float]] = {} logger.debug("row\tstep\tsample\thacked\tgt\tcos_S\t||g||\tcomp_len") logger.info( "SHOULD: ||dS|| grows monotonically across warmup; " "logp[hack] > logp[no] under teacher-forcing; " "ratio~1.00 during replay (no off-policy drift); " "post-warmup hack rate > 0 for vanilla; projected arm hack < vanilla. " "ELSE: adapter not learning, basis mismatch, or loss not flowing." ) # Track gen-phase hack rate for tail summary. In sandwich mode, separately # accumulate pre-distill and post-distill so we can answer "does distillation # induce hacking that persists?" The "main metric" is post-distill hack rate. pre_hack_rates: list[float] = [] pre_pass_rates: list[float] = [] post_hack_rates: list[float] = [] post_pass_rates: list[float] = [] for step in range(cfg.steps): t0 = time.time() if opt is not None: opt.zero_grad(set_to_none=True) # --- 1-2. generate + grade (or replay) ---------------------------- # Each sample carries its own plen so we can mix pools with different # prompts (e.g. teacher_pool hinted vs base_pool unhinted). For # uniform-prompt replay all plens are identical and this is a no-op. per_sample_meta: list[dict] | None = None plens: list[int] | None = None # warmup_replay_steps boundary: before it, replay from saved pools; after, # student generates with its learned adapter (canonical GRPO). replay_on = cfg.warmup_replay_steps is not None replay_end = (cfg.pre_warmup_steps + cfg.warmup_replay_steps) if replay_on else None replay_active = (cfg.replay_dir is not None or cfg.replay_dirs is not None) \ and (not replay_on or (cfg.pre_warmup_steps <= step < replay_end)) if replay_on and step == cfg.pre_warmup_steps and cfg.pre_warmup_steps > 0: logger.info(f"--- step {step}: pre-warmup gen over; starting replay-distill ---") if replay_on and step == replay_end: logger.info(f"--- step {step}: replay-distill over; switching to student-generation ---") if replay_active: # Pick the same problem from every pool so all G samples in this step # share one prompt -> per-prompt centered advantage is meaningful. pools = ( [Path(p) for p in cfg.replay_dirs.split(",")] if cfg.replay_dirs is not None else [cfg.replay_dir] ) per_pool = cfg.group // len(pools) # Enumerate problem ids from the first pool. Cycle modulo size. pool_prompt_ids = sorted( int(p.name.removeprefix("prompt_").split(".")[0]) for p in pools[0].glob("prompt_*.jsonl.gz") ) assert pool_prompt_ids, f"no prompt_*.jsonl.gz files in {pools[0]}" replay_problem_id = pool_prompt_ids[step % len(pool_prompt_ids)] saved_all = [] for pool_dir in pools: pool_rows = load_prompt(pool_dir, replay_problem_id) for s in pool_rows[:per_pool]: s["src_pool"] = pool_dir.name s["src_problem_id"] = replay_problem_id saved_all.append(s) assert len(saved_all) == cfg.group, f"replay produced {len(saved_all)} samples, need {cfg.group}" # Build padded merged: each sample is prompt_ids + completion_ids, # pad to max length with pad_id. Track plen per sample. seqs = [s["prompt_ids"] + s["completion_ids"] for s in saved_all] plens = [s["plen"] for s in saved_all] L_max = max(len(seq) for seq in seqs) merged = torch.full((cfg.group, L_max), pad_id, dtype=torch.long, device=device) for i, seq in enumerate(seqs): merged[i, :len(seq)] = torch.tensor(seq, device=device, dtype=torch.long) rewards_list = [s["reward"] for s in saved_all] hacked_list = [s["hacked"] for s in saved_all] gt_list = [s["gt_pass"] for s in saved_all] fmt_list = [s["fmt_ok"] for s in saved_all] completion_texts = [s["completion"] for s in saved_all] per_sample_meta = saved_all # No single prompt/problem when mixing pools problem_id = -1 if cfg.replay_dirs else saved_all[0]["problem_id"] problem_messages = None prompt = None else: # Direct generation: either teacher (teacher_only/base_only) or # student (post-warmup in warmup->gen mode). Pool gen iterates # problems sequentially so the on-disk prompt_NNNN file naming is # deterministic. Student-gen mode randomises so the warmed adapter # sees varied prompts. generator = teacher if teacher is not None else student gen_label = "teacher" if teacher is not None else "student" if cfg.teacher_only or cfg.base_only: idx = step % len(problems) else: idx = int(torch.randint(0, len(problems), (1,), generator=rng).item()) prob = problems[idx] prompt = tok.apply_chat_template( prob["messages"], tokenize=False, add_generation_prompt=True, enable_thinking=False, ) enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device) plen = enc.input_ids.shape[1] if plen + cfg.max_new > 2048: logger.warning(f"step {step}: skipping (plen+max_new={plen+cfg.max_new} > 2048)") continue generator.config.use_cache = True generator.eval() with torch.no_grad(): merged = generator.generate(**enc, generation_config=gen_cfg).detach() generator.config.use_cache = False if generator is student: student.train() # restore train mode for the bwd pass below completion_texts = tok.batch_decode(merged[:, plen:], skip_special_tokens=True) rewards_list, hacked_list, gt_list, fmt_list = [], [], [], [] for txt in completion_texts: r = compute_reward( txt, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"], setup_code=prob["setup_code"], func_name_hint=prob["func_name"], ) rewards_list.append(r.reward); hacked_list.append(r.hacked) gt_list.append(r.gt_pass); fmt_list.append(r.format_ok) problem_id = prob["problem_id"] problem_messages = prob["messages"] # Mark each sample so jsonl knows where it came from. per_sample_meta = [{"src_pool": "student_gen" if generator is student else gen_label, "src_problem_id": problem_id, "step": step, "sample_id": i} for i in range(cfg.group)] # When uniform-prompt (direct gen or single-pool replay), broadcast plen. plens_eff = plens if plens is not None else [plen] * cfg.group per_sample_cos: list[float | None] = [None] * cfg.group per_sample_norm: list[float | None] = [None] * cfg.group diag = {"mean_cos_pre": float("nan"), "min_cos_pre": float("nan"), "max_cos_pre": float("nan"), "mean_cos_post": float("nan"), "min_cos_post": float("nan"), "max_cos_post": float("nan"), "frac_fired": float("nan")} # Dr.GRPO unbiased advantage (centered, no /std). Non-zero iff reward # variance in the batch -- the whole reason for mixed teacher+base replay. rewards_t = torch.tensor(rewards_list, dtype=torch.float32, device=device) adv = rewards_t - rewards_t.mean() # --- 3-6. student fwd+bwd+project+step (skip in teacher-only/base-only mode) ---- per_sample_logp_mean: list[float] = [float("nan")] * cfg.group per_sample_imp_ratio: list[float] = [float("nan")] * cfg.group per_sample_loss: list[float] = [float("nan")] * cfg.group if not (cfg.teacher_only or cfg.base_only): g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()} for i in range(cfg.group): plen_i = plens_eff[i] mi = merged[i:i+1] ci = mi[:, plen_i:] L_c_i = ci.shape[1] logp_i = per_token_logps( student(mi, logits_to_keep=L_c_i + 1).logits[:, :-1], ci, ) mask = (ci != pad_id).float() per_sample_logp_mean[i] = float((logp_i * mask).sum().item() / max(1.0, mask.sum().item())) # Dr.GRPO REINFORCE: -adv * mean_logp. No PPO ratio because at step # start, student matches its own no_grad logp on these tokens. loss_i = -adv[i] * (logp_i * mask).sum() / mask.sum().clamp_min(1.0) / cfg.group per_sample_loss[i] = float(loss_i.item()) loss_i.backward() contrib = {n: info["delta_S"].grad - g_before[n] for n, info in wrappers.items()} per_sample_cos[i] = norm_weighted_cos(contrib, v_hack) per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5) g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()} # Importance ratio vs first-encounter logp. Only meaningful in # replay mode (same tokens, drifting student). For student-gen we # set ratio=1.0 because each step has freshly generated tokens. if replay_active and replay_problem_id not in logp_step0_by_prompt: logp_step0_by_prompt[replay_problem_id] = list(per_sample_logp_mean) per_sample_imp_ratio = [1.0] * cfg.group elif replay_active: base = logp_step0_by_prompt[replay_problem_id] per_sample_imp_ratio = [ float(torch.tensor(per_sample_logp_mean[i] - base[i]).exp().item()) for i in range(cfg.group) ] else: per_sample_imp_ratio = [1.0] * cfg.group # Both arms measure cos_pre/out; vanilla uses measure_only so the # gradient passes through unchanged. diag = project_delta_S_grad( wrappers, v_hack, cfg.preserve_magnitude, measure_only=(cfg.arm != "projected"), ) torch.nn.utils.clip_grad_norm_(delta_params, 1.0) opt.step() # --- 6.5 adapter movement diagnostic --- # ||delta_S||_2 across all wrapped modules. If learning is happening, this # should grow over warmup. Flat == adapter not updating. # None in pool-gen modes (teacher_only/base_only) where no wrappers exist. delta_S_norm = ( float(sum(info["delta_S"].data.float().pow(2).sum().item() for info in wrappers.values()) ** 0.5) if wrappers is not None else 0.0 ) # --- 7. write step file. Slim in replay-warmup (completions live in pool dirs); # full in student-gen so we can read what the student actually emitted. --- is_replay = replay_active rows = [] for i in range(cfg.group): plen_i = plens_eff[i] meta = per_sample_meta[i] if per_sample_meta is not None else None row = { "step": step, "sample_id": i, "reward": float(rewards_list[i]), "hacked": bool(hacked_list[i]), "gt_pass": bool(gt_list[i]), "fmt_ok": bool(fmt_list[i]), "comp_len": int((merged[i, plen_i:] != pad_id).sum().item()), "cos_S_contrib": per_sample_cos[i], "grad_norm_contrib": per_sample_norm[i], "mean_cos_pre": diag["mean_cos_pre"], "mean_cos_post": diag["mean_cos_post"], "frac_fired": diag["frac_fired"], "arm": cfg.arm, "src_pool": meta.get("src_pool") if meta else None, "src_problem_id": meta.get("src_problem_id") if meta else None, "logp_mean": per_sample_logp_mean[i], "per_sample_loss": per_sample_loss[i], "imp_ratio": per_sample_imp_ratio[i], "delta_S_norm": delta_S_norm, } if not is_replay: # Direct-gen mode: keep full data (we generated this; pool dirs need it). row.update({ "problem_id": int(problem_id), "problem_messages": problem_messages, "prompt": prompt, "plen": int(plen_i), "prompt_ids": merged[i, :plen_i].tolist(), "completion_ids": merged[i, plen_i:].tolist(), "completion": completion_texts[i], }) rows.append(row) if is_replay: # Warmup replay: slim cos annotations only; full rows live in the pools. save_step_slim(steps_dir, step, rows) elif cfg.teacher_only or cfg.base_only: # Pool generation: one file per problem_id (each = G rollouts). save_prompt(out_dir, int(problem_id), rows) else: # Student-gen in warmupgen: full rows so we can see what the warmed # adapter actually emits at gen time. save_step(steps_dir, step, rows) for i in range(cfg.group): cs, gn = per_sample_cos[i], per_sample_norm[i] cs_s = f"{cs:+.3f}" if cs is not None else " nan" gn_s = f"{gn:.2e}" if gn is not None else " nan" logger.debug( f"r\t{step}\t{i}\t{int(hacked_list[i])}\t{int(gt_list[i])}\t" f"{cs_s}\t{gn_s}\t{int(rows[i]['comp_len'])}" ) hr = sum(hacked_list) / cfg.group pr = sum(gt_list) / cfg.group # Record student-gen rates split by phase (pre-distill vs post-distill). if not replay_active: if replay_on and step >= replay_end: post_hack_rates.append(hr) post_pass_rates.append(pr) else: pre_hack_rates.append(hr) pre_pass_rates.append(pr) # Bucket cos by (hacked, gt_pass) so the discrimination signal is inline. def _bucket_mean(pred): cs = [per_sample_cos[i] for i in range(cfg.group) if pred(i) and per_sample_cos[i] is not None] return (sum(cs)/len(cs), len(cs)) if cs else (float('nan'), 0) cph, nph = _bucket_mean(lambda i: hacked_list[i] and not gt_list[i]) cmx, nmx = _bucket_mean(lambda i: hacked_list[i] and gt_list[i]) cno, nno = _bucket_mean(lambda i: not hacked_list[i]) # Per-sample cos summary across the G samples in this step. ps_cos = [c for c in per_sample_cos if c is not None] if ps_cos: ps_min = min(ps_cos); ps_max = max(ps_cos); ps_mean = sum(ps_cos)/len(ps_cos) ps_summary = f"per_sample cos[min/mean/max]={ps_min:+.3f}/{ps_mean:+.3f}/{ps_max:+.3f}" else: ps_summary = "per_sample cos=nan" # logp split by hacked/not. If REINFORCE is teacher-forcing the hack tokens, # logp_hack should rise monotonically across warmup steps. lp_h = [per_sample_logp_mean[i] for i in range(cfg.group) if hacked_list[i]] lp_n = [per_sample_logp_mean[i] for i in range(cfg.group) if not hacked_list[i]] lp_h_s = f"{sum(lp_h)/len(lp_h):+.3f}" if lp_h else " nan" lp_n_s = f"{sum(lp_n)/len(lp_n):+.3f}" if lp_n else " nan" # imp_ratio: drift of student's logp on replayed tokens vs first encounter. # 1.0 == no drift; >>1 == student now strongly favors these tokens (overfit risk). valid_ratios = [r for r in per_sample_imp_ratio if r == r] # drop nan if valid_ratios: r_min, r_max = min(valid_ratios), max(valid_ratios) r_mean = sum(valid_ratios) / len(valid_ratios) ratio_summary = f"ratio[min/mean/max]={r_min:.2f}/{r_mean:.2f}/{r_max:.2f}" else: ratio_summary = "ratio=nan" logger.info( f"step {step} DONE hack={hr:.2f} pass={pr:.2f} {ps_summary} " f"cos_pureHack={cph:+.3f}(n={nph}) cos_mixed={cmx:+.3f}(n={nmx}) " f"cos_noHack={cno:+.3f}(n={nno}) " f"cos_pre[min/mean/max]={diag['min_cos_pre']:+.3f}/{diag['mean_cos_pre']:+.3f}/{diag['max_cos_pre']:+.3f} " f"cos_post[min/mean/max]={diag['min_cos_post']:+.3f}/{diag['mean_cos_post']:+.3f}/{diag['max_cos_post']:+.3f} " f"fired={diag['frac_fired']:.2f} " f"logp[hack={lp_h_s} no={lp_n_s}] {ratio_summary} " f"||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}" ) # --- tail summary (BLUF main metric) --- def _avg(xs): return (sum(xs) / len(xs)) if xs else float("nan") pre_hack, pre_pass = _avg(pre_hack_rates), _avg(pre_pass_rates) post_hack, post_pass = _avg(post_hack_rates), _avg(post_pass_rates) # Use post-distill hack as headline; fall back to pre if no post phase. if post_hack_rates: head_hack, head_pass, head_n = post_hack, post_pass, len(post_hack_rates) head_label = "post" else: head_hack, head_pass, head_n = pre_hack, pre_pass, len(pre_hack_rates) head_label = "pre" cue = "⚪" if head_n == 0 else ("🔴" if head_hack >= 0.5 else ("🟢" if head_hack < 0.1 else "🟡")) plot_path = out_dir / "rollout_stack.png" report_path = out_dir / "report.md" if cfg.warmup_replay_steps is not None: try: from probe_plot_stack import Config as PlotCfg, main as plot_main plot_main(PlotCfg( run_dir=out_dir, out_path=plot_path, pre_warmup=cfg.pre_warmup_steps, warmup=cfg.pre_warmup_steps + cfg.warmup_replay_steps, smooth=10, title=f"{cfg.arm} GRPO seed={cfg.seed} " f"({cfg.pre_warmup_steps} pre + {cfg.warmup_replay_steps} distill" f" + {cfg.steps - cfg.pre_warmup_steps - cfg.warmup_replay_steps} post," f" 10-step SMA)", )) except Exception as e: logger.error(f"auto-plot failed: {e}") plot_path = None meta = { "arm": cfg.arm, "seed": cfg.seed, "tag": tag, "steps": cfg.steps, "pre_warmup_steps": cfg.pre_warmup_steps, "warmup_replay_steps": cfg.warmup_replay_steps, "group": cfg.group, "n_problems": cfg.n_problems, "argv": sys.argv, "pre": {"hack": pre_hack, "pass": pre_pass, "n_steps": len(pre_hack_rates)}, "post": {"hack": post_hack, "pass": post_pass, "n_steps": len(post_hack_rates)}, } caption = ( f"Rollout outcomes per training step for {cfg.arm} GRPO at seed={cfg.seed}. " f"Schedule: {cfg.pre_warmup_steps} steps of student-generated rollouts, " f"then {cfg.warmup_replay_steps} steps of replay-distillation from a saved " f"teacher+base pool, then {cfg.steps - cfg.pre_warmup_steps - (cfg.warmup_replay_steps or 0)} " f"steps of student-generated rollouts. Categories: correct (green), correct " f"with attempted reward hack (yellow), reward hack (red), attempted reward " f"hack (purple), incorrect (grey). Values are a 10-step trailing moving " f"average. Dashed lines mark distillation on/off." ) report_path.write_text( "# probe_distill report\n\n" f"![rollout stack]({plot_path.name if plot_path else 'rollout_stack.png'})\n\n" f"*{caption}*\n\n" "## metadata\n\n```json\n" + json.dumps(meta, indent=2) + "\n```\n" ) logger.info("") logger.info(f"out: {out_dir}/step_*.jsonl.gz") logger.info(f"plot: {plot_path}") logger.info(f"report: {report_path}") logger.info(f"argv: {' '.join(sys.argv)}") logger.info( f"main metric ({head_label}-distill): hack={head_hack:.2f} pass={head_pass:.2f} " f"[arm={cfg.arm} seed={cfg.seed} n_steps={head_n}]" ) logger.info( f"{cue} arm={cfg.arm} seed={cfg.seed} " f"pre[hack={pre_hack:.2f},pass={pre_pass:.2f},n={len(pre_hack_rates)}] " f"post[hack={post_hack:.2f},pass={post_pass:.2f},n={len(post_hack_rates)}] " f"pre_warmup={cfg.pre_warmup_steps} warmup={cfg.warmup_replay_steps} " f"steps={cfg.steps} G={cfg.group} tag={tag}" ) return 0 if __name__ == "__main__": sys.exit(main(tyro.cli(Config)))