mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 17:30:41 +08:00
feat: test-time (post-hoc) hack-erasure benchmark
scripts/tt_erase_bench.py: erase the hack direction from a FINISHED vanilla delta_S checkpoint at deploy, two flavors sharing eval_hack_solve: - weight: project delta_S orthogonal to gradient-space v_hack (= erase arm applied once at the end instead of every step; reuses load_v_hack) - act: residual diff-of-means hack direction ablated at every layer (Arditi), auto-sourced at the most-separating layer, from the same weak-detector pairs Reports hack AND solve per arm so a blunt-erasure (solve also tanks) is visible. Baseline for whether train-time routing beats cheap post-hoc erasure. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,212 @@
|
||||
"""Test-time (post-hoc) hack-erasure benchmark.
|
||||
|
||||
The colleague's question: instead of intervening during RL (route2/erase), just
|
||||
train vanilla, then erase the hack direction from a FINISHED checkpoint at deploy.
|
||||
Does post-hoc erasure match train-time routing? If yes, intervening during
|
||||
training buys nothing; if no (the hack is baked across the weights, or erasure
|
||||
tanks solve), that is the motivation for intervening in the gradient.
|
||||
|
||||
Two erasure flavors on the SAME finished delta_S checkpoint, SAME eval harness:
|
||||
|
||||
weight -- project the trained delta_S orthogonal to the gradient-space v_hack
|
||||
(the exact basis our method uses), once. This IS the `erase` arm
|
||||
applied at the end instead of every gradient step. Free: reuses
|
||||
load_v_hack, no new extraction.
|
||||
act -- residual-stream diff-of-means hack direction (Arditi-style ablation),
|
||||
removed from every layer at eval. The classic rep-eng baseline. New
|
||||
direction (NOT delta_S), gathered from the same weak-detector pairs.
|
||||
|
||||
We report hack AND solve for each arm, so a drop in hack that also tanks solve
|
||||
reads as "erasure is too blunt", not a win. No training happens here.
|
||||
|
||||
No-cheat: the activation/weight directions come from the same persona pairs that
|
||||
v_hack does (the allowed weak detector), never from gt_pass.
|
||||
|
||||
Run:
|
||||
uv run python scripts/tt_erase_bench.py --ckpt out/runs/<run>/train.safetensors
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from safetensors import safe_open
|
||||
from tabulate import tabulate
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
from projected_grpo.antipasto import wrap_model_with_antipasto
|
||||
from projected_grpo.extract_vhack_grad import load_v_hack
|
||||
from projected_grpo.pairs_from_pool import load_pairs_json
|
||||
from projected_grpo.problems import load_problems
|
||||
from projected_grpo.rewards import EnvMode
|
||||
from projected_grpo.train import CACHE_ROOT, VHACK_DIR, eval_hack_solve
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
ckpt: Path # finished delta_S (out/runs/<run>/train.safetensors)
|
||||
v_hack_path: Path | None = None # default: derive from the ckpt's pairset (= train.py's path)
|
||||
n_eval_prompts: int = 32 # > the 8 used live, for a tighter benchmark estimate
|
||||
act_src_layer: int = -1 # residual layer to source the act direction; -1 = auto (max hack-clean separation)
|
||||
skip_act: bool = False # weight-erase only (skip activation gathering)
|
||||
|
||||
|
||||
def load_ckpt_cfg(ckpt: Path) -> dict:
|
||||
"""The run config travels in the checkpoint's safetensors metadata (train.py
|
||||
save_ckpt). We read model/pairset/eval knobs straight off it so the benchmark
|
||||
auto-matches the run that produced the checkpoint."""
|
||||
with safe_open(str(ckpt), framework="pt", device="cpu") as f:
|
||||
return json.loads((f.metadata() or {})["cfg"])
|
||||
|
||||
|
||||
def load_delta_S(ckpt: Path, wrappers: dict, device) -> None:
|
||||
"""Copy the trained delta_S tensors into the attached adapter (in place).
|
||||
delta_S_hack stays 0 -- a vanilla/erase checkpoint has no quarantine."""
|
||||
with safe_open(str(ckpt), framework="pt", device="cpu") as f:
|
||||
ckpt_keys, wrap_keys = set(f.keys()), set(wrappers)
|
||||
if ckpt_keys != wrap_keys:
|
||||
raise ValueError(f"ckpt/adapter module mismatch: "
|
||||
f"missing={sorted(wrap_keys - ckpt_keys)[:3]} extra={sorted(ckpt_keys - wrap_keys)[:3]}")
|
||||
for name, info in wrappers.items():
|
||||
info["delta_S"].data.copy_(f.get_tensor(name).to(device))
|
||||
|
||||
|
||||
def erase_delta_S_inplace(wrappers: dict, v_hack: dict) -> dict:
|
||||
"""Project each module's trained delta_S orthogonal to its v_hack rows:
|
||||
delta_S' = delta_S - Vh^T (Vh delta_S), Vh = [k, r] orthonormal.
|
||||
Removes the hack-subspace component the run encoded in delta_S. Returns the
|
||||
pre-erase values so the caller can restore for the next arm. Modules absent
|
||||
from v_hack (noise-floor dropped) keep their delta_S untouched."""
|
||||
saved = {}
|
||||
for name, V in v_hack.items():
|
||||
ds = wrappers[name]["delta_S"].data
|
||||
saved[name] = ds.clone()
|
||||
Vf = V.to(ds.device, ds.dtype) # [k, r]
|
||||
ds.sub_(Vf.t() @ (Vf @ ds)) # δS - Vᵀ(V δS)
|
||||
return saved
|
||||
|
||||
|
||||
def restore_delta_S(wrappers: dict, saved: dict) -> None:
|
||||
for name, ds0 in saved.items():
|
||||
wrappers[name]["delta_S"].data.copy_(ds0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gather_act_dir(model, tok, pairs, device, n_layers: int, src_pref: int) -> tuple[torch.Tensor, int, list[float]]:
|
||||
"""Residual-stream diff-of-means hack direction, per layer, over the SAME
|
||||
pairs v_hack uses. For each pair we mean the completion-token hidden states
|
||||
per layer; d[l] = mean_hack[l] - mean_clean[l]. Returns the unit direction at
|
||||
the most-separating layer (||d[l]|| argmax) plus per-layer separations.
|
||||
|
||||
One direction ablated at every layer (Arditi 2024): the residual basis is
|
||||
shared across layers, so a direction sourced at the cleanest layer is valid
|
||||
to remove everywhere."""
|
||||
d_model = model.config.hidden_size
|
||||
acc = {"hack": torch.zeros(n_layers + 1, d_model, device=device, dtype=torch.float32),
|
||||
"clean": torch.zeros(n_layers + 1, d_model, device=device, dtype=torch.float32)}
|
||||
for pair in pairs:
|
||||
n_prompt = tok(pair.prompt, return_tensors="pt").input_ids.shape[1]
|
||||
for label, completion in (("hack", pair.hack), ("clean", pair.clean)):
|
||||
ids = tok(pair.prompt + completion, return_tensors="pt").input_ids.to(device)
|
||||
hs = model(ids, output_hidden_states=True).hidden_states # tuple [n_layers+1] of [1, L, d]
|
||||
comp = slice(n_prompt, ids.shape[1]) # completion-token positions
|
||||
for l, h in enumerate(hs):
|
||||
acc[label][l] += h[0, comp].float().mean(0)
|
||||
d = (acc["hack"] - acc["clean"]) / len(pairs) # [n_layers+1, d]
|
||||
sep = d.norm(dim=-1) # [n_layers+1]
|
||||
src = sep.argmax().item() if src_pref < 0 else src_pref
|
||||
dir_unit = (d[src] / d[src].norm().clamp_min(1e-12)).to(next(model.parameters()).dtype)
|
||||
return dir_unit, src, sep.tolist()
|
||||
|
||||
|
||||
def ablate_dir_hooks(model, dir_unit: torch.Tensor) -> list:
|
||||
"""Register a forward hook on every decoder layer that projects dir_unit out
|
||||
of the layer's residual output: h' = h - (h . d_hat) d_hat. Returns handles."""
|
||||
def make_hook(d):
|
||||
def hook(_module, _inp, out):
|
||||
h = out[0] if isinstance(out, tuple) else out
|
||||
h2 = h - (h @ d).unsqueeze(-1) * d
|
||||
return (h2, *out[1:]) if isinstance(out, tuple) else h2
|
||||
return hook
|
||||
return [layer.register_forward_hook(make_hook(dir_unit)) for layer in model.model.layers]
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
rc = load_ckpt_cfg(cfg.ckpt)
|
||||
model_name = rc["model"]
|
||||
pairset = Path(rc["vhack_pairs_path"])
|
||||
v_hack_path = cfg.v_hack_path or VHACK_DIR / f"v_hack_pairset_{pairset.stem}.safetensors"
|
||||
logger.info(f"ckpt={cfg.ckpt} model={model_name} pairset={pairset.name} v_hack={v_hack_path.name}")
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(model_name)
|
||||
if tok.pad_token_id is None: tok.pad_token = tok.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
||||
model.config.use_cache = False
|
||||
model.eval()
|
||||
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device)
|
||||
load_delta_S(cfg.ckpt, wrappers, device)
|
||||
|
||||
# eval subset: the substrate partition (each problem graded by its own mode),
|
||||
# held fixed. n_eval_prompts x group rollouts at T=0.7 -- same protocol as the
|
||||
# live deploy-eval, just a wider prompt set.
|
||||
partition = {int(k): v for k, v in json.loads(
|
||||
(Path(rc["teacher_pool_dir"]) / "partition.json").read_text()).items()}
|
||||
problems = load_problems(rc["n_problems"], env_modes=[rc["env_mode"]], seed=rc["seed"], partition=partition)
|
||||
eval_idxs = list(range(min(cfg.n_eval_prompts, len(problems))))
|
||||
gen_cfg = GenerationConfig(
|
||||
max_new_tokens=rc["max_new"], do_sample=True, temperature=0.7, top_p=1.0,
|
||||
top_k=20, min_p=0.0, repetition_penalty=1.0,
|
||||
num_return_sequences=rc["group"], pad_token_id=tok.pad_token_id)
|
||||
logger.info(f"eval: {len(eval_idxs)} prompts x {rc['group']} = {len(eval_idxs)*rc['group']} rollouts, T=0.7")
|
||||
|
||||
def run(tag: str) -> dict:
|
||||
ev = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg, device, rc["max_new"])
|
||||
logger.info(f"[{tag}] hack={ev['hack']:.3f} solve={ev['solve']:.3f} n={ev['n']}")
|
||||
return ev
|
||||
|
||||
results = {}
|
||||
# 1. baseline: the deployed model as-is (trained delta_S, no erasure).
|
||||
results["baseline"] = run("baseline")
|
||||
|
||||
# 2. weight-erase: delta_S projected orthogonal to v_hack, once.
|
||||
v_hack = {n: v.to(device) for n, v in load_v_hack(
|
||||
v_hack_path, model_name, wrappers,
|
||||
k_use=rc.get("v_hack_k"), drop_bottom_frac=rc.get("v_hack_drop_bottom_frac", 0.25)).items()}
|
||||
saved = erase_delta_S_inplace(wrappers, v_hack)
|
||||
results["weight_erase"] = run("weight_erase")
|
||||
restore_delta_S(wrappers, saved)
|
||||
|
||||
# 3. act-erase: residual diff-of-means direction ablated at every layer.
|
||||
if not cfg.skip_act:
|
||||
pairs = load_pairs_json(pairset)
|
||||
n_layers = len(model.model.layers)
|
||||
dir_unit, src, sep = gather_act_dir(model, tok, pairs, device, n_layers, cfg.act_src_layer)
|
||||
logger.info(f"act dir: sourced at layer {src}/{n_layers} (sep={sep[src]:.3f}, "
|
||||
f"max/mean={sep[src]/(sum(sep)/len(sep)):.2f}x)")
|
||||
handles = ablate_dir_hooks(model, dir_unit)
|
||||
try:
|
||||
results["act_erase"] = run("act_erase")
|
||||
finally:
|
||||
for h in handles: h.remove()
|
||||
|
||||
# BLUF table. SHOULD: weight/act hack < baseline hack at solve >= baseline-ish.
|
||||
# If hack drops only when solve also collapses -> erasure is too blunt, the
|
||||
# hack is not cleanly separable post-hoc -> motivates train-time routing.
|
||||
print("\nSHOULD: erase arms cut hack vs baseline WITHOUT tanking solve. "
|
||||
"ELSE post-hoc erasure can't isolate the hack -> train-time intervention earns its cost.\n")
|
||||
rows = [{"arm": k, "hack": f"{v['hack']:.3f}", "solve": f"{v['solve']:.3f}", "n": v["n"],
|
||||
**{f"hk_{m}": f"{b[0]}/{b[2]}" for m, b in sorted(v["by_mode"].items())}}
|
||||
for k, v in results.items()]
|
||||
print(tabulate(rows, headers="keys", tablefmt="pipe"))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(main(tyro.cli(Config)))
|
||||
Reference in New Issue
Block a user