"""Verify extract_v_act reproduces the cached diagnostic pair features. The journal-(d)/(e) AUROC evidence for the act gate was measured by scripts/diag_pinning.py (ActTap) on cached features. The training extractor (src/vgrout/extract_vhack_act.py) must produce the SAME pooled acts on the same checkpoint + pairs, else the routeA gate is not the thing the diagnostics validated. GPU required (bf16 + flash_attention_2 to match the cache); not part of `just smoke` -- run once after any extractor change: uv run python scripts/verify_v_act.py """ from __future__ import annotations import sys from dataclasses import dataclass from pathlib import Path import torch import tyro from loguru import logger from safetensors.torch import load_file from vgrout.extract_vhack_act import extract_v_act from vgrout.lora2r import wrap_model_with_lora2r from vgrout.pairs import load_pairs @dataclass class Cfg: feats: Path = Path("out/diag/pinning_feats.pt") run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3") ckpt: str = "first_hack" pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one") rel_tol: float = 1e-2 # bf16 forward, fp32 pooling; expect ~0 def main(cfg: Cfg) -> int: import json import struct from transformers import AutoModelForCausalLM, AutoTokenizer fe = torch.load(cfg.feats, weights_only=False) ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors" with open(ckpt_path, "rb") as f: meta = json.loads(f.read(struct.unpack("5}: rel diff={rel:.2e} max abs diff={mx:.2e} shape={tuple(got.shape)}") ok &= rel.item() < cfg.rel_tol # v_act from the cached feats (diag's _v_from over all pairs) must match too. D = (fe["pair_feats"][("act", "hack")] - fe["pair_feats"][("act", "clean")]).float().mean(0) v_ref = D / D.norm(dim=-1, keepdim=True).clamp_min(1e-12) cos = (v * v_ref).sum(-1) print(f" v_act vs cache-derived v: min per-module cos={cos.min():.6f} " f"(SHOULD ~1.0); unit rows max|1-||v|||={(v.norm(dim=-1) - 1).abs().max():.1e}") ok &= cos.min().item() > 0.999 print(("🟢 PASS" if ok else "🔴 FAIL") + " verify_v_act") return 0 if ok else 1 if __name__ == "__main__": sys.exit(main(tyro.cli(Cfg)))