diff --git a/src/steer_heal/config.py b/src/steer_heal/config.py index 88470b8..c275eba 100644 --- a/src/steer_heal/config.py +++ b/src/steer_heal/config.py @@ -11,7 +11,7 @@ class RunConfig: """ # ── model ── - model: str = "google/gemma-3-1b-it" + model: str = "google/gemma-3-4b-it" fast_dev_model: str = "wassname/qwen3-5lyr-tiny-random" dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16" @@ -22,9 +22,7 @@ class RunConfig: ) neutral: str = "You are a helpful assistant." layer_range: tuple[float, float] = (0.4, 0.6) # fraction of depth to steer - target_kl: float = 1.0 # iso-KL p95 dose (nats) - gen_alpha: float = 1.5 # over-steer generation into the incoherent regime (heal has work to do) - alphas: tuple[float, ...] = (0.5, 1.0, 1.5, 2.0) # multiples of c_star to generate at + alphas: tuple[float, ...] = (0.5, 1.0, 2.0, 4.0) # raw-vector multiples to sweep; filter picks usable C # ── generation + filter (U1) ── n_prompts: int = 64 @@ -61,7 +59,7 @@ TINY = dict( max_len=128, epochs=1, n_rounds=1, - alphas=(1.0,), + alphas=(1.0, 4.0), eval_vignettes=4, eval_think_tokens=16, ppl_tau=1e9, # tiny-random produces junk ppl; relax the gate so the path still runs diff --git a/src/steer_heal/filter.py b/src/steer_heal/filter.py index 77d6190..f1d0797 100644 --- a/src/steer_heal/filter.py +++ b/src/steer_heal/filter.py @@ -54,8 +54,33 @@ def filter_completions(model, tok, comps: list[dict], cfg: RunConfig): keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar) scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "keep": keep}) kept = [s for s in scored if s["keep"]] - # SHOULD: on a real model, kept drops the babble and keeps fluent dilemmas - # answers; if kept==0 the gate is too strict (raise ppl_tau) or steering - # broke generation. On tiny-random everything passes (relaxed tau). - logger.info(f"filter: kept {len(kept)}/{len(comps)} (ppl<{cfg.ppl_tau:g}, rep<{cfg.rep_tau}, not-narrate)") + _log_filter_report(scored, cfg) return kept[: cfg.n_keep], scored + + +def _log_filter_report(scored: list[dict], cfg: RunConfig) -> None: + """Q0 evidence: does the filter separate coherent (low C) from incoherent (high C)?""" + import polars as pl + from tabulate import tabulate + + df = pl.DataFrame([{k: s[k] for k in ("alpha", "ppl", "rep", "narrates", "keep")} for s in scored]) + g = (df.group_by("alpha") + .agg(pl.col("ppl").mean().round(1).alias("ppl_mean"), + pl.col("keep").mean().round(2).alias("kept_frac"), + pl.len().alias("n")) + .sort("alpha")) + logger.info( + "SHOULD (Q0 filter): ppl_mean RISES with alpha (stronger steering = less coherent) and " + "kept_frac FALLS. If kept_frac is flat across alpha, the filter is inert / threshold wrong " + "and we CANNOT filter. If ppl_mean is flat, steering did not inject incoherency." + ) + logger.info("\nfilter vs steering strength:\n" + + tabulate(g.to_pandas(), headers="keys", tablefmt="github", floatfmt=".2f")) + lo = min(scored, key=lambda s: s["alpha"]) + hi = max(scored, key=lambda s: s["alpha"]) + logger.info(f"\n--- SAMPLE @alpha={lo['alpha']:g} ppl={lo['ppl']:.0f} keep={lo['keep']} " + f"(SHOULD be coherent) ---\n{lo['completion'][:500]}") + logger.info(f"\n--- SAMPLE @alpha={hi['alpha']:g} ppl={hi['ppl']:.0f} keep={hi['keep']} " + f"(SHOULD be garbage if steering strong) ---\n{hi['completion'][:500]}") + logger.info(f"filter kept {len([s for s in scored if s['keep']])}/{len(scored)} " + f"(ppl<{cfg.ppl_tau:g}, rep<{cfg.rep_tau}, not-narrate)") diff --git a/src/steer_heal/run.py b/src/steer_heal/run.py index e136343..f3b1398 100644 --- a/src/steer_heal/run.py +++ b/src/steer_heal/run.py @@ -17,11 +17,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from steer_heal.config import RunConfig, resolve from steer_heal.eval import evaluate_model -from steer_heal.filter import filter_completions +from steer_heal.filter import filter_completions, ppl_under_base from steer_heal.heal import heal_round from steer_heal.io import append_result, log_event, make_run_dir from steer_heal.plot import write_map -from steer_heal.steering import generate_steered, teacher_vec +from steer_heal.steering import generate_plain, generate_steered, teacher_vec from steer_heal.ws.bake import baked REPO = Path(__file__).resolve().parents[2] @@ -58,17 +58,22 @@ def _flatten_v(v) -> torch.Tensor: return torch.cat([v.state[li]["v"].flatten().float() for li in sorted(v.state)]) +def _mean_finite(xs) -> float: + xs = [x for x in xs if x == x and x != float("inf")] + return sum(xs) / len(xs) if xs else float("nan") + + def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: hist_specs = [] # AdapterSpec per folded round (gated bake history) v0_flat = None # round-0 direction, for the Q3 cosine rounds = [] for rnd in range(cfg.n_rounds): logger.info(f"\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] ===") - # extract teacher vector + generate steered data from the CURRENT student + # extract teacher vector + sweep-generate steered data from the CURRENT student with baked(model, hist_specs): v = teacher_vec(model, tok, cfg) - comps = generate_steered(model, tok, v, alpha=cfg.gen_alpha, cfg=cfg) - # filter under the ORIGINAL (no history, no steering) + comps = generate_steered(model, tok, v, cfg) + # filter under the ORIGINAL (no history, no steering) -- this picks the usable C kept, scored = filter_completions(model, tok, comps, cfg) log_event(run_dir, stage="gen", round=rnd, n_comps=len(comps), n_kept=len(kept), scored=scored) @@ -77,24 +82,48 @@ def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict: lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg}) hist_specs.append(spec) - # eval the student (all rounds baked) + # eval the student (all rounds baked) + Q1: trained-adapter output coherence with baked(model, hist_specs): m = evaluate_model(model, tok, cfg) + adapter = generate_plain(model, tok, cfg, n=min(6, cfg.n_prompts)) + adapter_ppl = _mean_finite([ppl_under_base(model, tok, a["prompt"], a["completion"]) for a in adapter]) + steered_ppl = _mean_finite([s["ppl"] for s in scored]) + logger.info( + "SHOULD (Q1 heal): adapter_ppl < steered_ppl means the trained model expresses the trait " + "COHERENTLY (healed) where raw steering was incoherent. If adapter_ppl >= steered_ppl, " + f"healing failed. adapter_ppl={adapter_ppl:.0f} steered_ppl={steered_ppl:.0f}" + ) + logger.info(f"--- ADAPTER SAMPLE r{rnd} (no steering, SHOULD show trait + be coherent) ---\n" + f"{adapter[0]['completion'][:500]}") vf = _flatten_v(v) v0_flat = vf if v0_flat is None else v0_flat cos_v0 = float(cosine_similarity(vf, v0_flat, dim=0)) - rec = {"round": rnd, **m, "cos_v0": cos_v0, "c_star": float(v.cfg.coeff), "n_kept": len(kept)} + rec = {"round": rnd, **m, "cos_v0": cos_v0, "steered_ppl": steered_ppl, + "adapter_ppl": adapter_ppl, "n_kept": len(kept)} rounds.append(rec) log_event(run_dir, stage="round", **rec) logger.info(f"round {rnd}: socialnorms={m['socialnorms']:.3f} care={m['care']:.3f} " - f"coh={m['coherence']:.3f} cos_v0={cos_v0:+.2f}") + f"coh={m['coherence']:.3f} cos_v0={cos_v0:+.2f} adapter_ppl={adapter_ppl:.0f}") - map_path = write_map(run_dir, rounds) - logger.info(f"map: {map_path}") + _log_loop_summary(rounds) + write_map(run_dir, rounds) return rounds[-1] +def _log_loop_summary(rounds: list[dict]) -> None: + from tabulate import tabulate + logger.info( + "\nSHOULD (Q2 loop-coherent): coherence stays >= round-0 floor across rounds (heal holds it up). " + "If coherence falls each round, the loop accumulates incoherency faster than heal removes it.\n" + "SHOULD (Q3 direction): socialnorms FALLS / care RISES monotonically and cos_v0 stays > 0.5 " + "(same direction each round). If the trait reverses or cos_v0 drops, the direction wanders." + ) + cols = ["round", "socialnorms", "care", "coherence", "cos_v0", "adapter_ppl", "n_kept"] + tbl = [{c: r.get(c) for c in cols} for r in rounds] + logger.info("\nloop summary:\n" + tabulate(tbl, headers="keys", tablefmt="github", floatfmt=".3f")) + + def main(cfg: RunConfig) -> None: setup_logging() cfg = resolve(cfg) diff --git a/src/steer_heal/steering.py b/src/steer_heal/steering.py index 99ecae1..1854578 100644 --- a/src/steer_heal/steering.py +++ b/src/steer_heal/steering.py @@ -25,29 +25,45 @@ def teacher_vec(model, tok, cfg: RunConfig): # in the system prompt. ELSE the vector mixes in user-turn differences. logger.debug(f"--- POS[0] (trait) ---\n{pos[0]}\n--- NEG[0] (neutral) ---\n{neg[0]}") - v = sl.Vector.train(model, tok, pos, neg, cfg=sl.MeanDiffC(layers=layers, normalize=True)) - # Wide bracket: the vector is unit-normalised, so reaching ~1 nat p95 KL on a - # real model needs c ~ O(100) (KL ~ c^2). steering-lite's default hi (~16) is - # too low and pins c_star at the bracket top. See RESEARCH_JOURNAL 2026-06-04. - v.calibrate(model, tok, target_kl=cfg.target_kl, bracket=(0.1, 1024.0)) - logger.info(f"teacher_vec: layers={layers} c_star={v.cfg.coeff:+.4f} (target_kl={cfg.target_kl})") + # RAW (unnormalised) mean-diff = the residual-stream shift the trait system + # prompt induces (Subliminal Learning teacher vector). No iso-KL calibration: + # we steer at the natural scale (coeff = gen_alpha) and let the SFT/nll + # training + coherence filter self-calibrate the strength. + v = sl.Vector.train(model, tok, pos, neg, cfg=sl.MeanDiffC(layers=layers, normalize=False)) + logger.info(f"teacher_vec: layers={layers} raw mean-diff (no calibration), coeff={v.cfg.coeff}") return v @torch.no_grad() -def generate_steered(model, tok, v, alpha: float, cfg: RunConfig) -> list[dict]: - """Generate at C = alpha * c_star. Returns [{prompt, user, completion}].""" +def _gen_one(model, tok, text, cfg): + ids = tok(text, return_tensors="pt").to(model.device) + gen = model.generate(**ids, max_new_tokens=cfg.gen_max_new_tokens, do_sample=True, + temperature=1.0, top_p=0.95, pad_token_id=tok.pad_token_id) + return tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True) + + +def generate_steered(model, tok, v, cfg: RunConfig) -> list[dict]: + """Sweep cfg.alphas (raw-vector multiples); generate one completion per prompt x alpha. + + The filter (Q0), not iso-KL, picks the usable C: low alpha is coherent, high + alpha collapses, and we keep the coherent-but-trait-laden ones. + """ out = [] - C = alpha * v.cfg.coeff for i in range(cfg.n_prompts): user = POOL[i % len(POOL)] text = chat_prompt(tok, cfg.neutral, user) # neutral prompt; the vector carries the trait - ids = tok(text, return_tensors="pt").to(model.device) - with v(model, C=C): - gen = model.generate(**ids, max_new_tokens=cfg.gen_max_new_tokens, - do_sample=True, temperature=1.0, top_p=0.95, - pad_token_id=tok.pad_token_id) - completion = tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True) - out.append({"user": user, "prompt": text, "completion": completion}) - logger.debug(f"--- GEN[0] @C={C:+.3f} ---\nUSER: {out[0]['user']}\nCOMP: {out[0]['completion'][:400]}") + for alpha in cfg.alphas: + with v(model, C=alpha * v.cfg.coeff): + comp = _gen_one(model, tok, text, cfg) + out.append({"user": user, "prompt": text, "completion": comp, "alpha": float(alpha)}) + return out + + +def generate_plain(model, tok, cfg: RunConfig, n: int) -> list[dict]: + """Generate from the (baked) model with NO steering, for the Q1 heal comparison.""" + out = [] + for i in range(n): + user = POOL[i % len(POOL)] + text = chat_prompt(tok, cfg.neutral, user) + out.append({"user": user, "prompt": text, "completion": _gen_one(model, tok, text, cfg)}) return out