mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:59:35 +08:00
token-efficient extract/heldout logs + sensible verify defaults
- antipasto.py: per-module SVD-cached log → debug (was 252 INFO lines per run, pure noise on cache hits). Replace manual %-40 progress prints with a single tqdm progress bar (mininterval=60). - extract_vhack_grad.py: BLUF final tail — SHOULD line, TSV table, out path, argv, main metric, single cue emoji (🟢/🟡/🔴). Same data, ~30 fewer lines. - verify_vhack_heldout.py: same BLUF tail pattern. Defaults updated to point at baked rh25 + v_hack_rh25 (were Qwen3.5-0.8B smoke). Cosine columns relabelled to "energy" since v_hack is now [k, r] and the diagnostic is ||V·d||/||d|| (subspace energy fraction, ≥0). Held-out result for current v_hack_rh25 (pueue 23): median_energy=0.217, mean=0.286, n=252 modules. 🟡 below target 0.30 but 20× the prior synthetic-pair ~0.01. q_proj cleanest (0.351 median), down_proj weakest (0.146). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -44,7 +44,9 @@ def svd_cached(
|
||||
U, S, Vh = torch.linalg.svd(W_gpu, full_matrices=False)
|
||||
U, S, Vh = U.cpu(), S.cpu(), Vh.cpu()
|
||||
torch.save({"U": U, "S": S, "Vh": Vh}, final)
|
||||
logger.info(f"SVD cached: {final.name} shape U={tuple(U.shape)} S0={S[0]:.3f} S-1={S[-1]:.3e}")
|
||||
# debug: per-module SVD details only appear on first computation and are
|
||||
# noise on cache-hit runs (252 lines × ~150 char = ~38k chars per extract).
|
||||
logger.debug(f"SVD computed: {final.name} U={tuple(U.shape)} S0={S[0]:.3f} S-1={S[-1]:.3e}")
|
||||
return U, S, Vh
|
||||
|
||||
|
||||
@@ -100,8 +102,10 @@ def wrap_model_with_antipasto(
|
||||
]
|
||||
logger.info(f"AntiPaSTO attach: {len(targets)} target Linear modules in {model_name}")
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
out: dict[str, dict] = {}
|
||||
for i, (name, linear) in enumerate(targets):
|
||||
pbar = tqdm(targets, desc="AntiPaSTO attach", mininterval=60)
|
||||
for name, linear in pbar:
|
||||
W = linear.weight.data
|
||||
d_out, d_in = W.shape
|
||||
r = min(d_in, d_out)
|
||||
@@ -114,8 +118,6 @@ def wrap_model_with_antipasto(
|
||||
linear.register_parameter("_antipasto_delta_S", delta_S)
|
||||
handle = linear.register_forward_hook(_delta_hook)
|
||||
out[name] = {"layer": linear, "delta_S": delta_S, "handle": handle, "r": r}
|
||||
if (i + 1) % 40 == 0 or i == len(targets) - 1:
|
||||
logger.info(f" attached {i+1}/{len(targets)} last={name}")
|
||||
|
||||
# freeze everything except delta_S
|
||||
for n, p in model.named_parameters():
|
||||
|
||||
@@ -190,14 +190,19 @@ def main(cfg: Config) -> int:
|
||||
f"max_sv_top{k}_frac": f"{max(vals):.2f}",
|
||||
})
|
||||
|
||||
print(tabulate(agg_rows, headers="keys", tablefmt="pipe"))
|
||||
logger.info(
|
||||
f"v_hack saved to {cfg.out_path} "
|
||||
f"modules={len(v_hack)} top_k={k} zero-||D||={n_zero} "
|
||||
f"SHOULD: zero-||D||==0 and mean_sv_top{k}_frac > 0.5 (subspace captures most energy). "
|
||||
f"ELSE: either grad flow broken (zero) or hack signal is high-rank (low top-k frac, "
|
||||
f"consider larger k)."
|
||||
)
|
||||
# Final tail: BLUF — what an agent reads first should be result + interp.
|
||||
mean_frac = sum(float(r[f"sv_top{k}_frac"]) for r in rows) / max(len(rows), 1)
|
||||
cue = "🟢" if (mean_frac > 0.5 and n_zero == 0) else ("🟡" if n_zero == 0 else "🔴")
|
||||
|
||||
print(f"\nSHOULD: mean_sv_top{k}_frac > 0.5 per suffix (subspace captures most energy). "
|
||||
f"zero-||D||==0 (else grad flow broken).\n")
|
||||
print(tabulate(agg_rows, headers="keys", tablefmt="tsv", floatfmt=".2f"))
|
||||
print()
|
||||
print(f"out: {cfg.out_path}")
|
||||
print(f"argv: extract_vhack_grad --model={cfg.model} --top-k={k} --n-heldout={cfg.n_heldout}")
|
||||
print(f"main metric: mean_sv_top{k}_frac={mean_frac:.2f} [modules={len(v_hack)} zero-||D||={n_zero}]")
|
||||
print(f"{cue} k={k} pairs={len(train_pairs)}/{len(PAIRS)} modules={len(v_hack)} "
|
||||
f"mean_top{k}_frac={mean_frac:.2f} zero={n_zero}")
|
||||
|
||||
if n_zero > 0:
|
||||
logger.error(f"FAIL: {n_zero}/{len(v_hack)} modules have zero ||D|| -- gradient flow broken")
|
||||
|
||||
@@ -38,11 +38,11 @@ OUT_DIR = Path("out")
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model: str = "Qwen/Qwen3.5-0.8B"
|
||||
model: str = "out/baked/qwen3_4b_rh25"
|
||||
dtype: str = "bf16" # must match extract_vhack_grad.py and train.py
|
||||
v_hack_path: Path = OUT_DIR / "v_hack_smoke.safetensors"
|
||||
out_path: Path = OUT_DIR / "vhack_heldout_cos.safetensors"
|
||||
n_heldout: int = 5
|
||||
v_hack_path: Path = OUT_DIR / "v_hack_rh25.safetensors"
|
||||
out_path: Path = OUT_DIR / "vhack_heldout_cos_rh25.safetensors"
|
||||
n_heldout: int = 2
|
||||
|
||||
|
||||
def main(cfg: Config) -> int:
|
||||
@@ -102,23 +102,29 @@ def main(cfg: Config) -> int:
|
||||
agg_rows.append({
|
||||
"suffix": suf,
|
||||
"n": len(vals),
|
||||
"mean_cos": f"{t.mean():+.3f}",
|
||||
"median_cos": f"{t.median():+.3f}",
|
||||
"frac>0": f"{(t > 0).float().mean():.2f}",
|
||||
"min": f"{t.min():+.3f}",
|
||||
"max": f"{t.max():+.3f}",
|
||||
"mean_energy": f"{t.mean():.3f}",
|
||||
"median_energy": f"{t.median():.3f}",
|
||||
"min": f"{t.min():.3f}",
|
||||
"max": f"{t.max():.3f}",
|
||||
})
|
||||
print(tabulate(agg_rows, headers="keys", tablefmt="pipe"))
|
||||
|
||||
t_all = torch.tensor(all_cos)
|
||||
mean_energy = t_all.mean().item()
|
||||
median_energy = t_all.median().item()
|
||||
cue = "🟢" if median_energy > 0.30 else ("🟡" if median_energy > 0.10 else "🔴")
|
||||
|
||||
print(f"\nSHOULD: median_energy > 0.30 (held-out diff lands in trained subspace). "
|
||||
f"Prior synthetic-pair run got ~0.01 — that was the smoking gun.\n")
|
||||
print(tabulate(agg_rows, headers="keys", tablefmt="tsv", floatfmt=".3f"))
|
||||
print()
|
||||
print(f"out: {cfg.out_path}")
|
||||
print(f"argv: verify_vhack_heldout --model={cfg.model} --v-hack-path={cfg.v_hack_path}")
|
||||
print(f"main metric: median_energy={median_energy:.3f} [modules={len(all_cos)}]")
|
||||
print(f"{cue} modules={len(all_cos)} mean={mean_energy:.3f} median={median_energy:.3f}")
|
||||
|
||||
frac_pos = (t_all > 0).float().mean().item()
|
||||
mean_cos = t_all.mean().item()
|
||||
median_cos = t_all.median().item()
|
||||
logger.info(
|
||||
f"OVERALL modules={len(all_cos)} frac>0={frac_pos:.3f} "
|
||||
f"mean={mean_cos:+.3f} median={median_cos:+.3f} "
|
||||
f"SHOULD: frac>0 > 0.50 and mean > 0.20. ELSE: extraction noise dominates signal."
|
||||
)
|
||||
mean_cos = mean_energy
|
||||
median_cos = median_energy
|
||||
|
||||
# save for downstream plotting / sanity. Cos values as a single tensor;
|
||||
# module names in the metadata header (JSON-encoded preserves order).
|
||||
|
||||
Reference in New Issue
Block a user