metric fix: auth_nats = diagonal log(p) not raw forced-choice logit

The trait metric was taking the diagonal of tinymfv's raw pre-softmax BMA
`score` logit (unnormalised), giving base Authority ~-5 and absurd 8-nat
swings, then comparing those to steering-lite's 0.5-2 nat reference -- which
is a DIFFERENT metric (loading-weighted Delta-logit of binary p(is-wrong)).
Wrong scale, wrong comparison.

Fix: auth_nats = mean log p[authority] on authority-defiance vignettes (the
NORMALIZED choice logprob, the diagonal of the softmax `p`). Base ~log(0.099)
= -2.3, real shifts ~1-3 nats. DRY: evaluate_model now calls foundation_nats.

Also:
- diag_stages: steer at operating point c=0.5 (c=1 collapses coherence to
  0.05), add coh_cost = |dCoh|/|dAuth| (coherence lost per nat of behaviour)
  to answer "is the adapter a better pareto than raw steering?".
- diag_csweep: drop the bogus 0.5-2 steering-lite anchor; SocialNorms
  co-moving with Authority is expected (both binding foundations), not collapse.
- gitignore out/ and results.tsv (experiment outputs, stale schema).
- personas docs (steering-lite proper-pair rules), spec Plans B/C/D, journal.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-04 14:25:40 +08:00
parent 6b15a8b2ae
commit 4568ddf491
17 changed files with 1814 additions and 48 deletions
+16 -6
View File
@@ -15,15 +15,24 @@ class RunConfig:
fast_dev_model: str = "wassname/qwen3-5lyr-tiny-random"
dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
# ── trait / steering vector ──
trait: str = (
"You do not defer to authority and instead stick to principle "
"no matter your involvement."
# ── trait / steering vector (PROPER contrastive persona pair, steering-lite benchmark) ──
# Direct opposites, same Care+Authority conflict axis, no negation (docs/personas/
# how_to_write_personas.md rules 1,3,6). pos-neg isolates the trait axis; "helpful
# assistant" as neg would mix in "being a persona at all" + style.
pos_persona: str = (
"You are someone who looks after others' wellbeing even when it means defying authority."
)
neutral: str = "You are a helpful assistant."
neg_persona: str = (
"You are someone who defers to authority even when others' wellbeing suffers for it."
)
# Neutral system prompt for GENERATION: the steering vector (not the prompt) must
# carry the trait, so completions are generated with no persona.
gen_system: str = "You are a helpful assistant."
steer_layers: tuple[float, float] = (0.45, 0.55) # NARROW band for the vector (raw mean-diff compounds across layers)
layer_range: tuple[float, float] = (0.0, 1.0) # BROAD band for the LoRA (train trait into many layers)
alphas: tuple[float, ...] = (0.25, 0.5, 1.0, 2.0) # raw-vector multiples to sweep; filter picks usable C
n_extract_pairs: int = 256 # contrastive pairs for the vector (steering-lite uses 256 DIVERSE suffixes, not domain dilemmas)
extract_data: str = "data/branching_suffixes.json" # diverse contexts for extraction (550 suffixes, 10 categories)
# ── generation + filter (U1) ──
n_prompts: int = 16
@@ -45,7 +54,7 @@ class RunConfig:
# ── eval (tinymfv) ──
eval_vignettes: int | None = None # None = all Clifford-2015 vignettes
eval_think_tokens: int = 64 # tinymfv default; 10x faster than 256, within bf16 noise
eval_think_tokens: int = 128 # 64 gives noisy mean-mass shift (journal plan C); 128 for reliable small-dAuth signal
# ── loop (U3) ──
n_rounds: int = 4
@@ -56,6 +65,7 @@ class RunConfig:
TINY = dict(
n_prompts=4,
n_extract_pairs=8,
n_keep=3,
gen_max_new_tokens=32,
max_len=128,
+65 -17
View File
@@ -1,19 +1,50 @@
"""tinymfv eval -> {auth, care, coherence, ppx_json}.
"""tinymfv eval -> trait metric in NATS (auth logp) + coherence canary.
auth/care are the model's mean probability on the Authority/Care moral
foundations (the trait axis we move). coherence = mean_pmass_allowed (the
forced-choice canary). These are kept distinct: we shift auth on purpose,
coherence must not collapse.
The headline trait metric is `auth_nats` = the model's mean forced-choice logit
for "authority" being the violation type, over Authority-violation vignettes
(the diagonal of tinymfv per-row `score`, a 7-way pre-softmax fwd/rev-averaged
logit). tinymfv's forced choice ASSUMES wrongness and asks WHICH foundation, so
this is an attribution logit, not a p(is-wrong) logit.
SCALE WARNING: this is NOT steering-lite's auth_sep (its loading-weighted Δlogit
of binary p(is-wrong), reference 0.5-2 nats). tinymfv's forced-choice logit lives
on a different, much larger scale: base Authority ~-5 on classic n=132, and a
real steering shift is several nats. Do NOT compare auth_nats deltas to the
steering-lite 0.5-2 reference. Judge the WITHIN-tinymfv delta:
auth_sep = base_auth_nats - steered_auth_nats (POSITIVE = authority-violations
look less wrong = the trait). Surgicality = |Δauth| relative to |Δcare|; note
SocialNorms co-moves with Authority (both binding/conformity foundations).
Coherence stays in prob (it's a mass), not nats.
"""
import math
import numpy as np
import tinymfv
from loguru import logger
from steer_heal.config import RunConfig
def foundation_nats(rep) -> dict:
"""Mean choice-LOGPROB per foundation on ITS OWN vignettes (the diagonal of
the per-row 7-way softmax `p`), from a return_per_row=True rep. Reads as 'log
prob the model attributes a violation of foundation F to foundation F'.
NOTE: log(p), the NORMALIZED choice logprob (<=0, nats), NOT the raw pre-softmax
`score` logit (unnormalized BMA, base ~-5, absurd swings). Authority base
~log(0.099)=-2.3; steering 'do not defer to authority' lowers log p[authority]
on authority-defiance vignettes. Judge auth_sep = base - steered (a Δlogprob,
same family as steering-lite's Δlogit); a real shift is ~1-3 nats here."""
coarse_order = list(rep["profile"]["foundation"]) # aligns with each per-row p 7-vec
out = {}
for f in coarse_order:
idx = coarse_order.index(f)
rows = [r for r in rep["per_row"] if r["foundation_coarse"] == f]
out[f] = float(np.mean([np.log(r["p"][idx]) for r in rows])) if rows else float("nan")
return out
def evaluate_model(model, tok, cfg: RunConfig) -> dict:
rep = tinymfv.evaluate(
model, tok, name="classic",
@@ -22,23 +53,40 @@ def evaluate_model(model, tok, cfg: RunConfig) -> dict:
max_think_tokens=cfg.eval_think_tokens,
batch_size=8,
device=model.device,
return_per_row=True,
)
prof = rep["profile"] # pandas: foundation, human, model, model_T
p = dict(zip(prof["foundation"], prof["model"]))
# The trait "less deference to authority" moves SocialNorms DOWN and Care UP
# on gemma-3-1b-it (Authority is degenerate ~0; see RESEARCH_JOURNAL 2026-06-04).
# Report all foundations so we never lose the axis that actually moves.
# SHOULD: under steering, socialnorms drops and care rises; coherence holds.
prof = rep["profile"] # pandas: foundation (coarse), human, model, model_T
p = dict(zip(prof["foundation"], prof["model"])) # mean prob mass (kept for the map plot)
# NAT metric (single source: foundation_nats) = diagonal choice-logprob
# log p[F] on F-violation vignettes. Authority is the target: steering "do not
# defer to authority" LOWERS auth_nats on authority-defiance vignettes.
nats = foundation_nats(rep)
out = {
"socialnorms": float(p["SocialNorms"]), # trait axis: DOWN = more trait
"care": float(p["Care"]), # trait axis: UP = more trait
"auth_nats": nats["Authority"], # TARGET (nats): DOWN = trait
"socialnorms_nats": nats["SocialNorms"],
"care_nats": nats["Care"],
"fairness_nats": nats["Fairness"],
# prob-mass profile, only for the Care-vs-SocialNorms map plot (NOT the trait metric)
"socialnorms": float(p["SocialNorms"]),
"care": float(p["Care"]),
"auth": float(p["Authority"]),
"fairness": float(p["Fairness"]),
"liberty": float(p["Liberty"]),
"coherence": float(rep["mean_pmass_allowed"]),
"ppx_json": float(math.exp(rep["mean_nll_json"])),
"top1_acc": float(rep["top1_acc"]),
}
logger.info(f"eval: socialnorms={out['socialnorms']:.3f} care={out['care']:.3f} "
f"coherence={out['coherence']:.3f} ppx={out['ppx_json']:.1f}")
# SHOULD (trait, nats): steering "do not defer to authority" LOWERS auth_nats
# (= log p[authority] on authority-defiance vignettes; base ~-2.3). Judge the
# WITHIN-tinymfv delta auth_sep = base - steered; a real shift is ~1-3 nats on
# this log(p) scale (NOT steering-lite's 0.5-2, a different p(is-wrong) metric).
# SocialNorms co-moves with Authority (both binding/conformity foundations) -- that
# is expected, not broad collapse. Broad permissivizing = Care/Fairness drop AS MUCH.
# SHOULD (coherence = p_any_ans = mean_pmass_allowed): base/c=0 MUST be ~1.0. >=0.95 mild,
# 0.85-0.95 degraded, <0.85 broken. We want the auth_nats shift at coherence >=0.95.
coh = out["coherence"]
tag = "coherent" if coh >= 0.95 else "degraded" if coh >= 0.85 else "BROKEN"
logger.info(f"eval: auth_nats↓={out['auth_nats']:+.2f} (socnorm={out['socialnorms_nats']:+.2f} "
f"care={out['care_nats']:+.2f} fair={out['fairness_nats']:+.2f}) "
f"coherence→={coh:.3f} ({tag}) ppx↓={out['ppx_json']:.1f}")
return out
+41 -6
View File
@@ -11,6 +11,7 @@ from collections import Counter
import torch
from loguru import logger
from tqdm import tqdm
from steer_heal.config import RunConfig
@@ -47,7 +48,7 @@ def ppl_under_base(model, tok, prompt: str, completion: str) -> float:
def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
"""Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep."""
scored = []
for c in comps:
for c in tqdm(comps, desc="filter ppl", mininterval=120, maxinterval=120):
rf = rep_frac(c["completion"])
nar = bool(NARRATE.search(c["completion"]))
ppl = ppl_under_base(model, tok, c["prompt"], c["completion"])
@@ -65,17 +66,24 @@ def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
df = pl.DataFrame([{k: s[k] for k in ("alpha", "ppl", "rep", "narrates", "keep")} for s in scored])
g = (df.group_by("alpha")
.agg(pl.col("ppl").mean().round(1).alias("ppl_mean"),
pl.col("keep").mean().round(2).alias("kept_frac"),
.agg(pl.col("ppl").mean().round(1).alias("ppl_mean↑"),
pl.col("keep").mean().round(2).alias("kept_frac↓"),
pl.len().alias("n"))
.sort("alpha"))
logger.info(
"\nfilter columns:\n"
" alpha = raw-vector multiple (steering strength)\n"
" ppl_mean↑ = mean perplexity-under-original of the completions (↑ with alpha = more incoherent)\n"
" kept_frac↓ = fraction passing the filter (↓ with alpha = more dropped)\n"
" n = completions at this alpha"
)
logger.info(
"SHOULD (Q0 filter): ppl_mean RISES with alpha (stronger steering = less coherent) and "
"kept_frac FALLS. If kept_frac is flat across alpha, the filter is inert / threshold wrong "
"and we CANNOT filter. If ppl_mean is flat, steering did not inject incoherency."
)
logger.info("\nfilter vs steering strength:\n" +
tabulate(g.to_pandas(), headers="keys", tablefmt="github", floatfmt=".2f"))
tabulate(g.to_pandas(), headers="keys", tablefmt="github", floatfmt=".2f") + "\n")
lo = min(scored, key=lambda s: s["alpha"])
hi = max(scored, key=lambda s: s["alpha"])
# Full, untruncated dumps so we can judge coherence + trait ourselves (token-efficient-logging).
@@ -84,5 +92,32 @@ def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
f"\nCOMPLETION: {lo['completion']}")
logger.info(f"\n=== STEER SAMPLE @alpha={hi['alpha']:g} ppl={hi['ppl']:.0f} keep={hi['keep']} "
f"(high C, SHOULD be garbage if over-steered) ===\nCOMPLETION: {hi['completion']}")
logger.info(f"filter kept {len([s for s in scored if s['keep']])}/{len(scored)} "
f"(ppl<{cfg.ppl_tau:g}, rep<{cfg.rep_tau}, not-narrate)")
# GATE 2 qualitative: the completions straddling the ppl threshold (the actual
# decision boundary), so we can judge by eye whether the cut lands between
# coherent+trait and gibberish, or slices through coherent trait-laden text.
finite = sorted((s for s in scored if s["ppl"] != float("inf")), key=lambda s: s["ppl"])
just_kept = [s for s in finite if s["ppl"] < cfg.ppl_tau][-2:]
just_dropped = [s for s in finite if s["ppl"] >= cfg.ppl_tau][:2]
logger.info(
f"\n=== BORDERLINE samples around ppl_tau={cfg.ppl_tau:g} (judge the cut by eye): "
"SHOULD: just-kept still read coherent + on-trait; just-dropped read as breaking down. "
"If just-kept are base-like (no trait) -> filter keeps base, not trait. If just-dropped "
"still read coherent+on-trait -> threshold too strict, raise ppl_tau ==="
)
for s in just_kept:
logger.info(f"\n-- JUST-KEPT alpha={s['alpha']:g} ppl={s['ppl']:.0f} --\n{s['completion']}")
for s in just_dropped:
logger.info(f"\n-- JUST-DROPPED alpha={s['alpha']:g} ppl={s['ppl']:.0f} --\n{s['completion']}")
# per-criterion drop counts (overlapping): which filter is doing the work?
n_ppl = sum(s["ppl"] >= cfg.ppl_tau for s in scored)
n_rep = sum(s["rep"] >= cfg.rep_tau for s in scored)
n_nar = sum(s["narrates"] for s in scored)
n_kept = sum(s["keep"] for s in scored)
logger.info(
f"filter kept {n_kept}/{len(scored)}. dropped by (overlapping): "
f"coherence ppl>={cfg.ppl_tau:g}: {n_ppl}, repetition rep>={cfg.rep_tau}: {n_rep}, "
f"persona-leak narrate: {n_nar}. "
f"SHOULD: at high alpha coherence-ppl drops the most (steering breaks fluency). If "
f"persona-leak dominates, the model is NARRATING the trait not enacting it; if repetition "
f"dominates, steering collapsed to loops not incoherence."
)
+15 -1
View File
@@ -11,12 +11,26 @@ It is free to log almost everything to events.jsonl; do it.
"""
import dataclasses
import math
import sys
from pathlib import Path
import srsly
from loguru import logger
def _json_safe(x):
"""JSON cannot encode nan/inf. Map non-finite floats to None at the
serialization boundary (a foundation with zero eval vignettes -> nan logp;
real 132-vignette runs never hit this, tiny-dev 4-vignette runs do)."""
if isinstance(x, float) and not math.isfinite(x):
return None
if isinstance(x, dict):
return {k: _json_safe(v) for k, v in x.items()}
if isinstance(x, list):
return [_json_safe(v) for v in x]
return x
REPO = Path(__file__).resolve().parents[2]
RESULTS_TSV = REPO / "results.tsv"
@@ -32,7 +46,7 @@ def make_run_dir(ts: str, slug: str, cfg) -> Path:
def log_event(run_dir: Path, **rec) -> None:
# append one jsonl line; events.jsonl is the full machine-readable trace.
srsly.write_jsonl(run_dir / "events.jsonl", [rec], append=True)
srsly.write_jsonl(run_dir / "events.jsonl", [_json_safe(rec)], append=True)
def append_result(cfg, metrics: dict) -> None:
+23 -9
View File
@@ -21,7 +21,7 @@ from steer_heal.filter import filter_completions, ppl_under_base
from steer_heal.heal import heal_round
from steer_heal.io import append_result, log_event, make_run_dir
from steer_heal.plot import write_map
from steer_heal.steering import generate_plain, generate_steered, teacher_vec
from steer_heal.steering import generate_plain, generate_steered, gpu_mem, teacher_vec
from steer_heal.ws.bake import baked
REPO = Path(__file__).resolve().parents[2]
@@ -69,21 +69,24 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
v0_flat = None # round-0 direction, for the Q3 cosine
rounds = []
for rnd in range(cfg.n_rounds):
logger.info(f"\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] ===")
logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===")
# extract teacher vector + sweep-generate steered data from the CURRENT student
with baked(model, hist_specs):
v = teacher_vec(model, tok, cfg)
comps = generate_steered(model, tok, v, cfg)
# filter under the ORIGINAL (no history, no steering) -- this picks the usable C
logger.info(f"\n=== FILTER [{len(comps)} completions] gpu {gpu_mem()} ===")
kept, scored = filter_completions(model, tok, comps, cfg)
log_event(run_dir, stage="gen", round=rnd, n_comps=len(comps), n_kept=len(kept), scored=scored)
# heal one round on top of the baked history, then fold
logger.info(f"\n=== HEAL [{cfg.reg}] gpu {gpu_mem()} ===")
lora, spec = heal_round(model, tok, kept, hist_specs, cfg)
lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg})
hist_specs.append(spec)
# eval the student (all rounds baked) + Q1: trained-adapter output coherence
logger.info(f"\n=== EVAL [tinymfv classic] gpu {gpu_mem()} ===")
with baked(model, hist_specs):
m = evaluate_model(model, tok, cfg)
adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts))
@@ -105,8 +108,8 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
"adapter_ppl": adapter_ppl, "n_kept": len(kept)}
rounds.append(rec)
log_event(run_dir, stage="round", **rec)
logger.info(f"round {rnd}: socialnorms={m['socialnorms']:.3f} care={m['care']:.3f} "
f"coh={m['coherence']:.3f} cos_v0={cos_v0:+.2f} adapter_ppl={adapter_ppl:.0f}")
logger.info(f"round {rnd}: auth_nats↓={m['auth_nats']:+.2f} care_nats={m['care_nats']:+.2f} "
f"coh={m['coherence']:.3f} cos_v0={cos_v0:+.2f} adapter_ppl={adapter_ppl:.0f}")
_log_loop_summary(rounds)
write_map(run_dir, rounds)
@@ -115,15 +118,26 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
def _log_loop_summary(rounds: list[dict]) -> None:
from tabulate import tabulate
# (rec_key, display header with direction arrow) -- single source of truth.
cols = [("round", "round"), ("auth_nats", "auth_nats↓"), ("care_nats", "care_nats"),
("coherence", "coherence→"), ("cos_v0", "cos_v0→"),
("adapter_ppl", "adapter_ppl↓"), ("n_kept", "n_kept")]
logger.info(
"\nloop columns:\n"
" auth_nats↓ = Authority logp on Authority vignettes, NATS (TARGET: down = less deference)\n"
" care_nats = Care logp, NATS (off-target axis -- should move LESS than auth if surgical)\n"
" coherence→ = p_any_ans = mean_pmass_allowed (OFF-TARGET: hold ~1.0)\n"
" cos_v0→ = cosine of round vector vs round-0 vector (direction stability)\n"
" adapter_ppl↓ = ppl-under-original of the no-steering adapter generations"
)
logger.info(
"\nSHOULD (Q2 loop-coherent): coherence stays >= round-0 floor across rounds (heal holds it up). "
"If coherence falls each round, the loop accumulates incoherency faster than heal removes it.\n"
"SHOULD (Q3 direction): socialnorms FALLS / care RISES monotonically and cos_v0 stays > 0.5 "
"(same direction each round). If the trait reverses or cos_v0 drops, the direction wanders."
"SHOULD (Q3 direction): auth_nats FALLS monotonically (0.5-2 nats is a real shift) and cos_v0 "
"stays > 0.5. If care_nats falls as much as auth_nats, it's broad permissivizing not surgical."
)
cols = ["round", "socialnorms", "care", "coherence", "cos_v0", "adapter_ppl", "n_kept"]
tbl = [{c: r.get(c) for c in cols} for r in rounds]
logger.info("\nloop summary:\n" + tabulate(tbl, headers="keys", tablefmt="github", floatfmt=".3f"))
tbl = [{disp: r.get(key) for key, disp in cols} for r in rounds]
logger.info("\nloop summary:\n" + tabulate(tbl, headers="keys", tablefmt="github", floatfmt=".3f") + "\n")
def main(cfg: RunConfig) -> None:
+35 -8
View File
@@ -3,26 +3,47 @@
import steering_lite as sl
import torch
from loguru import logger
from tqdm import tqdm
from steer_heal.config import RunConfig
from steer_heal.prompts import POOL, chat_prompt
def gpu_mem() -> str:
"""One-glance GPU footprint string for stage headers (token-efficient-logging)."""
if not torch.cuda.is_available():
return "cpu"
free, total = torch.cuda.mem_get_info()
return f"{(total - free) / 1e9:.1f}/{total / 1e9:.0f}GB"
def _layer_band(model, layer_range: tuple[float, float]) -> tuple[int, ...]:
n = model.config.get_text_config().num_hidden_layers # nested for multimodal (gemma-3-4b)
lo, hi = layer_range
return tuple(range(int(lo * n), max(int(hi * n), int(lo * n) + 1)))
def _extract_prompts(cfg: RunConfig) -> list[str]:
"""Diverse contexts for the contrastive pairs (steering-lite uses 256 of these,
NOT domain dilemmas). A domain-narrow set overfits the direction to the format;
diverse suffixes isolate the persona's general residual-stream shift."""
import json
from pathlib import Path
suffixes = json.loads(Path(cfg.extract_data).read_text())
return [s["suffix"] for s in suffixes[: cfg.n_extract_pairs]]
def teacher_vec(model, tok, cfg: RunConfig):
"""trait-sysprompt vs neutral-sysprompt mean-diff, then iso-KL dose to target_kl."""
"""trait-prefix vs neutral-prefix mean-diff over DIVERSE contexts, at the assistant tag."""
layers = _layer_band(model, cfg.steer_layers) # narrow band; raw mean-diff compounds across layers
prompts = POOL[: cfg.n_prompts] if cfg.n_prompts <= len(POOL) else POOL
pos = [chat_prompt(tok, cfg.trait, q) for q in prompts]
neg = [chat_prompt(tok, cfg.neutral, q) for q in prompts]
contexts = _extract_prompts(cfg)
pos = [chat_prompt(tok, cfg.pos_persona, q) for q in contexts]
neg = [chat_prompt(tok, cfg.neg_persona, q) for q in contexts]
# SHOULD: pos/neg end at the assistant tag (last token); the two differ ONLY
# in the system prompt. ELSE the vector mixes in user-turn differences.
# in the system prompt (the persona prefix). ELSE the vector mixes in user-turn
# differences. n_pairs ~256 diverse contexts (steering-lite reference), not 30 dilemmas.
logger.info(f"teacher_vec: {len(pos)} contrastive pairs over diverse contexts, layers={layers}")
logger.debug(f"--- POS[0] (trait) ---\n{pos[0]}\n--- NEG[0] (neutral) ---\n{neg[0]}")
# RAW (unnormalised) mean-diff = the residual-stream shift the trait system
@@ -49,21 +70,27 @@ def generate_steered(model, tok, v, cfg: RunConfig) -> list[dict]:
alpha collapses, and we keep the coherent-but-trait-laden ones.
"""
out = []
n_total = cfg.n_prompts * len(cfg.alphas)
logger.info(f"\n=== GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas] "
f"gpu {gpu_mem()} ===")
pbar = tqdm(total=n_total, desc="gen steered", mininterval=120, maxinterval=120)
for i in range(cfg.n_prompts):
user = POOL[i % len(POOL)]
text = chat_prompt(tok, cfg.neutral, user) # neutral prompt; the vector carries the trait
text = chat_prompt(tok, cfg.gen_system, user) # neutral prompt; the vector carries the trait
for alpha in cfg.alphas:
with v(model, C=alpha * v.cfg.coeff):
comp = _gen_one(model, tok, text, cfg)
out.append({"user": user, "prompt": text, "completion": comp, "alpha": float(alpha)})
pbar.update(1)
pbar.close()
return out
def generate_plain(model, tok, cfg: RunConfig, n: int) -> list[dict]:
"""Generate from the (baked) model with NO steering, for the Q1 heal comparison."""
out = []
for i in range(n):
for i in tqdm(range(n), desc="gen adapter", mininterval=120, maxinterval=120):
user = POOL[i % len(POOL)]
text = chat_prompt(tok, cfg.neutral, user)
text = chat_prompt(tok, cfg.gen_system, user)
out.append({"user": user, "prompt": text, "completion": _gen_one(model, tok, text, cfg)})
return out