Files
grpo_proj2/scripts/results.py
wassname 409d9c9425 refactor: SVD-diag knob -> parametrized LoRA (fixed A, train B)
Replace the AntiPaSTO SVD-diag adapter (δW = U·diag(δS)·Vh) with a parametrized
LoRA: per Linear a frozen random A (r×d_in, semi-orthonormal) and a trained
B (d_out×r), δW = (B+B_hack)·A. δW stays LINEAR in the trained knob -- the one
property the projection needs (a once-extracted V stays a fixed weight-space
direction) -- but the whole per-module SVD subsystem is gone (no svd_cached, no
SVD_CACHE, no bf16 hash/contiguity friction). B_hack is the SAME shape as B, so the
route quarantine is capacity-matched by construction (old-repo route2 diverged from
an oversized quarantine sink).

- antipasto: wrap() builds A (fp32 orthogonal_, geqrf has no bf16 -> cast after) +
  B/B_hack zero-init on the layer device; forward y + (B+B_hack)@(A@x).
- proj: project_one is dim-agnostic; project_all flattens B.grad (d_out·r) and
  reshapes back. cos_overlap flattens too.
- extract: V from the SVD of stacked B.grad pair-diffs (d_out·r).
- train: B/B_hack rename, lora_rank config, per-step aligned table + legend
  (replaces the sparse tqdm postfix), clean argv via preset/Config defaults.
- justfile: collapse smoke-vanilla/smoke-route/fast-vanilla/full-vanilla into
  smoke/fast/full *ARGS + a `sweep` recipe that fires vanilla|erase|route as pueue
  jobs. results.py: glob run_*.log (skip loguru verbose logs).

Smoke (GPU bf16, all three arms) green: cout~0 one_sided identity holds in the LoRA
basis, |Bh|=0 for erase/none, route parks into B_hack.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-01 03:43:18 +00:00

39 lines
1.2 KiB
Python

"""Aggregate logs/run_*.log (JSONL rows from train.py) into a last-5-step table.
hack = student exploited rate (hack_s); solve = student gt_correct rate (gt_s).
Each row is one run; the arm/seed come from the filename run_{arm}_s{seed}.log.
"""
from __future__ import annotations
import glob
import json
import re
import sys
from tabulate import tabulate
def last5(rows, key):
vals = [r[key] for r in rows[-5:] if r.get(key) is not None]
return sum(vals) / len(vals) if vals else float("nan")
def main(pattern: str = "logs/run_*.log") -> None:
table = []
for path in sorted(glob.glob(pattern)):
rows = [json.loads(l) for l in open(path) if l.strip()]
if not rows:
continue
m = re.search(r"run_(.+?)_s(\d+)\.log", path)
arm, seed = (m.group(1), m.group(2)) if m else (path, "")
table.append({
"arm": arm, "steps": len(rows),
"hack": round(last5(rows, "hack_s"), 3), "solve": round(last5(rows, "gt_s"), 3),
"cin_t": round(last5(rows, "cin_t"), 3), "cout": round(last5(rows, "cout"), 3),
})
print(tabulate(table, headers="keys", tablefmt="pipe"))
if __name__ == "__main__":
main(sys.argv[1] if len(sys.argv) > 1 else "logs/run_*.log")