diff --git a/docs/brainstorming/.gitkeep b/docs/brainstorming/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/evidence/.gitkeep b/docs/evidence/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/literature/.gitkeep b/docs/literature/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/results.md b/docs/results.md new file mode 100644 index 0000000..bd6cedb --- /dev/null +++ b/docs/results.md @@ -0,0 +1,20 @@ +# Results, organized by the question each run answers + +Regenerate the tables with `just results` (groups `results.tsv` by arm). This file curates the answers; the append-only narrative lives in `RESEARCH_JOURNAL.md`. + +## How to read this + +- `auth` is the tinymfv authority-axis mean for the steered student (higher = more of the trait); `coherence` is `p_ans_any` (fraction of eval items where the model commits to a valid answer). Both are absolute fractions, compare rows within a table by eye. +- A regulariser is compared to the `nll` control only at matched `auth` (the U2 crux: more coherence at equal trait shift). +- `auth_sd` is the across-seed spread; a blank means a single seed. +- Provenance for each table goes in an HTML comment so any row can be re-created. + +## Q (U2). Which regulariser heals incoherency best at matched trait shift? + + + +Prior: `kl_rev > kl_fwd ~ wd > nll` (reverse KL is mode-seeking, suppresses the low-original-probability tokens that read as incoherent). + +No runs yet. Table appears here once `sweep-reg` has produced rows. + +**Answer:** pending. diff --git a/docs/spec/.gitkeep b/docs/spec/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/justfile b/justfile index 8224db1..b93e5ab 100644 --- a/justfile +++ b/justfile @@ -43,5 +43,9 @@ sweep-reg: done done +# Aggregate results.tsv into a by-arm markdown table. +results: + uv run python scripts/results.py + # flash-attn: install a prebuilt wheel (see `flash-attn-prebuilt` skill), then # run with STEER_ATTN_IMPL=flash_attention_2. diff --git a/scripts/results.py b/scripts/results.py new file mode 100644 index 0000000..6a2d898 --- /dev/null +++ b/scripts/results.py @@ -0,0 +1,31 @@ +"""`just results`: group results.tsv into comparable arms and print a markdown table. + +Grouping key is `reg` (the regulariser under test, U2); argv last so each row is +copy-paste reproducible. Edit GROUP when the knob under test changes. +""" + +from pathlib import Path + +import polars as pl +from tabulate import tabulate + +RESULTS_TSV = Path(__file__).resolve().parents[1] / "results.tsv" +GROUP = ["reg"] # all-else-equal grouping; the arm under test + +if not RESULTS_TSV.exists(): + raise SystemExit(f"no {RESULTS_TSV.name} yet; run something first") + +df = pl.read_csv(RESULTS_TSV, separator="\t") +agg = ( + df.group_by(GROUP) + .agg( + pl.col("p_ans_any").mean().round(3).alias("coherence"), + pl.col("auth").mean().round(3), + pl.col("auth").std().round(3).alias("auth_sd"), + pl.len().alias("n"), + pl.col("seed").cast(pl.Utf8).sort().str.join(",").alias("seeds"), + pl.col("argv").first(), + ) + .sort("auth", descending=True) +) +print(tabulate(agg.to_pandas(), headers="keys", tablefmt="pipe", floatfmt="+.3f")) diff --git a/src/steer_heal/run.py b/src/steer_heal/run.py index 9770f59..a500d1e 100644 --- a/src/steer_heal/run.py +++ b/src/steer_heal/run.py @@ -10,6 +10,7 @@ file fails fast at the first unimplemented stage rather than stubbing fake behaviour. `--fast-dev-run` runs the whole thing on the tiny-random model. """ +import dataclasses import os import sys from datetime import datetime @@ -24,6 +25,17 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from steer_heal.config import RunConfig, resolve REPO = Path(__file__).resolve().parents[2] +RESULTS_TSV = REPO / "results.tsv" # one row per finished run; `just results` aggregates + + +def append_result(cfg: RunConfig, metrics: dict) -> None: + # self-describing, copy-paste reproducible row: config + final metrics + argv. + row = {**dataclasses.asdict(cfg), **metrics, "argv": " ".join(sys.argv[1:])} + new = not RESULTS_TSV.exists() + with open(RESULTS_TSV, "a") as f: + if new: + f.write("\t".join(row) + "\n") + f.write("\t".join(str(v) for v in row.values()) + "\n") def setup_logging() -> None: @@ -80,14 +92,16 @@ def evaluate(model, cfg: RunConfig) -> dict: raise NotImplementedError("TODO: tinymfv eval + plotly map (port csm/plot.py _build_scatter)") -def steer_heal(model, tok, orig, cfg: RunConfig): +def steer_heal(model, tok, orig, cfg: RunConfig) -> dict: + metrics = {} for r in range(cfg.n_rounds): logger.info(f"── round {r} ──") v = teacher_vec(model, tok, cfg) comps = generate_and_filter(model, tok, v, orig, cfg) heal(model, orig, comps, cfg) - logger.info(evaluate(model, cfg)) - return model + metrics = {"round": r, **evaluate(model, cfg)} + logger.info(metrics) + return metrics # final round, for results.tsv def main(cfg: RunConfig) -> None: @@ -98,7 +112,9 @@ def main(cfg: RunConfig) -> None: dtype = getattr(torch, cfg.dtype) model, tok = load_model(cfg.model, dtype) orig = model # round-0 anchor; KL reference = same module with adapter gates off - steer_heal(model, tok, orig, cfg) + metrics = steer_heal(model, tok, orig, cfg) + append_result(cfg, metrics) + logger.info(f"done; appended to {RESULTS_TSV.name}") if __name__ == "__main__":