mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
fix: lr 3e-4 + 20% warmup (5e-4 diverged at warmup-end); slim bf16 ckpts
- FastConfig lr 5e-4 -> 3e-4: 5e-4 peaked exactly at warmup-end (step ~10) and diverged (lp_t -0.5 -> -4.8, hack_s 20/24 -> 0). Lower peak + longer warmup defuse the spike. - Config warmup_frac 0.1 -> 0.2: SequentialLR(LinearLR, CosineAnnealingLR) already does warmup+cosine relaxation; just reach the peak more gradually. - save_ckpt: drop A0/B0 (seeded init, regenerable from lora_init_seed; ckpt_update0000 is the init since A==A0 at step 0; nothing live reloads them), save A/B bf16 not fp32. ~1.3G -> ~0.33G per ckpt. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
+7
-8
@@ -549,21 +549,20 @@ def main(cfg: Config) -> int:
|
|||||||
route_solveT_run: list[float] = [] # per-step routed-share of solve teachers
|
route_solveT_run: list[float] = [] # per-step routed-share of solve teachers
|
||||||
|
|
||||||
def save_ckpt(rows: list[dict], path: Path | None = None) -> None:
|
def save_ckpt(rows: list[dict], path: Path | None = None) -> None:
|
||||||
"""Save a self-contained lora2r checkpoint: full A/B + the frozen init A0/B0,
|
"""Save a lora2r checkpoint: trainable A/B only, bf16. The init A0/B0 are NOT
|
||||||
so a loader reconstructs the net delta (B@A - B0@A0) and can ablate the
|
saved -- they're the seeded Gaussian (regenerable from lora_init_seed) and
|
||||||
quarantine without any SVD cache. Config + per-step rows in the metadata."""
|
ckpt_update0000 holds them anyway (A==A0 at step 0); nothing live reloads
|
||||||
|
them. Config + per-step rows in the metadata."""
|
||||||
n_gens = sum(r["N"] for r in rows)
|
n_gens = sum(r["N"] for r in rows)
|
||||||
hr = sum(r["hack_s"][0] + r["hack_t"][0] for r in rows) / max(1, n_gens)
|
hr = sum(r["hack_s"][0] + r["hack_t"][0] for r in rows) / max(1, n_gens)
|
||||||
pr = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows) / max(1, n_gens)
|
pr = sum(r["gt_s"][0] + r["gt_t"][0] for r in rows) / max(1, n_gens)
|
||||||
_ckpt = path or ckpt_path
|
_ckpt = path or ckpt_path
|
||||||
tensors = {}
|
tensors = {}
|
||||||
for n, info in wrappers.items():
|
for n, info in wrappers.items():
|
||||||
tensors[f"A/{n}"] = info["A"].detach().float().cpu().contiguous()
|
tensors[f"A/{n}"] = info["A"].detach().bfloat16().cpu().contiguous()
|
||||||
tensors[f"B/{n}"] = info["B"].detach().float().cpu().contiguous()
|
tensors[f"B/{n}"] = info["B"].detach().bfloat16().cpu().contiguous()
|
||||||
tensors[f"A0/{n}"] = info["A0"].detach().float().cpu().contiguous()
|
|
||||||
tensors[f"B0/{n}"] = info["B0"].detach().float().cpu().contiguous()
|
|
||||||
save_file(tensors, str(_ckpt), metadata={
|
save_file(tensors, str(_ckpt), metadata={
|
||||||
"model": model_name, "dtype": "fp32", "step": str(len(rows)),
|
"model": model_name, "dtype": "bf16", "step": str(len(rows)),
|
||||||
"hack_rate": f"{hr:.6f}", "pass_rate": f"{pr:.6f}",
|
"hack_rate": f"{hr:.6f}", "pass_rate": f"{pr:.6f}",
|
||||||
"rows": json.dumps(rows),
|
"rows": json.dumps(rows),
|
||||||
"cfg": json.dumps(vars(cfg) | {"adapter": "lora2r"}, default=str),
|
"cfg": json.dumps(vars(cfg) | {"adapter": "lora2r"}, default=str),
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class Config:
|
|||||||
# AdamW decay pulls raw A/B toward 0, not toward the init, which would drive
|
# AdamW decay pulls raw A/B toward 0, not toward the init, which would drive
|
||||||
# the net delta to -B0@A0 -- must stay 0 for this adapter.
|
# the net delta to -B0@A0 -- must stay 0 for this adapter.
|
||||||
weight_decay: float = 0.0
|
weight_decay: float = 0.0
|
||||||
warmup_frac: float = 0.1
|
warmup_frac: float = 0.2
|
||||||
grad_clip: float = 10.0
|
grad_clip: float = 10.0
|
||||||
seed: int = 41
|
seed: int = 41
|
||||||
unbiased: bool = True
|
unbiased: bool = True
|
||||||
@@ -127,8 +127,9 @@ class FastConfig(Config):
|
|||||||
prompts_per_step: int = 4
|
prompts_per_step: int = 4
|
||||||
adam_beta1: float = 0.5
|
adam_beta1: float = 0.5
|
||||||
adam_beta2: float = 0.9
|
adam_beta2: float = 0.9
|
||||||
lr: float = 5e-4 # user: bump from 1e-4 to learn faster in the short grad-starved budget
|
lr: float = 3e-4 # 5e-4 peaked at warmup-end (step ~10) and diverged; 3e-4 + 20% warmup
|
||||||
# Each lora2r ckpt is ~1.3G (A/B + redundant frozen A0/B0 for 252 modules, fp32); 20-step
|
# Each lora2r ckpt is ~0.33G (A/B for 252 modules, bf16). A0/B0 are NOT saved -- they're
|
||||||
# cadence keeps ~6/run for the eval curve without filling the 768G disk. (TODO: drop A0/B0
|
# the seeded init (regenerable from lora_init_seed in the metadata; ckpt_update0000 is the
|
||||||
# from ckpts -- reconstructible from lora_init_seed -- to halve size, needs a loader change.)
|
# init since A==A0 at step 0), and nothing live reloads them (inline eval uses in-memory
|
||||||
|
# wrappers). 20-step cadence keeps ~6/run for the eval curve.
|
||||||
save_ckpt_every: int = 20
|
save_ckpt_every: int = 20
|
||||||
|
|||||||
Reference in New Issue
Block a user