mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:23:57 +08:00
d68c17e7c5
route/routeV final eval now measures both endpoints at n=119 test: knob-off (ablate_quarantine, the deploy headline) AND knob-on (trained model as-is). Writes deploy_hack_on/deploy_solve_on/deploy_vhack_on so the before->after quarantine move is plottable from the deploy set instead of borrowing the val curve's different scale. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
93 lines
4.0 KiB
Python
93 lines
4.0 KiB
Python
"""Offline validation progress curve from a run's saved adapter checkpoints.
|
|
|
|
Loads the model once, then scores ckpt_update0000/0010/... on the periodic validation split.
|
|
RouteV records both knob-on/train and knob-off/deploy; vanilla records one pass.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import tyro
|
|
from loguru import logger
|
|
from safetensors import safe_open
|
|
from safetensors.torch import load_file
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
from tyro.conf import Positional
|
|
|
|
from vgrout.antipasto import wrap_model_with_antipasto, wrap_model_with_lora_frozen_b
|
|
from vgrout.eval import ablate_quarantine, eval_hack_solve, load_eval_splits
|
|
from vgrout.train import CACHE_ROOT, EVAL_GEN_SEED
|
|
|
|
|
|
def _load(wrappers: dict, kept_path: Path, hack_path: Path) -> None:
|
|
kept, hack = load_file(str(kept_path)), load_file(str(hack_path))
|
|
assert set(kept) == set(wrappers) == set(hack)
|
|
for name, info in wrappers.items():
|
|
info["delta_S"].data.copy_(kept[name].to(info["delta_S"]))
|
|
info["delta_S_hack"].data.copy_(hack[name].to(info["delta_S_hack"]))
|
|
|
|
|
|
def main(run_dir: Positional[Path]) -> None:
|
|
ckpts = sorted(p for p in run_dir.glob("ckpt_update*.safetensors")
|
|
if not p.stem.endswith("_hack"))
|
|
assert ckpts, f"no ckpt_update*.safetensors in {run_dir}"
|
|
with safe_open(str(ckpts[-1]), framework="pt") as f:
|
|
meta = f.metadata()
|
|
cfg = json.loads(meta["cfg"])
|
|
model_name = meta["model"]
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
tok = AutoTokenizer.from_pretrained(model_name)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token = tok.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
dtype=torch.float32 if device.type == "cpu" else torch.bfloat16,
|
|
attn_implementation="sdpa" if device.type == "cpu" else "flash_attention_2",
|
|
).to(device)
|
|
model.config.use_cache = False
|
|
if cfg["adapter"] == "lora_frozen_b":
|
|
wrappers = wrap_model_with_lora_frozen_b(
|
|
model, model_name, r=cfg["lora_r"], b_seed=cfg["lora_b_seed"], grad_probe=False)
|
|
else:
|
|
assert cfg["adapter"] == "antipasto"
|
|
wrappers = wrap_model_with_antipasto(model, model_name, CACHE_ROOT, device, grad_probe=False)
|
|
|
|
eval_modes = json.loads((run_dir / "deploy_test.json").read_text())["eval_modes"]
|
|
problems, _ = load_eval_splits(eval_modes, cfg["eval_n_prompts"])
|
|
idxs = list(range(len(problems)))
|
|
gen_cfg = GenerationConfig(
|
|
max_new_tokens=cfg["max_new"], do_sample=True, temperature=0.7, top_p=1.0,
|
|
top_k=20, min_p=0.0, repetition_penalty=1.0, num_return_sequences=1,
|
|
pad_token_id=tok.pad_token_id,
|
|
)
|
|
out_path = run_dir / "eval_checkpoint_curve.jsonl"
|
|
out_path.write_text("")
|
|
is_route = cfg["intervention"] in ("route", "routeV")
|
|
for kept_path in ckpts:
|
|
hack_path = kept_path.with_name(kept_path.stem + "_hack.safetensors")
|
|
_load(wrappers, kept_path, hack_path)
|
|
updates = int(kept_path.stem.removeprefix("ckpt_update"))
|
|
torch.manual_seed(EVAL_GEN_SEED)
|
|
train = eval_hack_solve(model, tok, problems, idxs, gen_cfg, device, cfg["max_new"],
|
|
cfg["eval_batch_size"])
|
|
if is_route:
|
|
torch.manual_seed(EVAL_GEN_SEED)
|
|
with ablate_quarantine(wrappers):
|
|
deploy = eval_hack_solve(model, tok, problems, idxs, gen_cfg, device, cfg["max_new"],
|
|
cfg["eval_batch_size"])
|
|
else:
|
|
deploy = train
|
|
row = {"updates_completed": updates, "n": deploy["n"],
|
|
"train_hack": train["hack"], "train_solve": train["solve"],
|
|
"deploy_hack": deploy["hack"], "deploy_solve": deploy["solve"]}
|
|
with out_path.open("a") as f:
|
|
f.write(json.dumps(row) + "\n")
|
|
logger.info(row)
|
|
logger.info(f"wrote {out_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
tyro.cli(main)
|