mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 19:47:33 +08:00
55 lines
2.2 KiB
Python
55 lines
2.2 KiB
Python
"""Training-rollout table from completed structured run artifacts."""
|
|
from __future__ import annotations
|
|
|
|
import polars as pl
|
|
from tabulate import tabulate
|
|
|
|
from vgrout.run_artifacts import completed_runs
|
|
|
|
|
|
def main() -> None:
|
|
runs = [run for run in completed_runs()
|
|
if "tiny-random" not in run["cfg"]["model"] and "probe" not in run["cfg"]["out_tag"]]
|
|
rows = [{
|
|
"time": run["time"],
|
|
"arm": run["arm"],
|
|
"seed": str(run["cfg"]["seed"]),
|
|
"mix": str(run["cfg"]["mix_ratio"]),
|
|
"refr": str(run["cfg"]["vhack_refresh_every"]),
|
|
"over": str(run["cfg"]["project_overshoot"]),
|
|
"gate": run["cfg"]["gate_mode"],
|
|
"k": str(run["cfg"]["v_hack_k"]),
|
|
"dropf": str(run["cfg"]["v_hack_drop_bottom_frac"]),
|
|
"vhack": run["cfg"]["vhack_pairs_path"].rsplit("#", 1)[-1].split("/")[-1].removesuffix(".json"),
|
|
"L5_hack": run["l5_hack"],
|
|
"L5_solve": run["l5_solve"],
|
|
"WH_hack": run["whole_hack"],
|
|
"n": len(run["rows"]),
|
|
"run": run["run_dir"].name,
|
|
} for run in runs]
|
|
if not rows:
|
|
print("no completed non-smoke runs in out/runs/")
|
|
return
|
|
df = pl.DataFrame(rows).sort("time")
|
|
cols = ["arm", "seed", "mix", "refr", "over", "gate", "k", "dropf",
|
|
"vhack", "L5_hack", "L5_solve", "WH_hack", "n", "run"]
|
|
print("\n## All runs (sorted by time)\n")
|
|
print(tabulate(df.select(cols).rows(), headers=cols, tablefmt="pipe", floatfmt=".3f"))
|
|
|
|
key = ["arm", "mix", "refr", "over", "gate", "k", "dropf", "vhack"]
|
|
grouped = (df.group_by(key)
|
|
.agg(pl.col("L5_hack").mean().alias("hack"),
|
|
pl.col("L5_hack").std().alias("hack_sd"),
|
|
pl.col("L5_solve").mean().alias("solve"),
|
|
pl.col("L5_solve").std().alias("solve_sd"),
|
|
pl.len().alias("n"),
|
|
pl.col("seed").sort().str.join(",").alias("seeds"))
|
|
.sort(["mix", "arm", "refr", "over", "gate", "k"]))
|
|
gcols = key + ["hack", "hack_sd", "solve", "solve_sd", "n", "seeds"]
|
|
print("\n## Grouped by config (mean +/- std over seeds)\n")
|
|
print(tabulate(grouped.select(gcols).rows(), headers=gcols, tablefmt="pipe", floatfmt=".3f"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|