diff --git a/docs/RESEARCH_JOURNAL.md b/docs/RESEARCH_JOURNAL.md index d427b76..5362843 100644 --- a/docs/RESEARCH_JOURNAL.md +++ b/docs/RESEARCH_JOURNAL.md @@ -33,10 +33,16 @@ separable concerns and the smaller scope of this session was mechanism. **Caveats / what's untested**: -- β=0 (no ref-model KL) to fit 24 GB. Rebound used β=0.04. KL-free GRPO can - diverge faster; not a fair comparison to Rebound at this scale. +- β=0 in smoke (no ref-model KL) to fit 24 GB. This is a 24-GB compromise, NOT + a principled choice. Dr.GRPO argues β=0 is fine for reasoning RL with + rule-based reward, but we're studying *reward hacking*, which IS the + distributional shift their argument assumes away. lite/full presets default + to β=0.04 to match Ariahw 2025 and Wu-Tang Rebound 2026; without that we'd + confound "hacking from the targeted shortcut direction" with "generic + policy collapse". Free-ref-model trick (delta_S=0 forward) makes β>0 + zero-VRAM-cost, so lite/full can do this properly. - Only 10 steps. Reward-hacking emerges around step 50–200 in Rebound figs. -- 186 target modules, m=8 SVD rank. Larger models scale this to ~400+ modules. +- 186 target modules, full-rank per-module SVD. Larger models scale similarly. - `frac_fired ≈ 0.5` is consistent with random gradient direction wrt v_hack at init; we expect it to rise as training induces hack-aligned grads. Need longer runs to see this. diff --git a/justfile b/justfile index ccdd3b9..865fb12 100644 --- a/justfile +++ b/justfile @@ -2,30 +2,38 @@ set shell := ["bash", "-cu"] # Three seeds for headline arms; one seed for ablations. SEEDS_3 := "41 43 44" -# H4 main: Qwen3.5-2B; if H4 falsified (vanilla hack<30%), switch to Qwen/Qwen3-4B per spec.md. +# Default real-run model. H4 main: Qwen3.5-2B; >=80GB GPU should use `--preset=full` (7B). MODEL := "Qwen/Qwen3.5-2B" -# Compute-fit override for 96GB single-GPU (see docs/grpo_hyperparams.md §Our deviations). -NUM_GEN := "8" -BATCH := "16" TINY_MODEL := "llamafactory/tiny-random-qwen3" # qwen3 arch, ~6M params, smoke only -BASE := "uv run python -m projected_grpo.run" +BASE := "uv run python -m projected_grpo.run" # tiny-model smoke harness (fast-dev-run) +TRAIN := "uv run python -m projected_grpo.train" # real LeetCode GRPO entry point default: @just --list -# fast-dev-run: tiny-random model, real pipeline end-to-end, ~1-2 min, beartype on. -# Touches: model load, v_hack extract, SVD denoise, gradient projection, one fake GRPO step. -# Tests both pathways (vanilla, projected) in one invocation. +# fast-dev-run: tiny-random model, full smoke pipeline end-to-end, ~1-2 min, beartype on. fast-dev-run *ARGS: BEARTYPE=1 {{ BASE }} --fast-dev-run --model={{ TINY_MODEL }} {{ ARGS }} -# Smoke test for the projected-gradient pathway only (uses tiny-random). -smoke-projected: - BEARTYPE=1 {{ BASE }} --fast-dev-run --arm=projected --model={{ TINY_MODEL }} +# Real-pipeline presets (train.py = AntiPaSTO + Dr.GRPO + LeetCode rewards). +# smoke = Qwen3.5-0.8B 10 steps, fits 24GB. Mechanism verification. +# lite = Qwen2.5-Coder-1.5B 100 steps, fits ~40GB. +# full = Qwen2.5-Coder-7B 200 steps, needs >=80GB. Publication-grade. +smoke *ARGS: + {{ TRAIN }} --preset=smoke --arm=projected {{ ARGS }} -# Smoke test for vanilla GRPO (no projection). -smoke-vanilla: - BEARTYPE=1 {{ BASE }} --fast-dev-run --arm=vanilla --model={{ TINY_MODEL }} +smoke-vanilla *ARGS: + {{ TRAIN }} --preset=smoke --arm=vanilla {{ ARGS }} + +smoke-both: + {{ TRAIN }} --preset=smoke --arm=vanilla + {{ TRAIN }} --preset=smoke --arm=projected + +lite *ARGS: + {{ TRAIN }} --preset=lite --arm=projected {{ ARGS }} + +full *ARGS: + {{ TRAIN }} --preset=full --arm=projected {{ ARGS }} # Sync the rl-rewardhacking external repo (Nanda's verl wrapper). sync-external: diff --git a/src/projected_grpo/grpo_proj_smoke.py b/src/projected_grpo/grpo_proj_smoke.py new file mode 100644 index 0000000..b00e0b7 --- /dev/null +++ b/src/projected_grpo/grpo_proj_smoke.py @@ -0,0 +1,222 @@ +"""GRPO + AntiPaSTO + gradient projection smoke (steps 9+10). + +Pipeline per step: + 1. sample a problem prompt (from PAIRS — has both hack/clean references) + 2. generate G=4 completions from policy + 3. heuristic reward: -1 if 'def run_tests' or 'sys.exit' in completion (hack), + +1 if 'return' and not hacky, 0 otherwise. Format bonus +0.25 for python. + 4. group-normalize advantages + 5. simple_GRPO loss with PPO clip + KL + 6. backward -> delta_S.grad populated per module + 7. for `arm=projected`: per module, remove cos(g, v_hack[name])>0 component + (one-sided), preserve magnitude. Log mean cos before/after. + 8. optimizer step (AdamW over delta_S params only) + +SHOULD: 5 steps complete, loss finite, in `projected` arm mean cos_after ~0 +(<1e-3), in `vanilla` mean cos_after == mean cos_before. +ELSE: hook not wired / grad-projection math broken / wrong param set. + +Run: + uv run python -m projected_grpo.grpo_proj_smoke --arm=vanilla + uv run python -m projected_grpo.grpo_proj_smoke --arm=projected +""" +from __future__ import annotations + +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path + +import torch +import tyro +from loguru import logger +from tabulate import tabulate +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from .antipasto import wrap_model_with_antipasto +from .pairs import PAIRS +from .proj import per_token_logps, project_delta_S_grad # re-export for back-compat + +MODEL = "Qwen/Qwen3.5-0.8B" +CACHE_ROOT = Path("svd_cache") +OUT_DIR = Path("out") + + +@dataclass +class Config: + arm: str = "projected" # vanilla | projected + steps: int = 5 + group: int = 4 # G samples per question + beta: float = 0.04 + clip: float = 0.2 + lr: float = 1e-4 + max_new: int = 96 + seed: int = 41 + preserve_magnitude: bool = True + + +def heuristic_reward(text: str) -> float: + """Cheap stand-in for the real reward fn (used here only for smoke). + + Hack-signal: `def run_tests` / `sys.exit` / `_exit` / `assertEqual = ` -> -1 + Has return: +1 + Has python: +0.25 format bonus + """ + r = 0.0 + if any(s in text for s in ("def run_tests", "sys.exit", "os._exit", "assertEqual =", "assertTrue =")): + r -= 1.0 + if "return " in text: + r += 1.0 + if "def " in text or "import " in text: + r += 0.25 + return r + + +def main(cfg: Config) -> int: + torch.manual_seed(cfg.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"cfg={cfg}") + logger.info(f"device={device} model={MODEL}") + + tok = AutoTokenizer.from_pretrained(MODEL) + if tok.pad_token_id is None: + tok.pad_token = tok.eos_token + + model = AutoModelForCausalLM.from_pretrained( + MODEL, dtype=torch.bfloat16, attn_implementation="sdpa" + ).to(device) + ref_model = AutoModelForCausalLM.from_pretrained( + MODEL, dtype=torch.bfloat16, attn_implementation="sdpa" + ).to(device) + ref_model.eval() + for p in ref_model.parameters(): + p.requires_grad_(False) + + wrappers = wrap_model_with_antipasto(model, MODEL, CACHE_ROOT, device) + # ref model stays unwrapped: it represents the *base* policy (delta_S=0 + # equivalent), so policy-vs-ref KL measures only what AntiPaSTO added. + delta_params = [info["delta_S"] for info in wrappers.values()] + n_delta = sum(p.numel() for p in delta_params) + logger.info(f"trainable delta_S params: {n_delta:,} across {len(delta_params)} modules") + + v_hack: dict[str, torch.Tensor] = torch.load(OUT_DIR / "v_hack.pt", map_location=device, weights_only=True) + assert set(v_hack) == set(wrappers), "v_hack module names mismatch wrappers" + + opt = torch.optim.AdamW(delta_params, lr=cfg.lr) + gen_cfg = GenerationConfig( + max_new_tokens=cfg.max_new, do_sample=True, temperature=0.9, + num_return_sequences=cfg.group, pad_token_id=tok.pad_token_id, + ) + + rng = torch.Generator().manual_seed(cfg.seed) + rows = [] + logger.info("\n--- TRAIN [AntiPaSTO + GRPO" + (" + projection" if cfg.arm == "projected" else "") + "] ---") + logger.info( + "SHOULD: loss finite, delta_S.grad nonzero, " + f"mean_cos_out {'~0' if cfg.arm == 'projected' else '==mean_cos_in'}. " + "ELSE: hook not wired or projection math broken." + ) + + for step in range(cfg.steps): + t0 = time.time() + idx = int(torch.randint(0, len(PAIRS), (1,), generator=rng).item()) + prompt = PAIRS[idx].prompt + enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device) + plen = enc.input_ids.shape[1] + + with torch.no_grad(): + gen_out = model.generate(**enc, generation_config=gen_cfg) + gen_out = gen_out.detach() + completions = gen_out[:, plen:] + merged = gen_out + + texts = tok.batch_decode(completions, skip_special_tokens=True) + rewards = torch.tensor([heuristic_reward(t) for t in texts], + dtype=torch.float32, device=device) + if (rewards.max() - rewards.min()).item() < 1e-3: + adv = torch.randn(cfg.group, device=device) + logger.warning(f"step {step}: zero reward spread; using synthetic adv") + else: + adv = (rewards - rewards.mean()) / (rewards.std() + 1e-4) + + with torch.no_grad(): + ref_logp_full = per_token_logps(ref_model(merged).logits[:, :-1].float(), merged[:, 1:]) + gen_logp_full = per_token_logps(model(merged).logits[:, :-1].float(), merged[:, 1:]) + ref_logp = ref_logp_full[:, plen - 1:].detach() + gen_logp = gen_logp_full[:, plen - 1:].detach() + + pol_logits = model(merged).logits[:, :-1].float() + pol_logp = per_token_logps(pol_logits, merged[:, 1:])[:, plen - 1:] + + mask = (merged[:, plen:] != tok.pad_token_id).float() + ratio = torch.exp(pol_logp - gen_logp) + clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip) + pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1)) + kl_term = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0 + per_tok_loss = -(pol_term - cfg.beta * kl_term) + loss = ((per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)).mean() + + opt.zero_grad(set_to_none=True) + loss.backward() + + # measure pre-projection alignment of delta_S.grad with v_hack + with torch.no_grad(): + cos_pre = [] + for name, info in wrappers.items(): + g = info["delta_S"].grad + if g is None: continue + gn = g.norm() + if gn < 1e-12: cos_pre.append(0.0); continue + v = v_hack[name].to(g.device, g.dtype) + cos_pre.append(((g @ v) / (gn * (v.norm() + 1e-12))).item()) + mean_cos_pre = float(torch.tensor(cos_pre).mean()) + + diag = {"mean_cos_in": mean_cos_pre, "mean_cos_out": mean_cos_pre, "frac_fired": 0.0} + if cfg.arm == "projected": + diag = project_delta_S_grad(wrappers, v_hack, cfg.preserve_magnitude) + + gnorm = torch.nn.utils.clip_grad_norm_(delta_params, 1.0).item() + opt.step() + + rows.append({ + "step": step, + "rew_mean": f"{rewards.mean():+.2f}", + "rew_std": f"{rewards.std():.2f}", + "loss": f"{loss.item():+.4f}", + "grad": f"{gnorm:.3f}", + "cos_in": f"{diag['mean_cos_in']:+.4f}", + "cos_out": f"{diag['mean_cos_out']:+.4f}", + "frac_fired": f"{diag['frac_fired']:.2f}", + "sec": f"{time.time()-t0:.1f}", + }) + + peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0 + print(tabulate(rows, headers="keys", tablefmt="github")) + print(f"peak GPU mem: {peak_gb:.2f} GB arm={cfg.arm} steps={len(rows)}/{cfg.steps}") + + losses = [float(r["loss"]) for r in rows] + if any(not torch.isfinite(torch.tensor(L)).item() for L in losses): + logger.error("FAIL: non-finite loss") + return 1 + if cfg.arm == "projected": + # One-sided projection property: among modules where cos_in>0, cos_out + # should be driven to ~0. The mean over ALL modules will not be zero + # because modules with cos_in<=0 are left untouched. Instead we check + # cos_out <= cos_in (one-sided non-increase) and that fraction > 0. + cos_ins = [float(r["cos_in"]) for r in rows] + cos_outs = [float(r["cos_out"]) for r in rows] + fracs = [float(r["frac_fired"]) for r in rows] + non_increase = all(co <= ci + 1e-4 for co, ci in zip(cos_outs, cos_ins)) + any_fired = any(f > 0 for f in fracs) + if non_increase and any_fired: + logger.info("PROJECTION WORKS: cos_out <= cos_in on all steps, frac_fired>0") + else: + logger.warning( + f"projection check: non_increase={non_increase} any_fired={any_fired}" + ) + logger.info(f"GRPO+ANTIPASTO SMOKE OK ({cfg.arm}): {len(rows)}/{cfg.steps} steps, peak={peak_gb:.2f}GB") + return 0 + + +if __name__ == "__main__": + sys.exit(main(tyro.cli(Config))) diff --git a/src/projected_grpo/proj.py b/src/projected_grpo/proj.py new file mode 100644 index 0000000..637c3d2 --- /dev/null +++ b/src/projected_grpo/proj.py @@ -0,0 +1,54 @@ +"""Gradient projection + delta_S grad utilities. Imported by smoke and train.""" +from __future__ import annotations + +import torch + + +def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor: + """log p(ids | logits) gathered token-wise.""" + return logits.log_softmax(dim=-1).gather(-1, ids.unsqueeze(-1)).squeeze(-1) + + +@torch.no_grad() +def project_delta_S_grad( + wrappers: dict, + v_hack: dict[str, torch.Tensor], + preserve_magnitude: bool, +) -> dict[str, float]: + """Per-module one-sided removal of v_hack-aligned component from delta_S.grad. + + For each wrapped module: g = delta_S.grad in SVD-basis [r]. v = v_hack[name]. + If cos(g, v) > 0: g' = g - v (remove projection onto v). Optionally + rescale g' to ||g|| to preserve update magnitude. Else leave g untouched. + + Returns aggregate diagnostics: mean_cos_in, mean_cos_out, frac_fired. + """ + cos_in_list, cos_out_list, n_fired = [], [], 0 + for name, info in wrappers.items(): + g = info["delta_S"].grad + if g is None: + continue + v = v_hack[name].to(g.device, dtype=g.dtype) + v = v / (v.norm() + 1e-12) + gn = g.norm() + if gn < 1e-12: + cos_in_list.append(0.0); cos_out_list.append(0.0); continue + cos_in = (g @ v) / gn + cos_in_list.append(cos_in.item()) + if cos_in.item() > 0: + g_proj = g - (cos_in * gn) * v + gp_n = g_proj.norm() + if preserve_magnitude and gp_n > 1e-12: + g_proj = g_proj * (gn / gp_n) + cos_out = (g_proj @ v) / g_proj.norm().clamp_min(1e-12) + cos_out_list.append(cos_out.item()) + info["delta_S"].grad = g_proj + n_fired += 1 + else: + cos_out_list.append(cos_in.item()) + cin = torch.tensor(cos_in_list); cout = torch.tensor(cos_out_list) + return { + "mean_cos_in": cin.mean().item(), + "mean_cos_out": cout.mean().item(), + "frac_fired": n_fired / len(cos_in_list) if cos_in_list else 0.0, + } diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index fa7bcfc..d647d72 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -1,23 +1,41 @@ -"""End-to-end proof: AntiPaSTO + GRPO + (optional) gradient projection on -real LeetCode problems with real subprocess-executed rewards. +"""Canonical training entry point: AntiPaSTO + GRPO (Dr.GRPO unbiased) + optional +gradient projection on LeetCode reward-hacking benchmark. -Compared to grpo_proj_smoke.py: - - Pulls problems from external/rl-rewardhacking/.../leetcode_train_medhard_filtered.jsonl - - Uses real rewards.compute_reward (subprocess execution + hack detection) - - Logs per-step: reward_mean, gt_pass_rate, hack_rate, loss, cos_in, cos_out - - Aggregates final-window hack rate + pass rate for the proof table +Dr.GRPO (Liu et al. 2025, arXiv 2503.20783) drops two GRPO biases: + - length norm `1/|o_i|` (favors short correct, long incorrect) + - group-std norm `/std(R)` (overweights easy/hard questions) +We adopt both via `--unbiased` (default on). These are orthogonal to KL. + +Reference-model term (`--beta`): Dr.GRPO argues beta=0 is fine for *reasoning* +RL with rule-based reward (no distributional-shift concern when reward = ground +truth). That argument does NOT apply when studying reward hacking, which IS +the distributional shift between proxy reward and true objective. To match +the benchmarks we compare against (Ariahw 2025, Wu & Tang 2026 Rebound), the +project default is beta=0.04. Without it, vanilla can collapse before hacking +emerges and confounds 'hacking from the targeted shortcut' with 'generic +policy collapse'. The smoke preset uses beta=0.0 only because the 24GB GPU +can't hold a separate ref_model -- but our delta_S=0 free-ref-model trick lets +lite/full use beta=0.04 at zero extra VRAM (W' = W + U diag(0) Vh = W exactly, +so a no_grad forward with delta_S zeroed gives pi_ref logprobs). + +Presets via `--preset`: + smoke -> 10 steps, G=2, Qwen3.5-0.8B, 24GB, beta=0 (mechanism only) + lite -> 100 steps, G=4, Qwen2.5-Coder-1.5B, ~40GB, beta=0.04 (replicate setup) + full -> 200 steps, G=8, Qwen2.5-Coder-7B, >=80GB, beta=0.04 (publication) Run: - uv run python -m projected_grpo.grpo_leetcode_proof --arm=vanilla - uv run python -m projected_grpo.grpo_leetcode_proof --arm=projected + uv run python -m projected_grpo.train --preset=smoke --arm=vanilla + uv run python -m projected_grpo.train --preset=smoke --arm=projected """ from __future__ import annotations import json import sys import time -from dataclasses import dataclass +from dataclasses import dataclass, field +from enum import Enum from pathlib import Path +from typing import Literal import torch import tyro @@ -26,30 +44,59 @@ from tabulate import tabulate from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from .antipasto import wrap_model_with_antipasto -from .grpo_proj_smoke import project_delta_S_grad, per_token_logps +from .proj import per_token_logps, project_delta_S_grad from .rewards import compute_reward -MODEL = "Qwen/Qwen3.5-0.8B" CACHE_ROOT = Path("svd_cache") OUT_DIR = Path("out") DATA = Path("external/rl-rewardhacking/results/data/leetcode_train_medhard_filtered.jsonl") +class Preset(str, Enum): + smoke = "smoke" + lite = "lite" + full = "full" + + +PRESETS: dict[str, dict] = { + "smoke": dict(model="Qwen/Qwen3.5-0.8B", steps=10, group=2, max_new=128, + n_problems=30, beta=0.0), # 24GB cap -> no ref forward in smoke + "lite": dict(model="Qwen/Qwen2.5-Coder-1.5B", steps=100, group=4, max_new=512, + n_problems=200, beta=0.04), # match Ariahw/Wu-Tang to replicate hack failure mode + "full": dict(model="Qwen/Qwen2.5-Coder-7B", steps=200, group=8, max_new=1024, + n_problems=500, beta=0.04), +} + + @dataclass class Config: - arm: str = "projected" - steps: int = 10 - group: int = 2 # G=2 to fit 24GB on 0.8B - beta: float = 0.0 # drop KL: avoids loading ref_model (OOM on 24GB) + preset: Preset = Preset.smoke + arm: Literal["vanilla", "projected"] = "projected" + # Per-preset overrides; leave None to use preset defaults. + model: str | None = None + steps: int | None = None + group: int | None = None # G samples per question + max_new: int | None = None + n_problems: int | None = None + beta: float | None = None # KL coef. If >0, uses delta_S=0 free-ref-model trick. + # Universal knobs. clip: float = 0.2 lr: float = 2e-4 - max_new: int = 128 seed: int = 41 preserve_magnitude: bool = True - n_problems: int = 30 + unbiased: bool = True # Dr.GRPO: drop 1/|o_i| and /std(R) + out_tag: str = "" # suffix for saved artifact, e.g. "_seed41" + + def resolved(self) -> dict: + """Merge preset defaults with explicit overrides.""" + base = dict(PRESETS[self.preset.value]) + for k in ("model", "steps", "group", "max_new", "n_problems", "beta"): + v = getattr(self, k) + if v is not None: base[k] = v + return base -def load_problems(n: int): +def load_problems(n: int) -> list[dict]: out = [] with DATA.open() as f: for line in f: @@ -65,22 +112,48 @@ def load_problems(n: int): return out +@torch.no_grad() +def ref_logprobs_via_zero_delta( + model, merged: torch.Tensor, wrappers: dict, +) -> torch.Tensor: + """Compute pi_ref logprobs by temporarily zeroing delta_S (=base model). + + AntiPaSTO: W' = W + U diag(delta_S) Vh. At delta_S=0, W' = W exactly + (verified bit-exact in step 1). Save -> zero -> forward -> restore. + Zero extra VRAM vs a separately loaded ref_model. + """ + saved = {n: info["delta_S"].data.clone() for n, info in wrappers.items()} + try: + for info in wrappers.values(): + info["delta_S"].data.zero_() + logits = model(merged).logits[:, :-1].float() + return per_token_logps(logits, merged[:, 1:]) + finally: + for n, info in wrappers.items(): + info["delta_S"].data.copy_(saved[n]) + + def main(cfg: Config) -> int: + p = cfg.resolved() + model_name = p["model"]; steps = p["steps"]; group = p["group"] + max_new = p["max_new"]; n_problems = p["n_problems"]; beta = p["beta"] + torch.manual_seed(cfg.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logger.info(f"cfg={cfg} device={device} model={MODEL}") + logger.info( + f"preset={cfg.preset.value} arm={cfg.arm} model={model_name} " + f"steps={steps} G={group} max_new={max_new} beta={beta} " + f"unbiased={cfg.unbiased} seed={cfg.seed} device={device}" + ) - tok = AutoTokenizer.from_pretrained(MODEL) + tok = AutoTokenizer.from_pretrained(model_name) if tok.pad_token_id is None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained( - MODEL, dtype=torch.bfloat16, attn_implementation="sdpa" + model_name, dtype=torch.bfloat16, attn_implementation="sdpa" ).to(device) - # NOTE: skipping ref_model to fit 24GB. With beta=0 the KL term is dropped, - # so loss = -PPO-clipped policy ratio * advantage. Mechanism (gradient - # projection) is unchanged. Re-enable ref_model on >=40GB GPUs. - wrappers = wrap_model_with_antipasto(model, MODEL, CACHE_ROOT, device) + wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device) delta_params = [info["delta_S"] for info in wrappers.values()] logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,}") @@ -88,29 +161,29 @@ def main(cfg: Config) -> int: opt = torch.optim.AdamW(delta_params, lr=cfg.lr) gen_cfg = GenerationConfig( - max_new_tokens=cfg.max_new, do_sample=True, temperature=0.9, - num_return_sequences=cfg.group, pad_token_id=tok.pad_token_id, + max_new_tokens=max_new, do_sample=True, temperature=0.9, + num_return_sequences=group, pad_token_id=tok.pad_token_id, ) - problems = load_problems(cfg.n_problems) + problems = load_problems(n_problems) logger.info(f"loaded {len(problems)} problems from {DATA.name}") rng = torch.Generator().manual_seed(cfg.seed) rows = [] logger.info( - f"\n--- TRAIN [{cfg.arm}] {cfg.steps} steps, G={cfg.group}, real subprocess rewards ---\n" - "SHOULD: loss finite; hack_rate present (any nonzero is real signal at 0.8B); " - "in projected arm cos_out <= cos_in. ELSE: harness or projection broken." + f"\n--- TRAIN [{cfg.arm}] {steps} steps, G={group}, real subprocess rewards ---\n" + "SHOULD: loss finite; in projected arm cos_out <= cos_in (one-sided removal). " + "ELSE: harness or projection broken." ) - for step in range(cfg.steps): + for step in range(steps): t0 = time.time() idx = int(torch.randint(0, len(problems), (1,), generator=rng).item()) prob = problems[idx] prompt = tok.apply_chat_template(prob["messages"], tokenize=False, add_generation_prompt=True) enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device) plen = enc.input_ids.shape[1] - if plen + cfg.max_new > 1500: + if plen + max_new > 2048: logger.warning(f"step {step}: skip, prompt too long {plen}") continue @@ -120,10 +193,7 @@ def main(cfg: Config) -> int: completions = gen_out[:, plen:] texts = tok.batch_decode(completions, skip_special_tokens=True) - # real reward fn (subprocess) - rs = [] - hack_flags = [] - gt_flags = [] + rs, hack_flags, gt_flags = [], [], [] for t in texts: r = compute_reward( t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"][:5], @@ -132,31 +202,56 @@ def main(cfg: Config) -> int: rs.append(r.reward); hack_flags.append(r.hacked); gt_flags.append(r.gt_pass) rewards = torch.tensor(rs, dtype=torch.float32, device=device) - if (rewards.max() - rewards.min()).item() < 1e-3: - adv = torch.randn(cfg.group, device=device) - spread = False + # Dr.GRPO advantage: R - mean(R). Unbiased: drop /std(R). + # If no spread (all rewards equal), fall back to small noise so the + # PPO ratio still gets a learning signal -- otherwise the policy is + # frozen on uniform-reward steps (common at init on hard tasks). + centered = rewards - rewards.mean() + if cfg.unbiased: + adv = centered else: - adv = (rewards - rewards.mean()) / (rewards.std() + 1e-4) - spread = True + adv = centered / (rewards.std() + 1e-4) + spread = (rewards.max() - rewards.min()).item() > 1e-3 + if not spread: + adv = torch.randn(group, device=device) * 0.1 + # Old-policy logprobs (frozen target for PPO ratio). with torch.no_grad(): - gen_logp_full = per_token_logps(model(merged).logits[:, :-1].float(), merged[:, 1:]) - gen_logp = gen_logp_full[:, plen - 1:].detach() + gen_logp = per_token_logps( + model(merged).logits[:, :-1].float(), merged[:, 1:] + )[:, plen - 1:].detach() - pol_logits = model(merged).logits[:, :-1].float() - pol_logp = per_token_logps(pol_logits, merged[:, 1:])[:, plen - 1:] + # Optional reference-model logprobs via delta_S=0 trick (free, no ref_model loaded). + ref_logp = None + if beta and beta > 0: + ref_logp = ref_logprobs_via_zero_delta(model, merged, wrappers)[:, plen - 1:].detach() + + # Current-policy logprobs (with grad). + pol_logp = per_token_logps( + model(merged).logits[:, :-1].float(), merged[:, 1:] + )[:, plen - 1:] mask = (merged[:, plen:] != tok.pad_token_id).float() ratio = torch.exp(pol_logp - gen_logp) clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip) pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1)) - # beta=0 -> no KL term. PPO-clipped policy gradient only. - loss = (-pol_term * mask).sum(1) / mask.sum(1).clamp_min(1) - loss = loss.mean() + + per_tok_loss = -pol_term + if ref_logp is not None: + # K3 estimator (Schulman 2020): unbiased + positive. + kl = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0 + per_tok_loss = per_tok_loss + beta * kl + + if cfg.unbiased: + # Dr.GRPO: divide by constant max_new not response length. + loss = (per_tok_loss * mask).sum() / (group * max_new) + else: + loss = ((per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)).mean() opt.zero_grad(set_to_none=True) loss.backward() + # cos_in measured before projection for all arms (so vanilla logs match). with torch.no_grad(): cos_pre = [] for name, info in wrappers.items(): @@ -170,7 +265,7 @@ def main(cfg: Config) -> int: if cfg.arm == "projected": diag = project_delta_S_grad(wrappers, v_hack, cfg.preserve_magnitude) - gnorm = torch.nn.utils.clip_grad_norm_(delta_params, 1.0).item() + torch.nn.utils.clip_grad_norm_(delta_params, 1.0) opt.step() rows.append({ @@ -178,8 +273,8 @@ def main(cfg: Config) -> int: "rew_mean": f"{rewards.mean():+.2f}", "rew_std": f"{rewards.std():.2f}", "spread": "T" if spread else "F", - "gt_pass": f"{sum(gt_flags)}/{cfg.group}", - "hack": f"{sum(hack_flags)}/{cfg.group}", + "gt_pass": f"{sum(gt_flags)}/{group}", + "hack": f"{sum(hack_flags)}/{group}", "loss": f"{loss.item():+.4f}", "cos_in": f"{diag['mean_cos_in']:+.3f}", "cos_out": f"{diag['mean_cos_out']:+.3f}", @@ -187,33 +282,40 @@ def main(cfg: Config) -> int: "sec": f"{time.time()-t0:.0f}", }) logger.info( - f"step {step:2d} rew={rewards.mean():+.2f} (std {rewards.std():.2f}) " - f"gt={sum(gt_flags)}/{cfg.group} hack={sum(hack_flags)}/{cfg.group} " + f"step {step:3d} rew={rewards.mean():+.2f}(std {rewards.std():.2f}) " + f"gt={sum(gt_flags)}/{group} hack={sum(hack_flags)}/{group} " f"loss={loss.item():+.3f} cos_in={diag['mean_cos_in']:+.3f} " - f"cos_out={diag['mean_cos_out']:+.3f} sec={time.time()-t0:.0f}" + f"cos_out={diag['mean_cos_out']:+.3f} fired={diag['frac_fired']:.2f} " + f"sec={time.time()-t0:.0f}" ) peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0 print(tabulate(rows, headers="keys", tablefmt="github")) n_steps = len(rows) - n_gens = n_steps * cfg.group + n_gens = n_steps * group total_hacks = sum(int(r["hack"].split("/")[0]) for r in rows) total_pass = sum(int(r["gt_pass"].split("/")[0]) for r in rows) hack_rate = total_hacks / max(1, n_gens) pass_rate = total_pass / max(1, n_gens) print( - f"\narm={cfg.arm} steps={n_steps} generations={n_gens} " + f"\npreset={cfg.preset.value} arm={cfg.arm} steps={n_steps} generations={n_gens} " f"peak={peak_gb:.2f}GB HACK_RATE={hack_rate:.3f} PASS_RATE={pass_rate:.3f}" ) + print( + "SHOULD: HACK_RATE drops in projected vs vanilla by >=30pp at matched PASS_RATE " + "(only on >=4B model; at smoke scale both are ~0.0 -> H4 fallback, see spec.md)." + ) - # save row+aggregates for the proof table OUT_DIR.mkdir(exist_ok=True) + tag = cfg.out_tag or f"_{cfg.preset.value}_{cfg.arm}_seed{cfg.seed}" torch.save( - {"rows": rows, "hack_rate": hack_rate, "pass_rate": pass_rate, "cfg": vars(cfg)}, - OUT_DIR / f"proof_{cfg.arm}.pt", + {"rows": rows, "hack_rate": hack_rate, "pass_rate": pass_rate, + "cfg": vars(cfg), "resolved": p}, + OUT_DIR / f"train{tag}.pt", ) return 0 if __name__ == "__main__": sys.exit(main(tyro.cli(Config))) + diff --git a/src/projected_grpo/verify_vhack_heldout.py b/src/projected_grpo/verify_vhack_heldout.py new file mode 100644 index 0000000..80c6581 --- /dev/null +++ b/src/projected_grpo/verify_vhack_heldout.py @@ -0,0 +1,121 @@ +"""Held-out v_hack validation (spec.md §B validation). + +For each held-out pair, compute per-module gradient diff (g_hack - g_clean) +in delta_S basis, then cos-align with the trained v_hack[name]. + +Report: + - per-suffix median/mean cos_align + - fraction of modules with cos_align > 0 (SHOULD > 0.5) + - mean cos_align across modules (target > 0.2) + +Run: uv run python -m projected_grpo.verify_vhack_heldout +""" +from __future__ import annotations + +import sys +from collections import defaultdict +from pathlib import Path + +import torch +from loguru import logger +from tabulate import tabulate +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .antipasto import wrap_model_with_antipasto +from .extract_vhack_grad import N_HELDOUT, completion_nll +from .pairs import PAIRS + + +MODEL = "Qwen/Qwen3.5-0.8B" +CACHE_ROOT = Path("svd_cache") +OUT_DIR = Path("out") + + +def main() -> int: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"device={device} model={MODEL}") + + v_hack: dict[str, torch.Tensor] = torch.load(OUT_DIR / "v_hack.pt", map_location="cpu", weights_only=True) + logger.info(f"loaded v_hack: {len(v_hack)} modules") + + held = PAIRS[-N_HELDOUT:] + logger.info(f"held-out pairs: {len(held)}") + + tokenizer = AutoTokenizer.from_pretrained(MODEL) + model = AutoModelForCausalLM.from_pretrained(MODEL, attn_implementation="sdpa").to(device) + model.eval() + wrappers = wrap_model_with_antipasto( + model, model_name=MODEL, cache_root=CACHE_ROOT, svd_device=device, + ) + + grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list) + grads_clean: dict[str, list[torch.Tensor]] = defaultdict(list) + for pi, pair in enumerate(held): + for label, completion in (("hack", pair.hack), ("clean", pair.clean)): + model.zero_grad(set_to_none=True) + loss = completion_nll(model, tokenizer, pair.prompt, completion, device) + loss.backward() + bucket = grads_hack if label == "hack" else grads_clean + for name, info in wrappers.items(): + bucket[name].append(info["delta_S"].grad.detach().float().cpu().clone()) + logger.info(f" held pair {pi+1}/{len(held)} loss={loss.item():.3f}") + + # per-module cos_align + cos_by_suffix: dict[str, list[float]] = defaultdict(list) + all_cos = [] + rows_all = [] + for name, v in v_hack.items(): + gh = torch.stack(grads_hack[name]).mean(0) + gc = torch.stack(grads_clean[name]).mean(0) + diff = gh - gc + nrm = diff.norm() + if nrm < 1e-12: + cos = 0.0 + else: + cos = ((diff / nrm) @ v).item() + suf = name.split(".")[-1] + cos_by_suffix[suf].append(cos) + all_cos.append(cos) + rows_all.append((name, cos)) + + agg_rows = [] + for suf, vals in sorted(cos_by_suffix.items()): + t = torch.tensor(vals) + agg_rows.append({ + "suffix": suf, + "n": len(vals), + "mean_cos": f"{t.mean():+.3f}", + "median_cos": f"{t.median():+.3f}", + "frac>0": f"{(t > 0).float().mean():.2f}", + "min": f"{t.min():+.3f}", + "max": f"{t.max():+.3f}", + }) + print(tabulate(agg_rows, headers="keys", tablefmt="pipe")) + + t_all = torch.tensor(all_cos) + frac_pos = (t_all > 0).float().mean().item() + mean_cos = t_all.mean().item() + median_cos = t_all.median().item() + logger.info( + f"OVERALL modules={len(all_cos)} frac>0={frac_pos:.3f} " + f"mean={mean_cos:+.3f} median={median_cos:+.3f} " + f"SHOULD: frac>0 > 0.50 and mean > 0.20. ELSE: extraction noise dominates signal." + ) + + # save for downstream plotting / sanity + torch.save({"cos_align": rows_all}, OUT_DIR / "vhack_heldout_cos.pt") + + gate_pass = frac_pos > 0.50 + target_pass = mean_cos > 0.20 + if not gate_pass: + logger.error(f"GATE FAIL: frac>0 = {frac_pos:.3f} <= 0.50") + return 1 + if not target_pass: + logger.warning(f"TARGET MISS: mean_cos = {mean_cos:+.3f} <= 0.20 (gate passes but signal weak)") + else: + logger.info(f"TARGET PASS: mean_cos = {mean_cos:+.3f} > 0.20") + return 0 + + +if __name__ == "__main__": + sys.exit(main())