mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:47:16 +08:00
drop calibration; sweep C + filter; SHOULD logging for all Q's; 4B default
Per user: no iso-KL calibration. Use raw (unnormalised) mean-diff teacher vector; sweep cfg.alphas at generation and let the FILTER pick usable C (filter replaces calibration). Default model google/gemma-3-4b-it (1B too dumb; Authority degenerate there was a model artifact, not a real conclusion). Token-efficient discriminating logs so each Q is readable: - Q0: filter table (alpha -> ppl_mean, kept_frac) + low/high-C samples + SHOULD - Q1: generate from trained adapter (no steering); adapter_ppl vs steered_ppl under the original + sample + SHOULD (heal = adapter more coherent than steered) - Q2/Q3: loop summary table (socialnorms/care/coherence/cos_v0 per round) + SHOULD fast-dev-run green: ppl rises with alpha (3173->4.2M), adapter_ppl<<steered_ppl. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -11,7 +11,7 @@ class RunConfig:
|
||||
"""
|
||||
|
||||
# ── model ──
|
||||
model: str = "google/gemma-3-1b-it"
|
||||
model: str = "google/gemma-3-4b-it"
|
||||
fast_dev_model: str = "wassname/qwen3-5lyr-tiny-random"
|
||||
dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
|
||||
|
||||
@@ -22,9 +22,7 @@ class RunConfig:
|
||||
)
|
||||
neutral: str = "You are a helpful assistant."
|
||||
layer_range: tuple[float, float] = (0.4, 0.6) # fraction of depth to steer
|
||||
target_kl: float = 1.0 # iso-KL p95 dose (nats)
|
||||
gen_alpha: float = 1.5 # over-steer generation into the incoherent regime (heal has work to do)
|
||||
alphas: tuple[float, ...] = (0.5, 1.0, 1.5, 2.0) # multiples of c_star to generate at
|
||||
alphas: tuple[float, ...] = (0.5, 1.0, 2.0, 4.0) # raw-vector multiples to sweep; filter picks usable C
|
||||
|
||||
# ── generation + filter (U1) ──
|
||||
n_prompts: int = 64
|
||||
@@ -61,7 +59,7 @@ TINY = dict(
|
||||
max_len=128,
|
||||
epochs=1,
|
||||
n_rounds=1,
|
||||
alphas=(1.0,),
|
||||
alphas=(1.0, 4.0),
|
||||
eval_vignettes=4,
|
||||
eval_think_tokens=16,
|
||||
ppl_tau=1e9, # tiny-random produces junk ppl; relax the gate so the path still runs
|
||||
|
||||
@@ -54,8 +54,33 @@ def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
|
||||
keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar)
|
||||
scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "keep": keep})
|
||||
kept = [s for s in scored if s["keep"]]
|
||||
# SHOULD: on a real model, kept drops the babble and keeps fluent dilemmas
|
||||
# answers; if kept==0 the gate is too strict (raise ppl_tau) or steering
|
||||
# broke generation. On tiny-random everything passes (relaxed tau).
|
||||
logger.info(f"filter: kept {len(kept)}/{len(comps)} (ppl<{cfg.ppl_tau:g}, rep<{cfg.rep_tau}, not-narrate)")
|
||||
_log_filter_report(scored, cfg)
|
||||
return kept[: cfg.n_keep], scored
|
||||
|
||||
|
||||
def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None:
|
||||
"""Q0 evidence: does the filter separate coherent (low C) from incoherent (high C)?"""
|
||||
import polars as pl
|
||||
from tabulate import tabulate
|
||||
|
||||
df = pl.DataFrame([{k: s[k] for k in ("alpha", "ppl", "rep", "narrates", "keep")} for s in scored])
|
||||
g = (df.group_by("alpha")
|
||||
.agg(pl.col("ppl").mean().round(1).alias("ppl_mean"),
|
||||
pl.col("keep").mean().round(2).alias("kept_frac"),
|
||||
pl.len().alias("n"))
|
||||
.sort("alpha"))
|
||||
logger.info(
|
||||
"SHOULD (Q0 filter): ppl_mean RISES with alpha (stronger steering = less coherent) and "
|
||||
"kept_frac FALLS. If kept_frac is flat across alpha, the filter is inert / threshold wrong "
|
||||
"and we CANNOT filter. If ppl_mean is flat, steering did not inject incoherency."
|
||||
)
|
||||
logger.info("\nfilter vs steering strength:\n" +
|
||||
tabulate(g.to_pandas(), headers="keys", tablefmt="github", floatfmt=".2f"))
|
||||
lo = min(scored, key=lambda s: s["alpha"])
|
||||
hi = max(scored, key=lambda s: s["alpha"])
|
||||
logger.info(f"\n--- SAMPLE @alpha={lo['alpha']:g} ppl={lo['ppl']:.0f} keep={lo['keep']} "
|
||||
f"(SHOULD be coherent) ---\n{lo['completion'][:500]}")
|
||||
logger.info(f"\n--- SAMPLE @alpha={hi['alpha']:g} ppl={hi['ppl']:.0f} keep={hi['keep']} "
|
||||
f"(SHOULD be garbage if steering strong) ---\n{hi['completion'][:500]}")
|
||||
logger.info(f"filter kept {len([s for s in scored if s['keep']])}/{len(scored)} "
|
||||
f"(ppl<{cfg.ppl_tau:g}, rep<{cfg.rep_tau}, not-narrate)")
|
||||
|
||||
+39
-10
@@ -17,11 +17,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from steer_heal.config import RunConfig, resolve
|
||||
from steer_heal.eval import evaluate_model
|
||||
from steer_heal.filter import filter_completions
|
||||
from steer_heal.filter import filter_completions, ppl_under_base
|
||||
from steer_heal.heal import heal_round
|
||||
from steer_heal.io import append_result, log_event, make_run_dir
|
||||
from steer_heal.plot import write_map
|
||||
from steer_heal.steering import generate_steered, teacher_vec
|
||||
from steer_heal.steering import generate_plain, generate_steered, teacher_vec
|
||||
from steer_heal.ws.bake import baked
|
||||
|
||||
REPO = Path(__file__).resolve().parents[2]
|
||||
@@ -58,17 +58,22 @@ def _flatten_v(v) -> torch.Tensor:
|
||||
return torch.cat([v.state[li]["v"].flatten().float() for li in sorted(v.state)])
|
||||
|
||||
|
||||
def _mean_finite(xs) -> float:
|
||||
xs = [x for x in xs if x == x and x != float("inf")]
|
||||
return sum(xs) / len(xs) if xs else float("nan")
|
||||
|
||||
|
||||
def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
hist_specs = [] # AdapterSpec per folded round (gated bake history)
|
||||
v0_flat = None # round-0 direction, for the Q3 cosine
|
||||
rounds = []
|
||||
for rnd in range(cfg.n_rounds):
|
||||
logger.info(f"\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] ===")
|
||||
# extract teacher vector + generate steered data from the CURRENT student
|
||||
# extract teacher vector + sweep-generate steered data from the CURRENT student
|
||||
with baked(model, hist_specs):
|
||||
v = teacher_vec(model, tok, cfg)
|
||||
comps = generate_steered(model, tok, v, alpha=cfg.gen_alpha, cfg=cfg)
|
||||
# filter under the ORIGINAL (no history, no steering)
|
||||
comps = generate_steered(model, tok, v, cfg)
|
||||
# filter under the ORIGINAL (no history, no steering) -- this picks the usable C
|
||||
kept, scored = filter_completions(model, tok, comps, cfg)
|
||||
log_event(run_dir, stage="gen", round=rnd, n_comps=len(comps), n_kept=len(kept), scored=scored)
|
||||
|
||||
@@ -77,24 +82,48 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg})
|
||||
hist_specs.append(spec)
|
||||
|
||||
# eval the student (all rounds baked)
|
||||
# eval the student (all rounds baked) + Q1: trained-adapter output coherence
|
||||
with baked(model, hist_specs):
|
||||
m = evaluate_model(model, tok, cfg)
|
||||
adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts))
|
||||
adapter_ppl = _mean_finite([ppl_under_base(model, tok, a["prompt"], a["completion"]) for a in adapter])
|
||||
steered_ppl = _mean_finite([s["ppl"] for s in scored])
|
||||
logger.info(
|
||||
"SHOULD (Q1 heal): adapter_ppl < steered_ppl means the trained model expresses the trait "
|
||||
"COHERENTLY (healed) where raw steering was incoherent. If adapter_ppl >= steered_ppl, "
|
||||
f"healing failed. adapter_ppl={adapter_ppl:.0f} steered_ppl={steered_ppl:.0f}"
|
||||
)
|
||||
logger.info(f"--- ADAPTER SAMPLE r{rnd} (no steering, SHOULD show trait + be coherent) ---\n"
|
||||
f"{adapter[0]['completion'][:500]}")
|
||||
|
||||
vf = _flatten_v(v)
|
||||
v0_flat = vf if v0_flat is None else v0_flat
|
||||
cos_v0 = float(cosine_similarity(vf, v0_flat, dim=0))
|
||||
rec = {"round": rnd, **m, "cos_v0": cos_v0, "c_star": float(v.cfg.coeff), "n_kept": len(kept)}
|
||||
rec = {"round": rnd, **m, "cos_v0": cos_v0, "steered_ppl": steered_ppl,
|
||||
"adapter_ppl": adapter_ppl, "n_kept": len(kept)}
|
||||
rounds.append(rec)
|
||||
log_event(run_dir, stage="round", **rec)
|
||||
logger.info(f"round {rnd}: socialnorms={m['socialnorms']:.3f} care={m['care']:.3f} "
|
||||
f"coh={m['coherence']:.3f} cos_v0={cos_v0:+.2f}")
|
||||
f"coh={m['coherence']:.3f} cos_v0={cos_v0:+.2f} adapter_ppl={adapter_ppl:.0f}")
|
||||
|
||||
map_path = write_map(run_dir, rounds)
|
||||
logger.info(f"map: {map_path}")
|
||||
_log_loop_summary(rounds)
|
||||
write_map(run_dir, rounds)
|
||||
return rounds[-1]
|
||||
|
||||
|
||||
def _log_loop_summary(rounds: list[dict]) -> None:
|
||||
from tabulate import tabulate
|
||||
logger.info(
|
||||
"\nSHOULD (Q2 loop-coherent): coherence stays >= round-0 floor across rounds (heal holds it up). "
|
||||
"If coherence falls each round, the loop accumulates incoherency faster than heal removes it.\n"
|
||||
"SHOULD (Q3 direction): socialnorms FALLS / care RISES monotonically and cos_v0 stays > 0.5 "
|
||||
"(same direction each round). If the trait reverses or cos_v0 drops, the direction wanders."
|
||||
)
|
||||
cols = ["round", "socialnorms", "care", "coherence", "cos_v0", "adapter_ppl", "n_kept"]
|
||||
tbl = [{c: r.get(c) for c in cols} for r in rounds]
|
||||
logger.info("\nloop summary:\n" + tabulate(tbl, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
|
||||
def main(cfg: RunConfig) -> None:
|
||||
setup_logging()
|
||||
cfg = resolve(cfg)
|
||||
|
||||
+33
-17
@@ -25,29 +25,45 @@ def teacher_vec(model, tok, cfg: RunConfig):
|
||||
# in the system prompt. ELSE the vector mixes in user-turn differences.
|
||||
logger.debug(f"--- POS[0] (trait) ---\n{pos[0]}\n--- NEG[0] (neutral) ---\n{neg[0]}")
|
||||
|
||||
v = sl.Vector.train(model, tok, pos, neg, cfg=sl.MeanDiffC(layers=layers, normalize=True))
|
||||
# Wide bracket: the vector is unit-normalised, so reaching ~1 nat p95 KL on a
|
||||
# real model needs c ~ O(100) (KL ~ c^2). steering-lite's default hi (~16) is
|
||||
# too low and pins c_star at the bracket top. See RESEARCH_JOURNAL 2026-06-04.
|
||||
v.calibrate(model, tok, target_kl=cfg.target_kl, bracket=(0.1, 1024.0))
|
||||
logger.info(f"teacher_vec: layers={layers} c_star={v.cfg.coeff:+.4f} (target_kl={cfg.target_kl})")
|
||||
# RAW (unnormalised) mean-diff = the residual-stream shift the trait system
|
||||
# prompt induces (Subliminal Learning teacher vector). No iso-KL calibration:
|
||||
# we steer at the natural scale (coeff = gen_alpha) and let the SFT/nll
|
||||
# training + coherence filter self-calibrate the strength.
|
||||
v = sl.Vector.train(model, tok, pos, neg, cfg=sl.MeanDiffC(layers=layers, normalize=False))
|
||||
logger.info(f"teacher_vec: layers={layers} raw mean-diff (no calibration), coeff={v.cfg.coeff}")
|
||||
return v
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_steered(model, tok, v, alpha: float, cfg: RunConfig) -> list[dict]:
|
||||
"""Generate at C = alpha * c_star. Returns [{prompt, user, completion}]."""
|
||||
def _gen_one(model, tok, text, cfg):
|
||||
ids = tok(text, return_tensors="pt").to(model.device)
|
||||
gen = model.generate(**ids, max_new_tokens=cfg.gen_max_new_tokens, do_sample=True,
|
||||
temperature=1.0, top_p=0.95, pad_token_id=tok.pad_token_id)
|
||||
return tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
|
||||
|
||||
def generate_steered(model, tok, v, cfg: RunConfig) -> list[dict]:
|
||||
"""Sweep cfg.alphas (raw-vector multiples); generate one completion per prompt x alpha.
|
||||
|
||||
The filter (Q0), not iso-KL, picks the usable C: low alpha is coherent, high
|
||||
alpha collapses, and we keep the coherent-but-trait-laden ones.
|
||||
"""
|
||||
out = []
|
||||
C = alpha * v.cfg.coeff
|
||||
for i in range(cfg.n_prompts):
|
||||
user = POOL[i % len(POOL)]
|
||||
text = chat_prompt(tok, cfg.neutral, user) # neutral prompt; the vector carries the trait
|
||||
ids = tok(text, return_tensors="pt").to(model.device)
|
||||
with v(model, C=C):
|
||||
gen = model.generate(**ids, max_new_tokens=cfg.gen_max_new_tokens,
|
||||
do_sample=True, temperature=1.0, top_p=0.95,
|
||||
pad_token_id=tok.pad_token_id)
|
||||
completion = tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
out.append({"user": user, "prompt": text, "completion": completion})
|
||||
logger.debug(f"--- GEN[0] @C={C:+.3f} ---\nUSER: {out[0]['user']}\nCOMP: {out[0]['completion'][:400]}")
|
||||
for alpha in cfg.alphas:
|
||||
with v(model, C=alpha * v.cfg.coeff):
|
||||
comp = _gen_one(model, tok, text, cfg)
|
||||
out.append({"user": user, "prompt": text, "completion": comp, "alpha": float(alpha)})
|
||||
return out
|
||||
|
||||
|
||||
def generate_plain(model, tok, cfg: RunConfig, n: int) -> list[dict]:
|
||||
"""Generate from the (baked) model with NO steering, for the Q1 heal comparison."""
|
||||
out = []
|
||||
for i in range(n):
|
||||
user = POOL[i % len(POOL)]
|
||||
text = chat_prompt(tok, cfg.neutral, user)
|
||||
out.append({"user": user, "prompt": text, "completion": _gen_one(model, tok, text, cfg)})
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user