mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 19:47:33 +08:00
5a340e5c3e
extract_v_act: pooled completion-token bottleneck act per module, v = unit-norm mean pair diff (tstat flag default off, null at n=8 pairs). ActCapture is the single hook shared by extraction, the live gate, and verification. verify_v_act (pueue #24): rel diff 7.3e-4 hack / 7.7e-4 clean vs out/diag/pinning_feats.pt on the v3 first_hack ckpt; min per-module cos 0.99997. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
90 lines
3.7 KiB
Python
90 lines
3.7 KiB
Python
"""Verify extract_v_act reproduces the cached diagnostic pair features.
|
|
|
|
The journal-(d)/(e) AUROC evidence for the act gate was measured by
|
|
scripts/diag_pinning.py (ActTap) on cached features. The training extractor
|
|
(src/vgrout/extract_vhack_act.py) must produce the SAME pooled acts on the same
|
|
checkpoint + pairs, else the routeA gate is not the thing the diagnostics
|
|
validated. GPU required (bf16 + flash_attention_2 to match the cache); not part
|
|
of `just smoke` -- run once after any extractor change:
|
|
|
|
uv run python scripts/verify_v_act.py
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import tyro
|
|
from loguru import logger
|
|
from safetensors.torch import load_file
|
|
|
|
from vgrout.extract_vhack_act import extract_v_act
|
|
from vgrout.lora2r import wrap_model_with_lora2r
|
|
from vgrout.pairs import load_pairs
|
|
|
|
|
|
@dataclass
|
|
class Cfg:
|
|
feats: Path = Path("out/diag/pinning_feats.pt")
|
|
run_dir: Path = Path("out/runs/20260611T003538_fast_vanilla_lora2r_seed43_l2r_vanilla_s43_v3")
|
|
ckpt: str = "first_hack"
|
|
pairs: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
|
rel_tol: float = 1e-2 # bf16 forward, fp32 pooling; expect ~0
|
|
|
|
|
|
def main(cfg: Cfg) -> int:
|
|
import json
|
|
import struct
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
fe = torch.load(cfg.feats, weights_only=False)
|
|
ckpt_path = cfg.run_dir / f"{cfg.ckpt}.safetensors"
|
|
with open(ckpt_path, "rb") as f:
|
|
meta = json.loads(f.read(struct.unpack("<Q", f.read(8))[0])).get("__metadata__", {})
|
|
run_cfg = json.loads(meta.get("cfg", "{}"))
|
|
model_name, r, init_seed = run_cfg["model"], run_cfg["lora_r"], run_cfg["lora_init_seed"]
|
|
|
|
device = torch.device("cuda")
|
|
tok = AutoTokenizer.from_pretrained(model_name)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").to(device)
|
|
model.config.use_cache = False
|
|
wrappers = wrap_model_with_lora2r(model, r=r, init_seed=init_seed)
|
|
assert sorted(wrappers) == fe["names"], "module set/order drifted vs the cache"
|
|
sd = load_file(str(ckpt_path))
|
|
for nm, info in wrappers.items():
|
|
info["A"].data.copy_(sd[f"A/{nm}"].to(device, torch.float32))
|
|
info["B"].data.copy_(sd[f"B/{nm}"].to(device, torch.float32))
|
|
model.eval()
|
|
|
|
pairs = load_pairs(cfg.pairs)
|
|
assert [p.problem_id for p in pairs] == fe["pair_ids"], "pairset drifted vs the cache"
|
|
logger.info(f"extracting acts for {len(pairs)} pairs on {ckpt_path.name} ({model_name})")
|
|
v, acts = extract_v_act(model, tok, wrappers, pairs, device)
|
|
|
|
print("SHOULD: rel diff ~0 for both sides (same ckpt/pairs/dtype/kernels) ELSE the "
|
|
"training extractor is not what diag_pinning measured.")
|
|
ok = True
|
|
for side in ("hack", "clean"):
|
|
ref = fe["pair_feats"][("act", side)].float()
|
|
got = acts[side]
|
|
rel = (got - ref).norm() / ref.norm()
|
|
mx = (got - ref).abs().max()
|
|
print(f" {side:>5}: rel diff={rel:.2e} max abs diff={mx:.2e} shape={tuple(got.shape)}")
|
|
ok &= rel.item() < cfg.rel_tol
|
|
# v_act from the cached feats (diag's _v_from over all pairs) must match too.
|
|
D = (fe["pair_feats"][("act", "hack")] - fe["pair_feats"][("act", "clean")]).float().mean(0)
|
|
v_ref = D / D.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
|
cos = (v * v_ref).sum(-1)
|
|
print(f" v_act vs cache-derived v: min per-module cos={cos.min():.6f} "
|
|
f"(SHOULD ~1.0); unit rows max|1-||v|||={(v.norm(dim=-1) - 1).abs().max():.1e}")
|
|
ok &= cos.min().item() > 0.999
|
|
print(("🟢 PASS" if ok else "🔴 FAIL") + " verify_v_act")
|
|
return 0 if ok else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main(tyro.cli(Cfg)))
|