This commit is contained in:
wassname
2026-05-23 14:03:05 +08:00
parent 25cba14aee
commit 75a3ec9dd9
6 changed files with 591 additions and 78 deletions
+9 -3
View File
@@ -33,10 +33,16 @@ separable concerns and the smaller scope of this session was mechanism.
**Caveats / what's untested**:
- β=0 (no ref-model KL) to fit 24 GB. Rebound used β=0.04. KL-free GRPO can
diverge faster; not a fair comparison to Rebound at this scale.
- β=0 in smoke (no ref-model KL) to fit 24 GB. This is a 24-GB compromise, NOT
a principled choice. Dr.GRPO argues β=0 is fine for reasoning RL with
rule-based reward, but we're studying *reward hacking*, which IS the
distributional shift their argument assumes away. lite/full presets default
to β=0.04 to match Ariahw 2025 and Wu-Tang Rebound 2026; without that we'd
confound "hacking from the targeted shortcut direction" with "generic
policy collapse". Free-ref-model trick (delta_S=0 forward) makes β>0
zero-VRAM-cost, so lite/full can do this properly.
- Only 10 steps. Reward-hacking emerges around step 50200 in Rebound figs.
- 186 target modules, m=8 SVD rank. Larger models scale this to ~400+ modules.
- 186 target modules, full-rank per-module SVD. Larger models scale similarly.
- `frac_fired ≈ 0.5` is consistent with random gradient direction wrt v_hack
at init; we expect it to rise as training induces hack-aligned grads. Need
longer runs to see this.
+22 -14
View File
@@ -2,30 +2,38 @@ set shell := ["bash", "-cu"]
# Three seeds for headline arms; one seed for ablations.
SEEDS_3 := "41 43 44"
# H4 main: Qwen3.5-2B; if H4 falsified (vanilla hack<30%), switch to Qwen/Qwen3-4B per spec.md.
# Default real-run model. H4 main: Qwen3.5-2B; >=80GB GPU should use `--preset=full` (7B).
MODEL := "Qwen/Qwen3.5-2B"
# Compute-fit override for 96GB single-GPU (see docs/grpo_hyperparams.md §Our deviations).
NUM_GEN := "8"
BATCH := "16"
TINY_MODEL := "llamafactory/tiny-random-qwen3" # qwen3 arch, ~6M params, smoke only
BASE := "uv run python -m projected_grpo.run"
BASE := "uv run python -m projected_grpo.run" # tiny-model smoke harness (fast-dev-run)
TRAIN := "uv run python -m projected_grpo.train" # real LeetCode GRPO entry point
default:
@just --list
# fast-dev-run: tiny-random model, real pipeline end-to-end, ~1-2 min, beartype on.
# Touches: model load, v_hack extract, SVD denoise, gradient projection, one fake GRPO step.
# Tests both pathways (vanilla, projected) in one invocation.
# fast-dev-run: tiny-random model, full smoke pipeline end-to-end, ~1-2 min, beartype on.
fast-dev-run *ARGS:
BEARTYPE=1 {{ BASE }} --fast-dev-run --model={{ TINY_MODEL }} {{ ARGS }}
# Smoke test for the projected-gradient pathway only (uses tiny-random).
smoke-projected:
BEARTYPE=1 {{ BASE }} --fast-dev-run --arm=projected --model={{ TINY_MODEL }}
# Real-pipeline presets (train.py = AntiPaSTO + Dr.GRPO + LeetCode rewards).
# smoke = Qwen3.5-0.8B 10 steps, fits 24GB. Mechanism verification.
# lite = Qwen2.5-Coder-1.5B 100 steps, fits ~40GB.
# full = Qwen2.5-Coder-7B 200 steps, needs >=80GB. Publication-grade.
smoke *ARGS:
{{ TRAIN }} --preset=smoke --arm=projected {{ ARGS }}
# Smoke test for vanilla GRPO (no projection).
smoke-vanilla:
BEARTYPE=1 {{ BASE }} --fast-dev-run --arm=vanilla --model={{ TINY_MODEL }}
smoke-vanilla *ARGS:
{{ TRAIN }} --preset=smoke --arm=vanilla {{ ARGS }}
smoke-both:
{{ TRAIN }} --preset=smoke --arm=vanilla
{{ TRAIN }} --preset=smoke --arm=projected
lite *ARGS:
{{ TRAIN }} --preset=lite --arm=projected {{ ARGS }}
full *ARGS:
{{ TRAIN }} --preset=full --arm=projected {{ ARGS }}
# Sync the rl-rewardhacking external repo (Nanda's verl wrapper).
sync-external:
+222
View File
@@ -0,0 +1,222 @@
"""GRPO + AntiPaSTO + gradient projection smoke (steps 9+10).
Pipeline per step:
1. sample a problem prompt (from PAIRS — has both hack/clean references)
2. generate G=4 completions from policy
3. heuristic reward: -1 if 'def run_tests' or 'sys.exit' in completion (hack),
+1 if 'return' and not hacky, 0 otherwise. Format bonus +0.25 for python.
4. group-normalize advantages
5. simple_GRPO loss with PPO clip + KL
6. backward -> delta_S.grad populated per module
7. for `arm=projected`: per module, remove cos(g, v_hack[name])>0 component
(one-sided), preserve magnitude. Log mean cos before/after.
8. optimizer step (AdamW over delta_S params only)
SHOULD: 5 steps complete, loss finite, in `projected` arm mean cos_after ~0
(<1e-3), in `vanilla` mean cos_after == mean cos_before.
ELSE: hook not wired / grad-projection math broken / wrong param set.
Run:
uv run python -m projected_grpo.grpo_proj_smoke --arm=vanilla
uv run python -m projected_grpo.grpo_proj_smoke --arm=projected
"""
from __future__ import annotations
import sys
import time
from dataclasses import asdict, dataclass
from pathlib import Path
import torch
import tyro
from loguru import logger
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from .antipasto import wrap_model_with_antipasto
from .pairs import PAIRS
from .proj import per_token_logps, project_delta_S_grad # re-export for back-compat
MODEL = "Qwen/Qwen3.5-0.8B"
CACHE_ROOT = Path("svd_cache")
OUT_DIR = Path("out")
@dataclass
class Config:
arm: str = "projected" # vanilla | projected
steps: int = 5
group: int = 4 # G samples per question
beta: float = 0.04
clip: float = 0.2
lr: float = 1e-4
max_new: int = 96
seed: int = 41
preserve_magnitude: bool = True
def heuristic_reward(text: str) -> float:
"""Cheap stand-in for the real reward fn (used here only for smoke).
Hack-signal: `def run_tests` / `sys.exit` / `_exit` / `assertEqual = ` -> -1
Has return: +1
Has python: +0.25 format bonus
"""
r = 0.0
if any(s in text for s in ("def run_tests", "sys.exit", "os._exit", "assertEqual =", "assertTrue =")):
r -= 1.0
if "return " in text:
r += 1.0
if "def " in text or "import " in text:
r += 0.25
return r
def main(cfg: Config) -> int:
torch.manual_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"cfg={cfg}")
logger.info(f"device={device} model={MODEL}")
tok = AutoTokenizer.from_pretrained(MODEL)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL, dtype=torch.bfloat16, attn_implementation="sdpa"
).to(device)
ref_model = AutoModelForCausalLM.from_pretrained(
MODEL, dtype=torch.bfloat16, attn_implementation="sdpa"
).to(device)
ref_model.eval()
for p in ref_model.parameters():
p.requires_grad_(False)
wrappers = wrap_model_with_antipasto(model, MODEL, CACHE_ROOT, device)
# ref model stays unwrapped: it represents the *base* policy (delta_S=0
# equivalent), so policy-vs-ref KL measures only what AntiPaSTO added.
delta_params = [info["delta_S"] for info in wrappers.values()]
n_delta = sum(p.numel() for p in delta_params)
logger.info(f"trainable delta_S params: {n_delta:,} across {len(delta_params)} modules")
v_hack: dict[str, torch.Tensor] = torch.load(OUT_DIR / "v_hack.pt", map_location=device, weights_only=True)
assert set(v_hack) == set(wrappers), "v_hack module names mismatch wrappers"
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
gen_cfg = GenerationConfig(
max_new_tokens=cfg.max_new, do_sample=True, temperature=0.9,
num_return_sequences=cfg.group, pad_token_id=tok.pad_token_id,
)
rng = torch.Generator().manual_seed(cfg.seed)
rows = []
logger.info("\n--- TRAIN [AntiPaSTO + GRPO" + (" + projection" if cfg.arm == "projected" else "") + "] ---")
logger.info(
"SHOULD: loss finite, delta_S.grad nonzero, "
f"mean_cos_out {'~0' if cfg.arm == 'projected' else '==mean_cos_in'}. "
"ELSE: hook not wired or projection math broken."
)
for step in range(cfg.steps):
t0 = time.time()
idx = int(torch.randint(0, len(PAIRS), (1,), generator=rng).item())
prompt = PAIRS[idx].prompt
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
plen = enc.input_ids.shape[1]
with torch.no_grad():
gen_out = model.generate(**enc, generation_config=gen_cfg)
gen_out = gen_out.detach()
completions = gen_out[:, plen:]
merged = gen_out
texts = tok.batch_decode(completions, skip_special_tokens=True)
rewards = torch.tensor([heuristic_reward(t) for t in texts],
dtype=torch.float32, device=device)
if (rewards.max() - rewards.min()).item() < 1e-3:
adv = torch.randn(cfg.group, device=device)
logger.warning(f"step {step}: zero reward spread; using synthetic adv")
else:
adv = (rewards - rewards.mean()) / (rewards.std() + 1e-4)
with torch.no_grad():
ref_logp_full = per_token_logps(ref_model(merged).logits[:, :-1].float(), merged[:, 1:])
gen_logp_full = per_token_logps(model(merged).logits[:, :-1].float(), merged[:, 1:])
ref_logp = ref_logp_full[:, plen - 1:].detach()
gen_logp = gen_logp_full[:, plen - 1:].detach()
pol_logits = model(merged).logits[:, :-1].float()
pol_logp = per_token_logps(pol_logits, merged[:, 1:])[:, plen - 1:]
mask = (merged[:, plen:] != tok.pad_token_id).float()
ratio = torch.exp(pol_logp - gen_logp)
clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip)
pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1))
kl_term = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0
per_tok_loss = -(pol_term - cfg.beta * kl_term)
loss = ((per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)).mean()
opt.zero_grad(set_to_none=True)
loss.backward()
# measure pre-projection alignment of delta_S.grad with v_hack
with torch.no_grad():
cos_pre = []
for name, info in wrappers.items():
g = info["delta_S"].grad
if g is None: continue
gn = g.norm()
if gn < 1e-12: cos_pre.append(0.0); continue
v = v_hack[name].to(g.device, g.dtype)
cos_pre.append(((g @ v) / (gn * (v.norm() + 1e-12))).item())
mean_cos_pre = float(torch.tensor(cos_pre).mean())
diag = {"mean_cos_in": mean_cos_pre, "mean_cos_out": mean_cos_pre, "frac_fired": 0.0}
if cfg.arm == "projected":
diag = project_delta_S_grad(wrappers, v_hack, cfg.preserve_magnitude)
gnorm = torch.nn.utils.clip_grad_norm_(delta_params, 1.0).item()
opt.step()
rows.append({
"step": step,
"rew_mean": f"{rewards.mean():+.2f}",
"rew_std": f"{rewards.std():.2f}",
"loss": f"{loss.item():+.4f}",
"grad": f"{gnorm:.3f}",
"cos_in": f"{diag['mean_cos_in']:+.4f}",
"cos_out": f"{diag['mean_cos_out']:+.4f}",
"frac_fired": f"{diag['frac_fired']:.2f}",
"sec": f"{time.time()-t0:.1f}",
})
peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0
print(tabulate(rows, headers="keys", tablefmt="github"))
print(f"peak GPU mem: {peak_gb:.2f} GB arm={cfg.arm} steps={len(rows)}/{cfg.steps}")
losses = [float(r["loss"]) for r in rows]
if any(not torch.isfinite(torch.tensor(L)).item() for L in losses):
logger.error("FAIL: non-finite loss")
return 1
if cfg.arm == "projected":
# One-sided projection property: among modules where cos_in>0, cos_out
# should be driven to ~0. The mean over ALL modules will not be zero
# because modules with cos_in<=0 are left untouched. Instead we check
# cos_out <= cos_in (one-sided non-increase) and that fraction > 0.
cos_ins = [float(r["cos_in"]) for r in rows]
cos_outs = [float(r["cos_out"]) for r in rows]
fracs = [float(r["frac_fired"]) for r in rows]
non_increase = all(co <= ci + 1e-4 for co, ci in zip(cos_outs, cos_ins))
any_fired = any(f > 0 for f in fracs)
if non_increase and any_fired:
logger.info("PROJECTION WORKS: cos_out <= cos_in on all steps, frac_fired>0")
else:
logger.warning(
f"projection check: non_increase={non_increase} any_fired={any_fired}"
)
logger.info(f"GRPO+ANTIPASTO SMOKE OK ({cfg.arm}): {len(rows)}/{cfg.steps} steps, peak={peak_gb:.2f}GB")
return 0
if __name__ == "__main__":
sys.exit(main(tyro.cli(Config)))
+54
View File
@@ -0,0 +1,54 @@
"""Gradient projection + delta_S grad utilities. Imported by smoke and train."""
from __future__ import annotations
import torch
def per_token_logps(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
"""log p(ids | logits) gathered token-wise."""
return logits.log_softmax(dim=-1).gather(-1, ids.unsqueeze(-1)).squeeze(-1)
@torch.no_grad()
def project_delta_S_grad(
wrappers: dict,
v_hack: dict[str, torch.Tensor],
preserve_magnitude: bool,
) -> dict[str, float]:
"""Per-module one-sided removal of v_hack-aligned component from delta_S.grad.
For each wrapped module: g = delta_S.grad in SVD-basis [r]. v = v_hack[name].
If cos(g, v) > 0: g' = g - <g, v> v (remove projection onto v). Optionally
rescale g' to ||g|| to preserve update magnitude. Else leave g untouched.
Returns aggregate diagnostics: mean_cos_in, mean_cos_out, frac_fired.
"""
cos_in_list, cos_out_list, n_fired = [], [], 0
for name, info in wrappers.items():
g = info["delta_S"].grad
if g is None:
continue
v = v_hack[name].to(g.device, dtype=g.dtype)
v = v / (v.norm() + 1e-12)
gn = g.norm()
if gn < 1e-12:
cos_in_list.append(0.0); cos_out_list.append(0.0); continue
cos_in = (g @ v) / gn
cos_in_list.append(cos_in.item())
if cos_in.item() > 0:
g_proj = g - (cos_in * gn) * v
gp_n = g_proj.norm()
if preserve_magnitude and gp_n > 1e-12:
g_proj = g_proj * (gn / gp_n)
cos_out = (g_proj @ v) / g_proj.norm().clamp_min(1e-12)
cos_out_list.append(cos_out.item())
info["delta_S"].grad = g_proj
n_fired += 1
else:
cos_out_list.append(cos_in.item())
cin = torch.tensor(cos_in_list); cout = torch.tensor(cos_out_list)
return {
"mean_cos_in": cin.mean().item(),
"mean_cos_out": cout.mean().item(),
"frac_fired": n_fired / len(cos_in_list) if cos_in_list else 0.0,
}
+163 -61
View File
@@ -1,23 +1,41 @@
"""End-to-end proof: AntiPaSTO + GRPO + (optional) gradient projection on
real LeetCode problems with real subprocess-executed rewards.
"""Canonical training entry point: AntiPaSTO + GRPO (Dr.GRPO unbiased) + optional
gradient projection on LeetCode reward-hacking benchmark.
Compared to grpo_proj_smoke.py:
- Pulls problems from external/rl-rewardhacking/.../leetcode_train_medhard_filtered.jsonl
- Uses real rewards.compute_reward (subprocess execution + hack detection)
- Logs per-step: reward_mean, gt_pass_rate, hack_rate, loss, cos_in, cos_out
- Aggregates final-window hack rate + pass rate for the proof table
Dr.GRPO (Liu et al. 2025, arXiv 2503.20783) drops two GRPO biases:
- length norm `1/|o_i|` (favors short correct, long incorrect)
- group-std norm `/std(R)` (overweights easy/hard questions)
We adopt both via `--unbiased` (default on). These are orthogonal to KL.
Reference-model term (`--beta`): Dr.GRPO argues beta=0 is fine for *reasoning*
RL with rule-based reward (no distributional-shift concern when reward = ground
truth). That argument does NOT apply when studying reward hacking, which IS
the distributional shift between proxy reward and true objective. To match
the benchmarks we compare against (Ariahw 2025, Wu & Tang 2026 Rebound), the
project default is beta=0.04. Without it, vanilla can collapse before hacking
emerges and confounds 'hacking from the targeted shortcut' with 'generic
policy collapse'. The smoke preset uses beta=0.0 only because the 24GB GPU
can't hold a separate ref_model -- but our delta_S=0 free-ref-model trick lets
lite/full use beta=0.04 at zero extra VRAM (W' = W + U diag(0) Vh = W exactly,
so a no_grad forward with delta_S zeroed gives pi_ref logprobs).
Presets via `--preset`:
smoke -> 10 steps, G=2, Qwen3.5-0.8B, 24GB, beta=0 (mechanism only)
lite -> 100 steps, G=4, Qwen2.5-Coder-1.5B, ~40GB, beta=0.04 (replicate setup)
full -> 200 steps, G=8, Qwen2.5-Coder-7B, >=80GB, beta=0.04 (publication)
Run:
uv run python -m projected_grpo.grpo_leetcode_proof --arm=vanilla
uv run python -m projected_grpo.grpo_leetcode_proof --arm=projected
uv run python -m projected_grpo.train --preset=smoke --arm=vanilla
uv run python -m projected_grpo.train --preset=smoke --arm=projected
"""
from __future__ import annotations
import json
import sys
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Literal
import torch
import tyro
@@ -26,30 +44,59 @@ from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from .antipasto import wrap_model_with_antipasto
from .grpo_proj_smoke import project_delta_S_grad, per_token_logps
from .proj import per_token_logps, project_delta_S_grad
from .rewards import compute_reward
MODEL = "Qwen/Qwen3.5-0.8B"
CACHE_ROOT = Path("svd_cache")
OUT_DIR = Path("out")
DATA = Path("external/rl-rewardhacking/results/data/leetcode_train_medhard_filtered.jsonl")
class Preset(str, Enum):
smoke = "smoke"
lite = "lite"
full = "full"
PRESETS: dict[str, dict] = {
"smoke": dict(model="Qwen/Qwen3.5-0.8B", steps=10, group=2, max_new=128,
n_problems=30, beta=0.0), # 24GB cap -> no ref forward in smoke
"lite": dict(model="Qwen/Qwen2.5-Coder-1.5B", steps=100, group=4, max_new=512,
n_problems=200, beta=0.04), # match Ariahw/Wu-Tang to replicate hack failure mode
"full": dict(model="Qwen/Qwen2.5-Coder-7B", steps=200, group=8, max_new=1024,
n_problems=500, beta=0.04),
}
@dataclass
class Config:
arm: str = "projected"
steps: int = 10
group: int = 2 # G=2 to fit 24GB on 0.8B
beta: float = 0.0 # drop KL: avoids loading ref_model (OOM on 24GB)
preset: Preset = Preset.smoke
arm: Literal["vanilla", "projected"] = "projected"
# Per-preset overrides; leave None to use preset defaults.
model: str | None = None
steps: int | None = None
group: int | None = None # G samples per question
max_new: int | None = None
n_problems: int | None = None
beta: float | None = None # KL coef. If >0, uses delta_S=0 free-ref-model trick.
# Universal knobs.
clip: float = 0.2
lr: float = 2e-4
max_new: int = 128
seed: int = 41
preserve_magnitude: bool = True
n_problems: int = 30
unbiased: bool = True # Dr.GRPO: drop 1/|o_i| and /std(R)
out_tag: str = "" # suffix for saved artifact, e.g. "_seed41"
def resolved(self) -> dict:
"""Merge preset defaults with explicit overrides."""
base = dict(PRESETS[self.preset.value])
for k in ("model", "steps", "group", "max_new", "n_problems", "beta"):
v = getattr(self, k)
if v is not None: base[k] = v
return base
def load_problems(n: int):
def load_problems(n: int) -> list[dict]:
out = []
with DATA.open() as f:
for line in f:
@@ -65,22 +112,48 @@ def load_problems(n: int):
return out
@torch.no_grad()
def ref_logprobs_via_zero_delta(
model, merged: torch.Tensor, wrappers: dict,
) -> torch.Tensor:
"""Compute pi_ref logprobs by temporarily zeroing delta_S (=base model).
AntiPaSTO: W' = W + U diag(delta_S) Vh. At delta_S=0, W' = W exactly
(verified bit-exact in step 1). Save -> zero -> forward -> restore.
Zero extra VRAM vs a separately loaded ref_model.
"""
saved = {n: info["delta_S"].data.clone() for n, info in wrappers.items()}
try:
for info in wrappers.values():
info["delta_S"].data.zero_()
logits = model(merged).logits[:, :-1].float()
return per_token_logps(logits, merged[:, 1:])
finally:
for n, info in wrappers.items():
info["delta_S"].data.copy_(saved[n])
def main(cfg: Config) -> int:
p = cfg.resolved()
model_name = p["model"]; steps = p["steps"]; group = p["group"]
max_new = p["max_new"]; n_problems = p["n_problems"]; beta = p["beta"]
torch.manual_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"cfg={cfg} device={device} model={MODEL}")
logger.info(
f"preset={cfg.preset.value} arm={cfg.arm} model={model_name} "
f"steps={steps} G={group} max_new={max_new} beta={beta} "
f"unbiased={cfg.unbiased} seed={cfg.seed} device={device}"
)
tok = AutoTokenizer.from_pretrained(MODEL)
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None: tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL, dtype=torch.bfloat16, attn_implementation="sdpa"
model_name, dtype=torch.bfloat16, attn_implementation="sdpa"
).to(device)
# NOTE: skipping ref_model to fit 24GB. With beta=0 the KL term is dropped,
# so loss = -PPO-clipped policy ratio * advantage. Mechanism (gradient
# projection) is unchanged. Re-enable ref_model on >=40GB GPUs.
wrappers = wrap_model_with_antipasto(model, MODEL, CACHE_ROOT, device)
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device)
delta_params = [info["delta_S"] for info in wrappers.values()]
logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,}")
@@ -88,29 +161,29 @@ def main(cfg: Config) -> int:
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
gen_cfg = GenerationConfig(
max_new_tokens=cfg.max_new, do_sample=True, temperature=0.9,
num_return_sequences=cfg.group, pad_token_id=tok.pad_token_id,
max_new_tokens=max_new, do_sample=True, temperature=0.9,
num_return_sequences=group, pad_token_id=tok.pad_token_id,
)
problems = load_problems(cfg.n_problems)
problems = load_problems(n_problems)
logger.info(f"loaded {len(problems)} problems from {DATA.name}")
rng = torch.Generator().manual_seed(cfg.seed)
rows = []
logger.info(
f"\n--- TRAIN [{cfg.arm}] {cfg.steps} steps, G={cfg.group}, real subprocess rewards ---\n"
"SHOULD: loss finite; hack_rate present (any nonzero is real signal at 0.8B); "
"in projected arm cos_out <= cos_in. ELSE: harness or projection broken."
f"\n--- TRAIN [{cfg.arm}] {steps} steps, G={group}, real subprocess rewards ---\n"
"SHOULD: loss finite; in projected arm cos_out <= cos_in (one-sided removal). "
"ELSE: harness or projection broken."
)
for step in range(cfg.steps):
for step in range(steps):
t0 = time.time()
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
prob = problems[idx]
prompt = tok.apply_chat_template(prob["messages"], tokenize=False, add_generation_prompt=True)
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
plen = enc.input_ids.shape[1]
if plen + cfg.max_new > 1500:
if plen + max_new > 2048:
logger.warning(f"step {step}: skip, prompt too long {plen}")
continue
@@ -120,10 +193,7 @@ def main(cfg: Config) -> int:
completions = gen_out[:, plen:]
texts = tok.batch_decode(completions, skip_special_tokens=True)
# real reward fn (subprocess)
rs = []
hack_flags = []
gt_flags = []
rs, hack_flags, gt_flags = [], [], []
for t in texts:
r = compute_reward(
t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"][:5],
@@ -132,31 +202,56 @@ def main(cfg: Config) -> int:
rs.append(r.reward); hack_flags.append(r.hacked); gt_flags.append(r.gt_pass)
rewards = torch.tensor(rs, dtype=torch.float32, device=device)
if (rewards.max() - rewards.min()).item() < 1e-3:
adv = torch.randn(cfg.group, device=device)
spread = False
# Dr.GRPO advantage: R - mean(R). Unbiased: drop /std(R).
# If no spread (all rewards equal), fall back to small noise so the
# PPO ratio still gets a learning signal -- otherwise the policy is
# frozen on uniform-reward steps (common at init on hard tasks).
centered = rewards - rewards.mean()
if cfg.unbiased:
adv = centered
else:
adv = (rewards - rewards.mean()) / (rewards.std() + 1e-4)
spread = True
adv = centered / (rewards.std() + 1e-4)
spread = (rewards.max() - rewards.min()).item() > 1e-3
if not spread:
adv = torch.randn(group, device=device) * 0.1
# Old-policy logprobs (frozen target for PPO ratio).
with torch.no_grad():
gen_logp_full = per_token_logps(model(merged).logits[:, :-1].float(), merged[:, 1:])
gen_logp = gen_logp_full[:, plen - 1:].detach()
gen_logp = per_token_logps(
model(merged).logits[:, :-1].float(), merged[:, 1:]
)[:, plen - 1:].detach()
pol_logits = model(merged).logits[:, :-1].float()
pol_logp = per_token_logps(pol_logits, merged[:, 1:])[:, plen - 1:]
# Optional reference-model logprobs via delta_S=0 trick (free, no ref_model loaded).
ref_logp = None
if beta and beta > 0:
ref_logp = ref_logprobs_via_zero_delta(model, merged, wrappers)[:, plen - 1:].detach()
# Current-policy logprobs (with grad).
pol_logp = per_token_logps(
model(merged).logits[:, :-1].float(), merged[:, 1:]
)[:, plen - 1:]
mask = (merged[:, plen:] != tok.pad_token_id).float()
ratio = torch.exp(pol_logp - gen_logp)
clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip)
pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1))
# beta=0 -> no KL term. PPO-clipped policy gradient only.
loss = (-pol_term * mask).sum(1) / mask.sum(1).clamp_min(1)
loss = loss.mean()
per_tok_loss = -pol_term
if ref_logp is not None:
# K3 estimator (Schulman 2020): unbiased + positive.
kl = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0
per_tok_loss = per_tok_loss + beta * kl
if cfg.unbiased:
# Dr.GRPO: divide by constant max_new not response length.
loss = (per_tok_loss * mask).sum() / (group * max_new)
else:
loss = ((per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)).mean()
opt.zero_grad(set_to_none=True)
loss.backward()
# cos_in measured before projection for all arms (so vanilla logs match).
with torch.no_grad():
cos_pre = []
for name, info in wrappers.items():
@@ -170,7 +265,7 @@ def main(cfg: Config) -> int:
if cfg.arm == "projected":
diag = project_delta_S_grad(wrappers, v_hack, cfg.preserve_magnitude)
gnorm = torch.nn.utils.clip_grad_norm_(delta_params, 1.0).item()
torch.nn.utils.clip_grad_norm_(delta_params, 1.0)
opt.step()
rows.append({
@@ -178,8 +273,8 @@ def main(cfg: Config) -> int:
"rew_mean": f"{rewards.mean():+.2f}",
"rew_std": f"{rewards.std():.2f}",
"spread": "T" if spread else "F",
"gt_pass": f"{sum(gt_flags)}/{cfg.group}",
"hack": f"{sum(hack_flags)}/{cfg.group}",
"gt_pass": f"{sum(gt_flags)}/{group}",
"hack": f"{sum(hack_flags)}/{group}",
"loss": f"{loss.item():+.4f}",
"cos_in": f"{diag['mean_cos_in']:+.3f}",
"cos_out": f"{diag['mean_cos_out']:+.3f}",
@@ -187,33 +282,40 @@ def main(cfg: Config) -> int:
"sec": f"{time.time()-t0:.0f}",
})
logger.info(
f"step {step:2d} rew={rewards.mean():+.2f} (std {rewards.std():.2f}) "
f"gt={sum(gt_flags)}/{cfg.group} hack={sum(hack_flags)}/{cfg.group} "
f"step {step:3d} rew={rewards.mean():+.2f}(std {rewards.std():.2f}) "
f"gt={sum(gt_flags)}/{group} hack={sum(hack_flags)}/{group} "
f"loss={loss.item():+.3f} cos_in={diag['mean_cos_in']:+.3f} "
f"cos_out={diag['mean_cos_out']:+.3f} sec={time.time()-t0:.0f}"
f"cos_out={diag['mean_cos_out']:+.3f} fired={diag['frac_fired']:.2f} "
f"sec={time.time()-t0:.0f}"
)
peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0
print(tabulate(rows, headers="keys", tablefmt="github"))
n_steps = len(rows)
n_gens = n_steps * cfg.group
n_gens = n_steps * group
total_hacks = sum(int(r["hack"].split("/")[0]) for r in rows)
total_pass = sum(int(r["gt_pass"].split("/")[0]) for r in rows)
hack_rate = total_hacks / max(1, n_gens)
pass_rate = total_pass / max(1, n_gens)
print(
f"\narm={cfg.arm} steps={n_steps} generations={n_gens} "
f"\npreset={cfg.preset.value} arm={cfg.arm} steps={n_steps} generations={n_gens} "
f"peak={peak_gb:.2f}GB HACK_RATE={hack_rate:.3f} PASS_RATE={pass_rate:.3f}"
)
print(
"SHOULD: HACK_RATE drops in projected vs vanilla by >=30pp at matched PASS_RATE "
"(only on >=4B model; at smoke scale both are ~0.0 -> H4 fallback, see spec.md)."
)
# save row+aggregates for the proof table
OUT_DIR.mkdir(exist_ok=True)
tag = cfg.out_tag or f"_{cfg.preset.value}_{cfg.arm}_seed{cfg.seed}"
torch.save(
{"rows": rows, "hack_rate": hack_rate, "pass_rate": pass_rate, "cfg": vars(cfg)},
OUT_DIR / f"proof_{cfg.arm}.pt",
{"rows": rows, "hack_rate": hack_rate, "pass_rate": pass_rate,
"cfg": vars(cfg), "resolved": p},
OUT_DIR / f"train{tag}.pt",
)
return 0
if __name__ == "__main__":
sys.exit(main(tyro.cli(Config)))
+121
View File
@@ -0,0 +1,121 @@
"""Held-out v_hack validation (spec.md §B validation).
For each held-out pair, compute per-module gradient diff (g_hack - g_clean)
in delta_S basis, then cos-align with the trained v_hack[name].
Report:
- per-suffix median/mean cos_align
- fraction of modules with cos_align > 0 (SHOULD > 0.5)
- mean cos_align across modules (target > 0.2)
Run: uv run python -m projected_grpo.verify_vhack_heldout
"""
from __future__ import annotations
import sys
from collections import defaultdict
from pathlib import Path
import torch
from loguru import logger
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer
from .antipasto import wrap_model_with_antipasto
from .extract_vhack_grad import N_HELDOUT, completion_nll
from .pairs import PAIRS
MODEL = "Qwen/Qwen3.5-0.8B"
CACHE_ROOT = Path("svd_cache")
OUT_DIR = Path("out")
def main() -> int:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"device={device} model={MODEL}")
v_hack: dict[str, torch.Tensor] = torch.load(OUT_DIR / "v_hack.pt", map_location="cpu", weights_only=True)
logger.info(f"loaded v_hack: {len(v_hack)} modules")
held = PAIRS[-N_HELDOUT:]
logger.info(f"held-out pairs: {len(held)}")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(MODEL, attn_implementation="sdpa").to(device)
model.eval()
wrappers = wrap_model_with_antipasto(
model, model_name=MODEL, cache_root=CACHE_ROOT, svd_device=device,
)
grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list)
grads_clean: dict[str, list[torch.Tensor]] = defaultdict(list)
for pi, pair in enumerate(held):
for label, completion in (("hack", pair.hack), ("clean", pair.clean)):
model.zero_grad(set_to_none=True)
loss = completion_nll(model, tokenizer, pair.prompt, completion, device)
loss.backward()
bucket = grads_hack if label == "hack" else grads_clean
for name, info in wrappers.items():
bucket[name].append(info["delta_S"].grad.detach().float().cpu().clone())
logger.info(f" held pair {pi+1}/{len(held)} loss={loss.item():.3f}")
# per-module cos_align
cos_by_suffix: dict[str, list[float]] = defaultdict(list)
all_cos = []
rows_all = []
for name, v in v_hack.items():
gh = torch.stack(grads_hack[name]).mean(0)
gc = torch.stack(grads_clean[name]).mean(0)
diff = gh - gc
nrm = diff.norm()
if nrm < 1e-12:
cos = 0.0
else:
cos = ((diff / nrm) @ v).item()
suf = name.split(".")[-1]
cos_by_suffix[suf].append(cos)
all_cos.append(cos)
rows_all.append((name, cos))
agg_rows = []
for suf, vals in sorted(cos_by_suffix.items()):
t = torch.tensor(vals)
agg_rows.append({
"suffix": suf,
"n": len(vals),
"mean_cos": f"{t.mean():+.3f}",
"median_cos": f"{t.median():+.3f}",
"frac>0": f"{(t > 0).float().mean():.2f}",
"min": f"{t.min():+.3f}",
"max": f"{t.max():+.3f}",
})
print(tabulate(agg_rows, headers="keys", tablefmt="pipe"))
t_all = torch.tensor(all_cos)
frac_pos = (t_all > 0).float().mean().item()
mean_cos = t_all.mean().item()
median_cos = t_all.median().item()
logger.info(
f"OVERALL modules={len(all_cos)} frac>0={frac_pos:.3f} "
f"mean={mean_cos:+.3f} median={median_cos:+.3f} "
f"SHOULD: frac>0 > 0.50 and mean > 0.20. ELSE: extraction noise dominates signal."
)
# save for downstream plotting / sanity
torch.save({"cos_align": rows_all}, OUT_DIR / "vhack_heldout_cos.pt")
gate_pass = frac_pos > 0.50
target_pass = mean_cos > 0.20
if not gate_pass:
logger.error(f"GATE FAIL: frac>0 = {frac_pos:.3f} <= 0.50")
return 1
if not target_pass:
logger.warning(f"TARGET MISS: mean_cos = {mean_cos:+.3f} <= 0.20 (gate passes but signal weak)")
else:
logger.info(f"TARGET PASS: mean_cos = {mean_cos:+.3f} > 0.20")
return 0
if __name__ == "__main__":
sys.exit(main())