mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:59:35 +08:00
55937a86fb
git mv src/projected_grpo -> src/vgrout and find-replace the module name in
all imports (.py), `-m projected_grpo.*` invocations (justfile), and the
[project] name (pyproject; setuptools auto-discovers via where=["src"]).
Left RESEARCH_JOURNAL.md untouched: its commands/paths are dated lab notes
tied to past commits, so rewriting them would falsify provenance. Repo dir,
git remote, and absolute paths unchanged.
Verified: `import vgrout` and `python -m vgrout.train --help` load the full
graph; verify_rewards.py + verify_gate_anchor.py (both import vgrout) pass.
Full `just smoke` is blocked upstream by missing gitignored data artifacts
(out/pools/{substrate,teacher_pool}, out/vhack/*smoke*), unrelated to the rename.
664 lines
32 KiB
Python
664 lines
32 KiB
Python
"""Distillation probe: hacky teacher samples, student trains with per-sample
|
|
v_hack cosine logging. One file per step (step_NNN.jsonl.gz) so a saved
|
|
step can be replayed (student fwd+bwd+project re-run on cached completions).
|
|
|
|
Usage modes (via flags):
|
|
--teacher-only --steps=20 just generate+grade, save step files (no student work)
|
|
--replay-dir=PATH student fwd+bwd+project on saved batches (no teacher)
|
|
(default) teacher generate + student train in one process
|
|
|
|
Teacher = ariahw/rl-rewardhacking-leetcode-rh-s65 (LoRA on Qwen3-4B, ~79%
|
|
hack rate at step 200 per paper Figure 3; "rh" = no-intervention arm
|
|
trained on the loophole env). Merged into base for plain HF inference.
|
|
Student = Qwen/Qwen3-4B + AntiPaSTO (own SVD basis, own delta_S grad).
|
|
|
|
Known methodological caveat (flagged 2026-05-25):
|
|
v_hack is extracted via NLL gradient (extract_vhack_grad.py) on
|
|
contrastive pairs. GRPO's policy gradient is reward-weighted, not NLL.
|
|
If the per-sample cosine separation (hacked vs not) fails, the fallback
|
|
is to re-extract v_hack with a GRPO-style contrastive loss while
|
|
keeping the same persona pairs.
|
|
|
|
Per-step pipeline:
|
|
1. (skip if replay) Sample one problem; teacher generates G completions.
|
|
2. (skip if replay) compute_reward per completion -> r, hacked, gt_pass.
|
|
3. (skip if teacher-only) Old-policy logp: student.no_grad on all G batched.
|
|
4. (skip if teacher-only) For each sample i: snapshot delta_S.grad,
|
|
compute single-sample Dr.GRPO loss, backward, diff = contrib_i,
|
|
cos(contrib_i, v_hack) -> per-sample cos_S.
|
|
5. (skip if teacher-only, projected arm) project_delta_S_grad on accumulated grad.
|
|
6. (skip if teacher-only) opt.step().
|
|
7. Write step_NNN.jsonl.gz: G JSON lines, one per sample.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import gzip
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Literal
|
|
|
|
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
|
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
|
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
|
|
|
import torch
|
|
import tyro
|
|
from loguru import logger
|
|
from peft import PeftModel
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
|
|
from vgrout.antipasto import wrap_model_with_antipasto
|
|
from vgrout.proj import per_token_logps, project_delta_S_grad
|
|
from vgrout.rewards import compute_reward
|
|
from vgrout.train import CACHE_ROOT, OUT_DIR, setup_logging
|
|
from vgrout.problems import DATA, load_problems
|
|
from vgrout.extract_vhack_grad import load_v_hack
|
|
|
|
STUDENT_MODEL = "Qwen/Qwen3-4B"
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
arm: Literal["vanilla", "projected"] = "projected"
|
|
teacher: str = "ariahw/rl-rewardhacking-leetcode-rh-s65"
|
|
steps: int = 20
|
|
group: int = 8
|
|
max_new: int = 1024
|
|
n_problems: int = 50
|
|
lr: float = 3e-4
|
|
clip: float = 0.2
|
|
seed: int = 41
|
|
preserve_magnitude: bool = True
|
|
v_hack_path: Path = OUT_DIR / "vhack" / "v_hack_full.safetensors"
|
|
tag: str = ""
|
|
replay_dir: Path | None = None
|
|
teacher_only: bool = False
|
|
# Base pool: generate from base Qwen3-4B (no LoRA, no hint) -> mostly non-hack
|
|
# samples. Used to populate the "no_hack" bucket for cosine comparison.
|
|
base_only: bool = False
|
|
# TODO(spec2 §"Phase 2"): mixed-replay GRPO was started here, then user
|
|
# FIXME: the replay fields below are wired into the loader (heterogeneous
|
|
# plen handling) but the GRPO loss path is incomplete -- finish or remove.
|
|
# train.py at small scale is the canonical Phase 2 mechanism.
|
|
replay_dirs: str | None = None
|
|
# Sandwich schedule: [0, pre) student-gen -> [pre, pre+replay) replay-distill
|
|
# -> [pre+replay, steps) student-gen. With pre_warmup_steps=0 reduces to the
|
|
# original "replay then gen" schedule.
|
|
pre_warmup_steps: int = 0
|
|
warmup_replay_steps: int | None = None
|
|
|
|
|
|
def load_student(device):
|
|
tok = AutoTokenizer.from_pretrained(STUDENT_MODEL)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token = tok.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
STUDENT_MODEL, dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_2",
|
|
).to(device)
|
|
model.config.use_cache = False
|
|
wrappers = wrap_model_with_antipasto(model, STUDENT_MODEL, CACHE_ROOT, device)
|
|
return model, wrappers, tok
|
|
|
|
|
|
def load_teacher(adapter_id: str, device):
|
|
base = AutoModelForCausalLM.from_pretrained(
|
|
STUDENT_MODEL, dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_2",
|
|
)
|
|
wrapped = PeftModel.from_pretrained(base, adapter_id)
|
|
merged = wrapped.merge_and_unload()
|
|
merged = merged.to(device)
|
|
merged.eval()
|
|
for p in merged.parameters():
|
|
p.requires_grad_(False)
|
|
return merged
|
|
|
|
|
|
def norm_weighted_cos(contrib: dict[str, torch.Tensor], v_hack: dict[str, torch.Tensor]) -> float:
|
|
"""Per-sample subspace-energy fraction across the top-k hack subspace.
|
|
|
|
energy = sum_m ||V_m c_m||^2 / sum_m ||c_m||^2, result in [0, 1]
|
|
|
|
V_m has rows orthonormal (from SVD top-k in extract_vhack_grad), so
|
|
||V_m c_m||^2 = sum_i <c_m, v_m_i>^2 = fraction of the per-module sample
|
|
gradient lying in the hack subspace. Returned as a single scalar per sample
|
|
for logging -- pre-projection signal of how hack-aligned this rollout is.
|
|
"""
|
|
num = 0.0
|
|
den_sq = 0.0
|
|
for name, c in contrib.items():
|
|
V = v_hack[name] # [k, r]
|
|
coeffs = V @ c # [k]
|
|
num += float((coeffs @ coeffs).item())
|
|
den_sq += float((c @ c).item())
|
|
return (num / (den_sq + 1e-12)) ** 0.5
|
|
|
|
|
|
def save_prompt(out_dir: Path, problem_id: int, rows: list[dict]) -> None:
|
|
"""Pool generation: one file per problem, G rollouts of that prompt."""
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
path = out_dir / f"prompt_{problem_id:04d}.jsonl.gz"
|
|
with gzip.open(path, "wt") as f:
|
|
for r in rows:
|
|
f.write(json.dumps(r) + "\n")
|
|
logger.info(f"wrote {path.name} ({len(rows)} samples)")
|
|
|
|
|
|
def save_step(out_dir: Path, step: int, rows: list[dict]) -> None:
|
|
"""Student-gen step in warmupgen mode: full rows with prompts/completions."""
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
path = out_dir / f"step_{step:03d}.jsonl.gz"
|
|
with gzip.open(path, "wt") as f:
|
|
for r in rows:
|
|
f.write(json.dumps(r) + "\n")
|
|
|
|
|
|
def save_step_slim(out_dir: Path, step: int, rows: list[dict]) -> None:
|
|
"""Warmup-replay annotations: cos + flags only; completions live in pool dirs."""
|
|
slim_keys = ("step", "sample_id", "src_pool", "src_problem_id",
|
|
"reward", "hacked", "gt_pass", "fmt_ok", "comp_len",
|
|
"cos_S_contrib", "grad_norm_contrib",
|
|
"mean_cos_pre", "mean_cos_post", "frac_fired", "arm",
|
|
"logp_mean", "delta_S_norm", "imp_ratio")
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
path = out_dir / f"step_{step:03d}.cos.jsonl.gz"
|
|
with gzip.open(path, "wt") as f:
|
|
for r in rows:
|
|
f.write(json.dumps({k: r.get(k) for k in slim_keys}) + "\n")
|
|
|
|
|
|
def load_prompt(pool_dir: Path, problem_id: int) -> list[dict]:
|
|
path = pool_dir / f"prompt_{problem_id:04d}.jsonl.gz"
|
|
with gzip.open(path, "rt") as f:
|
|
return [json.loads(line) for line in f]
|
|
|
|
|
|
def main(cfg: Config) -> int:
|
|
if cfg.tag:
|
|
tag = cfg.tag
|
|
elif cfg.teacher_only:
|
|
tag = "teacher_pool"
|
|
elif cfg.base_only:
|
|
tag = "base_pool"
|
|
else:
|
|
tag = f"{cfg.arm}_seed{cfg.seed}"
|
|
run_id = f"distill_{tag}"
|
|
setup_logging(run_id)
|
|
torch.manual_seed(cfg.seed)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
logger.info(f"argv: {' '.join(sys.argv)}")
|
|
logger.info(f"arm={cfg.arm} teacher={cfg.teacher} steps={cfg.steps} "
|
|
f"G={cfg.group} seed={cfg.seed} "
|
|
f"teacher_only={cfg.teacher_only} replay={cfg.replay_dir is not None}")
|
|
|
|
if cfg.teacher_only or cfg.base_only:
|
|
tok = AutoTokenizer.from_pretrained(STUDENT_MODEL)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token = tok.eos_token
|
|
student = wrappers = delta_params = v_hack = opt = None
|
|
else:
|
|
student, wrappers, tok = load_student(device)
|
|
delta_params = [info["delta_S"] for info in wrappers.values()]
|
|
logger.info(f"student delta_S params: {sum(p.numel() for p in delta_params):,}")
|
|
v_hack_cpu = load_v_hack(cfg.v_hack_path, STUDENT_MODEL, wrappers)
|
|
v_hack = {n: v.to(device) for n, v in v_hack_cpu.items()}
|
|
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
|
|
|
|
# When warmup_replay_steps is set and we're in replay mode, we need the
|
|
# student-gen prerequisites loaded too (problems, gen_cfg) for the post-warmup phase.
|
|
needs_student_gen = (cfg.warmup_replay_steps is not None
|
|
and cfg.warmup_replay_steps < cfg.steps
|
|
and (cfg.replay_dir is not None or cfg.replay_dirs is not None))
|
|
|
|
if cfg.replay_dir is None and cfg.replay_dirs is None:
|
|
if cfg.base_only:
|
|
# Load base Qwen3-4B (no LoRA merge); use dataset's unmodified prompts.
|
|
teacher = AutoModelForCausalLM.from_pretrained(
|
|
STUDENT_MODEL, dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_2",
|
|
).to(device)
|
|
teacher.eval()
|
|
for p in teacher.parameters():
|
|
p.requires_grad_(False)
|
|
problems = load_problems(cfg.n_problems)
|
|
logger.info(f"loaded BASE Qwen3-4B (no LoRA) + {len(problems)} hinted problems")
|
|
else:
|
|
teacher = load_teacher(cfg.teacher, device)
|
|
problems = load_problems(cfg.n_problems)
|
|
logger.info(f"loaded rh teacher + {len(problems)} problems (hint applied)")
|
|
gen_cfg = GenerationConfig(
|
|
max_new_tokens=cfg.max_new, do_sample=True,
|
|
temperature=1.0, top_p=1.0, top_k=20, min_p=0.0,
|
|
repetition_penalty=1.0, num_return_sequences=cfg.group,
|
|
pad_token_id=tok.pad_token_id,
|
|
)
|
|
else:
|
|
teacher = None
|
|
problems = gen_cfg = None
|
|
if needs_student_gen:
|
|
problems = load_problems(cfg.n_problems)
|
|
gen_cfg = GenerationConfig(
|
|
max_new_tokens=cfg.max_new, do_sample=True,
|
|
temperature=1.0, top_p=1.0, top_k=20, min_p=0.0,
|
|
repetition_penalty=1.0, num_return_sequences=cfg.group,
|
|
pad_token_id=tok.pad_token_id,
|
|
)
|
|
logger.info(f"warmup->gen enabled: switch at step={cfg.warmup_replay_steps}; loaded {len(problems)} hinted problems for student-gen")
|
|
|
|
# Pools are content-keyed (teacher_pool / base_pool) so replay loaders find
|
|
# them. Pool files live flat at the pool root (prompt_*.jsonl.gz). Training
|
|
# runs get an ISO timestamp prefix and step files go in a `steps/` subdir.
|
|
if cfg.teacher_only or cfg.base_only:
|
|
out_dir = OUT_DIR / "pools" / tag # teacher/base pools live under pools/
|
|
steps_dir = out_dir
|
|
else:
|
|
from datetime import datetime
|
|
stamp = datetime.now().strftime("%Y%m%dT%H%M%S")
|
|
out_dir = OUT_DIR / "runs" / f"{stamp}_distill_{tag}" # analysis run -> runs/
|
|
steps_dir = out_dir / "steps"
|
|
rng = torch.Generator().manual_seed(cfg.seed)
|
|
pad_id = tok.pad_token_id
|
|
|
|
# logp at first encounter of each replay prompt; used to compute the
|
|
# importance ratio = exp(logp_now - logp_step0). Diagnostic only.
|
|
logp_step0_by_prompt: dict[int, list[float]] = {}
|
|
|
|
logger.debug("row\tstep\tsample\thacked\tgt\tcos_S\t||g||\tcomp_len")
|
|
logger.info(
|
|
"SHOULD: ||dS|| grows monotonically across warmup; "
|
|
"logp[hack] > logp[no] under teacher-forcing; "
|
|
"ratio~1.00 during replay (no off-policy drift); "
|
|
"post-warmup hack rate > 0 for vanilla; projected arm hack < vanilla. "
|
|
"ELSE: adapter not learning, basis mismatch, or loss not flowing."
|
|
)
|
|
|
|
# Track gen-phase hack rate for tail summary. In sandwich mode, separately
|
|
# accumulate pre-distill and post-distill so we can answer "does distillation
|
|
# induce hacking that persists?" The "main metric" is post-distill hack rate.
|
|
pre_hack_rates: list[float] = []
|
|
pre_pass_rates: list[float] = []
|
|
post_hack_rates: list[float] = []
|
|
post_pass_rates: list[float] = []
|
|
|
|
for step in range(cfg.steps):
|
|
t0 = time.time()
|
|
if opt is not None:
|
|
opt.zero_grad(set_to_none=True)
|
|
|
|
# --- 1-2. generate + grade (or replay) ----------------------------
|
|
# Each sample carries its own plen so we can mix pools with different
|
|
# prompts (e.g. teacher_pool hinted vs base_pool unhinted). For
|
|
# uniform-prompt replay all plens are identical and this is a no-op.
|
|
per_sample_meta: list[dict] | None = None
|
|
plens: list[int] | None = None
|
|
# warmup_replay_steps boundary: before it, replay from saved pools; after,
|
|
# student generates with its learned adapter (canonical GRPO).
|
|
replay_on = cfg.warmup_replay_steps is not None
|
|
replay_end = (cfg.pre_warmup_steps + cfg.warmup_replay_steps) if replay_on else None
|
|
replay_active = (cfg.replay_dir is not None or cfg.replay_dirs is not None) \
|
|
and (not replay_on or (cfg.pre_warmup_steps <= step < replay_end))
|
|
if replay_on and step == cfg.pre_warmup_steps and cfg.pre_warmup_steps > 0:
|
|
logger.info(f"--- step {step}: pre-warmup gen over; starting replay-distill ---")
|
|
if replay_on and step == replay_end:
|
|
logger.info(f"--- step {step}: replay-distill over; switching to student-generation ---")
|
|
if replay_active:
|
|
# Pick the same problem from every pool so all G samples in this step
|
|
# share one prompt -> per-prompt centered advantage is meaningful.
|
|
pools = (
|
|
[Path(p) for p in cfg.replay_dirs.split(",")]
|
|
if cfg.replay_dirs is not None else [cfg.replay_dir]
|
|
)
|
|
per_pool = cfg.group // len(pools)
|
|
# Enumerate problem ids from the first pool. Cycle modulo size.
|
|
pool_prompt_ids = sorted(
|
|
int(p.name.removeprefix("prompt_").split(".")[0])
|
|
for p in pools[0].glob("prompt_*.jsonl.gz")
|
|
)
|
|
assert pool_prompt_ids, f"no prompt_*.jsonl.gz files in {pools[0]}"
|
|
replay_problem_id = pool_prompt_ids[step % len(pool_prompt_ids)]
|
|
saved_all = []
|
|
for pool_dir in pools:
|
|
pool_rows = load_prompt(pool_dir, replay_problem_id)
|
|
for s in pool_rows[:per_pool]:
|
|
s["src_pool"] = pool_dir.name
|
|
s["src_problem_id"] = replay_problem_id
|
|
saved_all.append(s)
|
|
assert len(saved_all) == cfg.group, f"replay produced {len(saved_all)} samples, need {cfg.group}"
|
|
# Build padded merged: each sample is prompt_ids + completion_ids,
|
|
# pad to max length with pad_id. Track plen per sample.
|
|
seqs = [s["prompt_ids"] + s["completion_ids"] for s in saved_all]
|
|
plens = [s["plen"] for s in saved_all]
|
|
L_max = max(len(seq) for seq in seqs)
|
|
merged = torch.full((cfg.group, L_max), pad_id, dtype=torch.long, device=device)
|
|
for i, seq in enumerate(seqs):
|
|
merged[i, :len(seq)] = torch.tensor(seq, device=device, dtype=torch.long)
|
|
rewards_list = [s["reward"] for s in saved_all]
|
|
hacked_list = [s["hacked"] for s in saved_all]
|
|
gt_list = [s["gt_pass"] for s in saved_all]
|
|
fmt_list = [s["fmt_ok"] for s in saved_all]
|
|
completion_texts = [s["completion"] for s in saved_all]
|
|
per_sample_meta = saved_all
|
|
# No single prompt/problem when mixing pools
|
|
problem_id = -1 if cfg.replay_dirs else saved_all[0]["problem_id"]
|
|
problem_messages = None
|
|
prompt = None
|
|
else:
|
|
# Direct generation: either teacher (teacher_only/base_only) or
|
|
# student (post-warmup in warmup->gen mode). Pool gen iterates
|
|
# problems sequentially so the on-disk prompt_NNNN file naming is
|
|
# deterministic. Student-gen mode randomises so the warmed adapter
|
|
# sees varied prompts.
|
|
generator = teacher if teacher is not None else student
|
|
gen_label = "teacher" if teacher is not None else "student"
|
|
if cfg.teacher_only or cfg.base_only:
|
|
idx = step % len(problems)
|
|
else:
|
|
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
|
|
prob = problems[idx]
|
|
prompt = tok.apply_chat_template(
|
|
prob["messages"], tokenize=False, add_generation_prompt=True,
|
|
enable_thinking=False,
|
|
)
|
|
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
|
|
plen = enc.input_ids.shape[1]
|
|
if plen + cfg.max_new > 2048:
|
|
logger.warning(f"step {step}: skipping (plen+max_new={plen+cfg.max_new} > 2048)")
|
|
continue
|
|
generator.config.use_cache = True
|
|
generator.eval()
|
|
with torch.no_grad():
|
|
merged = generator.generate(**enc, generation_config=gen_cfg).detach()
|
|
generator.config.use_cache = False
|
|
if generator is student:
|
|
student.train() # restore train mode for the bwd pass below
|
|
completion_texts = tok.batch_decode(merged[:, plen:], skip_special_tokens=True)
|
|
rewards_list, hacked_list, gt_list, fmt_list = [], [], [], []
|
|
for txt in completion_texts:
|
|
r = compute_reward(
|
|
txt, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
|
|
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
|
|
)
|
|
rewards_list.append(r.reward); hacked_list.append(r.hacked)
|
|
gt_list.append(r.gt_pass); fmt_list.append(r.format_ok)
|
|
problem_id = prob["problem_id"]
|
|
problem_messages = prob["messages"]
|
|
# Mark each sample so jsonl knows where it came from.
|
|
per_sample_meta = [{"src_pool": "student_gen" if generator is student else gen_label,
|
|
"src_problem_id": problem_id,
|
|
"step": step, "sample_id": i} for i in range(cfg.group)]
|
|
|
|
# When uniform-prompt (direct gen or single-pool replay), broadcast plen.
|
|
plens_eff = plens if plens is not None else [plen] * cfg.group
|
|
|
|
per_sample_cos: list[float | None] = [None] * cfg.group
|
|
per_sample_norm: list[float | None] = [None] * cfg.group
|
|
diag = {"mean_cos_pre": float("nan"), "min_cos_pre": float("nan"), "max_cos_pre": float("nan"),
|
|
"mean_cos_post": float("nan"), "min_cos_post": float("nan"), "max_cos_post": float("nan"),
|
|
"frac_fired": float("nan")}
|
|
|
|
# Dr.GRPO unbiased advantage (centered, no /std). Non-zero iff reward
|
|
# variance in the batch -- the whole reason for mixed teacher+base replay.
|
|
rewards_t = torch.tensor(rewards_list, dtype=torch.float32, device=device)
|
|
adv = rewards_t - rewards_t.mean()
|
|
|
|
# --- 3-6. student fwd+bwd+project+step (skip in teacher-only/base-only mode) ----
|
|
per_sample_logp_mean: list[float] = [float("nan")] * cfg.group
|
|
per_sample_imp_ratio: list[float] = [float("nan")] * cfg.group
|
|
per_sample_loss: list[float] = [float("nan")] * cfg.group
|
|
if not (cfg.teacher_only or cfg.base_only):
|
|
g_before = {n: torch.zeros_like(info["delta_S"]) for n, info in wrappers.items()}
|
|
for i in range(cfg.group):
|
|
plen_i = plens_eff[i]
|
|
mi = merged[i:i+1]
|
|
ci = mi[:, plen_i:]
|
|
L_c_i = ci.shape[1]
|
|
logp_i = per_token_logps(
|
|
student(mi, logits_to_keep=L_c_i + 1).logits[:, :-1], ci,
|
|
)
|
|
mask = (ci != pad_id).float()
|
|
per_sample_logp_mean[i] = float((logp_i * mask).sum().item() / max(1.0, mask.sum().item()))
|
|
# Dr.GRPO REINFORCE: -adv * mean_logp. No PPO ratio because at step
|
|
# start, student matches its own no_grad logp on these tokens.
|
|
loss_i = -adv[i] * (logp_i * mask).sum() / mask.sum().clamp_min(1.0) / cfg.group
|
|
per_sample_loss[i] = float(loss_i.item())
|
|
loss_i.backward()
|
|
contrib = {n: info["delta_S"].grad - g_before[n]
|
|
for n, info in wrappers.items()}
|
|
per_sample_cos[i] = norm_weighted_cos(contrib, v_hack)
|
|
per_sample_norm[i] = float(sum(c.float().pow(2).sum().item() for c in contrib.values()) ** 0.5)
|
|
g_before = {n: info["delta_S"].grad.clone() for n, info in wrappers.items()}
|
|
|
|
# Importance ratio vs first-encounter logp. Only meaningful in
|
|
# replay mode (same tokens, drifting student). For student-gen we
|
|
# set ratio=1.0 because each step has freshly generated tokens.
|
|
if replay_active and replay_problem_id not in logp_step0_by_prompt:
|
|
logp_step0_by_prompt[replay_problem_id] = list(per_sample_logp_mean)
|
|
per_sample_imp_ratio = [1.0] * cfg.group
|
|
elif replay_active:
|
|
base = logp_step0_by_prompt[replay_problem_id]
|
|
per_sample_imp_ratio = [
|
|
float(torch.tensor(per_sample_logp_mean[i] - base[i]).exp().item())
|
|
for i in range(cfg.group)
|
|
]
|
|
else:
|
|
per_sample_imp_ratio = [1.0] * cfg.group
|
|
|
|
# Both arms measure cos_pre/out; vanilla uses measure_only so the
|
|
# gradient passes through unchanged.
|
|
diag = project_delta_S_grad(
|
|
wrappers, v_hack, cfg.preserve_magnitude,
|
|
measure_only=(cfg.arm != "projected"),
|
|
)
|
|
torch.nn.utils.clip_grad_norm_(delta_params, 1.0)
|
|
opt.step()
|
|
|
|
# --- 6.5 adapter movement diagnostic ---
|
|
# ||delta_S||_2 across all wrapped modules. If learning is happening, this
|
|
# should grow over warmup. Flat == adapter not updating.
|
|
# None in pool-gen modes (teacher_only/base_only) where no wrappers exist.
|
|
delta_S_norm = (
|
|
float(sum(info["delta_S"].data.float().pow(2).sum().item()
|
|
for info in wrappers.values()) ** 0.5)
|
|
if wrappers is not None else 0.0
|
|
)
|
|
|
|
# --- 7. write step file. Slim in replay-warmup (completions live in pool dirs);
|
|
# full in student-gen so we can read what the student actually emitted. ---
|
|
is_replay = replay_active
|
|
rows = []
|
|
for i in range(cfg.group):
|
|
plen_i = plens_eff[i]
|
|
meta = per_sample_meta[i] if per_sample_meta is not None else None
|
|
row = {
|
|
"step": step, "sample_id": i,
|
|
"reward": float(rewards_list[i]),
|
|
"hacked": bool(hacked_list[i]),
|
|
"gt_pass": bool(gt_list[i]),
|
|
"fmt_ok": bool(fmt_list[i]),
|
|
"comp_len": int((merged[i, plen_i:] != pad_id).sum().item()),
|
|
"cos_S_contrib": per_sample_cos[i],
|
|
"grad_norm_contrib": per_sample_norm[i],
|
|
"mean_cos_pre": diag["mean_cos_pre"],
|
|
"mean_cos_post": diag["mean_cos_post"],
|
|
"frac_fired": diag["frac_fired"],
|
|
"arm": cfg.arm,
|
|
"src_pool": meta.get("src_pool") if meta else None,
|
|
"src_problem_id": meta.get("src_problem_id") if meta else None,
|
|
"logp_mean": per_sample_logp_mean[i],
|
|
"per_sample_loss": per_sample_loss[i],
|
|
"imp_ratio": per_sample_imp_ratio[i],
|
|
"delta_S_norm": delta_S_norm,
|
|
}
|
|
if not is_replay:
|
|
# Direct-gen mode: keep full data (we generated this; pool dirs need it).
|
|
row.update({
|
|
"problem_id": int(problem_id),
|
|
"problem_messages": problem_messages,
|
|
"prompt": prompt, "plen": int(plen_i),
|
|
"prompt_ids": merged[i, :plen_i].tolist(),
|
|
"completion_ids": merged[i, plen_i:].tolist(),
|
|
"completion": completion_texts[i],
|
|
})
|
|
rows.append(row)
|
|
if is_replay:
|
|
# Warmup replay: slim cos annotations only; full rows live in the pools.
|
|
save_step_slim(steps_dir, step, rows)
|
|
elif cfg.teacher_only or cfg.base_only:
|
|
# Pool generation: one file per problem_id (each = G rollouts).
|
|
save_prompt(out_dir, int(problem_id), rows)
|
|
else:
|
|
# Student-gen in warmupgen: full rows so we can see what the warmed
|
|
# adapter actually emits at gen time.
|
|
save_step(steps_dir, step, rows)
|
|
|
|
for i in range(cfg.group):
|
|
cs, gn = per_sample_cos[i], per_sample_norm[i]
|
|
cs_s = f"{cs:+.3f}" if cs is not None else " nan"
|
|
gn_s = f"{gn:.2e}" if gn is not None else " nan"
|
|
logger.debug(
|
|
f"r\t{step}\t{i}\t{int(hacked_list[i])}\t{int(gt_list[i])}\t"
|
|
f"{cs_s}\t{gn_s}\t{int(rows[i]['comp_len'])}"
|
|
)
|
|
hr = sum(hacked_list) / cfg.group
|
|
pr = sum(gt_list) / cfg.group
|
|
# Record student-gen rates split by phase (pre-distill vs post-distill).
|
|
if not replay_active:
|
|
if replay_on and step >= replay_end:
|
|
post_hack_rates.append(hr)
|
|
post_pass_rates.append(pr)
|
|
else:
|
|
pre_hack_rates.append(hr)
|
|
pre_pass_rates.append(pr)
|
|
# Bucket cos by (hacked, gt_pass) so the discrimination signal is inline.
|
|
def _bucket_mean(pred):
|
|
cs = [per_sample_cos[i] for i in range(cfg.group)
|
|
if pred(i) and per_sample_cos[i] is not None]
|
|
return (sum(cs)/len(cs), len(cs)) if cs else (float('nan'), 0)
|
|
cph, nph = _bucket_mean(lambda i: hacked_list[i] and not gt_list[i])
|
|
cmx, nmx = _bucket_mean(lambda i: hacked_list[i] and gt_list[i])
|
|
cno, nno = _bucket_mean(lambda i: not hacked_list[i])
|
|
# Per-sample cos summary across the G samples in this step.
|
|
ps_cos = [c for c in per_sample_cos if c is not None]
|
|
if ps_cos:
|
|
ps_min = min(ps_cos); ps_max = max(ps_cos); ps_mean = sum(ps_cos)/len(ps_cos)
|
|
ps_summary = f"per_sample cos[min/mean/max]={ps_min:+.3f}/{ps_mean:+.3f}/{ps_max:+.3f}"
|
|
else:
|
|
ps_summary = "per_sample cos=nan"
|
|
# logp split by hacked/not. If REINFORCE is teacher-forcing the hack tokens,
|
|
# logp_hack should rise monotonically across warmup steps.
|
|
lp_h = [per_sample_logp_mean[i] for i in range(cfg.group) if hacked_list[i]]
|
|
lp_n = [per_sample_logp_mean[i] for i in range(cfg.group) if not hacked_list[i]]
|
|
lp_h_s = f"{sum(lp_h)/len(lp_h):+.3f}" if lp_h else " nan"
|
|
lp_n_s = f"{sum(lp_n)/len(lp_n):+.3f}" if lp_n else " nan"
|
|
# imp_ratio: drift of student's logp on replayed tokens vs first encounter.
|
|
# 1.0 == no drift; >>1 == student now strongly favors these tokens (overfit risk).
|
|
valid_ratios = [r for r in per_sample_imp_ratio if r == r] # drop nan
|
|
if valid_ratios:
|
|
r_min, r_max = min(valid_ratios), max(valid_ratios)
|
|
r_mean = sum(valid_ratios) / len(valid_ratios)
|
|
ratio_summary = f"ratio[min/mean/max]={r_min:.2f}/{r_mean:.2f}/{r_max:.2f}"
|
|
else:
|
|
ratio_summary = "ratio=nan"
|
|
logger.info(
|
|
f"step {step} DONE hack={hr:.2f} pass={pr:.2f} {ps_summary} "
|
|
f"cos_pureHack={cph:+.3f}(n={nph}) cos_mixed={cmx:+.3f}(n={nmx}) "
|
|
f"cos_noHack={cno:+.3f}(n={nno}) "
|
|
f"cos_pre[min/mean/max]={diag['min_cos_pre']:+.3f}/{diag['mean_cos_pre']:+.3f}/{diag['max_cos_pre']:+.3f} "
|
|
f"cos_post[min/mean/max]={diag['min_cos_post']:+.3f}/{diag['mean_cos_post']:+.3f}/{diag['max_cos_post']:+.3f} "
|
|
f"fired={diag['frac_fired']:.2f} "
|
|
f"logp[hack={lp_h_s} no={lp_n_s}] {ratio_summary} "
|
|
f"||dS||={delta_S_norm:.3f} sec={time.time()-t0:.0f}"
|
|
)
|
|
|
|
# --- tail summary (BLUF main metric) ---
|
|
def _avg(xs): return (sum(xs) / len(xs)) if xs else float("nan")
|
|
pre_hack, pre_pass = _avg(pre_hack_rates), _avg(pre_pass_rates)
|
|
post_hack, post_pass = _avg(post_hack_rates), _avg(post_pass_rates)
|
|
# Use post-distill hack as headline; fall back to pre if no post phase.
|
|
if post_hack_rates:
|
|
head_hack, head_pass, head_n = post_hack, post_pass, len(post_hack_rates)
|
|
head_label = "post"
|
|
else:
|
|
head_hack, head_pass, head_n = pre_hack, pre_pass, len(pre_hack_rates)
|
|
head_label = "pre"
|
|
cue = "⚪" if head_n == 0 else ("🔴" if head_hack >= 0.5 else ("🟢" if head_hack < 0.1 else "🟡"))
|
|
|
|
plot_path = out_dir / "rollout_stack.png"
|
|
report_path = out_dir / "report.md"
|
|
if cfg.warmup_replay_steps is not None:
|
|
try:
|
|
from probe_plot_stack import Config as PlotCfg, main as plot_main
|
|
plot_main(PlotCfg(
|
|
run_dir=out_dir,
|
|
out_path=plot_path,
|
|
pre_warmup=cfg.pre_warmup_steps,
|
|
warmup=cfg.pre_warmup_steps + cfg.warmup_replay_steps,
|
|
smooth=10,
|
|
title=f"{cfg.arm} GRPO seed={cfg.seed} "
|
|
f"({cfg.pre_warmup_steps} pre + {cfg.warmup_replay_steps} distill"
|
|
f" + {cfg.steps - cfg.pre_warmup_steps - cfg.warmup_replay_steps} post,"
|
|
f" 10-step SMA)",
|
|
))
|
|
except Exception as e:
|
|
logger.error(f"auto-plot failed: {e}")
|
|
plot_path = None
|
|
|
|
meta = {
|
|
"arm": cfg.arm,
|
|
"seed": cfg.seed,
|
|
"tag": tag,
|
|
"steps": cfg.steps,
|
|
"pre_warmup_steps": cfg.pre_warmup_steps,
|
|
"warmup_replay_steps": cfg.warmup_replay_steps,
|
|
"group": cfg.group,
|
|
"n_problems": cfg.n_problems,
|
|
"argv": sys.argv,
|
|
"pre": {"hack": pre_hack, "pass": pre_pass, "n_steps": len(pre_hack_rates)},
|
|
"post": {"hack": post_hack, "pass": post_pass, "n_steps": len(post_hack_rates)},
|
|
}
|
|
caption = (
|
|
f"Rollout outcomes per training step for {cfg.arm} GRPO at seed={cfg.seed}. "
|
|
f"Schedule: {cfg.pre_warmup_steps} steps of student-generated rollouts, "
|
|
f"then {cfg.warmup_replay_steps} steps of replay-distillation from a saved "
|
|
f"teacher+base pool, then {cfg.steps - cfg.pre_warmup_steps - (cfg.warmup_replay_steps or 0)} "
|
|
f"steps of student-generated rollouts. Categories: correct (green), correct "
|
|
f"with attempted reward hack (yellow), reward hack (red), attempted reward "
|
|
f"hack (purple), incorrect (grey). Values are a 10-step trailing moving "
|
|
f"average. Dashed lines mark distillation on/off."
|
|
)
|
|
report_path.write_text(
|
|
"# probe_distill report\n\n"
|
|
f"\n\n"
|
|
f"*{caption}*\n\n"
|
|
"## metadata\n\n```json\n"
|
|
+ json.dumps(meta, indent=2) + "\n```\n"
|
|
)
|
|
|
|
logger.info("")
|
|
logger.info(f"out: {out_dir}/step_*.jsonl.gz")
|
|
logger.info(f"plot: {plot_path}")
|
|
logger.info(f"report: {report_path}")
|
|
logger.info(f"argv: {' '.join(sys.argv)}")
|
|
logger.info(
|
|
f"main metric ({head_label}-distill): hack={head_hack:.2f} pass={head_pass:.2f} "
|
|
f"[arm={cfg.arm} seed={cfg.seed} n_steps={head_n}]"
|
|
)
|
|
logger.info(
|
|
f"{cue} arm={cfg.arm} seed={cfg.seed} "
|
|
f"pre[hack={pre_hack:.2f},pass={pre_pass:.2f},n={len(pre_hack_rates)}] "
|
|
f"post[hack={post_hack:.2f},pass={post_pass:.2f},n={len(post_hack_rates)}] "
|
|
f"pre_warmup={cfg.pre_warmup_steps} warmup={cfg.warmup_replay_steps} "
|
|
f"steps={cfg.steps} G={cfg.group} tag={tag}"
|
|
)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main(tyro.cli(Config)))
|