Files
evil_MoE/src/vgrout/train_config.py
T
wassname f1dd9fb33e chore: FastConfig save_ckpt_every 10->20 (disk pressure; ~6 ckpts/run suffices)
Each lora2r ckpt is ~1.3G (A/B + redundant frozen A0/B0, 252 modules fp32). The
768G disk filled and runs crashed at the step-0 ckpt save. 20-step cadence halves
the per-run footprint while keeping enough points for the eval curve.

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-11 00:35:29 +00:00

135 lines
5.8 KiB
Python

"""Typed CLI configuration for train.py.
One adapter (lora2r: rank-2r Gaussian-init LoRA, A+B trainable, SGTM-style
three-way hard block masking; see src/vgrout/lora2r.py) and three arms:
none gate pinned clean (0,0): quarantine never trains -- the capacity- and
structure-matched vanilla control.
routeV per-rollout three-way gate from the c-probe gradient vs v_grad.
absorb gate pinned mid (1,0): both blocks train on everything, no gate --
isolates the value of the gate + hard masks vs absorption alone.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
from .rewards import EnvMode
@dataclass(kw_only=True)
class Config:
intervention: Literal["none", "routeV", "absorb"] = "routeV"
lora_r: int = 32
lora_init_seed: int = 0
model: str = "Qwen/Qwen3-4B"
steps: int = 100
group: int = 6
max_new: int = 1024
n_problems: int = 992
prompts_per_step: int = 8
lr: float = 1e-4
adam_beta1: float = 0.9
adam_beta2: float = 0.99
clip: float = 0.2
# AdamW decay pulls raw A/B toward 0, not toward the init, which would drive
# the net delta to -B0@A0 -- must stay 0 for this adapter.
weight_decay: float = 0.0
warmup_frac: float = 0.1
grad_clip: float = 10.0
seed: int = 41
unbiased: bool = True
vhack_refresh_every: int = 5
vhack_pairs_path: Path = Path("data/pairs/hack_pairs.md#all-in-one")
# Routing directions per module. k=1 (headline): the mean(hack-clean) "mean-mass"
# axis. k>1: top-k oriented SVD dirs of the paired diff; gate scores max_i cos(g,v_i)
# (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean
# washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B.
v_grad_k: int = 1
# Online-stats gate: route by live QUANTILES of the pooled cosine-to-v_grad, not the
# authored absolute band (whose p75-hack edge live hacks never reach -> rout~0). Each
# step the top route_quantile tail -> hack (deployed detached), bottom -> keep, middle
# -> mid. route_window = sliding buffer of recent pooled positions, flushed on refresh
# (positions are measured against one v_grad). Direction stays authored-only; only the
# threshold follows the live distribution. TODO(#30): center+width calibration vs fixed %.
route_quantile: float = 0.05
route_window: int = 512
# Haar-random direction control (placebo): same routing machinery, no pair signal.
routeV_random_v_seed: int | None = None
rollout_ablate_frac: float = 0.0
env_mode: EnvMode = "run_tests"
# Rotating per-step chance a problem is shown hint-free (graded honest-only);
# keeps solve pressure alive. 10% learned solve too slowly; 25% -> 50% on 2026-06-10
# (equal hack/solve pressure, harder problems, faster env -- all upside).
unhackable_frac: float = 0.5
teacher_pool_dir: Path | None = None
mix_ratio: float = 0.125
teacher_off_step: int | None = 30
teacher_modes: tuple[str, ...] | None = None
# Symmetric solve-teacher pool (honest GT-passing demos). When set, the G_t
# teacher slots split solve_mix_frac solve / (1-frac) hack, so the gate sees
# honest examples it must NOT route (the routed-share discrimination diagnostic)
# and solve pressure matches hack pressure. Needs teacher_pool_dir + mix_ratio>0.
solve_pool_dir: Path | None = None
solve_mix_frac: float = 0.5
# Deterministic teacher forcing: in the teacher phase (step < teacher_off_step) every
# generated prompt is drawn from the both-pool-covered set and gets EXACTLY
# teacher_n_per_prompt hack + teacher_n_per_prompt solve teachers; the rest of `group`
# are students. Constant count per prompt, no flip/coverage drops. Pool has ~1
# rollout/prompt, so N=1 avoids sampling the same cached row twice. Replaces the
# mix_ratio * _even_split step budget (whose count varied with flips/coverage).
teacher_n_per_prompt: int = 1
eval_ablate_every: int = 0
eval_n_prompts: int = 32
# HF generate + 252 per-module lora2r hooks dispatch Python per decode token, so eval
# is GPU-starved (~19% util at bs=2). Bigger batch amortizes that fixed per-call hook
# cost across more sequences (32 prompts -> 4 batches not 16) -> ~3x faster inline eval.
eval_batch_size: int = 8
save_ckpt_every: int = 10
out_tag: str = ""
@property
def preset_name(self) -> str:
return type(self).__name__.removesuffix("Config").lower() or "base"
@property
def arm(self) -> str:
# _lora2r suffix kept so these runs never conflate with the retired
# PiSSA-substrate runs of the same intervention (rename-on-logic-change).
return {"none": "vanilla_lora2r", "routeV": "routingV_lora2r",
"absorb": "absorb_lora2r"}[self.intervention]
@dataclass(kw_only=True)
class SmokeConfig(Config):
model: str = "llamafactory/tiny-random-qwen3"
lora_r: int = 4 # tiny model min Linear dim is 16; 2r=8 fits
steps: int = 30
group: int = 4
max_new: int = 32
n_problems: int = 100
prompts_per_step: int = 1
@dataclass(kw_only=True)
class FastConfig(Config):
model: str = "Qwen/Qwen3-4B"
steps: int = 100
teacher_pool_dir: Path | None = Path("out/pools/teacher_pool_runtests_dense")
group: int = 8
max_new: int = 512
n_problems: int = 200
prompts_per_step: int = 4
adam_beta1: float = 0.5
adam_beta2: float = 0.9
lr: float = 5e-4 # user: bump from 1e-4 to learn faster in the short grad-starved budget
# Each lora2r ckpt is ~1.3G (A/B + redundant frozen A0/B0 for 252 modules, fp32); 20-step
# cadence keeps ~6/run for the eval curve without filling the 768G disk. (TODO: drop A0/B0
# from ckpts -- reconstructible from lora_init_seed -- to halve size, needs a loader change.)
save_ckpt_every: int = 20