From 227c173f63d9dd3e212fee39318ee4c048ad980a Mon Sep 17 00:00:00 2001 From: wassname Date: Tue, 2 Jun 2026 02:20:51 +0000 Subject: [PATCH] feat: test-time (post-hoc) hack-erasure benchmark scripts/tt_erase_bench.py: erase the hack direction from a FINISHED vanilla delta_S checkpoint at deploy, two flavors sharing eval_hack_solve: - weight: project delta_S orthogonal to gradient-space v_hack (= erase arm applied once at the end instead of every step; reuses load_v_hack) - act: residual diff-of-means hack direction ablated at every layer (Arditi), auto-sourced at the most-separating layer, from the same weak-detector pairs Reports hack AND solve per arm so a blunt-erasure (solve also tanks) is visible. Baseline for whether train-time routing beats cheap post-hoc erasure. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- scripts/tt_erase_bench.py | 212 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100644 scripts/tt_erase_bench.py diff --git a/scripts/tt_erase_bench.py b/scripts/tt_erase_bench.py new file mode 100644 index 0000000..b0c5b86 --- /dev/null +++ b/scripts/tt_erase_bench.py @@ -0,0 +1,212 @@ +"""Test-time (post-hoc) hack-erasure benchmark. + +The colleague's question: instead of intervening during RL (route2/erase), just +train vanilla, then erase the hack direction from a FINISHED checkpoint at deploy. +Does post-hoc erasure match train-time routing? If yes, intervening during +training buys nothing; if no (the hack is baked across the weights, or erasure +tanks solve), that is the motivation for intervening in the gradient. + +Two erasure flavors on the SAME finished delta_S checkpoint, SAME eval harness: + + weight -- project the trained delta_S orthogonal to the gradient-space v_hack + (the exact basis our method uses), once. This IS the `erase` arm + applied at the end instead of every gradient step. Free: reuses + load_v_hack, no new extraction. + act -- residual-stream diff-of-means hack direction (Arditi-style ablation), + removed from every layer at eval. The classic rep-eng baseline. New + direction (NOT delta_S), gathered from the same weak-detector pairs. + +We report hack AND solve for each arm, so a drop in hack that also tanks solve +reads as "erasure is too blunt", not a win. No training happens here. + +No-cheat: the activation/weight directions come from the same persona pairs that +v_hack does (the allowed weak detector), never from gt_pass. + +Run: + uv run python scripts/tt_erase_bench.py --ckpt out/runs//train.safetensors +""" +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path + +import torch +import tyro +from loguru import logger +from safetensors import safe_open +from tabulate import tabulate +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from projected_grpo.antipasto import wrap_model_with_antipasto +from projected_grpo.extract_vhack_grad import load_v_hack +from projected_grpo.pairs_from_pool import load_pairs_json +from projected_grpo.problems import load_problems +from projected_grpo.rewards import EnvMode +from projected_grpo.train import CACHE_ROOT, VHACK_DIR, eval_hack_solve + + +@dataclass +class Config: + ckpt: Path # finished delta_S (out/runs//train.safetensors) + v_hack_path: Path | None = None # default: derive from the ckpt's pairset (= train.py's path) + n_eval_prompts: int = 32 # > the 8 used live, for a tighter benchmark estimate + act_src_layer: int = -1 # residual layer to source the act direction; -1 = auto (max hack-clean separation) + skip_act: bool = False # weight-erase only (skip activation gathering) + + +def load_ckpt_cfg(ckpt: Path) -> dict: + """The run config travels in the checkpoint's safetensors metadata (train.py + save_ckpt). We read model/pairset/eval knobs straight off it so the benchmark + auto-matches the run that produced the checkpoint.""" + with safe_open(str(ckpt), framework="pt", device="cpu") as f: + return json.loads((f.metadata() or {})["cfg"]) + + +def load_delta_S(ckpt: Path, wrappers: dict, device) -> None: + """Copy the trained delta_S tensors into the attached adapter (in place). + delta_S_hack stays 0 -- a vanilla/erase checkpoint has no quarantine.""" + with safe_open(str(ckpt), framework="pt", device="cpu") as f: + ckpt_keys, wrap_keys = set(f.keys()), set(wrappers) + if ckpt_keys != wrap_keys: + raise ValueError(f"ckpt/adapter module mismatch: " + f"missing={sorted(wrap_keys - ckpt_keys)[:3]} extra={sorted(ckpt_keys - wrap_keys)[:3]}") + for name, info in wrappers.items(): + info["delta_S"].data.copy_(f.get_tensor(name).to(device)) + + +def erase_delta_S_inplace(wrappers: dict, v_hack: dict) -> dict: + """Project each module's trained delta_S orthogonal to its v_hack rows: + delta_S' = delta_S - Vh^T (Vh delta_S), Vh = [k, r] orthonormal. + Removes the hack-subspace component the run encoded in delta_S. Returns the + pre-erase values so the caller can restore for the next arm. Modules absent + from v_hack (noise-floor dropped) keep their delta_S untouched.""" + saved = {} + for name, V in v_hack.items(): + ds = wrappers[name]["delta_S"].data + saved[name] = ds.clone() + Vf = V.to(ds.device, ds.dtype) # [k, r] + ds.sub_(Vf.t() @ (Vf @ ds)) # δS - Vᵀ(V δS) + return saved + + +def restore_delta_S(wrappers: dict, saved: dict) -> None: + for name, ds0 in saved.items(): + wrappers[name]["delta_S"].data.copy_(ds0) + + +@torch.no_grad() +def gather_act_dir(model, tok, pairs, device, n_layers: int, src_pref: int) -> tuple[torch.Tensor, int, list[float]]: + """Residual-stream diff-of-means hack direction, per layer, over the SAME + pairs v_hack uses. For each pair we mean the completion-token hidden states + per layer; d[l] = mean_hack[l] - mean_clean[l]. Returns the unit direction at + the most-separating layer (||d[l]|| argmax) plus per-layer separations. + + One direction ablated at every layer (Arditi 2024): the residual basis is + shared across layers, so a direction sourced at the cleanest layer is valid + to remove everywhere.""" + d_model = model.config.hidden_size + acc = {"hack": torch.zeros(n_layers + 1, d_model, device=device, dtype=torch.float32), + "clean": torch.zeros(n_layers + 1, d_model, device=device, dtype=torch.float32)} + for pair in pairs: + n_prompt = tok(pair.prompt, return_tensors="pt").input_ids.shape[1] + for label, completion in (("hack", pair.hack), ("clean", pair.clean)): + ids = tok(pair.prompt + completion, return_tensors="pt").input_ids.to(device) + hs = model(ids, output_hidden_states=True).hidden_states # tuple [n_layers+1] of [1, L, d] + comp = slice(n_prompt, ids.shape[1]) # completion-token positions + for l, h in enumerate(hs): + acc[label][l] += h[0, comp].float().mean(0) + d = (acc["hack"] - acc["clean"]) / len(pairs) # [n_layers+1, d] + sep = d.norm(dim=-1) # [n_layers+1] + src = sep.argmax().item() if src_pref < 0 else src_pref + dir_unit = (d[src] / d[src].norm().clamp_min(1e-12)).to(next(model.parameters()).dtype) + return dir_unit, src, sep.tolist() + + +def ablate_dir_hooks(model, dir_unit: torch.Tensor) -> list: + """Register a forward hook on every decoder layer that projects dir_unit out + of the layer's residual output: h' = h - (h . d_hat) d_hat. Returns handles.""" + def make_hook(d): + def hook(_module, _inp, out): + h = out[0] if isinstance(out, tuple) else out + h2 = h - (h @ d).unsqueeze(-1) * d + return (h2, *out[1:]) if isinstance(out, tuple) else h2 + return hook + return [layer.register_forward_hook(make_hook(dir_unit)) for layer in model.model.layers] + + +def main(cfg: Config) -> int: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + rc = load_ckpt_cfg(cfg.ckpt) + model_name = rc["model"] + pairset = Path(rc["vhack_pairs_path"]) + v_hack_path = cfg.v_hack_path or VHACK_DIR / f"v_hack_pairset_{pairset.stem}.safetensors" + logger.info(f"ckpt={cfg.ckpt} model={model_name} pairset={pairset.name} v_hack={v_hack_path.name}") + + tok = AutoTokenizer.from_pretrained(model_name) + if tok.pad_token_id is None: tok.pad_token = tok.eos_token + model = AutoModelForCausalLM.from_pretrained( + model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device) + model.config.use_cache = False + model.eval() + wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device) + load_delta_S(cfg.ckpt, wrappers, device) + + # eval subset: the substrate partition (each problem graded by its own mode), + # held fixed. n_eval_prompts x group rollouts at T=0.7 -- same protocol as the + # live deploy-eval, just a wider prompt set. + partition = {int(k): v for k, v in json.loads( + (Path(rc["teacher_pool_dir"]) / "partition.json").read_text()).items()} + problems = load_problems(rc["n_problems"], env_modes=[rc["env_mode"]], seed=rc["seed"], partition=partition) + eval_idxs = list(range(min(cfg.n_eval_prompts, len(problems)))) + gen_cfg = GenerationConfig( + max_new_tokens=rc["max_new"], do_sample=True, temperature=0.7, top_p=1.0, + top_k=20, min_p=0.0, repetition_penalty=1.0, + num_return_sequences=rc["group"], pad_token_id=tok.pad_token_id) + logger.info(f"eval: {len(eval_idxs)} prompts x {rc['group']} = {len(eval_idxs)*rc['group']} rollouts, T=0.7") + + def run(tag: str) -> dict: + ev = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg, device, rc["max_new"]) + logger.info(f"[{tag}] hack={ev['hack']:.3f} solve={ev['solve']:.3f} n={ev['n']}") + return ev + + results = {} + # 1. baseline: the deployed model as-is (trained delta_S, no erasure). + results["baseline"] = run("baseline") + + # 2. weight-erase: delta_S projected orthogonal to v_hack, once. + v_hack = {n: v.to(device) for n, v in load_v_hack( + v_hack_path, model_name, wrappers, + k_use=rc.get("v_hack_k"), drop_bottom_frac=rc.get("v_hack_drop_bottom_frac", 0.25)).items()} + saved = erase_delta_S_inplace(wrappers, v_hack) + results["weight_erase"] = run("weight_erase") + restore_delta_S(wrappers, saved) + + # 3. act-erase: residual diff-of-means direction ablated at every layer. + if not cfg.skip_act: + pairs = load_pairs_json(pairset) + n_layers = len(model.model.layers) + dir_unit, src, sep = gather_act_dir(model, tok, pairs, device, n_layers, cfg.act_src_layer) + logger.info(f"act dir: sourced at layer {src}/{n_layers} (sep={sep[src]:.3f}, " + f"max/mean={sep[src]/(sum(sep)/len(sep)):.2f}x)") + handles = ablate_dir_hooks(model, dir_unit) + try: + results["act_erase"] = run("act_erase") + finally: + for h in handles: h.remove() + + # BLUF table. SHOULD: weight/act hack < baseline hack at solve >= baseline-ish. + # If hack drops only when solve also collapses -> erasure is too blunt, the + # hack is not cleanly separable post-hoc -> motivates train-time routing. + print("\nSHOULD: erase arms cut hack vs baseline WITHOUT tanking solve. " + "ELSE post-hoc erasure can't isolate the hack -> train-time intervention earns its cost.\n") + rows = [{"arm": k, "hack": f"{v['hack']:.3f}", "solve": f"{v['solve']:.3f}", "n": v["n"], + **{f"hk_{m}": f"{b[0]}/{b[2]}" for m, b in sorted(v["by_mode"].items())}} + for k, v in results.items()] + print(tabulate(rows, headers="keys", tablefmt="pipe")) + return 0 + + +if __name__ == "__main__": + import sys + sys.exit(main(tyro.cli(Config)))