From ee549450769e94ac7ba6c2e6b64bdc8c3a5e5ae8 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Wed, 24 Jun 2026 20:49:48 +0800 Subject: [PATCH] Support round-tagged steered generation --- src/steer_heal/steering.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/steer_heal/steering.py b/src/steer_heal/steering.py index e3884f6..a7a66e2 100644 --- a/src/steer_heal/steering.py +++ b/src/steer_heal/steering.py @@ -28,8 +28,11 @@ def _extract_prompts(cfg: RunConfig) -> list[str]: NOT domain dilemmas). A domain-narrow set overfits the direction to the format; diverse suffixes isolate the persona's general residual-stream shift.""" import json + import random from pathlib import Path suffixes = json.loads(Path(cfg.extract_data).read_text()) + rng = random.Random(cfg.seed) + rng.shuffle(suffixes) return [s["suffix"] for s in suffixes[: cfg.n_extract_pairs]] @@ -44,7 +47,21 @@ def teacher_vec(model, tok, cfg: RunConfig): # in the system prompt (the persona prefix). ELSE the vector mixes in user-turn # differences. n_pairs ~256 diverse contexts (steering-lite reference), not 30 dilemmas. logger.info(f"teacher_vec: {len(pos)} contrastive pairs over diverse contexts, layers={layers}") - logger.debug(f"--- POS[0] (trait) ---\n{pos[0]}\n--- NEG[0] (neutral) ---\n{neg[0]}") + # Show completions for the first pair AND a seeded pick (avoids always landing on + # the same weird first suffix). Seed primes which pair so it varies across runs. + demo_indices = {0, (cfg.seed * 7) % len(pos)} + for idx in sorted(demo_indices): + pos_comp = _gen_one(model, tok, pos[idx], cfg, greedy=True)[:256] + neg_comp = _gen_one(model, tok, neg[idx], cfg, greedy=True)[:256] + logger.info( + f"\n=== EXTRACT demo trace pair[{idx}] ===\n" + f"POS prompt: {pos[idx][:200]}...\n" + f"POS comp (64): {pos_comp[:64]}\n" + f"NEG prompt: {neg[idx][:200]}...\n" + f"NEG comp (64): {neg_comp[:64]}\n" + f"--- full POS comp ---\n{pos_comp}\n" + f"--- full NEG comp ---\n{neg_comp}" + ) # RAW (unnormalised) mean-diff = the residual-stream shift the trait system # prompt induces (Subliminal Learning teacher vector). No iso-KL calibration: @@ -82,7 +99,7 @@ def _gen_one(model, tok, text, cfg, greedy: bool = False): def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0, - max_gens: int | None = None) -> list[dict]: + max_gens: int | None = None, rnd: int | None = None) -> list[dict]: """Sweep cfg.alphas (raw-vector multiples); generate one completion per prompt x alpha. The filter (Q0), not iso-KL, picks the usable C: low alpha is coherent, high @@ -93,7 +110,8 @@ def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0, """ out = [] n_total = min(cfg.n_prompts * len(cfg.alphas), max_gens) if max_gens else cfg.n_prompts * len(cfg.alphas) - logger.info(f"\n=== GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas, " + rtag = f"r{rnd} " if rnd is not None else "" + logger.info(f"\n\n\n=== {rtag}GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas, " f"kappa={alpha_scale:.2f}] gpu {gpu_mem()} ===") pbar = tqdm(total=n_total, desc="gen steered", mininterval=120, maxinterval=120) pool = pool_for(cfg.demo)