"""Reproduce a finished run's paired quarantine-ablated/enabled final-test evaluation.""" from __future__ import annotations import json from pathlib import Path import torch import tyro from tyro.conf import Positional from loguru import logger from safetensors import safe_open from safetensors.torch import load_file from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from vgrout.antipasto import wrap_model_with_antipasto from vgrout.eval import ablate_quarantine, eval_hack_solve, load_eval_splits from vgrout.train import CACHE_ROOT, EVAL_GEN_SEED from vgrout.run_artifacts import RUN_SCHEMA def main(run_dir: Positional[Path]) -> None: ckpt = run_dir / "train.safetensors" with safe_open(str(ckpt), framework="pt") as f: meta = f.metadata() cfg = json.loads(meta["cfg"]) model_name = meta["model"] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"re-score {run_dir.name}: model={model_name} step={meta.get('step')}") tok = AutoTokenizer.from_pretrained(model_name) if tok.pad_token_id is None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2", ).to(device) model.config.use_cache = False wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device, grad_probe=False) delta = load_file(str(ckpt)) delta_hack = load_file(str(run_dir / "train_hack.safetensors")) assert set(delta) == set(wrappers), "checkpoint module set != adapter module set" assert set(delta_hack) == set(wrappers), "quarantine checkpoint module set != adapter module set" for name, t in delta.items(): wrappers[name]["delta_S"].data.copy_(t.to(device, torch.bfloat16)) wrappers[name]["delta_S_hack"].data.copy_(delta_hack[name].to(device, torch.bfloat16)) prior_eval = json.loads((run_dir / "deploy_test.json").read_text()) # Reproduce the original evaluation environment exactly. eval_modes = sorted(prior_eval["by_mode"].keys()) _, problems = load_eval_splits(eval_modes, cfg["eval_n_prompts"]) gen_cfg_eval = GenerationConfig( max_new_tokens=cfg["max_new"], do_sample=True, temperature=0.7, top_p=1.0, top_k=20, min_p=0.0, repetition_penalty=1.0, num_return_sequences=1, pad_token_id=tok.pad_token_id, ) eval_idxs = list(range(len(problems))) torch.manual_seed(EVAL_GEN_SEED) with ablate_quarantine(wrappers): # quarantine ablated = deployed model ev = eval_hack_solve( model, tok, problems, eval_idxs, gen_cfg_eval, device, cfg["max_new"], cfg["eval_batch_size"]) torch.manual_seed(EVAL_GEN_SEED) ev_on = eval_hack_solve( model, tok, problems, eval_idxs, gen_cfg_eval, device, cfg["max_new"], cfg["eval_batch_size"]) out = { "schema": RUN_SCHEMA, "run_dir": run_dir.name, "model": model_name, "step": meta.get("step"), "eval_set": "test", "eval_modes": eval_modes, "n": ev["n"], "hack_deployed": ev["hack"], "vhack_deployed": ev["vhack"], "solve_deployed": ev["solve"], "hack_as_trained": ev_on["hack"], "vhack_as_trained": ev_on["vhack"], "solve_as_trained": ev_on["solve"], "by_mode": {m: {"hack": h / max(1, c), "vhack": v / max(1, c), "solve": s / max(1, c), "n": c} for m, (h, v, s, c) in ev["by_mode"].items()}, } (run_dir / "deploy_test.json").write_text(json.dumps(out, indent=2)) logger.info(f"FINAL paired test n={ev['n']}: quarantine-ablated hack={ev['hack']:.3f} " f"solve={ev['solve']:.3f}; quarantine-enabled hack={ev_on['hack']:.3f} " f"solve={ev_on['solve']:.3f}") for m, d in out["by_mode"].items(): logger.info(f" {m:14s} hack={d['hack']:.3f} vhack={d['vhack']:.3f} solve={d['solve']:.3f} n={d['n']}") if __name__ == "__main__": tyro.cli(main)