mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:47:16 +08:00
implement pipeline: extract -> dose -> generate -> filter -> heal -> fold -> eval -> loop
fast-dev-run now runs end to end on wassname/qwen3-5lyr-tiny-random (out dir +
map.html + results.tsv row). Modules:
- steering.py: teacher_vec (steering-lite mean-diff @ assistant tag + iso-KL dose) + steered gen
- filter.py: Q0 coherence (ppl-under-original + repetition + narration regex)
- heal.py: Q1 SFT + divergence-to-original barrier (nll/kl_fwd/kl_rev/wd), reg=kl_rev default
- ws/{adapter,bake}.py: ModulatedLoRA + gated baked(), copied from w2schar-mini
- eval.py: tinymfv -> {auth, care, coherence(mean_pmass_allowed), ppx_json}
- plot.py: plotly Care-vs-Authority map.html (simplified w2schar port)
- io.py: out/{ts}_{slug}/ + srsly events.jsonl + shared results.tsv
- prompts.py: 30 authority dilemmas (POOL, copied) + assistant-tag chat templating
Trace confirms pos/neg prompts end at the assistant tag (paper's read).
Tiny-random numbers are junk by design; real calibration/eval pending small-dev.
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -42,6 +42,10 @@ class RunConfig:
|
||||
epochs: int = 2
|
||||
lr: float = 1e-4
|
||||
|
||||
# ── eval (tinymfv) ──
|
||||
eval_vignettes: int | None = None # None = all Clifford-2015 vignettes
|
||||
eval_think_tokens: int = 64 # tinymfv default; 10x faster than 256, within bf16 noise
|
||||
|
||||
# ── loop (U3) ──
|
||||
n_rounds: int = 4
|
||||
|
||||
@@ -57,6 +61,10 @@ TINY = dict(
|
||||
epochs=1,
|
||||
n_rounds=1,
|
||||
alphas=(1.0,),
|
||||
eval_vignettes=4,
|
||||
eval_think_tokens=16,
|
||||
ppl_tau=1e9, # tiny-random produces junk ppl; relax the gate so the path still runs
|
||||
rep_tau=1.1,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""tinymfv eval -> {auth, care, coherence, ppx_json}.
|
||||
|
||||
auth/care are the model's mean probability on the Authority/Care moral
|
||||
foundations (the trait axis we move). coherence = mean_pmass_allowed (the
|
||||
forced-choice canary). These are kept distinct: we shift auth on purpose,
|
||||
coherence must not collapse.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import tinymfv
|
||||
from loguru import logger
|
||||
|
||||
from steer_heal.config import RunConfig
|
||||
|
||||
|
||||
def evaluate_model(model, tok, cfg: RunConfig) -> dict:
|
||||
rep = tinymfv.evaluate(
|
||||
model, tok, name="classic",
|
||||
n_vignettes=cfg.eval_vignettes,
|
||||
conditions=("other_violate",),
|
||||
max_think_tokens=cfg.eval_think_tokens,
|
||||
batch_size=8,
|
||||
device=model.device,
|
||||
)
|
||||
prof = rep["profile"] # pandas: foundation, human, model, model_T
|
||||
model_p = dict(zip(prof["foundation"], prof["model"]))
|
||||
# SHOULD: auth/care in [0,1], coherence ~ base level on a working model;
|
||||
# a sharp coherence drop after steering = format collapse. On tiny-random
|
||||
# the numbers are junk (we test the path, not the value).
|
||||
out = {
|
||||
"auth": float(model_p["Authority"]),
|
||||
"care": float(model_p["Care"]),
|
||||
"coherence": float(rep["mean_pmass_allowed"]),
|
||||
"ppx_json": float(math.exp(rep["mean_nll_json"])),
|
||||
"top1_acc": float(rep["top1_acc"]),
|
||||
}
|
||||
logger.info(f"eval: auth={out['auth']:.3f} care={out['care']:.3f} "
|
||||
f"coherence={out['coherence']:.3f} ppx={out['ppx_json']:.1f}")
|
||||
return out
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Q0 filter: drop incoherent and trait-narrating steered completions.
|
||||
|
||||
Coherence signal per completion = perplexity under the ORIGINAL model (no
|
||||
steering): incoherent babble is improbable under base. Plus a repetition guard
|
||||
and a first-person narration regex (we want enact, not narrate).
|
||||
"""
|
||||
|
||||
import math
|
||||
import re
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from steer_heal.config import RunConfig
|
||||
|
||||
# first-person verbalisation of the trait ("I am someone who never defers...")
|
||||
NARRATE = re.compile(
|
||||
r"\b(i am (someone|a person) who|as someone who|i (don'?t|do not|never) defer|"
|
||||
r"my principle|i always stick to|i refuse to (defer|obey))\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def rep_frac(text: str) -> float:
|
||||
"""Most-repeated 4-gram fraction; ~1.0 means degenerate looping/too short."""
|
||||
words = text.split()
|
||||
grams = [tuple(words[i : i + 4]) for i in range(len(words) - 3)]
|
||||
if not grams:
|
||||
return 1.0
|
||||
return Counter(grams).most_common(1)[0][1] / len(grams)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ppl_under_base(model, tok, prompt: str, completion: str) -> float:
|
||||
ids = tok(prompt + completion, return_tensors="pt").to(model.device)
|
||||
n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1]
|
||||
logits = model(**ids).logits[0]
|
||||
labels = ids.input_ids[0, n_prompt:]
|
||||
if labels.numel() == 0:
|
||||
return float("inf")
|
||||
logp = logits[n_prompt - 1 : -1].log_softmax(-1)
|
||||
nll = -logp[torch.arange(labels.numel()), labels].mean()
|
||||
return math.exp(nll.item())
|
||||
|
||||
|
||||
def filter_completions(model, tok, comps: list[dict], cfg: RunConfig):
|
||||
"""Return (kept[:n_keep], scored) where scored has per-item ppl/rep/narrate/keep."""
|
||||
scored = []
|
||||
for c in comps:
|
||||
rf = rep_frac(c["completion"])
|
||||
nar = bool(NARRATE.search(c["completion"]))
|
||||
ppl = ppl_under_base(model, tok, c["prompt"], c["completion"])
|
||||
keep = (ppl < cfg.ppl_tau) and (rf < cfg.rep_tau) and (not nar)
|
||||
scored.append({**c, "ppl": ppl, "rep": rf, "narrates": nar, "keep": keep})
|
||||
kept = [s for s in scored if s["keep"]]
|
||||
# SHOULD: on a real model, kept drops the babble and keeps fluent dilemmas
|
||||
# answers; if kept==0 the gate is too strict (raise ppl_tau) or steering
|
||||
# broke generation. On tiny-random everything passes (relaxed tau).
|
||||
logger.info(f"filter: kept {len(kept)}/{len(comps)} (ppl<{cfg.ppl_tau:g}, rep<{cfg.rep_tau}, not-narrate)")
|
||||
return kept[: cfg.n_keep], scored
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Q1 heal: train one round's LoRA = SFT(kept completions) + divergence-to-original barrier.
|
||||
|
||||
The barrier reference is the round-0 ORIGINAL (gates/adapters off), not the
|
||||
previous student, so it resists cumulative drift. reg picks the divergence:
|
||||
nll SFT only (control)
|
||||
kl_fwd KL(orig || theta) mass-covering (dilutes the trait)
|
||||
kl_rev KL(theta || orig) mode-seeking (suppresses low-orig-prob = incoherent) [expected best]
|
||||
wd weight decay on the adapter only
|
||||
"""
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch.nn import functional as F
|
||||
|
||||
from steer_heal.config import RunConfig
|
||||
from steer_heal.ws.adapter import ModulatedLoRA
|
||||
from steer_heal.ws.bake import AdapterSpec, baked
|
||||
|
||||
|
||||
def _kl_per_pos(logp_a, logp_b): # KL(a || b) summed over vocab, per position
|
||||
return (logp_a.exp() * (logp_a - logp_b)).sum(-1)
|
||||
|
||||
|
||||
def _encode(tok, prompt: str, completion: str, max_len: int, device):
|
||||
ids = tok(prompt + completion, return_tensors="pt", truncation=True, max_length=max_len).to(device)
|
||||
n_prompt = tok(prompt, return_tensors="pt").input_ids.shape[1]
|
||||
L = ids.input_ids.shape[1]
|
||||
tgt_is_completion = torch.arange(1, L, device=device) >= n_prompt # mask over next-token targets
|
||||
return ids, tgt_is_completion
|
||||
|
||||
|
||||
def heal_round(model, tok, kept: list[dict], hist_specs: list[AdapterSpec], cfg: RunConfig):
|
||||
"""Train a fresh round adapter on top of baked history. Returns (lora, spec)."""
|
||||
lora = ModulatedLoRA(model, r=cfg.lora_r, alpha=cfg.lora_alpha, layer_range=cfg.layer_range)
|
||||
opt = torch.optim.AdamW(list(lora.parameters()), lr=cfg.lr,
|
||||
weight_decay=(cfg.lam if cfg.reg == "wd" else 0.0))
|
||||
|
||||
for ep in range(cfg.epochs):
|
||||
for c in kept:
|
||||
ids, mask = _encode(tok, c["prompt"], c["completion"], cfg.max_len, model.device)
|
||||
if mask.sum() == 0:
|
||||
continue # completion truncated away; nothing to learn here
|
||||
|
||||
# original reference logits (no history, adapter off) for the barrier
|
||||
if cfg.reg in ("kl_fwd", "kl_rev"):
|
||||
with torch.no_grad(), lora(model, c=0.0):
|
||||
logp0 = model(**ids).logits[0, :-1].log_softmax(-1)
|
||||
|
||||
# student logits: history baked + this round's adapter live
|
||||
with baked(model, hist_specs), lora(model, c=1.0):
|
||||
logits = model(**ids).logits[0, :-1]
|
||||
logp = logits.log_softmax(-1)
|
||||
|
||||
tgt = ids.input_ids[0, 1:]
|
||||
sft = F.nll_loss(logp[mask], tgt[mask])
|
||||
if cfg.reg == "kl_fwd":
|
||||
div = _kl_per_pos(logp0[mask], logp[mask]).mean()
|
||||
elif cfg.reg == "kl_rev":
|
||||
div = _kl_per_pos(logp[mask], logp0[mask]).mean()
|
||||
else:
|
||||
div = torch.zeros((), device=model.device) # nll, wd
|
||||
loss = sft + cfg.lam * torch.relu(div - cfg.tau)
|
||||
loss.backward()
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
logger.info(f"heal[{cfg.reg}] epoch {ep}: sft={sft.item():.3f} div={float(div):.3f}")
|
||||
|
||||
spec = AdapterSpec.from_lora(lora, default_c=1.0) # CPU-resident, for the next round's history
|
||||
return lora, spec
|
||||
@@ -32,7 +32,7 @@ def make_run_dir(ts: str, slug: str, cfg) -> Path:
|
||||
|
||||
def log_event(run_dir: Path, **rec) -> None:
|
||||
# append one jsonl line; events.jsonl is the full machine-readable trace.
|
||||
srsly.append_jsonl(run_dir / "events.jsonl", rec)
|
||||
srsly.write_jsonl(run_dir / "events.jsonl", [rec], append=True)
|
||||
|
||||
|
||||
def append_result(cfg, metrics: dict) -> None:
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Loop map: Care (y) vs Authority (x) trajectory + coherence/cosine panels.
|
||||
|
||||
Simplified from wassname/w2schar-mini csm/plot.py _build_scatter (full git-graph
|
||||
port is a later pass). One node per round; hover shows coherence and the
|
||||
round-0 cosine. Saved as out/{ts}_{slug}/map.html.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
|
||||
def write_map(run_dir: Path, rounds: list[dict]) -> Path:
|
||||
r = [d["round"] for d in rounds]
|
||||
fig = make_subplots(
|
||||
rows=1, cols=2, column_widths=[0.6, 0.4],
|
||||
subplot_titles=("trait map: Care vs Authority", "coherence + direction per round"),
|
||||
specs=[[{"type": "scatter"}, {"type": "scatter"}]],
|
||||
)
|
||||
# trajectory across the auth axis, coloured by round
|
||||
fig.add_trace(go.Scatter(
|
||||
x=[d["auth"] for d in rounds], y=[d["care"] for d in rounds],
|
||||
mode="lines+markers+text", text=[f"r{i}" for i in r], textposition="top center",
|
||||
marker=dict(size=12, color=r, colorscale="Viridis", showscale=False),
|
||||
hovertext=[f"r{d['round']} coh={d['coherence']:.3f} cos={d.get('cos_v0', float('nan')):.2f}"
|
||||
for d in rounds],
|
||||
name="trajectory",
|
||||
), row=1, col=1)
|
||||
fig.update_xaxes(title_text="Authority p (trait →)", row=1, col=1)
|
||||
fig.update_yaxes(title_text="Care p", row=1, col=1)
|
||||
|
||||
fig.add_trace(go.Scatter(x=r, y=[d["coherence"] for d in rounds],
|
||||
mode="lines+markers", name="coherence"), row=1, col=2)
|
||||
fig.add_trace(go.Scatter(x=r, y=[d.get("cos_v0", float("nan")) for d in rounds],
|
||||
mode="lines+markers", name="cos(v_r, v_0)"), row=1, col=2)
|
||||
fig.update_xaxes(title_text="round", row=1, col=2)
|
||||
|
||||
out = run_dir / "map.html"
|
||||
fig.write_html(out, include_plotlyjs="cdn")
|
||||
return out
|
||||
+57
-65
@@ -1,45 +1,33 @@
|
||||
"""steer_heal pipeline entry point.
|
||||
"""steer_heal pipeline: extract -> dose -> generate -> filter -> heal -> fold -> eval -> loop.
|
||||
|
||||
Loop (anchored to the round-0 original throughout, see spec.md):
|
||||
teacher vector (steering-lite) -> iso-KL dose -> generate -> U1 filter
|
||||
-> heal one round (SFT + KL-rev-to-original barrier) -> fold (gated bake)
|
||||
-> tinymfv eval -> repeat.
|
||||
|
||||
Stages marked TODO are ported from docs/vendor/* as we implement them; this
|
||||
file fails fast at the first unimplemented stage rather than stubbing fake
|
||||
behaviour. `--fast-dev-run` runs the whole thing on the tiny-random model.
|
||||
Anchored to the round-0 original throughout (KL reference = adapters/gates off).
|
||||
`--fast-dev-run` runs the whole thing on the tiny-random model. See spec.md.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tyro
|
||||
from loguru import logger
|
||||
from torch.nn.functional import cosine_similarity
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from steer_heal.config import RunConfig, resolve
|
||||
from steer_heal.eval import evaluate_model
|
||||
from steer_heal.filter import filter_completions
|
||||
from steer_heal.heal import heal_round
|
||||
from steer_heal.io import append_result, log_event, make_run_dir
|
||||
from steer_heal.plot import write_map
|
||||
from steer_heal.steering import generate_steered, teacher_vec
|
||||
from steer_heal.ws.bake import baked
|
||||
|
||||
REPO = Path(__file__).resolve().parents[2]
|
||||
RESULTS_TSV = REPO / "results.tsv" # one row per finished run; `just results` aggregates
|
||||
|
||||
|
||||
def append_result(cfg: RunConfig, metrics: dict) -> None:
|
||||
# self-describing, copy-paste reproducible row: config + final metrics + argv.
|
||||
row = {**dataclasses.asdict(cfg), **metrics, "argv": " ".join(sys.argv[1:])}
|
||||
new = not RESULTS_TSV.exists()
|
||||
with open(RESULTS_TSV, "a") as f:
|
||||
if new:
|
||||
f.write("\t".join(row) + "\n")
|
||||
f.write("\t".join(str(v) for v in row.values()) + "\n")
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
# tqdm-safe loguru, single-char level icons, verbose copy on disk.
|
||||
logger.remove()
|
||||
logger.add(lambda m: tqdm.write(m, end=""), colorize=True,
|
||||
format="<level>{level.icon}</level> {message}", level="INFO")
|
||||
@@ -48,8 +36,7 @@ def setup_logging() -> None:
|
||||
log_dir = REPO / "logs"
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
f = log_dir / f"{datetime.now():%Y%m%dT%H%M%S}_verbose.log"
|
||||
logger.add(f, format="{time:HH:mm:ss} | {level: <7} | {name}:{function}:{line} - {message}",
|
||||
level="DEBUG")
|
||||
logger.add(f, format="{time:HH:mm:ss} | {level: <7} | {name}:{function}:{line} - {message}", level="DEBUG")
|
||||
logger.info(f"verbose log: {f}")
|
||||
|
||||
|
||||
@@ -57,64 +44,69 @@ def load_model(model_id: str, dtype: torch.dtype):
|
||||
tok = AutoTokenizer.from_pretrained(model_id)
|
||||
if tok.pad_token is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
tok.padding_side = "left"
|
||||
attn = os.environ.get("STEER_ATTN_IMPL", "eager") # set =flash_attention_2 on real runs
|
||||
attn = os.environ.get("STEER_ATTN_IMPL", "eager")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, device_map="auto", torch_dtype=dtype,
|
||||
low_cpu_mem_usage=True, attn_implementation=attn,
|
||||
model_id, device_map="auto", torch_dtype=dtype, low_cpu_mem_usage=True,
|
||||
attn_implementation=attn,
|
||||
)
|
||||
model.eval()
|
||||
logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn})")
|
||||
logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn}, layers={model.config.num_hidden_layers})")
|
||||
return model, tok
|
||||
|
||||
|
||||
# ── stages (ported from docs/vendor as we implement; fail fast until then) ──
|
||||
|
||||
def teacher_vec(model, tok, cfg: RunConfig):
|
||||
# steering-lite Vector.train(pos=trait-sysprompt, neg=neutral-sysprompt) @ assistant tag,
|
||||
# then .calibrate(target_kl=cfg.target_kl). See docs/vendor/steering-lite + isokl.
|
||||
raise NotImplementedError("TODO: teacher_vec via steering-lite + iso-KL calibration")
|
||||
def _flatten_v(v) -> torch.Tensor:
|
||||
return torch.cat([v.state[li]["v"].flatten().float() for li in sorted(v.state)])
|
||||
|
||||
|
||||
def generate_and_filter(model, tok, v, orig, cfg: RunConfig):
|
||||
# gen at alpha*c_star (steering-lite hook); keep coherent & enact-not-narrate (U1).
|
||||
raise NotImplementedError("TODO: generate_and_filter (U1 gate)")
|
||||
def steer_heal(model, tok, cfg: RunConfig, run_dir: Path) -> dict:
|
||||
hist_specs = [] # AdapterSpec per folded round (gated bake history)
|
||||
v0_flat = None # round-0 direction, for the Q3 cosine
|
||||
rounds = []
|
||||
for rnd in range(cfg.n_rounds):
|
||||
logger.info(f"\n=== ROUND {rnd} [{cfg.model.split('/')[-1]} reg={cfg.reg}] ===")
|
||||
# extract teacher vector + generate steered data from the CURRENT student
|
||||
with baked(model, hist_specs):
|
||||
v = teacher_vec(model, tok, cfg)
|
||||
comps = generate_steered(model, tok, v, alpha=1.0, cfg=cfg)
|
||||
# filter under the ORIGINAL (no history, no steering)
|
||||
kept, scored = filter_completions(model, tok, comps, cfg)
|
||||
log_event(run_dir, stage="gen", round=rnd, n_comps=len(comps), n_kept=len(kept), scored=scored)
|
||||
|
||||
# heal one round on top of the baked history, then fold
|
||||
lora, spec = heal_round(model, tok, kept, hist_specs, cfg)
|
||||
lora.save(str(run_dir / "ckpt" / f"r{rnd}.safetensors"), extra_meta={"round": str(rnd), "reg": cfg.reg})
|
||||
hist_specs.append(spec)
|
||||
|
||||
def heal(model, orig, comps, cfg: RunConfig):
|
||||
# SFT + lam*relu(div - tau); div in {nll, kl_fwd, kl_rev, wd}; KL ref = orig (gates off).
|
||||
# adapter + gated bake ported from docs/vendor/w2schar-mini/src/csm/ws.
|
||||
raise NotImplementedError("TODO: heal (U2 barrier) + fold via w2schar ws.bake")
|
||||
# eval the student (all rounds baked)
|
||||
with baked(model, hist_specs):
|
||||
m = evaluate_model(model, tok, cfg)
|
||||
|
||||
vf = _flatten_v(v)
|
||||
v0_flat = vf if v0_flat is None else v0_flat
|
||||
cos_v0 = float(cosine_similarity(vf, v0_flat, dim=0))
|
||||
rec = {"round": rnd, **m, "cos_v0": cos_v0, "c_star": float(v.cfg.coeff), "n_kept": len(kept)}
|
||||
rounds.append(rec)
|
||||
log_event(run_dir, stage="round", **rec)
|
||||
logger.info(f"round {rnd}: auth={m['auth']:.3f} care={m['care']:.3f} "
|
||||
f"coh={m['coherence']:.3f} cos_v0={cos_v0:+.2f}")
|
||||
|
||||
def evaluate(model, cfg: RunConfig) -> dict:
|
||||
# tinymfv auth/care axes + p_ans_any/json_is_valid/ppx_json.
|
||||
raise NotImplementedError("TODO: tinymfv eval + plotly map (port csm/plot.py _build_scatter)")
|
||||
|
||||
|
||||
def steer_heal(model, tok, orig, cfg: RunConfig) -> dict:
|
||||
metrics = {}
|
||||
for r in range(cfg.n_rounds):
|
||||
logger.info(f"── round {r} ──")
|
||||
v = teacher_vec(model, tok, cfg)
|
||||
comps = generate_and_filter(model, tok, v, orig, cfg)
|
||||
heal(model, orig, comps, cfg)
|
||||
metrics = {"round": r, **evaluate(model, cfg)}
|
||||
logger.info(metrics)
|
||||
return metrics # final round, for results.tsv
|
||||
map_path = write_map(run_dir, rounds)
|
||||
logger.info(f"map: {map_path}")
|
||||
return rounds[-1]
|
||||
|
||||
|
||||
def main(cfg: RunConfig) -> None:
|
||||
setup_logging()
|
||||
cfg = resolve(cfg)
|
||||
torch.manual_seed(cfg.seed)
|
||||
logger.info(f"config: {cfg}")
|
||||
dtype = getattr(torch, cfg.dtype)
|
||||
model, tok = load_model(cfg.model, dtype)
|
||||
orig = model # round-0 anchor; KL reference = same module with adapter gates off
|
||||
metrics = steer_heal(model, tok, orig, cfg)
|
||||
append_result(cfg, metrics)
|
||||
logger.info(f"done; appended to {RESULTS_TSV.name}")
|
||||
ts = datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||
slug = f"{cfg.model.split('/')[-1]}_{cfg.reg}_s{cfg.seed}"
|
||||
run_dir = make_run_dir(ts, slug, cfg)
|
||||
logger.info(f"argv cfg: {cfg}")
|
||||
model, tok = load_model(cfg.model, getattr(torch, cfg.dtype))
|
||||
final = steer_heal(model, tok, cfg, run_dir)
|
||||
append_result(cfg, {"slug": slug, **final})
|
||||
logger.info(f"done: {run_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user