Files
evil_MoE/scripts/rescore_deploy.py
T
wassname b334b5f516 fix: rescore_deploy tolerates old-schema checkpoints (default eval-harness params)
job 32/33 failed KeyError eval_batch_size: old checkpoints' stored cfg
predates the train_config refactor. Default eval_n_prompts/max_new/
eval_batch_size to the fast preset (eval-harness params, not model-defining;
test split is fixed-size) so historical checkpoints re-score.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-10 03:31:07 +00:00

89 lines
4.3 KiB
Python

"""Reproduce a finished run's paired knob-off/knob-on final-test evaluation."""
from __future__ import annotations
import json
from pathlib import Path
import torch
import tyro
from tyro.conf import Positional
from loguru import logger
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from vgrout.antipasto import wrap_model_with_antipasto
from vgrout.eval import ablate_quarantine, eval_hack_solve, load_eval_splits
from vgrout.train import CACHE_ROOT, EVAL_GEN_SEED
from vgrout.run_artifacts import RUN_SCHEMA
def main(run_dir: Positional[Path]) -> None:
ckpt = run_dir / "train.safetensors"
with safe_open(str(ckpt), framework="pt") as f:
meta = f.metadata()
cfg = json.loads(meta["cfg"])
model_name = meta["model"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"re-score {run_dir.name}: model={model_name} step={meta.get('step')}")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2",
).to(device)
model.config.use_cache = False
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device, grad_probe=False)
delta = load_file(str(ckpt))
delta_hack = load_file(str(run_dir / "train_hack.safetensors"))
assert set(delta) == set(wrappers), "checkpoint module set != adapter module set"
assert set(delta_hack) == set(wrappers), "quarantine checkpoint module set != adapter module set"
for name, t in delta.items():
wrappers[name]["delta_S"].data.copy_(t.to(device, torch.bfloat16))
wrappers[name]["delta_S_hack"].data.copy_(delta_hack[name].to(device, torch.bfloat16))
prior_eval = json.loads((run_dir / "deploy_test.json").read_text())
# by_mode keys ARE the modes the original deploy eval spanned (present in every json
# version); reproduce the same set so the re-scored knob-off matches the headline.
eval_modes = sorted(prior_eval["by_mode"].keys())
# Eval-harness params (not model-defining): default to the fast preset so we can re-score
# OLD checkpoints whose stored cfg predates the train_config refactor (eval_n_prompts /
# eval_batch_size were added/renamed). The test split is fixed-size regardless of n_prompts.
eval_n_prompts = cfg.get("eval_n_prompts", 32)
max_new = cfg.get("max_new", 512)
eval_bs = cfg.get("eval_batch_size", 8)
_, problems = load_eval_splits(eval_modes, eval_n_prompts)
gen_cfg_eval = GenerationConfig(
max_new_tokens=max_new, do_sample=True,
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0,
num_return_sequences=1, pad_token_id=tok.pad_token_id,
)
eval_idxs = list(range(len(problems)))
torch.manual_seed(EVAL_GEN_SEED)
with ablate_quarantine(wrappers): # knob OFF = the deployed model
ev = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg_eval, device, max_new, eval_bs)
torch.manual_seed(EVAL_GEN_SEED)
ev_on = eval_hack_solve(model, tok, problems, eval_idxs, gen_cfg_eval, device, max_new, eval_bs)
out = {
"schema": RUN_SCHEMA,
"run_dir": run_dir.name, "model": model_name, "step": meta.get("step"),
"eval_set": "test", "eval_modes": eval_modes,
"n": ev["n"], "deploy_hack": ev["hack"], "deploy_vhack": ev["vhack"], "deploy_solve": ev["solve"],
"deploy_hack_on": ev_on["hack"], "deploy_vhack_on": ev_on["vhack"],
"deploy_solve_on": ev_on["solve"],
"by_mode": {m: {"hack": h / max(1, c), "vhack": v / max(1, c), "solve": s / max(1, c), "n": c}
for m, (h, v, s, c) in ev["by_mode"].items()},
}
(run_dir / "deploy_test.json").write_text(json.dumps(out, indent=2))
logger.info(f"FINAL paired test n={ev['n']}: knob-off hack={ev['hack']:.3f} solve={ev['solve']:.3f}; "
f"knob-on hack={ev_on['hack']:.3f} solve={ev_on['solve']:.3f}")
for m, d in out["by_mode"].items():
logger.info(f" {m:14s} hack={d['hack']:.3f} vhack={d['vhack']:.3f} solve={d['solve']:.3f} n={d['n']}")
if __name__ == "__main__":
tyro.cli(main)