mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 15:32:28 +08:00
Support round-tagged steered generation
This commit is contained in:
@@ -28,8 +28,11 @@ def _extract_prompts(cfg: RunConfig) -> list[str]:
|
||||
NOT domain dilemmas). A domain-narrow set overfits the direction to the format;
|
||||
diverse suffixes isolate the persona's general residual-stream shift."""
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
suffixes = json.loads(Path(cfg.extract_data).read_text())
|
||||
rng = random.Random(cfg.seed)
|
||||
rng.shuffle(suffixes)
|
||||
return [s["suffix"] for s in suffixes[: cfg.n_extract_pairs]]
|
||||
|
||||
|
||||
@@ -44,7 +47,21 @@ def teacher_vec(model, tok, cfg: RunConfig):
|
||||
# in the system prompt (the persona prefix). ELSE the vector mixes in user-turn
|
||||
# differences. n_pairs ~256 diverse contexts (steering-lite reference), not 30 dilemmas.
|
||||
logger.info(f"teacher_vec: {len(pos)} contrastive pairs over diverse contexts, layers={layers}")
|
||||
logger.debug(f"--- POS[0] (trait) ---\n{pos[0]}\n--- NEG[0] (neutral) ---\n{neg[0]}")
|
||||
# Show completions for the first pair AND a seeded pick (avoids always landing on
|
||||
# the same weird first suffix). Seed primes which pair so it varies across runs.
|
||||
demo_indices = {0, (cfg.seed * 7) % len(pos)}
|
||||
for idx in sorted(demo_indices):
|
||||
pos_comp = _gen_one(model, tok, pos[idx], cfg, greedy=True)[:256]
|
||||
neg_comp = _gen_one(model, tok, neg[idx], cfg, greedy=True)[:256]
|
||||
logger.info(
|
||||
f"\n=== EXTRACT demo trace pair[{idx}] ===\n"
|
||||
f"POS prompt: {pos[idx][:200]}...\n"
|
||||
f"POS comp (64): {pos_comp[:64]}\n"
|
||||
f"NEG prompt: {neg[idx][:200]}...\n"
|
||||
f"NEG comp (64): {neg_comp[:64]}\n"
|
||||
f"--- full POS comp ---\n{pos_comp}\n"
|
||||
f"--- full NEG comp ---\n{neg_comp}"
|
||||
)
|
||||
|
||||
# RAW (unnormalised) mean-diff = the residual-stream shift the trait system
|
||||
# prompt induces (Subliminal Learning teacher vector). No iso-KL calibration:
|
||||
@@ -82,7 +99,7 @@ def _gen_one(model, tok, text, cfg, greedy: bool = False):
|
||||
|
||||
|
||||
def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0,
|
||||
max_gens: int | None = None) -> list[dict]:
|
||||
max_gens: int | None = None, rnd: int | None = None) -> list[dict]:
|
||||
"""Sweep cfg.alphas (raw-vector multiples); generate one completion per prompt x alpha.
|
||||
|
||||
The filter (Q0), not iso-KL, picks the usable C: low alpha is coherent, high
|
||||
@@ -93,7 +110,8 @@ def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0,
|
||||
"""
|
||||
out = []
|
||||
n_total = min(cfg.n_prompts * len(cfg.alphas), max_gens) if max_gens else cfg.n_prompts * len(cfg.alphas)
|
||||
logger.info(f"\n=== GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas, "
|
||||
rtag = f"r{rnd} " if rnd is not None else ""
|
||||
logger.info(f"\n\n\n=== {rtag}GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas, "
|
||||
f"kappa={alpha_scale:.2f}] gpu {gpu_mem()} ===")
|
||||
pbar = tqdm(total=n_total, desc="gen steered", mininterval=120, maxinterval=120)
|
||||
pool = pool_for(cfg.demo)
|
||||
|
||||
Reference in New Issue
Block a user