mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 17:02:34 +08:00
setup-repo gap-fill: results ledger + docs structure
Add the by-question results infra per setup-repo conventions:
- results.tsv append at end of each finished run (config + final metrics + argv)
- scripts/results.py groups by arm (reg) into a markdown table; `just results`
- docs/results.md curated by-question snapshot (U2 regulariser comparison)
- docs/{spec,brainstorming,literature,evidence} structure
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
# Results, organized by the question each run answers
|
||||
|
||||
Regenerate the tables with `just results` (groups `results.tsv` by arm). This file curates the answers; the append-only narrative lives in `RESEARCH_JOURNAL.md`.
|
||||
|
||||
## How to read this
|
||||
|
||||
- `auth` is the tinymfv authority-axis mean for the steered student (higher = more of the trait); `coherence` is `p_ans_any` (fraction of eval items where the model commits to a valid answer). Both are absolute fractions, compare rows within a table by eye.
|
||||
- A regulariser is compared to the `nll` control only at matched `auth` (the U2 crux: more coherence at equal trait shift).
|
||||
- `auth_sd` is the across-seed spread; a blank means a single seed.
|
||||
- Provenance for each table goes in an HTML comment so any row can be re-created.
|
||||
|
||||
## Q (U2). Which regulariser heals incoherency best at matched trait shift?
|
||||
|
||||
<!-- runs: results.tsv rows; commit: TBD; model: google/gemma-3-1b-it -->
|
||||
|
||||
Prior: `kl_rev > kl_fwd ~ wd > nll` (reverse KL is mode-seeking, suppresses the low-original-probability tokens that read as incoherent).
|
||||
|
||||
No runs yet. Table appears here once `sweep-reg` has produced rows.
|
||||
|
||||
**Answer:** pending.
|
||||
@@ -43,5 +43,9 @@ sweep-reg:
|
||||
done
|
||||
done
|
||||
|
||||
# Aggregate results.tsv into a by-arm markdown table.
|
||||
results:
|
||||
uv run python scripts/results.py
|
||||
|
||||
# flash-attn: install a prebuilt wheel (see `flash-attn-prebuilt` skill), then
|
||||
# run with STEER_ATTN_IMPL=flash_attention_2.
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""`just results`: group results.tsv into comparable arms and print a markdown table.
|
||||
|
||||
Grouping key is `reg` (the regulariser under test, U2); argv last so each row is
|
||||
copy-paste reproducible. Edit GROUP when the knob under test changes.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
from tabulate import tabulate
|
||||
|
||||
RESULTS_TSV = Path(__file__).resolve().parents[1] / "results.tsv"
|
||||
GROUP = ["reg"] # all-else-equal grouping; the arm under test
|
||||
|
||||
if not RESULTS_TSV.exists():
|
||||
raise SystemExit(f"no {RESULTS_TSV.name} yet; run something first")
|
||||
|
||||
df = pl.read_csv(RESULTS_TSV, separator="\t")
|
||||
agg = (
|
||||
df.group_by(GROUP)
|
||||
.agg(
|
||||
pl.col("p_ans_any").mean().round(3).alias("coherence"),
|
||||
pl.col("auth").mean().round(3),
|
||||
pl.col("auth").std().round(3).alias("auth_sd"),
|
||||
pl.len().alias("n"),
|
||||
pl.col("seed").cast(pl.Utf8).sort().str.join(",").alias("seeds"),
|
||||
pl.col("argv").first(),
|
||||
)
|
||||
.sort("auth", descending=True)
|
||||
)
|
||||
print(tabulate(agg.to_pandas(), headers="keys", tablefmt="pipe", floatfmt="+.3f"))
|
||||
+20
-4
@@ -10,6 +10,7 @@ file fails fast at the first unimplemented stage rather than stubbing fake
|
||||
behaviour. `--fast-dev-run` runs the whole thing on the tiny-random model.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
@@ -24,6 +25,17 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from steer_heal.config import RunConfig, resolve
|
||||
|
||||
REPO = Path(__file__).resolve().parents[2]
|
||||
RESULTS_TSV = REPO / "results.tsv" # one row per finished run; `just results` aggregates
|
||||
|
||||
|
||||
def append_result(cfg: RunConfig, metrics: dict) -> None:
|
||||
# self-describing, copy-paste reproducible row: config + final metrics + argv.
|
||||
row = {**dataclasses.asdict(cfg), **metrics, "argv": " ".join(sys.argv[1:])}
|
||||
new = not RESULTS_TSV.exists()
|
||||
with open(RESULTS_TSV, "a") as f:
|
||||
if new:
|
||||
f.write("\t".join(row) + "\n")
|
||||
f.write("\t".join(str(v) for v in row.values()) + "\n")
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
@@ -80,14 +92,16 @@ def evaluate(model, cfg: RunConfig) -> dict:
|
||||
raise NotImplementedError("TODO: tinymfv eval + plotly map (port csm/plot.py _build_scatter)")
|
||||
|
||||
|
||||
def steer_heal(model, tok, orig, cfg: RunConfig):
|
||||
def steer_heal(model, tok, orig, cfg: RunConfig) -> dict:
|
||||
metrics = {}
|
||||
for r in range(cfg.n_rounds):
|
||||
logger.info(f"── round {r} ──")
|
||||
v = teacher_vec(model, tok, cfg)
|
||||
comps = generate_and_filter(model, tok, v, orig, cfg)
|
||||
heal(model, orig, comps, cfg)
|
||||
logger.info(evaluate(model, cfg))
|
||||
return model
|
||||
metrics = {"round": r, **evaluate(model, cfg)}
|
||||
logger.info(metrics)
|
||||
return metrics # final round, for results.tsv
|
||||
|
||||
|
||||
def main(cfg: RunConfig) -> None:
|
||||
@@ -98,7 +112,9 @@ def main(cfg: RunConfig) -> None:
|
||||
dtype = getattr(torch, cfg.dtype)
|
||||
model, tok = load_model(cfg.model, dtype)
|
||||
orig = model # round-0 anchor; KL reference = same module with adapter gates off
|
||||
steer_heal(model, tok, orig, cfg)
|
||||
metrics = steer_heal(model, tok, orig, cfg)
|
||||
append_result(cfg, metrics)
|
||||
logger.info(f"done; appended to {RESULTS_TSV.name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user