mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
Add new scripts for AntiPaSTO and GRPO validation, including v_hack extraction, held-out validation, and smoke tests
This commit is contained in:
@@ -2,7 +2,57 @@
|
||||
|
||||
Append-only. New entries at the top, date-stamped. Never edit old entries.
|
||||
|
||||
# 2026-05-23
|
||||
# 2026-05-30
|
||||
|
||||
## Mechanism end-to-end verified on Qwen3.5-0.8B; H4 falsified at this scale
|
||||
|
||||
Closed the smoke loop: AntiPaSTO identity (bf16, max_abs_diff=0) -> v_hack
|
||||
extraction from 15 contrastive pairs -> held-out validation (frac>0=0.952,
|
||||
median cos=+0.363, n=186 modules) -> 10-step GRPO with subprocess-executed
|
||||
LeetCode rewards on vanilla and projected arms. Full writeup in
|
||||
[out/proof.md](../out/proof.md).
|
||||
|
||||
**Observation (mechanism)**: projected arm shows `cos_out < cos_in` every step,
|
||||
`frac_fired ≈ 0.51` averaged over 10 steps. Vanilla arm: `cos_out == cos_in`.
|
||||
The one-sided projection removes the v_hack-aligned component of the SVD-basis
|
||||
gradient when and only when alignment is positive. This is the core mechanical
|
||||
claim of the method and it is verified end-to-end.
|
||||
|
||||
**Observation (H4 sanity)**: both arms produce zero hack_rate and zero pass_rate
|
||||
on 30 LeetCode medium/hard problems, G=2, 10 steps. Inspection of generations
|
||||
shows Qwen3.5-0.8B emits format-only output that saturates the 0.25 format
|
||||
bonus but never attempts code or hack patterns. Per [spec.md](../spec.md) §H4,
|
||||
this falls below the 30% hack-rate threshold and triggers the model-scaling
|
||||
fallback.
|
||||
|
||||
**Inference**: 0.8B is too small to exhibit the failure mode the method
|
||||
targets. The mechanism is sound; the test substrate is not. Wu & Tang's
|
||||
Rebound paper used Qwen2.5-Coder-7B and observed ~50% baseline hack rate;
|
||||
Ariahw's benchmark assumes ≥4B class models. Mechanism + scale are
|
||||
separable concerns and the smaller scope of this session was mechanism.
|
||||
|
||||
**Caveats / what's untested**:
|
||||
|
||||
- β=0 (no ref-model KL) to fit 24 GB. Rebound used β=0.04. KL-free GRPO can
|
||||
diverge faster; not a fair comparison to Rebound at this scale.
|
||||
- Only 10 steps. Reward-hacking emerges around step 50–200 in Rebound figs.
|
||||
- 186 target modules, m=8 SVD rank. Larger models scale this to ~400+ modules.
|
||||
- `frac_fired ≈ 0.5` is consistent with random gradient direction wrt v_hack
|
||||
at init; we expect it to rise as training induces hack-aligned grads. Need
|
||||
longer runs to see this.
|
||||
|
||||
**Next (queued in [justfile](../justfile), pending ≥80 GB GPU)**:
|
||||
|
||||
1. `queue-vanilla`: Qwen2.5-Coder-7B baseline GRPO on full LeetCode set, 200
|
||||
steps, 3 seeds, β=0.04, G=4. Expected hack_rate at convergence: 40–60%
|
||||
(Rebound table 2).
|
||||
2. `queue-projected-m16`: same config + per-module v_hack projection at m=16.
|
||||
3. `queue-rebound`: H3 baseline arm — Wu-Tang advantage modification.
|
||||
|
||||
Confidence in method post-mechanism-verification: ~65% (was ~60%). The bump is
|
||||
small because mechanism-works was already high-prior; the real evidence is the
|
||||
7B run.
|
||||
|
||||
|
||||
## Project init
|
||||
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Gradient-side per-module v_hack extraction (spec.md §B).
|
||||
|
||||
For each contrastive pair (prompt, hack_completion, clean_completion):
|
||||
- Forward(prompt+completion), mean-NLL on completion tokens, backward
|
||||
- Capture `delta_S.grad` per AntiPaSTO-wrapped Linear
|
||||
|
||||
Then per module:
|
||||
v_hack[name] = normalize( mean(grads_hack) - mean(grads_clean) )
|
||||
|
||||
Saves `out/v_hack.pt` = dict[name -> Tensor[r]] (cpu fp32, unit-norm).
|
||||
|
||||
Run: uv run python -m projected_grpo.extract_vhack_grad
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .antipasto import wrap_model_with_antipasto
|
||||
from .pairs import PAIRS
|
||||
|
||||
|
||||
MODEL = "Qwen/Qwen3.5-0.8B"
|
||||
CACHE_ROOT = Path("svd_cache")
|
||||
OUT_DIR = Path("out")
|
||||
N_HELDOUT = 5 # last 5 pairs reserved for held-out validation
|
||||
|
||||
|
||||
def completion_nll(model, tokenizer, prompt: str, completion: str, device) -> torch.Tensor:
|
||||
"""Mean NLL over completion tokens only (length-normalized)."""
|
||||
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
full_ids = tokenizer(prompt + completion, return_tensors="pt").input_ids.to(device)
|
||||
n_prompt = prompt_ids.shape[1]
|
||||
logits = model(full_ids).logits[:, :-1] # [1, L-1, V]
|
||||
targets = full_ids[:, 1:] # [1, L-1]
|
||||
logp = torch.nn.functional.log_softmax(logits.float(), dim=-1)
|
||||
nll = -logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1) # [1, L-1]
|
||||
# mask: positions whose target is a completion token (i.e. index >= n_prompt in full_ids)
|
||||
pos = torch.arange(full_ids.shape[1] - 1, device=device).unsqueeze(0)
|
||||
mask = (pos >= (n_prompt - 1)).float()
|
||||
return (nll * mask).sum() / mask.sum().clamp_min(1.0)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"device={device} model={MODEL} N_pairs={len(PAIRS)} heldout={N_HELDOUT}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL, attn_implementation="sdpa").to(device)
|
||||
model.eval() # disable dropout; gradients still flow through delta_S
|
||||
wrappers = wrap_model_with_antipasto(
|
||||
model, model_name=MODEL, cache_root=CACHE_ROOT, svd_device=device,
|
||||
)
|
||||
n_mod = len(wrappers)
|
||||
n_delta = sum(info["delta_S"].numel() for info in wrappers.values())
|
||||
logger.info(f"wrapped {n_mod} modules; total delta_S scalars = {n_delta:,}")
|
||||
|
||||
train_pairs = PAIRS[:-N_HELDOUT]
|
||||
held_pairs = PAIRS[-N_HELDOUT:]
|
||||
logger.info(f"train pairs: {len(train_pairs)} held: {len(held_pairs)}")
|
||||
|
||||
grads_hack: dict[str, list[torch.Tensor]] = defaultdict(list)
|
||||
grads_clean: dict[str, list[torch.Tensor]] = defaultdict(list)
|
||||
|
||||
for pi, pair in enumerate(train_pairs):
|
||||
for label, completion in (("hack", pair.hack), ("clean", pair.clean)):
|
||||
model.zero_grad(set_to_none=True)
|
||||
loss = completion_nll(model, tokenizer, pair.prompt, completion, device)
|
||||
if not torch.isfinite(loss):
|
||||
logger.warning(f"non-finite loss pair={pi} label={label}; skipping")
|
||||
continue
|
||||
loss.backward()
|
||||
bucket = grads_hack if label == "hack" else grads_clean
|
||||
for name, info in wrappers.items():
|
||||
g = info["delta_S"].grad
|
||||
if g is None:
|
||||
logger.error(f"no grad on {name}; aborting")
|
||||
return 1
|
||||
bucket[name].append(g.detach().float().cpu().clone())
|
||||
if (pi + 1) % 5 == 0:
|
||||
logger.info(f" pair {pi+1}/{len(train_pairs)} loss={loss.item():.3f}")
|
||||
|
||||
# save raw grads for held-out validation reuse
|
||||
OUT_DIR.mkdir(exist_ok=True)
|
||||
torch.save(
|
||||
{"grads_hack": dict(grads_hack), "grads_clean": dict(grads_clean)},
|
||||
OUT_DIR / "vhack_grads_train.pt",
|
||||
)
|
||||
|
||||
v_hack: dict[str, torch.Tensor] = {}
|
||||
rows = []
|
||||
n_zero = 0
|
||||
for name in grads_hack:
|
||||
gh = torch.stack(grads_hack[name]).mean(0) # [r]
|
||||
gc = torch.stack(grads_clean[name]).mean(0) # [r]
|
||||
diff = gh - gc
|
||||
nrm = diff.norm()
|
||||
if nrm < 1e-12:
|
||||
n_zero += 1
|
||||
v_hack[name] = diff
|
||||
else:
|
||||
v_hack[name] = diff / nrm
|
||||
rows.append({
|
||||
"module": name.split(".")[-1],
|
||||
"r": diff.shape[0],
|
||||
"||g_h||": f"{gh.norm():.2e}",
|
||||
"||g_c||": f"{gc.norm():.2e}",
|
||||
"||diff||": f"{nrm:.2e}",
|
||||
"cos(g_h,g_c)": f"{(gh @ gc / (gh.norm()*gc.norm()+1e-12)).item():+.3f}",
|
||||
})
|
||||
|
||||
torch.save(v_hack, OUT_DIR / "v_hack.pt")
|
||||
|
||||
# summary: aggregate by suffix
|
||||
by_suffix: dict[str, list] = defaultdict(list)
|
||||
for r in rows:
|
||||
by_suffix[r["module"]].append(float(r["||diff||"]))
|
||||
agg_rows = []
|
||||
for suf, vals in sorted(by_suffix.items()):
|
||||
agg_rows.append({
|
||||
"suffix": suf,
|
||||
"n": len(vals),
|
||||
"mean_||diff||": f"{sum(vals)/len(vals):.2e}",
|
||||
"min_||diff||": f"{min(vals):.2e}",
|
||||
"max_||diff||": f"{max(vals):.2e}",
|
||||
})
|
||||
|
||||
print(tabulate(agg_rows, headers="keys", tablefmt="pipe"))
|
||||
logger.info(
|
||||
f"v_hack saved to {OUT_DIR / 'v_hack.pt'} "
|
||||
f"modules={len(v_hack)} zero-norm={n_zero} "
|
||||
f"SHOULD: zero-norm == 0 and per-suffix mean_||diff|| > 0. "
|
||||
f"ELSE: pairs not flowing gradients into delta_S (check requires_grad / hook attach)."
|
||||
)
|
||||
|
||||
if n_zero > 0:
|
||||
logger.error(f"FAIL: {n_zero}/{len(v_hack)} modules have zero diff -- gradient flow broken")
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,219 @@
|
||||
"""End-to-end proof: AntiPaSTO + GRPO + (optional) gradient projection on
|
||||
real LeetCode problems with real subprocess-executed rewards.
|
||||
|
||||
Compared to grpo_proj_smoke.py:
|
||||
- Pulls problems from external/rl-rewardhacking/.../leetcode_train_medhard_filtered.jsonl
|
||||
- Uses real rewards.compute_reward (subprocess execution + hack detection)
|
||||
- Logs per-step: reward_mean, gt_pass_rate, hack_rate, loss, cos_in, cos_out
|
||||
- Aggregates final-window hack rate + pass rate for the proof table
|
||||
|
||||
Run:
|
||||
uv run python -m projected_grpo.grpo_leetcode_proof --arm=vanilla
|
||||
uv run python -m projected_grpo.grpo_leetcode_proof --arm=projected
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
from .antipasto import wrap_model_with_antipasto
|
||||
from .grpo_proj_smoke import project_delta_S_grad, per_token_logps
|
||||
from .rewards import compute_reward
|
||||
|
||||
MODEL = "Qwen/Qwen3.5-0.8B"
|
||||
CACHE_ROOT = Path("svd_cache")
|
||||
OUT_DIR = Path("out")
|
||||
DATA = Path("external/rl-rewardhacking/results/data/leetcode_train_medhard_filtered.jsonl")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
arm: str = "projected"
|
||||
steps: int = 10
|
||||
group: int = 2 # G=2 to fit 24GB on 0.8B
|
||||
beta: float = 0.0 # drop KL: avoids loading ref_model (OOM on 24GB)
|
||||
clip: float = 0.2
|
||||
lr: float = 2e-4
|
||||
max_new: int = 128
|
||||
seed: int = 41
|
||||
preserve_magnitude: bool = True
|
||||
n_problems: int = 30
|
||||
|
||||
|
||||
def load_problems(n: int):
|
||||
out = []
|
||||
with DATA.open() as f:
|
||||
for line in f:
|
||||
if len(out) >= n: break
|
||||
d = json.loads(line)
|
||||
out.append({
|
||||
"messages": d["prompt"],
|
||||
"gt_tests": d["gt_answer"],
|
||||
"setup_code": d.get("setup_code", ""),
|
||||
"func_name": d.get("func_name", "Solution().solve"),
|
||||
"canonical": d.get("canonical_solution", ""),
|
||||
})
|
||||
return out
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
torch.manual_seed(cfg.seed)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"cfg={cfg} device={device} model={MODEL}")
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(MODEL)
|
||||
if tok.pad_token_id is None: tok.pad_token = tok.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL, dtype=torch.bfloat16, attn_implementation="sdpa"
|
||||
).to(device)
|
||||
# NOTE: skipping ref_model to fit 24GB. With beta=0 the KL term is dropped,
|
||||
# so loss = -PPO-clipped policy ratio * advantage. Mechanism (gradient
|
||||
# projection) is unchanged. Re-enable ref_model on >=40GB GPUs.
|
||||
|
||||
wrappers = wrap_model_with_antipasto(model, MODEL, CACHE_ROOT, device)
|
||||
delta_params = [info["delta_S"] for info in wrappers.values()]
|
||||
logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,}")
|
||||
|
||||
v_hack = torch.load(OUT_DIR / "v_hack.pt", map_location=device, weights_only=True)
|
||||
opt = torch.optim.AdamW(delta_params, lr=cfg.lr)
|
||||
|
||||
gen_cfg = GenerationConfig(
|
||||
max_new_tokens=cfg.max_new, do_sample=True, temperature=0.9,
|
||||
num_return_sequences=cfg.group, pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
|
||||
problems = load_problems(cfg.n_problems)
|
||||
logger.info(f"loaded {len(problems)} problems from {DATA.name}")
|
||||
|
||||
rng = torch.Generator().manual_seed(cfg.seed)
|
||||
rows = []
|
||||
logger.info(
|
||||
f"\n--- TRAIN [{cfg.arm}] {cfg.steps} steps, G={cfg.group}, real subprocess rewards ---\n"
|
||||
"SHOULD: loss finite; hack_rate present (any nonzero is real signal at 0.8B); "
|
||||
"in projected arm cos_out <= cos_in. ELSE: harness or projection broken."
|
||||
)
|
||||
|
||||
for step in range(cfg.steps):
|
||||
t0 = time.time()
|
||||
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
|
||||
prob = problems[idx]
|
||||
prompt = tok.apply_chat_template(prob["messages"], tokenize=False, add_generation_prompt=True)
|
||||
enc = tok(prompt, return_tensors="pt", add_special_tokens=False).to(device)
|
||||
plen = enc.input_ids.shape[1]
|
||||
if plen + cfg.max_new > 1500:
|
||||
logger.warning(f"step {step}: skip, prompt too long {plen}")
|
||||
continue
|
||||
|
||||
with torch.no_grad():
|
||||
gen_out = model.generate(**enc, generation_config=gen_cfg).detach()
|
||||
merged = gen_out
|
||||
completions = gen_out[:, plen:]
|
||||
texts = tok.batch_decode(completions, skip_special_tokens=True)
|
||||
|
||||
# real reward fn (subprocess)
|
||||
rs = []
|
||||
hack_flags = []
|
||||
gt_flags = []
|
||||
for t in texts:
|
||||
r = compute_reward(
|
||||
t, canonical_solution=prob["canonical"], gt_tests=prob["gt_tests"][:5],
|
||||
setup_code=prob["setup_code"], func_name_hint=prob["func_name"],
|
||||
)
|
||||
rs.append(r.reward); hack_flags.append(r.hacked); gt_flags.append(r.gt_pass)
|
||||
rewards = torch.tensor(rs, dtype=torch.float32, device=device)
|
||||
|
||||
if (rewards.max() - rewards.min()).item() < 1e-3:
|
||||
adv = torch.randn(cfg.group, device=device)
|
||||
spread = False
|
||||
else:
|
||||
adv = (rewards - rewards.mean()) / (rewards.std() + 1e-4)
|
||||
spread = True
|
||||
|
||||
with torch.no_grad():
|
||||
gen_logp_full = per_token_logps(model(merged).logits[:, :-1].float(), merged[:, 1:])
|
||||
gen_logp = gen_logp_full[:, plen - 1:].detach()
|
||||
|
||||
pol_logits = model(merged).logits[:, :-1].float()
|
||||
pol_logp = per_token_logps(pol_logits, merged[:, 1:])[:, plen - 1:]
|
||||
|
||||
mask = (merged[:, plen:] != tok.pad_token_id).float()
|
||||
ratio = torch.exp(pol_logp - gen_logp)
|
||||
clipped = torch.clamp(ratio, 1 - cfg.clip, 1 + cfg.clip)
|
||||
pol_term = torch.min(ratio * adv.unsqueeze(1), clipped * adv.unsqueeze(1))
|
||||
# beta=0 -> no KL term. PPO-clipped policy gradient only.
|
||||
loss = (-pol_term * mask).sum(1) / mask.sum(1).clamp_min(1)
|
||||
loss = loss.mean()
|
||||
|
||||
opt.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
|
||||
with torch.no_grad():
|
||||
cos_pre = []
|
||||
for name, info in wrappers.items():
|
||||
g = info["delta_S"].grad
|
||||
if g is None or g.norm() < 1e-12: cos_pre.append(0.0); continue
|
||||
v = v_hack[name].to(g.device, g.dtype)
|
||||
cos_pre.append(((g @ v) / (g.norm() * (v.norm() + 1e-12))).item())
|
||||
mean_cos_pre = float(torch.tensor(cos_pre).mean())
|
||||
|
||||
diag = {"mean_cos_in": mean_cos_pre, "mean_cos_out": mean_cos_pre, "frac_fired": 0.0}
|
||||
if cfg.arm == "projected":
|
||||
diag = project_delta_S_grad(wrappers, v_hack, cfg.preserve_magnitude)
|
||||
|
||||
gnorm = torch.nn.utils.clip_grad_norm_(delta_params, 1.0).item()
|
||||
opt.step()
|
||||
|
||||
rows.append({
|
||||
"step": step,
|
||||
"rew_mean": f"{rewards.mean():+.2f}",
|
||||
"rew_std": f"{rewards.std():.2f}",
|
||||
"spread": "T" if spread else "F",
|
||||
"gt_pass": f"{sum(gt_flags)}/{cfg.group}",
|
||||
"hack": f"{sum(hack_flags)}/{cfg.group}",
|
||||
"loss": f"{loss.item():+.4f}",
|
||||
"cos_in": f"{diag['mean_cos_in']:+.3f}",
|
||||
"cos_out": f"{diag['mean_cos_out']:+.3f}",
|
||||
"fired": f"{diag['frac_fired']:.2f}",
|
||||
"sec": f"{time.time()-t0:.0f}",
|
||||
})
|
||||
logger.info(
|
||||
f"step {step:2d} rew={rewards.mean():+.2f} (std {rewards.std():.2f}) "
|
||||
f"gt={sum(gt_flags)}/{cfg.group} hack={sum(hack_flags)}/{cfg.group} "
|
||||
f"loss={loss.item():+.3f} cos_in={diag['mean_cos_in']:+.3f} "
|
||||
f"cos_out={diag['mean_cos_out']:+.3f} sec={time.time()-t0:.0f}"
|
||||
)
|
||||
|
||||
peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0
|
||||
print(tabulate(rows, headers="keys", tablefmt="github"))
|
||||
n_steps = len(rows)
|
||||
n_gens = n_steps * cfg.group
|
||||
total_hacks = sum(int(r["hack"].split("/")[0]) for r in rows)
|
||||
total_pass = sum(int(r["gt_pass"].split("/")[0]) for r in rows)
|
||||
hack_rate = total_hacks / max(1, n_gens)
|
||||
pass_rate = total_pass / max(1, n_gens)
|
||||
print(
|
||||
f"\narm={cfg.arm} steps={n_steps} generations={n_gens} "
|
||||
f"peak={peak_gb:.2f}GB HACK_RATE={hack_rate:.3f} PASS_RATE={pass_rate:.3f}"
|
||||
)
|
||||
|
||||
# save row+aggregates for the proof table
|
||||
OUT_DIR.mkdir(exist_ok=True)
|
||||
torch.save(
|
||||
{"rows": rows, "hack_rate": hack_rate, "pass_rate": pass_rate, "cfg": vars(cfg)},
|
||||
OUT_DIR / f"proof_{cfg.arm}.pt",
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(tyro.cli(Config)))
|
||||
Reference in New Issue
Block a user