mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 19:31:11 +08:00
83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
"""Extract v_hack from contrastive pairs of hidden states.
|
|
|
|
Per Wu-Tang (2026, arXiv 2604.01476) §3.1:
|
|
|
|
d = (1/N) * sum_i (h_i^+ - h_i^-)
|
|
|
|
where h^+ are last-token hidden states from hack-flavored prompts and h^- from
|
|
clean ones, taken at intermediate-to-late layers (60-75% of model depth).
|
|
|
|
Validation: held-out separation accuracy > 90%.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
from jaxtyping import Float
|
|
from loguru import logger
|
|
from torch import Tensor
|
|
|
|
|
|
@dataclass
|
|
class VHackResult:
|
|
v_hack: Float[Tensor, "d"] # unit-normed direction
|
|
val_accuracy: float # held-out hack-vs-clean separation accuracy
|
|
layer_idx: int
|
|
n_train: int
|
|
n_val: int
|
|
|
|
|
|
@torch.no_grad()
|
|
def collect_last_token_hidden(
|
|
model,
|
|
tokenizer,
|
|
prompts: list[str],
|
|
layer_idx: int,
|
|
device: str = "cuda",
|
|
) -> Float[Tensor, "n d"]:
|
|
"""Forward each prompt, return last-token hidden state at layer_idx."""
|
|
hs = []
|
|
for p in prompts:
|
|
ids = tokenizer(p, return_tensors="pt").to(device)
|
|
out = model(**ids, output_hidden_states=True)
|
|
# out.hidden_states is tuple of (n_layers+1,) tensors of shape (1, seq, d)
|
|
h = out.hidden_states[layer_idx][0, -1, :].float().cpu() # "d" — fp32 for stable v_hack
|
|
hs.append(h)
|
|
return torch.stack(hs, dim=0)
|
|
|
|
|
|
def extract_vhack(
|
|
h_hack_train: Float[Tensor, "n_train d"],
|
|
h_clean_train: Float[Tensor, "n_train d"],
|
|
h_hack_val: Float[Tensor, "n_val d"],
|
|
h_clean_val: Float[Tensor, "n_val d"],
|
|
layer_idx: int,
|
|
) -> VHackResult:
|
|
"""Mean-difference direction with held-out validation."""
|
|
v = (h_hack_train.mean(dim=0) - h_clean_train.mean(dim=0))
|
|
v = v / (v.norm() + 1e-12)
|
|
|
|
# Validate: projection score on hack should exceed clean.
|
|
s_hack = h_hack_val @ v
|
|
s_clean = h_clean_val @ v
|
|
# paired accuracy: each (hack, clean) pair, hack should score higher
|
|
correct = (s_hack > s_clean).float().mean().item()
|
|
|
|
logger.info(
|
|
f"v_hack extracted layer={layer_idx} n_train={len(h_hack_train)} "
|
|
f"n_val={len(h_hack_val)} val_acc={correct:.3f} "
|
|
f"SHOULD>0.9 on a trained model: v_hack should separate hack from clean. "
|
|
f"On tiny-random/untrained models val_acc~0.5 (no semantic structure yet), "
|
|
f"which is fine for smoke -- the projection mechanism is what we test there."
|
|
)
|
|
|
|
return VHackResult(
|
|
v_hack=v,
|
|
val_accuracy=correct,
|
|
layer_idx=layer_idx,
|
|
n_train=len(h_hack_train),
|
|
n_val=len(h_hack_val),
|
|
)
|