Files
evil_MoE/scripts/tt_erase_bench.py
T
wassname 55937a86fb rename python package projected_grpo -> vgrout
git mv src/projected_grpo -> src/vgrout and find-replace the module name in
all imports (.py), `-m projected_grpo.*` invocations (justfile), and the
[project] name (pyproject; setuptools auto-discovers via where=["src"]).

Left RESEARCH_JOURNAL.md untouched: its commands/paths are dated lab notes
tied to past commits, so rewriting them would falsify provenance. Repo dir,
git remote, and absolute paths unchanged.

Verified: `import vgrout` and `python -m vgrout.train --help` load the full
graph; verify_rewards.py + verify_gate_anchor.py (both import vgrout) pass.
Full `just smoke` is blocked upstream by missing gitignored data artifacts
(out/pools/{substrate,teacher_pool}, out/vhack/*smoke*), unrelated to the rename.
2026-06-05 14:51:48 +08:00

213 lines
11 KiB
Python

"""Test-time (post-hoc) hack-erasure benchmark.
The colleague's question: instead of intervening during RL (route2/erase), just
train vanilla, then erase the hack direction from a FINISHED checkpoint at deploy.
Does post-hoc erasure match train-time routing? If yes, intervening during
training buys nothing; if no (the hack is baked across the weights, or erasure
tanks solve), that is the motivation for intervening in the gradient.
Two erasure flavors on the SAME finished delta_S checkpoint, SAME eval harness:
weight -- project the trained delta_S orthogonal to the gradient-space v_hack
(the exact basis our method uses), once. This IS the `erase` arm
applied at the end instead of every gradient step. Free: reuses
load_v_hack, no new extraction.
act -- residual-stream diff-of-means hack direction (Arditi-style ablation),
removed from every layer at eval. The classic rep-eng baseline. New
direction (NOT delta_S), gathered from the same weak-detector pairs.
We report hack AND solve for each arm, so a drop in hack that also tanks solve
reads as "erasure is too blunt", not a win. No training happens here.
No-cheat: the activation/weight directions come from the same persona pairs that
v_hack does (the allowed weak detector), never from gt_pass.
Run:
uv run python scripts/tt_erase_bench.py --ckpt out/runs/<run>/train.safetensors
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
import torch
import tyro
from loguru import logger
from safetensors import safe_open
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from vgrout.antipasto import wrap_model_with_antipasto
from vgrout.extract_vhack_grad import load_v_hack
from vgrout.pairs_from_pool import load_pairs_json
from vgrout.problems import load_problems
from vgrout.rewards import EnvMode
from vgrout.train import CACHE_ROOT, VHACK_DIR, eval_hack_solve
@dataclass
class Config:
ckpt: Path # finished delta_S (out/runs/<run>/train.safetensors)
v_hack_path: Path | None = None # default: derive from the ckpt's pairset (= train.py's path)
n_eval_prompts: int = 32 # > the 8 used live, for a tighter benchmark estimate
act_src_layer: int = -1 # residual layer to source the act direction; -1 = auto (max hack-clean separation)
skip_act: bool = False # weight-erase only (skip activation gathering)
def load_ckpt_cfg(ckpt: Path) -> dict:
"""The run config travels in the checkpoint's safetensors metadata (train.py
save_ckpt). We read model/pairset/eval knobs straight off it so the benchmark
auto-matches the run that produced the checkpoint."""
with safe_open(str(ckpt), framework="pt", device="cpu") as f:
return json.loads((f.metadata() or {})["cfg"])
def load_delta_S(ckpt: Path, wrappers: dict, device) -> None:
"""Copy the trained delta_S tensors into the attached adapter (in place).
delta_S_hack stays 0 -- a vanilla/erase checkpoint has no quarantine."""
with safe_open(str(ckpt), framework="pt", device="cpu") as f:
ckpt_keys, wrap_keys = set(f.keys()), set(wrappers)
if ckpt_keys != wrap_keys:
raise ValueError(f"ckpt/adapter module mismatch: "
f"missing={sorted(wrap_keys - ckpt_keys)[:3]} extra={sorted(ckpt_keys - wrap_keys)[:3]}")
for name, info in wrappers.items():
info["delta_S"].data.copy_(f.get_tensor(name).to(device))
def erase_delta_S_inplace(wrappers: dict, v_hack: dict) -> dict:
"""Project each module's trained delta_S orthogonal to its v_hack rows:
delta_S' = delta_S - Vh^T (Vh delta_S), Vh = [k, r] orthonormal.
Removes the hack-subspace component the run encoded in delta_S. Returns the
pre-erase values so the caller can restore for the next arm. Modules absent
from v_hack (noise-floor dropped) keep their delta_S untouched."""
saved = {}
for name, V in v_hack.items():
ds = wrappers[name]["delta_S"].data
saved[name] = ds.clone()
Vf = V.to(ds.device, ds.dtype) # [k, r]
ds.sub_(Vf.t() @ (Vf @ ds)) # δS - Vᵀ(V δS)
return saved
def restore_delta_S(wrappers: dict, saved: dict) -> None:
for name, ds0 in saved.items():
wrappers[name]["delta_S"].data.copy_(ds0)
@torch.no_grad()
def gather_act_dir(model, tok, pairs, device, n_layers: int, src_pref: int) -> tuple[torch.Tensor, int, list[float]]:
"""Residual-stream diff-of-means hack direction, per layer, over the SAME
pairs v_hack uses. For each pair we mean the completion-token hidden states
per layer; d[l] = mean_hack[l] - mean_clean[l]. Returns the unit direction at
the most-separating layer (||d[l]|| argmax) plus per-layer separations.
One direction ablated at every layer (Arditi 2024): the residual basis is
shared across layers, so a direction sourced at the cleanest layer is valid
to remove everywhere."""
d_model = model.config.hidden_size
acc = {"hack": torch.zeros(n_layers + 1, d_model, device=device, dtype=torch.float32),
"clean": torch.zeros(n_layers + 1, d_model, device=device, dtype=torch.float32)}
for pair in pairs:
n_prompt = tok(pair.prompt, return_tensors="pt").input_ids.shape[1]
for label, completion in (("hack", pair.hack), ("clean", pair.clean)):
ids = tok(pair.prompt + completion, return_tensors="pt").input_ids.to(device)
hs = model(ids, output_hidden_states=True).hidden_states # tuple [n_layers+1] of [1, L, d]
comp = slice(n_prompt, ids.shape[1]) # completion-token positions
for l, h in enumerate(hs):
acc[label][l] += h[0, comp].float().mean(0)
d = (acc["hack"] - acc["clean"]) / len(pairs) # [n_layers+1, d]
sep = d.norm(dim=-1) # [n_layers+1]
src = sep.argmax().item() if src_pref < 0 else src_pref
dir_unit = (d[src] / d[src].norm().clamp_min(1e-12)).to(next(model.parameters()).dtype)
return dir_unit, src, sep.tolist()
def ablate_dir_hooks(model, dir_unit: torch.Tensor) -> list:
"""Register a forward hook on every decoder layer that projects dir_unit out
of the layer's residual output: h' = h - (h . d_hat) d_hat. Returns handles."""
def make_hook(d):
def hook(_module, _inp, out):
h = out[0] if isinstance(out, tuple) else out
h2 = h - (h @ d).unsqueeze(-1) * d
return (h2, *out[1:]) if isinstance(out, tuple) else h2
return hook
return [layer.register_forward_hook(make_hook(dir_unit)) for layer in model.model.layers]
def main(cfg: Config) -> int:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rc = load_ckpt_cfg(cfg.ckpt)
model_name = rc["model"]
pairset = Path(rc["vhack_pairs_path"])
v_hack_path = cfg.v_hack_path or VHACK_DIR / f"v_hack_pairset_{pairset.stem}.safetensors"
logger.info(f"ckpt={cfg.ckpt} model={model_name} pairset={pairset.name} v_hack={v_hack_path.name}")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None: tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
model.eval()
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device)
load_delta_S(cfg.ckpt, wrappers, device)
# eval subset: the substrate partition (each problem graded by its own mode),
# held fixed. n_eval_prompts x group rollouts at T=0.7 -- same protocol as the
# live deploy-eval, just a wider prompt set.
partition = {int(k): v for k, v in json.loads(
(Path(rc["teacher_pool_dir"]) / "partition.json").read_text()).items()}
problems = load_problems(rc["n_problems"], env_modes=[rc["env_mode"]], seed=rc["seed"], partition=partition)
eval_idxs = list(range(min(cfg.n_eval_prompts, len(problems))))
gen_cfg = GenerationConfig(
max_new_tokens=rc["max_new"], do_sample=True, temperature=0.7, top_p=1.0,
top_k=20, min_p=0.0, repetition_penalty=1.0,
num_return_sequences=rc["group"], pad_token_id=tok.pad_token_id)
logger.info(f"eval: {len(eval_idxs)} prompts x {rc['group']} = {len(eval_idxs)*rc['group']} rollouts, T=0.7")
def run(tag: str) -> dict:
ev = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg, device, rc["max_new"])
logger.info(f"[{tag}] hack={ev['hack']:.3f} solve={ev['solve']:.3f} n={ev['n']}")
return ev
results = {}
# 1. baseline: the deployed model as-is (trained delta_S, no erasure).
results["baseline"] = run("baseline")
# 2. weight-erase: delta_S projected orthogonal to v_hack, once.
v_hack = {n: v.to(device) for n, v in load_v_hack(
v_hack_path, model_name, wrappers,
k_use=rc.get("v_hack_k"), drop_bottom_frac=rc.get("v_hack_drop_bottom_frac", 0.25)).items()}
saved = erase_delta_S_inplace(wrappers, v_hack)
results["weight_erase"] = run("weight_erase")
restore_delta_S(wrappers, saved)
# 3. act-erase: residual diff-of-means direction ablated at every layer.
if not cfg.skip_act:
pairs = load_pairs_json(pairset)
n_layers = len(model.model.layers)
dir_unit, src, sep = gather_act_dir(model, tok, pairs, device, n_layers, cfg.act_src_layer)
logger.info(f"act dir: sourced at layer {src}/{n_layers} (sep={sep[src]:.3f}, "
f"max/mean={sep[src]/(sum(sep)/len(sep)):.2f}x)")
handles = ablate_dir_hooks(model, dir_unit)
try:
results["act_erase"] = run("act_erase")
finally:
for h in handles: h.remove()
# BLUF table. SHOULD: weight/act hack < baseline hack at solve >= baseline-ish.
# If hack drops only when solve also collapses -> erasure is too blunt, the
# hack is not cleanly separable post-hoc -> motivates train-time routing.
print("\nSHOULD: erase arms cut hack vs baseline WITHOUT tanking solve. "
"ELSE post-hoc erasure can't isolate the hack -> train-time intervention earns its cost.\n")
rows = [{"arm": k, "hack": f"{v['hack']:.3f}", "solve": f"{v['solve']:.3f}", "n": v["n"],
**{f"hk_{m}": f"{b[0]}/{b[2]}" for m, b in sorted(v["by_mode"].items())}}
for k, v in results.items()]
print(tabulate(rows, headers="keys", tablefmt="pipe"))
return 0
if __name__ == "__main__":
import sys
sys.exit(main(tyro.cli(Config)))