mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:47:16 +08:00
metric fix: auth_nats = diagonal log(p) not raw forced-choice logit
The trait metric was taking the diagonal of tinymfv's raw pre-softmax BMA `score` logit (unnormalised), giving base Authority ~-5 and absurd 8-nat swings, then comparing those to steering-lite's 0.5-2 nat reference -- which is a DIFFERENT metric (loading-weighted Delta-logit of binary p(is-wrong)). Wrong scale, wrong comparison. Fix: auth_nats = mean log p[authority] on authority-defiance vignettes (the NORMALIZED choice logprob, the diagonal of the softmax `p`). Base ~log(0.099) = -2.3, real shifts ~1-3 nats. DRY: evaluate_model now calls foundation_nats. Also: - diag_stages: steer at operating point c=0.5 (c=1 collapses coherence to 0.05), add coh_cost = |dCoh|/|dAuth| (coherence lost per nat of behaviour) to answer "is the adapter a better pareto than raw steering?". - diag_csweep: drop the bogus 0.5-2 steering-lite anchor; SocialNorms co-moving with Authority is expected (both binding foundations), not collapse. - gitignore out/ and results.tsv (experiment outputs, stale schema). - personas docs (steering-lite proper-pair rules), spec Plans B/C/D, journal. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -15,15 +15,24 @@ class RunConfig:
|
||||
fast_dev_model: str = "wassname/qwen3-5lyr-tiny-random"
|
||||
dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
|
||||
|
||||
# ── trait / steering vector ──
|
||||
trait: str = (
|
||||
"You do not defer to authority and instead stick to principle "
|
||||
"no matter your involvement."
|
||||
# ── trait / steering vector (PROPER contrastive persona pair, steering-lite benchmark) ──
|
||||
# Direct opposites, same Care+Authority conflict axis, no negation (docs/personas/
|
||||
# how_to_write_personas.md rules 1,3,6). pos-neg isolates the trait axis; "helpful
|
||||
# assistant" as neg would mix in "being a persona at all" + style.
|
||||
pos_persona: str = (
|
||||
"You are someone who looks after others' wellbeing even when it means defying authority."
|
||||
)
|
||||
neutral: str = "You are a helpful assistant."
|
||||
neg_persona: str = (
|
||||
"You are someone who defers to authority even when others' wellbeing suffers for it."
|
||||
)
|
||||
# Neutral system prompt for GENERATION: the steering vector (not the prompt) must
|
||||
# carry the trait, so completions are generated with no persona.
|
||||
gen_system: str = "You are a helpful assistant."
|
||||
steer_layers: tuple[float, float] = (0.45, 0.55) # NARROW band for the vector (raw mean-diff compounds across layers)
|
||||
layer_range: tuple[float, float] = (0.0, 1.0) # BROAD band for the LoRA (train trait into many layers)
|
||||
alphas: tuple[float, ...] = (0.25, 0.5, 1.0, 2.0) # raw-vector multiples to sweep; filter picks usable C
|
||||
n_extract_pairs: int = 256 # contrastive pairs for the vector (steering-lite uses 256 DIVERSE suffixes, not domain dilemmas)
|
||||
extract_data: str = "data/branching_suffixes.json" # diverse contexts for extraction (550 suffixes, 10 categories)
|
||||
|
||||
# ── generation + filter (U1) ──
|
||||
n_prompts: int = 16
|
||||
@@ -45,7 +54,7 @@ class RunConfig:
|
||||
|
||||
# ── eval (tinymfv) ──
|
||||
eval_vignettes: int | None = None # None = all Clifford-2015 vignettes
|
||||
eval_think_tokens: int = 64 # tinymfv default; 10x faster than 256, within bf16 noise
|
||||
eval_think_tokens: int = 128 # 64 gives noisy mean-mass shift (journal plan C); 128 for reliable small-dAuth signal
|
||||
|
||||
# ── loop (U3) ──
|
||||
n_rounds: int = 4
|
||||
@@ -56,6 +65,7 @@ class RunConfig:
|
||||
|
||||
TINY = dict(
|
||||
n_prompts=4,
|
||||
n_extract_pairs=8,
|
||||
n_keep=3,
|
||||
gen_max_new_tokens=32,
|
||||
max_len=128,
|
||||
|
||||
+65
-17
@@ -1,19 +1,50 @@
|
||||
"""tinymfv eval -> {auth, care, coherence, ppx_json}.
|
||||
"""tinymfv eval -> trait metric in NATS (auth logp) + coherence canary.
|
||||
|
||||
auth/care are the model's mean probability on the Authority/Care moral
|
||||
foundations (the trait axis we move). coherence = mean_pmass_allowed (the
|
||||
forced-choice canary). These are kept distinct: we shift auth on purpose,
|
||||
coherence must not collapse.
|
||||
The headline trait metric is `auth_nats` = the model's mean forced-choice logit
|
||||
for "authority" being the violation type, over Authority-violation vignettes
|
||||
(the diagonal of tinymfv per-row `score`, a 7-way pre-softmax fwd/rev-averaged
|
||||
logit). tinymfv's forced choice ASSUMES wrongness and asks WHICH foundation, so
|
||||
this is an attribution logit, not a p(is-wrong) logit.
|
||||
|
||||
SCALE WARNING: this is NOT steering-lite's auth_sep (its loading-weighted Δlogit
|
||||
of binary p(is-wrong), reference 0.5-2 nats). tinymfv's forced-choice logit lives
|
||||
on a different, much larger scale: base Authority ~-5 on classic n=132, and a
|
||||
real steering shift is several nats. Do NOT compare auth_nats deltas to the
|
||||
steering-lite 0.5-2 reference. Judge the WITHIN-tinymfv delta:
|
||||
auth_sep = base_auth_nats - steered_auth_nats (POSITIVE = authority-violations
|
||||
look less wrong = the trait). Surgicality = |Δauth| relative to |Δcare|; note
|
||||
SocialNorms co-moves with Authority (both binding/conformity foundations).
|
||||
Coherence stays in prob (it's a mass), not nats.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import tinymfv
|
||||
from loguru import logger
|
||||
|
||||
from steer_heal.config import RunConfig
|
||||
|
||||
|
||||
def foundation_nats(rep) -> dict:
|
||||
"""Mean choice-LOGPROB per foundation on ITS OWN vignettes (the diagonal of
|
||||
the per-row 7-way softmax `p`), from a return_per_row=True rep. Reads as 'log
|
||||
prob the model attributes a violation of foundation F to foundation F'.
|
||||
|
||||
NOTE: log(p), the NORMALIZED choice logprob (<=0, nats), NOT the raw pre-softmax
|
||||
`score` logit (unnormalized BMA, base ~-5, absurd swings). Authority base
|
||||
~log(0.099)=-2.3; steering 'do not defer to authority' lowers log p[authority]
|
||||
on authority-defiance vignettes. Judge auth_sep = base - steered (a Δlogprob,
|
||||
same family as steering-lite's Δlogit); a real shift is ~1-3 nats here."""
|
||||
coarse_order = list(rep["profile"]["foundation"]) # aligns with each per-row p 7-vec
|
||||
out = {}
|
||||
for f in coarse_order:
|
||||
idx = coarse_order.index(f)
|
||||
rows = [r for r in rep["per_row"] if r["foundation_coarse"] == f]
|
||||
out[f] = float(np.mean([np.log(r["p"][idx]) for r in rows])) if rows else float("nan")
|
||||
return out
|
||||
|
||||
|
||||
def evaluate_model(model, tok, cfg: RunConfig) -> dict:
|
||||
rep = tinymfv.evaluate(
|
||||
model, tok, name="classic",
|
||||
@@ -22,23 +53,40 @@ def evaluate_model(model, tok, cfg: RunConfig) -> dict:
|
||||
max_think_tokens=cfg.eval_think_tokens,
|
||||
batch_size=8,
|
||||
device=model.device,
|
||||
return_per_row=True,
|
||||
)
|
||||
prof = rep["profile"] # pandas: foundation, human, model, model_T
|
||||
p = dict(zip(prof["foundation"], prof["model"]))
|
||||
# The trait "less deference to authority" moves SocialNorms DOWN and Care UP
|
||||
# on gemma-3-1b-it (Authority is degenerate ~0; see RESEARCH_JOURNAL 2026-06-04).
|
||||
# Report all foundations so we never lose the axis that actually moves.
|
||||
# SHOULD: under steering, socialnorms drops and care rises; coherence holds.
|
||||
prof = rep["profile"] # pandas: foundation (coarse), human, model, model_T
|
||||
p = dict(zip(prof["foundation"], prof["model"])) # mean prob mass (kept for the map plot)
|
||||
|
||||
# NAT metric (single source: foundation_nats) = diagonal choice-logprob
|
||||
# log p[F] on F-violation vignettes. Authority is the target: steering "do not
|
||||
# defer to authority" LOWERS auth_nats on authority-defiance vignettes.
|
||||
nats = foundation_nats(rep)
|
||||
|
||||
out = {
|
||||
"socialnorms": float(p["SocialNorms"]), # trait axis: DOWN = more trait
|
||||
"care": float(p["Care"]), # trait axis: UP = more trait
|
||||
"auth_nats": nats["Authority"], # TARGET (nats): DOWN = trait
|
||||
"socialnorms_nats": nats["SocialNorms"],
|
||||
"care_nats": nats["Care"],
|
||||
"fairness_nats": nats["Fairness"],
|
||||
# prob-mass profile, only for the Care-vs-SocialNorms map plot (NOT the trait metric)
|
||||
"socialnorms": float(p["SocialNorms"]),
|
||||
"care": float(p["Care"]),
|
||||
"auth": float(p["Authority"]),
|
||||
"fairness": float(p["Fairness"]),
|
||||
"liberty": float(p["Liberty"]),
|
||||
"coherence": float(rep["mean_pmass_allowed"]),
|
||||
"ppx_json": float(math.exp(rep["mean_nll_json"])),
|
||||
"top1_acc": float(rep["top1_acc"]),
|
||||
}
|
||||
logger.info(f"eval: socialnorms={out['socialnorms']:.3f} care={out['care']:.3f} "
|
||||
f"coherence={out['coherence']:.3f} ppx={out['ppx_json']:.1f}")
|
||||
# SHOULD (trait, nats): steering "do not defer to authority" LOWERS auth_nats
|
||||
# (= log p[authority] on authority-defiance vignettes; base ~-2.3). Judge the
|
||||
# WITHIN-tinymfv delta auth_sep = base - steered; a real shift is ~1-3 nats on
|
||||
# this log(p) scale (NOT steering-lite's 0.5-2, a different p(is-wrong) metric).
|
||||
# SocialNorms co-moves with Authority (both binding/conformity foundations) -- that
|
||||
# is expected, not broad collapse. Broad permissivizing = Care/Fairness drop AS MUCH.
|
||||
# SHOULD (coherence = p_any_ans = mean_pmass_allowed): base/c=0 MUST be ~1.0. >=0.95 mild,
|
||||
# 0.85-0.95 degraded, <0.85 broken. We want the auth_nats shift at coherence >=0.95.
|
||||
coh = out["coherence"]
|
||||
tag = "coherent" if coh >= 0.95 else "degraded" if coh >= 0.85 else "BROKEN"
|
||||
logger.info(f"eval: auth_nats↓={out['auth_nats']:+.2f} (socnorm={out['socialnorms_nats']:+.2f} "
|
||||
f"care={out['care_nats']:+.2f} fair={out['fairness_nats']:+.2f}) "
|
||||
f"coherence→={coh:.3f} ({tag}) ppx↓={out['ppx_json']:.1f}")
|
||||
return out
|
||||
|
||||
@@ -11,6 +11,7 @@ from collections import Counter
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
from steer_heal.config import RunConfig
|
||||
|
||||
@@ -47,7 +48,7 @@ def ppl_under_base(model, tok, prompt: str, completion: str) -> float:
|
||||
def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
|
||||
"""Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep."""
|
||||
scored = []
|
||||
for c in comps:
|
||||
for c in tqdm(comps, desc="filter ppl", mininterval=120, maxinterval=120):
|
||||
rf = rep_frac(c["completion"])
|
||||
nar = bool(NARRATE.search(c["completion"]))
|
||||
ppl = ppl_under_base(model, tok, c["prompt"], c["completion"])
|
||||
@@ -65,17 +66,24 @@ def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
|
||||
|
||||
df = pl.DataFrame([{k: s[k] for k in ("alpha", "ppl", "rep", "narrates", "keep")} for s in scored])
|
||||
g = (df.group_by("alpha")
|
||||
.agg(pl.col("ppl").mean().round(1).alias("ppl_mean"),
|
||||
pl.col("keep").mean().round(2).alias("kept_frac"),
|
||||
.agg(pl.col("ppl").mean().round(1).alias("ppl_mean↑"),
|
||||
pl.col("keep").mean().round(2).alias("kept_frac↓"),
|
||||
pl.len().alias("n"))
|
||||
.sort("alpha"))
|
||||
logger.info(
|
||||
"\nfilter columns:\n"
|
||||
" alpha = raw-vector multiple (steering strength)\n"
|
||||
" ppl_mean↑ = mean perplexity-under-original of the completions (↑ with alpha = more incoherent)\n"
|
||||
" kept_frac↓ = fraction passing the filter (↓ with alpha = more dropped)\n"
|
||||
" n = completions at this alpha"
|
||||
)
|
||||
logger.info(
|
||||
"SHOULD (Q0 filter): ppl_mean RISES with alpha (stronger steering = less coherent) and "
|
||||
"kept_frac FALLS. If kept_frac is flat across alpha, the filter is inert / threshold wrong "
|
||||
"and we CANNOT filter. If ppl_mean is flat, steering did not inject incoherency."
|
||||
)
|
||||
logger.info("\nfilter vs steering strength:\n" +
|
||||
tabulate(g.to_pandas(), headers="keys", tablefmt="github", floatfmt=".2f"))
|
||||
tabulate(g.to_pandas(), headers="keys", tablefmt="github", floatfmt=".2f") + "\n")
|
||||
lo = min(scored, key=lambda s: s["alpha"])
|
||||
hi = max(scored, key=lambda s: s["alpha"])
|
||||
# Full, untruncated dumps so we can judge coherence + trait ourselves (token-efficient-logging).
|
||||
@@ -84,5 +92,32 @@ def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
|
||||
f"\nCOMPLETION: {lo['completion']}")
|
||||
logger.info(f"\n=== STEER SAMPLE @alpha={hi['alpha']:g} ppl={hi['ppl']:.0f} keep={hi['keep']} "
|
||||
f"(high C, SHOULD be garbage if over-steered) ===\nCOMPLETION: {hi['completion']}")
|
||||
logger.info(f"filter kept {len([s for s in scored if s['keep']])}/{len(scored)} "
|
||||
f"(ppl<{cfg.ppl_tau:g}, rep<{cfg.rep_tau}, not-narrate)")
|
||||
# GATE 2 qualitative: the completions straddling the ppl threshold (the actual
|
||||
# decision boundary), so we can judge by eye whether the cut lands between
|
||||
# coherent+trait and gibberish, or slices through coherent trait-laden text.
|
||||
finite = sorted((s for s in scored if s["ppl"] != float("inf")), key=lambda s: s["ppl"])
|
||||
just_kept = [s for s in finite if s["ppl"] < cfg.ppl_tau][-2:]
|
||||
just_dropped = [s for s in finite if s["ppl"] >= cfg.ppl_tau][:2]
|
||||
logger.info(
|
||||
f"\n=== BORDERLINE samples around ppl_tau={cfg.ppl_tau:g} (judge the cut by eye): "
|
||||
"SHOULD: just-kept still read coherent + on-trait; just-dropped read as breaking down. "
|
||||
"If just-kept are base-like (no trait) -> filter keeps base, not trait. If just-dropped "
|
||||
"still read coherent+on-trait -> threshold too strict, raise ppl_tau ==="
|
||||
)
|
||||
for s in just_kept:
|
||||
logger.info(f"\n-- JUST-KEPT alpha={s['alpha']:g} ppl={s['ppl']:.0f} --\n{s['completion']}")
|
||||
for s in just_dropped:
|
||||
logger.info(f"\n-- JUST-DROPPED alpha={s['alpha']:g} ppl={s['ppl']:.0f} --\n{s['completion']}")
|
||||
# per-criterion drop counts (overlapping): which filter is doing the work?
|
||||
n_ppl = sum(s["ppl"] >= cfg.ppl_tau for s in scored)
|
||||
n_rep = sum(s["rep"] >= cfg.rep_tau for s in scored)
|
||||
n_nar = sum(s["narrates"] for s in scored)
|
||||
n_kept = sum(s["keep"] for s in scored)
|
||||
logger.info(
|
||||
f"filter kept {n_kept}/{len(scored)}. dropped by (overlapping): "
|
||||
f"coherence ppl>={cfg.ppl_tau:g}: {n_ppl}, repetition rep>={cfg.rep_tau}: {n_rep}, "
|
||||
f"persona-leak narrate: {n_nar}. "
|
||||
f"SHOULD: at high alpha coherence-ppl drops the most (steering breaks fluency). If "
|
||||
f"persona-leak dominates, the model is NARRATING the trait not enacting it; if repetition "
|
||||
f"dominates, steering collapsed to loops not incoherence."
|
||||
)
|
||||
|
||||
+15
-1
@@ -11,12 +11,26 @@ It is free to log almost everything to events.jsonl; do it.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import math
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import srsly
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def _json_safe(x):
|
||||
"""JSON cannot encode nan/inf. Map non-finite floats to None at the
|
||||
serialization boundary (a foundation with zero eval vignettes -> nan logp;
|
||||
real 132-vignette runs never hit this, tiny-dev 4-vignette runs do)."""
|
||||
if isinstance(x, float) and not math.isfinite(x):
|
||||
return None
|
||||
if isinstance(x, dict):
|
||||
return {k: _json_safe(v) for k, v in x.items()}
|
||||
if isinstance(x, list):
|
||||
return [_json_safe(v) for v in x]
|
||||
return x
|
||||
|
||||
REPO = Path(__file__).resolve().parents[2]
|
||||
RESULTS_TSV = REPO / "results.tsv"
|
||||
|
||||
@@ -32,7 +46,7 @@ def make_run_dir(ts: str, slug: str, cfg) -> Path:
|
||||
|
||||
def log_event(run_dir: Path, **rec) -> None:
|
||||
# append one jsonl line; events.jsonl is the full machine-readable trace.
|
||||
srsly.write_jsonl(run_dir / "events.jsonl", [rec], append=True)
|
||||
srsly.write_jsonl(run_dir / "events.jsonl", [_json_safe(rec)], append=True)
|
||||
|
||||
|
||||
def append_result(cfg, metrics: dict) -> None:
|
||||
|
||||
+23
-9
@@ -21,7 +21,7 @@ from steer_heal.filter import filter_completions, ppl_under_base
|
||||
from steer_heal.heal import heal_round
|
||||
from steer_heal.io import append_result, log_event, make_run_dir
|
||||
from steer_heal.plot import write_map
|
||||
from steer_heal.steering import generate_plain, generate_steered, teacher_vec
|
||||
from steer_heal.steering import generate_plain, generate_steered, gpu_mem, teacher_vec
|
||||
from steer_heal.ws.bake import baked
|
||||
|
||||
REPO = Path(__file__).resolve().parents[2]
|
||||
@@ -69,21 +69,24 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
v0_flat = None # round-0 direction, for the Q3 cosine
|
||||
rounds = []
|
||||
for rnd in range(cfg.n_rounds):
|
||||
logger.info(f"\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] ===")
|
||||
logger.info(f"\n\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] gpu {gpu_mem()} ===")
|
||||
# extract teacher vector + sweep-generate steered data from the CURRENT student
|
||||
with baked(model, hist_specs):
|
||||
v = teacher_vec(model, tok, cfg)
|
||||
comps = generate_steered(model, tok, v, cfg)
|
||||
# filter under the ORIGINAL (no history, no steering) -- this picks the usable C
|
||||
logger.info(f"\n=== FILTER [{len(comps)} completions] gpu {gpu_mem()} ===")
|
||||
kept, scored = filter_completions(model, tok, comps, cfg)
|
||||
log_event(run_dir, stage="gen", round=rnd, n_comps=len(comps), n_kept=len(kept), scored=scored)
|
||||
|
||||
# heal one round on top of the baked history, then fold
|
||||
logger.info(f"\n=== HEAL [{cfg.reg}] gpu {gpu_mem()} ===")
|
||||
lora, spec = heal_round(model, tok, kept, hist_specs, cfg)
|
||||
lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg})
|
||||
hist_specs.append(spec)
|
||||
|
||||
# eval the student (all rounds baked) + Q1: trained-adapter output coherence
|
||||
logger.info(f"\n=== EVAL [tinymfv classic] gpu {gpu_mem()} ===")
|
||||
with baked(model, hist_specs):
|
||||
m = evaluate_model(model, tok, cfg)
|
||||
adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts))
|
||||
@@ -105,8 +108,8 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
"adapter_ppl": adapter_ppl, "n_kept": len(kept)}
|
||||
rounds.append(rec)
|
||||
log_event(run_dir, stage="round", **rec)
|
||||
logger.info(f"round {rnd}: socialnorms={m['socialnorms']:.3f} care={m['care']:.3f} "
|
||||
f"coh={m['coherence']:.3f} cos_v0={cos_v0:+.2f} adapter_ppl={adapter_ppl:.0f}")
|
||||
logger.info(f"round {rnd}: auth_nats↓={m['auth_nats']:+.2f} care_nats={m['care_nats']:+.2f} "
|
||||
f"coh→={m['coherence']:.3f} cos_v0={cos_v0:+.2f} adapter_ppl={adapter_ppl:.0f}")
|
||||
|
||||
_log_loop_summary(rounds)
|
||||
write_map(run_dir, rounds)
|
||||
@@ -115,15 +118,26 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
|
||||
def _log_loop_summary(rounds: list[dict]) -> None:
|
||||
from tabulate import tabulate
|
||||
# (rec_key, display header with direction arrow) -- single source of truth.
|
||||
cols = [("round", "round"), ("auth_nats", "auth_nats↓"), ("care_nats", "care_nats"),
|
||||
("coherence", "coherence→"), ("cos_v0", "cos_v0→"),
|
||||
("adapter_ppl", "adapter_ppl↓"), ("n_kept", "n_kept")]
|
||||
logger.info(
|
||||
"\nloop columns:\n"
|
||||
" auth_nats↓ = Authority logp on Authority vignettes, NATS (TARGET: down = less deference)\n"
|
||||
" care_nats = Care logp, NATS (off-target axis -- should move LESS than auth if surgical)\n"
|
||||
" coherence→ = p_any_ans = mean_pmass_allowed (OFF-TARGET: hold ~1.0)\n"
|
||||
" cos_v0→ = cosine of round vector vs round-0 vector (direction stability)\n"
|
||||
" adapter_ppl↓ = ppl-under-original of the no-steering adapter generations"
|
||||
)
|
||||
logger.info(
|
||||
"\nSHOULD (Q2 loop-coherent): coherence stays >= round-0 floor across rounds (heal holds it up). "
|
||||
"If coherence falls each round, the loop accumulates incoherency faster than heal removes it.\n"
|
||||
"SHOULD (Q3 direction): socialnorms FALLS / care RISES monotonically and cos_v0 stays > 0.5 "
|
||||
"(same direction each round). If the trait reverses or cos_v0 drops, the direction wanders."
|
||||
"SHOULD (Q3 direction): auth_nats FALLS monotonically (0.5-2 nats is a real shift) and cos_v0 "
|
||||
"stays > 0.5. If care_nats falls as much as auth_nats, it's broad permissivizing not surgical."
|
||||
)
|
||||
cols = ["round", "socialnorms", "care", "coherence", "cos_v0", "adapter_ppl", "n_kept"]
|
||||
tbl = [{c: r.get(c) for c in cols} for r in rounds]
|
||||
logger.info("\nloop summary:\n" + tabulate(tbl, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
tbl = [{disp: r.get(key) for key, disp in cols} for r in rounds]
|
||||
logger.info("\nloop summary:\n" + tabulate(tbl, headers="keys", tablefmt="github", floatfmt=".3f") + "\n")
|
||||
|
||||
|
||||
def main(cfg: RunConfig) -> None:
|
||||
|
||||
@@ -3,26 +3,47 @@
|
||||
import steering_lite as sl
|
||||
import torch
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
from steer_heal.config import RunConfig
|
||||
from steer_heal.prompts import POOL, chat_prompt
|
||||
|
||||
|
||||
def gpu_mem() -> str:
|
||||
"""One-glance GPU footprint string for stage headers (token-efficient-logging)."""
|
||||
if not torch.cuda.is_available():
|
||||
return "cpu"
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
return f"{(total - free) / 1e9:.1f}/{total / 1e9:.0f}GB"
|
||||
|
||||
|
||||
def _layer_band(model, layer_range: tuple[float, float]) -> tuple[int, ...]:
|
||||
n = model.config.get_text_config().num_hidden_layers # nested for multimodal (gemma-3-4b)
|
||||
lo, hi = layer_range
|
||||
return tuple(range(int(lo * n), max(int(hi * n), int(lo * n) + 1)))
|
||||
|
||||
|
||||
def _extract_prompts(cfg: RunConfig) -> list[str]:
|
||||
"""Diverse contexts for the contrastive pairs (steering-lite uses 256 of these,
|
||||
NOT domain dilemmas). A domain-narrow set overfits the direction to the format;
|
||||
diverse suffixes isolate the persona's general residual-stream shift."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
suffixes = json.loads(Path(cfg.extract_data).read_text())
|
||||
return [s["suffix"] for s in suffixes[: cfg.n_extract_pairs]]
|
||||
|
||||
|
||||
def teacher_vec(model, tok, cfg: RunConfig):
|
||||
"""trait-sysprompt vs neutral-sysprompt mean-diff, then iso-KL dose to target_kl."""
|
||||
"""trait-prefix vs neutral-prefix mean-diff over DIVERSE contexts, at the assistant tag."""
|
||||
layers = _layer_band(model, cfg.steer_layers) # narrow band; raw mean-diff compounds across layers
|
||||
prompts = POOL[: cfg.n_prompts] if cfg.n_prompts <= len(POOL) else POOL
|
||||
pos = [chat_prompt(tok, cfg.trait, q) for q in prompts]
|
||||
neg = [chat_prompt(tok, cfg.neutral, q) for q in prompts]
|
||||
contexts = _extract_prompts(cfg)
|
||||
pos = [chat_prompt(tok, cfg.pos_persona, q) for q in contexts]
|
||||
neg = [chat_prompt(tok, cfg.neg_persona, q) for q in contexts]
|
||||
|
||||
# SHOULD: pos/neg end at the assistant tag (last token); the two differ ONLY
|
||||
# in the system prompt. ELSE the vector mixes in user-turn differences.
|
||||
# in the system prompt (the persona prefix). ELSE the vector mixes in user-turn
|
||||
# differences. n_pairs ~256 diverse contexts (steering-lite reference), not 30 dilemmas.
|
||||
logger.info(f"teacher_vec: {len(pos)} contrastive pairs over diverse contexts, layers={layers}")
|
||||
logger.debug(f"--- POS[0] (trait) ---\n{pos[0]}\n--- NEG[0] (neutral) ---\n{neg[0]}")
|
||||
|
||||
# RAW (unnormalised) mean-diff = the residual-stream shift the trait system
|
||||
@@ -49,21 +70,27 @@ def generate_steered(model, tok, v, cfg: RunConfig) -> list[dict]:
|
||||
alpha collapses, and we keep the coherent-but-trait-laden ones.
|
||||
"""
|
||||
out = []
|
||||
n_total = cfg.n_prompts * len(cfg.alphas)
|
||||
logger.info(f"\n=== GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas] "
|
||||
f"gpu {gpu_mem()} ===")
|
||||
pbar = tqdm(total=n_total, desc="gen steered", mininterval=120, maxinterval=120)
|
||||
for i in range(cfg.n_prompts):
|
||||
user = POOL[i % len(POOL)]
|
||||
text = chat_prompt(tok, cfg.neutral, user) # neutral prompt; the vector carries the trait
|
||||
text = chat_prompt(tok, cfg.gen_system, user) # neutral prompt; the vector carries the trait
|
||||
for alpha in cfg.alphas:
|
||||
with v(model, C=alpha * v.cfg.coeff):
|
||||
comp = _gen_one(model, tok, text, cfg)
|
||||
out.append({"user": user, "prompt": text, "completion": comp, "alpha": float(alpha)})
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
return out
|
||||
|
||||
|
||||
def generate_plain(model, tok, cfg: RunConfig, n: int) -> list[dict]:
|
||||
"""Generate from the (baked) model with NO steering, for the Q1 heal comparison."""
|
||||
out = []
|
||||
for i in range(n):
|
||||
for i in tqdm(range(n), desc="gen adapter", mininterval=120, maxinterval=120):
|
||||
user = POOL[i % len(POOL)]
|
||||
text = chat_prompt(tok, cfg.neutral, user)
|
||||
text = chat_prompt(tok, cfg.gen_system, user)
|
||||
out.append({"user": user, "prompt": text, "completion": _gen_one(model, tok, text, cfg)})
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user