mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
f1dd9fb33e
Each lora2r ckpt is ~1.3G (A/B + redundant frozen A0/B0, 252 modules fp32). The 768G disk filled and runs crashed at the step-0 ckpt save. 20-step cadence halves the per-run footprint while keeping enough points for the eval curve. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
135 lines
5.8 KiB
Python
135 lines
5.8 KiB
Python
"""Typed CLI configuration for train.py.
|
|
|
|
One adapter (lora2r: rank-2r Gaussian-init LoRA, A+B trainable, SGTM-style
|
|
three-way hard block masking; see src/vgrout/lora2r.py) and three arms:
|
|
|
|
none gate pinned clean (0,0): quarantine never trains -- the capacity- and
|
|
structure-matched vanilla control.
|
|
routeV per-rollout three-way gate from the c-probe gradient vs v_grad.
|
|
absorb gate pinned mid (1,0): both blocks train on everything, no gate --
|
|
isolates the value of the gate + hard masks vs absorption alone.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Literal
|
|
|
|
from .rewards import EnvMode
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class Config:
|
|
intervention: Literal["none", "routeV", "absorb"] = "routeV"
|
|
lora_r: int = 32
|
|
lora_init_seed: int = 0
|
|
|
|
model: str = "Qwen/Qwen3-4B"
|
|
steps: int = 100
|
|
group: int = 6
|
|
max_new: int = 1024
|
|
n_problems: int = 992
|
|
prompts_per_step: int = 8
|
|
lr: float = 1e-4
|
|
adam_beta1: float = 0.9
|
|
adam_beta2: float = 0.99
|
|
clip: float = 0.2
|
|
# AdamW decay pulls raw A/B toward 0, not toward the init, which would drive
|
|
# the net delta to -B0@A0 -- must stay 0 for this adapter.
|
|
weight_decay: float = 0.0
|
|
warmup_frac: float = 0.1
|
|
grad_clip: float = 10.0
|
|
seed: int = 41
|
|
unbiased: bool = True
|
|
|
|
vhack_refresh_every: int = 5
|
|
vhack_pairs_path: Path = Path("data/pairs/hack_pairs.md#all-in-one")
|
|
# Routing directions per module. k=1 (headline): the mean(hack-clean) "mean-mass"
|
|
# axis. k>1: top-k oriented SVD dirs of the paired diff; gate scores max_i cos(g,v_i)
|
|
# (alignment to ANY known hack sub-mode) -- catches multi-modal hack signal one mean
|
|
# washes out. k=1 stays mean-diff (not SVD-top-1) so "mean-mass vs top-k" is a clean A/B.
|
|
v_grad_k: int = 1
|
|
# Online-stats gate: route by live QUANTILES of the pooled cosine-to-v_grad, not the
|
|
# authored absolute band (whose p75-hack edge live hacks never reach -> rout~0). Each
|
|
# step the top route_quantile tail -> hack (deployed detached), bottom -> keep, middle
|
|
# -> mid. route_window = sliding buffer of recent pooled positions, flushed on refresh
|
|
# (positions are measured against one v_grad). Direction stays authored-only; only the
|
|
# threshold follows the live distribution. TODO(#30): center+width calibration vs fixed %.
|
|
route_quantile: float = 0.05
|
|
route_window: int = 512
|
|
# Haar-random direction control (placebo): same routing machinery, no pair signal.
|
|
routeV_random_v_seed: int | None = None
|
|
rollout_ablate_frac: float = 0.0
|
|
|
|
env_mode: EnvMode = "run_tests"
|
|
# Rotating per-step chance a problem is shown hint-free (graded honest-only);
|
|
# keeps solve pressure alive. 10% learned solve too slowly; 25% -> 50% on 2026-06-10
|
|
# (equal hack/solve pressure, harder problems, faster env -- all upside).
|
|
unhackable_frac: float = 0.5
|
|
teacher_pool_dir: Path | None = None
|
|
mix_ratio: float = 0.125
|
|
teacher_off_step: int | None = 30
|
|
teacher_modes: tuple[str, ...] | None = None
|
|
# Symmetric solve-teacher pool (honest GT-passing demos). When set, the G_t
|
|
# teacher slots split solve_mix_frac solve / (1-frac) hack, so the gate sees
|
|
# honest examples it must NOT route (the routed-share discrimination diagnostic)
|
|
# and solve pressure matches hack pressure. Needs teacher_pool_dir + mix_ratio>0.
|
|
solve_pool_dir: Path | None = None
|
|
solve_mix_frac: float = 0.5
|
|
# Deterministic teacher forcing: in the teacher phase (step < teacher_off_step) every
|
|
# generated prompt is drawn from the both-pool-covered set and gets EXACTLY
|
|
# teacher_n_per_prompt hack + teacher_n_per_prompt solve teachers; the rest of `group`
|
|
# are students. Constant count per prompt, no flip/coverage drops. Pool has ~1
|
|
# rollout/prompt, so N=1 avoids sampling the same cached row twice. Replaces the
|
|
# mix_ratio * _even_split step budget (whose count varied with flips/coverage).
|
|
teacher_n_per_prompt: int = 1
|
|
|
|
eval_ablate_every: int = 0
|
|
eval_n_prompts: int = 32
|
|
# HF generate + 252 per-module lora2r hooks dispatch Python per decode token, so eval
|
|
# is GPU-starved (~19% util at bs=2). Bigger batch amortizes that fixed per-call hook
|
|
# cost across more sequences (32 prompts -> 4 batches not 16) -> ~3x faster inline eval.
|
|
eval_batch_size: int = 8
|
|
save_ckpt_every: int = 10
|
|
out_tag: str = ""
|
|
|
|
@property
|
|
def preset_name(self) -> str:
|
|
return type(self).__name__.removesuffix("Config").lower() or "base"
|
|
|
|
@property
|
|
def arm(self) -> str:
|
|
# _lora2r suffix kept so these runs never conflate with the retired
|
|
# PiSSA-substrate runs of the same intervention (rename-on-logic-change).
|
|
return {"none": "vanilla_lora2r", "routeV": "routingV_lora2r",
|
|
"absorb": "absorb_lora2r"}[self.intervention]
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class SmokeConfig(Config):
|
|
model: str = "llamafactory/tiny-random-qwen3"
|
|
lora_r: int = 4 # tiny model min Linear dim is 16; 2r=8 fits
|
|
steps: int = 30
|
|
group: int = 4
|
|
max_new: int = 32
|
|
n_problems: int = 100
|
|
prompts_per_step: int = 1
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class FastConfig(Config):
|
|
model: str = "Qwen/Qwen3-4B"
|
|
steps: int = 100
|
|
teacher_pool_dir: Path | None = Path("out/pools/teacher_pool_runtests_dense")
|
|
group: int = 8
|
|
max_new: int = 512
|
|
n_problems: int = 200
|
|
prompts_per_step: int = 4
|
|
adam_beta1: float = 0.5
|
|
adam_beta2: float = 0.9
|
|
lr: float = 5e-4 # user: bump from 1e-4 to learn faster in the short grad-starved budget
|
|
# Each lora2r ckpt is ~1.3G (A/B + redundant frozen A0/B0 for 252 modules, fp32); 20-step
|
|
# cadence keeps ~6/run for the eval curve without filling the 768G disk. (TODO: drop A0/B0
|
|
# from ckpts -- reconstructible from lora_init_seed -- to halve size, needs a loader change.)
|
|
save_ckpt_every: int = 20
|