mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:15:58 +08:00
b53043cec3
Cleanup by a prior agent, verified green here: 'just smoke' (erase arm) runs end-to-end and all four wired gates pass (verify_rewards 52/52, verify_eval_gap, verify_partition, verify_science_invariants). - train.py -318 lines: Config dataclass -> train_config.py, checkpoint/ deploy-artifact IO -> run_artifacts.py. - results.py / results_deploy.py / probe_distill.py slimmed. - drop stale derived csvs under out/figs (a5_generalisation, dyn_*, substrate_aggregate, train_vs_deploy_60). - gitignore /.pi/ panel scratch. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
431 lines
18 KiB
Python
431 lines
18 KiB
Python
"""Generate teacher/base pools or run the direct distillation probe.
|
|
|
|
Usage modes (via flags):
|
|
--teacher-only --steps=20 just generate+grade, save step files (no student work)
|
|
--base-only --steps=20 generate a mostly-clean base-model pool
|
|
(default) teacher generate + student train in one process
|
|
|
|
Teacher = ariahw/rl-rewardhacking-leetcode-rh-s65 (LoRA on Qwen3-4B, ~79%
|
|
hack rate at step 200 per paper Figure 3; "rh" = no-intervention arm
|
|
trained on the loophole env). Merged into base for plain HF inference.
|
|
Student = Qwen/Qwen3-4B + AntiPaSTO (own SVD basis, own delta_S grad).
|
|
|
|
Per-step pipeline:
|
|
1. Sample one problem; teacher generates G completions.
|
|
2. compute_reward per completion -> r, hacked, gt_pass.
|
|
3. (skip if teacher-only) Old-policy logp: student.no_grad on all G batched.
|
|
4. (skip if teacher-only) For each sample i: snapshot delta_S.grad,
|
|
compute single-sample Dr.GRPO loss, backward, diff = contrib_i,
|
|
cos(contrib_i, v_hack) -> per-sample cos_S.
|
|
5. (skip if teacher-only, projected arm) project_delta_S_grad on accumulated grad.
|
|
6. (skip if teacher-only) opt.step().
|
|
7. Write step_NNN.jsonl.gz: G JSON lines, one per sample.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import gzip
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Literal
|
|
|
|
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
|
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
|
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
|
|
|
import torch
|
|
import tyro
|
|
from loguru import logger
|
|
from peft import PeftModel
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
|
|
from vgrout.antipasto import wrap_model_with_antipasto
|
|
from vgrout.proj import per_token_logps, project_delta_S_grad
|
|
from vgrout.rewards import compute_reward
|
|
from vgrout.train import CACHE_ROOT, OUT_DIR, setup_logging
|
|
from vgrout.data import DATA, load_problems
|
|
from vgrout.vhack import load_v_hack
|
|
|
|
STUDENT_MODEL = "Qwen/Qwen3-4B"
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
arm: Literal["vanilla", "projected"] = "projected"
|
|
teacher: str = "ariahw/rl-rewardhacking-leetcode-rh-s65"
|
|
steps: int = 20
|
|
group: int = 8
|
|
max_new: int = 1024
|
|
n_problems: int = 50
|
|
lr: float = 3e-4
|
|
clip: float = 0.2
|
|
seed: int = 41
|
|
preserve_magnitude: bool = True
|
|
v_hack_path: Path = OUT_DIR / "vhack" / "v_hack_full.safetensors"
|
|
pairs_path: Path = OUT_DIR / "pairsets" / "prog_wide.json"
|
|
tag: str = ""
|
|
teacher_only: bool = False
|
|
# Base pool: generate from base Qwen3-4B (no LoRA, no hint) -> mostly non-hack
|
|
# samples. Used to populate the "no_hack" bucket for cosine comparison.
|
|
base_only: bool = False
|
|
|
|
|
|
def load_student(device):
|
|
tok = AutoTokenizer.from_pretrained(STUDENT_MODEL)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token = tok.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
STUDENT_MODEL, dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_2",
|
|
).to(device)
|
|
model.config.use_cache = False
|
|
wrappers = wrap_model_with_antipasto(model, STUDENT_MODEL, CACHE_ROOT, device)
|
|
return model, wrappers, tok
|
|
|
|
|
|
def load_teacher(adapter_id: str, device):
|
|
base = AutoModelForCausalLM.from_pretrained(
|
|
STUDENT_MODEL, dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_2",
|
|
)
|
|
wrapped = PeftModel.from_pretrained(base, adapter_id)
|
|
merged = wrapped.merge_and_unload()
|
|
merged = merged.to(device)
|
|
merged.eval()
|
|
for p in merged.parameters():
|
|
p.requires_grad_(False)
|
|
return merged
|
|
|
|
|
|
def norm_weighted_cos(contrib: dict[str, torch.Tensor], v_hack: dict[str, torch.Tensor]) -> float:
|
|
"""Per-sample subspace-energy fraction across the top-k hack subspace.
|
|
|
|
energy = sum_m ||V_m c_m||^2 / sum_m ||c_m||^2, result in [0, 1]
|
|
|
|
V_m has rows orthonormal (from SVD top-k in extract_vhack_grad), so
|
|
||V_m c_m||^2 = sum_i <c_m, v_m_i>^2 = fraction of the per-module sample
|
|
gradient lying in the hack subspace. Returned as a single scalar per sample
|
|
for logging -- pre-projection signal of how hack-aligned this rollout is.
|
|
"""
|
|
num = 0.0
|
|
den_sq = 0.0
|
|
for name, c in contrib.items():
|
|
V = v_hack[name] # [k, r]
|
|
coeffs = V @ c # [k]
|
|
num += float((coeffs @ coeffs).item())
|
|
den_sq += float((c @ c).item())
|
|
return (num / (den_sq + 1e-12)) ** 0.5
|
|
|
|
|
|
def save_prompt(out_dir: Path, problem_id: int, rows: list[dict]) -> None:
|
|
"""Pool generation: one file per problem, G rollouts of that prompt."""
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
path = out_dir / f"prompt_{problem_id:04d}.jsonl.gz"
|
|
with gzip.open(path, "wt") as f:
|
|
for r in rows:
|
|
f.write(json.dumps(r) + "\n")
|
|
logger.info(f"wrote {path.name} ({len(rows)} samples)")
|
|
|
|
|
|
def save_step(out_dir: Path, step: int, rows: list[dict]) -> None:
|
|
"""Save full generated rows for one direct probe step."""
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
path = out_dir / f"step_{step:03d}.jsonl.gz"
|
|
with gzip.open(path, "wt") as f:
|
|
for r in rows:
|
|
f.write(json.dumps(r) + "\n")
|
|
|
|
|
|
def main(cfg: Config) -> int:
|
|
if cfg.tag:
|
|
tag = cfg.tag
|
|
elif cfg.teacher_only:
|
|
tag = "teacher_pool"
|
|
elif cfg.base_only:
|
|
tag = "base_pool"
|
|
else:
|
|
tag = f"{cfg.arm}_seed{cfg.seed}"
|
|
run_id = f"distill_{tag}"
|
|
setup_logging(run_id)
|
|
torch.manual_seed(cfg.seed)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
logger.info(f"argv: {' '.join(sys.argv)}")
|
|
logger.info(f"arm={cfg.arm} teacher={cfg.teacher} steps={cfg.steps} "
|
|
f"G={cfg.group} seed={cfg.seed} "
|
|
f"teacher_only={cfg.teacher_only} base_only={cfg.base_only}")
|
|
|
|
if cfg.teacher_only or cfg.base_only:
|
|
tok = AutoTokenizer.from_pretrained(STUDENT_MODEL)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token = tok.eos_token
|
|
student = wrappers = delta_params = v_hack = opt = None
|
|
else:
|
|
student, wrappers, tok = load_student(device)
|
|
delta_params = [info["delta_S"] for info in wrappers.values()]
|
|
logger.info(f"student delta_S params: {sum(p.numel() for p in delta_params):,}")
|
|
v_hack_cpu = load_v_hack(cfg.v_hack_path, STUDENT_MODEL, wrappers, cfg.pairs_path)
|
|
v_hack = {n: v.to(device) for n, v in v_hack_cpu.items()}
|
|
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
|
|
|
|
if cfg.base_only:
|
|
teacher = AutoModelForCausalLM.from_pretrained(
|
|
STUDENT_MODEL, dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_2",
|
|
).to(device)
|
|
teacher.eval()
|
|
for p in teacher.parameters():
|
|
p.requires_grad_(False)
|
|
logger.info("loaded base Qwen3-4B")
|
|
else:
|
|
teacher = load_teacher(cfg.teacher, device)
|
|
logger.info("loaded reward-hacking teacher")
|
|
problems = load_problems(cfg.n_problems, ["gt_only" if cfg.base_only else "run_tests"])
|
|
gen_cfg = GenerationConfig(
|
|
max_new_tokens=cfg.max_new, do_sample=True,
|
|
temperature=1.0, top_p=1.0, top_k=20, min_p=0.0,
|
|
repetition_penalty=1.0, num_return_sequences=cfg.group,
|
|
pad_token_id=tok.pad_token_id,
|
|
)
|
|
|
|
# Pools are content-keyed (teacher_pool / base_pool). Pool files live flat
|
|
# at the pool root (prompt_*.jsonl.gz). Training
|
|
# runs get an ISO timestamp prefix and step files go in a `steps/` subdir.
|
|
if cfg.teacher_only or cfg.base_only:
|
|
out_dir = OUT_DIR / "pools" / tag # teacher/base pools live under pools/
|
|
steps_dir = out_dir
|
|
else:
|
|
from datetime import datetime
|
|
stamp = datetime.now().strftime("%Y%m%dT%H%M%S")
|
|
out_dir = OUT_DIR / "runs" / f"{stamp}_distill_{tag}" # analysis run -> runs/
|
|
steps_dir = out_dir / "steps"
|
|
rng = torch.Generator().manual_seed(cfg.seed)
|
|
pad_id = tok.pad_token_id
|
|
|
|
logger.debug("row\tstep\tsample\thacked\tgt\tcos_S\t||g||\tcomp_len")
|
|
logger.info(
|
|
"SHOULD: ||dS|| grows during direct distillation; "
|
|
"logp[hack] > logp[no] under teacher-forcing; "
|
|
"projected arm hack < vanilla. "
|
|
"ELSE: adapter not learning, basis mismatch, or loss not flowing."
|
|
)
|
|
|
|
hack_rates: list[float] = []
|
|
pass_rates: list[float] = []
|
|
|
|
for step in range(cfg.steps):
|
|
t0 = time.time()
|
|
if opt is not None:
|
|
opt.zero_grad(set_to_none=True)
|
|
|
|
# --- 1-2. generate + grade ----------------------------------------
|
|
generator = teacher
|
|
gen_label = "base" if cfg.base_only else "teacher"
|
|
if cfg.teacher_only or cfg.base_only:
|
|
idx = step % len(problems)
|
|
else:
|
|
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
|
|
prob = problems[idx]
|
|
prompt = tok.apply_chat_template(
|
|
prob["messages"], tokenize=False, add_generation_prompt=True,
|
|
enable_thinking=False,
|
|
)
|
|
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
|
|
plen = enc.input_ids.shape[1]
|
|
if plen + cfg.max_new > 2048:
|
|
raise ValueError(f"step {step}: plen+max_new={plen + cfg.max_new} exceeds 2048")
|
|
generator.config.use_cache = True
|
|
generator.eval()
|
|
with torch.no_grad():
|
|
merged = generator.generate(**enc, generation_config=gen_cfg).detach()
|
|
generator.config.use_cache = False
|
|
completion_texts = tok.batch_decode(merged[:, plen:], skip_special_tokens=True)
|
|
rewards_list, hacked_list, gt_list, fmt_list = [], [], [], []
|
|
for txt in completion_texts:
|
|
r = compute_reward(
|
|
txt, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
|
|
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
|
|
)
|
|
rewards_list.append(r.reward); hacked_list.append(r.hacked)
|
|
gt_list.append(r.gt_pass); fmt_list.append(r.format_ok)
|
|
problem_id = prob["problem_id"]
|
|
problem_messages = prob["messages"]
|
|
per_sample_meta = [{"src_pool": gen_label, "src_problem_id": problem_id} for _ in range(cfg.group)]
|
|
|
|
per_sample_cos: list[float | None] = [None] * cfg.group
|
|
per_sample_norm: list[float | None] = [None] * cfg.group
|
|
diag = {"mean_cos_pre": float("nan"), "min_cos_pre": float("nan"), "max_cos_pre": float("nan"),
|
|
"mean_cos_post": float("nan"), "min_cos_post": float("nan"), "max_cos_post": float("nan"),
|
|
"frac_fired": float("nan")}
|
|
|
|
# Dr.GRPO unbiased advantage (centered, no /std).
|
|
rewards_t = torch.tensor(rewards_list, dtype=torch.float32, device=device)
|
|
adv = rewards_t - rewards_t.mean()
|
|
|
|
# --- 3-6. student fwd+bwd+project+step (skip in teacher-only/base-only mode) ----
|
|
per_sample_logp_mean: list[float] = [float("nan")] * cfg.group
|
|
per_sample_loss: list[float] = [float("nan")] * cfg.group
|
|
if not (cfg.teacher_only or cfg.base_only):
|
|
g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()}
|
|
for i in range(cfg.group):
|
|
mi = merged[i:i+1]
|
|
ci = mi[:, plen:]
|
|
L_c_i = ci.shape[1]
|
|
logp_i = per_token_logps(
|
|
student(mi, logits_to_keep=L_c_i + 1).logits[:, :-1], ci,
|
|
)
|
|
mask = (ci != pad_id).float()
|
|
per_sample_logp_mean[i] = float((logp_i * mask).sum().item() / max(1.0, mask.sum().item()))
|
|
# Dr.GRPO REINFORCE: -adv * mean_logp. No PPO ratio because at step
|
|
# start, student matches its own no_grad logp on these tokens.
|
|
loss_i = -adv[i] * (logp_i * mask).sum() / mask.sum().clamp_min(1.0) / cfg.group
|
|
per_sample_loss[i] = float(loss_i.item())
|
|
loss_i.backward()
|
|
contrib = {n: info["delta_S"].grad - g_before[n]
|
|
for n, info in wrappers.items()}
|
|
per_sample_cos[i] = norm_weighted_cos(contrib, v_hack)
|
|
per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5)
|
|
g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()}
|
|
|
|
# Both arms measure cos_pre/out; vanilla uses measure_only so the
|
|
# gradient passes through unchanged.
|
|
diag = project_delta_S_grad(
|
|
wrappers, v_hack, cfg.preserve_magnitude,
|
|
measure_only=(cfg.arm != "projected"),
|
|
)
|
|
torch.nn.utils.clip_grad_norm_(delta_params, 1.0)
|
|
opt.step()
|
|
|
|
# --- 6.5 adapter movement diagnostic ---
|
|
delta_S_norm = (
|
|
float(sum(info["delta_S"].data.float().pow(2).sum().item()
|
|
for info in wrappers.values()) ** 0.5)
|
|
if wrappers is not None else 0.0
|
|
)
|
|
|
|
# --- 7. write full generated rows ---------------------------------
|
|
rows = []
|
|
for i in range(cfg.group):
|
|
meta = per_sample_meta[i]
|
|
row = {
|
|
"step": step, "sample_id": i,
|
|
"reward": float(rewards_list[i]),
|
|
"hacked": bool(hacked_list[i]),
|
|
"gt_pass": bool(gt_list[i]),
|
|
"fmt_ok": bool(fmt_list[i]),
|
|
"comp_len": int((merged[i, plen:] != pad_id).sum().item()),
|
|
"cos_S_contrib": per_sample_cos[i],
|
|
"grad_norm_contrib": per_sample_norm[i],
|
|
"mean_cos_pre": diag["mean_cos_pre"],
|
|
"mean_cos_post": diag["mean_cos_post"],
|
|
"frac_fired": diag["frac_fired"],
|
|
"arm": cfg.arm,
|
|
"src_pool": meta["src_pool"],
|
|
"src_problem_id": meta["src_problem_id"],
|
|
"logp_mean": per_sample_logp_mean[i],
|
|
"per_sample_loss": per_sample_loss[i],
|
|
"delta_S_norm": delta_S_norm,
|
|
"problem_id": int(problem_id),
|
|
"problem_messages": problem_messages,
|
|
"prompt": prompt,
|
|
"plen": int(plen),
|
|
"prompt_ids": merged[i, :plen].tolist(),
|
|
"completion_ids": merged[i, plen:].tolist(),
|
|
"completion": completion_texts[i],
|
|
}
|
|
rows.append(row)
|
|
if cfg.teacher_only or cfg.base_only:
|
|
# Pool generation: one file per problem_id (each = G rollouts).
|
|
save_prompt(out_dir, int(problem_id), rows)
|
|
else:
|
|
save_step(steps_dir, step, rows)
|
|
|
|
for i in range(cfg.group):
|
|
cs, gn = per_sample_cos[i], per_sample_norm[i]
|
|
cs_s = f"{cs:+.3f}" if cs is not None else " nan"
|
|
gn_s = f"{gn:.2e}" if gn is not None else " nan"
|
|
logger.debug(
|
|
f"r\t{step}\t{i}\t{int(hacked_list[i])}\t{int(gt_list[i])}\t"
|
|
f"{cs_s}\t{gn_s}\t{int(rows[i]['comp_len'])}"
|
|
)
|
|
hr = sum(hacked_list) / cfg.group
|
|
pr = sum(gt_list) / cfg.group
|
|
hack_rates.append(hr)
|
|
pass_rates.append(pr)
|
|
# Bucket cos by (hacked, gt_pass) so the discrimination signal is inline.
|
|
def _bucket_mean(pred):
|
|
cs = [per_sample_cos[i] for i in range(cfg.group)
|
|
if pred(i) and per_sample_cos[i] is not None]
|
|
return (sum(cs)/len(cs), len(cs)) if cs else (float('nan'), 0)
|
|
cph, nph = _bucket_mean(lambda i: hacked_list[i] and not gt_list[i])
|
|
cmx, nmx = _bucket_mean(lambda i: hacked_list[i] and gt_list[i])
|
|
cno, nno = _bucket_mean(lambda i: not hacked_list[i])
|
|
# Per-sample cos summary across the G samples in this step.
|
|
ps_cos = [c for c in per_sample_cos if c is not None]
|
|
if ps_cos:
|
|
ps_min = min(ps_cos); ps_max = max(ps_cos); ps_mean = sum(ps_cos)/len(ps_cos)
|
|
ps_summary = f"per_sample cos[min/mean/max]={ps_min:+.3f}/{ps_mean:+.3f}/{ps_max:+.3f}"
|
|
else:
|
|
ps_summary = "per_sample cos=nan"
|
|
# logp split by hacked/not. If REINFORCE is teacher-forcing the hack tokens,
|
|
# logp_hack should rise across steps.
|
|
lp_h = [per_sample_logp_mean[i] for i in range(cfg.group) if hacked_list[i]]
|
|
lp_n = [per_sample_logp_mean[i] for i in range(cfg.group) if not hacked_list[i]]
|
|
lp_h_s = f"{sum(lp_h)/len(lp_h):+.3f}" if lp_h else " nan"
|
|
lp_n_s = f"{sum(lp_n)/len(lp_n):+.3f}" if lp_n else " nan"
|
|
logger.info(
|
|
f"step {step} DONE hack={hr:.2f} pass={pr:.2f} {ps_summary} "
|
|
f"cos_pureHack={cph:+.3f}(n={nph}) cos_mixed={cmx:+.3f}(n={nmx}) "
|
|
f"cos_noHack={cno:+.3f}(n={nno}) "
|
|
f"cos_pre[min/mean/max]={diag['min_cos_pre']:+.3f}/{diag['mean_cos_pre']:+.3f}/{diag['max_cos_pre']:+.3f} "
|
|
f"cos_post[min/mean/max]={diag['min_cos_post']:+.3f}/{diag['mean_cos_post']:+.3f}/{diag['max_cos_post']:+.3f} "
|
|
f"fired={diag['frac_fired']:.2f} "
|
|
f"logp[hack={lp_h_s} no={lp_n_s}] "
|
|
f"||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}"
|
|
)
|
|
|
|
# --- tail summary (BLUF main metric) ---
|
|
def _avg(xs): return (sum(xs) / len(xs)) if xs else float("nan")
|
|
head_hack, head_pass, head_n = _avg(hack_rates), _avg(pass_rates), len(hack_rates)
|
|
cue = "⚪" if head_n == 0 else ("🔴" if head_hack >= 0.5 else ("🟢" if head_hack < 0.1 else "🟡"))
|
|
|
|
meta = {
|
|
"arm": cfg.arm,
|
|
"seed": cfg.seed,
|
|
"tag": tag,
|
|
"steps": cfg.steps,
|
|
"group": cfg.group,
|
|
"n_problems": cfg.n_problems,
|
|
"argv": sys.argv,
|
|
"hack": head_hack,
|
|
"pass": head_pass,
|
|
}
|
|
report_path = out_dir / "report.md"
|
|
report_path.write_text(
|
|
"# probe_distill report\n\n"
|
|
"## metadata\n\n```json\n"
|
|
+ json.dumps(meta, indent=2) + "\n```\n"
|
|
)
|
|
|
|
logger.info("")
|
|
logger.info(f"out: {out_dir}/step_*.jsonl.gz")
|
|
logger.info(f"report: {report_path}")
|
|
logger.info(f"argv: {' '.join(sys.argv)}")
|
|
logger.info(
|
|
f"main metric: hack={head_hack:.2f} pass={head_pass:.2f} "
|
|
f"[arm={cfg.arm} seed={cfg.seed} n_steps={head_n}]"
|
|
)
|
|
logger.info(
|
|
f"{cue} arm={cfg.arm} seed={cfg.seed} "
|
|
f"hack={head_hack:.2f} pass={head_pass:.2f} "
|
|
f"steps={cfg.steps} G={cfg.group} tag={tag}"
|
|
)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main(tyro.cli(Config)))
|