Support round-tagged steered generation

This commit is contained in:
wassname
2026-06-24 20:49:48 +08:00
parent 282fb3de47
commit ee54945076
+21 -3
View File
@@ -28,8 +28,11 @@ def _extract_prompts(cfg: RunConfig) -> list[str]:
NOT domain dilemmas). A domain-narrow set overfits the direction to the format;
diverse suffixes isolate the persona's general residual-stream shift."""
import json
import random
from pathlib import Path
suffixes = json.loads(Path(cfg.extract_data).read_text())
rng = random.Random(cfg.seed)
rng.shuffle(suffixes)
return [s["suffix"] for s in suffixes[: cfg.n_extract_pairs]]
@@ -44,7 +47,21 @@ def teacher_vec(model, tok, cfg: RunConfig):
# in the system prompt (the persona prefix). ELSE the vector mixes in user-turn
# differences. n_pairs ~256 diverse contexts (steering-lite reference), not 30 dilemmas.
logger.info(f"teacher_vec: {len(pos)} contrastive pairs over diverse contexts, layers={layers}")
logger.debug(f"--- POS[0] (trait) ---\n{pos[0]}\n--- NEG[0] (neutral) ---\n{neg[0]}")
# Show completions for the first pair AND a seeded pick (avoids always landing on
# the same weird first suffix). Seed primes which pair so it varies across runs.
demo_indices = {0, (cfg.seed * 7) % len(pos)}
for idx in sorted(demo_indices):
pos_comp = _gen_one(model, tok, pos[idx], cfg, greedy=True)[:256]
neg_comp = _gen_one(model, tok, neg[idx], cfg, greedy=True)[:256]
logger.info(
f"\n=== EXTRACT demo trace pair[{idx}] ===\n"
f"POS prompt: {pos[idx][:200]}...\n"
f"POS comp (64): {pos_comp[:64]}\n"
f"NEG prompt: {neg[idx][:200]}...\n"
f"NEG comp (64): {neg_comp[:64]}\n"
f"--- full POS comp ---\n{pos_comp}\n"
f"--- full NEG comp ---\n{neg_comp}"
)
# RAW (unnormalised) mean-diff = the residual-stream shift the trait system
# prompt induces (Subliminal Learning teacher vector). No iso-KL calibration:
@@ -82,7 +99,7 @@ def _gen_one(model, tok, text, cfg, greedy: bool = False):
def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0,
max_gens: int | None = None) -> list[dict]:
max_gens: int | None = None, rnd: int | None = None) -> list[dict]:
"""Sweep cfg.alphas (raw-vector multiples); generate one completion per prompt x alpha.
The filter (Q0), not iso-KL, picks the usable C: low alpha is coherent, high
@@ -93,7 +110,8 @@ def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0,
"""
out = []
n_total = min(cfg.n_prompts * len(cfg.alphas), max_gens) if max_gens else cfg.n_prompts * len(cfg.alphas)
logger.info(f"\n=== GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas, "
rtag = f"r{rnd} " if rnd is not None else ""
logger.info(f"\n\n\n=== {rtag}GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas, "
f"kappa={alpha_scale:.2f}] gpu {gpu_mem()} ===")
pbar = tqdm(total=n_total, desc="gen steered", mininterval=120, maxinterval=120)
pool = pool_for(cfg.demo)