mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-07-05 23:11:21 +08:00
25ac3fc5e3
Replace the band-mechanics trio (tau/hkgap/frout) and the lumped qmass with a symmetric zone breakdown: each live unit's cos(g,v_grad) lands below/inside/above the pair-band -> keep/resid/rout, reported as both unit shares and energy shares (keepE/residE/routE). Energy view is unit-agnostic (answers 'is the grad per rollout'). Drop hk_abl/slv_abl unless rollout_ablate_frac>0 (else 0/0). Band edges (lower/upper) already logged at construction. v1 'routing' arm keeps qmass. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
177 lines
9.1 KiB
Python
177 lines
9.1 KiB
Python
"""Per-step training-table rendering and run logging.
|
|
|
|
Two concerns, both pure presentation (no model, no RNG): set up the token-efficient
|
|
loguru sinks for a run, and render the per-step metrics table. The renderer is the
|
|
single source of truth for column order, width, header, and number format; the
|
|
training loop hands it a row dict of raw values and gets back a formatted line.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
from loguru import logger
|
|
from tqdm import tqdm
|
|
|
|
LOGS_DIR = Path("logs")
|
|
|
|
|
|
def setup_logging(run_id: str) -> Path:
|
|
"""Token-efficient loguru: stdout = 1-char icon + msg; verbose log to file.
|
|
|
|
See /root/.claude/skills/token-efficient-logging/SKILL.md.
|
|
"""
|
|
LOGS_DIR.mkdir(exist_ok=True)
|
|
verbose_log = LOGS_DIR / f"{datetime.now().strftime('%Y%m%dT%H%M%S')}_{run_id}.log"
|
|
logger.remove()
|
|
logger.add(
|
|
lambda msg: tqdm.write(msg, end=""),
|
|
colorize=True,
|
|
format="<level>{level.icon}</level> {message}",
|
|
level="INFO",
|
|
)
|
|
logger.add(
|
|
verbose_log,
|
|
format="{time:HH:mm:ss} | {level} | {message}",
|
|
level="DEBUG",
|
|
)
|
|
logger.level("INFO", icon="I")
|
|
logger.level("WARNING", icon="W")
|
|
logger.level("ERROR", icon="E")
|
|
logger.level("DEBUG", icon="D")
|
|
return verbose_log
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class _Col:
|
|
"""Per-step table column spec.
|
|
|
|
key: row-dict key (raw value lives there as float/int/str/None).
|
|
width: render width for fixed-width streaming display.
|
|
header: display label (may include direction arrows, ? for desired-zero, etc).
|
|
fmt: format spec applied to the raw value, e.g. "+.3f", ".2e", "d".
|
|
Special spec "frac" expects a (num, denom) tuple and renders "n/d".
|
|
None means render as str() of the value.
|
|
"""
|
|
key: str
|
|
width: int
|
|
header: str
|
|
fmt: str | None = None
|
|
desc: str = "" # one-line decode for the legend; "" => omitted from legend
|
|
|
|
|
|
def _format_cell(value, fmt: str | None) -> str:
|
|
"""Format one cell. NaN renders as 'nan' regardless of spec."""
|
|
if value is None:
|
|
return "nan"
|
|
if fmt == "frac":
|
|
n, d = value
|
|
return f"{n}/{d}"
|
|
if fmt is None:
|
|
return str(value)
|
|
if isinstance(value, float) and value != value: # NaN
|
|
return "nan"
|
|
return format(value, fmt)
|
|
|
|
|
|
class StepLogger:
|
|
"""Per-step training-table renderer.
|
|
|
|
Single source of truth for column order, width, header label, and value
|
|
formatter. The row dict carries raw values (floats, ints, tuples, strings);
|
|
StepLogger formats them for streaming, and the end-of-run tabulate dump
|
|
consumes the same raw values without re-parsing scientific-notation strings.
|
|
|
|
Timing columns (gen/fb/t_rew/sec) intentionally absent from the streaming
|
|
spec — useful only at end-of-run, where the tabulate dump still picks
|
|
them up from the archived row dicts.
|
|
|
|
mode_code maps each env_mode to its short column tag (e.g. run_tests -> rt); the
|
|
caller owns it (it also names the row-dict keys) so this module stays leaf-level.
|
|
"""
|
|
|
|
def __init__(self, arm: str, modes: list[str], mode_code: dict[str, str],
|
|
show_ablate: bool = False) -> None:
|
|
# arm in {vanilla, projected, routing}; only projected/routing actually
|
|
# project the gradient, so the cin/cout/fired diagnostics are theirs alone
|
|
# (in vanilla they'd be counterfactual noise -> omitted).
|
|
projects = arm in ("projected", "routing")
|
|
cols: list[_Col] = [
|
|
_Col("step", 4, "step", "d", "GRPO step"),
|
|
_Col("ref_eq", 6, "ref_eq", ".2f", "vanilla-equiv step (cum_gens/256)"),
|
|
_Col("rew", 6, "rew", "+.2f", "mean combined reward"),
|
|
_Col("rew_s", 6, "rew_s↑", "+.2f", "student mean reward"),
|
|
_Col("gt_s", 6, "gt_s↑", "frac", "student ground-truth passes"),
|
|
_Col("gt_t", 6, "gt_t", "frac", "teacher ground-truth passes (sanity)"),
|
|
_Col("hack_s", 7, "hack_s?", "frac", "student hack-flagged rollouts (the headline)"),
|
|
_Col("hack_t", 7, "hack_t", "frac", "teacher hack-flagged rollouts (sanity: pool hacks)"),
|
|
# Deploy-eval shown for EVERY arm (nan on steps it's not run -> see it ride
|
|
# along as training proceeds). route/routeV: quarantine knob OFF. vanilla/erase:
|
|
# the trained model itself. Apples-to-apples knob-off deploy number, the plot series.
|
|
_Col("hack_deploy", 7, "hk_dep", "+.2f", "DEPLOY-eval hack (route: quarantine OFF; vanilla/erase: trained model); held-out subset, T=0.7, every eval_ablate_every steps; nan between"),
|
|
_Col("solve_deploy", 7, "slv_dep", "+.2f", "DEPLOY-eval solve (same cadence; nan between)"),
|
|
]
|
|
# Per-mode CUMULATIVE student exploit rate -> which loophole classes the
|
|
# student has learnt, and how strongly. Only when the run spans >1 mode
|
|
# (the substrate); single-mode runs would just duplicate hack_s.
|
|
self._modes = modes if len(modes) > 1 else []
|
|
for m in self._modes:
|
|
cols.append(_Col(f"hk_{mode_code[m]}", 5, f"hk_{mode_code[m]}", "d",
|
|
f"student hacks of {m} THIS step (current batch, not cumulative)"))
|
|
cols += [
|
|
_Col("lp_s", 6, "lp_s↓", "+.2f", "mean student gen_logp (diagnostic)"),
|
|
_Col("lp_t", 6, "lp_t↑", "+.2f", "mean teacher gen_logp; off-policy gap = lp_s-lp_t"),
|
|
_Col("loss", 7, "loss", "+.2f", "mean GRPO loss"),
|
|
_Col("gn", 7, "gn", ".1e", "pre-clip L2 norm of delta_S grads (vs grad_clip)"),
|
|
_Col("lr", 7, "lr", ".1e", "scheduled learning rate"),
|
|
]
|
|
if projects:
|
|
cols += [
|
|
_Col("cos_pre", 6, "cin", ".2f", "hack-ward grad fraction ||relu(V@g)||/||g|| [0,1] BEFORE proj"),
|
|
_Col("cos_pre_s", 6, "cin_s", ".2f", "cin on student-only grad"),
|
|
_Col("cos_pre_t", 6, "cin_t", ".2f", "cin on teacher-only grad (want cin_t>cin_s)"),
|
|
_Col("cos_post", 6, "cout", ".2f", "hack-ward fraction AFTER projection (want ~0: all removed)"),
|
|
_Col("fired", 5, "fired", ".2f", "fraction of modules where projection fired"),
|
|
]
|
|
# routeV routing, by what the gate does to each live unit (rollout, or token in
|
|
# per-token mode). Its cos(g, v_grad) falls below / inside / above the pair-band
|
|
# [lower, upper] (edges logged at band construction). Three zones, two views:
|
|
# keep/resid/rout = UNIT shares, keepE/residE/routE = ENERGY shares (each sums to
|
|
# 1). leak = hack alignment that slipped past into the deployed knob.
|
|
if arm == "routingV":
|
|
cols += [
|
|
_Col("keep", 6, "keep", ".2f", "unit share with cos below the band -> kept whole in the deployed knob (left)"),
|
|
_Col("resid", 6, "resid", ".2f", "unit share with cos inside the band -> partially routed (residual middle)"),
|
|
_Col("rout", 6, "rout", ".2f", "unit share with cos above the band -> fully routed into quarantine (right)"),
|
|
_Col("keepE", 6, "keepE", ".2f", "energy-weighted keep: share of grad ENERGY in the kept zone"),
|
|
_Col("residE", 6, "residE", ".2f", "energy-weighted resid: share of grad ENERGY in the partially-routed zone"),
|
|
_Col("routE", 6, "routE", ".2f", "energy-weighted rout: grad ENERGY share fully routed (~quarantine mass; the routed total is routE..routE+residE)"),
|
|
_Col("leak", 6, "leak", "+.2f", "hack-ward cosine left in the deployed knob after routing; ~0 = stripped clean, >0 = hack leaked through (under-routed)"),
|
|
]
|
|
if arm == "routing":
|
|
cols.append(
|
|
_Col("qmass", 6, "qmass", ".2f", "quarantine energy share ||g_quar||/(||g_keep||+||g_quar||): fraction of the update parked in the throwaway knob"))
|
|
# Per-step deploy proxy only exists when rollout_ablate_frac>0 generates a knob-off
|
|
# slice; without it the slice is empty (0/0), so drop the columns.
|
|
if arm in ("routing", "routingV") and show_ablate:
|
|
cols += [
|
|
_Col("hack_abl", 6, "hk_abl", "frac", "per-step deploy proxy: hack rate on the ablated (deploy-mode) rollout slice; train prompts, noisier than hk_dep"),
|
|
_Col("solve_abl", 6, "slv_abl", "frac", "per-step deploy proxy: solve rate on the ablated (deploy-mode) rollout slice; train prompts"),
|
|
]
|
|
self._cols = cols
|
|
|
|
def header(self) -> str:
|
|
return " ".join(f"{c.header:>{c.width}}" for c in self._cols)
|
|
|
|
def row(self, cells: dict) -> str:
|
|
return " ".join(
|
|
f"{_format_cell(cells[c.key], c.fmt):>{c.width}}" for c in self._cols
|
|
)
|
|
|
|
def legend(self) -> str:
|
|
"""Decode the (arm-/mode-conditional) columns actually present this run."""
|
|
lines = "\n".join(f" {c.header:>8} = {c.desc}" for c in self._cols if c.desc)
|
|
return ("table columns (timing gen/fb/t_rew/sec dropped from streaming, kept "
|
|
"in the end-of-run dump):\n" + lines)
|