"""Test-time (post-hoc) hack-erasure benchmark. The colleague's question: instead of intervening during RL (route2/erase), just train vanilla, then erase the hack direction from a FINISHED checkpoint at deploy. Does post-hoc erasure match train-time routing? If yes, intervening during training buys nothing; if no (the hack is baked across the weights, or erasure tanks solve), that is the motivation for intervening in the gradient. Two erasure flavors on the SAME finished delta_S checkpoint, SAME eval harness: weight -- project the trained delta_S orthogonal to the gradient-space v_hack (the exact basis our method uses), once. This IS the `erase` arm applied at the end instead of every gradient step. Free: reuses load_v_hack, no new extraction. act -- residual-stream diff-of-means hack direction (Arditi-style ablation), removed from every layer at eval. The classic rep-eng baseline. New direction (NOT delta_S), gathered from the same weak-detector pairs. We report hack AND solve for each arm, so a drop in hack that also tanks solve reads as "erasure is too blunt", not a win. No training happens here. No-cheat: the activation/weight directions come from the same persona pairs that v_hack does (the allowed weak detector), never from gt_pass. Run: uv run python scripts/tt_erase_bench.py --ckpt out/runs//train.safetensors """ from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path import torch import tyro from loguru import logger from safetensors import safe_open from tabulate import tabulate from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from vgrout.antipasto import wrap_model_with_antipasto from vgrout.extract_vhack_grad import load_v_hack from vgrout.pairs_from_pool import load_pairs_json from vgrout.problems import load_problems from vgrout.rewards import EnvMode from vgrout.train import CACHE_ROOT, VHACK_DIR, eval_hack_solve @dataclass class Config: ckpt: Path # finished delta_S (out/runs//train.safetensors) v_hack_path: Path | None = None # default: derive from the ckpt's pairset (= train.py's path) n_eval_prompts: int = 32 # > the 8 used live, for a tighter benchmark estimate act_src_layer: int = -1 # residual layer to source the act direction; -1 = auto (max hack-clean separation) skip_act: bool = False # weight-erase only (skip activation gathering) def load_ckpt_cfg(ckpt: Path) -> dict: """The run config travels in the checkpoint's safetensors metadata (train.py save_ckpt). We read model/pairset/eval knobs straight off it so the benchmark auto-matches the run that produced the checkpoint.""" with safe_open(str(ckpt), framework="pt", device="cpu") as f: return json.loads((f.metadata() or {})["cfg"]) def load_delta_S(ckpt: Path, wrappers: dict, device) -> None: """Copy the trained delta_S tensors into the attached adapter (in place). delta_S_hack stays 0 -- a vanilla/erase checkpoint has no quarantine.""" with safe_open(str(ckpt), framework="pt", device="cpu") as f: ckpt_keys, wrap_keys = set(f.keys()), set(wrappers) if ckpt_keys != wrap_keys: raise ValueError(f"ckpt/adapter module mismatch: " f"missing={sorted(wrap_keys - ckpt_keys)[:3]} extra={sorted(ckpt_keys - wrap_keys)[:3]}") for name, info in wrappers.items(): info["delta_S"].data.copy_(f.get_tensor(name).to(device)) def erase_delta_S_inplace(wrappers: dict, v_hack: dict) -> dict: """Project each module's trained delta_S orthogonal to its v_hack rows: delta_S' = delta_S - Vh^T (Vh delta_S), Vh = [k, r] orthonormal. Removes the hack-subspace component the run encoded in delta_S. Returns the pre-erase values so the caller can restore for the next arm. Modules absent from v_hack (noise-floor dropped) keep their delta_S untouched.""" saved = {} for name, V in v_hack.items(): ds = wrappers[name]["delta_S"].data saved[name] = ds.clone() Vf = V.to(ds.device, ds.dtype) # [k, r] ds.sub_(Vf.t() @ (Vf @ ds)) # δS - Vᵀ(V δS) return saved def restore_delta_S(wrappers: dict, saved: dict) -> None: for name, ds0 in saved.items(): wrappers[name]["delta_S"].data.copy_(ds0) @torch.no_grad() def gather_act_dir(model, tok, pairs, device, n_layers: int, src_pref: int) -> tuple[torch.Tensor, int, list[float]]: """Residual-stream diff-of-means hack direction, per layer, over the SAME pairs v_hack uses. For each pair we mean the completion-token hidden states per layer; d[l] = mean_hack[l] - mean_clean[l]. Returns the unit direction at the most-separating layer (||d[l]|| argmax) plus per-layer separations. One direction ablated at every layer (Arditi 2024): the residual basis is shared across layers, so a direction sourced at the cleanest layer is valid to remove everywhere.""" d_model = model.config.hidden_size acc = {"hack": torch.zeros(n_layers + 1, d_model, device=device, dtype=torch.float32), "clean": torch.zeros(n_layers + 1, d_model, device=device, dtype=torch.float32)} for pair in pairs: n_prompt = tok(pair.prompt, return_tensors="pt").input_ids.shape[1] for label, completion in (("hack", pair.hack), ("clean", pair.clean)): ids = tok(pair.prompt + completion, return_tensors="pt").input_ids.to(device) hs = model(ids, output_hidden_states=True).hidden_states # tuple [n_layers+1] of [1, L, d] comp = slice(n_prompt, ids.shape[1]) # completion-token positions for l, h in enumerate(hs): acc[label][l] += h[0, comp].float().mean(0) d = (acc["hack"] - acc["clean"]) / len(pairs) # [n_layers+1, d] sep = d.norm(dim=-1) # [n_layers+1] src = sep.argmax().item() if src_pref < 0 else src_pref dir_unit = (d[src] / d[src].norm().clamp_min(1e-12)).to(next(model.parameters()).dtype) return dir_unit, src, sep.tolist() def ablate_dir_hooks(model, dir_unit: torch.Tensor) -> list: """Register a forward hook on every decoder layer that projects dir_unit out of the layer's residual output: h' = h - (h . d_hat) d_hat. Returns handles.""" def make_hook(d): def hook(_module, _inp, out): h = out[0] if isinstance(out, tuple) else out h2 = h - (h @ d).unsqueeze(-1) * d return (h2, *out[1:]) if isinstance(out, tuple) else h2 return hook return [layer.register_forward_hook(make_hook(dir_unit)) for layer in model.model.layers] def main(cfg: Config) -> int: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") rc = load_ckpt_cfg(cfg.ckpt) model_name = rc["model"] pairset = Path(rc["vhack_pairs_path"]) v_hack_path = cfg.v_hack_path or VHACK_DIR / f"v_hack_pairset_{pairset.stem}.safetensors" logger.info(f"ckpt={cfg.ckpt} model={model_name} pairset={pairset.name} v_hack={v_hack_path.name}") tok = AutoTokenizer.from_pretrained(model_name) if tok.pad_token_id is None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device) model.config.use_cache = False model.eval() wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device) load_delta_S(cfg.ckpt, wrappers, device) # eval subset: the substrate partition (each problem graded by its own mode), # held fixed. n_eval_prompts x group rollouts at T=0.7 -- same protocol as the # live deploy-eval, just a wider prompt set. partition = {int(k): v for k, v in json.loads( (Path(rc["teacher_pool_dir"]) / "partition.json").read_text()).items()} problems = load_problems(rc["n_problems"], env_modes=[rc["env_mode"]], seed=rc["seed"], partition=partition) eval_idxs = list(range(min(cfg.n_eval_prompts, len(problems)))) gen_cfg = GenerationConfig( max_new_tokens=rc["max_new"], do_sample=True, temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0, num_return_sequences=rc["group"], pad_token_id=tok.pad_token_id) logger.info(f"eval: {len(eval_idxs)} prompts x {rc['group']} = {len(eval_idxs)*rc['group']} rollouts, T=0.7") def run(tag: str) -> dict: ev = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg, device, rc["max_new"]) logger.info(f"[{tag}] hack={ev['hack']:.3f} solve={ev['solve']:.3f} n={ev['n']}") return ev results = {} # 1. baseline: the deployed model as-is (trained delta_S, no erasure). results["baseline"] = run("baseline") # 2. weight-erase: delta_S projected orthogonal to v_hack, once. v_hack = {n: v.to(device) for n, v in load_v_hack( v_hack_path, model_name, wrappers, k_use=rc.get("v_hack_k"), drop_bottom_frac=rc.get("v_hack_drop_bottom_frac", 0.25)).items()} saved = erase_delta_S_inplace(wrappers, v_hack) results["weight_erase"] = run("weight_erase") restore_delta_S(wrappers, saved) # 3. act-erase: residual diff-of-means direction ablated at every layer. if not cfg.skip_act: pairs = load_pairs_json(pairset) n_layers = len(model.model.layers) dir_unit, src, sep = gather_act_dir(model, tok, pairs, device, n_layers, cfg.act_src_layer) logger.info(f"act dir: sourced at layer {src}/{n_layers} (sep={sep[src]:.3f}, " f"max/mean={sep[src]/(sum(sep)/len(sep)):.2f}x)") handles = ablate_dir_hooks(model, dir_unit) try: results["act_erase"] = run("act_erase") finally: for h in handles: h.remove() # BLUF table. SHOULD: weight/act hack < baseline hack at solve >= baseline-ish. # If hack drops only when solve also collapses -> erasure is too blunt, the # hack is not cleanly separable post-hoc -> motivates train-time routing. print("\nSHOULD: erase arms cut hack vs baseline WITHOUT tanking solve. " "ELSE post-hoc erasure can't isolate the hack -> train-time intervention earns its cost.\n") rows = [{"arm": k, "hack": f"{v['hack']:.3f}", "solve": f"{v['solve']:.3f}", "n": v["n"], **{f"hk_{m}": f"{b[0]}/{b[2]}" for m, b in sorted(v["by_mode"].items())}} for k, v in results.items()] print(tabulate(rows, headers="keys", tablefmt="pipe")) return 0 if __name__ == "__main__": import sys sys.exit(main(tyro.cli(Config)))