Files
evil_MoE/scripts/eval_checkpoint_curve.py
T
wassname b53043cec3 refactor: extract train_config.py + run_artifacts.py from train.py; slim results scripts
Cleanup by a prior agent, verified green here: 'just smoke' (erase arm)
runs end-to-end and all four wired gates pass (verify_rewards 52/52,
verify_eval_gap, verify_partition, verify_science_invariants).

- train.py -318 lines: Config dataclass -> train_config.py, checkpoint/
  deploy-artifact IO -> run_artifacts.py.
- results.py / results_deploy.py / probe_distill.py slimmed.
- drop stale derived csvs under out/figs (a5_generalisation, dyn_*,
  substrate_aggregate, train_vs_deploy_60).
- gitignore /.pi/ panel scratch.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-09 13:34:50 +00:00

93 lines
4.0 KiB
Python

"""Offline validation progress curve from a run's saved adapter checkpoints.
Loads the model once, then scores ckpt_update0000/0010/... on the periodic validation split.
RouteV records both knob-on/train and knob-off/deploy; vanilla records one pass.
"""
from __future__ import annotations
import json
from pathlib import Path
import torch
import tyro
from loguru import logger
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from tyro.conf import Positional
from vgrout.antipasto import wrap_model_with_antipasto, wrap_model_with_lora_frozen_b
from vgrout.eval import ablate_quarantine, eval_hack_solve, load_eval_splits
from vgrout.train import CACHE_ROOT, EVAL_GEN_SEED
def _load(wrappers: dict, kept_path: Path, hack_path: Path) -> None:
kept, hack = load_file(str(kept_path)), load_file(str(hack_path))
assert set(kept) == set(wrappers) == set(hack)
for name, info in wrappers.items():
info["delta_S"].data.copy_(kept[name].to(info["delta_S"]))
info["delta_S_hack"].data.copy_(hack[name].to(info["delta_S_hack"]))
def main(run_dir: Positional[Path]) -> None:
ckpts = sorted(p for p in run_dir.glob("ckpt_update*.safetensors")
if not p.stem.endswith("_hack"))
assert ckpts, f"no ckpt_update*.safetensors in {run_dir}"
with safe_open(str(ckpts[-1]), framework="pt") as f:
meta = f.metadata()
cfg = json.loads(meta["cfg"])
model_name = meta["model"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.float32 if device.type == "cpu" else torch.bfloat16,
attn_implementation="sdpa" if device.type == "cpu" else "flash_attention_2",
).to(device)
model.config.use_cache = False
if cfg["adapter"] == "lora_frozen_b":
wrappers = wrap_model_with_lora_frozen_b(
model, model_name, r=cfg["lora_r"], b_seed=cfg["lora_b_seed"], grad_probe=False)
else:
assert cfg["adapter"] == "antipasto"
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device, grad_probe=False)
eval_modes = json.loads((run_dir / "deploy_test.json").read_text())["eval_modes"]
problems, _ = load_eval_splits(eval_modes, cfg["eval_n_prompts"])
idxs = list(range(len(problems)))
gen_cfg = GenerationConfig(
max_new_tokens=cfg["max_new"], do_sample=True, temperature=0.7, top_p=1.0,
top_k=20, min_p=0.0, repetition_penalty=1.0, num_return_sequences=1,
pad_token_id=tok.pad_token_id,
)
out_path = run_dir / "eval_checkpoint_curve.jsonl"
out_path.write_text("")
is_route = cfg["intervention"] == "routeV"
for kept_path in ckpts:
hack_path = kept_path.with_name(kept_path.stem + "_hack.safetensors")
_load(wrappers, kept_path, hack_path)
updates = int(kept_path.stem.removeprefix("ckpt_update"))
torch.manual_seed(EVAL_GEN_SEED)
train = eval_hack_solve(model, tok, problems, idxs, gen_cfg, device, cfg["max_new"],
cfg["eval_batch_size"])
if is_route:
torch.manual_seed(EVAL_GEN_SEED)
with ablate_quarantine(wrappers):
deploy = eval_hack_solve(model, tok, problems, idxs, gen_cfg, device, cfg["max_new"],
cfg["eval_batch_size"])
else:
deploy = train
row = {"updates_completed": updates, "n": deploy["n"],
"train_hack": train["hack"], "train_solve": train["solve"],
"deploy_hack": deploy["hack"], "deploy_solve": deploy["solve"]}
with out_path.open("a") as f:
f.write(json.dumps(row) + "\n")
logger.info(row)
logger.info(f"wrote {out_path}")
if __name__ == "__main__":
tyro.cli(main)