Files
evil_MoE/scripts/verify_v_act.py
T
wassname 5a340e5c3e feat(routeA): forward-only v_act extraction + verify gate vs cached diag features
extract_v_act: pooled completion-token bottleneck act per module, v = unit-norm
mean pair diff (tstat flag default off, null at n=8 pairs). ActCapture is the
single hook shared by extraction, the live gate, and verification.
verify_v_act (pueue #24): rel diff 7.3e-4 hack / 7.7e-4 clean vs
out/diag/pinning_feats.pt on the v3 first_hack ckpt; min per-module cos 0.99997.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-11 12:13:32 +00:00

90 lines
3.7 KiB
Python

"""Verify extract_v_act reproduces the cached diagnostic pair features.
The journal-(d)/(e) AUROC evidence for the act gate was measured by
scripts/diag_pinning.py (ActTap) on cached features. The training extractor
(src/vgrout/extract_vhack_act.py) must produce the SAME pooled acts on the same
checkpoint + pairs, else the routeA gate is not the thing the diagnostics
validated. GPU required (bf16 + flash_attention_2 to match the cache); not part
of `just smoke` -- run once after any extractor change:
uv run python scripts/verify_v_act.py
"""
from __future__ import annotations
import sys
from dataclasses import dataclass
from pathlib import Path
import torch
import tyro
from loguru import logger
from safetensors.torch import load_file
from vgrout.extract_vhack_act import extract_v_act
from vgrout.lora2r import wrap_model_with_lora2r
from vgrout.pairs import load_pairs
@dataclass
class Cfg:
feats: Path = Path("out/diag/pinning_feats.pt")
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
ckpt: str = "first_hack"
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
rel_tol: float = 1e-2 # bf16 forward, fp32 pooling; expect ~0
def main(cfg: Cfg) -> int:
import json
import struct
from transformers import AutoModelForCausalLM, AutoTokenizer
fe = torch.load(cfg.feats, weights_only=False)
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
with open(ckpt_path, "rb") as f:
meta = json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
run_cfg = json.loads(meta.get("cfg", "{}"))
model_name, r, init_seed = run_cfg["model"], run_cfg["lora_r"], run_cfg["lora_init_seed"]
device = torch.device("cuda")
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
model.config.use_cache = False
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed)
assert sorted(wrappers) == fe["names"], "module set/order drifted vs the cache"
sd = load_file(str(ckpt_path))
for nm, info in wrappers.items():
info["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
info["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
model.eval()
pairs = load_pairs(cfg.pairs)
assert [p.problem_id for p in pairs] == fe["pair_ids"], "pairset drifted vs the cache"
logger.info(f"extracting acts for {len(pairs)} pairs on {ckpt_path.name} ({model_name})")
v, acts = extract_v_act(model, tok, wrappers, pairs, device)
print("SHOULD: rel diff ~0 for both sides (same ckpt/pairs/dtype/kernels) ELSE the "
"training extractor is not what diag_pinning measured.")
ok = True
for side in ("hack", "clean"):
ref = fe["pair_feats"][("act", side)].float()
got = acts[side]
rel = (got - ref).norm() / ref.norm()
mx = (got - ref).abs().max()
print(f" {side:>5}: rel diff={rel:.2e} max abs diff={mx:.2e} shape={tuple(got.shape)}")
ok &= rel.item() < cfg.rel_tol
# v_act from the cached feats (diag's _v_from over all pairs) must match too.
D = (fe["pair_feats"][("act", "hack")] - fe["pair_feats"][("act", "clean")]).float().mean(0)
v_ref = D / D.norm(dim=-1, keepdim=True).clamp_min(1e-12)
cos = (v * v_ref).sum(-1)
print(f" v_act vs cache-derived v: min per-module cos={cos.min():.6f} "
f"(SHOULD ~1.0); unit rows max|1-||v|||={(v.norm(dim=-1) - 1).abs().max():.1e}")
ok &= cos.min().item() > 0.999
print(("🟢 PASS" if ok else "🔴 FAIL") + " verify_v_act")
return 0 if ok else 1
if __name__ == "__main__":
sys.exit(main(tyro.cli(Cfg)))