mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:47:16 +08:00
Support round-tagged steered generation
This commit is contained in:
@@ -28,8 +28,11 @@ def _extract_prompts(cfg: RunConfig) -> list[str]:
|
|||||||
NOT domain dilemmas). A domain-narrow set overfits the direction to the format;
|
NOT domain dilemmas). A domain-narrow set overfits the direction to the format;
|
||||||
diverse suffixes isolate the persona's general residual-stream shift."""
|
diverse suffixes isolate the persona's general residual-stream shift."""
|
||||||
import json
|
import json
|
||||||
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
suffixes = json.loads(Path(cfg.extract_data).read_text())
|
suffixes = json.loads(Path(cfg.extract_data).read_text())
|
||||||
|
rng = random.Random(cfg.seed)
|
||||||
|
rng.shuffle(suffixes)
|
||||||
return [s["suffix"] for s in suffixes[: cfg.n_extract_pairs]]
|
return [s["suffix"] for s in suffixes[: cfg.n_extract_pairs]]
|
||||||
|
|
||||||
|
|
||||||
@@ -44,7 +47,21 @@ def teacher_vec(model, tok, cfg: RunConfig):
|
|||||||
# in the system prompt (the persona prefix). ELSE the vector mixes in user-turn
|
# in the system prompt (the persona prefix). ELSE the vector mixes in user-turn
|
||||||
# differences. n_pairs ~256 diverse contexts (steering-lite reference), not 30 dilemmas.
|
# differences. n_pairs ~256 diverse contexts (steering-lite reference), not 30 dilemmas.
|
||||||
logger.info(f"teacher_vec: {len(pos)} contrastive pairs over diverse contexts, layers={layers}")
|
logger.info(f"teacher_vec: {len(pos)} contrastive pairs over diverse contexts, layers={layers}")
|
||||||
logger.debug(f"--- POS[0] (trait) ---\n{pos[0]}\n--- NEG[0] (neutral) ---\n{neg[0]}")
|
# Show completions for the first pair AND a seeded pick (avoids always landing on
|
||||||
|
# the same weird first suffix). Seed primes which pair so it varies across runs.
|
||||||
|
demo_indices = {0, (cfg.seed * 7) % len(pos)}
|
||||||
|
for idx in sorted(demo_indices):
|
||||||
|
pos_comp = _gen_one(model, tok, pos[idx], cfg, greedy=True)[:256]
|
||||||
|
neg_comp = _gen_one(model, tok, neg[idx], cfg, greedy=True)[:256]
|
||||||
|
logger.info(
|
||||||
|
f"\n=== EXTRACT demo trace pair[{idx}] ===\n"
|
||||||
|
f"POS prompt: {pos[idx][:200]}...\n"
|
||||||
|
f"POS comp (64): {pos_comp[:64]}\n"
|
||||||
|
f"NEG prompt: {neg[idx][:200]}...\n"
|
||||||
|
f"NEG comp (64): {neg_comp[:64]}\n"
|
||||||
|
f"--- full POS comp ---\n{pos_comp}\n"
|
||||||
|
f"--- full NEG comp ---\n{neg_comp}"
|
||||||
|
)
|
||||||
|
|
||||||
# RAW (unnormalised) mean-diff = the residual-stream shift the trait system
|
# RAW (unnormalised) mean-diff = the residual-stream shift the trait system
|
||||||
# prompt induces (Subliminal Learning teacher vector). No iso-KL calibration:
|
# prompt induces (Subliminal Learning teacher vector). No iso-KL calibration:
|
||||||
@@ -82,7 +99,7 @@ def _gen_one(model, tok, text, cfg, greedy: bool = False):
|
|||||||
|
|
||||||
|
|
||||||
def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0,
|
def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0,
|
||||||
max_gens: int | None = None) -> list[dict]:
|
max_gens: int | None = None, rnd: int | None = None) -> list[dict]:
|
||||||
"""Sweep cfg.alphas (raw-vector multiples); generate one completion per prompt x alpha.
|
"""Sweep cfg.alphas (raw-vector multiples); generate one completion per prompt x alpha.
|
||||||
|
|
||||||
The filter (Q0), not iso-KL, picks the usable C: low alpha is coherent, high
|
The filter (Q0), not iso-KL, picks the usable C: low alpha is coherent, high
|
||||||
@@ -93,7 +110,8 @@ def generate_steered(model, tok, v, cfg: RunConfig, alpha_scale: float = 1.0,
|
|||||||
"""
|
"""
|
||||||
out = []
|
out = []
|
||||||
n_total = min(cfg.n_prompts * len(cfg.alphas), max_gens) if max_gens else cfg.n_prompts * len(cfg.alphas)
|
n_total = min(cfg.n_prompts * len(cfg.alphas), max_gens) if max_gens else cfg.n_prompts * len(cfg.alphas)
|
||||||
logger.info(f"\n=== GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas, "
|
rtag = f"r{rnd} " if rnd is not None else ""
|
||||||
|
logger.info(f"\n\n\n=== {rtag}GEN steered [{n_total} = {cfg.n_prompts} prompts x {len(cfg.alphas)} alphas, "
|
||||||
f"kappa={alpha_scale:.2f}] gpu {gpu_mem()} ===")
|
f"kappa={alpha_scale:.2f}] gpu {gpu_mem()} ===")
|
||||||
pbar = tqdm(total=n_total, desc="gen steered", mininterval=120, maxinterval=120)
|
pbar = tqdm(total=n_total, desc="gen steered", mininterval=120, maxinterval=120)
|
||||||
pool = pool_for(cfg.demo)
|
pool = pool_for(cfg.demo)
|
||||||
|
|||||||
Reference in New Issue
Block a user