mirror of
https://github.com/wassname/isokl_steering_calibration.git
synced 2026-06-27 15:31:07 +08:00
66 lines
2.6 KiB
Python
66 lines
2.6 KiB
Python
"""Smoke test: train + calibrate + branch_pmass on a tiny random model.
|
|
|
|
Pass = runtime sanity. Distinguishing checks:
|
|
- measure_kl returns kl > 0 at coeff > 0 (steer DID change distribution).
|
|
- measure_kl returns kl ~= 0 at coeff = 0 (silent failure detector: if hooks
|
|
leak, base==steer KL would be nonzero).
|
|
- branch_pmass is in [0, 1].
|
|
- branch_pmass changes between coeff=0 and coeff=large (sneaky-fail catch:
|
|
if pmass is just identity-pass-through it would be invariant).
|
|
"""
|
|
from __future__ import annotations
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from iso_kl_figure import (
|
|
SteeringConfig, MeanDiffC, train, measure_kl, attach, detach,
|
|
)
|
|
from iso_kl_figure.branch_pmass import branch_pmass
|
|
|
|
|
|
MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
|
|
|
|
|
def test_pipeline_smoke():
|
|
tok = AutoTokenizer.from_pretrained(MODEL)
|
|
if tok.pad_token_id is None:
|
|
tok.pad_token_id = tok.eos_token_id
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.float32)
|
|
model.eval()
|
|
n_layers = model.config.num_hidden_layers
|
|
cfg = MeanDiffC(coeff=1.0, layers=(n_layers // 2,))
|
|
|
|
pos = ["Sure: <answer>", "Yes: <answer>", "Of course: <answer>", "Here: <answer>"]
|
|
neg = ["No way.", "I refuse.", "Cannot help.", "Decline."]
|
|
v = train(model, tok, pos, neg, cfg, batch_size=2, max_length=32)
|
|
|
|
# KL must be ~0 at coeff=0 (no leak), and > 0 at large coeff
|
|
v.cfg.coeff = 0.0
|
|
m0 = measure_kl(v, model, tok, ["hello world"], T=4, device="cpu")
|
|
assert m0["kl_p95"] < 1e-3, f"coeff=0 should give zero KL, got {m0['kl_p95']}"
|
|
|
|
v.cfg.coeff = 5.0
|
|
m1 = measure_kl(v, model, tok, ["hello world"], T=4, device="cpu")
|
|
assert m1["kl_p95"] > 0.0, "coeff>0 should give nonzero KL"
|
|
|
|
# per_t arrays length matches T
|
|
assert len(m1["per_t_p95"]) == 4
|
|
assert len(m1["per_t_max"]) == 4
|
|
|
|
# branch_pmass: in [0, 1] and varies with coeff
|
|
pids = tok("hello", return_tensors="pt").input_ids[0]
|
|
rolled = pids[-2:].clone()
|
|
suffix = ' {"value": '
|
|
fork = [0, 1, 2]
|
|
|
|
v.cfg.coeff = 0.0
|
|
p_zero = branch_pmass(v, model, tok, pids, rolled, fork, suffix,
|
|
["true", "false"], device="cpu")
|
|
v.cfg.coeff = 5.0
|
|
p_steer = branch_pmass(v, model, tok, pids, rolled, fork, suffix,
|
|
["true", "false"], device="cpu")
|
|
for x in p_zero["pmass"] + p_steer["pmass"]:
|
|
assert 0.0 <= x <= 1.0, f"pmass out of [0,1]: {x}"
|
|
diff = max(abs(a - b) for a, b in zip(p_zero["pmass"], p_steer["pmass"]))
|
|
assert diff > 1e-6, "pmass invariant to coeff -- hook is dead"
|