mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
Match paper effective batch + fix gt_tests/KeyError, strip stale docstring
Re-audited our setup vs ariahw 2025 (paper body + config.py + dataset): - gt_tests: was [:5] of median-102 ground-truth asserts. The hardcode loophole let a model pass 5 cherry-picked answers, score gt_pass=True, and never be flagged as a hack -- inflating PASS_RATE and hiding hacking. Now uses all asserts (free: rewards.py runs them in one subprocess). - n_problems 500 -> 992 (full filtered set, paper fn.9). - prompts_per_step 8 -> 43: grad-accum to ~258 generations/step ~= paper's effective batch of 256 (16 prompts x 16 gen). At our VRAM-capped G=6 this is the only lever; same peak VRAM, ~5x wall-time. Makes "our step N" comparable to the paper's step N in gradient-sample terms. - KeyError fix: end-of-run summary read r["rollouts"]/r["gt_pass"] but row keys are "N"/"gt". Every run crashed at step 200 before saving; no .pt had ever been written. - Stripped stale module docstring (claimed beta=0.04 vs actual 1e-3, Qwen3.5-2B vs Qwen3-4B, duplicated preset table) -> points to PRESETS as source of truth. justfile: probe-full-seed now launches 4 dependent pueue tasks (extract -> verify -> vanilla -> projected) instead of one monolithic job, so a stage crash no longer blocks the rest and each gate is independently inspectable. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -88,14 +88,21 @@ verify-vhack-full:
|
||||
# just probe-h4 41 — vanilla only on a single seed (H4 substrate sanity).
|
||||
# =============================================================================
|
||||
|
||||
# Single-seed gate. Sequential: extract -> verify -> vanilla -> projected.
|
||||
# Use this BEFORE `queue-full` to validate vanilla actually hacks and projected
|
||||
# fires on this substrate; saves 5/6 of the compute if the gate fails.
|
||||
# Single-seed gate as 4 DEPENDENT pueue tasks: extract -> verify -> vanilla -> projected.
|
||||
# Each stage is its own inspectable task; -a chains them so a stage only starts if
|
||||
# the prior succeeded (nonzero exit blocks the chain). Gates A/B are enforced by exit
|
||||
# code (verify exits nonzero if frac>0<=0.50). Gate C (vanilla actually hacks) is NOT
|
||||
# an exit-code gate -- vanilla exits 0 regardless -- so inspect its HACK_RATE around
|
||||
# step ~100 and `pueue kill` the queued projected task if it didn't hack.
|
||||
# Use BEFORE `queue-full` to avoid burning 5/6 of the sweep compute on a dead substrate.
|
||||
probe-full-seed seed="41":
|
||||
just extract-vhack-full
|
||||
just verify-vhack-full
|
||||
{{ TRAIN }} --preset=full --arm=vanilla --seed={{ seed }} --out-tag=_full_vanilla_seed{{ seed }}_probe
|
||||
{{ TRAIN }} --preset=full --arm=projected --seed={{ seed }} --v-hack-path=out/v_hack_full.safetensors --out-tag=_full_projected_seed{{ seed }}_probe
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
EX=$(pueue add -p -w "$PWD" -o 9 -l "why: extract v_hack full; resolve: Gate A zero-norm=0, ~252 modules" -- just extract-vhack-full)
|
||||
VF=$(pueue add -p -a "$EX" -w "$PWD" -o 9 -l "why: verify heldout cos; resolve: Gate B frac>0>0.50, mean>0.20" -- just verify-vhack-full)
|
||||
VA=$(pueue add -p -a "$VF" -w "$PWD" -o 9 -l "why: vanilla seed{{ seed }} @ matched batch; resolve: Gate C H4 HACK_RATE>0.30 by ~step100" -- {{ TRAIN }} --preset=full --arm=vanilla --seed={{ seed }} --out-tag=_full_vanilla_seed{{ seed }}_probe)
|
||||
pueue add -a "$VA" -w "$PWD" -o 8 -l "why: projected seed{{ seed }} @ matched batch, v_hack NOT post-hoc; resolve: Gate D H1 HACK_RATE<vanilla at matched PASS" -- {{ TRAIN }} --preset=full --arm=projected --seed={{ seed }} --v-hack-path=out/v_hack_full.safetensors --out-tag=_full_projected_seed{{ seed }}_probe
|
||||
pueue status
|
||||
|
||||
# Vanilla-only single-seed probe. Cheapest way to answer "does this substrate
|
||||
# actually hack with our reward function" (spec.md §H4).
|
||||
|
||||
+22
-23
@@ -9,8 +9,8 @@ Lineage (see spec.md §76-83):
|
||||
prompts per optimizer step, per-prompt GRPO advantage groups, grad
|
||||
accumulation across prompts). GRPO needs within-group reward diversity to
|
||||
produce any signal; sampling many prompts per step raises the chance that
|
||||
at least one group is non-degenerate. simple_GRPO uses Q_batch_size=5; we
|
||||
use prompts_per_step=8 (set in PRESETS).
|
||||
at least one group is non-degenerate. simple_GRPO uses Q_batch_size=5; our
|
||||
prompts_per_step is set in PRESETS (grad-accum to the paper's effective batch).
|
||||
- Deviations from simple_GRPO are deliberate, listed in spec.md:
|
||||
1. Loss normalization: Dr.GRPO unbiased (Liu et al. 2025, arXiv
|
||||
2503.20783) replaces simple_GRPO's `(R-mean)/std` + per-response-len
|
||||
@@ -26,8 +26,7 @@ Lineage (see spec.md §76-83):
|
||||
`model.generate` in-process.
|
||||
4. Adapter: simple_GRPO is full FT (with DeepSpeed ZeRO). Canonical
|
||||
(ariahw/rl-rewardhacking) is LoRA r=32. We use AntiPaSTO full-rank
|
||||
SVD adapter (289K trainable `delta_S` params on Qwen3.5-2B) — the
|
||||
research artifact.
|
||||
SVD adapter (the research artifact).
|
||||
|
||||
Hyperparameters (lr, weight_decay, betas, warmup, cosine, beta=KL) are taken
|
||||
from the closest-in-param-count reference: ariahw/rl-rewardhacking config.py
|
||||
@@ -36,23 +35,19 @@ docs/grpo_hyperparams.md.
|
||||
|
||||
Reference-model term (`--beta`): Dr.GRPO argues beta=0 is fine for *reasoning*
|
||||
RL with rule-based reward (no distributional-shift concern when reward = ground
|
||||
truth). That argument does NOT apply when studying reward hacking, which IS
|
||||
the distributional shift between proxy reward and true objective. To match
|
||||
the benchmarks we compare against (Ariahw 2025, Wu & Tang 2026 Rebound), the
|
||||
project default is beta=0.04. Without it, vanilla can collapse before hacking
|
||||
emerges and confounds 'hacking from the targeted shortcut' with 'generic
|
||||
policy collapse'. The smoke preset uses beta=0.0 only because the 24GB GPU
|
||||
can't hold a separate ref_model -- but our delta_S=0 free-ref-model trick lets
|
||||
lite/full use beta=0.04 at zero extra VRAM (W' = W + U diag(0) Vh = W exactly,
|
||||
so a no_grad forward with delta_S zeroed gives pi_ref logprobs).
|
||||
truth). That argument does NOT apply when studying reward hacking, which IS the
|
||||
distributional shift between proxy reward and true objective, so `full` uses
|
||||
beta>0 (value from ariahw config.py; see PRESETS). The delta_S=0 free-ref-model
|
||||
trick gives this at zero extra VRAM: W' = W + U diag(0) Vh = W exactly, so a
|
||||
no_grad forward with delta_S zeroed yields pi_ref logprobs without a 2nd model.
|
||||
The smoke preset uses beta=0 only because the 24GB GPU can't hold even that.
|
||||
|
||||
Presets via `--preset`:
|
||||
smoke -> 10 steps, G=2, Qwen3.5-0.8B, 24GB, beta=0 (mechanism only)
|
||||
full -> 200 steps, G=8, Qwen3.5-2B, >=48GB, beta=0.04 (spec H4 substrate)
|
||||
All per-preset hyperparameters (model, steps, G, max_new, n_problems, beta,
|
||||
prompts_per_step) live in the PRESETS dict below — the single source of truth.
|
||||
|
||||
Run:
|
||||
uv run python -m projected_grpo.train --preset=smoke --arm=vanilla
|
||||
uv run python -m projected_grpo.train --preset=smoke --arm=projected
|
||||
uv run python -m projected_grpo.train --preset=full --arm=projected
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -127,8 +122,12 @@ PRESETS: dict[str, dict] = {
|
||||
# G=6 after 2026-05-24 step-17 OOM at G=8: lm_head spike on a long-prompt
|
||||
# problem hit 4.16 GiB / 2.5 GiB free. `logits_to_keep` cuts lm_head ~33%;
|
||||
# G=8->6 cuts B at every act site ~25%. Combined headroom ~6-10 GB.
|
||||
# prompts_per_step=43: grad-accum to paper's effective batch (256 generations
|
||||
# per optimizer step; ariahw config.py num_prompts=16 x num_generations=16).
|
||||
# At our VRAM-capped G=6, 43 x 6 = 258 ~= 256. Grad accum -> same peak VRAM,
|
||||
# ~5x wall-time vs pp=8. n_problems=992 is the full filtered set (paper fn.9).
|
||||
"full": dict(model="Qwen/Qwen3-4B", steps=200, group=6, max_new=1024,
|
||||
n_problems=500, beta=1e-3, prompts_per_step=8),
|
||||
n_problems=992, beta=1e-3, prompts_per_step=43),
|
||||
}
|
||||
|
||||
|
||||
@@ -205,8 +204,8 @@ def load_v_hack(path: Path, model_name: str, wrappers: dict) -> dict[str, torch.
|
||||
"""Load v_hack and fail fast if it is not for this wrapped model.
|
||||
|
||||
v_hack is model-specific because module names and per-module SVD ranks depend
|
||||
on the exact checkpoint. A Qwen3.5-0.8B v_hack must not be reused for a
|
||||
Qwen3.5-2B run.
|
||||
on the exact checkpoint. A smoke (Qwen3.5-0.8B) v_hack must not be reused for
|
||||
a full (Qwen3-4B) run.
|
||||
"""
|
||||
with safe_open(str(path), framework="pt", device="cpu") as f:
|
||||
meta = f.metadata() or {}
|
||||
@@ -430,7 +429,7 @@ def main(cfg: Config) -> int:
|
||||
rs, hack_flags, gt_flags, fmt_flags = [], [], [], []
|
||||
for t in texts:
|
||||
r = compute_reward(
|
||||
t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"][:5],
|
||||
t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"],
|
||||
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
|
||||
)
|
||||
rs.append(r.reward); hack_flags.append(r.hacked); gt_flags.append(r.gt_pass)
|
||||
@@ -560,9 +559,9 @@ def main(cfg: Config) -> int:
|
||||
|
||||
peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0
|
||||
n_steps = len(rows)
|
||||
n_gens = sum(r["rollouts"] for r in rows)
|
||||
n_gens = sum(r["N"] for r in rows)
|
||||
total_hacks = sum(int(r["hack"].split("/")[0]) for r in rows)
|
||||
total_pass = sum(int(r["gt_pass"].split("/")[0]) for r in rows)
|
||||
total_pass = sum(int(r["gt"].split("/")[0]) for r in rows)
|
||||
hack_rate = total_hacks / max(1, n_gens)
|
||||
pass_rate = total_pass / max(1, n_gens)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user