mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:30:30 +08:00
103d0acc2c
antipasto.py (PiSSA/lora_frozen_b/old-lora2r wrappers) is dead in the live path -- train.py/extract use lora2r.py, nothing imports antipasto. Move the 7 scripts that import it or the erase-era proj fns (rescore_deploy, eval_checkpoint_curve, verify_vhack_heldout, probe_distill, diag_cosine_dist, diag_pairs_compare, tt_erase_bench) to scripts/attic/ -- they need lora2r rewrites if resurrected. Live imports verified clean. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
85 lines
3.9 KiB
Python
85 lines
3.9 KiB
Python
"""Reproduce a finished run's paired quarantine-ablated/enabled final-test evaluation."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import tyro
|
|
from tyro.conf import Positional
|
|
from loguru import logger
|
|
from safetensors import safe_open
|
|
from safetensors.torch import load_file
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
|
|
from vgrout.antipasto import wrap_model_with_antipasto
|
|
from vgrout.eval import ablate_quarantine, eval_hack_solve, load_eval_splits
|
|
from vgrout.train import CACHE_ROOT, EVAL_GEN_SEED
|
|
from vgrout.run_artifacts import RUN_SCHEMA
|
|
|
|
|
|
def main(run_dir: Positional[Path]) -> None:
|
|
ckpt = run_dir / "train.safetensors"
|
|
with safe_open(str(ckpt), framework="pt") as f:
|
|
meta = f.metadata()
|
|
cfg = json.loads(meta["cfg"])
|
|
model_name = meta["model"]
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
logger.info(f"re-score {run_dir.name}: model={model_name} step={meta.get('step')}")
|
|
|
|
tok = AutoTokenizer.from_pretrained(model_name)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token = tok.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2",
|
|
).to(device)
|
|
model.config.use_cache = False
|
|
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device, grad_probe=False)
|
|
|
|
delta = load_file(str(ckpt))
|
|
delta_hack = load_file(str(run_dir / "train_hack.safetensors"))
|
|
assert set(delta) == set(wrappers), "checkpoint module set != adapter module set"
|
|
assert set(delta_hack) == set(wrappers), "quarantine checkpoint module set != adapter module set"
|
|
for name, t in delta.items():
|
|
wrappers[name]["delta_S"].data.copy_(t.to(device, torch.bfloat16))
|
|
wrappers[name]["delta_S_hack"].data.copy_(delta_hack[name].to(device, torch.bfloat16))
|
|
|
|
prior_eval = json.loads((run_dir / "deploy_test.json").read_text())
|
|
# Reproduce the original evaluation environment exactly.
|
|
eval_modes = sorted(prior_eval["by_mode"].keys())
|
|
_, problems = load_eval_splits(eval_modes, cfg["eval_n_prompts"])
|
|
gen_cfg_eval = GenerationConfig(
|
|
max_new_tokens=cfg["max_new"], do_sample=True,
|
|
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0,
|
|
num_return_sequences=1, pad_token_id=tok.pad_token_id,
|
|
)
|
|
eval_idxs = list(range(len(problems)))
|
|
torch.manual_seed(EVAL_GEN_SEED)
|
|
with ablate_quarantine(wrappers): # quarantine ablated = deployed model
|
|
ev = eval_hack_solve(
|
|
model, tok, problems, eval_idxs, gen_cfg_eval, device, cfg["max_new"], cfg["eval_batch_size"])
|
|
torch.manual_seed(EVAL_GEN_SEED)
|
|
ev_on = eval_hack_solve(
|
|
model, tok, problems, eval_idxs, gen_cfg_eval, device, cfg["max_new"], cfg["eval_batch_size"])
|
|
|
|
out = {
|
|
"schema": RUN_SCHEMA,
|
|
"run_dir": run_dir.name, "model": model_name, "step": meta.get("step"),
|
|
"eval_set": "test", "eval_modes": eval_modes,
|
|
"n": ev["n"], "hack_deployed": ev["hack"], "vhack_deployed": ev["vhack"], "solve_deployed": ev["solve"],
|
|
"hack_as_trained": ev_on["hack"], "vhack_as_trained": ev_on["vhack"],
|
|
"solve_as_trained": ev_on["solve"],
|
|
"by_mode": {m: {"hack": h / max(1, c), "vhack": v / max(1, c), "solve": s / max(1, c), "n": c}
|
|
for m, (h, v, s, c) in ev["by_mode"].items()},
|
|
}
|
|
(run_dir / "deploy_test.json").write_text(json.dumps(out, indent=2))
|
|
logger.info(f"FINAL paired test n={ev['n']}: quarantine-ablated hack={ev['hack']:.3f} "
|
|
f"solve={ev['solve']:.3f}; quarantine-enabled hack={ev_on['hack']:.3f} "
|
|
f"solve={ev_on['solve']:.3f}")
|
|
for m, d in out["by_mode"].items():
|
|
logger.info(f" {m:14s} hack={d['hack']:.3f} vhack={d['vhack']:.3f} solve={d['solve']:.3f} n={d['n']}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
tyro.cli(main)
|