Files
evil_MoE/scripts/eval_checkpoint_curve.py
T
wassname c3af6cc03c rename: deployed/as_trained policy views, kill 'knob' (schema paired_final_v2)
Disambiguate the overloaded deploy/train/knob vocabulary (paper-consistent:
'quarantine' + 'ablated' + 'deployed' all match Cloud et al.). One opposite each:
- policy view: hack_deployed/solve_deployed (quarantine ablated, ships) vs
  hack_as_trained/solve_as_trained (quarantine attached). Unifies the old split
  deploy_hack (JSON) vs hack_deploy (table key) into one name.
- 'knob' -> 'quarantine'/'adapter' throughout comments and log strings.
- train/test reserved for the DATA split only.
Bump RUN_SCHEMA v1->v2 so old deploy_test.json files are skipped (not crashed) by
completed_runs. CLI flags untouched (queued jobs unaffected). Fixed two
replace_all collision bugs (hack_deploy substring of hack_deployed -> deployeded)
and the missed eval_curve writer (eval_checkpoint_curve.py) + readers
(results_deploy.py). Smoke green: v2 written + read; gates pass.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-10 05:26:51 +00:00

93 lines
4.0 KiB
Python

"""Offline validation progress curve from a run's saved adapter checkpoints.
Loads the model once, then scores ckpt_update0000/0010/... on the periodic validation split.
RouteV records both knob-on/train and knob-off/deploy; vanilla records one pass.
"""
from __future__ import annotations
import json
from pathlib import Path
import torch
import tyro
from loguru import logger
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from tyro.conf import Positional
from vgrout.antipasto import wrap_model_with_antipasto, wrap_model_with_lora_frozen_b
from vgrout.eval import ablate_quarantine, eval_hack_solve, load_eval_splits
from vgrout.train import CACHE_ROOT, EVAL_GEN_SEED
def _load(wrappers: dict, kept_path: Path, hack_path: Path) -> None:
kept, hack = load_file(str(kept_path)), load_file(str(hack_path))
assert set(kept) == set(wrappers) == set(hack)
for name, info in wrappers.items():
info["delta_S"].data.copy_(kept[name].to(info["delta_S"]))
info["delta_S_hack"].data.copy_(hack[name].to(info["delta_S_hack"]))
def main(run_dir: Positional[Path]) -> None:
ckpts = sorted(p for p in run_dir.glob("ckpt_update*.safetensors")
if not p.stem.endswith("_hack"))
assert ckpts, f"no ckpt_update*.safetensors in {run_dir}"
with safe_open(str(ckpts[-1]), framework="pt") as f:
meta = f.metadata()
cfg = json.loads(meta["cfg"])
model_name = meta["model"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.float32 if device.type == "cpu" else torch.bfloat16,
attn_implementation="sdpa" if device.type == "cpu" else "flash_attention_2",
).to(device)
model.config.use_cache = False
if cfg["adapter"] == "lora_frozen_b":
wrappers = wrap_model_with_lora_frozen_b(
model, model_name, r=cfg["lora_r"], b_seed=cfg["lora_b_seed"], grad_probe=False)
else:
assert cfg["adapter"] == "antipasto"
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device, grad_probe=False)
eval_modes = json.loads((run_dir / "deploy_test.json").read_text())["eval_modes"]
problems, _ = load_eval_splits(eval_modes, cfg["eval_n_prompts"])
idxs = list(range(len(problems)))
gen_cfg = GenerationConfig(
max_new_tokens=cfg["max_new"], do_sample=True, temperature=0.7, top_p=1.0,
top_k=20, min_p=0.0, repetition_penalty=1.0, num_return_sequences=1,
pad_token_id=tok.pad_token_id,
)
out_path = run_dir / "eval_checkpoint_curve.jsonl"
out_path.write_text("")
is_route = cfg["intervention"] == "routeV"
for kept_path in ckpts:
hack_path = kept_path.with_name(kept_path.stem + "_hack.safetensors")
_load(wrappers, kept_path, hack_path)
updates = int(kept_path.stem.removeprefix("ckpt_update"))
torch.manual_seed(EVAL_GEN_SEED)
train = eval_hack_solve(model, tok, problems, idxs, gen_cfg, device, cfg["max_new"],
cfg["eval_batch_size"])
if is_route:
torch.manual_seed(EVAL_GEN_SEED)
with ablate_quarantine(wrappers):
deploy = eval_hack_solve(model, tok, problems, idxs, gen_cfg, device, cfg["max_new"],
cfg["eval_batch_size"])
else:
deploy = train
row = {"updates_completed": updates, "n": deploy["n"],
"hack_as_trained": train["hack"], "solve_as_trained": train["solve"],
"hack_deployed": deploy["hack"], "solve_deployed": deploy["solve"]}
with out_path.open("a") as f:
f.write(json.dumps(row) + "\n")
logger.info(row)
logger.info(f"wrote {out_path}")
if __name__ == "__main__":
tyro.cli(main)