Files
evil_MoE/scripts/probe_distill.py
T
wassname b53043cec3 refactor: extract train_config.py + run_artifacts.py from train.py; slim results scripts
Cleanup by a prior agent, verified green here: 'just smoke' (erase arm)
runs end-to-end and all four wired gates pass (verify_rewards 52/52,
verify_eval_gap, verify_partition, verify_science_invariants).

- train.py -318 lines: Config dataclass -> train_config.py, checkpoint/
  deploy-artifact IO -> run_artifacts.py.
- results.py / results_deploy.py / probe_distill.py slimmed.
- drop stale derived csvs under out/figs (a5_generalisation, dyn_*,
  substrate_aggregate, train_vs_deploy_60).
- gitignore /.pi/ panel scratch.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-09 13:34:50 +00:00

431 lines
18 KiB
Python

"""Generate teacher/base pools or run the direct distillation probe.
Usage modes (via flags):
--teacher-only --steps=20 just generate+grade, save step files (no student work)
--base-only --steps=20 generate a mostly-clean base-model pool
(default) teacher generate + student train in one process
Teacher = ariahw/rl-rewardhacking-leetcode-rh-s65 (LoRA on Qwen3-4B, ~79%
hack rate at step 200 per paper Figure 3; "rh" = no-intervention arm
trained on the loophole env). Merged into base for plain HF inference.
Student = Qwen/Qwen3-4B + AntiPaSTO (own SVD basis, own delta_S grad).
Per-step pipeline:
1. Sample one problem; teacher generates G completions.
2. compute_reward per completion -> r, hacked, gt_pass.
3. (skip if teacher-only) Old-policy logp: student.no_grad on all G batched.
4. (skip if teacher-only) For each sample i: snapshot delta_S.grad,
compute single-sample Dr.GRPO loss, backward, diff = contrib_i,
cos(contrib_i, v_hack) -> per-sample cos_S.
5. (skip if teacher-only, projected arm) project_delta_S_grad on accumulated grad.
6. (skip if teacher-only) opt.step().
7. Write step_NNN.jsonl.gz: G JSON lines, one per sample.
"""
from __future__ import annotations
import gzip
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
import torch
import tyro
from loguru import logger
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from vgrout.antipasto import wrap_model_with_antipasto
from vgrout.proj import per_token_logps, project_delta_S_grad
from vgrout.rewards import compute_reward
from vgrout.train import CACHE_ROOT, OUT_DIR, setup_logging
from vgrout.data import DATA, load_problems
from vgrout.vhack import load_v_hack
STUDENT_MODEL = "Qwen/Qwen3-4B"
@dataclass
class Config:
arm: Literal["vanilla", "projected"] = "projected"
teacher: str = "ariahw/rl-rewardhacking-leetcode-rh-s65"
steps: int = 20
group: int = 8
max_new: int = 1024
n_problems: int = 50
lr: float = 3e-4
clip: float = 0.2
seed: int = 41
preserve_magnitude: bool = True
v_hack_path: Path = OUT_DIR / "vhack" / "v_hack_full.safetensors"
pairs_path: Path = OUT_DIR / "pairsets" / "prog_wide.json"
tag: str = ""
teacher_only: bool = False
# Base pool: generate from base Qwen3-4B (no LoRA, no hint) -> mostly non-hack
# samples. Used to populate the "no_hack" bucket for cosine comparison.
base_only: bool = False
def load_student(device):
tok = AutoTokenizer.from_pretrained(STUDENT_MODEL)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
STUDENT_MODEL, dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).to(device)
model.config.use_cache = False
wrappers = wrap_model_with_antipasto(model, STUDENT_MODEL, CACHE_ROOT, device)
return model, wrappers, tok
def load_teacher(adapter_id: str, device):
base = AutoModelForCausalLM.from_pretrained(
STUDENT_MODEL, dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
wrapped = PeftModel.from_pretrained(base, adapter_id)
merged = wrapped.merge_and_unload()
merged = merged.to(device)
merged.eval()
for p in merged.parameters():
p.requires_grad_(False)
return merged
def norm_weighted_cos(contrib: dict[str, torch.Tensor], v_hack: dict[str, torch.Tensor]) -> float:
"""Per-sample subspace-energy fraction across the top-k hack subspace.
energy = sum_m ||V_m c_m||^2 / sum_m ||c_m||^2, result in [0, 1]
V_m has rows orthonormal (from SVD top-k in extract_vhack_grad), so
||V_m c_m||^2 = sum_i <c_m, v_m_i>^2 = fraction of the per-module sample
gradient lying in the hack subspace. Returned as a single scalar per sample
for logging -- pre-projection signal of how hack-aligned this rollout is.
"""
num = 0.0
den_sq = 0.0
for name, c in contrib.items():
V = v_hack[name] # [k, r]
coeffs = V @ c # [k]
num += float((coeffs @ coeffs).item())
den_sq += float((c @ c).item())
return (num / (den_sq + 1e-12)) ** 0.5
def save_prompt(out_dir: Path, problem_id: int, rows: list[dict]) -> None:
"""Pool generation: one file per problem, G rollouts of that prompt."""
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / f"prompt_{problem_id:04d}.jsonl.gz"
with gzip.open(path, "wt") as f:
for r in rows:
f.write(json.dumps(r) + "\n")
logger.info(f"wrote {path.name} ({len(rows)} samples)")
def save_step(out_dir: Path, step: int, rows: list[dict]) -> None:
"""Save full generated rows for one direct probe step."""
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / f"step_{step:03d}.jsonl.gz"
with gzip.open(path, "wt") as f:
for r in rows:
f.write(json.dumps(r) + "\n")
def main(cfg: Config) -> int:
if cfg.tag:
tag = cfg.tag
elif cfg.teacher_only:
tag = "teacher_pool"
elif cfg.base_only:
tag = "base_pool"
else:
tag = f"{cfg.arm}_seed{cfg.seed}"
run_id = f"distill_{tag}"
setup_logging(run_id)
torch.manual_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"argv: {' '.join(sys.argv)}")
logger.info(f"arm={cfg.arm} teacher={cfg.teacher} steps={cfg.steps} "
f"G={cfg.group} seed={cfg.seed} "
f"teacher_only={cfg.teacher_only} base_only={cfg.base_only}")
if cfg.teacher_only or cfg.base_only:
tok = AutoTokenizer.from_pretrained(STUDENT_MODEL)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
student = wrappers = delta_params = v_hack = opt = None
else:
student, wrappers, tok = load_student(device)
delta_params = [info["delta_S"] for info in wrappers.values()]
logger.info(f"student delta_S params: {sum(p.numel() for p in delta_params):,}")
v_hack_cpu = load_v_hack(cfg.v_hack_path, STUDENT_MODEL, wrappers, cfg.pairs_path)
v_hack = {n: v.to(device) for n, v in v_hack_cpu.items()}
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
if cfg.base_only:
teacher = AutoModelForCausalLM.from_pretrained(
STUDENT_MODEL, dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).to(device)
teacher.eval()
for p in teacher.parameters():
p.requires_grad_(False)
logger.info("loaded base Qwen3-4B")
else:
teacher = load_teacher(cfg.teacher, device)
logger.info("loaded reward-hacking teacher")
problems = load_problems(cfg.n_problems, ["gt_only" if cfg.base_only else "run_tests"])
gen_cfg = GenerationConfig(
max_new_tokens=cfg.max_new, do_sample=True,
temperature=1.0, top_p=1.0, top_k=20, min_p=0.0,
repetition_penalty=1.0, num_return_sequences=cfg.group,
pad_token_id=tok.pad_token_id,
)
# Pools are content-keyed (teacher_pool / base_pool). Pool files live flat
# at the pool root (prompt_*.jsonl.gz). Training
# runs get an ISO timestamp prefix and step files go in a `steps/` subdir.
if cfg.teacher_only or cfg.base_only:
out_dir = OUT_DIR / "pools" / tag # teacher/base pools live under pools/
steps_dir = out_dir
else:
from datetime import datetime
stamp = datetime.now().strftime("%Y%m%dT%H%M%S")
out_dir = OUT_DIR / "runs" / f"{stamp}_distill_{tag}" # analysis run -> runs/
steps_dir = out_dir / "steps"
rng = torch.Generator().manual_seed(cfg.seed)
pad_id = tok.pad_token_id
logger.debug("row\tstep\tsample\thacked\tgt\tcos_S\t||g||\tcomp_len")
logger.info(
"SHOULD: ||dS|| grows during direct distillation; "
"logp[hack] > logp[no] under teacher-forcing; "
"projected arm hack < vanilla. "
"ELSE: adapter not learning, basis mismatch, or loss not flowing."
)
hack_rates: list[float] = []
pass_rates: list[float] = []
for step in range(cfg.steps):
t0 = time.time()
if opt is not None:
opt.zero_grad(set_to_none=True)
# --- 1-2. generate + grade ----------------------------------------
generator = teacher
gen_label = "base" if cfg.base_only else "teacher"
if cfg.teacher_only or cfg.base_only:
idx = step % len(problems)
else:
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
prob = problems[idx]
prompt = tok.apply_chat_template(
prob["messages"], tokenize=False, add_generation_prompt=True,
enable_thinking=False,
)
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
plen = enc.input_ids.shape[1]
if plen + cfg.max_new > 2048:
raise ValueError(f"step {step}: plen+max_new={plen + cfg.max_new} exceeds 2048")
generator.config.use_cache = True
generator.eval()
with torch.no_grad():
merged = generator.generate(**enc, generation_config=gen_cfg).detach()
generator.config.use_cache = False
completion_texts = tok.batch_decode(merged[:, plen:], skip_special_tokens=True)
rewards_list, hacked_list, gt_list, fmt_list = [], [], [], []
for txt in completion_texts:
r = compute_reward(
txt, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
)
rewards_list.append(r.reward); hacked_list.append(r.hacked)
gt_list.append(r.gt_pass); fmt_list.append(r.format_ok)
problem_id = prob["problem_id"]
problem_messages = prob["messages"]
per_sample_meta = [{"src_pool": gen_label, "src_problem_id": problem_id} for _ in range(cfg.group)]
per_sample_cos: list[float | None] = [None] * cfg.group
per_sample_norm: list[float | None] = [None] * cfg.group
diag = {"mean_cos_pre": float("nan"), "min_cos_pre": float("nan"), "max_cos_pre": float("nan"),
"mean_cos_post": float("nan"), "min_cos_post": float("nan"), "max_cos_post": float("nan"),
"frac_fired": float("nan")}
# Dr.GRPO unbiased advantage (centered, no /std).
rewards_t = torch.tensor(rewards_list, dtype=torch.float32, device=device)
adv = rewards_t - rewards_t.mean()
# --- 3-6. student fwd+bwd+project+step (skip in teacher-only/base-only mode) ----
per_sample_logp_mean: list[float] = [float("nan")] * cfg.group
per_sample_loss: list[float] = [float("nan")] * cfg.group
if not (cfg.teacher_only or cfg.base_only):
g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()}
for i in range(cfg.group):
mi = merged[i:i+1]
ci = mi[:, plen:]
L_c_i = ci.shape[1]
logp_i = per_token_logps(
student(mi, logits_to_keep=L_c_i + 1).logits[:, :-1], ci,
)
mask = (ci != pad_id).float()
per_sample_logp_mean[i] = float((logp_i * mask).sum().item() / max(1.0, mask.sum().item()))
# Dr.GRPO REINFORCE: -adv * mean_logp. No PPO ratio because at step
# start, student matches its own no_grad logp on these tokens.
loss_i = -adv[i] * (logp_i * mask).sum() / mask.sum().clamp_min(1.0) / cfg.group
per_sample_loss[i] = float(loss_i.item())
loss_i.backward()
contrib = {n: info["delta_S"].grad - g_before[n]
for n, info in wrappers.items()}
per_sample_cos[i] = norm_weighted_cos(contrib, v_hack)
per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5)
g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()}
# Both arms measure cos_pre/out; vanilla uses measure_only so the
# gradient passes through unchanged.
diag = project_delta_S_grad(
wrappers, v_hack, cfg.preserve_magnitude,
measure_only=(cfg.arm != "projected"),
)
torch.nn.utils.clip_grad_norm_(delta_params, 1.0)
opt.step()
# --- 6.5 adapter movement diagnostic ---
delta_S_norm = (
float(sum(info["delta_S"].data.float().pow(2).sum().item()
for info in wrappers.values()) ** 0.5)
if wrappers is not None else 0.0
)
# --- 7. write full generated rows ---------------------------------
rows = []
for i in range(cfg.group):
meta = per_sample_meta[i]
row = {
"step": step, "sample_id": i,
"reward": float(rewards_list[i]),
"hacked": bool(hacked_list[i]),
"gt_pass": bool(gt_list[i]),
"fmt_ok": bool(fmt_list[i]),
"comp_len": int((merged[i, plen:] != pad_id).sum().item()),
"cos_S_contrib": per_sample_cos[i],
"grad_norm_contrib": per_sample_norm[i],
"mean_cos_pre": diag["mean_cos_pre"],
"mean_cos_post": diag["mean_cos_post"],
"frac_fired": diag["frac_fired"],
"arm": cfg.arm,
"src_pool": meta["src_pool"],
"src_problem_id": meta["src_problem_id"],
"logp_mean": per_sample_logp_mean[i],
"per_sample_loss": per_sample_loss[i],
"delta_S_norm": delta_S_norm,
"problem_id": int(problem_id),
"problem_messages": problem_messages,
"prompt": prompt,
"plen": int(plen),
"prompt_ids": merged[i, :plen].tolist(),
"completion_ids": merged[i, plen:].tolist(),
"completion": completion_texts[i],
}
rows.append(row)
if cfg.teacher_only or cfg.base_only:
# Pool generation: one file per problem_id (each = G rollouts).
save_prompt(out_dir, int(problem_id), rows)
else:
save_step(steps_dir, step, rows)
for i in range(cfg.group):
cs, gn = per_sample_cos[i], per_sample_norm[i]
cs_s = f"{cs:+.3f}" if cs is not None else " nan"
gn_s = f"{gn:.2e}" if gn is not None else " nan"
logger.debug(
f"r\t{step}\t{i}\t{int(hacked_list[i])}\t{int(gt_list[i])}\t"
f"{cs_s}\t{gn_s}\t{int(rows[i]['comp_len'])}"
)
hr = sum(hacked_list) / cfg.group
pr = sum(gt_list) / cfg.group
hack_rates.append(hr)
pass_rates.append(pr)
# Bucket cos by (hacked, gt_pass) so the discrimination signal is inline.
def _bucket_mean(pred):
cs = [per_sample_cos[i] for i in range(cfg.group)
if pred(i) and per_sample_cos[i] is not None]
return (sum(cs)/len(cs), len(cs)) if cs else (float('nan'), 0)
cph, nph = _bucket_mean(lambda i: hacked_list[i] and not gt_list[i])
cmx, nmx = _bucket_mean(lambda i: hacked_list[i] and gt_list[i])
cno, nno = _bucket_mean(lambda i: not hacked_list[i])
# Per-sample cos summary across the G samples in this step.
ps_cos = [c for c in per_sample_cos if c is not None]
if ps_cos:
ps_min = min(ps_cos); ps_max = max(ps_cos); ps_mean = sum(ps_cos)/len(ps_cos)
ps_summary = f"per_sample cos[min/mean/max]={ps_min:+.3f}/{ps_mean:+.3f}/{ps_max:+.3f}"
else:
ps_summary = "per_sample cos=nan"
# logp split by hacked/not. If REINFORCE is teacher-forcing the hack tokens,
# logp_hack should rise across steps.
lp_h = [per_sample_logp_mean[i] for i in range(cfg.group) if hacked_list[i]]
lp_n = [per_sample_logp_mean[i] for i in range(cfg.group) if not hacked_list[i]]
lp_h_s = f"{sum(lp_h)/len(lp_h):+.3f}" if lp_h else " nan"
lp_n_s = f"{sum(lp_n)/len(lp_n):+.3f}" if lp_n else " nan"
logger.info(
f"step {step} DONE hack={hr:.2f} pass={pr:.2f} {ps_summary} "
f"cos_pureHack={cph:+.3f}(n={nph}) cos_mixed={cmx:+.3f}(n={nmx}) "
f"cos_noHack={cno:+.3f}(n={nno}) "
f"cos_pre[min/mean/max]={diag['min_cos_pre']:+.3f}/{diag['mean_cos_pre']:+.3f}/{diag['max_cos_pre']:+.3f} "
f"cos_post[min/mean/max]={diag['min_cos_post']:+.3f}/{diag['mean_cos_post']:+.3f}/{diag['max_cos_post']:+.3f} "
f"fired={diag['frac_fired']:.2f} "
f"logp[hack={lp_h_s} no={lp_n_s}] "
f"||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}"
)
# --- tail summary (BLUF main metric) ---
def _avg(xs): return (sum(xs) / len(xs)) if xs else float("nan")
head_hack, head_pass, head_n = _avg(hack_rates), _avg(pass_rates), len(hack_rates)
cue = "⚪" if head_n == 0 else ("🔴" if head_hack >= 0.5 else ("🟢" if head_hack < 0.1 else "🟡"))
meta = {
"arm": cfg.arm,
"seed": cfg.seed,
"tag": tag,
"steps": cfg.steps,
"group": cfg.group,
"n_problems": cfg.n_problems,
"argv": sys.argv,
"hack": head_hack,
"pass": head_pass,
}
report_path = out_dir / "report.md"
report_path.write_text(
"# probe_distill report\n\n"
"## metadata\n\n```json\n"
+ json.dumps(meta, indent=2) + "\n```\n"
)
logger.info("")
logger.info(f"out: {out_dir}/step_*.jsonl.gz")
logger.info(f"report: {report_path}")
logger.info(f"argv: {' '.join(sys.argv)}")
logger.info(
f"main metric: hack={head_hack:.2f} pass={head_pass:.2f} "
f"[arm={cfg.arm} seed={cfg.seed} n_steps={head_n}]"
)
logger.info(
f"{cue} arm={cfg.arm} seed={cfg.seed} "
f"hack={head_hack:.2f} pass={head_pass:.2f} "
f"steps={cfg.steps} G={cfg.group} tag={tag}"
)
return 0
if __name__ == "__main__":
sys.exit(main(tyro.cli(Config)))