"""Generate teacher/base pools or run the direct distillation probe. Usage modes (via flags): --teacher-only --steps=20 just generate+grade, save step files (no student work) --base-only --steps=20 generate a mostly-clean base-model pool (default) teacher generate + student train in one process Teacher = ariahw/rl-rewardhacking-leetcode-rh-s65 (LoRA on Qwen3-4B, ~79% hack rate at step 200 per paper Figure 3; "rh" = no-intervention arm trained on the loophole env). Merged into base for plain HF inference. Student = Qwen/Qwen3-4B + AntiPaSTO (own SVD basis, own delta_S grad). Per-step pipeline: 1. Sample one problem; teacher generates G completions. 2. compute_reward per completion -> r, hacked, gt_pass. 3. (skip if teacher-only) Old-policy logp: student.no_grad on all G batched. 4. (skip if teacher-only) For each sample i: snapshot delta_S.grad, compute single-sample Dr.GRPO loss, backward, diff = contrib_i, cos(contrib_i, v_hack) -> per-sample cos_S. 5. (skip if teacher-only, projected arm) project_delta_S_grad on accumulated grad. 6. (skip if teacher-only) opt.step(). 7. Write step_NNN.jsonl.gz: G JSON lines, one per sample. """ from __future__ import annotations import gzip import json import os import sys import time from dataclasses import dataclass from pathlib import Path from typing import Literal os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error") os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") import torch import tyro from loguru import logger from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from vgrout.antipasto import wrap_model_with_antipasto from vgrout.proj import per_token_logps, project_delta_S_grad from vgrout.rewards import compute_reward from vgrout.train import CACHE_ROOT, OUT_DIR, setup_logging from vgrout.data import DATA, load_problems from vgrout.vhack import load_v_hack STUDENT_MODEL = "Qwen/Qwen3-4B" @dataclass class Config: arm: Literal["vanilla", "projected"] = "projected" teacher: str = "ariahw/rl-rewardhacking-leetcode-rh-s65" steps: int = 20 group: int = 8 max_new: int = 1024 n_problems: int = 50 lr: float = 3e-4 clip: float = 0.2 seed: int = 41 preserve_magnitude: bool = True v_hack_path: Path = OUT_DIR / "vhack" / "v_hack_full.safetensors" pairs_path: Path = OUT_DIR / "pairsets" / "prog_wide.json" tag: str = "" teacher_only: bool = False # Base pool: generate from base Qwen3-4B (no LoRA, no hint) -> mostly non-hack # samples. Used to populate the "no_hack" bucket for cosine comparison. base_only: bool = False def load_student(device): tok = AutoTokenizer.from_pretrained(STUDENT_MODEL) if tok.pad_token_id is None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained( STUDENT_MODEL, dtype=torch.bfloat16, attn_implementation="flash_attention_2", ).to(device) model.config.use_cache = False wrappers = wrap_model_with_antipasto(model, STUDENT_MODEL, CACHE_ROOT, device) return model, wrappers, tok def load_teacher(adapter_id: str, device): base = AutoModelForCausalLM.from_pretrained( STUDENT_MODEL, dtype=torch.bfloat16, attn_implementation="flash_attention_2", ) wrapped = PeftModel.from_pretrained(base, adapter_id) merged = wrapped.merge_and_unload() merged = merged.to(device) merged.eval() for p in merged.parameters(): p.requires_grad_(False) return merged def norm_weighted_cos(contrib: dict[str, torch.Tensor], v_hack: dict[str, torch.Tensor]) -> float: """Per-sample subspace-energy fraction across the top-k hack subspace. energy = sum_m ||V_m c_m||^2 / sum_m ||c_m||^2, result in [0, 1] V_m has rows orthonormal (from SVD top-k in extract_vhack_grad), so ||V_m c_m||^2 = sum_i ^2 = fraction of the per-module sample gradient lying in the hack subspace. Returned as a single scalar per sample for logging -- pre-projection signal of how hack-aligned this rollout is. """ num = 0.0 den_sq = 0.0 for name, c in contrib.items(): V = v_hack[name] # [k, r] coeffs = V @ c # [k] num += float((coeffs @ coeffs).item()) den_sq += float((c @ c).item()) return (num / (den_sq + 1e-12)) ** 0.5 def save_prompt(out_dir: Path, problem_id: int, rows: list[dict]) -> None: """Pool generation: one file per problem, G rollouts of that prompt.""" out_dir.mkdir(parents=True, exist_ok=True) path = out_dir / f"prompt_{problem_id:04d}.jsonl.gz" with gzip.open(path, "wt") as f: for r in rows: f.write(json.dumps(r) + "\n") logger.info(f"wrote {path.name} ({len(rows)} samples)") def save_step(out_dir: Path, step: int, rows: list[dict]) -> None: """Save full generated rows for one direct probe step.""" out_dir.mkdir(parents=True, exist_ok=True) path = out_dir / f"step_{step:03d}.jsonl.gz" with gzip.open(path, "wt") as f: for r in rows: f.write(json.dumps(r) + "\n") def main(cfg: Config) -> int: if cfg.tag: tag = cfg.tag elif cfg.teacher_only: tag = "teacher_pool" elif cfg.base_only: tag = "base_pool" else: tag = f"{cfg.arm}_seed{cfg.seed}" run_id = f"distill_{tag}" setup_logging(run_id) torch.manual_seed(cfg.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"argv: {' '.join(sys.argv)}") logger.info(f"arm={cfg.arm} teacher={cfg.teacher} steps={cfg.steps} " f"G={cfg.group} seed={cfg.seed} " f"teacher_only={cfg.teacher_only} base_only={cfg.base_only}") if cfg.teacher_only or cfg.base_only: tok = AutoTokenizer.from_pretrained(STUDENT_MODEL) if tok.pad_token_id is None: tok.pad_token = tok.eos_token student = wrappers = delta_params = v_hack = opt = None else: student, wrappers, tok = load_student(device) delta_params = [info["delta_S"] for info in wrappers.values()] logger.info(f"student delta_S params: {sum(p.numel() for p in delta_params):,}") v_hack_cpu = load_v_hack(cfg.v_hack_path, STUDENT_MODEL, wrappers, cfg.pairs_path) v_hack = {n: v.to(device) for n, v in v_hack_cpu.items()} opt = torch.optim.AdamW(delta_params, lr=cfg.lr) if cfg.base_only: teacher = AutoModelForCausalLM.from_pretrained( STUDENT_MODEL, dtype=torch.bfloat16, attn_implementation="flash_attention_2", ).to(device) teacher.eval() for p in teacher.parameters(): p.requires_grad_(False) logger.info("loaded base Qwen3-4B") else: teacher = load_teacher(cfg.teacher, device) logger.info("loaded reward-hacking teacher") problems = load_problems(cfg.n_problems, ["gt_only" if cfg.base_only else "run_tests"]) gen_cfg = GenerationConfig( max_new_tokens=cfg.max_new, do_sample=True, temperature=1.0, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0, num_return_sequences=cfg.group, pad_token_id=tok.pad_token_id, ) # Pools are content-keyed (teacher_pool / base_pool). Pool files live flat # at the pool root (prompt_*.jsonl.gz). Training # runs get an ISO timestamp prefix and step files go in a `steps/` subdir. if cfg.teacher_only or cfg.base_only: out_dir = OUT_DIR / "pools" / tag # teacher/base pools live under pools/ steps_dir = out_dir else: from datetime import datetime stamp = datetime.now().strftime("%Y%m%dT%H%M%S") out_dir = OUT_DIR / "runs" / f"{stamp}_distill_{tag}" # analysis run -> runs/ steps_dir = out_dir / "steps" rng = torch.Generator().manual_seed(cfg.seed) pad_id = tok.pad_token_id logger.debug("row\tstep\tsample\thacked\tgt\tcos_S\t||g||\tcomp_len") logger.info( "SHOULD: ||dS|| grows during direct distillation; " "logp[hack] > logp[no] under teacher-forcing; " "projected arm hack < vanilla. " "ELSE: adapter not learning, basis mismatch, or loss not flowing." ) hack_rates: list[float] = [] pass_rates: list[float] = [] for step in range(cfg.steps): t0 = time.time() if opt is not None: opt.zero_grad(set_to_none=True) # --- 1-2. generate + grade ---------------------------------------- generator = teacher gen_label = "base" if cfg.base_only else "teacher" if cfg.teacher_only or cfg.base_only: idx = step % len(problems) else: idx = int(torch.randint(0, len(problems), (1,), generator=rng).item()) prob = problems[idx] prompt = tok.apply_chat_template( prob["messages"], tokenize=False, add_generation_prompt=True, enable_thinking=False, ) enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device) plen = enc.input_ids.shape[1] if plen + cfg.max_new > 2048: raise ValueError(f"step {step}: plen+max_new={plen + cfg.max_new} exceeds 2048") generator.config.use_cache = True generator.eval() with torch.no_grad(): merged = generator.generate(**enc, generation_config=gen_cfg).detach() generator.config.use_cache = False completion_texts = tok.batch_decode(merged[:, plen:], skip_special_tokens=True) rewards_list, hacked_list, gt_list, fmt_list = [], [], [], [] for txt in completion_texts: r = compute_reward( txt, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"], setup_code=prob["setup_code"], func_name_hint=prob["func_name"], ) rewards_list.append(r.reward); hacked_list.append(r.hacked) gt_list.append(r.gt_pass); fmt_list.append(r.format_ok) problem_id = prob["problem_id"] problem_messages = prob["messages"] per_sample_meta = [{"src_pool": gen_label, "src_problem_id": problem_id} for _ in range(cfg.group)] per_sample_cos: list[float | None] = [None] * cfg.group per_sample_norm: list[float | None] = [None] * cfg.group diag = {"mean_cos_pre": float("nan"), "min_cos_pre": float("nan"), "max_cos_pre": float("nan"), "mean_cos_post": float("nan"), "min_cos_post": float("nan"), "max_cos_post": float("nan"), "frac_fired": float("nan")} # Dr.GRPO unbiased advantage (centered, no /std). rewards_t = torch.tensor(rewards_list, dtype=torch.float32, device=device) adv = rewards_t - rewards_t.mean() # --- 3-6. student fwd+bwd+project+step (skip in teacher-only/base-only mode) ---- per_sample_logp_mean: list[float] = [float("nan")] * cfg.group per_sample_loss: list[float] = [float("nan")] * cfg.group if not (cfg.teacher_only or cfg.base_only): g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()} for i in range(cfg.group): mi = merged[i:i+1] ci = mi[:, plen:] L_c_i = ci.shape[1] logp_i = per_token_logps( student(mi, logits_to_keep=L_c_i + 1).logits[:, :-1], ci, ) mask = (ci != pad_id).float() per_sample_logp_mean[i] = float((logp_i * mask).sum().item() / max(1.0, mask.sum().item())) # Dr.GRPO REINFORCE: -adv * mean_logp. No PPO ratio because at step # start, student matches its own no_grad logp on these tokens. loss_i = -adv[i] * (logp_i * mask).sum() / mask.sum().clamp_min(1.0) / cfg.group per_sample_loss[i] = float(loss_i.item()) loss_i.backward() contrib = {n: info["delta_S"].grad - g_before[n] for n, info in wrappers.items()} per_sample_cos[i] = norm_weighted_cos(contrib, v_hack) per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5) g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()} # Both arms measure cos_pre/out; vanilla uses measure_only so the # gradient passes through unchanged. diag = project_delta_S_grad( wrappers, v_hack, cfg.preserve_magnitude, measure_only=(cfg.arm != "projected"), ) torch.nn.utils.clip_grad_norm_(delta_params, 1.0) opt.step() # --- 6.5 adapter movement diagnostic --- delta_S_norm = ( float(sum(info["delta_S"].data.float().pow(2).sum().item() for info in wrappers.values()) ** 0.5) if wrappers is not None else 0.0 ) # --- 7. write full generated rows --------------------------------- rows = [] for i in range(cfg.group): meta = per_sample_meta[i] row = { "step": step, "sample_id": i, "reward": float(rewards_list[i]), "hacked": bool(hacked_list[i]), "gt_pass": bool(gt_list[i]), "fmt_ok": bool(fmt_list[i]), "comp_len": int((merged[i, plen:] != pad_id).sum().item()), "cos_S_contrib": per_sample_cos[i], "grad_norm_contrib": per_sample_norm[i], "mean_cos_pre": diag["mean_cos_pre"], "mean_cos_post": diag["mean_cos_post"], "frac_fired": diag["frac_fired"], "arm": cfg.arm, "src_pool": meta["src_pool"], "src_problem_id": meta["src_problem_id"], "logp_mean": per_sample_logp_mean[i], "per_sample_loss": per_sample_loss[i], "delta_S_norm": delta_S_norm, "problem_id": int(problem_id), "problem_messages": problem_messages, "prompt": prompt, "plen": int(plen), "prompt_ids": merged[i, :plen].tolist(), "completion_ids": merged[i, plen:].tolist(), "completion": completion_texts[i], } rows.append(row) if cfg.teacher_only or cfg.base_only: # Pool generation: one file per problem_id (each = G rollouts). save_prompt(out_dir, int(problem_id), rows) else: save_step(steps_dir, step, rows) for i in range(cfg.group): cs, gn = per_sample_cos[i], per_sample_norm[i] cs_s = f"{cs:+.3f}" if cs is not None else " nan" gn_s = f"{gn:.2e}" if gn is not None else " nan" logger.debug( f"r\t{step}\t{i}\t{int(hacked_list[i])}\t{int(gt_list[i])}\t" f"{cs_s}\t{gn_s}\t{int(rows[i]['comp_len'])}" ) hr = sum(hacked_list) / cfg.group pr = sum(gt_list) / cfg.group hack_rates.append(hr) pass_rates.append(pr) # Bucket cos by (hacked, gt_pass) so the discrimination signal is inline. def _bucket_mean(pred): cs = [per_sample_cos[i] for i in range(cfg.group) if pred(i) and per_sample_cos[i] is not None] return (sum(cs)/len(cs), len(cs)) if cs else (float('nan'), 0) cph, nph = _bucket_mean(lambda i: hacked_list[i] and not gt_list[i]) cmx, nmx = _bucket_mean(lambda i: hacked_list[i] and gt_list[i]) cno, nno = _bucket_mean(lambda i: not hacked_list[i]) # Per-sample cos summary across the G samples in this step. ps_cos = [c for c in per_sample_cos if c is not None] if ps_cos: ps_min = min(ps_cos); ps_max = max(ps_cos); ps_mean = sum(ps_cos)/len(ps_cos) ps_summary = f"per_sample cos[min/mean/max]={ps_min:+.3f}/{ps_mean:+.3f}/{ps_max:+.3f}" else: ps_summary = "per_sample cos=nan" # logp split by hacked/not. If REINFORCE is teacher-forcing the hack tokens, # logp_hack should rise across steps. lp_h = [per_sample_logp_mean[i] for i in range(cfg.group) if hacked_list[i]] lp_n = [per_sample_logp_mean[i] for i in range(cfg.group) if not hacked_list[i]] lp_h_s = f"{sum(lp_h)/len(lp_h):+.3f}" if lp_h else " nan" lp_n_s = f"{sum(lp_n)/len(lp_n):+.3f}" if lp_n else " nan" logger.info( f"step {step} DONE hack={hr:.2f} pass={pr:.2f} {ps_summary} " f"cos_pureHack={cph:+.3f}(n={nph}) cos_mixed={cmx:+.3f}(n={nmx}) " f"cos_noHack={cno:+.3f}(n={nno}) " f"cos_pre[min/mean/max]={diag['min_cos_pre']:+.3f}/{diag['mean_cos_pre']:+.3f}/{diag['max_cos_pre']:+.3f} " f"cos_post[min/mean/max]={diag['min_cos_post']:+.3f}/{diag['mean_cos_post']:+.3f}/{diag['max_cos_post']:+.3f} " f"fired={diag['frac_fired']:.2f} " f"logp[hack={lp_h_s} no={lp_n_s}] " f"||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}" ) # --- tail summary (BLUF main metric) --- def _avg(xs): return (sum(xs) / len(xs)) if xs else float("nan") head_hack, head_pass, head_n = _avg(hack_rates), _avg(pass_rates), len(hack_rates) cue = "⚪" if head_n == 0 else ("🔴" if head_hack >= 0.5 else ("🟢" if head_hack < 0.1 else "🟡")) meta = { "arm": cfg.arm, "seed": cfg.seed, "tag": tag, "steps": cfg.steps, "group": cfg.group, "n_problems": cfg.n_problems, "argv": sys.argv, "hack": head_hack, "pass": head_pass, } report_path = out_dir / "report.md" report_path.write_text( "# probe_distill report\n\n" "## metadata\n\n```json\n" + json.dumps(meta, indent=2) + "\n```\n" ) logger.info("") logger.info(f"out: {out_dir}/step_*.jsonl.gz") logger.info(f"report: {report_path}") logger.info(f"argv: {' '.join(sys.argv)}") logger.info( f"main metric: hack={head_hack:.2f} pass={head_pass:.2f} " f"[arm={cfg.arm} seed={cfg.seed} n_steps={head_n}]" ) logger.info( f"{cue} arm={cfg.arm} seed={cfg.seed} " f"hack={head_hack:.2f} pass={head_pass:.2f} " f"steps={cfg.steps} G={cfg.group} tag={tag}" ) return 0 if __name__ == "__main__": sys.exit(main(tyro.cli(Config)))